提交 2608c485 编写于 作者: H hedaoyuan

Add test cases where the height and width (input, filter) are not equal.

上级 01d52ebf
...@@ -98,25 +98,112 @@ public: ...@@ -98,25 +98,112 @@ public:
} }
}; };
// Mainly used to test cases where the height and width (input, filter)
// are not equal.
template <DeviceType DType1, DeviceType DType2>
class ConvolutionTest2 {
public:
ConvolutionTest2(const std::string& conv1,
const std::string& conv2,
TestType type,
std::string algo = "auto") {
for (size_t batchSize : {16}) {
for (size_t inputHeight : {7, 31}) {
for (size_t inputWidth : {10, 54}) {
for (size_t filterHeight : {1, 5}) {
for (size_t filterWidth : {3, 7}) {
for (size_t inputChannels : {7}) {
for (size_t outputChannels : {32}) {
size_t stride = 1;
size_t padding = 0;
size_t outputHeight =
(inputHeight - filterHeight + 2 * padding + stride) /
stride;
size_t outputWidth =
(inputWidth - filterWidth + 2 * padding + stride) /
stride;
VLOG(3) << " batchSize=" << batchSize
<< " inputChannels=" << inputChannels
<< " inputHeight=" << inputHeight
<< " inputWidth=" << inputWidth
<< " outputChannels=" << outputChannels
<< " filterHeight=" << filterHeight
<< " filterWidth=" << filterWidth
<< " outputHeight=" << outputHeight
<< " outputWidth=" << outputWidth
<< " stride=" << stride << " padding=" << padding;
std::vector<size_t> paddings = {padding, padding};
std::vector<size_t> strides = {stride, stride};
Compare2Function<DType1, DType2> test(
conv1,
conv2,
FuncConfig()
.set("paddings", paddings)
.set("strides", strides)
.set("groups", (size_t)1)
.set("algo", algo));
TensorShape input{
batchSize, inputChannels, inputHeight, inputWidth};
TensorShape filter{
outputChannels, inputChannels, filterHeight, filterWidth};
TensorShape output{
batchSize, outputChannels, outputHeight, outputWidth};
if (type == kForwardTest) {
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.run();
} else if (type == kBackwardInputTest) {
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, input), ADD_TO);
test.run();
} else if (type == kBackwardFilterTest) {
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, filter));
test.run();
}
}
}
}
}
}
}
}
}
};
TEST(Forward, GEMM) { TEST(Forward, GEMM) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_CPU> test( ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_CPU> test(
"NaiveConv-CPU", "GemmConv-CPU", kForwardTest); "NaiveConv-CPU", "GemmConv-CPU", kForwardTest);
ConvolutionTest2<DEVICE_TYPE_CPU, DEVICE_TYPE_CPU> test2(
"NaiveConv-CPU", "GemmConv-CPU", kForwardTest);
} }
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
TEST(Forward, GEMM2) { TEST(Forward, GEMM2) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test( ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test(
"GemmConv-CPU", "GemmConv-GPU", kForwardTest); "GemmConv-CPU", "GemmConv-GPU", kForwardTest);
ConvolutionTest2<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test2(
"GemmConv-CPU", "GemmConv-GPU", kForwardTest);
} }
TEST(BackwardInput, GEMM) { TEST(BackwardInput, GEMM) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test( ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test(
"GemmConvGradInput-CPU", "GemmConvGradInput-GPU", kBackwardInputTest); "GemmConvGradInput-CPU", "GemmConvGradInput-GPU", kBackwardInputTest);
ConvolutionTest2<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test2(
"GemmConvGradInput-CPU", "GemmConvGradInput-GPU", kBackwardInputTest);
} }
TEST(BackwardFilter, GEMM) { TEST(BackwardFilter, GEMM) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test( ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test(
"GemmConvGradFilter-CPU", "GemmConvGradFilter-GPU", kBackwardFilterTest); "GemmConvGradFilter-CPU", "GemmConvGradFilter-GPU", kBackwardFilterTest);
ConvolutionTest2<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test2(
"GemmConvGradFilter-CPU", "GemmConvGradFilter-GPU", kBackwardFilterTest);
} }
#endif #endif
......
...@@ -104,7 +104,7 @@ public: ...@@ -104,7 +104,7 @@ public:
size_t inputHeight = inputs[0].shape()[2]; size_t inputHeight = inputs[0].shape()[2];
size_t inputWidth = inputs[0].shape()[3]; size_t inputWidth = inputs[0].shape()[3];
size_t filterHeight = inputs[1].shape()[2]; size_t filterHeight = inputs[1].shape()[2];
size_t filterWidth = inputs[1].shape()[2]; size_t filterWidth = inputs[1].shape()[3];
size_t outputChannels = outputs[0].shape()[1]; size_t outputChannels = outputs[0].shape()[1];
size_t outputHeight = outputs[0].shape()[2]; size_t outputHeight = outputs[0].shape()[2];
size_t outputWidth = outputs[0].shape()[3]; size_t outputWidth = outputs[0].shape()[3];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册