提交 3408b4b2 编写于 作者: H hedaoyuan

Bug fix

上级 c70d3e1a
...@@ -58,7 +58,7 @@ public: ...@@ -58,7 +58,7 @@ public:
CHECK_EQ(outputs[0].shape().ndims(), (size_t)4); CHECK_EQ(outputs[0].shape().ndims(), (size_t)4);
CHECK(inputs[0].shape()[0] == outputs[0].shape()[0]); CHECK(inputs[0].shape()[0] == outputs[0].shape()[0]);
CHECK(inputs[0].shape()[1] == inputs[1].shape()[1]); CHECK(inputs[0].shape()[1] / groups_ == inputs[1].shape()[1]);
CHECK(outputs[0].shape()[1] == inputs[1].shape()[0]); CHECK(outputs[0].shape()[1] == inputs[1].shape()[0]);
} }
......
...@@ -83,9 +83,11 @@ TEST(Convolution, GEMM) { ...@@ -83,9 +83,11 @@ TEST(Convolution, GEMM) {
"GemmConv-CPU"); "GemmConv-CPU");
} }
#ifndef PADDLE_ONLY_CPU
TEST(Convolution, GEMM2) { TEST(Convolution, GEMM2) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test("GemmConv-CPU", ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test("GemmConv-CPU",
"GemmConv-GPU"); "GemmConv-GPU");
} }
#endif
} // namespace paddle } // namespace paddle
...@@ -101,8 +101,6 @@ public: ...@@ -101,8 +101,6 @@ public:
size_t outputHeight = outputs[0].shape()[2]; size_t outputHeight = outputs[0].shape()[2];
size_t outputWidth = outputs[0].shape()[3]; size_t outputWidth = outputs[0].shape()[3];
CHECK_EQ(inputChannels / groups_, inputs[1].shape()[1]);
real* inputData = inputs[0].data<real>(); real* inputData = inputs[0].data<real>();
real* filterData = inputs[1].data<real>(); real* filterData = inputs[1].data<real>();
real* outputData = outputs[0].data<real>(); real* outputData = outputs[0].data<real>();
...@@ -134,9 +132,9 @@ public: ...@@ -134,9 +132,9 @@ public:
outputWidth, outputWidth,
colData); colData);
int M = outputChannels; int M = outputChannels / groups_;
int N = outputHeight * outputWidth; int N = outputHeight * outputWidth;
int K = inputChannels * filterHeight * filterWidth; int K = inputChannels / groups_ * filterHeight * filterWidth;
gemm(M, gemm(M,
N, N,
K, K,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册