diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp index a40e5d9d2e76605525f0956445fc43c693933cf8..3f10bb9c83754b92174401b38feea28eec8e8386 100644 --- a/paddle/function/GemmConvOp.cpp +++ b/paddle/function/GemmConvOp.cpp @@ -12,101 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "GemmConvOp.h" +#include "ConvOp.h" #include "GemmFunctor.h" +#include "Im2Col.h" #include "paddle/math/MemoryHandle.h" namespace paddle { -/* - * imData = [input_channels, input_height, input_width] - * colData = [input_channels, filter_height, filter_width, - * output_height, output_width] - */ -template -class Im2ColFunctor { -public: - void operator()(const T* imData, - int inputChannels, - int inputHeight, - int inputWidth, - int filterHeight, - int filterWidth, - int strideHeight, - int strideWidth, - int paddingHeight, - int paddingWidth, - int outputHeight, - int outputWidth, - T* colData) { - int channelsCol = inputChannels * filterHeight * filterWidth; - - for (int c = 0; c < channelsCol; ++c) { - int wOffset = c % filterWidth; - int hOffset = (c / filterWidth) % filterHeight; - int c_im = c / filterWidth / filterHeight; - for (int h = 0; h < outputHeight; ++h) { - for (int w = 0; w < outputWidth; ++w) { - int imRowIdx = h * strideHeight + hOffset; - int imColIdx = w * strideWidth + wOffset; - if ((imRowIdx - paddingHeight) < 0 || - (imRowIdx - paddingHeight) >= inputHeight || - (imColIdx - paddingWidth) < 0 || - (imColIdx - paddingWidth) >= inputWidth) { - colData[(c * outputHeight + h) * outputWidth + w] = T(0); - } else { - imRowIdx += c_im * inputHeight - paddingHeight; - imColIdx -= paddingWidth; - colData[(c * outputHeight + h) * outputWidth + w] = - imData[imRowIdx * inputWidth + imColIdx]; - } - } - } - } - } -}; - -template -class Col2ImFunctor { -public: - void operator()(const T* colData, - int inputChannels, - int inputHeight, - int inputWidth, - int filterHeight, - int filterWidth, - int strideHeight, - int strideWidth, - int paddingHeight, - int paddingWidth, - int outputHeight, - int outputWidth, - T* imData) { - int channelsCol = inputChannels * filterHeight * filterWidth; - - for (int c = 0; c < channelsCol; ++c) { - int wOffset = c % filterWidth; - int hOffset = (c / filterWidth) % filterHeight; - int c_im = c / filterWidth / filterHeight; - for (int h = 0; h < outputHeight; ++h) { - for (int w = 0; w < outputWidth; ++w) { - int imRowIdx = h * strideHeight + hOffset; - int imColIdx = w * strideWidth + wOffset; - if ((imRowIdx - paddingHeight) >= 0 && - (imRowIdx - paddingHeight) < inputHeight && - (imColIdx - paddingWidth) >= 0 && - (imColIdx - paddingWidth) < inputWidth) { - imRowIdx += c_im * inputHeight - paddingHeight; - imColIdx -= paddingWidth; - imData[imRowIdx * inputWidth + imColIdx] += - colData[(c * outputHeight + h) * outputWidth + w]; - } - } - } - } - } -}; - /* * \brief Forward calculation of convolution. */ @@ -155,15 +67,20 @@ public: real* inputData = inputs[0].data(); real* filterData = inputs[1].data(); real* outputData = outputs[0].data(); - - size_t size = inputChannels / groups_ * filterHeight * filterWidth * - outputHeight * outputWidth; - resizeBuffer(size); + TensorShape imShape = + TensorShape({inputChannels / groups_, inputHeight, inputWidth}); + TensorShape colShape = TensorShape({inputChannels / groups_, + filterHeight, + filterWidth, + outputHeight, + outputWidth}); + + resizeBuffer(colShape.getElements()); real* colData = reinterpret_cast(memory_->getBuf()); - Im2ColFunctor im2col; + Im2ColFunctor im2col; GemmFunctor gemm; - size_t inputOffset = (inputChannels / groups_) * inputHeight * inputWidth; + size_t inputOffset = imShape.getElements(); size_t outputOffset = (outputChannels / groups_) * outputHeight * outputWidth; size_t filterOffset = filter.getElements() / groups_; @@ -171,18 +88,13 @@ public: for (size_t i = 0; i < batchSize; i++) { for (size_t g = 0; g < groups_; g++) { im2col(inputData + g * inputOffset, - inputChannels / groups_, - inputHeight, - inputWidth, - filterHeight, - filterWidth, + imShape, + colData, + colShape, strideH(), strideW(), paddingH(), - paddingW(), - outputHeight, - outputWidth, - colData); + paddingW()); int M = outputChannels / groups_; int N = outputHeight * outputWidth; @@ -249,15 +161,20 @@ public: real* outputGrad = inputs[0].data(); real* filterData = inputs[1].data(); real* inputGrad = outputs[0].data(); - - size_t size = inputChannels / groups_ * filterHeight * filterWidth * - outputHeight * outputWidth; - resizeBuffer(size); + TensorShape imShape = + TensorShape({inputChannels / groups_, inputHeight, inputWidth}); + TensorShape colShape = TensorShape({inputChannels / groups_, + filterHeight, + filterWidth, + outputHeight, + outputWidth}); + + resizeBuffer(colShape.getElements()); real* colData = reinterpret_cast(memory_->getBuf()); - Col2ImFunctor col2im; + Col2ImFunctor col2im; GemmFunctor gemm; - size_t inputOffset = (inputChannels / groups_) * inputHeight * inputWidth; + size_t inputOffset = imShape.getElements(); size_t outputOffset = (outputChannels / groups_) * outputHeight * outputWidth; size_t filterOffset = filter.getElements() / groups_; @@ -280,20 +197,14 @@ public: 0.0f, colData, N); - - col2im(colData, - inputChannels / groups_, - inputHeight, - inputWidth, - filterHeight, - filterWidth, + col2im(inputGrad + g * inputOffset, + imShape, + colData, + colShape, strideH(), strideW(), paddingH(), - paddingW(), - outputHeight, - outputWidth, - inputGrad + g * inputOffset); + paddingW()); } inputGrad += inputChannels * inputHeight * inputWidth; outputGrad += outputChannels * outputHeight * outputWidth; @@ -347,33 +258,33 @@ public: real* outputGrad = inputs[0].data(); real* inputData = inputs[1].data(); real* filterGrad = outputs[0].data(); - - size_t size = inputChannels / groups_ * filterHeight * filterWidth * - outputHeight * outputWidth; - resizeBuffer(size); + TensorShape imShape = + TensorShape({inputChannels / groups_, inputHeight, inputWidth}); + TensorShape colShape = TensorShape({inputChannels / groups_, + filterHeight, + filterWidth, + outputHeight, + outputWidth}); + + resizeBuffer(colShape.getElements()); real* colData = reinterpret_cast(memory_->getBuf()); - Im2ColFunctor im2col; + Im2ColFunctor im2col; GemmFunctor gemm; - size_t inputOffset = (inputChannels / groups_) * inputHeight * inputWidth; + size_t inputOffset = imShape.getElements(); size_t outputOffset = (outputChannels / groups_) * outputHeight * outputWidth; size_t filterOffset = filter.getElements() / groups_; for (size_t i = 0; i < batchSize; i++) { for (size_t g = 0; g < groups_; g++) { im2col(inputData + g * inputOffset, - inputChannels / groups_, - inputHeight, - inputWidth, - filterHeight, - filterWidth, + imShape, + colData, + colShape, strideH(), strideW(), paddingH(), - paddingW(), - outputHeight, - outputWidth, - colData); + paddingW()); int M = outputChannels / groups_; int K = outputHeight * outputWidth;