[4cc9fe5] | 1 | #! /usr/bin/python |
---|
| 2 | |
---|
| 3 | from aubio.bench.node import * |
---|
[4f4a8a4] | 4 | from aubio.tasks import * |
---|
[75139a9] | 5 | |
---|
[336cf77] | 6 | class benchonset(bench): |
---|
[75139a9] | 7 | |
---|
[4f4a8a4] | 8 | def dir_eval(self): |
---|
[75139a9] | 9 | self.P = 100*float(self.expc-self.missed-self.merged)/(self.expc-self.missed-self.merged + self.bad+self.doubled) |
---|
| 10 | self.R = 100*float(self.expc-self.missed-self.merged)/(self.expc-self.missed-self.merged + self.missed+self.merged) |
---|
| 11 | if self.R < 0: self.R = 0 |
---|
| 12 | self.F = 2* self.P*self.R / (self.P+self.R) |
---|
| 13 | |
---|
[4f4a8a4] | 14 | self.values = [self.params.onsetmode, |
---|
[75139a9] | 15 | "%2.3f" % self.params.threshold, |
---|
| 16 | self.orig, |
---|
| 17 | self.expc, |
---|
| 18 | self.missed, |
---|
| 19 | self.merged, |
---|
| 20 | self.bad, |
---|
| 21 | self.doubled, |
---|
| 22 | (self.orig-self.missed-self.merged), |
---|
| 23 | "%2.3f" % (100*float(self.orig-self.missed-self.merged)/(self.orig)), |
---|
| 24 | "%2.3f" % (100*float(self.bad+self.doubled)/(self.orig)), |
---|
| 25 | "%2.3f" % (100*float(self.orig-self.missed)/(self.orig)), |
---|
| 26 | "%2.3f" % (100*float(self.bad)/(self.orig)), |
---|
| 27 | "%2.3f" % self.P, |
---|
| 28 | "%2.3f" % self.R, |
---|
| 29 | "%2.3f" % self.F ] |
---|
| 30 | |
---|
[4f4a8a4] | 31 | def file_exec(self,input,output): |
---|
| 32 | filetask = self.task(input,params=self.params) |
---|
| 33 | computed_data = filetask.compute_all() |
---|
| 34 | results = filetask.eval(computed_data) |
---|
| 35 | self.orig += filetask.orig |
---|
| 36 | self.missed += filetask.missed |
---|
| 37 | self.merged += filetask.merged |
---|
| 38 | self.expc += filetask.expc |
---|
| 39 | self.bad += filetask.bad |
---|
| 40 | self.doubled += filetask.doubled |
---|
| 41 | |
---|
[75139a9] | 42 | |
---|
| 43 | def run_bench(self,modes=['dual'],thresholds=[0.5]): |
---|
| 44 | self.modes = modes |
---|
| 45 | self.thresholds = thresholds |
---|
| 46 | |
---|
| 47 | self.pretty_print(self.titles) |
---|
| 48 | for mode in self.modes: |
---|
[4f4a8a4] | 49 | self.params.onsetmode = mode |
---|
[75139a9] | 50 | for threshold in self.thresholds: |
---|
| 51 | self.params.threshold = threshold |
---|
[4f4a8a4] | 52 | self.dir_exec() |
---|
| 53 | self.dir_eval() |
---|
[75139a9] | 54 | self.pretty_print(self.values) |
---|
[4cc9fe5] | 55 | |
---|
[75139a9] | 56 | def auto_learn(self,modes=['dual'],thresholds=[0.1,1.5]): |
---|
| 57 | """ simple dichotomia like algorithm to optimise threshold """ |
---|
| 58 | self.modes = modes |
---|
| 59 | self.pretty_print(self.titles) |
---|
| 60 | for mode in self.modes: |
---|
| 61 | steps = 10 |
---|
| 62 | lesst = thresholds[0] |
---|
| 63 | topt = thresholds[1] |
---|
[4f4a8a4] | 64 | self.params.onsetmode = mode |
---|
[75139a9] | 65 | |
---|
| 66 | self.params.threshold = topt |
---|
[4f4a8a4] | 67 | self.dir_exec() |
---|
| 68 | self.dir_eval() |
---|
[75139a9] | 69 | self.pretty_print(self.values) |
---|
| 70 | topF = self.F |
---|
| 71 | |
---|
| 72 | self.params.threshold = lesst |
---|
[4f4a8a4] | 73 | self.dir_exec() |
---|
| 74 | self.dir_eval() |
---|
[75139a9] | 75 | self.pretty_print(self.values) |
---|
| 76 | lessF = self.F |
---|
| 77 | |
---|
| 78 | for i in range(steps): |
---|
| 79 | self.params.threshold = ( lesst + topt ) * .5 |
---|
[4f4a8a4] | 80 | self.dir_exec() |
---|
| 81 | self.dir_eval() |
---|
[75139a9] | 82 | self.pretty_print(self.values) |
---|
| 83 | if self.F == 100.0 or self.F == topF: |
---|
| 84 | print "assuming we converged, stopping" |
---|
| 85 | break |
---|
| 86 | #elif abs(self.F - topF) < 0.01 : |
---|
| 87 | # print "done converging" |
---|
| 88 | # break |
---|
| 89 | if topF < self.F: |
---|
| 90 | #lessF = topF |
---|
| 91 | #lesst = topt |
---|
| 92 | topF = self.F |
---|
| 93 | topt = self.params.threshold |
---|
| 94 | elif lessF < self.F: |
---|
| 95 | lessF = self.F |
---|
| 96 | lesst = self.params.threshold |
---|
| 97 | if topt == lesst: |
---|
| 98 | lesst /= 2. |
---|
| 99 | |
---|
[af445db] | 100 | def auto_learn2(self,modes=['dual'],thresholds=[0.1,1.0]): |
---|
| 101 | """ simple dichotomia like algorithm to optimise threshold """ |
---|
| 102 | self.modes = modes |
---|
| 103 | self.pretty_print(self.titles) |
---|
| 104 | for mode in self.modes: |
---|
| 105 | steps = 10 |
---|
| 106 | step = thresholds[1] |
---|
| 107 | curt = thresholds[0] |
---|
| 108 | self.params.onsetmode = mode |
---|
| 109 | |
---|
| 110 | self.params.threshold = curt |
---|
| 111 | self.dir_exec() |
---|
| 112 | self.dir_eval() |
---|
| 113 | self.pretty_print(self.values) |
---|
| 114 | curexp = self.expc |
---|
| 115 | |
---|
| 116 | for i in range(steps): |
---|
| 117 | if curexp < self.orig: |
---|
| 118 | #print "we found at most less onsets than annotated" |
---|
| 119 | self.params.threshold -= step |
---|
| 120 | step /= 2 |
---|
| 121 | elif curexp > self.orig: |
---|
| 122 | #print "we found more onsets than annotated" |
---|
| 123 | self.params.threshold += step |
---|
| 124 | step /= 2 |
---|
| 125 | self.dir_exec() |
---|
| 126 | self.dir_eval() |
---|
| 127 | curexp = self.expc |
---|
| 128 | self.pretty_print(self.values) |
---|
| 129 | if self.orig == 100.0 or self.orig == self.expc: |
---|
| 130 | print "assuming we converged, stopping" |
---|
| 131 | break |
---|
[75139a9] | 132 | |
---|
[4f4a8a4] | 133 | if __name__ == "__main__": |
---|
| 134 | import sys |
---|
| 135 | if len(sys.argv) > 1: datapath = sys.argv[1] |
---|
| 136 | else: print "ERR: a path is required"; sys.exit(1) |
---|
| 137 | modes = ['complex', 'energy', 'phase', 'specdiff', 'kl', 'mkl', 'dual'] |
---|
| 138 | #modes = [ 'complex' ] |
---|
| 139 | thresholds = [ 0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5] |
---|
| 140 | #thresholds = [1.5] |
---|
| 141 | |
---|
| 142 | #datapath = "%s%s" % (DATADIR,'/onset/DB/*/') |
---|
| 143 | respath = '/var/tmp/DB-testings' |
---|
| 144 | |
---|
| 145 | benchonset = benchonset(datapath,respath,checkres=True,checkanno=True) |
---|
| 146 | benchonset.params = taskparams() |
---|
| 147 | benchonset.task = taskonset |
---|
| 148 | |
---|
| 149 | benchonset.titles = [ 'mode', 'thres', 'orig', 'expc', 'missd', 'mergd', |
---|
| 150 | 'bad', 'doubl', 'corrt', 'GD', 'FP', 'GD-merged', 'FP-pruned', |
---|
| 151 | 'prec', 'recl', 'dist' ] |
---|
| 152 | benchonset.formats = ["%12s" , "| %6s", "| %6s", "| %6s", "| %6s", "| %6s", |
---|
| 153 | "| %6s", "| %6s", "| %6s", "| %8s", "| %8s", "| %8s", "| %8s", |
---|
| 154 | "| %6s", "| %6s", "| %6s"] |
---|
| 155 | |
---|
| 156 | try: |
---|
[af445db] | 157 | benchonset.auto_learn2(modes=modes) |
---|
[4f4a8a4] | 158 | #benchonset.run_bench(modes=modes) |
---|
| 159 | except KeyboardInterrupt: |
---|
| 160 | sys.exit(1) |
---|