diff --git a/paddle/function/ConvOp.h b/paddle/function/ConvOp.h index 017d4e26f2b7b583737a8698ac87e63e69535ac2..14c20b74f2007314534b20d36cf588cef84d34cd 100644 --- a/paddle/function/ConvOp.h +++ b/paddle/function/ConvOp.h @@ -58,7 +58,7 @@ public: CHECK_EQ(outputs[0].shape().ndims(), (size_t)4); CHECK(inputs[0].shape()[0] == outputs[0].shape()[0]); - CHECK(inputs[0].shape()[1] == inputs[1].shape()[1]); + CHECK(inputs[0].shape()[1] / groups_ == inputs[1].shape()[1]); CHECK(outputs[0].shape()[1] == inputs[1].shape()[0]); } diff --git a/paddle/function/ConvOpTest.cpp b/paddle/function/ConvOpTest.cpp index 896267141337f1476c8b48a13b99cc928485f11c..d9de211448879a152351bdbcd9e2282274e42832 100644 --- a/paddle/function/ConvOpTest.cpp +++ b/paddle/function/ConvOpTest.cpp @@ -83,9 +83,11 @@ TEST(Convolution, GEMM) { "GemmConv-CPU"); } +#ifndef PADDLE_ONLY_CPU TEST(Convolution, GEMM2) { ConvolutionTest test("GemmConv-CPU", "GemmConv-GPU"); } +#endif } // namespace paddle diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp index e7a93ae676ff7388425e9efee619bcddf21692af..78aa8f14f34457e131dd4126fc6b8b2b6b07485c 100644 --- a/paddle/function/GemmConvOp.cpp +++ b/paddle/function/GemmConvOp.cpp @@ -101,8 +101,6 @@ public: size_t outputHeight = outputs[0].shape()[2]; size_t outputWidth = outputs[0].shape()[3]; - CHECK_EQ(inputChannels / groups_, inputs[1].shape()[1]); - real* inputData = inputs[0].data(); real* filterData = inputs[1].data(); real* outputData = outputs[0].data(); @@ -134,9 +132,9 @@ public: outputWidth, colData); - int M = outputChannels; + int M = outputChannels / groups_; int N = outputHeight * outputWidth; - int K = inputChannels * filterHeight * filterWidth; + int K = inputChannels / groups_ * filterHeight * filterWidth; gemm(M, N, K,