提交 3408b4b2 编写于 作者: H hedaoyuan

Bug fix

上级 c70d3e1a
......@@ -58,7 +58,7 @@ public:
CHECK_EQ(outputs[0].shape().ndims(), (size_t)4);
CHECK(inputs[0].shape()[0] == outputs[0].shape()[0]);
CHECK(inputs[0].shape()[1] == inputs[1].shape()[1]);
CHECK(inputs[0].shape()[1] / groups_ == inputs[1].shape()[1]);
CHECK(outputs[0].shape()[1] == inputs[1].shape()[0]);
}
......
......@@ -83,9 +83,11 @@ TEST(Convolution, GEMM) {
"GemmConv-CPU");
}
#ifndef PADDLE_ONLY_CPU
TEST(Convolution, GEMM2) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test("GemmConv-CPU",
"GemmConv-GPU");
}
#endif
} // namespace paddle
......@@ -101,8 +101,6 @@ public:
size_t outputHeight = outputs[0].shape()[2];
size_t outputWidth = outputs[0].shape()[3];
CHECK_EQ(inputChannels / groups_, inputs[1].shape()[1]);
real* inputData = inputs[0].data<real>();
real* filterData = inputs[1].data<real>();
real* outputData = outputs[0].data<real>();
......@@ -134,9 +132,9 @@ public:
outputWidth,
colData);
int M = outputChannels;
int M = outputChannels / groups_;
int N = outputHeight * outputWidth;
int K = inputChannels * filterHeight * filterWidth;
int K = inputChannels / groups_ * filterHeight * filterWidth;
gemm(M,
N,
K,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册