提交 248149f4 编写于 作者: X xzl

add depthwiseconv test and fix the little bug of the convOpTest

上级 5b07d4e0
...@@ -38,19 +38,20 @@ public: ...@@ -38,19 +38,20 @@ public:
for (size_t filterSize : {1, 3, 5}) { for (size_t filterSize : {1, 3, 5}) {
for (size_t inputChannels : {3, 64}) { for (size_t inputChannels : {3, 64}) {
for (size_t outputChannels : {3, 64}) { for (size_t outputChannels : {3, 64}) {
for (size_t groups : {1, 3, 64}) {
if (inputChannels > outputChannels) break; if (inputChannels > outputChannels) break;
if (groups != 1 && size_t groups;
(inputChannels != groups || outputChannels % groups != 0)) if (!useGroups) {
continue; groups = 1;
if (!useGroups) groups = 1; } else {
if (outputChannels % inputChannels != 0) continue;
groups = inputChannels;
}
for (size_t stride : {1, 2}) { for (size_t stride : {1, 2}) {
for (size_t padding : {0, 1}) { for (size_t padding : {0, 1}) {
if (padding >= filterSize) break; if (padding >= filterSize) break;
size_t outputSize = size_t outputSize =
(inputSize - filterSize + 2 * padding + stride) / (inputSize - filterSize + 2 * padding + stride) / stride;
stride;
VLOG(3) << " batchSize=" << batchSize VLOG(3) << " batchSize=" << batchSize
<< " inputChannels=" << inputChannels << " inputChannels=" << inputChannels
<< " inputHeight=" << inputSize << " inputHeight=" << inputSize
...@@ -99,13 +100,13 @@ public: ...@@ -99,13 +100,13 @@ public:
} else if (type == kBackwardInputTest) { } else if (type == kBackwardInputTest) {
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output)); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter)); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, input), test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, input), ADD_TO);
ADD_TO);
test.run(); test.run();
} else if (type == kBackwardFilterTest) { } else if (type == kBackwardFilterTest) {
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output)); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input)); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, filter)); test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, filter),
ADD_TO);
test.run(); test.run();
} }
} }
...@@ -116,7 +117,6 @@ public: ...@@ -116,7 +117,6 @@ public:
} }
} }
} }
}
}; };
// Mainly used to test cases where the height and width (input, filter) // Mainly used to test cases where the height and width (input, filter)
...@@ -136,11 +136,13 @@ public: ...@@ -136,11 +136,13 @@ public:
for (size_t filterWidth : {3, 7}) { for (size_t filterWidth : {3, 7}) {
for (size_t inputChannels : {7}) { for (size_t inputChannels : {7}) {
for (size_t outputChannels : {7}) { for (size_t outputChannels : {7}) {
for (size_t groups : {1, 7}) { size_t groups;
if (groups != 1 && (inputChannels != groups || if (!useGroups) {
outputChannels % groups != 0)) groups = 1;
continue; } else {
if (!useGroups) groups = 1; if (outputChannels % inputChannels != 0) continue;
groups = inputChannels;
}
size_t stride = 1; size_t stride = 1;
size_t padding = 0; size_t padding = 0;
...@@ -198,13 +200,13 @@ public: ...@@ -198,13 +200,13 @@ public:
} else if (type == kBackwardInputTest) { } else if (type == kBackwardInputTest) {
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output)); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter)); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, input), test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, input), ADD_TO);
ADD_TO);
test.run(); test.run();
} else if (type == kBackwardFilterTest) { } else if (type == kBackwardFilterTest) {
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output)); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input)); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, filter)); test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, filter),
ADD_TO);
test.run(); test.run();
} }
} }
...@@ -215,9 +217,10 @@ public: ...@@ -215,9 +217,10 @@ public:
} }
} }
} }
}
}; };
// ======Start Convolution TEST======
TEST(Forward, GEMM) { TEST(Forward, GEMM) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_CPU> test( ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_CPU> test(
"NaiveConv-CPU", "GemmConv-CPU", kForwardTest, false); "NaiveConv-CPU", "GemmConv-CPU", kForwardTest, false);
...@@ -228,24 +231,76 @@ TEST(Forward, GEMM) { ...@@ -228,24 +231,76 @@ TEST(Forward, GEMM) {
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
TEST(Forward, GEMM2) { TEST(Forward, GEMM2) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test( ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test(
"GemmConv-CPU", "GemmConv-GPU", kForwardTest); "GemmConv-CPU", "GemmConv-GPU", kForwardTest, false);
ConvolutionTest2<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test2( ConvolutionTest2<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test2(
"GemmConv-CPU", "GemmConv-GPU", kForwardTest); "GemmConv-CPU", "GemmConv-GPU", kForwardTest, false);
} }
TEST(BackwardInput, GEMM) { TEST(BackwardInput, GEMM) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test( ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test(
"GemmConvGradInput-CPU", "GemmConvGradInput-GPU", kBackwardInputTest); "GemmConvGradInput-CPU",
"GemmConvGradInput-GPU",
kBackwardInputTest,
false);
ConvolutionTest2<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test2( ConvolutionTest2<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test2(
"GemmConvGradInput-CPU", "GemmConvGradInput-GPU", kBackwardInputTest); "GemmConvGradInput-CPU",
"GemmConvGradInput-GPU",
kBackwardInputTest,
false);
} }
TEST(BackwardFilter, GEMM) { TEST(BackwardFilter, GEMM) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test( ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test(
"GemmConvGradFilter-CPU", "GemmConvGradFilter-GPU", kBackwardFilterTest); "GemmConvGradFilter-CPU",
"GemmConvGradFilter-GPU",
kBackwardFilterTest,
false);
ConvolutionTest2<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test2(
"GemmConvGradFilter-CPU",
"GemmConvGradFilter-GPU",
kBackwardFilterTest,
false);
}
#endif
// ======End Convolution TEST======
// ======Start DepthwiseConvolution TEST======
// TODO(zhaolong) The depthwise convolution cpu test will be added when the cpu
// version of depthwiseConv is implemented.
#ifndef PADDLE_ONLY_CPU
TEST(DepthwiseConvForward, GEMM2) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test(
"GemmConv-CPU", "DepthwiseConv-GPU", kForwardTest);
ConvolutionTest2<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test2(
"GemmConv-CPU", "DepthwiseConv-GPU", kForwardTest);
}
TEST(DepthwiseConvBackwardInput, GEMM) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test(
"GemmConvGradInput-CPU",
"DepthwiseConvGradInput-GPU",
kBackwardInputTest);
ConvolutionTest2<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test2( ConvolutionTest2<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test2(
"GemmConvGradFilter-CPU", "GemmConvGradFilter-GPU", kBackwardFilterTest); "GemmConvGradInput-CPU",
"DepthwiseConvGradInput-GPU",
kBackwardInputTest);
} }
TEST(DepthwiseConvBackwardFilter, GEMM) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test(
"GemmConvGradFilter-CPU",
"DepthwiseConvGradFilter-GPU",
kBackwardFilterTest);
ConvolutionTest2<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test2(
"GemmConvGradFilter-CPU",
"DepthwiseConvGradFilter-GPU",
kBackwardFilterTest);
}
#endif #endif
// ======End DepthwiseConvolution TEST======
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册