1 | /* |
---|
2 | Copyright (C) 2018 Paul Brossier <piem@aubio.org> |
---|
3 | |
---|
4 | This file is part of aubio. |
---|
5 | |
---|
6 | aubio is free software: you can redistribute it and/or modify |
---|
7 | it under the terms of the GNU General Public License as published by |
---|
8 | the Free Software Foundation, either version 3 of the License, or |
---|
9 | (at your option) any later version. |
---|
10 | |
---|
11 | aubio is distributed in the hope that it will be useful, |
---|
12 | but WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
13 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
---|
14 | GNU General Public License for more details. |
---|
15 | |
---|
16 | You should have received a copy of the GNU General Public License |
---|
17 | along with aubio. If not, see <http://www.gnu.org/licenses/>. |
---|
18 | |
---|
19 | */ |
---|
20 | |
---|
21 | |
---|
22 | #include "aubio_priv.h" |
---|
23 | #include "fmat.h" |
---|
24 | #include "tensor.h" |
---|
25 | #include "dense.h" |
---|
26 | |
---|
27 | struct _aubio_dense_t { |
---|
28 | uint_t n_units; |
---|
29 | fmat_t *weights; |
---|
30 | fvec_t *bias; |
---|
31 | }; |
---|
32 | |
---|
33 | aubio_dense_t *new_aubio_dense(uint_t n_units) { |
---|
34 | aubio_dense_t *c = AUBIO_NEW(aubio_dense_t); |
---|
35 | |
---|
36 | AUBIO_GOTO_FAILURE((sint_t)n_units >= 1); |
---|
37 | |
---|
38 | c->n_units = n_units; |
---|
39 | |
---|
40 | return c; |
---|
41 | failure: |
---|
42 | del_aubio_dense(c); |
---|
43 | return NULL; |
---|
44 | } |
---|
45 | |
---|
46 | void del_aubio_dense(aubio_dense_t *c) { |
---|
47 | AUBIO_ASSERT(c); |
---|
48 | if (c->weights) |
---|
49 | del_fmat(c->weights); |
---|
50 | if (c->bias) |
---|
51 | del_fvec(c->bias); |
---|
52 | AUBIO_FREE(c); |
---|
53 | } |
---|
54 | |
---|
55 | #if defined(DEBUG) |
---|
56 | void aubio_dense_debug(aubio_dense_t *c, aubio_tensor_t *input_tensor) |
---|
57 | { |
---|
58 | uint_t n_params = input_tensor->shape[0] * c->n_units + c->n_units; |
---|
59 | AUBIO_DBG("dense: %15s -> (%d,) (%d params)" |
---|
60 | " (n_units=%d, weights=(%d, %d), bias=(%d,))\n", |
---|
61 | aubio_tensor_get_shape_string(input_tensor), c->n_units, n_params, |
---|
62 | c->n_units, c->weights->height, c->weights->length, c->bias->length); |
---|
63 | } |
---|
64 | #endif |
---|
65 | |
---|
66 | uint_t aubio_dense_get_output_shape(aubio_dense_t *c, |
---|
67 | aubio_tensor_t *input, uint_t *shape) |
---|
68 | { |
---|
69 | AUBIO_ASSERT (c && input && shape); |
---|
70 | AUBIO_ASSERT (input->ndim == 1); |
---|
71 | shape[0] = c->n_units; |
---|
72 | |
---|
73 | if (c->weights) del_fmat(c->weights); |
---|
74 | c->weights = new_fmat(input->shape[0], c->n_units); |
---|
75 | if (!c->weights) return AUBIO_FAIL; |
---|
76 | |
---|
77 | if (c->bias) del_fvec(c->bias); |
---|
78 | c->bias = new_fvec(c->n_units); |
---|
79 | if (!c->bias) return AUBIO_FAIL; |
---|
80 | |
---|
81 | #if defined(DEBUG) |
---|
82 | aubio_dense_debug(c, input); |
---|
83 | #endif |
---|
84 | |
---|
85 | return AUBIO_OK; |
---|
86 | } |
---|
87 | |
---|
88 | fmat_t *aubio_dense_get_weights(aubio_dense_t *c) { |
---|
89 | return c->weights; |
---|
90 | } |
---|
91 | |
---|
92 | fvec_t *aubio_dense_get_bias(aubio_dense_t *c) { |
---|
93 | return c->bias; |
---|
94 | } |
---|
95 | |
---|
96 | void aubio_dense_do(aubio_dense_t *c, aubio_tensor_t *input_tensor, |
---|
97 | aubio_tensor_t *activations) { |
---|
98 | AUBIO_ASSERT(c && input_tensor && activations); |
---|
99 | AUBIO_ASSERT(input_tensor->ndim == 1); |
---|
100 | AUBIO_ASSERT(activations->ndim == 1); |
---|
101 | AUBIO_ASSERT(input_tensor->shape[0] == c->weights->height); |
---|
102 | AUBIO_ASSERT(activations->shape[0] == c->weights->length); |
---|
103 | |
---|
104 | fvec_t input_vec; |
---|
105 | aubio_tensor_as_fvec(input_tensor, &input_vec); |
---|
106 | fvec_t output_vec; |
---|
107 | aubio_tensor_as_fvec(activations, &output_vec); |
---|
108 | |
---|
109 | // compute x.W |
---|
110 | fvec_matmul(&input_vec, c->weights, &output_vec); |
---|
111 | // add bias |
---|
112 | fvec_vecadd(&output_vec, c->bias); |
---|
113 | } |
---|