提交 90326198 编写于 作者: H hedaoyuan

Bug fix & add test of GemmConvGradFilter.

上级 6a93f0f3
......@@ -19,11 +19,18 @@ limitations under the License. */
namespace paddle {
enum TestType {
FORWARD_TEST = 0,
BACKWARD_INPUT_TEST = 1,
BACKWARD_FILTER_TEST = 2,
};
template <DeviceType DType1, DeviceType DType2>
class ConvolutionTest {
public:
ConvolutionTest(const std::string& conv1,
const std::string& conv2,
TestType type,
std::string algo = "auto") {
for (size_t batchSize : {1, 32}) {
for (size_t inputSize : {7, 14, 54}) {
......@@ -58,16 +65,31 @@ public:
.set("groups", (size_t)1)
.set("algo", algo));
TensorShape shape0{
TensorShape input{
batchSize, inputChannels, inputSize, inputSize};
TensorShape shape1{
TensorShape filter{
outputChannels, inputChannels, filterSize, filterSize};
TensorShape shape2{
TensorShape output{
batchSize, outputChannels, outputSize, outputSize};
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, shape0));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, shape1));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, shape2));
if (type == FORWARD_TEST) {
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.run();
} else if (type == BACKWARD_INPUT_TEST) {
#if 0
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, input));
test.run();
#endif
} else if (type == BACKWARD_FILTER_TEST) {
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, filter));
test.run();
}
}
}
}
......@@ -78,15 +100,20 @@ public:
}
};
TEST(Convolution, GEMM) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_CPU> test("NaiveConv-CPU",
"GemmConv-CPU");
TEST(Forward, GEMM) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_CPU> test(
"NaiveConv-CPU", "GemmConv-CPU", FORWARD_TEST);
}
#ifndef PADDLE_ONLY_CPU
TEST(Convolution, GEMM2) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test("GemmConv-CPU",
"GemmConv-GPU");
TEST(Forward, GEMM2) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test(
"GemmConv-CPU", "GemmConv-GPU", FORWARD_TEST);
}
TEST(BackwardFilter, GEMM) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test(
"GemmConvGradFilter-CPU", "GemmConvGradFilter-GPU", BACKWARD_FILTER_TEST);
}
#endif
......
......@@ -255,10 +255,10 @@ public:
filterGrad + g * filterOffset,
N);
}
}
inputData += inputChannels * inputHeight * inputWidth;
outputGrad += outputChannels * outputHeight * outputWidth;
}
}
};
REGISTER_TYPED_FUNC(GemmConv, CPU, GemmConvFunction);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册