source: python/test/bench/onset/bench-onset @ e968939

feature/autosinkfeature/cnnfeature/cnn_orgfeature/constantqfeature/crepefeature/crepe_orgfeature/pitchshiftfeature/pydocstringsfeature/timestretchfix/ffmpeg5pitchshiftsamplertimestretchyinfft+
Last change on this file since e968939 was e968939, checked in by Paul Brossier <piem@altern.org>, 19 years ago

update to new bench onset
update to new bench onset

  • Property mode set to 100755
File size: 6.5 KB
Line 
1#! /usr/bin/python
2
3from aubio.bench.node import *
4from aubio.tasks import *
5
6
7
8
9def mmean(l):
10        return sum(l)/float(len(l))
11
12def stdev(l):
13        smean = 0
14        lmean = mmean(l)
15        for i in l:
16                smean += (i-lmean)**2
17        smean *= 1. / len(l)
18        return smean**.5
19
20class benchonset(bench):
21
22        valuenames = ['orig','missed','Tm','expc','bad','Td']
23        valuelists = ['l','labs']
24        printnames = [ 'mode', 'thres', 'dist', 'prec', 'recl', 'Ttrue', 'Tfp',  'Tfn',  'Tm',   'Td',
25                'aTtrue', 'aTfp', 'aTfn', 'aTm',  'aTd',  'mean', 'smean',  'amean', 'samean']
26
27        formats = {'mode': "%12s" ,
28        'thres': "%5.4s",
29        'dist':  "%5.4s",
30        'prec':  "%5.4s",
31        'recl':  "%5.4s",
32                 
33        'Ttrue': "%5.4s",
34        'Tfp':   "%5.4s",
35        'Tfn':   "%5.4s",
36        'Tm':    "%5.4s",
37        'Td':    "%5.4s",
38                 
39        'aTtrue':"%5.4s",
40        'aTfp':  "%5.4s",
41        'aTfn':  "%5.4s",
42        'aTm':   "%5.4s",
43        'aTd':   "%5.4s",
44                 
45        'mean':  "%5.40s",
46        'smean': "%5.40s",
47        'amean':  "%5.40s",
48        'samean': "%5.40s"}
49       
50        def file_gettruth(self,input):
51                from os.path import isfile
52                ftrulist = []
53                # search for match as filetask.input,".txt"
54                ftru = '.'.join(input.split('.')[:-1])
55                ftru = '.'.join((ftru,'txt'))
56                if isfile(ftru):
57                        ftrulist.append(ftru)
58                else:
59                        # search for matches for filetask.input in the list of results
60                        for i in range(len(self.reslist)):
61                                check = '.'.join(self.reslist[i].split('.')[:-1])
62                                check = '_'.join(check.split('_')[:-1])
63                                if check == '.'.join(input.split('.')[:-1]):
64                                        ftrulist.append(self.reslist[i])
65                return ftrulist
66
67        def file_exec(self,input,output):
68                filetask = self.task(input,params=self.params)
69                computed_data = filetask.compute_all()
70                ftrulist = self.file_gettruth(filetask.input)
71                for i in ftrulist:
72                        #print i
73                        filetask.eval(computed_data,i,mode='rocloc',vmode='')
74                        for i in self.valuenames:
75                                self.v[i] += filetask.v[i]
76                        for i in filetask.v['l']:
77                                self.v['l'].append(i)
78                        for i in filetask.v['labs']:
79                                self.v['labs'].append(i)
80       
81        def dir_exec(self):
82                """ run file_exec on every input file """
83                self.l , self.labs = [], []
84                self.v = {}
85                for i in self.valuenames:
86                        self.v[i] = 0.
87                for i in self.valuelists:
88                        self.v[i] = []
89                self.v['thres'] = self.params.threshold
90                act_on_files(self.file_exec,self.sndlist,self.reslist, \
91                        suffix='',filter=sndfile_filter)
92
93        def dir_eval(self):
94                totaltrue = self.v['expc']-self.v['bad']-self.v['Td']
95                totalfp = self.v['bad']+self.v['Td']
96                totalfn = self.v['missed']+self.v['Tm']
97                self.P = 100*float(totaltrue)/max(totaltrue + totalfp,1)
98                self.R = 100*float(totaltrue)/max(totaltrue + totalfn,1)
99                if self.R < 0: self.R = 0
100                self.F = 2.* self.P*self.R / max(float(self.P+self.R),1)
101               
102                N = float(len(self.reslist))
103
104                self.v['mode']      = self.params.onsetmode
105                self.v['thres']     = "%2.3f" % self.params.threshold
106                self.v['dist']      = "%2.3f" % self.F
107                self.v['prec']      = "%2.3f" % self.P
108                self.v['recl']      = "%2.3f" % self.R
109                self.v['Ttrue']     = totaltrue
110                self.v['Tfp']       = totalfp
111                self.v['Tfn']       = totalfn
112                self.v['aTtrue']    = totaltrue/N
113                self.v['aTfp']      = totalfp/N
114                self.v['aTfn']      = totalfn/N
115                self.v['aTm']       = self.v['Tm']/N
116                self.v['aTd']       = self.v['Td']/N
117                self.v['mean']      = mmean(self.v['l'])
118                self.v['smean']     = stdev(self.v['l'])
119                self.v['amean']     = mmean(self.v['labs'])
120                self.v['samean']    = stdev(self.v['labs'])
121
122        def run_bench(self,modes=['dual'],thresholds=[0.5]):
123                self.modes = modes
124                self.thresholds = thresholds
125
126                self.pretty_titles()
127                for mode in self.modes:
128                        self.params.onsetmode = mode
129                        for threshold in self.thresholds:
130                                self.params.threshold = threshold
131                                self.dir_exec()
132                                self.dir_eval()
133                                self.pretty_print()
134                                #print self.v
135
136        def pretty_print(self,sep='|'):
137                for i in self.printnames:
138                        print self.formats[i] % self.v[i], sep,
139                print
140
141        def pretty_titles(self,sep='|'):
142                for i in self.printnames:
143                        print self.formats[i] % i, sep,
144                print
145
146        def auto_learn(self,modes=['dual'],thresholds=[0.1,1.5]):
147                """ simple dichotomia like algorithm to optimise threshold """
148                self.modes = modes
149                self.pretty_titles()
150                for mode in self.modes:
151                        steps = 10
152                        lesst = thresholds[0]
153                        topt = thresholds[1]
154                        self.params.onsetmode = mode
155
156                        self.params.threshold = topt
157                        self.dir_exec()
158                        self.dir_eval()
159                        self.pretty_print()
160                        topF = self.F
161
162                        self.params.threshold = lesst
163                        self.dir_exec()
164                        self.dir_eval()
165                        self.pretty_print()
166                        lessF = self.F
167
168                        for i in range(steps):
169                                self.params.threshold = ( lesst + topt ) * .5
170                                self.dir_exec()
171                                self.dir_eval()
172                                self.pretty_print()
173                                if self.F == 100.0 or self.F == topF:
174                                        print "assuming we converged, stopping"
175                                        break
176                                #elif abs(self.F - topF) < 0.01 :
177                                #       print "done converging"
178                                #       break
179                                if topF < self.F:
180                                        #lessF = topF
181                                        #lesst = topt
182                                        topF = self.F
183                                        topt = self.params.threshold
184                                elif lessF < self.F:
185                                        lessF = self.F
186                                        lesst = self.params.threshold
187                                if topt == lesst:
188                                        lesst /= 2.
189
190        def auto_learn2(self,modes=['dual'],thresholds=[0.00001,1.0]):
191                """ simple dichotomia like algorithm to optimise threshold """
192                self.modes = modes
193                self.pretty_titles([])
194                for mode in self.modes:
195                        steps = 10
196                        step = 0.4
197                        self.params.onsetmode = mode
198                        self.params.threshold = thresholds[0]
199                        cur = 0
200
201                        for i in range(steps):
202                                self.dir_exec()
203                                self.dir_eval()
204                                self.pretty_print()
205                                new = self.P
206                                if self.R == 0.0:
207                                        #print "Found maximum, highering"
208                                        step /= 2.
209                                        self.params.threshold -= step
210                                elif new == 100.0:
211                                        #print "Found maximum, highering"
212                                        step *= .99
213                                        self.params.threshold += step
214                                elif cur > new:
215                                        #print "lower"
216                                        step /= 2.
217                                        self.params.threshold -= step
218                                elif cur < new:
219                                        #print "higher"
220                                        step *= .99
221                                        self.params.threshold += step
222                                else:
223                                        print "Assuming we converged"
224                                        break
225                                cur = new
226
227
228if __name__ == "__main__":
229        import sys
230        if len(sys.argv) > 1: datapath = sys.argv[1]
231        else: print "ERR: a path is required"; sys.exit(1)
232        modes = ['complex', 'energy', 'phase', 'specdiff', 'kl', 'mkl', 'dual']
233        #modes = [ 'phase' ]
234        thresholds = [ 0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2]
235        #thresholds = [1.5]
236
237        #datapath = "%s%s" % (DATADIR,'/onset/DB/*/')
238        respath = '/var/tmp/DB-testings'
239
240        benchonset = benchonset(datapath,respath,checkres=True,checkanno=True)
241        benchonset.params = taskparams()
242        benchonset.task = taskonset
243        benchonset.valuesdict = {}
244
245
246        try:
247                #benchonset.auto_learn2(modes=modes)
248                benchonset.run_bench(modes=modes,thresholds=thresholds)
249        except KeyboardInterrupt:
250                sys.exit(1)
Note: See TracBrowser for help on using the repository browser.