source: src/ai/tensor.c @ e18c30e

feature/crepe_org
Last change on this file since e18c30e was d9a5466, checked in by Paul Brossier <piem@piem.org>, 6 years ago

[tensor] add warning in aubio_tensor_as_fmat, to be avoided

  • Property mode set to 100644
File size: 6.5 KB
RevLine 
[aa5cc08]1/*
2  Copyright (C) 2018 Paul Brossier <piem@aubio.org>
3
4  This file is part of aubio.
5
6  aubio is free software: you can redistribute it and/or modify
7  it under the terms of the GNU General Public License as published by
8  the Free Software Foundation, either version 3 of the License, or
9  (at your option) any later version.
10
11  aubio is distributed in the hope that it will be useful,
12  but WITHOUT ANY WARRANTY; without even the implied warranty of
13  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  GNU General Public License for more details.
15
16  You should have received a copy of the GNU General Public License
17  along with aubio.  If not, see <http://www.gnu.org/licenses/>.
18
19*/
20
[4c33f81]21#include "aubio_priv.h"
22#include "fmat.h"
23#include "tensor.h"
24
[c4b6b59]25#define STRN_LENGTH 40
26#if !HAVE_AUBIO_DOUBLE
27#define AUBIO_SMPL_TFMT "% 9.4f"
28#else
29#define AUBIO_SMPL_TFMT "% 9.4lf"
30#endif /* HAVE_AUBIO_DOUBLE */
31
[5010e61]32aubio_tensor_t *new_aubio_tensor(uint_t ndim, uint_t *shape)
[4c33f81]33{
34  aubio_tensor_t *c = AUBIO_NEW(aubio_tensor_t);
[f5ea4fb]35  uint_t items_per_row = 1;
[4c33f81]36  uint_t i;
37
[5010e61]38  if ((sint_t)ndim <= 0) goto failure;
39  for (i = 0; i < ndim; i++) {
40    if ((sint_t)shape[i] <= 0) goto failure;
[4c33f81]41  }
42
[5010e61]43  c->ndim = ndim;
44  c->shape[0] = shape[0];
45  for (i = 1; i < ndim; i++) {
46    c->shape[i] = shape[i];
[9ca7923]47    items_per_row *= shape[i];
[4c33f81]48  }
[9ca7923]49  c->size = items_per_row * shape[0];
[f5ea4fb]50  c->buffer = AUBIO_ARRAY(smpl_t, c->size);
[5010e61]51  c->data = AUBIO_ARRAY(smpl_t*, shape[0]);
[f5ea4fb]52  for (i = 0; i < c->shape[0]; i++) {
53    c->data[i] = c->buffer + i * items_per_row;
[4c33f81]54  }
55
56  return c;
57
58failure:
59  del_aubio_tensor(c);
60  return NULL;
61}
62
63void del_aubio_tensor(aubio_tensor_t *c)
64{
65  if (c->data) {
66    if (c->data[0]) {
67      AUBIO_FREE(c->data[0]);
68    }
69    AUBIO_FREE(c->data);
70  }
71  AUBIO_FREE(c);
72}
73
74uint_t aubio_tensor_as_fvec(aubio_tensor_t *c, fvec_t *o) {
[427a48c]75  if (!c || !o) return AUBIO_FAIL;
76  o->length = c->size;
[71655fee]77  o->data = c->buffer;
[4c33f81]78  return AUBIO_OK;
79}
80
81uint_t aubio_fvec_as_tensor(fvec_t *o, aubio_tensor_t *c) {
[427a48c]82  if (!o || !c) return AUBIO_FAIL;
[5010e61]83  c->ndim = 1;
84  c->shape[0] = o->length;
[4c33f81]85  c->data = &o->data;
[427a48c]86  c->buffer = o->data;
[849210c]87  c->size = o->length;
[4c33f81]88  return AUBIO_OK;
89}
90
91uint_t aubio_tensor_as_fmat(aubio_tensor_t *c, fmat_t *o) {
[427a48c]92  if (!c || !o) return AUBIO_FAIL;
[5010e61]93  o->height = c->shape[0];
[427a48c]94  o->length = c->size / c->shape[0];
[4c33f81]95  o->data = c->data;
[d9a5466]96  // o was allocated on the stack, data[1] may be NULL
97  AUBIO_WRN("aubio_tensor_as_fmat will not create a usable table of rows\n");
[4c33f81]98  return AUBIO_OK;
99}
100
101uint_t aubio_fmat_as_tensor(fmat_t *o, aubio_tensor_t *c) {
[427a48c]102  if (!o || !c) return AUBIO_FAIL;
[9d35014]103  c->ndim = 2;
104  c->shape[0] = o->height;
105  c->shape[1] = o->length;
[849210c]106  c->size = o->height * o->length;
[4c33f81]107  c->data = o->data;
[427a48c]108  c->buffer = o->data[0];
[4c33f81]109  return AUBIO_OK;
110}
111
[7e0b641]112uint_t aubio_tensor_get_subtensor(aubio_tensor_t *t, uint_t i,
113    aubio_tensor_t *st)
114{
115  uint_t j;
116  if (!t || !st) return AUBIO_FAIL;
117  if (i >= t->shape[0]) {
118    AUBIO_ERR("tensor: index %d out of range, only %d subtensors\n",
119        i, t->shape[0]);
120    return AUBIO_FAIL;
121  }
122  if(t->ndim > 1) {
123    st->ndim = t->ndim - 1;
124    for (j = 0; j < st->ndim; j++) {
125      st->shape[j] = t->shape[j + 1];
126    }
127    for (j = st->ndim; j < AUBIO_TENSOR_MAXDIM; j++) {
128      st->shape[j] = 0;
129    }
130    st->size = t->size / t->shape[0];
131  } else {
132    st->ndim = 1;
133    st->shape[0] = 1;
134    st->size = 1;
135  }
136  // st was allocated on the stack, row indices are lost
137  st->data = NULL;
138  st->buffer = &t->buffer[0] + st->size * i;
139  return AUBIO_OK;
140}
141
[4b2d174]142uint_t aubio_tensor_have_same_shape(aubio_tensor_t *a, aubio_tensor_t *b)
[dc257cc]143{
144  uint_t n;
[085696b]145  if (!a || !b)
[4b2d174]146    return 0;
[085696b]147  if (a->ndim != b->ndim)
148    return 0;
149
[4b2d174]150  for (n = 0; n < a->ndim; n++) {
151    if (a->shape[n] != b->shape[n]) {
[dc257cc]152      return 0;
153    }
154  }
155  return 1;
156}
157
[4c33f81]158smpl_t aubio_tensor_max(aubio_tensor_t *t)
159{
160  uint_t i;
[a85c7f3]161  smpl_t max = t->buffer[0];
[9ca7923]162  for (i = 0; i < t->size; i++) {
[a85c7f3]163    max = MAX(t->buffer[i], max);
[4c33f81]164  }
165  return max;
166}
[50d7afe]167
168const char_t *aubio_tensor_get_shape_string(aubio_tensor_t *t) {
169  uint_t i;
170  if (!t) return NULL;
171  size_t offset = 2;
172  static char_t shape_str[STRN_LENGTH];
173  char_t shape_str_previous[STRN_LENGTH] = "(";
174  for (i = 0; i < t->ndim; i++) {
[21631e9]175    // and space last if not the last one
176    int add_space = (i < t->ndim - 1);
177    // add coma first if this not last, or always if 1d
178    int add_coma = add_space || (t->ndim == 1);
179    int len = snprintf(shape_str, STRN_LENGTH, "%s%d%s%s",
180        shape_str_previous, t->shape[i],
181        add_coma ? "," : "", add_space ? " " : "");
[50d7afe]182    strncpy(shape_str_previous, shape_str, len);
183  }
184  snprintf(shape_str, strnlen(shape_str, STRN_LENGTH - offset - 1) + offset,
185      "%s)", shape_str_previous);
[9bffada]186  return shape_str;
[50d7afe]187}
[c4b6b59]188
189static void aubio_tensor_print_subtensor(aubio_tensor_t *t, uint_t depth)
190{
191  uint_t i;
192  AUBIO_MSG("[");
193  for (i = 0; i < t->shape[0]; i ++) {
194    AUBIO_MSG("%*s", i == 0 ? 0 : depth + 1, i == 0 ? "" : " ");
195    if (t->ndim == 1) {
196      AUBIO_MSG(AUBIO_SMPL_TFMT, t->buffer[i]);
197    } else {
198      aubio_tensor_t st;
199      aubio_tensor_get_subtensor(t, i, &st);
200      aubio_tensor_print_subtensor(&st, depth + 1); // recursive call
201    }
202    AUBIO_MSG("%s%s", (i < t->shape[0] - 1) ? "," : "",
203        t->ndim == 1 ? " " : ((i < t->shape[0] - 1) ? "\n" : ""));
204  }
205  AUBIO_MSG("]");
206}
207
208void aubio_tensor_print(aubio_tensor_t *t)
209{
210  AUBIO_MSG("tensor of shape %s\n", aubio_tensor_get_shape_string(t));
211  aubio_tensor_print_subtensor(t, 0);
212  AUBIO_MSG("\n");
213}
[6006760]214
215void aubio_tensor_matmul(aubio_tensor_t *a, aubio_tensor_t *b,
216    aubio_tensor_t *c)
217{
218  AUBIO_ASSERT (a->shape[0] == c->shape[0]);
219  AUBIO_ASSERT (a->shape[1] == b->shape[0]);
220  AUBIO_ASSERT (b->shape[1] == c->shape[1]);
221#if !defined(HAVE_BLAS)
222  uint_t i, j, k;
223  for (i = 0; i < c->shape[0]; i++) {
224    for (j = 0; j < c->shape[1]; j++) {
225      smpl_t sum = 0.;
226      for (k = 0; k < a->shape[1]; k++) {
227          sum += a->buffer[i * a->shape[1] + k]
228            * b->buffer[k * b->shape[1] + j];
229      }
230      c->buffer[i * c->shape[1] + j] = sum;
231    }
232  }
233#else
[1df9cd1]234  aubio_cblas__gemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, a->shape[0],
235      b->size/b->shape[0], b->shape[0], 1.F, a->buffer, a->size/a->shape[0],
236      b->buffer,
237      b->size/b->shape[0], 0.F, c->buffer, b->size/b->shape[0]);
[6006760]238#endif
239}
[cc74a29]240
241void aubio_tensor_copy(aubio_tensor_t *s, aubio_tensor_t *t)
242{
243  if (!aubio_tensor_have_same_shape(s, t)) {
244    AUBIO_ERR("tensor: not copying source tensor %s",
245        aubio_tensor_get_shape_string(s));
246    AUBIO_ERR(" to destination tensor %s",
247        aubio_tensor_get_shape_string(t));
248    return;
249  }
250  AUBIO_MEMCPY(t->buffer, s->buffer, s->size);
251}
Note: See TracBrowser for help on using the repository browser.