提交 328169a9 编写于 作者: X xzl

im2col cpu gpu dilation support

上级 7a5b3846
...@@ -78,7 +78,9 @@ public: ...@@ -78,7 +78,9 @@ public:
int strideHeight, int strideHeight,
int strideWidth, int strideWidth,
int paddingHeight, int paddingHeight,
int paddingWidth); int paddingWidth,
int dilationHeight = 1,
int dilationWidth = 1);
}; };
template <ColFormat Format, DeviceType Device, class T> template <ColFormat Format, DeviceType Device, class T>
...@@ -91,7 +93,9 @@ public: ...@@ -91,7 +93,9 @@ public:
int strideHeight, int strideHeight,
int strideWidth, int strideWidth,
int paddingHeight, int paddingHeight,
int paddingWidth); int paddingWidth,
int dilationHeight = 1,
int dilationWidth = 1);
}; };
} // namespace paddle } // namespace paddle
...@@ -31,7 +31,9 @@ public: ...@@ -31,7 +31,9 @@ public:
int strideHeight, int strideHeight,
int strideWidth, int strideWidth,
int paddingHeight, int paddingHeight,
int paddingWidth) { int paddingWidth,
int dilationHeight,
int dilationWidth) {
int inputChannels = imShape[0]; int inputChannels = imShape[0];
int inputHeight = imShape[1]; int inputHeight = imShape[1];
int inputWidth = imShape[2]; int inputWidth = imShape[2];
...@@ -47,8 +49,8 @@ public: ...@@ -47,8 +49,8 @@ public:
int c_im = c / filterWidth / filterHeight; int c_im = c / filterWidth / filterHeight;
for (int h = 0; h < outputHeight; ++h) { for (int h = 0; h < outputHeight; ++h) {
for (int w = 0; w < outputWidth; ++w) { for (int w = 0; w < outputWidth; ++w) {
int imRowIdx = h * strideHeight + hOffset; int imRowIdx = h * strideHeight + hOffset * dilationHeight;
int imColIdx = w * strideWidth + wOffset; int imColIdx = w * strideWidth + wOffset * dilationWidth;
if ((imRowIdx - paddingHeight) < 0 || if ((imRowIdx - paddingHeight) < 0 ||
(imRowIdx - paddingHeight) >= inputHeight || (imRowIdx - paddingHeight) >= inputHeight ||
(imColIdx - paddingWidth) < 0 || (imColIdx - paddingWidth) < 0 ||
...@@ -81,7 +83,9 @@ public: ...@@ -81,7 +83,9 @@ public:
int strideHeight, int strideHeight,
int strideWidth, int strideWidth,
int paddingHeight, int paddingHeight,
int paddingWidth) { int paddingWidth,
int dilationHeight,
int dilationWidth) {
int inputChannels = imShape[0]; int inputChannels = imShape[0];
int inputHeight = imShape[1]; int inputHeight = imShape[1];
int inputWidth = imShape[2]; int inputWidth = imShape[2];
...@@ -97,8 +101,8 @@ public: ...@@ -97,8 +101,8 @@ public:
int c_im = c / filterWidth / filterHeight; int c_im = c / filterWidth / filterHeight;
for (int h = 0; h < outputHeight; ++h) { for (int h = 0; h < outputHeight; ++h) {
for (int w = 0; w < outputWidth; ++w) { for (int w = 0; w < outputWidth; ++w) {
int imRowIdx = h * strideHeight + hOffset; int imRowIdx = h * strideHeight + hOffset * dilationHeight;
int imColIdx = w * strideWidth + wOffset; int imColIdx = w * strideWidth + wOffset * dilationWidth;
if ((imRowIdx - paddingHeight) >= 0 && if ((imRowIdx - paddingHeight) >= 0 &&
(imRowIdx - paddingHeight) < inputHeight && (imRowIdx - paddingHeight) < inputHeight &&
(imColIdx - paddingWidth) >= 0 && (imColIdx - paddingWidth) >= 0 &&
...@@ -134,7 +138,9 @@ public: ...@@ -134,7 +138,9 @@ public:
int strideHeight, int strideHeight,
int strideWidth, int strideWidth,
int paddingHeight, int paddingHeight,
int paddingWidth) { int paddingWidth,
int dilationHeight = 1,
int dilationWidth = 1) {
int inputChannels = imShape[0]; int inputChannels = imShape[0];
int inputHeight = imShape[1]; int inputHeight = imShape[1];
int inputWidth = imShape[2]; int inputWidth = imShape[2];
...@@ -147,9 +153,10 @@ public: ...@@ -147,9 +153,10 @@ public:
for (int channel = 0; channel < inputChannels; ++channel) { for (int channel = 0; channel < inputChannels; ++channel) {
for (int filterH = 0; filterH < filterHeight; ++filterH) { for (int filterH = 0; filterH < filterHeight; ++filterH) {
for (int filterW = 0; filterW < filterWidth; ++filterW) { for (int filterW = 0; filterW < filterWidth; ++filterW) {
int imRowOffset = int imRowOffset = outputH * strideHeight +
outputH * strideHeight + filterH - paddingHeight; filterH * dilationHeight - paddingHeight;
int imColOffset = outputW * strideWidth + filterW - paddingWidth; int imColOffset = outputW * strideWidth +
filterW * dilationWidth - paddingWidth;
int colDataOffset = int colDataOffset =
(((outputH * outputWidth + outputW) * inputChannels + (((outputH * outputWidth + outputW) * inputChannels +
channel) * channel) *
...@@ -189,7 +196,9 @@ public: ...@@ -189,7 +196,9 @@ public:
int strideHeight, int strideHeight,
int strideWidth, int strideWidth,
int paddingHeight, int paddingHeight,
int paddingWidth) { int paddingWidth,
int dilationHeight = 1,
int dilationWidth = 1) {
int inputChannels = imShape[0]; int inputChannels = imShape[0];
int inputHeight = imShape[1]; int inputHeight = imShape[1];
int inputWidth = imShape[2]; int inputWidth = imShape[2];
...@@ -202,9 +211,10 @@ public: ...@@ -202,9 +211,10 @@ public:
for (int channel = 0; channel < inputChannels; ++channel) { for (int channel = 0; channel < inputChannels; ++channel) {
for (int filterH = 0; filterH < filterHeight; ++filterH) { for (int filterH = 0; filterH < filterHeight; ++filterH) {
for (int filterW = 0; filterW < filterWidth; ++filterW) { for (int filterW = 0; filterW < filterWidth; ++filterW) {
int imRowOffset = int imRowOffset = outputH * strideHeight +
outputH * strideHeight + filterH - paddingHeight; filterH * dilationHeight - paddingHeight;
int imColOffset = outputW * strideWidth + filterW - paddingWidth; int imColOffset = outputW * strideWidth +
filterW * dilationWidth - paddingWidth;
int colDataOffset = int colDataOffset =
(((outputH * outputWidth + outputW) * inputChannels + (((outputH * outputWidth + outputW) * inputChannels +
channel) * channel) *
......
...@@ -28,6 +28,8 @@ __global__ void im2col(const T* data_im, ...@@ -28,6 +28,8 @@ __global__ void im2col(const T* data_im,
int strideW, int strideW,
int paddingH, int paddingH,
int paddingW, int paddingW,
int dilationH,
int dilationW,
int height_col, int height_col,
int width_col, int width_col,
T* data_col) { T* data_col) {
...@@ -44,8 +46,8 @@ __global__ void im2col(const T* data_im, ...@@ -44,8 +46,8 @@ __global__ void im2col(const T* data_im,
data_col += (channel_out * height_col + h_out) * width_col + w_out; data_col += (channel_out * height_col + h_out) * width_col + w_out;
for (int i = 0; i < blockH; ++i) { for (int i = 0; i < blockH; ++i) {
for (int j = 0; j < blockW; ++j) { for (int j = 0; j < blockW; ++j) {
int rIdx = int(h_in + i); int rIdx = int(h_in + i * dilationH);
int cIdx = int(w_in + j); int cIdx = int(w_in + j * dilationW);
if ((rIdx - (int)paddingH) >= (int)height || if ((rIdx - (int)paddingH) >= (int)height ||
(rIdx - (int)paddingH) < 0 || (rIdx - (int)paddingH) < 0 ||
(cIdx - (int)paddingW) >= (int)width || (cIdx - (int)paddingW) >= (int)width ||
...@@ -77,7 +79,9 @@ public: ...@@ -77,7 +79,9 @@ public:
int strideHeight, int strideHeight,
int strideWidth, int strideWidth,
int paddingHeight, int paddingHeight,
int paddingWidth) { int paddingWidth,
int dilationHeight,
int dilationWidth) {
int inputChannels = imShape[0]; int inputChannels = imShape[0];
int inputHeight = imShape[1]; int inputHeight = imShape[1];
int inputWidth = imShape[2]; int inputWidth = imShape[2];
...@@ -102,6 +106,8 @@ public: ...@@ -102,6 +106,8 @@ public:
strideWidth, strideWidth,
paddingHeight, paddingHeight,
paddingWidth, paddingWidth,
dilationHeight,
dilationWidth,
outputHeight, outputHeight,
outputWidth, outputWidth,
colData); colData);
...@@ -121,6 +127,8 @@ __global__ void col2im(size_t n, ...@@ -121,6 +127,8 @@ __global__ void col2im(size_t n,
size_t strideW, size_t strideW,
size_t paddingH, size_t paddingH,
size_t paddingW, size_t paddingW,
size_t dilationH,
size_t dilationW,
size_t height_col, size_t height_col,
size_t width_col, size_t width_col,
T* data_im) { T* data_im) {
...@@ -131,23 +139,34 @@ __global__ void col2im(size_t n, ...@@ -131,23 +139,34 @@ __global__ void col2im(size_t n,
int w = int(index % width); int w = int(index % width);
int h = int((index / width) % height); int h = int((index / width) % height);
int c = int(index / (width * height)); int c = int(index / (width * height));
int filterH = (blockH - 1) * dilationH + 1;
int filterW = (blockW - 1) * dilationW + 1;
if ((w - (int)paddingW) >= 0 && if ((w - (int)paddingW) >= 0 &&
(w - (int)paddingW) < (width - 2 * paddingW) && (w - (int)paddingW) < (width - 2 * paddingW) &&
(h - (int)paddingH) >= 0 && (h - paddingH) < (height - 2 * paddingH)) { (h - (int)paddingH) >= 0 && (h - paddingH) < (height - 2 * paddingH)) {
// compute the start and end of the output // compute the start and end of the output
int w_col_start = int w_col_start =
(w < (int)blockW) ? 0 : (w - int(blockW)) / (int)strideW + 1; (w < (int)filterW) ? 0 : (w - int(filterW)) / (int)strideW + 1;
int w_col_end = min((int)(w / (int)strideW + 1), (int)(width_col)); int w_col_end = min((int)(w / (int)strideW + 1), (int)(width_col));
int h_col_start = int h_col_start =
(h < (int)blockH) ? 0 : (h - (int)blockH) / (int)strideH + 1; (h < (int)filterH) ? 0 : (h - (int)filterH) / (int)strideH + 1;
int h_col_end = min(int(h / strideH + 1), int(height_col)); int h_col_end = min(int(h / strideH + 1), int(height_col));
for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
// the col location: [c * width * height + h_out, w_out] // the col location: [c * width * height + h_out, w_out]
int c_col = int(c * blockH * blockW) + int h_k = (h - h_col * strideH);
(h - h_col * (int)strideH) * (int)blockW + int w_k = (w - w_col * strideW);
(w - w_col * (int)strideW); if (h_k % dilationH == 0 && w_k % dilationW == 0) {
val += data_col[(c_col * height_col + h_col) * width_col + w_col]; h_k /= dilationH;
w_k /= dilationW;
int c_col =
(((c * blockH + h_k) * blockW + w_k) * height_col + h_col) *
width_col +
w_col;
val += data_col[c_col];
}
} }
} }
h -= paddingH; h -= paddingH;
...@@ -173,7 +192,9 @@ public: ...@@ -173,7 +192,9 @@ public:
int strideHeight, int strideHeight,
int strideWidth, int strideWidth,
int paddingHeight, int paddingHeight,
int paddingWidth) { int paddingWidth,
int dilationHeight,
int dilationWidth) {
int inputChannels = imShape[0]; int inputChannels = imShape[0];
int inputHeight = imShape[1]; int inputHeight = imShape[1];
int inputWidth = imShape[2]; int inputWidth = imShape[2];
...@@ -205,6 +226,8 @@ public: ...@@ -205,6 +226,8 @@ public:
strideWidth, strideWidth,
paddingHeight, paddingHeight,
paddingWidth, paddingWidth,
dilationHeight,
dilationWidth,
outputHeight, outputHeight,
outputWidth, outputWidth,
imData); imData);
...@@ -229,6 +252,8 @@ __global__ void im2colOCF(const T* imData, ...@@ -229,6 +252,8 @@ __global__ void im2colOCF(const T* imData,
int strideWidth, int strideWidth,
int paddingHeight, int paddingHeight,
int paddingWidth, int paddingWidth,
int dilationHeight,
int dilationWidth,
int outputHeight, int outputHeight,
int outputWidth) { int outputWidth) {
int swId = blockIdx.x; int swId = blockIdx.x;
...@@ -237,8 +262,10 @@ __global__ void im2colOCF(const T* imData, ...@@ -237,8 +262,10 @@ __global__ void im2colOCF(const T* imData,
channelId += blockDim.z) { channelId += blockDim.z) {
for (int idy = threadIdx.y; idy < filterHeight; idy += blockDim.y) { for (int idy = threadIdx.y; idy < filterHeight; idy += blockDim.y) {
for (int idx = threadIdx.x; idx < filterWidth; idx += blockDim.x) { for (int idx = threadIdx.x; idx < filterWidth; idx += blockDim.x) {
int widthOffset = idx + swId * strideWidth - paddingWidth; int widthOffset =
int heightOffset = idy + shId * strideHeight - paddingHeight; idx * dilationHeight + swId * strideWidth - paddingWidth;
int heightOffset =
idy * dilationWidth + shId * strideHeight - paddingHeight;
int imOffset = widthOffset + heightOffset * inputWidth + int imOffset = widthOffset + heightOffset * inputWidth +
channelId * inputHeight * inputWidth; channelId * inputHeight * inputWidth;
...@@ -273,7 +300,9 @@ public: ...@@ -273,7 +300,9 @@ public:
int strideHeight, int strideHeight,
int strideWidth, int strideWidth,
int paddingHeight, int paddingHeight,
int paddingWidth) { int paddingWidth,
int dilationHeight,
int dilationWidth) {
int inputChannels = imShape[0]; int inputChannels = imShape[0];
int inputHeight = imShape[1]; int inputHeight = imShape[1];
int inputWidth = imShape[2]; int inputWidth = imShape[2];
...@@ -312,6 +341,8 @@ public: ...@@ -312,6 +341,8 @@ public:
strideWidth, strideWidth,
paddingHeight, paddingHeight,
paddingWidth, paddingWidth,
dilationHeight,
dilationWidth,
outputHeight, outputHeight,
outputWidth); outputWidth);
CHECK_SYNC("Im2ColFunctor GPU failed"); CHECK_SYNC("Im2ColFunctor GPU failed");
...@@ -330,6 +361,8 @@ __global__ void col2imOCF(T* imData, ...@@ -330,6 +361,8 @@ __global__ void col2imOCF(T* imData,
int strideWidth, int strideWidth,
int paddingHeight, int paddingHeight,
int paddingWidth, int paddingWidth,
int dilationHeight,
int dilationWidth,
int outputHeight, int outputHeight,
int outputWidth) { int outputWidth) {
int swId = blockIdx.x; int swId = blockIdx.x;
...@@ -338,8 +371,10 @@ __global__ void col2imOCF(T* imData, ...@@ -338,8 +371,10 @@ __global__ void col2imOCF(T* imData,
channelId += blockDim.z) { channelId += blockDim.z) {
for (int idy = threadIdx.y; idy < filterHeight; idy += blockDim.y) { for (int idy = threadIdx.y; idy < filterHeight; idy += blockDim.y) {
for (int idx = threadIdx.x; idx < filterWidth; idx += blockDim.x) { for (int idx = threadIdx.x; idx < filterWidth; idx += blockDim.x) {
int widthOffset = idx + swId * strideWidth - paddingWidth; int widthOffset =
int heightOffset = idy + shId * strideHeight - paddingHeight; idx * dilationWidth + swId * strideWidth - paddingWidth;
int heightOffset =
idy * dilationHeight + shId * strideHeight - paddingHeight;
int imOffset = widthOffset + heightOffset * inputWidth + int imOffset = widthOffset + heightOffset * inputWidth +
channelId * inputHeight * inputWidth; channelId * inputHeight * inputWidth;
...@@ -372,7 +407,9 @@ public: ...@@ -372,7 +407,9 @@ public:
int strideHeight, int strideHeight,
int strideWidth, int strideWidth,
int paddingHeight, int paddingHeight,
int paddingWidth) { int paddingWidth,
int dilationHeight,
int dilationWidth) {
int inputChannels = imShape[0]; int inputChannels = imShape[0];
int inputHeight = imShape[1]; int inputHeight = imShape[1];
int inputWidth = imShape[2]; int inputWidth = imShape[2];
...@@ -411,6 +448,8 @@ public: ...@@ -411,6 +448,8 @@ public:
strideWidth, strideWidth,
paddingHeight, paddingHeight,
paddingWidth, paddingWidth,
dilationHeight,
dilationWidth,
outputHeight, outputHeight,
outputWidth); outputWidth);
CHECK_SYNC("Col2ImFunctor GPU failed"); CHECK_SYNC("Col2ImFunctor GPU failed");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册