Ignore:
Timestamp:
Feb 17, 2006, 5:07:36 PM (19 years ago)
Author:
Paul Brossier <piem@altern.org>
Branches:
feature/autosink, feature/cnn, feature/cnn_org, feature/constantq, feature/crepe, feature/crepe_org, feature/pitchshift, feature/pydocstrings, feature/timestretch, fix/ffmpeg5, master, pitchshift, sampler, timestretch, yinfft+
Children:
c912c67
Parents:
677b267
Message:

update to new bench onset
update to new bench onset

File:
1 edited

Legend:

Unmodified
Added
Removed
  • python/test/bench/onset/bench-onset

    r677b267 re968939  
    44from aubio.tasks import *
    55
     6
     7
     8
     9def mmean(l):
     10        return sum(l)/float(len(l))
     11
     12def stdev(l):
     13        smean = 0
     14        lmean = mmean(l)
     15        for i in l:
     16                smean += (i-lmean)**2
     17        smean *= 1. / len(l)
     18        return smean**.5
     19
    620class benchonset(bench):
     21
     22        valuenames = ['orig','missed','Tm','expc','bad','Td']
     23        valuelists = ['l','labs']
     24        printnames = [ 'mode', 'thres', 'dist', 'prec', 'recl', 'Ttrue', 'Tfp',  'Tfn',  'Tm',   'Td',
     25                'aTtrue', 'aTfp', 'aTfn', 'aTm',  'aTd',  'mean', 'smean',  'amean', 'samean']
     26
     27        formats = {'mode': "%12s" ,
     28        'thres': "%5.4s",
     29        'dist':  "%5.4s",
     30        'prec':  "%5.4s",
     31        'recl':  "%5.4s",
     32                 
     33        'Ttrue': "%5.4s",
     34        'Tfp':   "%5.4s",
     35        'Tfn':   "%5.4s",
     36        'Tm':    "%5.4s",
     37        'Td':    "%5.4s",
     38                 
     39        'aTtrue':"%5.4s",
     40        'aTfp':  "%5.4s",
     41        'aTfn':  "%5.4s",
     42        'aTm':   "%5.4s",
     43        'aTd':   "%5.4s",
     44                 
     45        'mean':  "%5.40s",
     46        'smean': "%5.40s",
     47        'amean':  "%5.40s",
     48        'samean': "%5.40s"}
    749       
    8         def dir_eval(self):
    9                 self.P = 100*float(self.expc-self.missed-self.merged)/(self.expc-self.missed-self.merged + self.bad+self.doubled)
    10                 self.R = 100*float(self.expc-self.missed-self.merged)/(self.expc-self.missed-self.merged + self.missed+self.merged)
    11                 if self.R < 0: self.R = 0
    12                 self.F = 2* self.P*self.R / (self.P+self.R)
    13 
    14                 self.values = [self.params.onsetmode,
    15                 "%2.3f" % self.params.threshold,
    16                 self.orig,
    17                 self.expc,
    18                 self.missed,
    19                 self.merged,
    20                 self.bad,
    21                 self.doubled,
    22                 (self.orig-self.missed-self.merged),
    23                 "%2.3f" % (100*float(self.orig-self.missed-self.merged)/(self.orig)),
    24                 "%2.3f" % (100*float(self.bad+self.doubled)/(self.orig)),
    25                 "%2.3f" % (100*float(self.orig-self.missed)/(self.orig)),
    26                 "%2.3f" % (100*float(self.bad)/(self.orig)),
    27                 "%2.3f" % self.P,
    28                 "%2.3f" % self.R,
    29                 "%2.3f" % self.F  ]
     50        def file_gettruth(self,input):
     51                from os.path import isfile
     52                ftrulist = []
     53                # search for match as filetask.input,".txt"
     54                ftru = '.'.join(input.split('.')[:-1])
     55                ftru = '.'.join((ftru,'txt'))
     56                if isfile(ftru):
     57                        ftrulist.append(ftru)
     58                else:
     59                        # search for matches for filetask.input in the list of results
     60                        for i in range(len(self.reslist)):
     61                                check = '.'.join(self.reslist[i].split('.')[:-1])
     62                                check = '_'.join(check.split('_')[:-1])
     63                                if check == '.'.join(input.split('.')[:-1]):
     64                                        ftrulist.append(self.reslist[i])
     65                return ftrulist
    3066
    3167        def file_exec(self,input,output):
    3268                filetask = self.task(input,params=self.params)
    3369                computed_data = filetask.compute_all()
    34                 results = filetask.eval(computed_data)
    35                 self.orig    += filetask.orig
    36                 self.missed  += filetask.missed
    37                 self.merged  += filetask.merged
    38                 self.expc    += filetask.expc
    39                 self.bad     += filetask.bad
    40                 self.doubled += filetask.doubled
    41 
     70                ftrulist = self.file_gettruth(filetask.input)
     71                for i in ftrulist:
     72                        #print i
     73                        filetask.eval(computed_data,i,mode='rocloc',vmode='')
     74                        for i in self.valuenames:
     75                                self.v[i] += filetask.v[i]
     76                        for i in filetask.v['l']:
     77                                self.v['l'].append(i)
     78                        for i in filetask.v['labs']:
     79                                self.v['labs'].append(i)
     80       
     81        def dir_exec(self):
     82                """ run file_exec on every input file """
     83                self.l , self.labs = [], []
     84                self.v = {}
     85                for i in self.valuenames:
     86                        self.v[i] = 0.
     87                for i in self.valuelists:
     88                        self.v[i] = []
     89                self.v['thres'] = self.params.threshold
     90                act_on_files(self.file_exec,self.sndlist,self.reslist, \
     91                        suffix='',filter=sndfile_filter)
     92
     93        def dir_eval(self):
     94                totaltrue = self.v['expc']-self.v['bad']-self.v['Td']
     95                totalfp = self.v['bad']+self.v['Td']
     96                totalfn = self.v['missed']+self.v['Tm']
     97                self.P = 100*float(totaltrue)/max(totaltrue + totalfp,1)
     98                self.R = 100*float(totaltrue)/max(totaltrue + totalfn,1)
     99                if self.R < 0: self.R = 0
     100                self.F = 2.* self.P*self.R / max(float(self.P+self.R),1)
     101               
     102                N = float(len(self.reslist))
     103
     104                self.v['mode']      = self.params.onsetmode
     105                self.v['thres']     = "%2.3f" % self.params.threshold
     106                self.v['dist']      = "%2.3f" % self.F
     107                self.v['prec']      = "%2.3f" % self.P
     108                self.v['recl']      = "%2.3f" % self.R
     109                self.v['Ttrue']     = totaltrue
     110                self.v['Tfp']       = totalfp
     111                self.v['Tfn']       = totalfn
     112                self.v['aTtrue']    = totaltrue/N
     113                self.v['aTfp']      = totalfp/N
     114                self.v['aTfn']      = totalfn/N
     115                self.v['aTm']       = self.v['Tm']/N
     116                self.v['aTd']       = self.v['Td']/N
     117                self.v['mean']      = mmean(self.v['l'])
     118                self.v['smean']     = stdev(self.v['l'])
     119                self.v['amean']     = mmean(self.v['labs'])
     120                self.v['samean']    = stdev(self.v['labs'])
    42121
    43122        def run_bench(self,modes=['dual'],thresholds=[0.5]):
     
    45124                self.thresholds = thresholds
    46125
    47                 self.pretty_print(self.titles)
     126                self.pretty_titles()
    48127                for mode in self.modes:
    49128                        self.params.onsetmode = mode
     
    52131                                self.dir_exec()
    53132                                self.dir_eval()
    54                                 self.pretty_print(self.values)
     133                                self.pretty_print()
     134                                #print self.v
     135
     136        def pretty_print(self,sep='|'):
     137                for i in self.printnames:
     138                        print self.formats[i] % self.v[i], sep,
     139                print
     140
     141        def pretty_titles(self,sep='|'):
     142                for i in self.printnames:
     143                        print self.formats[i] % i, sep,
     144                print
    55145
    56146        def auto_learn(self,modes=['dual'],thresholds=[0.1,1.5]):
    57147                """ simple dichotomia like algorithm to optimise threshold """
    58148                self.modes = modes
    59                 self.pretty_print(self.titles)
     149                self.pretty_titles()
    60150                for mode in self.modes:
    61151                        steps = 10
     
    67157                        self.dir_exec()
    68158                        self.dir_eval()
    69                         self.pretty_print(self.values)
     159                        self.pretty_print()
    70160                        topF = self.F
    71161
     
    73163                        self.dir_exec()
    74164                        self.dir_eval()
    75                         self.pretty_print(self.values)
     165                        self.pretty_print()
    76166                        lessF = self.F
    77167
     
    80170                                self.dir_exec()
    81171                                self.dir_eval()
    82                                 self.pretty_print(self.values)
     172                                self.pretty_print()
    83173                                if self.F == 100.0 or self.F == topF:
    84174                                        print "assuming we converged, stopping"
     
    98188                                        lesst /= 2.
    99189
    100         def auto_learn2(self,modes=['dual'],thresholds=[0.1,1.0]):
     190        def auto_learn2(self,modes=['dual'],thresholds=[0.00001,1.0]):
    101191                """ simple dichotomia like algorithm to optimise threshold """
    102192                self.modes = modes
    103                 self.pretty_print(self.titles)
     193                self.pretty_titles([])
    104194                for mode in self.modes:
    105195                        steps = 10
    106                         step = thresholds[1]
    107                         curt = thresholds[0]
     196                        step = 0.4
    108197                        self.params.onsetmode = mode
    109 
    110                         self.params.threshold = curt
    111                         self.dir_exec()
    112                         self.dir_eval()
    113                         self.pretty_print(self.values)
    114                         curexp = self.expc
     198                        self.params.threshold = thresholds[0]
     199                        cur = 0
    115200
    116201                        for i in range(steps):
    117                                 if curexp < self.orig:
    118                                         #print "we found at most less onsets than annotated"
    119                                         self.params.threshold -= step
    120                                         step /= 2
    121                                 elif curexp > self.orig:
    122                                         #print "we found more onsets than annotated"
    123                                         self.params.threshold += step
    124                                         step /= 2
    125202                                self.dir_exec()
    126203                                self.dir_eval()
    127                                 curexp = self.expc
    128                                 self.pretty_print(self.values)
    129                                 if self.orig == 100.0 or self.orig == self.expc:
    130                                         print "assuming we converged, stopping"
     204                                self.pretty_print()
     205                                new = self.P
     206                                if self.R == 0.0:
     207                                        #print "Found maximum, highering"
     208                                        step /= 2.
     209                                        self.params.threshold -= step
     210                                elif new == 100.0:
     211                                        #print "Found maximum, highering"
     212                                        step *= .99
     213                                        self.params.threshold += step
     214                                elif cur > new:
     215                                        #print "lower"
     216                                        step /= 2.
     217                                        self.params.threshold -= step
     218                                elif cur < new:
     219                                        #print "higher"
     220                                        step *= .99
     221                                        self.params.threshold += step
     222                                else:
     223                                        print "Assuming we converged"
    131224                                        break
     225                                cur = new
     226
    132227
    133228if __name__ == "__main__":
     
    136231        else: print "ERR: a path is required"; sys.exit(1)
    137232        modes = ['complex', 'energy', 'phase', 'specdiff', 'kl', 'mkl', 'dual']
    138         #modes = [ 'complex' ]
    139         thresholds = [ 0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5]
     233        #modes = [ 'phase' ]
     234        thresholds = [ 0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2]
    140235        #thresholds = [1.5]
    141236
     
    146241        benchonset.params = taskparams()
    147242        benchonset.task = taskonset
    148 
    149         benchonset.titles = [ 'mode', 'thres', 'orig', 'expc', 'missd', 'mergd',
    150         'bad', 'doubl', 'corrt', 'GD', 'FP', 'GD-merged', 'FP-pruned',
    151         'prec', 'recl', 'dist' ]
    152         benchonset.formats = ["%12s" , "| %6s", "| %6s", "| %6s", "| %6s", "| %6s",
    153         "| %6s", "| %6s", "| %6s", "| %8s", "| %8s", "| %8s", "| %8s",
    154         "| %6s", "| %6s", "| %6s"]
     243        benchonset.valuesdict = {}
     244
    155245
    156246        try:
    157                 benchonset.auto_learn2(modes=modes)
    158                 #benchonset.run_bench(modes=modes)
     247                #benchonset.auto_learn2(modes=modes)
     248                benchonset.run_bench(modes=modes,thresholds=thresholds)
    159249        except KeyboardInterrupt:
    160250                sys.exit(1)
Note: See TracChangeset for help on using the changeset viewer.