source: python/demos/demo_yin_compare.py

Last change on this file was ccd0327, checked in by Paul Brossier <piem@piem.org>, 7 years ago

python/demos/demo_yin_compare.py: fix indentation

  • Property mode set to 100755
File size: 4.8 KB
Line 
1#! /usr/bin/env python
2# -*- coding: utf8 -*-
3
4""" Pure python implementation of the sum of squared difference
5
6    sqd_yin: original sum of squared difference [0]
7        d_t(tau) = x ⊗ kernel
8    sqd_yinfast: sum of squared diff using complex domain [0]
9    sqd_yinfftslow: tappered squared diff [1]
10    sqd_yinfft: modified squared diff using complex domain [1]
11
12[0]:http://audition.ens.fr/adc/pdf/2002_JASA_YIN.pdf
13[1]:https://aubio.org/phd/
14"""
15
16import sys
17import numpy as np
18import matplotlib.pyplot as plt
19
20def sqd_yin(samples):
21    """ compute original sum of squared difference
22
23    Brute-force computation (cost o(N**2), slow)."""
24    B = len(samples)
25    W = B//2
26    yin = np.zeros(W)
27    for j in range(W):
28        for tau in range(1, W):
29            yin[tau] += (samples[j] - samples[j+tau])**2
30    return yin
31
32def sqd_yinfast(samples):
33    """ compute approximate sum of squared difference
34
35    Using complex convolution (fast, cost o(n*log(n)) )"""
36    # yin_t(tau) = (r_t(0) + r_(t+tau)(0)) - 2r_t(tau)
37    B = len(samples)
38    W = B//2
39    yin = np.zeros(W)
40    sqdiff = np.zeros(W)
41    kernel = np.zeros(B)
42    # compute r_(t+tau)(0)
43    squares = samples**2
44    for tau in range(W):
45        sqdiff[tau] = squares[tau:tau+W].sum()
46    # add r_t(0)
47    sqdiff += sqdiff[0]
48    # compute r_t(tau) using kernel convolution in complex domain
49    samples_fft = np.fft.fft(samples)
50    kernel[1:W+1] = samples[W-1::-1] # first half, reversed
51    kernel_fft = np.fft.fft(kernel)
52    r_t_tau = np.fft.ifft(samples_fft * kernel_fft).real[W:]
53    # compute yin_t(tau)
54    yin = sqdiff - 2 * r_t_tau
55    return yin
56
57def sqd_yintapered(samples):
58    """ compute tappered sum of squared difference
59
60    Brute-force computation (cost o(N**2), slow)."""
61    B = len(samples)
62    W = B//2
63    yin = np.zeros(W)
64    for tau in range(1, W):
65        for j in range(W - tau):
66            yin[tau] += (samples[j] - samples[j+tau])**2
67    return yin
68
69def sqd_yinfft(samples):
70    """ compute yinfft modified sum of squared differences
71
72    Very fast, improved performance in transients.
73
74    FIXME: biased."""
75    B = len(samples)
76    W = B//2
77    yin = np.zeros(W)
78    def hanningz(W):
79        return .5 * (1. - np.cos(2. * np.pi * np.arange(W) / W))
80    #win = np.ones(B)
81    win = hanningz(B)
82    sqrmag = np.zeros(B)
83    fftout = np.fft.fft(win*samples)
84    sqrmag[0] = fftout[0].real**2
85    for l in range(1, W):
86        sqrmag[l] = fftout[l].real**2 + fftout[l].imag**2
87        sqrmag[B-l] = sqrmag[l]
88    sqrmag[W] = fftout[W].real**2
89    fftout = np.fft.fft(sqrmag)
90    sqrsum = 2.*sqrmag[:W + 1].sum()
91    yin[0] = 0
92    yin[1:] = sqrsum - fftout.real[1:W]
93    return yin / B
94
95def cumdiff(yin):
96    """ compute the cumulative mean normalized difference """
97    W = len(yin)
98    yin[0] = 1.
99    cumsum = 0.
100    for tau in range(1, W):
101        cumsum += yin[tau]
102        if cumsum != 0:
103            yin[tau] *= tau/cumsum
104        else:
105            yin[tau] = 1
106    return yin
107
108def compute_all(x):
109    import time
110    now = time.time()
111
112    yin     = sqd_yin(x)
113    t1 = time.time()
114    print ("yin took %.2fms" % ((t1-now) * 1000.))
115
116    yinfast = sqd_yinfast(x)
117    t2 = time.time()
118    print ("yinfast took: %.2fms" % ((t2-t1) * 1000.))
119
120    yintapered = sqd_yintapered(x)
121    t3 = time.time()
122    print ("yintapered took: %.2fms" % ((t3-t2) * 1000.))
123
124    yinfft  = sqd_yinfft(x)
125    t4 = time.time()
126    print ("yinfft took: %.2fms" % ((t4-t3) * 1000.))
127
128    return yin, yinfast, yintapered, yinfft
129
130def plot_all(yin, yinfast, yintapered, yinfft):
131    fig, axes = plt.subplots(nrows=2, ncols=2, sharex=True, sharey='col')
132
133    axes[0, 0].plot(yin, label='yin')
134    axes[0, 0].plot(yintapered, label='yintapered')
135    axes[0, 0].set_ylim(bottom=0)
136    axes[0, 0].legend()
137    axes[1, 0].plot(yinfast, '-', label='yinfast')
138    axes[1, 0].plot(yinfft, label='yinfft')
139    axes[1, 0].legend()
140
141    axes[0, 1].plot(cumdiff(yin), label='yin')
142    axes[0, 1].plot(cumdiff(yintapered), label='yin tapered')
143    axes[0, 1].set_ylim(bottom=0)
144    axes[0, 1].legend()
145    axes[1, 1].plot(cumdiff(yinfast), '-', label='yinfast')
146    axes[1, 1].plot(cumdiff(yinfft), label='yinfft')
147    axes[1, 1].legend()
148
149    fig.tight_layout()
150
151testfreqs = [441., 800., 10000., 40.]
152
153if len(sys.argv) > 1:
154    testfreqs = map(float,sys.argv[1:])
155
156for f in testfreqs:
157    print ("Comparing yin implementations for sine wave at %.fHz" % f)
158    samplerate = 44100.
159    win_s = 4096
160
161    x = np.cos(2.*np.pi * np.arange(win_s) * f / samplerate)
162
163    n_times = 1#00
164    for n in range(n_times):
165        yin, yinfast, yinfftslow, yinfft = compute_all(x)
166    if 0: # plot difference
167        plt.plot(yin-yinfast)
168        plt.tight_layout()
169        plt.show()
170    if 1:
171        plt.plot(yinfftslow-yinfft)
172        plt.tight_layout()
173        plt.show()
174    plot_all(yin, yinfast, yinfftslow, yinfft)
175plt.show()
Note: See TracBrowser for help on using the repository browser.