From 784e59406c541def813e75db5fd11bb0361eccef Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Thu, 11 Jan 2018 14:02:57 +0800 Subject: [PATCH] Bug fix of Im2ColMobileFunctor. --- paddle/function/Im2Col.h | 58 +++------------------------------- paddle/function/Im2ColTest.cpp | 6 ++-- 2 files changed, 7 insertions(+), 57 deletions(-) diff --git a/paddle/function/Im2Col.h b/paddle/function/Im2Col.h index 361ba4c18a2..915119e291c 100644 --- a/paddle/function/Im2Col.h +++ b/paddle/function/Im2Col.h @@ -98,58 +98,6 @@ public: int dilationWidth = 1); }; -#if 0 -template -class Im2ColMobileFunctor { -public: - void operator()(const T* imData, - const TensorShape& imShape, - T* colData, - const TensorShape& colShape, - int strideHeight, - int strideWidth, - int paddingHeight, - int paddingWidth, - int dilationHeight, - int dilationWidth, - int colHeightStart, - int colHeightSize, - int colWidthStart, - int colWidthSize) { - int inputHeight = imShape[1]; - int inputWidth = imShape[2]; - int filterHeight = colShape[1]; - int filterWidth = colShape[2]; - int outputWidth = colShape[4]; - - for (int colh = 0; colh < colHeightSize; colh++) { - int wOffset = (colHeightStart + colh) % filterWidth; - int hOffset = ((colHeightStart + colh) / filterWidth) % filterHeight; - int c_im = (colHeightStart + colh) / filterWidth / filterHeight; - - for (int colw = 0; colw < colWidthSize; colw++) { - int h = (colWidthStart + colw) / outputWidth; - int w = (colWidthStart + colw) % outputWidth; - - int imRowIdx = h * strideHeight + hOffset * dilationHeight; - int imColIdx = w * strideWidth + wOffset * dilationWidth; - if ((imRowIdx - paddingHeight) < 0 || - (imRowIdx - paddingHeight) >= inputHeight || - (imColIdx - paddingWidth) < 0 || - (imColIdx - paddingWidth) >= inputWidth) { - colData[colh * colWidthSize + colw] = static_cast(0); - } else { - imRowIdx += c_im * inputHeight - paddingHeight; - imColIdx -= paddingWidth; - colData[colh * colWidthSize + colw] = - imData[imRowIdx * inputWidth + imColIdx]; - } - } - } - } -}; -#endif - template class Im2ColMobileFunctor { public: @@ -178,12 +126,14 @@ public: T* dstData = colData + oh * outputWidth; for (int fh = 0; fh < filterHeight; fh++) { for (int fw = 0; fw < filterWidth; fw++) { - int imRowIdx = (oh + colOffset) * strideHeight + fh - paddingHeight; + int imRowIdx = (oh + colOffset) * strideHeight + + fh * dilationHeight - paddingHeight; if (imRowIdx < 0 || imRowIdx >= inputHeight) { memset(dstData, 0, outputWidth * sizeof(T)); } else { for (int ow = 0; ow < outputWidth; ow++) { - int imColIdx = ow * strideWidth + fw - paddingWidth; + int imColIdx = + ow * strideWidth + fw * dilationWidth - paddingWidth; if (imColIdx < 0 || imColIdx >= inputWidth) { dstData[ow] = T(0); } else { diff --git a/paddle/function/Im2ColTest.cpp b/paddle/function/Im2ColTest.cpp index 3ba866dcdd8..fe44a8bf790 100644 --- a/paddle/function/Im2ColTest.cpp +++ b/paddle/function/Im2ColTest.cpp @@ -202,10 +202,10 @@ void TestIm2ColMobileFunctor() { padding, dilation, dilation, + channels, 0, - height, - 0, - width); + outputHeight, + outputHeight * outputWidth); autotest::TensorCheckEqual(*output1, *output2); } -- GitLab