diff --git a/paddle/function/ConvOp.h b/paddle/function/ConvOp.h index 62c9fd9b2c21beed91e111907024bb677e6796bd..baf78bc6c88d0d294f4457b81c52b22e425d9fdb 100644 --- a/paddle/function/ConvOp.h +++ b/paddle/function/ConvOp.h @@ -110,10 +110,10 @@ protected: } // determine whether im2col needs to be performed - inline bool isSkipIm2col(const TensorShape& filter) const { - return (getFilterHeight(filter) == 1 && getFilterWidth(filter) == 1 && - strideH() == 1 && strideW() == 1 && paddingH() == 0 && - paddingW() == 0); + inline bool isNeedIm2col(const TensorShape& filter) const { + return !(getFilterHeight(filter) == 1 && getFilterWidth(filter) == 1 && + strideH() == 1 && strideW() == 1 && paddingH() == 0 && + paddingW() == 0); } std::vector strides_; diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp index 263795596591c37bb7d3275a9b0fe2efae6a405d..0ada4d70a0c7d13f9b5fb1a42eac07fc4c775a87 100644 --- a/paddle/function/GemmConvOp.cpp +++ b/paddle/function/GemmConvOp.cpp @@ -66,15 +66,15 @@ public: real* inputData = inputs[0].data(); real* filterData = inputs[1].data(); real* outputData = outputs[0].data(); - bool skipIm2col = isSkipIm2col(filter); + bool needIm2col = isNeedIm2col(filter); TensorShape imShape = TensorShape({inputChannels / groups_, inputHeight, inputWidth}); TensorShape colShape; - real *colBuffer, *colData = NULL; + real* colData = NULL; - if (!skipIm2col) { + if (needIm2col) { colShape = TensorShape({inputChannels / groups_, filterHeight, filterWidth, @@ -93,8 +93,7 @@ public: for (size_t i = 0; i < batchSize; i++) { for (size_t g = 0; g < groups_; g++) { - colBuffer = inputData + g * inputOffset; - if (!skipIm2col) { + if (needIm2col) { im2col(inputData + g * inputOffset, imShape, colData, @@ -103,7 +102,8 @@ public: strideW(), paddingH(), paddingW()); - colBuffer = colData; + } else { + colData = inputData + g * inputOffset; } int M = outputChannels / groups_; int N = outputHeight * outputWidth; @@ -116,7 +116,7 @@ public: 1.0f, filterData + g * filterOffset, K, - colBuffer, + colData, N, beta, outputData + g * outputOffset, @@ -169,15 +169,15 @@ public: real* outputGrad = inputs[0].data(); real* filterData = inputs[1].data(); real* inputGrad = outputs[0].data(); - bool skipIm2col = isSkipIm2col(filter); + bool needIm2col = isNeedIm2col(filter); TensorShape imShape = TensorShape({inputChannels / groups_, inputHeight, inputWidth}); TensorShape colShape; - real *colBuffer, *colData = NULL; + real* colData = NULL; - if (!skipIm2col) { + if (needIm2col) { colShape = TensorShape({inputChannels / groups_, filterHeight, filterWidth, @@ -200,10 +200,9 @@ public: int K = outputChannels / groups_; int N = outputHeight * outputWidth; int M = inputChannels / groups_ * filterHeight * filterWidth; - colBuffer = colData; real scale = 0.0f; - if (skipIm2col) { - colBuffer = inputGrad + g * inputOffset; + if (!needIm2col) { + colData = inputGrad + g * inputOffset; scale = 1.0f; } gemm(CblasTrans, @@ -217,12 +216,12 @@ public: outputGrad + g * outputOffset, N, scale, - colBuffer, + colData, N); - if (!skipIm2col) { + if (needIm2col) { col2im(inputGrad + g * inputOffset, imShape, - colBuffer, + colData, colShape, strideH(), strideW(), @@ -281,15 +280,15 @@ public: real* outputGrad = inputs[0].data(); real* inputData = inputs[1].data(); real* filterGrad = outputs[0].data(); - bool skipIm2col = isSkipIm2col(filter); + bool needIm2col = isNeedIm2col(filter); TensorShape imShape = TensorShape({inputChannels / groups_, inputHeight, inputWidth}); TensorShape colShape; - real *colBuffer, *colData = NULL; + real* colData = NULL; - if (!skipIm2col) { + if (needIm2col) { colShape = TensorShape({inputChannels / groups_, filterHeight, filterWidth, @@ -307,8 +306,7 @@ public: size_t filterOffset = filter.getElements() / groups_; for (size_t i = 0; i < batchSize; i++) { for (size_t g = 0; g < groups_; g++) { - colBuffer = inputData + g * inputOffset; - if (!skipIm2col) { + if (needIm2col) { im2col(inputData + g * inputOffset, imShape, colData, @@ -317,7 +315,8 @@ public: strideW(), paddingH(), paddingW()); - colBuffer = colData; + } else { + colData = inputData + g * inputOffset; } int M = outputChannels / groups_; int K = outputHeight * outputWidth; @@ -330,7 +329,7 @@ public: 1.0f, outputGrad + g * outputOffset, K, - colBuffer, + colData, K, i == 0 ? beta : 1.0f, filterGrad + g * filterOffset,