diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp index 08eb6a54902c44bdf84bb082598f36d20d0c8822..75a5b4fe8491d980d27bab51046482cb35a372c3 100644 --- a/paddle/function/GemmConvOp.cpp +++ b/paddle/function/GemmConvOp.cpp @@ -206,8 +206,7 @@ public: colData = reinterpret_cast(memory_->getBuf()); } - Im2ColFunctor im2col; - GemmFunctor gemm; + Im2ColMobileFunctor im2col; size_t inputOffset = imShape.getElements(); size_t outputOffset = (outputChannels / groups_) * outputHeight * outputWidth; @@ -241,19 +240,20 @@ public: // gemm int M = outputChannels / groups_; - gemm(CblasNoTrans, - CblasNoTrans, - M, - N, - K, - 1.0f, - filterData + g * filterOffset + colHeightStart, - kStride, - colData, - N, - beta_, - outputData + g * outputOffset + colWidthStart, - nStride); + BlasGemm::compute( + false, + false, + M, + N, + K, + 1.0f, + filterData + g * filterOffset + colHeightStart, + kStride, + colData, + N, + beta_, + outputData + g * outputOffset + colWidthStart, + nStride); } beta_ = 1.0; } @@ -261,19 +261,19 @@ public: int M = outputChannels / groups_; int N = outputHeight * outputWidth; int K = inputChannels / groups_ * filterHeight * filterWidth; - gemm(CblasNoTrans, - CblasNoTrans, - M, - N, - K, - 1.0f, - filterData + g * filterOffset, - K, - inputData + g * inputOffset, - N, - beta, - outputData + g * outputOffset, - N); + BlasGemm::compute(false, + false, + M, + N, + K, + 1.0f, + filterData + g * filterOffset, + K, + inputData + g * inputOffset, + N, + beta, + outputData + g * outputOffset, + N); } } inputData += inputChannels * inputHeight * inputWidth; diff --git a/paddle/function/Im2Col.h b/paddle/function/Im2Col.h index 0c37fc972484bfbede01d23652e384071bf883af..f43ca465a21253cc863e24603eaec844fa13bfc9 100644 --- a/paddle/function/Im2Col.h +++ b/paddle/function/Im2Col.h @@ -98,4 +98,52 @@ public: int dilationWidth = 1); }; +template +class Im2ColMobileFunctor { +public: + void operator()(const T* imData, + const TensorShape& imShape, + T* colData, + const TensorShape& colShape, + int strideHeight, + int strideWidth, + int paddingHeight, + int paddingWidth, + int colHeightStart, + int colHeightSize, + int colWidthStart, + int colWidthSize) { + int inputHeight = imShape[1]; + int inputWidth = imShape[2]; + int filterHeight = colShape[1]; + int filterWidth = colShape[2]; + int outputWidth = colShape[4]; + + for (int colh = 0; colh < colHeightSize; colh++) { + int wOffset = (colHeightStart + colh) % filterWidth; + int hOffset = ((colHeightStart + colh) / filterWidth) % filterHeight; + int c_im = (colHeightStart + colh) / filterWidth / filterHeight; + + for (int colw = 0; colw < colWidthSize; colw++) { + int h = (colWidthStart + colw) / outputWidth; + int w = (colWidthStart + colw) % outputWidth; + + int imRowIdx = h * strideHeight + hOffset; + int imColIdx = w * strideWidth + wOffset; + if ((imRowIdx - paddingHeight) < 0 || + (imRowIdx - paddingHeight) >= inputHeight || + (imColIdx - paddingWidth) < 0 || + (imColIdx - paddingWidth) >= inputWidth) { + colData[colh * colWidthSize + colw] = T(0); + } else { + imRowIdx += c_im * inputHeight - paddingHeight; + imColIdx -= paddingWidth; + colData[colh * colWidthSize + colw] = + imData[imRowIdx * inputWidth + imColIdx]; + } + } + } + } +}; + } // namespace paddle