source: src/fmat.c @ befee7a

feature/crepe_org
Last change on this file since befee7a was 7048b56, checked in by Paul Brossier <piem@piem.org>, 6 years ago

[fmat] add matmul with blas implementation

  • Property mode set to 100644
File size: 6.5 KB
RevLine 
[c7860af]1/*
2  Copyright (C) 2009 Paul Brossier <piem@aubio.org>
3
4  This file is part of aubio.
5
6  aubio is free software: you can redistribute it and/or modify
7  it under the terms of the GNU General Public License as published by
8  the Free Software Foundation, either version 3 of the License, or
9  (at your option) any later version.
10
11  aubio is distributed in the hope that it will be useful,
12  but WITHOUT ANY WARRANTY; without even the implied warranty of
13  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  GNU General Public License for more details.
15
16  You should have received a copy of the GNU General Public License
17  along with aubio.  If not, see <http://www.gnu.org/licenses/>.
18
19*/
20
21#include "aubio_priv.h"
22#include "fmat.h"
23
[4ed0ed1]24fmat_t * new_fmat (uint_t height, uint_t length) {
[c21acb9]25  fmat_t * s;
[1fcd392]26  uint_t i;
[a9a8c04]27  if ((sint_t)height <= 0 || (sint_t)length <= 0 ) {
[767990e]28    return NULL;
29  }
[c21acb9]30  s = AUBIO_NEW(fmat_t);
[c7860af]31  s->height = height;
32  s->length = length;
[a9a8c04]33  // array of row pointers
[c7860af]34  s->data = AUBIO_ARRAY(smpl_t*,s->height);
[a9a8c04]35  // first row store the full height * length buffer
36  s->data[0] = AUBIO_ARRAY(smpl_t, s->height * s->length);
[1fcd392]37  for (i=1; i< s->height; i++) {
38    s->data[i] = s->data[0] + i * s->length;
[c7860af]39  }
40  return s;
41}
42
43void del_fmat (fmat_t *s) {
[1fcd392]44  AUBIO_ASSERT(s);
45  if (s->data[0])
46    AUBIO_FREE(s->data[0]);
47  if (s->data)
48    AUBIO_FREE(s->data);
[c7860af]49  AUBIO_FREE(s);
50}
51
[1ece4f8]52void fmat_set_sample(fmat_t *s, smpl_t data, uint_t channel, uint_t position) {
[c7860af]53  s->data[channel][position] = data;
54}
[1ece4f8]55
[1120f86]56smpl_t fmat_get_sample(const fmat_t *s, uint_t channel, uint_t position) {
[c7860af]57  return s->data[channel][position];
58}
[1ece4f8]59
[1120f86]60void fmat_get_channel(const fmat_t *s, uint_t channel, fvec_t *output) {
[3cd9fdd]61  output->data = s->data[channel];
62  output->length = s->length;
63  return;
[c7860af]64}
65
[1120f86]66smpl_t * fmat_get_channel_data(const fmat_t *s, uint_t channel) {
[1ece4f8]67  return s->data[channel];
68}
69
[1120f86]70smpl_t ** fmat_get_data(const fmat_t *s) {
[c7860af]71  return s->data;
72}
73
74/* helper functions */
75
[1120f86]76void fmat_print(const fmat_t *s) {
[c7860af]77  uint_t i,j;
78  for (i=0; i< s->height; i++) {
79    for (j=0; j< s->length; j++) {
[9498c88]80      AUBIO_MSG(AUBIO_SMPL_FMT " ", s->data[0][i * s->length + j]);
[c7860af]81    }
82    AUBIO_MSG("\n");
83  }
84}
85
86void fmat_set(fmat_t *s, smpl_t val) {
87  uint_t i,j;
88  for (i=0; i< s->height; i++) {
89    for (j=0; j< s->length; j++) {
90      s->data[i][j] = val;
91    }
92  }
93}
94
95void fmat_zeros(fmat_t *s) {
[7585822]96#ifdef HAVE_MEMCPY_HACKS
[c666a18]97  uint_t i;
98  for (i=0; i< s->height; i++) {
99    memset(s->data[i], 0, s->length * sizeof(smpl_t));
100  }
[7585822]101#else /* HAVE_MEMCPY_HACKS */
[c7860af]102  fmat_set(s, 0.);
[7585822]103#endif /* HAVE_MEMCPY_HACKS */
[c7860af]104}
105
106void fmat_ones(fmat_t *s) {
107  fmat_set(s, 1.);
108}
109
110void fmat_rev(fmat_t *s) {
111  uint_t i,j;
112  for (i=0; i< s->height; i++) {
[e141b23]113    for (j=0; j< FLOOR((smpl_t)s->length/2); j++) {
[c7860af]114      ELEM_SWAP(s->data[i][j], s->data[i][s->length-1-j]);
115    }
116  }
117}
118
[1120f86]119void fmat_weight(fmat_t *s, const fmat_t *weight) {
[c7860af]120  uint_t i,j;
121  uint_t length = MIN(s->length, weight->length);
122  for (i=0; i< s->height; i++) {
123    for (j=0; j< length; j++) {
124      s->data[i][j] *= weight->data[0][j];
125    }
126  }
127}
128
[1120f86]129void fmat_copy(const fmat_t *s, fmat_t *t) {
[c21acb9]130  uint_t i;
[7585822]131#ifndef HAVE_MEMCPY_HACKS
[bc6c1a7]132  uint_t j;
[7585822]133#endif /* HAVE_MEMCPY_HACKS */
[c7860af]134  if (s->height != t->height) {
[923a7a8]135    AUBIO_ERR("trying to copy %d rows to %d rows \n",
[c7860af]136            s->height, t->height);
[923a7a8]137    return;
[c7860af]138  }
139  if (s->length != t->length) {
[923a7a8]140    AUBIO_ERR("trying to copy %d columns to %d columns\n",
[c7860af]141            s->length, t->length);
[923a7a8]142    return;
[c7860af]143  }
[7585822]144#ifdef HAVE_MEMCPY_HACKS
[83d2948]145  for (i=0; i< s->height; i++) {
146    memcpy(t->data[i], s->data[i], t->length * sizeof(smpl_t));
147  }
[7585822]148#else /* HAVE_MEMCPY_HACKS */
[923a7a8]149  for (i=0; i< t->height; i++) {
150    for (j=0; j< t->length; j++) {
[c7860af]151      t->data[i][j] = s->data[i][j];
152    }
153  }
[7585822]154#endif /* HAVE_MEMCPY_HACKS */
[c7860af]155}
156
[1120f86]157void fmat_vecmul(const fmat_t *s, const fvec_t *scale, fvec_t *output) {
[630191c]158#if !defined(HAVE_ACCELERATE) && !defined(HAVE_BLAS)
[096a174]159  uint_t j, k;
160  AUBIO_ASSERT(s->height == output->length);
161  AUBIO_ASSERT(s->length == scale->length);
[a7348ca5]162  fvec_zeros(output);
163  for (j = 0; j < s->length; j++) {
164    for (k = 0; k < s->height; k++) {
[096a174]165      output->data[k] += scale->data[j] * s->data[k][j];
[a7348ca5]166    }
167  }
[b8c6c1e]168#elif defined(HAVE_BLAS)
[096a174]169#if 0
[a7348ca5]170  for (k = 0; k < s->height; k++) {
171    output->data[k] = aubio_cblas_dot( s->length, scale->data, 1, s->data[k], 1);
172  }
[096a174]173#else
174  aubio_cblas__gemv(CblasColMajor, CblasTrans,
175      s->length, s->height, 1.,
176      s->data[0], s->length,
177      scale->data, 1, 0.,
178      output->data, 1);
179#endif
[a7348ca5]180#elif defined(HAVE_ACCELERATE)
181#if 0
182  // seems slower and less precise (and dangerous?)
183  vDSP_mmul (s->data[0], 1, scale->data, 1, output->data, 1, s->height, 1, s->length);
184#else
[096a174]185  uint_t k;
[a7348ca5]186  for (k = 0; k < s->height; k++) {
187    aubio_vDSP_dotpr( scale->data, 1, s->data[k], 1, &(output->data[k]), s->length);
188  }
189#endif
190#endif
[147afba]191}
192
193void fvec_matmul(const fvec_t *scale, const fmat_t *s, fvec_t *output) {
194  AUBIO_ASSERT(s->height == scale->length);
195  AUBIO_ASSERT(s->length == output->length);
196#if !defined(HAVE_ACCELERATE) && !defined(HAVE_BLAS)
197  uint_t j, k;
198  fvec_zeros(output);
199  for (k = 0; k < s->height; k++) {
200    for (j = 0; j < s->length; j++) {
201      output->data[j] += s->data[k][j] * scale->data[k];
202    }
203  }
204#elif defined(HAVE_BLAS)
205#if 0
206  for (k = 0; k < s->length; k++) {
207    output->data[k] = aubio_cblas_dot( scale->length, scale->data, 1,
208        &s->data[0][0] + k, s->length);
209  }
210#else
211  aubio_cblas__gemv(CblasColMajor, CblasNoTrans,
212      s->length, s->height, 1.,
213      s->data[0], s->length,
214      scale->data, 1, 0.,
215      output->data, 1);
216#endif
217#elif defined(HAVE_ACCELERATE)
218#if 0
219  // seems slower and less precise (and dangerous?)
220  vDSP_mmul (s->data[0], 1, scale->data, 1, output->data, 1, s->height, 1, s->length);
221#else
222  uint_t k;
223  for (k = 0; k < s->height; k++) {
224    aubio_vDSP_dotpr( scale->data, 1, s->data[k], 1, &(output->data[k]), scale->length);
[096a174]225  }
226#endif
227#endif
[a7348ca5]228}
[7048b56]229
230void fmat_matmul(const fmat_t *a, const fmat_t *b, fmat_t *c)
231{
232  AUBIO_ASSERT (a->height == c->height);
233  AUBIO_ASSERT (a->length == b->height);
234  AUBIO_ASSERT (b->length == c->length);
235#if !defined(HAVE_BLAS)
236  uint_t i, j, k;
237  for (i = 0; i < c->height; i++) {
238    for (j = 0; j < c->length; j++) {
239      smpl_t sum = 0.;
240      for (k = 0; k < a->length; k++) {
241          sum += a->data[0][i * a->length + k]
242            * b->data[0][k * b->length + j];
243      }
244      c->data[0][i * c->length + j] = sum;
245    }
246  }
247#else
248  aubio_cblas__gemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, a->height,
249      b->length, b->height, 1.F, a->data[0], a->length, b->data[0],
250      b->length, 0.F, c->data[0], b->length);
251#endif
252}
Note: See TracBrowser for help on using the repository browser.