diff --git a/paddle/function/ConvOp.h b/paddle/function/ConvOp.h index bb4f48364b9b454af7d37fe4d3c340666e53285c..62c9fd9b2c21beed91e111907024bb677e6796bd 100644 --- a/paddle/function/ConvOp.h +++ b/paddle/function/ConvOp.h @@ -109,6 +109,13 @@ protected: return filter[filter.ndims() - 1]; } + // determine whether im2col needs to be performed + inline bool isSkipIm2col(const TensorShape& filter) const { + return (getFilterHeight(filter) == 1 && getFilterWidth(filter) == 1 && + strideH() == 1 && strideW() == 1 && paddingH() == 0 && + paddingW() == 0); + } + std::vector strides_; std::vector paddings_; diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp index 9deb2739fcfff935a98a0b5b31b5d11819d81227..263795596591c37bb7d3275a9b0fe2efae6a405d 100644 --- a/paddle/function/GemmConvOp.cpp +++ b/paddle/function/GemmConvOp.cpp @@ -66,16 +66,23 @@ public: real* inputData = inputs[0].data(); real* filterData = inputs[1].data(); real* outputData = outputs[0].data(); + bool skipIm2col = isSkipIm2col(filter); + TensorShape imShape = TensorShape({inputChannels / groups_, inputHeight, inputWidth}); - TensorShape colShape = TensorShape({inputChannels / groups_, - filterHeight, - filterWidth, - outputHeight, - outputWidth}); - resizeBuffer(colShape.getElements()); - real* colData = reinterpret_cast(memory_->getBuf()); + TensorShape colShape; + real *colBuffer, *colData = NULL; + + if (!skipIm2col) { + colShape = TensorShape({inputChannels / groups_, + filterHeight, + filterWidth, + outputHeight, + outputWidth}); + resizeBuffer(colShape.getElements()); + colData = reinterpret_cast(memory_->getBuf()); + } Im2ColFunctor im2col; GemmFunctor gemm; @@ -86,15 +93,18 @@ public: for (size_t i = 0; i < batchSize; i++) { for (size_t g = 0; g < groups_; g++) { - im2col(inputData + g * inputOffset, - imShape, - colData, - colShape, - strideH(), - strideW(), - paddingH(), - paddingW()); - + colBuffer = inputData + g * inputOffset; + if (!skipIm2col) { + im2col(inputData + g * inputOffset, + imShape, + colData, + colShape, + strideH(), + strideW(), + paddingH(), + paddingW()); + colBuffer = colData; + } int M = outputChannels / groups_; int N = outputHeight * outputWidth; int K = inputChannels / groups_ * filterHeight * filterWidth; @@ -106,7 +116,7 @@ public: 1.0f, filterData + g * filterOffset, K, - colData, + colBuffer, N, beta, outputData + g * outputOffset, @@ -159,19 +169,27 @@ public: real* outputGrad = inputs[0].data(); real* filterData = inputs[1].data(); real* inputGrad = outputs[0].data(); + bool skipIm2col = isSkipIm2col(filter); + TensorShape imShape = TensorShape({inputChannels / groups_, inputHeight, inputWidth}); - TensorShape colShape = TensorShape({inputChannels / groups_, - filterHeight, - filterWidth, - outputHeight, - outputWidth}); - resizeBuffer(colShape.getElements()); - real* colData = reinterpret_cast(memory_->getBuf()); + TensorShape colShape; + real *colBuffer, *colData = NULL; + + if (!skipIm2col) { + colShape = TensorShape({inputChannels / groups_, + filterHeight, + filterWidth, + outputHeight, + outputWidth}); + resizeBuffer(colShape.getElements()); + colData = reinterpret_cast(memory_->getBuf()); + } Col2ImFunctor col2im; GemmFunctor gemm; + size_t inputOffset = imShape.getElements(); size_t outputOffset = (outputChannels / groups_) * outputHeight * outputWidth; @@ -182,6 +200,12 @@ public: int K = outputChannels / groups_; int N = outputHeight * outputWidth; int M = inputChannels / groups_ * filterHeight * filterWidth; + colBuffer = colData; + real scale = 0.0f; + if (skipIm2col) { + colBuffer = inputGrad + g * inputOffset; + scale = 1.0f; + } gemm(CblasTrans, CblasNoTrans, M, @@ -192,17 +216,19 @@ public: M, outputGrad + g * outputOffset, N, - 0.0f, - colData, + scale, + colBuffer, N); - col2im(inputGrad + g * inputOffset, - imShape, - colData, - colShape, - strideH(), - strideW(), - paddingH(), - paddingW()); + if (!skipIm2col) { + col2im(inputGrad + g * inputOffset, + imShape, + colBuffer, + colShape, + strideH(), + strideW(), + paddingH(), + paddingW()); + } } inputGrad += inputChannels * inputHeight * inputWidth; outputGrad += outputChannels * outputHeight * outputWidth; @@ -255,16 +281,23 @@ public: real* outputGrad = inputs[0].data(); real* inputData = inputs[1].data(); real* filterGrad = outputs[0].data(); + bool skipIm2col = isSkipIm2col(filter); + TensorShape imShape = TensorShape({inputChannels / groups_, inputHeight, inputWidth}); - TensorShape colShape = TensorShape({inputChannels / groups_, - filterHeight, - filterWidth, - outputHeight, - outputWidth}); - resizeBuffer(colShape.getElements()); - real* colData = reinterpret_cast(memory_->getBuf()); + TensorShape colShape; + real *colBuffer, *colData = NULL; + + if (!skipIm2col) { + colShape = TensorShape({inputChannels / groups_, + filterHeight, + filterWidth, + outputHeight, + outputWidth}); + resizeBuffer(colShape.getElements()); + colData = reinterpret_cast(memory_->getBuf()); + } Im2ColFunctor im2col; GemmFunctor gemm; @@ -274,15 +307,18 @@ public: size_t filterOffset = filter.getElements() / groups_; for (size_t i = 0; i < batchSize; i++) { for (size_t g = 0; g < groups_; g++) { - im2col(inputData + g * inputOffset, - imShape, - colData, - colShape, - strideH(), - strideW(), - paddingH(), - paddingW()); - + colBuffer = inputData + g * inputOffset; + if (!skipIm2col) { + im2col(inputData + g * inputOffset, + imShape, + colData, + colShape, + strideH(), + strideW(), + paddingH(), + paddingW()); + colBuffer = colData; + } int M = outputChannels / groups_; int K = outputHeight * outputWidth; int N = inputChannels / groups_ * filterHeight * filterWidth; @@ -294,7 +330,7 @@ public: 1.0f, outputGrad + g * outputOffset, K, - colData, + colBuffer, K, i == 0 ? beta : 1.0f, filterGrad + g * filterOffset,