source: python/tests/eval_pitch @ 633400d

feature/cnnfeature/crepefeature/pitchshiftfeature/timestretchfix/ffmpeg5
Last change on this file since 633400d was a1bf01d, checked in by Paul Brossier <piem@piem.org>, 8 years ago

python/tests/: use local import, create init.py

  • Property mode set to 100755
File size: 5.5 KB
Line 
1#! /usr/bin/env python
2
3"""
4Script to evaluate pitch algorithms against TONAS database.
5
6See http://mtg.upf.edu/download/datasets/tonas/
7
8Example run:
9
10    $ ./eval_pitch /path/to/TONAS/*/*.wav
11    OK:  94.74% vx r:  96.87% vx f:  15.83% f0:  96.02% %12:   0.50% /path/to/TONAS/Deblas/01-D_AMairena.wav
12    OK:  89.89% vx r:  93.21% vx f:  13.81% f0:  90.74% %12:   1.51% /path/to/TONAS/Deblas/02-D_ChanoLobato.wav
13    OK:  96.02% vx r:  96.73% vx f:  10.91% f0:  96.42% %12:   0.00% /path/to/TONAS/Deblas/03-D_Chocolate.wav
14    [...]
15    OK:  82.35% vx r:  95.52% vx f:  67.09% f0:  89.80% %12:   0.95% /path/to/TONAS/Martinetes2/80-M2_Rancapinos.wav
16    OK:  61.97% vx r:  85.71% vx f:  22.03% f0:  55.63% %12:   8.57% /path/to/TONAS/Martinetes2/81-M2_SDonday.wav
17    OK:  75.26% vx r:  91.63% vx f:  27.27% f0:  75.99% %12:   5.05% /path/to/TONAS/Martinetes2/82-M2_TiaAnicalaPiriniaca.wav
18    OK:  82.77% vx r:  92.74% vx f:  38.27% f0:  87.33% %12:   1.67% 69 files, total_length: 1177.69s, total runtime: 25.91s
19
20
21"""
22
23import sys
24import time
25import os.path
26import numpy
27from .utils import array_from_text_file, array_from_yaml_file
28from aubio import source, pitch, freqtomidi
29
30start = time.time()
31
32freq_tol = .50 # more or less half a tone
33
34methods = ["default", "yinfft", "mcomb", "yin", "fcomb", "schmitt", "specacf"]
35method = methods[0]
36
37downsample = 1
38tolerance =  0.35
39silence = -40.
40skip = 1
41if method in ["yinfft", "default"]:
42    downsample = 1
43    tolerance = 0.45
44elif method == "mcomb":
45    downsample = 4
46elif method == "yin":
47    downsample = 4
48    tolerance = 0.2
49
50samplerate = 44100 / downsample
51hop_s = 512 / downsample
52win_s = 2048 / downsample
53
54def get_pitches (filename, samplerate = samplerate, win_s = win_s, hop_s = hop_s):
55    s = source(filename, samplerate, hop_s)
56    samplerate = s.samplerate
57
58    p = pitch(method, win_s, hop_s, samplerate)
59    p.set_unit("freq")
60    p.set_tolerance(tolerance)
61    p.set_silence(silence)
62
63    # list of pitches, in samples
64    pitches = []
65
66    # total number of frames read
67    total_frames = 0
68    while True:
69        samples, read = s()
70        new_pitch = p(samples)[0]
71        pitches.append([total_frames/float(samplerate), new_pitch])
72        total_frames += read
73        if read < hop_s: break
74    return numpy.array(pitches)
75
76total_correct_f0, total_correct_sil, total_missed, total_incorrect, total_fp, total_total = 0, 0, 0, 0, 0, 0
77total_correct_chroma, total_voiced = 0, 0
78for source_file in sys.argv[1:]:
79    ground_truth_file = source_file.replace('.wav', '.f0.Corrected')
80    if os.path.isfile(ground_truth_file):
81        ground_truth = array_from_text_file(ground_truth_file)[:,[0,2]]
82        experiment = get_pitches(source_file)
83        # check that we have the same length, more or less one frame
84        assert abs(len(ground_truth) - len(experiment)) < 2
85        # align experiment by skipping first results
86        experiment = experiment[skip:]
87        experiment[:,0] -= experiment[0,0]
88        # trim to shortest list
89        maxlen = min(len(ground_truth), len(experiment))
90        experiment = experiment[:maxlen]
91        ground_truth = ground_truth[:maxlen]
92        # get difference matrix
93        diffmat = abs(experiment - ground_truth)
94        # make sure we got the timing right
95        assert max(diffmat[:,0]) < 10e-4, source_file
96        truth_pitches = freqtomidi(ground_truth[:,1])
97        exper_pitches = freqtomidi(experiment[:,1])
98
99        total = len(truth_pitches)
100        unvoiced = len(truth_pitches[truth_pitches == 0])
101        voiced = total - unvoiced
102        correct_sil, fp, missed, correct_f0, correct_chroma, incorrect = 0, 0, 0, 0, 0, 0
103        for a, b in zip(truth_pitches, exper_pitches):
104            if a == 0 and b == 0:
105                correct_sil += 1
106            elif a == 0 and b != 0:
107                fp += 1
108            elif a != 0 and b == 0:
109                missed += 1
110            elif abs(b - a) < freq_tol:
111                correct_f0 += 1
112            elif abs(b - a) % 12. < freq_tol:
113                correct_chroma += 1
114            else:
115                incorrect += 1
116        assert correct_sil + fp + missed + correct_f0 + correct_chroma + incorrect == total
117        assert unvoiced == correct_sil + fp
118        assert voiced == missed + correct_f0 + correct_chroma + incorrect
119        print "OK: %6s%%" % ("%.2f" % (100. * (correct_f0 + correct_sil) / total )),
120        print "vx r: %6s%%" % ("%.2f" % (100. - 100. * missed / voiced)),
121        print "vx f: %6s%%" % ("%.2f" % (100. * fp / unvoiced)),
122        print "f0: %6s%%" % ("%.2f" % (100. * correct_f0 / voiced)),
123        print "%%12: %6s%%" % ("%.2f" % (100. * correct_chroma / voiced)),
124        print source_file
125        total_correct_sil += correct_sil
126        total_correct_f0 += correct_f0
127        total_correct_chroma += correct_chroma
128        total_missed += missed
129        total_incorrect += incorrect
130        total_fp += fp
131        total_voiced += voiced
132        total_total += total
133    else:
134        print "ERR", "could not find ground_truth_file", ground_truth_file
135
136print "OK: %6s%%" % ("%.2f" % (100. * (total_correct_f0 + total_correct_sil) / total_total )),
137print "vx r: %6s%%" % ("%.2f" % (100. - 100. * total_missed / total_voiced)),
138print "vx f: %6s%%" % ("%.2f" % (100. * (total_fp) / (total_correct_sil + total_fp))),
139print "f0: %6s%%" % ("%.2f" % (100. * total_correct_f0 / total_voiced)),
140print "%%12: %6s%%" % ("%.2f" % (100. * total_correct_chroma / total_voiced)),
141print "%d files," % len(sys.argv[1:]),
142print "total_length: %.2fs," % ((total_total * hop_s) / float(samplerate)),
143print "total runtime: %.2fs" % (time.time() - start)
Note: See TracBrowser for help on using the repository browser.