Changeset 1768cdb


Ignore:
Timestamp:
Jan 8, 2019, 12:00:16 AM (6 years ago)
Author:
Paul Brossier <piem@piem.org>
Branches:
feature/cnn_org, feature/crepe_org
Children:
f9f03ff
Parents:
791b436
Message:

[tests] add tensor_matmul test

File:
1 edited

Legend:

Unmodified
Added
Removed
  • tests/src/ai/test-tensor.c

    r791b436 r1768cdb  
    271271}
    272272
     273int test_matmul(void)
     274{
     275  uint_t m = 3, n = 2, k = 4;
     276  uint_t input_shape[2] =  {m, k};
     277  uint_t kernel_shape[2] = {k, n};
     278  uint_t output_shape[2] = {m, n};
     279
     280  aubio_tensor_t *input_tensor = new_aubio_tensor(2, input_shape);
     281  aubio_tensor_t *kernel_tensor = new_aubio_tensor(2, kernel_shape);
     282  aubio_tensor_t *output_tensor = new_aubio_tensor(2, output_shape);
     283
     284  input_tensor->data[0][0] = 1;
     285  input_tensor->data[1][1] = 1;
     286  input_tensor->data[2][0] = -1;
     287  input_tensor->data[2][1] = 1;
     288  uint_t i;
     289  for (i = 0; i < kernel_tensor->size; i++) {
     290    kernel_tensor->buffer[i] = (smpl_t)i + 1.;
     291  }
     292
     293  aubio_tensor_matmul(input_tensor, kernel_tensor, output_tensor);
     294
     295  PRINT_MSG("input: ");
     296  aubio_tensor_print(input_tensor);
     297  PRINT_MSG("kernel: ");
     298  aubio_tensor_print(kernel_tensor);
     299  PRINT_MSG("output: ");
     300  aubio_tensor_print(output_tensor);
     301
     302  assert (output_tensor->data[0][0] == kernel_tensor->data[0][0]);
     303  assert (output_tensor->data[0][1] == kernel_tensor->data[0][1]);
     304  assert (output_tensor->data[1][0] == kernel_tensor->data[1][0]);
     305  assert (output_tensor->data[1][1] == kernel_tensor->data[1][1]);
     306  assert (output_tensor->data[2][0] == 2);
     307  assert (output_tensor->data[2][1] == 2);
     308
     309  del_aubio_tensor(output_tensor);
     310  del_aubio_tensor(kernel_tensor);
     311  del_aubio_tensor(input_tensor);
     312  return 0;
     313}
    273314int main(void) {
    274315  PRINT_MSG("testing 1d tensors\n");
     
    292333  PRINT_MSG("testing get_shape_string\n");
    293334  assert (test_get_shape_string() == 0);
    294   return 0;
    295 }
     335  PRINT_MSG("testing matmul\n");
     336  assert (test_matmul() == 0);
     337  return 0;
     338}
Note: See TracChangeset for help on using the changeset viewer.