diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp index 75a5b4fe8491d980d27bab51046482cb35a372c3..acf1415ebffe16c93109aa07e44aa3f47992dbe3 100644 --- a/paddle/function/GemmConvOp.cpp +++ b/paddle/function/GemmConvOp.cpp @@ -233,6 +233,8 @@ public: strideW(), paddingH(), paddingW(), + dilationH(), + dilationW(), colHeightStart, K, colWidthStart, diff --git a/paddle/function/Im2Col.h b/paddle/function/Im2Col.h index f43ca465a21253cc863e24603eaec844fa13bfc9..1053e4fd232b8803c9661ece823da88715ffc220 100644 --- a/paddle/function/Im2Col.h +++ b/paddle/function/Im2Col.h @@ -109,6 +109,8 @@ public: int strideWidth, int paddingHeight, int paddingWidth, + int dilationHeight, + int dilationWidth, int colHeightStart, int colHeightSize, int colWidthStart, @@ -128,8 +130,8 @@ public: int h = (colWidthStart + colw) / outputWidth; int w = (colWidthStart + colw) % outputWidth; - int imRowIdx = h * strideHeight + hOffset; - int imColIdx = w * strideWidth + wOffset; + int imRowIdx = h * strideHeight + hOffset * dilationHeight; + int imColIdx = w * strideWidth + wOffset * dilationWidth; if ((imRowIdx - paddingHeight) < 0 || (imRowIdx - paddingHeight) >= inputHeight || (imColIdx - paddingWidth) < 0 || diff --git a/paddle/function/Im2ColTest.cpp b/paddle/function/Im2ColTest.cpp index 0dc58696f74ec3c402e3e4d6a6c56bc78c962bf9..c573469168d3c18cfc574570c0681beb60a52693 100644 --- a/paddle/function/Im2ColTest.cpp +++ b/paddle/function/Im2ColTest.cpp @@ -147,7 +147,7 @@ void TestIm2ColMobileFunctor() { for (size_t filterWidth : {3, 7}) { for (size_t stride : {1, 2}) { for (size_t padding : {0, 1}) { - for (size_t dilation : {1 /*, 3*/}) { + for (size_t dilation : {1, 3}) { size_t filterSizeH = (filterHeight - 1) * dilation + 1; size_t filterSizeW = (filterWidth - 1) * dilation + 1; if (inputHeight + 2 * padding < filterSizeH || @@ -200,6 +200,8 @@ void TestIm2ColMobileFunctor() { stride, padding, padding, + dilation, + dilation, 0, height, 0,