1 | #! /usr/bin/python |
---|
2 | |
---|
3 | from aubio.bench.node import * |
---|
4 | from aubio.tasks import * |
---|
5 | |
---|
6 | class benchonset(bench): |
---|
7 | |
---|
8 | def dir_eval(self): |
---|
9 | self.P = 100*float(self.expc-self.missed-self.merged)/(self.expc-self.missed-self.merged + self.bad+self.doubled) |
---|
10 | self.R = 100*float(self.expc-self.missed-self.merged)/(self.expc-self.missed-self.merged + self.missed+self.merged) |
---|
11 | if self.R < 0: self.R = 0 |
---|
12 | self.F = 2* self.P*self.R / (self.P+self.R) |
---|
13 | |
---|
14 | self.values = [self.params.onsetmode, |
---|
15 | "%2.3f" % self.params.threshold, |
---|
16 | self.orig, |
---|
17 | self.expc, |
---|
18 | self.missed, |
---|
19 | self.merged, |
---|
20 | self.bad, |
---|
21 | self.doubled, |
---|
22 | (self.orig-self.missed-self.merged), |
---|
23 | "%2.3f" % (100*float(self.orig-self.missed-self.merged)/(self.orig)), |
---|
24 | "%2.3f" % (100*float(self.bad+self.doubled)/(self.orig)), |
---|
25 | "%2.3f" % (100*float(self.orig-self.missed)/(self.orig)), |
---|
26 | "%2.3f" % (100*float(self.bad)/(self.orig)), |
---|
27 | "%2.3f" % self.P, |
---|
28 | "%2.3f" % self.R, |
---|
29 | "%2.3f" % self.F ] |
---|
30 | |
---|
31 | def file_exec(self,input,output): |
---|
32 | filetask = self.task(input,params=self.params) |
---|
33 | computed_data = filetask.compute_all() |
---|
34 | results = filetask.eval(computed_data) |
---|
35 | self.orig += filetask.orig |
---|
36 | self.missed += filetask.missed |
---|
37 | self.merged += filetask.merged |
---|
38 | self.expc += filetask.expc |
---|
39 | self.bad += filetask.bad |
---|
40 | self.doubled += filetask.doubled |
---|
41 | |
---|
42 | |
---|
43 | def run_bench(self,modes=['dual'],thresholds=[0.5]): |
---|
44 | self.modes = modes |
---|
45 | self.thresholds = thresholds |
---|
46 | |
---|
47 | self.pretty_print(self.titles) |
---|
48 | for mode in self.modes: |
---|
49 | self.params.onsetmode = mode |
---|
50 | for threshold in self.thresholds: |
---|
51 | self.params.threshold = threshold |
---|
52 | self.dir_exec() |
---|
53 | self.dir_eval() |
---|
54 | self.pretty_print(self.values) |
---|
55 | |
---|
56 | def auto_learn(self,modes=['dual'],thresholds=[0.1,1.5]): |
---|
57 | """ simple dichotomia like algorithm to optimise threshold """ |
---|
58 | self.modes = modes |
---|
59 | self.pretty_print(self.titles) |
---|
60 | for mode in self.modes: |
---|
61 | steps = 10 |
---|
62 | lesst = thresholds[0] |
---|
63 | topt = thresholds[1] |
---|
64 | self.params.onsetmode = mode |
---|
65 | |
---|
66 | self.params.threshold = topt |
---|
67 | self.dir_exec() |
---|
68 | self.dir_eval() |
---|
69 | self.pretty_print(self.values) |
---|
70 | topF = self.F |
---|
71 | |
---|
72 | self.params.threshold = lesst |
---|
73 | self.dir_exec() |
---|
74 | self.dir_eval() |
---|
75 | self.pretty_print(self.values) |
---|
76 | lessF = self.F |
---|
77 | |
---|
78 | for i in range(steps): |
---|
79 | self.params.threshold = ( lesst + topt ) * .5 |
---|
80 | self.dir_exec() |
---|
81 | self.dir_eval() |
---|
82 | self.pretty_print(self.values) |
---|
83 | if self.F == 100.0 or self.F == topF: |
---|
84 | print "assuming we converged, stopping" |
---|
85 | break |
---|
86 | #elif abs(self.F - topF) < 0.01 : |
---|
87 | # print "done converging" |
---|
88 | # break |
---|
89 | if topF < self.F: |
---|
90 | #lessF = topF |
---|
91 | #lesst = topt |
---|
92 | topF = self.F |
---|
93 | topt = self.params.threshold |
---|
94 | elif lessF < self.F: |
---|
95 | lessF = self.F |
---|
96 | lesst = self.params.threshold |
---|
97 | if topt == lesst: |
---|
98 | lesst /= 2. |
---|
99 | |
---|
100 | |
---|
101 | if __name__ == "__main__": |
---|
102 | import sys |
---|
103 | if len(sys.argv) > 1: datapath = sys.argv[1] |
---|
104 | else: print "ERR: a path is required"; sys.exit(1) |
---|
105 | modes = ['complex', 'energy', 'phase', 'specdiff', 'kl', 'mkl', 'dual'] |
---|
106 | #modes = [ 'complex' ] |
---|
107 | thresholds = [ 0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5] |
---|
108 | #thresholds = [1.5] |
---|
109 | |
---|
110 | #datapath = "%s%s" % (DATADIR,'/onset/DB/*/') |
---|
111 | respath = '/var/tmp/DB-testings' |
---|
112 | |
---|
113 | benchonset = benchonset(datapath,respath,checkres=True,checkanno=True) |
---|
114 | benchonset.params = taskparams() |
---|
115 | benchonset.task = taskonset |
---|
116 | |
---|
117 | benchonset.titles = [ 'mode', 'thres', 'orig', 'expc', 'missd', 'mergd', |
---|
118 | 'bad', 'doubl', 'corrt', 'GD', 'FP', 'GD-merged', 'FP-pruned', |
---|
119 | 'prec', 'recl', 'dist' ] |
---|
120 | benchonset.formats = ["%12s" , "| %6s", "| %6s", "| %6s", "| %6s", "| %6s", |
---|
121 | "| %6s", "| %6s", "| %6s", "| %8s", "| %8s", "| %8s", "| %8s", |
---|
122 | "| %6s", "| %6s", "| %6s"] |
---|
123 | |
---|
124 | try: |
---|
125 | benchonset.auto_learn(modes=modes) |
---|
126 | #benchonset.run_bench(modes=modes) |
---|
127 | except KeyboardInterrupt: |
---|
128 | sys.exit(1) |
---|