diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp index de7b70e271b38ebe3a4c38704d0cced47d010788..cbdbf5335d32d55a0221728758025c9d2cb3e7d1 100644 --- a/paddle/function/GemmConvOp.cpp +++ b/paddle/function/GemmConvOp.cpp @@ -126,14 +126,165 @@ public: inputData += inputChannels * inputHeight * inputWidth; outputData += outputChannels * outputHeight * outputWidth; } + } +}; + #ifdef PADDLE_MOBILE_INFERENCE - if (Device == DEVICE_TYPE_CPU) { - memory_.reset(); + +/* + * \brief Forward calculation of convolution, optimized for mobile. + */ +template +class GemmConvMobileFunction : public ConvFunctionBase { +public: + void init(const FuncConfig& config) override { + ConvFunctionBase::init(config); + } + + void check(const BufferArgs& inputs, const BufferArgs& outputs) override { + const TensorShape& input = inputs[0].shape(); + const TensorShape& filter = inputs[1].shape(); + const TensorShape& output = outputs[0].shape(); + checkShape(input, filter, output); + } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(numInputs_, inputs.size()); + CHECK_EQ(numOutputs_, outputs.size()); + check(inputs, outputs); + // TODO(hedaoyuan): Need to define some index macros, + // to avoid useing 0 and 1. + const TensorShape& input = inputs[0].shape(); + const TensorShape& filter = inputs[1].shape(); + const TensorShape& output = outputs[0].shape(); + + real beta; + if (outputs[0].getArgType() == ADD_TO) { + beta = 1.0; + } else { + beta = 0.0; } -#endif + + size_t batchSize = input[0]; + size_t inputChannels = input[1]; + size_t inputHeight = input[2]; + size_t inputWidth = input[3]; + size_t filterHeight = getFilterHeight(filter); + size_t filterWidth = getFilterWidth(filter); + size_t outputChannels = output[1]; + size_t outputHeight = output[2]; + size_t outputWidth = output[3]; + + real* inputData = inputs[0].data(); + real* filterData = inputs[1].data(); + real* outputData = outputs[0].data(); + bool needIm2col = isNeedIm2col(filter); + + TensorShape imShape = + TensorShape({inputChannels / groups_, inputHeight, inputWidth}); + + TensorShape colShape; + real* colData = NULL; + + size_t colHeight = inputChannels / groups_ * filterHeight * filterWidth; + size_t colWidth = outputHeight * outputWidth; + // Max col matrix height 256, Max col matrix width 1024 + size_t stepColHeight = std::min(colHeight, static_cast(256)); + size_t stepColWidth = std::min(colWidth, static_cast(2048)); + + if (needIm2col) { + colShape = TensorShape({inputChannels / groups_, + filterHeight, + filterWidth, + outputHeight, + outputWidth}); + + resizeBuffer(stepColHeight * stepColWidth * sizeof(real)); + colData = reinterpret_cast(memory_->getBuf()); + } + + Im2ColMobileFunctor im2col; + size_t inputOffset = imShape.getElements(); + size_t outputOffset = + (outputChannels / groups_) * outputHeight * outputWidth; + size_t filterOffset = filter.getElements() / groups_; + + int nStride = colWidth; + int kStride = colHeight; + for (size_t i = 0; i < batchSize; i++) { + for (size_t g = 0; g < groups_; g++) { + if (needIm2col) { + real beta_ = beta; + for (size_t colHeightStart = 0; colHeightStart < colHeight; + colHeightStart += stepColHeight) { + for (size_t colWidthStart = 0; colWidthStart < colWidth; + colWidthStart += stepColWidth) { + int N = std::min(colWidth - colWidthStart, stepColWidth); + int K = std::min(colHeight - colHeightStart, stepColHeight); + // im2col + im2col(inputData + g * inputOffset, + imShape, + colData, + colShape, + strideH(), + strideW(), + paddingH(), + paddingW(), + dilationH(), + dilationW(), + colHeightStart, + K, + colWidthStart, + N); + + // gemm + int M = outputChannels / groups_; + BlasGemm::compute( + false, + false, + M, + N, + K, + 1.0f, + filterData + g * filterOffset + colHeightStart, + kStride, + colData, + N, + beta_, + outputData + g * outputOffset + colWidthStart, + nStride); + } + beta_ = 1.0; + } + } else { + int M = outputChannels / groups_; + int N = outputHeight * outputWidth; + int K = inputChannels / groups_ * filterHeight * filterWidth; + BlasGemm::compute(false, + false, + M, + N, + K, + 1.0f, + filterData + g * filterOffset, + K, + inputData + g * inputOffset, + N, + beta, + outputData + g * outputOffset, + N); + } + } + inputData += inputChannels * inputHeight * inputWidth; + outputData += outputChannels * outputHeight * outputWidth; + } + + memory_.reset(); } }; +#endif + /* * \brief Backward input calculation of convolution. */ @@ -348,7 +499,11 @@ public: } }; +#ifdef PADDLE_MOBILE_INFERENCE +REGISTER_TYPED_FUNC(GemmConv, CPU, GemmConvMobileFunction); +#else REGISTER_TYPED_FUNC(GemmConv, CPU, GemmConvFunction); +#endif REGISTER_TYPED_FUNC(GemmConvGradInput, CPU, GemmConvGradInputFunction); REGISTER_TYPED_FUNC(GemmConvGradFilter, CPU, GemmConvGradFilterFunction); #ifdef PADDLE_WITH_CUDA diff --git a/paddle/function/Im2Col.h b/paddle/function/Im2Col.h index 0c37fc972484bfbede01d23652e384071bf883af..36a9bcf84e4b14965c83627821b71d1c7c0da1b2 100644 --- a/paddle/function/Im2Col.h +++ b/paddle/function/Im2Col.h @@ -98,4 +98,54 @@ public: int dilationWidth = 1); }; +template +class Im2ColMobileFunctor { +public: + void operator()(const T* imData, + const TensorShape& imShape, + T* colData, + const TensorShape& colShape, + int strideHeight, + int strideWidth, + int paddingHeight, + int paddingWidth, + int dilationHeight, + int dilationWidth, + int colHeightStart, + int colHeightSize, + int colWidthStart, + int colWidthSize) { + int inputHeight = imShape[1]; + int inputWidth = imShape[2]; + int filterHeight = colShape[1]; + int filterWidth = colShape[2]; + int outputWidth = colShape[4]; + + for (int colh = 0; colh < colHeightSize; colh++) { + int wOffset = (colHeightStart + colh) % filterWidth; + int hOffset = ((colHeightStart + colh) / filterWidth) % filterHeight; + int c_im = (colHeightStart + colh) / filterWidth / filterHeight; + + for (int colw = 0; colw < colWidthSize; colw++) { + int h = (colWidthStart + colw) / outputWidth; + int w = (colWidthStart + colw) % outputWidth; + + int imRowIdx = h * strideHeight + hOffset * dilationHeight; + int imColIdx = w * strideWidth + wOffset * dilationWidth; + if ((imRowIdx - paddingHeight) < 0 || + (imRowIdx - paddingHeight) >= inputHeight || + (imColIdx - paddingWidth) < 0 || + (imColIdx - paddingWidth) >= inputWidth) { + colData[colh * colWidthSize + colw] = static_cast(0); + } else { + imRowIdx += c_im * inputHeight - paddingHeight; + imColIdx -= paddingWidth; + colData[colh * colWidthSize + colw] = + imData[imRowIdx * inputWidth + imColIdx]; + } + } + } + } +}; + } // namespace paddle diff --git a/paddle/function/Im2ColTest.cpp b/paddle/function/Im2ColTest.cpp index 1f085538d81904dbd5b5d6bcd014adaed22e37d7..3ba866dcdd845403d52f7a85adfef08cbb11c305 100644 --- a/paddle/function/Im2ColTest.cpp +++ b/paddle/function/Im2ColTest.cpp @@ -138,4 +138,86 @@ TEST(Im2ColFunctor, GPU) { TestIm2ColFunctor(); } #endif +template +void TestIm2ColMobileFunctor() { + for (size_t channels : {32}) { + for (size_t inputHeight : {33, 100}) { + for (size_t inputWidth : {32, 96}) { + for (size_t filterHeight : {5}) { + for (size_t filterWidth : {7}) { + for (size_t stride : {2}) { + for (size_t padding : {1}) { + for (size_t dilation : {1, 3}) { + size_t filterSizeH = (filterHeight - 1) * dilation + 1; + size_t filterSizeW = (filterWidth - 1) * dilation + 1; + if (inputHeight + 2 * padding < filterSizeH || + inputWidth + 2 * padding < filterSizeW) + break; + if (padding >= filterSizeH || padding >= filterSizeW) break; + size_t outputHeight = + (inputHeight - filterSizeH + 2 * padding) / stride + 1; + size_t outputWidth = + (inputWidth - filterSizeW + 2 * padding) / stride + 1; + + TensorShape imShape = + TensorShape({channels, inputHeight, inputWidth}); + TensorShape colShape1 = TensorShape({channels, + filterHeight, + filterWidth, + outputHeight, + outputWidth}); + + size_t height = channels * filterHeight * filterWidth; + size_t width = outputHeight * outputWidth; + VectorPtr input1 = + Vector::create(imShape.getElements(), false); + VectorPtr input2 = + Vector::create(imShape.getElements(), false); + MatrixPtr output1 = + Matrix::create(height, width, false, false); + MatrixPtr output2 = + Matrix::create(height, width, false, false); + input1->uniform(0.001, 1); + input2->copyFrom(*input1); + + Im2ColFunctor im2Col1; + Im2ColMobileFunctor im2Col2; + im2Col1(input1->getData(), + imShape, + output1->getData(), + colShape1, + stride, + stride, + padding, + padding, + dilation, + dilation); + im2Col2(input2->getData(), + imShape, + output2->getData(), + colShape1, + stride, + stride, + padding, + padding, + dilation, + dilation, + 0, + height, + 0, + width); + + autotest::TensorCheckEqual(*output1, *output2); + } + } + } + } + } + } + } + } +} + +TEST(Im2ColFunctor, Mobile) { TestIm2ColMobileFunctor(); } + } // namespace paddle