From a850dec991d7d6d28f2669a959b3198a7a796ce9 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Tue, 26 Dec 2017 16:07:09 +0800 Subject: [PATCH] Add dilation. --- paddle/function/GemmConvOp.cpp | 2 ++ paddle/function/Im2Col.h | 6 ++++-- paddle/function/Im2ColTest.cpp | 4 +++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp index 75a5b4fe849..acf1415ebff 100644 --- a/paddle/function/GemmConvOp.cpp +++ b/paddle/function/GemmConvOp.cpp @@ -233,6 +233,8 @@ public: strideW(), paddingH(), paddingW(), + dilationH(), + dilationW(), colHeightStart, K, colWidthStart, diff --git a/paddle/function/Im2Col.h b/paddle/function/Im2Col.h index f43ca465a21..1053e4fd232 100644 --- a/paddle/function/Im2Col.h +++ b/paddle/function/Im2Col.h @@ -109,6 +109,8 @@ public: int strideWidth, int paddingHeight, int paddingWidth, + int dilationHeight, + int dilationWidth, int colHeightStart, int colHeightSize, int colWidthStart, @@ -128,8 +130,8 @@ public: int h = (colWidthStart + colw) / outputWidth; int w = (colWidthStart + colw) % outputWidth; - int imRowIdx = h * strideHeight + hOffset; - int imColIdx = w * strideWidth + wOffset; + int imRowIdx = h * strideHeight + hOffset * dilationHeight; + int imColIdx = w * strideWidth + wOffset * dilationWidth; if ((imRowIdx - paddingHeight) < 0 || (imRowIdx - paddingHeight) >= inputHeight || (imColIdx - paddingWidth) < 0 || diff --git a/paddle/function/Im2ColTest.cpp b/paddle/function/Im2ColTest.cpp index 0dc58696f74..c573469168d 100644 --- a/paddle/function/Im2ColTest.cpp +++ b/paddle/function/Im2ColTest.cpp @@ -147,7 +147,7 @@ void TestIm2ColMobileFunctor() { for (size_t filterWidth : {3, 7}) { for (size_t stride : {1, 2}) { for (size_t padding : {0, 1}) { - for (size_t dilation : {1 /*, 3*/}) { + for (size_t dilation : {1, 3}) { size_t filterSizeH = (filterHeight - 1) * dilation + 1; size_t filterSizeW = (filterWidth - 1) * dilation + 1; if (inputHeight + 2 * padding < filterSizeH || @@ -200,6 +200,8 @@ void TestIm2ColMobileFunctor() { stride, padding, padding, + dilation, + dilation, 0, height, 0, -- GitLab