From 328169a95539fbe1371d0fa99f2f5fa865517c2e Mon Sep 17 00:00:00 2001 From: xzl Date: Fri, 27 Oct 2017 21:32:33 +0800 Subject: [PATCH] im2col cpu gpu dilation support --- paddle/function/Im2Col.h | 8 +++- paddle/function/Im2ColOp.cpp | 38 +++++++++++------- paddle/function/Im2ColOpGpu.cu | 71 ++++++++++++++++++++++++++-------- 3 files changed, 85 insertions(+), 32 deletions(-) diff --git a/paddle/function/Im2Col.h b/paddle/function/Im2Col.h index 1e0cff436..0c37fc972 100644 --- a/paddle/function/Im2Col.h +++ b/paddle/function/Im2Col.h @@ -78,7 +78,9 @@ public: int strideHeight, int strideWidth, int paddingHeight, - int paddingWidth); + int paddingWidth, + int dilationHeight = 1, + int dilationWidth = 1); }; template @@ -91,7 +93,9 @@ public: int strideHeight, int strideWidth, int paddingHeight, - int paddingWidth); + int paddingWidth, + int dilationHeight = 1, + int dilationWidth = 1); }; } // namespace paddle diff --git a/paddle/function/Im2ColOp.cpp b/paddle/function/Im2ColOp.cpp index b7d1eb1ed..f864d42f8 100644 --- a/paddle/function/Im2ColOp.cpp +++ b/paddle/function/Im2ColOp.cpp @@ -31,7 +31,9 @@ public: int strideHeight, int strideWidth, int paddingHeight, - int paddingWidth) { + int paddingWidth, + int dilationHeight, + int dilationWidth) { int inputChannels = imShape[0]; int inputHeight = imShape[1]; int inputWidth = imShape[2]; @@ -47,8 +49,8 @@ public: int c_im = c / filterWidth / filterHeight; for (int h = 0; h < outputHeight; ++h) { for (int w = 0; w < outputWidth; ++w) { - int imRowIdx = h * strideHeight + hOffset; - int imColIdx = w * strideWidth + wOffset; + int imRowIdx = h * strideHeight + hOffset * dilationHeight; + int imColIdx = w * strideWidth + wOffset * dilationWidth; if ((imRowIdx - paddingHeight) < 0 || (imRowIdx - paddingHeight) >= inputHeight || (imColIdx - paddingWidth) < 0 || @@ -81,7 +83,9 @@ public: int strideHeight, int strideWidth, int paddingHeight, - int paddingWidth) { + int paddingWidth, + int dilationHeight, + int dilationWidth) { int inputChannels = imShape[0]; int inputHeight = imShape[1]; int inputWidth = imShape[2]; @@ -97,8 +101,8 @@ public: int c_im = c / filterWidth / filterHeight; for (int h = 0; h < outputHeight; ++h) { for (int w = 0; w < outputWidth; ++w) { - int imRowIdx = h * strideHeight + hOffset; - int imColIdx = w * strideWidth + wOffset; + int imRowIdx = h * strideHeight + hOffset * dilationHeight; + int imColIdx = w * strideWidth + wOffset * dilationWidth; if ((imRowIdx - paddingHeight) >= 0 && (imRowIdx - paddingHeight) < inputHeight && (imColIdx - paddingWidth) >= 0 && @@ -134,7 +138,9 @@ public: int strideHeight, int strideWidth, int paddingHeight, - int paddingWidth) { + int paddingWidth, + int dilationHeight = 1, + int dilationWidth = 1) { int inputChannels = imShape[0]; int inputHeight = imShape[1]; int inputWidth = imShape[2]; @@ -147,9 +153,10 @@ public: for (int channel = 0; channel < inputChannels; ++channel) { for (int filterH = 0; filterH < filterHeight; ++filterH) { for (int filterW = 0; filterW < filterWidth; ++filterW) { - int imRowOffset = - outputH * strideHeight + filterH - paddingHeight; - int imColOffset = outputW * strideWidth + filterW - paddingWidth; + int imRowOffset = outputH * strideHeight + + filterH * dilationHeight - paddingHeight; + int imColOffset = outputW * strideWidth + + filterW * dilationWidth - paddingWidth; int colDataOffset = (((outputH * outputWidth + outputW) * inputChannels + channel) * @@ -189,7 +196,9 @@ public: int strideHeight, int strideWidth, int paddingHeight, - int paddingWidth) { + int paddingWidth, + int dilationHeight = 1, + int dilationWidth = 1) { int inputChannels = imShape[0]; int inputHeight = imShape[1]; int inputWidth = imShape[2]; @@ -202,9 +211,10 @@ public: for (int channel = 0; channel < inputChannels; ++channel) { for (int filterH = 0; filterH < filterHeight; ++filterH) { for (int filterW = 0; filterW < filterWidth; ++filterW) { - int imRowOffset = - outputH * strideHeight + filterH - paddingHeight; - int imColOffset = outputW * strideWidth + filterW - paddingWidth; + int imRowOffset = outputH * strideHeight + + filterH * dilationHeight - paddingHeight; + int imColOffset = outputW * strideWidth + + filterW * dilationWidth - paddingWidth; int colDataOffset = (((outputH * outputWidth + outputW) * inputChannels + channel) * diff --git a/paddle/function/Im2ColOpGpu.cu b/paddle/function/Im2ColOpGpu.cu index bd9861049..71da11b95 100644 --- a/paddle/function/Im2ColOpGpu.cu +++ b/paddle/function/Im2ColOpGpu.cu @@ -28,6 +28,8 @@ __global__ void im2col(const T* data_im, int strideW, int paddingH, int paddingW, + int dilationH, + int dilationW, int height_col, int width_col, T* data_col) { @@ -44,8 +46,8 @@ __global__ void im2col(const T* data_im, data_col += (channel_out * height_col + h_out) * width_col + w_out; for (int i = 0; i < blockH; ++i) { for (int j = 0; j < blockW; ++j) { - int rIdx = int(h_in + i); - int cIdx = int(w_in + j); + int rIdx = int(h_in + i * dilationH); + int cIdx = int(w_in + j * dilationW); if ((rIdx - (int)paddingH) >= (int)height || (rIdx - (int)paddingH) < 0 || (cIdx - (int)paddingW) >= (int)width || @@ -77,7 +79,9 @@ public: int strideHeight, int strideWidth, int paddingHeight, - int paddingWidth) { + int paddingWidth, + int dilationHeight, + int dilationWidth) { int inputChannels = imShape[0]; int inputHeight = imShape[1]; int inputWidth = imShape[2]; @@ -102,6 +106,8 @@ public: strideWidth, paddingHeight, paddingWidth, + dilationHeight, + dilationWidth, outputHeight, outputWidth, colData); @@ -121,6 +127,8 @@ __global__ void col2im(size_t n, size_t strideW, size_t paddingH, size_t paddingW, + size_t dilationH, + size_t dilationW, size_t height_col, size_t width_col, T* data_im) { @@ -131,23 +139,34 @@ __global__ void col2im(size_t n, int w = int(index % width); int h = int((index / width) % height); int c = int(index / (width * height)); + int filterH = (blockH - 1) * dilationH + 1; + int filterW = (blockW - 1) * dilationW + 1; + if ((w - (int)paddingW) >= 0 && (w - (int)paddingW) < (width - 2 * paddingW) && (h - (int)paddingH) >= 0 && (h - paddingH) < (height - 2 * paddingH)) { // compute the start and end of the output int w_col_start = - (w < (int)blockW) ? 0 : (w - int(blockW)) / (int)strideW + 1; + (w < (int)filterW) ? 0 : (w - int(filterW)) / (int)strideW + 1; int w_col_end = min((int)(w / (int)strideW + 1), (int)(width_col)); int h_col_start = - (h < (int)blockH) ? 0 : (h - (int)blockH) / (int)strideH + 1; + (h < (int)filterH) ? 0 : (h - (int)filterH) / (int)strideH + 1; int h_col_end = min(int(h / strideH + 1), int(height_col)); + for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { // the col location: [c * width * height + h_out, w_out] - int c_col = int(c * blockH * blockW) + - (h - h_col * (int)strideH) * (int)blockW + - (w - w_col * (int)strideW); - val += data_col[(c_col * height_col + h_col) * width_col + w_col]; + int h_k = (h - h_col * strideH); + int w_k = (w - w_col * strideW); + if (h_k % dilationH == 0 && w_k % dilationW == 0) { + h_k /= dilationH; + w_k /= dilationW; + int c_col = + (((c * blockH + h_k) * blockW + w_k) * height_col + h_col) * + width_col + + w_col; + val += data_col[c_col]; + } } } h -= paddingH; @@ -173,7 +192,9 @@ public: int strideHeight, int strideWidth, int paddingHeight, - int paddingWidth) { + int paddingWidth, + int dilationHeight, + int dilationWidth) { int inputChannels = imShape[0]; int inputHeight = imShape[1]; int inputWidth = imShape[2]; @@ -205,6 +226,8 @@ public: strideWidth, paddingHeight, paddingWidth, + dilationHeight, + dilationWidth, outputHeight, outputWidth, imData); @@ -229,6 +252,8 @@ __global__ void im2colOCF(const T* imData, int strideWidth, int paddingHeight, int paddingWidth, + int dilationHeight, + int dilationWidth, int outputHeight, int outputWidth) { int swId = blockIdx.x; @@ -237,8 +262,10 @@ __global__ void im2colOCF(const T* imData, channelId += blockDim.z) { for (int idy = threadIdx.y; idy < filterHeight; idy += blockDim.y) { for (int idx = threadIdx.x; idx < filterWidth; idx += blockDim.x) { - int widthOffset = idx + swId * strideWidth - paddingWidth; - int heightOffset = idy + shId * strideHeight - paddingHeight; + int widthOffset = + idx * dilationHeight + swId * strideWidth - paddingWidth; + int heightOffset = + idy * dilationWidth + shId * strideHeight - paddingHeight; int imOffset = widthOffset + heightOffset * inputWidth + channelId * inputHeight * inputWidth; @@ -273,7 +300,9 @@ public: int strideHeight, int strideWidth, int paddingHeight, - int paddingWidth) { + int paddingWidth, + int dilationHeight, + int dilationWidth) { int inputChannels = imShape[0]; int inputHeight = imShape[1]; int inputWidth = imShape[2]; @@ -312,6 +341,8 @@ public: strideWidth, paddingHeight, paddingWidth, + dilationHeight, + dilationWidth, outputHeight, outputWidth); CHECK_SYNC("Im2ColFunctor GPU failed"); @@ -330,6 +361,8 @@ __global__ void col2imOCF(T* imData, int strideWidth, int paddingHeight, int paddingWidth, + int dilationHeight, + int dilationWidth, int outputHeight, int outputWidth) { int swId = blockIdx.x; @@ -338,8 +371,10 @@ __global__ void col2imOCF(T* imData, channelId += blockDim.z) { for (int idy = threadIdx.y; idy < filterHeight; idy += blockDim.y) { for (int idx = threadIdx.x; idx < filterWidth; idx += blockDim.x) { - int widthOffset = idx + swId * strideWidth - paddingWidth; - int heightOffset = idy + shId * strideHeight - paddingHeight; + int widthOffset = + idx * dilationWidth + swId * strideWidth - paddingWidth; + int heightOffset = + idy * dilationHeight + shId * strideHeight - paddingHeight; int imOffset = widthOffset + heightOffset * inputWidth + channelId * inputHeight * inputWidth; @@ -372,7 +407,9 @@ public: int strideHeight, int strideWidth, int paddingHeight, - int paddingWidth) { + int paddingWidth, + int dilationHeight, + int dilationWidth) { int inputChannels = imShape[0]; int inputHeight = imShape[1]; int inputWidth = imShape[2]; @@ -411,6 +448,8 @@ public: strideWidth, paddingHeight, paddingWidth, + dilationHeight, + dilationWidth, outputHeight, outputWidth); CHECK_SYNC("Col2ImFunctor GPU failed"); -- GitLab