From 3408b4b2f409a5a8191248c7c17e1c882779de27 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Fri, 2 Jun 2017 19:52:08 +0800 Subject: [PATCH] Bug fix --- paddle/function/ConvOp.h | 2 +- paddle/function/ConvOpTest.cpp | 2 ++ paddle/function/GemmConvOp.cpp | 6 ++---- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/paddle/function/ConvOp.h b/paddle/function/ConvOp.h index 017d4e26f2b..14c20b74f20 100644 --- a/paddle/function/ConvOp.h +++ b/paddle/function/ConvOp.h @@ -58,7 +58,7 @@ public: CHECK_EQ(outputs[0].shape().ndims(), (size_t)4); CHECK(inputs[0].shape()[0] == outputs[0].shape()[0]); - CHECK(inputs[0].shape()[1] == inputs[1].shape()[1]); + CHECK(inputs[0].shape()[1] / groups_ == inputs[1].shape()[1]); CHECK(outputs[0].shape()[1] == inputs[1].shape()[0]); } diff --git a/paddle/function/ConvOpTest.cpp b/paddle/function/ConvOpTest.cpp index 89626714133..d9de2114488 100644 --- a/paddle/function/ConvOpTest.cpp +++ b/paddle/function/ConvOpTest.cpp @@ -83,9 +83,11 @@ TEST(Convolution, GEMM) { "GemmConv-CPU"); } +#ifndef PADDLE_ONLY_CPU TEST(Convolution, GEMM2) { ConvolutionTest test("GemmConv-CPU", "GemmConv-GPU"); } +#endif } // namespace paddle diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp index e7a93ae676f..78aa8f14f34 100644 --- a/paddle/function/GemmConvOp.cpp +++ b/paddle/function/GemmConvOp.cpp @@ -101,8 +101,6 @@ public: size_t outputHeight = outputs[0].shape()[2]; size_t outputWidth = outputs[0].shape()[3]; - CHECK_EQ(inputChannels / groups_, inputs[1].shape()[1]); - real* inputData = inputs[0].data(); real* filterData = inputs[1].data(); real* outputData = outputs[0].data(); @@ -134,9 +132,9 @@ public: outputWidth, colData); - int M = outputChannels; + int M = outputChannels / groups_; int N = outputHeight * outputWidth; - int K = inputChannels * filterHeight * filterWidth; + int K = inputChannels / groups_ * filterHeight * filterWidth; gemm(M, N, K, -- GitLab