From 90326198e929772fe3e87fe5c067f057927f7f64 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Mon, 5 Jun 2017 21:35:10 +0800 Subject: [PATCH] Bug fix & add test of GemmConvGradFilter. --- paddle/function/ConvOpTest.cpp | 53 +++++++++++++++++++++++++--------- paddle/function/GemmConvOp.cpp | 4 +-- 2 files changed, 42 insertions(+), 15 deletions(-) diff --git a/paddle/function/ConvOpTest.cpp b/paddle/function/ConvOpTest.cpp index d9de2114488..e2997df0128 100644 --- a/paddle/function/ConvOpTest.cpp +++ b/paddle/function/ConvOpTest.cpp @@ -19,11 +19,18 @@ limitations under the License. */ namespace paddle { +enum TestType { + FORWARD_TEST = 0, + BACKWARD_INPUT_TEST = 1, + BACKWARD_FILTER_TEST = 2, +}; + template class ConvolutionTest { public: ConvolutionTest(const std::string& conv1, const std::string& conv2, + TestType type, std::string algo = "auto") { for (size_t batchSize : {1, 32}) { for (size_t inputSize : {7, 14, 54}) { @@ -58,16 +65,31 @@ public: .set("groups", (size_t)1) .set("algo", algo)); - TensorShape shape0{ + TensorShape input{ batchSize, inputChannels, inputSize, inputSize}; - TensorShape shape1{ + TensorShape filter{ outputChannels, inputChannels, filterSize, filterSize}; - TensorShape shape2{ + TensorShape output{ batchSize, outputChannels, outputSize, outputSize}; - test.addInputs(BufferArg(VALUE_TYPE_FLOAT, shape0)); - test.addInputs(BufferArg(VALUE_TYPE_FLOAT, shape1)); - test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, shape2)); - test.run(); + + if (type == FORWARD_TEST) { + test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input)); + test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter)); + test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, output)); + test.run(); + } else if (type == BACKWARD_INPUT_TEST) { +#if 0 + test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output)); + test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter)); + test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, input)); + test.run(); +#endif + } else if (type == BACKWARD_FILTER_TEST) { + test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output)); + test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input)); + test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, filter)); + test.run(); + } } } } @@ -78,15 +100,20 @@ public: } }; -TEST(Convolution, GEMM) { - ConvolutionTest test("NaiveConv-CPU", - "GemmConv-CPU"); +TEST(Forward, GEMM) { + ConvolutionTest test( + "NaiveConv-CPU", "GemmConv-CPU", FORWARD_TEST); } #ifndef PADDLE_ONLY_CPU -TEST(Convolution, GEMM2) { - ConvolutionTest test("GemmConv-CPU", - "GemmConv-GPU"); +TEST(Forward, GEMM2) { + ConvolutionTest test( + "GemmConv-CPU", "GemmConv-GPU", FORWARD_TEST); +} + +TEST(BackwardFilter, GEMM) { + ConvolutionTest test( + "GemmConvGradFilter-CPU", "GemmConvGradFilter-GPU", BACKWARD_FILTER_TEST); } #endif diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp index 6b5db1d62ed..414c7a885b6 100644 --- a/paddle/function/GemmConvOp.cpp +++ b/paddle/function/GemmConvOp.cpp @@ -255,9 +255,9 @@ public: filterGrad + g * filterOffset, N); } + inputData += inputChannels * inputHeight * inputWidth; + outputGrad += outputChannels * outputHeight * outputWidth; } - inputData += inputChannels * inputHeight * inputWidth; - outputGrad += outputChannels * outputHeight * outputWidth; } }; -- GitLab