diff --git a/paddle/function/ConvOp.h b/paddle/function/ConvOp.h index 465db57ae7d82049d30973e643a12c27c39ec304..173ca228096d962f5bcffa18d66d2926295d0a7c 100644 --- a/paddle/function/ConvOp.h +++ b/paddle/function/ConvOp.h @@ -38,6 +38,7 @@ public: // function arguments strides_ = config.get>("strides"); paddings_ = config.get>("paddings"); + groups_ = config.get("groups"); // number of inputs and outputs numInputs_ = 2; @@ -62,6 +63,11 @@ public: protected: std::vector strides_; std::vector paddings_; + /// Group size, refer to grouped convolution in + /// Alex Krizhevsky's paper: when group=2, the first half of the + /// filters are only connected to the first half of the input channels, + /// and the second half only connected to the second half. + size_t groups_; inline int strideH() const { return strides_[0]; } inline int strideW() const { return strides_[1]; } diff --git a/paddle/function/ConvOpTest.cpp b/paddle/function/ConvOpTest.cpp index db8d9fa9da4609248078598257346245d8b92be9..eb0084804814cd188b90f8e0933cc633657a9d19 100644 --- a/paddle/function/ConvOpTest.cpp +++ b/paddle/function/ConvOpTest.cpp @@ -55,6 +55,7 @@ public: FuncConfig() .set("paddings", paddings) .set("strides", strides) + .set("groups", (size_t)1) .set("algo", algo)); TensorShape shape0{ diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp index 42786e44e0e97a315bb5f71b9d3d389d9f743f85..b8e44cc60bce40182c5ada6c8d975d9f297c3025 100644 --- a/paddle/function/GemmConvOp.cpp +++ b/paddle/function/GemmConvOp.cpp @@ -101,49 +101,57 @@ public: size_t outputHeight = outputs[0].shape()[2]; size_t outputWidth = outputs[0].shape()[3]; + CHECK_EQ(inputChannels / groups_, inputs[1].shape()[1]); + real* inputData = inputs[0].data(); real* filterData = inputs[1].data(); real* outputData = outputs[0].data(); - size_t size = - inputChannels * filterHeight * filterWidth * outputHeight * outputWidth; + size_t size = inputChannels / groups_ * filterHeight * filterWidth * + outputHeight * outputWidth; resizeBuffer(size); real* colData = reinterpret_cast(memory_->getBuf()); Im2ColFunctor im2col; + size_t inputOffset = (inputChannels / groups_) * inputHeight * inputWidth; + size_t outputOffset = + (outputChannels / groups_) * outputHeight * outputWidth; + size_t filterOffset = inputs[1].shape().getElements() / groups_; for (size_t i = 0; i < batchSize; i++) { - im2col(inputData, - inputChannels, - inputHeight, - inputWidth, - filterHeight, - filterWidth, - strideH(), - strideW(), - paddingH(), - paddingW(), - outputHeight, - outputWidth, - colData); - - int M = outputChannels; - int N = outputHeight * outputWidth; - int K = inputChannels * filterHeight * filterWidth; - gemm(CblasNoTrans, - CblasNoTrans, - M, - N, - K, - 1.0f, - filterData, - K, - colData, - N, - 0.0f, - outputData, - N); - inputData += inputChannels * inputHeight * inputWidth; - outputData += outputChannels * outputHeight * outputWidth; + for (int g = 0; g < groups_; g++) { + im2col(inputData + g * inputOffset, + inputChannels / groups_, + inputHeight, + inputWidth, + filterHeight, + filterWidth, + strideH(), + strideW(), + paddingH(), + paddingW(), + outputHeight, + outputWidth, + colData); + + int M = outputChannels; + int N = outputHeight * outputWidth; + int K = inputChannels * filterHeight * filterWidth; + gemm(CblasNoTrans, + CblasNoTrans, + M, + N, + K, + 1.0f, + filterData + g * filterOffset, + K, + colData, + N, + 0.0f, + outputData + g * outputOffset, + N); + inputData += inputChannels * inputHeight * inputWidth; + outputData += outputChannels * outputHeight * outputWidth; + } } }