Changeset 337e70d


Ignore:
Timestamp:
Dec 29, 2021, 5:51:48 PM (2 years ago)
Author:
Paul Brossier <piem@piem.org>
Branches:
feature/cnn, feature/crepe
Children:
c97f7ed
Parents:
b6097ac
git-author:
Paul Brossier <piem@piem.org> (01/08/19 15:29:51)
git-committer:
Paul Brossier <piem@piem.org> (12/29/21 17:51:48)
Message:

[conv1d] add first blas optimisation using sdot

File:
1 edited

Legend:

Unmodified
Added
Removed
  • src/ai/conv1d.c

    rb6097ac r337e70d  
    4545  uint_t output_shape[2];     // shape of output
    4646  uint_t padding_start;    // {top, left} padding
     47
     48#if defined(HAVE_BLAS)
     49  aubio_tensor_t *padded_input;
     50#endif
    4751};
    4852
     
    8387  if (c->bias)
    8488    del_fvec(c->bias);
     89#if defined(HAVE_BLAS)
     90  if (c->padded_input) del_aubio_tensor(c->padded_input);
     91#endif
    8592  AUBIO_FREE(c);
    8693}
     
    104111{
    105112  uint_t output_shape[2] = {0, c->n_filters};
     113  uint_t padding_shape = 0;  // total amount of padding
    106114  uint_t padding_start = 0;
    107115
     
    120128          / (smpl_t)c->stride_shape);
    121129
    122       uint_t padding_shape;  // total amount of padding
    123130      padding_shape = (output_shape[0] - 1) * c->stride_shape +
    124131        c->kernel_shape - input_tensor->shape[0];
     
    154161  c->output_shape[0] = output_shape[0];
    155162  c->output_shape[1] = output_shape[1];
     163
     164#if defined(HAVE_BLAS)
     165  if (c->padded_input) del_aubio_tensor(c->padded_input);
     166  uint_t padded_shape[2] = {input_tensor->shape[0] + padding_shape,
     167    input_tensor->shape[1]};
     168  c->padded_input = new_aubio_tensor(2, padded_shape);
     169#endif
    156170
    157171  c->padding_start = padding_start;
     
    206220}
    207221
     222#if !defined(HAVE_BLAS)
    208223void aubio_conv1d_do(aubio_conv1d_t *c, aubio_tensor_t *input_tensor,
    209224    aubio_tensor_t *activations)
     
    256271}
    257272
     273#else /* HAVE_BLAS */
     274
     275// blas implementation
     276//
     277//  uses _sdot on the padded input to compute each output elements at once
     278//
     279// TODO
     280//  - switch to sgemv to factorise over activations->shape[j]
     281//  - avoid copy when padding_start == 0
     282//  - optimize copying using tensor helpers
     283
     284void aubio_conv1d_do(aubio_conv1d_t *c, aubio_tensor_t *input_tensor,
     285    aubio_tensor_t *activations)
     286{
     287  uint_t i, j;
     288  smpl_t bias, acc;
     289
     290  uint_t sdot_size = c->kernel->shape[0] * c->kernel->shape[1];
     291  uint_t input_stride = c->stride_shape * c->kernel->shape[1];
     292
     293  AUBIO_ASSERT(c && input_tensor && activations);
     294  if (aubio_conv1d_check_output_shape(c, input_tensor, activations))
     295  {
     296    AUBIO_ERR("conv1d: check_output_shape failed\n");
     297    return;
     298  }
     299
     300  // copy input to padded version
     301  for (j = 0; j < input_tensor->shape[0]; j++) {
     302    for (i = 0; i < input_tensor->shape[1]; i++) {
     303      c->padded_input->data[j + c->padding_start][i] =
     304        input_tensor->data[j][i];
     305    }
     306  }
     307
     308  // for each output
     309  for (j = 0; j < activations->shape[0]; j++) {
     310    // for each kernel filter k
     311    for (i = 0; i < activations->shape[1]; i++) {
     312      // get bias
     313      bias = c->bias->data[i];
     314
     315      // compute one activation output
     316      acc = aubio_cblas_dot(sdot_size, c->kernel->buffer + i,
     317          c->kernel->shape[2], c->padded_input->buffer + j * input_stride, 1);
     318
     319      // apply bias
     320      acc += bias;
     321
     322      // compute RELU
     323      activations->data[j][i] = MAX(acc, 0.);
     324    }
     325  }
     326}
     327#endif /* HAVE_BLAS */
     328
    258329uint_t aubio_conv1d_set_padding_mode(aubio_conv1d_t *c,
    259330    const char_t *padding_mode)
Note: See TracChangeset for help on using the changeset viewer.