提交 0cd44960 编写于 作者: 刘琦

Merge branch 'conv1x1_test' into 'master'

Update conv2d unit tests and add benchmark for bad alignments

See merge request !38
...@@ -61,7 +61,10 @@ static void Conv2d(int iters, int batch, int channels, int height, int width, ...@@ -61,7 +61,10 @@ static void Conv2d(int iters, int batch, int channels, int height, int width,
BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, NEON); BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, NEON);
BM_CONV_2D(1, 64, 32, 32, 1, 1, 1, VALID, 128, float); BM_CONV_2D(1, 64, 32, 32, 1, 1, 1, VALID, 128, float);
BM_CONV_2D(1, 64, 33, 31, 1, 1, 1, VALID, 128, float); // Test bad alignments
BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, VALID, 128, float); BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, VALID, 128, float);
BM_CONV_2D(1, 64, 33, 31, 3, 3, 1, VALID, 128, float);
BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, SAME, 128, float); BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, SAME, 128, float);
BM_CONV_2D(1, 64, 33, 31, 3, 3, 1, SAME, 128, float);
} // namespace mace } // namespace mace
...@@ -192,8 +192,10 @@ TEST_F(Conv2dOpTest, Conv1x1) { ...@@ -192,8 +192,10 @@ TEST_F(Conv2dOpTest, Conv1x1) {
} }
// TODO we need more tests // TODO we need more tests
TEST_F(Conv2dOpTest, Conv3x3R1) { TEST_F(Conv2dOpTest, ConvNxNS12) {
auto func = [&](Padding type) { auto func = [&](int kernel_h, int kernel_w,
int stride_h, int stride_w,
Padding type) {
srand(time(NULL)); srand(time(NULL));
// generate random input // generate random input
...@@ -212,13 +214,14 @@ TEST_F(Conv2dOpTest, Conv3x3R1) { ...@@ -212,13 +214,14 @@ TEST_F(Conv2dOpTest, Conv3x3R1) {
.Finalize(net.operator_def()); .Finalize(net.operator_def());
// Add args // Add args
net.AddIntsArg("strides", {1, 1}); net.AddIntsArg("strides", {stride_h, stride_w});
net.AddIntArg("padding", type); net.AddIntArg("padding", type);
net.AddIntsArg("dilations", {1, 1}); net.AddIntsArg("dilations", {1, 1});
// Add input data // Add input data
net.AddRandomInput<float>("Input", {batch, input_channels, height, width}); net.AddRandomInput<float>("Input", {batch, input_channels, height, width});
net.AddRandomInput<float>("Filter", {output_channels, input_channels, 3, 3}); net.AddRandomInput<float>("Filter", {output_channels, input_channels,
kernel_h, kernel_w});
net.AddRandomInput<float>("Bias", {output_channels}); net.AddRandomInput<float>("Bias", {output_channels});
// run cpu // run cpu
net.RunOp(); net.RunOp();
...@@ -233,6 +236,10 @@ TEST_F(Conv2dOpTest, Conv3x3R1) { ...@@ -233,6 +236,10 @@ TEST_F(Conv2dOpTest, Conv3x3R1) {
}; };
func(VALID); for (int kernel_size : {1, 3}) {
func(SAME); for (int stride : {1, 2}) {
func(kernel_size, kernel_size, stride, stride, VALID);
func(kernel_size, kernel_size, stride, stride, SAME);
}
}
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册