提交 90326198 编写于 作者: H hedaoyuan

Bug fix & add test of GemmConvGradFilter.

上级 6a93f0f3
...@@ -19,11 +19,18 @@ limitations under the License. */ ...@@ -19,11 +19,18 @@ limitations under the License. */
namespace paddle { namespace paddle {
enum TestType {
FORWARD_TEST = 0,
BACKWARD_INPUT_TEST = 1,
BACKWARD_FILTER_TEST = 2,
};
template <DeviceType DType1, DeviceType DType2> template <DeviceType DType1, DeviceType DType2>
class ConvolutionTest { class ConvolutionTest {
public: public:
ConvolutionTest(const std::string& conv1, ConvolutionTest(const std::string& conv1,
const std::string& conv2, const std::string& conv2,
TestType type,
std::string algo = "auto") { std::string algo = "auto") {
for (size_t batchSize : {1, 32}) { for (size_t batchSize : {1, 32}) {
for (size_t inputSize : {7, 14, 54}) { for (size_t inputSize : {7, 14, 54}) {
...@@ -58,16 +65,31 @@ public: ...@@ -58,16 +65,31 @@ public:
.set("groups", (size_t)1) .set("groups", (size_t)1)
.set("algo", algo)); .set("algo", algo));
TensorShape shape0{ TensorShape input{
batchSize, inputChannels, inputSize, inputSize}; batchSize, inputChannels, inputSize, inputSize};
TensorShape shape1{ TensorShape filter{
outputChannels, inputChannels, filterSize, filterSize}; outputChannels, inputChannels, filterSize, filterSize};
TensorShape shape2{ TensorShape output{
batchSize, outputChannels, outputSize, outputSize}; batchSize, outputChannels, outputSize, outputSize};
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, shape0));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, shape1)); if (type == FORWARD_TEST) {
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, shape2)); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input));
test.run(); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.run();
} else if (type == BACKWARD_INPUT_TEST) {
#if 0
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, input));
test.run();
#endif
} else if (type == BACKWARD_FILTER_TEST) {
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, filter));
test.run();
}
} }
} }
} }
...@@ -78,15 +100,20 @@ public: ...@@ -78,15 +100,20 @@ public:
} }
}; };
TEST(Convolution, GEMM) { TEST(Forward, GEMM) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_CPU> test("NaiveConv-CPU", ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_CPU> test(
"GemmConv-CPU"); "NaiveConv-CPU", "GemmConv-CPU", FORWARD_TEST);
} }
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
TEST(Convolution, GEMM2) { TEST(Forward, GEMM2) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test("GemmConv-CPU", ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test(
"GemmConv-GPU"); "GemmConv-CPU", "GemmConv-GPU", FORWARD_TEST);
}
TEST(BackwardFilter, GEMM) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test(
"GemmConvGradFilter-CPU", "GemmConvGradFilter-GPU", BACKWARD_FILTER_TEST);
} }
#endif #endif
......
...@@ -255,9 +255,9 @@ public: ...@@ -255,9 +255,9 @@ public:
filterGrad + g * filterOffset, filterGrad + g * filterOffset,
N); N);
} }
inputData += inputChannels * inputHeight * inputWidth;
outputGrad += outputChannels * outputHeight * outputWidth;
} }
inputData += inputChannels * inputHeight * inputWidth;
outputGrad += outputChannels * outputHeight * outputWidth;
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册