From 5d3d75148e8423fb3d7fc8271b8e059adb68852c Mon Sep 17 00:00:00 2001 From: Liangliang He Date: Thu, 14 Sep 2017 16:33:59 +0800 Subject: [PATCH] Update conv2d unit tests and add benchmark for bad alignments --- mace/ops/conv_2d_benchmark.cc | 3 +++ mace/ops/conv_2d_test.cc | 19 +++++++++++++------ 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/mace/ops/conv_2d_benchmark.cc b/mace/ops/conv_2d_benchmark.cc index 772dd200..cc6ec092 100644 --- a/mace/ops/conv_2d_benchmark.cc +++ b/mace/ops/conv_2d_benchmark.cc @@ -61,7 +61,10 @@ static void Conv2d(int iters, int batch, int channels, int height, int width, BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, NEON); BM_CONV_2D(1, 64, 32, 32, 1, 1, 1, VALID, 128, float); +BM_CONV_2D(1, 64, 33, 31, 1, 1, 1, VALID, 128, float); // Test bad alignments BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, VALID, 128, float); +BM_CONV_2D(1, 64, 33, 31, 3, 3, 1, VALID, 128, float); BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, SAME, 128, float); +BM_CONV_2D(1, 64, 33, 31, 3, 3, 1, SAME, 128, float); } // namespace mace diff --git a/mace/ops/conv_2d_test.cc b/mace/ops/conv_2d_test.cc index c516dd73..b169400f 100644 --- a/mace/ops/conv_2d_test.cc +++ b/mace/ops/conv_2d_test.cc @@ -192,8 +192,10 @@ TEST_F(Conv2dOpTest, Conv1x1) { } // TODO we need more tests -TEST_F(Conv2dOpTest, Conv3x3R1) { - auto func = [&](Padding type) { +TEST_F(Conv2dOpTest, ConvNxNS12) { + auto func = [&](int kernel_h, int kernel_w, + int stride_h, int stride_w, + Padding type) { srand(time(NULL)); // generate random input @@ -212,13 +214,14 @@ TEST_F(Conv2dOpTest, Conv3x3R1) { .Finalize(net.operator_def()); // Add args - net.AddIntsArg("strides", {1, 1}); + net.AddIntsArg("strides", {stride_h, stride_w}); net.AddIntArg("padding", type); net.AddIntsArg("dilations", {1, 1}); // Add input data net.AddRandomInput("Input", {batch, input_channels, height, width}); - net.AddRandomInput("Filter", {output_channels, input_channels, 3, 3}); + net.AddRandomInput("Filter", {output_channels, input_channels, + kernel_h, kernel_w}); net.AddRandomInput("Bias", {output_channels}); // run cpu net.RunOp(); @@ -233,6 +236,10 @@ TEST_F(Conv2dOpTest, Conv3x3R1) { }; - func(VALID); - func(SAME); + for (int kernel_size : {1, 3}) { + for (int stride : {1, 2}) { + func(kernel_size, kernel_size, stride, stride, VALID); + func(kernel_size, kernel_size, stride, stride, SAME); + } + } } -- GitLab