[fb3a9f5] | 1 | /* |
---|
| 2 | Copyright (C) 2018 Paul Brossier <piem@aubio.org> |
---|
| 3 | |
---|
| 4 | This file is part of aubio. |
---|
| 5 | |
---|
| 6 | aubio is free software: you can redistribute it and/or modify |
---|
| 7 | it under the terms of the GNU General Public License as published by |
---|
| 8 | the Free Software Foundation, either version 3 of the License, or |
---|
| 9 | (at your option) any later version. |
---|
| 10 | |
---|
| 11 | aubio is distributed in the hope that it will be useful, |
---|
| 12 | but WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
| 13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
---|
| 14 | GNU General Public License for more details. |
---|
| 15 | |
---|
| 16 | You should have received a copy of the GNU General Public License |
---|
| 17 | along with aubio. If not, see <http://www.gnu.org/licenses/>. |
---|
| 18 | |
---|
| 19 | */ |
---|
| 20 | |
---|
| 21 | |
---|
| 22 | #include "aubio_priv.h" |
---|
| 23 | #include "fmat.h" |
---|
| 24 | #include "tensor.h" |
---|
| 25 | #include "dense.h" |
---|
| 26 | |
---|
| 27 | struct _aubio_dense_t { |
---|
| 28 | uint_t n_units; |
---|
| 29 | fmat_t *weights; |
---|
| 30 | fvec_t *bias; |
---|
| 31 | }; |
---|
| 32 | |
---|
| 33 | aubio_dense_t *new_aubio_dense(uint_t n_units) { |
---|
| 34 | aubio_dense_t *c = AUBIO_NEW(aubio_dense_t); |
---|
| 35 | |
---|
| 36 | AUBIO_GOTO_FAILURE((sint_t)n_units >= 1); |
---|
| 37 | |
---|
| 38 | c->n_units = n_units; |
---|
| 39 | |
---|
| 40 | return c; |
---|
| 41 | failure: |
---|
| 42 | del_aubio_dense(c); |
---|
| 43 | return NULL; |
---|
| 44 | } |
---|
| 45 | |
---|
| 46 | void del_aubio_dense(aubio_dense_t *c) { |
---|
| 47 | AUBIO_ASSERT(c); |
---|
| 48 | if (c->weights) |
---|
| 49 | del_fmat(c->weights); |
---|
| 50 | if (c->bias) |
---|
| 51 | del_fvec(c->bias); |
---|
| 52 | AUBIO_FREE(c); |
---|
| 53 | } |
---|
| 54 | |
---|
| 55 | void aubio_dense_debug(aubio_dense_t *c, aubio_tensor_t *input_tensor) |
---|
| 56 | { |
---|
[4647d38] | 57 | uint_t n_params = input_tensor->shape[0] * c->n_units + c->n_units; |
---|
| 58 | AUBIO_DBG("dense: %15s -> (%d,) (%d params)" |
---|
| 59 | " (n_units=%d, weights=(%d, %d), bias=(%d,))\n", |
---|
| 60 | aubio_tensor_get_shape_string(input_tensor), c->n_units, n_params, |
---|
| 61 | c->n_units, c->weights->height, c->weights->length, c->bias->length); |
---|
[fb3a9f5] | 62 | } |
---|
| 63 | |
---|
| 64 | uint_t aubio_dense_get_output_shape(aubio_dense_t *c, |
---|
| 65 | aubio_tensor_t *input, uint_t *shape) |
---|
| 66 | { |
---|
| 67 | AUBIO_ASSERT (c && input && shape); |
---|
| 68 | AUBIO_ASSERT (input->ndim == 1); |
---|
| 69 | shape[0] = c->n_units; |
---|
| 70 | |
---|
| 71 | if (c->weights) del_fmat(c->weights); |
---|
| 72 | c->weights = new_fmat(input->shape[0], c->n_units); |
---|
| 73 | if (!c->weights) return AUBIO_FAIL; |
---|
| 74 | |
---|
| 75 | if (c->bias) del_fvec(c->bias); |
---|
| 76 | c->bias = new_fvec(c->n_units); |
---|
| 77 | if (!c->bias) return AUBIO_FAIL; |
---|
| 78 | |
---|
| 79 | aubio_dense_debug(c, input); |
---|
| 80 | |
---|
| 81 | return AUBIO_OK; |
---|
| 82 | } |
---|
| 83 | |
---|
| 84 | fmat_t *aubio_dense_get_weights(aubio_dense_t *c) { |
---|
| 85 | return c->weights; |
---|
| 86 | } |
---|
| 87 | |
---|
| 88 | fvec_t *aubio_dense_get_bias(aubio_dense_t *c) { |
---|
| 89 | return c->bias; |
---|
| 90 | } |
---|
| 91 | |
---|
| 92 | void aubio_dense_do(aubio_dense_t *c, aubio_tensor_t *input_tensor, |
---|
| 93 | aubio_tensor_t *activations) { |
---|
| 94 | AUBIO_ASSERT(c && input_tensor && activations); |
---|
| 95 | AUBIO_ASSERT(input_tensor->ndim == 1); |
---|
| 96 | AUBIO_ASSERT(activations->ndim == 1); |
---|
| 97 | AUBIO_ASSERT(input_tensor->shape[0] == c->weights->height); |
---|
| 98 | AUBIO_ASSERT(activations->shape[0] == c->weights->length); |
---|
| 99 | |
---|
| 100 | fvec_t input_vec; |
---|
| 101 | aubio_tensor_as_fvec(input_tensor, &input_vec); |
---|
| 102 | fvec_t output_vec; |
---|
| 103 | aubio_tensor_as_fvec(activations, &output_vec); |
---|
| 104 | |
---|
| 105 | // compute x.W |
---|
| 106 | fvec_matmul(&input_vec, c->weights, &output_vec); |
---|
| 107 | // add bias |
---|
| 108 | fvec_vecadd(&output_vec, c->bias); |
---|
| 109 | } |
---|