diff --git a/mace/ops/conv_2d_benchmark.cc b/mace/ops/conv_2d_benchmark.cc index 772dd200beeebf5b2cd1efae1346dbc20ec50931..cc6ec092723c6e6417855ef637901c83b5f2b3b6 100644 --- a/mace/ops/conv_2d_benchmark.cc +++ b/mace/ops/conv_2d_benchmark.cc @@ -61,7 +61,10 @@ static void Conv2d(int iters, int batch, int channels, int height, int width, BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, NEON); BM_CONV_2D(1, 64, 32, 32, 1, 1, 1, VALID, 128, float); +BM_CONV_2D(1, 64, 33, 31, 1, 1, 1, VALID, 128, float); // Test bad alignments BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, VALID, 128, float); +BM_CONV_2D(1, 64, 33, 31, 3, 3, 1, VALID, 128, float); BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, SAME, 128, float); +BM_CONV_2D(1, 64, 33, 31, 3, 3, 1, SAME, 128, float); } // namespace mace diff --git a/mace/ops/conv_2d_test.cc b/mace/ops/conv_2d_test.cc index c516dd734df808657a13a92f593084fdd3cce70b..b169400f73ae8caa9553faecc4fb0c12c0b78cf2 100644 --- a/mace/ops/conv_2d_test.cc +++ b/mace/ops/conv_2d_test.cc @@ -192,8 +192,10 @@ TEST_F(Conv2dOpTest, Conv1x1) { } // TODO we need more tests -TEST_F(Conv2dOpTest, Conv3x3R1) { - auto func = [&](Padding type) { +TEST_F(Conv2dOpTest, ConvNxNS12) { + auto func = [&](int kernel_h, int kernel_w, + int stride_h, int stride_w, + Padding type) { srand(time(NULL)); // generate random input @@ -212,13 +214,14 @@ TEST_F(Conv2dOpTest, Conv3x3R1) { .Finalize(net.operator_def()); // Add args - net.AddIntsArg("strides", {1, 1}); + net.AddIntsArg("strides", {stride_h, stride_w}); net.AddIntArg("padding", type); net.AddIntsArg("dilations", {1, 1}); // Add input data net.AddRandomInput("Input", {batch, input_channels, height, width}); - net.AddRandomInput("Filter", {output_channels, input_channels, 3, 3}); + net.AddRandomInput("Filter", {output_channels, input_channels, + kernel_h, kernel_w}); net.AddRandomInput("Bias", {output_channels}); // run cpu net.RunOp(); @@ -233,6 +236,10 @@ TEST_F(Conv2dOpTest, Conv3x3R1) { }; - func(VALID); - func(SAME); + for (int kernel_size : {1, 3}) { + for (int stride : {1, 2}) { + func(kernel_size, kernel_size, stride, stride, VALID); + func(kernel_size, kernel_size, stride, stride, SAME); + } + } }