diff --git a/paddle/function/ConvOp.h b/paddle/function/ConvOp.h index bb4f48364b9b454af7d37fe4d3c340666e53285c..baf78bc6c88d0d294f4457b81c52b22e425d9fdb 100644 --- a/paddle/function/ConvOp.h +++ b/paddle/function/ConvOp.h @@ -109,6 +109,13 @@ protected: return filter[filter.ndims() - 1]; } + // determine whether im2col needs to be performed + inline bool isNeedIm2col(const TensorShape& filter) const { + return !(getFilterHeight(filter) == 1 && getFilterWidth(filter) == 1 && + strideH() == 1 && strideW() == 1 && paddingH() == 0 && + paddingW() == 0); + } + std::vector strides_; std::vector paddings_; diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp index 9deb2739fcfff935a98a0b5b31b5d11819d81227..0ada4d70a0c7d13f9b5fb1a42eac07fc4c775a87 100644 --- a/paddle/function/GemmConvOp.cpp +++ b/paddle/function/GemmConvOp.cpp @@ -66,16 +66,23 @@ public: real* inputData = inputs[0].data(); real* filterData = inputs[1].data(); real* outputData = outputs[0].data(); + bool needIm2col = isNeedIm2col(filter); + TensorShape imShape = TensorShape({inputChannels / groups_, inputHeight, inputWidth}); - TensorShape colShape = TensorShape({inputChannels / groups_, - filterHeight, - filterWidth, - outputHeight, - outputWidth}); - resizeBuffer(colShape.getElements()); - real* colData = reinterpret_cast(memory_->getBuf()); + TensorShape colShape; + real* colData = NULL; + + if (needIm2col) { + colShape = TensorShape({inputChannels / groups_, + filterHeight, + filterWidth, + outputHeight, + outputWidth}); + resizeBuffer(colShape.getElements()); + colData = reinterpret_cast(memory_->getBuf()); + } Im2ColFunctor im2col; GemmFunctor gemm; @@ -86,15 +93,18 @@ public: for (size_t i = 0; i < batchSize; i++) { for (size_t g = 0; g < groups_; g++) { - im2col(inputData + g * inputOffset, - imShape, - colData, - colShape, - strideH(), - strideW(), - paddingH(), - paddingW()); - + if (needIm2col) { + im2col(inputData + g * inputOffset, + imShape, + colData, + colShape, + strideH(), + strideW(), + paddingH(), + paddingW()); + } else { + colData = inputData + g * inputOffset; + } int M = outputChannels / groups_; int N = outputHeight * outputWidth; int K = inputChannels / groups_ * filterHeight * filterWidth; @@ -159,19 +169,27 @@ public: real* outputGrad = inputs[0].data(); real* filterData = inputs[1].data(); real* inputGrad = outputs[0].data(); + bool needIm2col = isNeedIm2col(filter); + TensorShape imShape = TensorShape({inputChannels / groups_, inputHeight, inputWidth}); - TensorShape colShape = TensorShape({inputChannels / groups_, - filterHeight, - filterWidth, - outputHeight, - outputWidth}); - resizeBuffer(colShape.getElements()); - real* colData = reinterpret_cast(memory_->getBuf()); + TensorShape colShape; + real* colData = NULL; + + if (needIm2col) { + colShape = TensorShape({inputChannels / groups_, + filterHeight, + filterWidth, + outputHeight, + outputWidth}); + resizeBuffer(colShape.getElements()); + colData = reinterpret_cast(memory_->getBuf()); + } Col2ImFunctor col2im; GemmFunctor gemm; + size_t inputOffset = imShape.getElements(); size_t outputOffset = (outputChannels / groups_) * outputHeight * outputWidth; @@ -182,6 +200,11 @@ public: int K = outputChannels / groups_; int N = outputHeight * outputWidth; int M = inputChannels / groups_ * filterHeight * filterWidth; + real scale = 0.0f; + if (!needIm2col) { + colData = inputGrad + g * inputOffset; + scale = 1.0f; + } gemm(CblasTrans, CblasNoTrans, M, @@ -192,17 +215,19 @@ public: M, outputGrad + g * outputOffset, N, - 0.0f, + scale, colData, N); - col2im(inputGrad + g * inputOffset, - imShape, - colData, - colShape, - strideH(), - strideW(), - paddingH(), - paddingW()); + if (needIm2col) { + col2im(inputGrad + g * inputOffset, + imShape, + colData, + colShape, + strideH(), + strideW(), + paddingH(), + paddingW()); + } } inputGrad += inputChannels * inputHeight * inputWidth; outputGrad += outputChannels * outputHeight * outputWidth; @@ -255,16 +280,23 @@ public: real* outputGrad = inputs[0].data(); real* inputData = inputs[1].data(); real* filterGrad = outputs[0].data(); + bool needIm2col = isNeedIm2col(filter); + TensorShape imShape = TensorShape({inputChannels / groups_, inputHeight, inputWidth}); - TensorShape colShape = TensorShape({inputChannels / groups_, - filterHeight, - filterWidth, - outputHeight, - outputWidth}); - resizeBuffer(colShape.getElements()); - real* colData = reinterpret_cast(memory_->getBuf()); + TensorShape colShape; + real* colData = NULL; + + if (needIm2col) { + colShape = TensorShape({inputChannels / groups_, + filterHeight, + filterWidth, + outputHeight, + outputWidth}); + resizeBuffer(colShape.getElements()); + colData = reinterpret_cast(memory_->getBuf()); + } Im2ColFunctor im2col; GemmFunctor gemm; @@ -274,15 +306,18 @@ public: size_t filterOffset = filter.getElements() / groups_; for (size_t i = 0; i < batchSize; i++) { for (size_t g = 0; g < groups_; g++) { - im2col(inputData + g * inputOffset, - imShape, - colData, - colShape, - strideH(), - strideW(), - paddingH(), - paddingW()); - + if (needIm2col) { + im2col(inputData + g * inputOffset, + imShape, + colData, + colShape, + strideH(), + strideW(), + paddingH(), + paddingW()); + } else { + colData = inputData + g * inputOffset; + } int M = outputChannels / groups_; int K = outputHeight * outputWidth; int N = inputChannels / groups_ * filterHeight * filterWidth;