source: src/ai/batchnorm.c @ c249f57

feature/crepe_org
Last change on this file since c249f57 was 689ba93, checked in by Paul Brossier <piem@piem.org>, 6 years ago

[batchnorm] accepts any input size, allocate weights in get_output_shape

  • Property mode set to 100644
File size: 4.8 KB
RevLine 
[2fec649]1/*
2  Copyright (C) 2018 Paul Brossier <piem@aubio.org>
3
4  This file is part of aubio.
5
6  aubio is free software: you can redistribute it and/or modify
7  it under the terms of the GNU General Public License as published by
8  the Free Software Foundation, either version 3 of the License, or
9  (at your option) any later version.
10
11  aubio is distributed in the hope that it will be useful,
12  but WITHOUT ANY WARRANTY; without even the implied warranty of
13  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  GNU General Public License for more details.
15
16  You should have received a copy of the GNU General Public License
17  along with aubio.  If not, see <http://www.gnu.org/licenses/>.
18
19*/
20
21#include "aubio_priv.h"
22#include "fmat.h"
23#include "tensor.h"
24#include "batchnorm.h"
25
26struct _aubio_batchnorm_t {
27  uint_t n_outputs;
28  fvec_t *gamma;
29  fvec_t *beta;
30  fvec_t *moving_mean;
31  fvec_t *moving_variance;
32};
33
34static void aubio_batchnorm_debug(aubio_batchnorm_t *c,
35    aubio_tensor_t *input_tensor);
36
[689ba93]37aubio_batchnorm_t *new_aubio_batchnorm(void)
[2fec649]38{
39  aubio_batchnorm_t *c = AUBIO_NEW(aubio_batchnorm_t);
40  return c;
[689ba93]41#if 0 // no argument so no other possible failure
[2fec649]42failure:
43  del_aubio_batchnorm(c);
44  return NULL;
[689ba93]45#endif
[2fec649]46}
47
[689ba93]48static void aubio_batchnorm_reset(aubio_batchnorm_t *c) {
[2fec649]49  AUBIO_ASSERT(c);
50  if (c->gamma)
51    del_fvec(c->gamma);
52  if (c->beta)
53    del_fvec(c->beta);
54  if (c->moving_mean)
55    del_fvec(c->moving_mean);
56  if (c->moving_variance)
57    del_fvec(c->moving_variance);
[689ba93]58}
59
60void del_aubio_batchnorm(aubio_batchnorm_t* c) {
61  aubio_batchnorm_reset(c);
[2fec649]62  AUBIO_FREE(c);
63}
64
65void aubio_batchnorm_debug(aubio_batchnorm_t *c, aubio_tensor_t *input_tensor)
66{
[6ee6107]67  AUBIO_DBG("batchnorm: %15s -> %s (%d params) (4 * (%d,))\n",
68      aubio_tensor_get_shape_string(input_tensor),
69      aubio_tensor_get_shape_string(input_tensor), // same output shape
70      c->n_outputs, 4 * c->n_outputs);
[2fec649]71}
72
73uint_t aubio_batchnorm_get_output_shape(aubio_batchnorm_t *c,
74    aubio_tensor_t *input, uint_t *shape)
75{
[fa90792]76  uint_t i;
77
[2fec649]78  AUBIO_ASSERT(c && input && shape);
79
[fa90792]80  for (i = 0; i < input->ndim; i++) {
81    shape[i] = input->shape[i];
82  }
[2fec649]83
[689ba93]84  aubio_batchnorm_reset(c);
85
86  c->n_outputs = input->shape[input->ndim - 1];
87
88  c->gamma = new_fvec(c->n_outputs);
89  c->beta = new_fvec(c->n_outputs);
90  c->moving_mean = new_fvec(c->n_outputs);
91  c->moving_variance = new_fvec(c->n_outputs);
92
93  if (!c->gamma || !c->beta || !c->moving_mean || !c->moving_variance)
94  {
95    aubio_batchnorm_reset(c);
96    return AUBIO_FAIL;
97  }
98
[2fec649]99  aubio_batchnorm_debug(c, input);
100
101  return AUBIO_OK;
102}
103
104void aubio_batchnorm_do(aubio_batchnorm_t *c, aubio_tensor_t *input_tensor,
105    aubio_tensor_t *activations)
106{
107  smpl_t s;
[fa90792]108  uint_t i, j;
109  uint_t ii = 0;
110  uint_t length = activations->shape[activations->ndim - 1];
111  uint_t height = activations->size / length;
112
[2fec649]113  AUBIO_ASSERT(c);
114  AUBIO_ASSERT_EQUAL_SHAPE(input_tensor, activations);
[fa90792]115  AUBIO_ASSERT(length == c->n_outputs);
116  AUBIO_ASSERT(height * length == activations->size);
117
118  for (i = 0; i < height; i++) {
119    for (j = 0; j < length; j++) {
120      s = input_tensor->buffer[ii + j];
121      s -= c->moving_mean->data[j];
122      s *= c->gamma->data[j];
123      s /= SQRT(c->moving_variance->data[j] + 1.e-4);
124      s += c->beta->data[j];
125      activations->buffer[ii + j] = s;
[2fec649]126    }
[fa90792]127    ii += length;
[2fec649]128  }
129}
130
131uint_t aubio_batchnorm_set_gamma(aubio_batchnorm_t *t, fvec_t *gamma)
132{
133  AUBIO_ASSERT(t && t->gamma);
134  AUBIO_ASSERT(gamma);
135  if (t->gamma->length != gamma->length) return AUBIO_FAIL;
136  fvec_copy(gamma, t->gamma);
137  return AUBIO_OK;
138}
139
140uint_t aubio_batchnorm_set_beta(aubio_batchnorm_t *t, fvec_t *beta)
141{
[3977b4f]142  AUBIO_ASSERT(t && t->beta && beta);
143  if (t->beta->length != beta->length)
144    return AUBIO_FAIL;
[2fec649]145  fvec_copy(beta, t->beta);
146  return AUBIO_OK;
147}
148
[3977b4f]149uint_t aubio_batchnorm_set_moving_mean(aubio_batchnorm_t *t,
150    fvec_t *moving_mean)
[2fec649]151{
152  AUBIO_ASSERT(t && t->moving_mean);
153  AUBIO_ASSERT(moving_mean);
[3977b4f]154  if (t->moving_mean->length != moving_mean->length)
155    return AUBIO_FAIL;
[2fec649]156  fvec_copy(moving_mean, t->moving_mean);
157  return AUBIO_OK;
158}
159
[3977b4f]160uint_t aubio_batchnorm_set_moving_variance(aubio_batchnorm_t *t,
161    fvec_t *moving_variance)
[2fec649]162{
163  AUBIO_ASSERT(t && t->moving_variance);
164  AUBIO_ASSERT(moving_variance);
[3977b4f]165  if (t->moving_variance->length != moving_variance->length)
166    return AUBIO_FAIL;
[2fec649]167  fvec_copy(moving_variance, t->moving_variance);
168  return AUBIO_OK;
169}
170
171fvec_t *aubio_batchnorm_get_gamma(aubio_batchnorm_t *t)
172{
173  AUBIO_ASSERT(t && t->gamma);
174  return t->gamma;
175}
176
177fvec_t *aubio_batchnorm_get_beta(aubio_batchnorm_t *t)
178{
179  AUBIO_ASSERT(t && t->beta);
180  return t->beta;
181}
182
183fvec_t *aubio_batchnorm_get_moving_mean(aubio_batchnorm_t *t)
184{
185  AUBIO_ASSERT(t && t->moving_mean);
186  return t->moving_mean;
187}
188
189fvec_t *aubio_batchnorm_get_moving_variance(aubio_batchnorm_t *t)
190{
191  AUBIO_ASSERT(t && t->moving_variance);
192  return t->moving_variance;
193}
Note: See TracBrowser for help on using the repository browser.