source: python/bench-onset @ af445db

feature/autosinkfeature/cnnfeature/cnn_orgfeature/constantqfeature/crepefeature/crepe_orgfeature/pitchshiftfeature/pydocstringsfeature/timestretchfix/ffmpeg5pitchshiftsamplertimestretchyinfft+
Last change on this file since af445db was af445db, checked in by Paul Brossier <piem@altern.org>, 19 years ago

add auto_learn2 method, which converges
add auto_learn2 method, which converges

  • Property mode set to 100755
File size: 4.6 KB
Line 
1#! /usr/bin/python
2
3from aubio.bench.node import *
4from aubio.tasks import *
5
6class benchonset(bench):
7       
8        def dir_eval(self):
9                self.P = 100*float(self.expc-self.missed-self.merged)/(self.expc-self.missed-self.merged + self.bad+self.doubled)
10                self.R = 100*float(self.expc-self.missed-self.merged)/(self.expc-self.missed-self.merged + self.missed+self.merged)
11                if self.R < 0: self.R = 0
12                self.F = 2* self.P*self.R / (self.P+self.R)
13
14                self.values = [self.params.onsetmode,
15                "%2.3f" % self.params.threshold,
16                self.orig,
17                self.expc,
18                self.missed,
19                self.merged,
20                self.bad,
21                self.doubled,
22                (self.orig-self.missed-self.merged),
23                "%2.3f" % (100*float(self.orig-self.missed-self.merged)/(self.orig)),
24                "%2.3f" % (100*float(self.bad+self.doubled)/(self.orig)),
25                "%2.3f" % (100*float(self.orig-self.missed)/(self.orig)),
26                "%2.3f" % (100*float(self.bad)/(self.orig)),
27                "%2.3f" % self.P,
28                "%2.3f" % self.R,
29                "%2.3f" % self.F  ]
30
31        def file_exec(self,input,output):
32                filetask = self.task(input,params=self.params)
33                computed_data = filetask.compute_all()
34                results = filetask.eval(computed_data)
35                self.orig    += filetask.orig
36                self.missed  += filetask.missed
37                self.merged  += filetask.merged
38                self.expc    += filetask.expc
39                self.bad     += filetask.bad
40                self.doubled += filetask.doubled
41
42
43        def run_bench(self,modes=['dual'],thresholds=[0.5]):
44                self.modes = modes
45                self.thresholds = thresholds
46
47                self.pretty_print(self.titles)
48                for mode in self.modes:
49                        self.params.onsetmode = mode
50                        for threshold in self.thresholds:
51                                self.params.threshold = threshold
52                                self.dir_exec()
53                                self.dir_eval()
54                                self.pretty_print(self.values)
55
56        def auto_learn(self,modes=['dual'],thresholds=[0.1,1.5]):
57                """ simple dichotomia like algorithm to optimise threshold """
58                self.modes = modes
59                self.pretty_print(self.titles)
60                for mode in self.modes:
61                        steps = 10
62                        lesst = thresholds[0]
63                        topt = thresholds[1]
64                        self.params.onsetmode = mode
65
66                        self.params.threshold = topt
67                        self.dir_exec()
68                        self.dir_eval()
69                        self.pretty_print(self.values)
70                        topF = self.F
71
72                        self.params.threshold = lesst
73                        self.dir_exec()
74                        self.dir_eval()
75                        self.pretty_print(self.values)
76                        lessF = self.F
77
78                        for i in range(steps):
79                                self.params.threshold = ( lesst + topt ) * .5
80                                self.dir_exec()
81                                self.dir_eval()
82                                self.pretty_print(self.values)
83                                if self.F == 100.0 or self.F == topF:
84                                        print "assuming we converged, stopping"
85                                        break
86                                #elif abs(self.F - topF) < 0.01 :
87                                #       print "done converging"
88                                #       break
89                                if topF < self.F:
90                                        #lessF = topF
91                                        #lesst = topt
92                                        topF = self.F
93                                        topt = self.params.threshold
94                                elif lessF < self.F:
95                                        lessF = self.F
96                                        lesst = self.params.threshold
97                                if topt == lesst:
98                                        lesst /= 2.
99
100        def auto_learn2(self,modes=['dual'],thresholds=[0.1,1.0]):
101                """ simple dichotomia like algorithm to optimise threshold """
102                self.modes = modes
103                self.pretty_print(self.titles)
104                for mode in self.modes:
105                        steps = 10
106                        step = thresholds[1]
107                        curt = thresholds[0]
108                        self.params.onsetmode = mode
109
110                        self.params.threshold = curt
111                        self.dir_exec()
112                        self.dir_eval()
113                        self.pretty_print(self.values)
114                        curexp = self.expc
115
116                        for i in range(steps):
117                                if curexp < self.orig:
118                                        #print "we found at most less onsets than annotated"
119                                        self.params.threshold -= step
120                                        step /= 2
121                                elif curexp > self.orig:
122                                        #print "we found more onsets than annotated"
123                                        self.params.threshold += step
124                                        step /= 2
125                                self.dir_exec()
126                                self.dir_eval()
127                                curexp = self.expc
128                                self.pretty_print(self.values)
129                                if self.orig == 100.0 or self.orig == self.expc:
130                                        print "assuming we converged, stopping"
131                                        break
132
133if __name__ == "__main__":
134        import sys
135        if len(sys.argv) > 1: datapath = sys.argv[1]
136        else: print "ERR: a path is required"; sys.exit(1)
137        modes = ['complex', 'energy', 'phase', 'specdiff', 'kl', 'mkl', 'dual']
138        #modes = [ 'complex' ]
139        thresholds = [ 0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5]
140        #thresholds = [1.5]
141
142        #datapath = "%s%s" % (DATADIR,'/onset/DB/*/')
143        respath = '/var/tmp/DB-testings'
144
145        benchonset = benchonset(datapath,respath,checkres=True,checkanno=True)
146        benchonset.params = taskparams()
147        benchonset.task = taskonset
148
149        benchonset.titles = [ 'mode', 'thres', 'orig', 'expc', 'missd', 'mergd',
150        'bad', 'doubl', 'corrt', 'GD', 'FP', 'GD-merged', 'FP-pruned',
151        'prec', 'recl', 'dist' ]
152        benchonset.formats = ["%12s" , "| %6s", "| %6s", "| %6s", "| %6s", "| %6s",
153        "| %6s", "| %6s", "| %6s", "| %8s", "| %8s", "| %8s", "| %8s",
154        "| %6s", "| %6s", "| %6s"]
155
156        try:
157                benchonset.auto_learn2(modes=modes)
158                #benchonset.run_bench(modes=modes)
159        except KeyboardInterrupt:
160                sys.exit(1)
Note: See TracBrowser for help on using the repository browser.