1 | #! /usr/bin/env python |
---|
2 | # -*- coding: utf8 -*- |
---|
3 | |
---|
4 | """ Pure python implementation of the sum of squared difference |
---|
5 | |
---|
6 | sqd_yin: original sum of squared difference [0] |
---|
7 | d_t(tau) = x ⊗ kernel |
---|
8 | sqd_yinfast: sum of squared diff using complex domain [0] |
---|
9 | sqd_yinfftslow: tappered squared diff [1] |
---|
10 | sqd_yinfft: modified squared diff using complex domain [1] |
---|
11 | |
---|
12 | [0]:http://audition.ens.fr/adc/pdf/2002_JASA_YIN.pdf |
---|
13 | [1]:https://aubio.org/phd/ |
---|
14 | """ |
---|
15 | |
---|
16 | import sys |
---|
17 | import numpy as np |
---|
18 | import matplotlib.pyplot as plt |
---|
19 | |
---|
20 | def sqd_yin(samples): |
---|
21 | """ compute original sum of squared difference |
---|
22 | |
---|
23 | Brute-force computation (cost o(N**2), slow).""" |
---|
24 | B = len(samples) |
---|
25 | W = B//2 |
---|
26 | yin = np.zeros(W) |
---|
27 | for j in range(W): |
---|
28 | for tau in range(1, W): |
---|
29 | yin[tau] += (samples[j] - samples[j+tau])**2 |
---|
30 | return yin |
---|
31 | |
---|
32 | def sqd_yinfast(samples): |
---|
33 | """ compute approximate sum of squared difference |
---|
34 | |
---|
35 | Using complex convolution (fast, cost o(n*log(n)) )""" |
---|
36 | # yin_t(tau) = (r_t(0) + r_(t+tau)(0)) - 2r_t(tau) |
---|
37 | B = len(samples) |
---|
38 | W = B//2 |
---|
39 | yin = np.zeros(W) |
---|
40 | sqdiff = np.zeros(W) |
---|
41 | kernel = np.zeros(B) |
---|
42 | # compute r_(t+tau)(0) |
---|
43 | squares = samples**2 |
---|
44 | for tau in range(W): |
---|
45 | sqdiff[tau] = squares[tau:tau+W].sum() |
---|
46 | # add r_t(0) |
---|
47 | sqdiff += sqdiff[0] |
---|
48 | # compute r_t(tau) using kernel convolution in complex domain |
---|
49 | samples_fft = np.fft.fft(samples) |
---|
50 | kernel[1:W+1] = samples[W-1::-1] # first half, reversed |
---|
51 | kernel_fft = np.fft.fft(kernel) |
---|
52 | r_t_tau = np.fft.ifft(samples_fft * kernel_fft).real[W:] |
---|
53 | # compute yin_t(tau) |
---|
54 | yin = sqdiff - 2 * r_t_tau |
---|
55 | return yin |
---|
56 | |
---|
57 | def sqd_yintapered(samples): |
---|
58 | """ compute tappered sum of squared difference |
---|
59 | |
---|
60 | Brute-force computation (cost o(N**2), slow).""" |
---|
61 | B = len(samples) |
---|
62 | W = B//2 |
---|
63 | yin = np.zeros(W) |
---|
64 | for tau in range(1, W): |
---|
65 | for j in range(W - tau): |
---|
66 | yin[tau] += (samples[j] - samples[j+tau])**2 |
---|
67 | return yin |
---|
68 | |
---|
69 | def sqd_yinfft(samples): |
---|
70 | """ compute yinfft modified sum of squared differences |
---|
71 | |
---|
72 | Very fast, improved performance in transients. |
---|
73 | |
---|
74 | FIXME: biased.""" |
---|
75 | B = len(samples) |
---|
76 | W = B//2 |
---|
77 | yin = np.zeros(W) |
---|
78 | def hanningz(W): |
---|
79 | return .5 * (1. - np.cos(2. * np.pi * np.arange(W) / W)) |
---|
80 | #win = np.ones(B) |
---|
81 | win = hanningz(B) |
---|
82 | sqrmag = np.zeros(B) |
---|
83 | fftout = np.fft.fft(win*samples) |
---|
84 | sqrmag[0] = fftout[0].real**2 |
---|
85 | for l in range(1, W): |
---|
86 | sqrmag[l] = fftout[l].real**2 + fftout[l].imag**2 |
---|
87 | sqrmag[B-l] = sqrmag[l] |
---|
88 | sqrmag[W] = fftout[W].real**2 |
---|
89 | fftout = np.fft.fft(sqrmag) |
---|
90 | sqrsum = 2.*sqrmag[:W + 1].sum() |
---|
91 | yin[0] = 0 |
---|
92 | yin[1:] = sqrsum - fftout.real[1:W] |
---|
93 | return yin / B |
---|
94 | |
---|
95 | def cumdiff(yin): |
---|
96 | """ compute the cumulative mean normalized difference """ |
---|
97 | W = len(yin) |
---|
98 | yin[0] = 1. |
---|
99 | cumsum = 0. |
---|
100 | for tau in range(1, W): |
---|
101 | cumsum += yin[tau] |
---|
102 | if cumsum != 0: |
---|
103 | yin[tau] *= tau/cumsum |
---|
104 | else: |
---|
105 | yin[tau] = 1 |
---|
106 | return yin |
---|
107 | |
---|
108 | def compute_all(x): |
---|
109 | import time |
---|
110 | now = time.time() |
---|
111 | |
---|
112 | yin = sqd_yin(x) |
---|
113 | t1 = time.time() |
---|
114 | print ("yin took %.2fms" % ((t1-now) * 1000.)) |
---|
115 | |
---|
116 | yinfast = sqd_yinfast(x) |
---|
117 | t2 = time.time() |
---|
118 | print ("yinfast took: %.2fms" % ((t2-t1) * 1000.)) |
---|
119 | |
---|
120 | yintapered = sqd_yintapered(x) |
---|
121 | t3 = time.time() |
---|
122 | print ("yintapered took: %.2fms" % ((t3-t2) * 1000.)) |
---|
123 | |
---|
124 | yinfft = sqd_yinfft(x) |
---|
125 | t4 = time.time() |
---|
126 | print ("yinfft took: %.2fms" % ((t4-t3) * 1000.)) |
---|
127 | |
---|
128 | return yin, yinfast, yintapered, yinfft |
---|
129 | |
---|
130 | def plot_all(yin, yinfast, yintapered, yinfft): |
---|
131 | fig, axes = plt.subplots(nrows=2, ncols=2, sharex=True, sharey='col') |
---|
132 | |
---|
133 | axes[0, 0].plot(yin, label='yin') |
---|
134 | axes[0, 0].plot(yintapered, label='yintapered') |
---|
135 | axes[0, 0].set_ylim(bottom=0) |
---|
136 | axes[0, 0].legend() |
---|
137 | axes[1, 0].plot(yinfast, '-', label='yinfast') |
---|
138 | axes[1, 0].plot(yinfft, label='yinfft') |
---|
139 | axes[1, 0].legend() |
---|
140 | |
---|
141 | axes[0, 1].plot(cumdiff(yin), label='yin') |
---|
142 | axes[0, 1].plot(cumdiff(yintapered), label='yin tapered') |
---|
143 | axes[0, 1].set_ylim(bottom=0) |
---|
144 | axes[0, 1].legend() |
---|
145 | axes[1, 1].plot(cumdiff(yinfast), '-', label='yinfast') |
---|
146 | axes[1, 1].plot(cumdiff(yinfft), label='yinfft') |
---|
147 | axes[1, 1].legend() |
---|
148 | |
---|
149 | fig.tight_layout() |
---|
150 | |
---|
151 | testfreqs = [441., 800., 10000., 40.] |
---|
152 | |
---|
153 | if len(sys.argv) > 1: |
---|
154 | testfreqs = map(float,sys.argv[1:]) |
---|
155 | |
---|
156 | for f in testfreqs: |
---|
157 | print ("Comparing yin implementations for sine wave at %.fHz" % f) |
---|
158 | samplerate = 44100. |
---|
159 | win_s = 4096 |
---|
160 | |
---|
161 | x = np.cos(2.*np.pi * np.arange(win_s) * f / samplerate) |
---|
162 | |
---|
163 | n_times = 1#00 |
---|
164 | for n in range(n_times): |
---|
165 | yin, yinfast, yinfftslow, yinfft = compute_all(x) |
---|
166 | if 0: # plot difference |
---|
167 | plt.plot(yin-yinfast) |
---|
168 | plt.tight_layout() |
---|
169 | plt.show() |
---|
170 | if 1: |
---|
171 | plt.plot(yinfftslow-yinfft) |
---|
172 | plt.tight_layout() |
---|
173 | plt.show() |
---|
174 | plot_all(yin, yinfast, yinfftslow, yinfft) |
---|
175 | plt.show() |
---|