From dc7d07358c594b8f8ea81e33948ddf416686f64d Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Sat, 21 Oct 2017 14:11:40 +0800 Subject: [PATCH] add padding up, down, left, right --- paddle/operators/conv2d_op.h | 8 +- paddle/operators/math/im2col.cc | 142 +++++++++++++++------------ paddle/operators/math/im2col.cu | 119 +++++++++++----------- paddle/operators/math/im2col.h | 7 +- paddle/operators/math/im2col_test.cc | 16 +-- 5 files changed, 158 insertions(+), 134 deletions(-) diff --git a/paddle/operators/conv2d_op.h b/paddle/operators/conv2d_op.h index 7ebdbe81cbb..046f8f5faca 100644 --- a/paddle/operators/conv2d_op.h +++ b/paddle/operators/conv2d_op.h @@ -116,7 +116,7 @@ class GemmConv2DKernel : public framework::OpKernel { // im2col Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); im2col(context.device_context(), in_slice, col, strides[0], strides[1], - paddings[0], paddings[1]); + paddings[0], paddings[0], paddings[1], paddings[1]); // gemm Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); @@ -217,7 +217,8 @@ class GemmConvGrad2DKernel : public framework::OpKernel { Tensor in_grad_slice = in_grad_batch.Slice(g * in_step, (g + 1) * in_step); col2im(context.device_context(), in_grad_slice, col, strides[0], - strides[1], paddings[0], paddings[1]); + strides[1], paddings[0], paddings[0], paddings[1], + paddings[1]); } } } @@ -239,7 +240,8 @@ class GemmConvGrad2DKernel : public framework::OpKernel { out_grad_batch.Slice(g * out_step, (g + 1) * out_step); Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); im2col(context.device_context(), in_slice, col, strides[0], - strides[1], paddings[0], paddings[1]); + strides[1], paddings[0], paddings[0], paddings[1], + paddings[1]); // gemm Tensor filter_grad_slice = diff --git a/paddle/operators/math/im2col.cc b/paddle/operators/math/im2col.cc index 729ba8665cf..441ae7c2292 100644 --- a/paddle/operators/math/im2col.cc +++ b/paddle/operators/math/im2col.cc @@ -29,8 +29,8 @@ class Im2ColFunctor(); @@ -54,14 +64,14 @@ class Im2ColFunctor= input_height || - (im_col_idx - padding_width) < 0 || - (im_col_idx - padding_width) >= input_width) { + if ((im_row_idx - padding_up) < 0 || + (im_row_idx - padding_up) >= input_height || + (im_col_idx - padding_left) < 0 || + (im_col_idx - padding_left) >= input_width) { col_data[(c * output_height + h) * output_width + w] = T(0); } else { - im_row_idx += c_im * input_height - padding_height; - im_col_idx -= padding_width; + im_row_idx += c_im * input_height - padding_up; + im_col_idx -= padding_left; col_data[(c * output_height + h) * output_width + w] = im_data[im_row_idx * input_width + im_col_idx]; } @@ -82,7 +92,8 @@ class Col2ImFunctor(); @@ -105,12 +126,12 @@ class Col2ImFunctor= 0 && - (im_row_idx - padding_height) < input_height && - (im_col_idx - padding_width) >= 0 && - (im_col_idx - padding_width) < input_width) { - im_row_idx += c_im * input_height - padding_height; - im_col_idx -= padding_width; + if ((im_row_idx - padding_up) >= 0 && + (im_row_idx - padding_up) < input_height && + (im_col_idx - padding_left) >= 0 && + (im_col_idx - padding_left) < input_width) { + im_row_idx += c_im * input_height - padding_up; + im_col_idx -= padding_left; im_data[im_row_idx * input_width + im_col_idx] += col_data[(c * output_height + h) * output_width + w]; } @@ -140,8 +161,8 @@ class Im2ColFunctor= down_pad) { - row_begin = 0; - } else { - row_begin = down_pad - up_pad; - } - row_end = row_begin + ((input_height + up_pad + down_pad - filter_height) / - stride_height + - 1); + PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) / + stride_height + + 1 == + output_height); + PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) / + stride_width + + 1 == + output_width); const T* im_data = im.data(); T* col_data = col.data(); - for (int col_row_idx = row_begin; col_row_idx < row_end; ++col_row_idx) { + for (int col_row_idx = 0; col_row_idx < output_height; ++col_row_idx) { for (int col_col_idx = 0; col_col_idx < output_width; ++col_col_idx) { for (int channel = 0; channel < input_channels; ++channel) { for (int filter_row_idx = 0; filter_row_idx < filter_height; @@ -175,17 +193,16 @@ class Im2ColFunctor= input_height || im_col_offset < 0 || im_col_offset >= input_width) { col_data[col_offset] = T(0); @@ -214,7 +231,8 @@ class Col2ImFunctor= down_pad) { - row_begin = 0; - } else { - row_begin = down_pad - up_pad; - } - row_end = row_begin + ((input_height + up_pad + down_pad - filter_height) / - stride_height + - 1); + PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) / + stride_height + + 1 == + output_height); + PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) / + stride_width + + 1 == + output_width); T* im_data = im.data(); const T* col_data = col.data(); - for (int col_row_idx = row_begin; col_row_idx < row_end; ++col_row_idx) { + for (int col_row_idx = 0; col_row_idx < output_height; ++col_row_idx) { for (int col_col_idx = 0; col_col_idx < output_width; ++col_col_idx) { for (int channel = 0; channel < input_channels; ++channel) { for (int filter_row_idx = 0; filter_row_idx < filter_height; @@ -248,17 +263,16 @@ class Col2ImFunctor= 0 && im_row_offset < input_height && im_col_offset >= 0 && im_col_offset < input_width) { int im_offset = diff --git a/paddle/operators/math/im2col.cu b/paddle/operators/math/im2col.cu index 24167586299..7b201fdbf3c 100644 --- a/paddle/operators/math/im2col.cu +++ b/paddle/operators/math/im2col.cu @@ -66,8 +66,8 @@ class Im2ColFunctor(context) .stream()>>>( im.data(), num_outputs, input_height, input_width, filter_height, - filter_width, stride_height, stride_width, padding_height, - padding_width, output_height, output_width, col.data()); + filter_width, stride_height, stride_width, padding_up, padding_left, + output_height, output_width, col.data()); } }; @@ -152,7 +161,8 @@ class Col2ImFunctor<<(context) .stream()>>>( - num_kernels, col.data(), input_height + 2 * padding_height, - input_width + 2 * padding_width, input_channels, filter_height, - filter_width, stride_height, stride_width, padding_height, - padding_width, output_height, output_width, im.data()); + num_kernels, col.data(), input_height + padding_up + padding_down, + input_width + padding_left + padding_left, input_channels, + filter_height, filter_width, stride_height, stride_width, padding_up, + padding_left, output_height, output_width, im.data()); } }; @@ -199,8 +219,7 @@ __global__ void im2colOCF(const T* im_data, T* col_data, int input_channels, int input_height, int input_width, int filter_height, int filter_width, int stride_height, int stride_width, int padding_height, int padding_width, - int output_height, int output_width, int row_begin, - int row_end) { + int output_height, int output_width) { int swid = blockIdx.x; int shid = blockIdx.y; for (int channelid = threadIdx.z; channelid < input_channels; @@ -208,8 +227,7 @@ __global__ void im2colOCF(const T* im_data, T* col_data, int input_channels, for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) { for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) { int width_offset = idx + swid * stride_width - padding_width; - int height_offset = - idy + (shid + row_begin) * stride_height - padding_height; + int height_offset = idy + shid * stride_height - padding_height; int im_offset = width_offset + height_offset * input_width + channelid * input_height * input_width; @@ -240,8 +258,8 @@ class Im2ColFunctor= down_pad) { - row_begin = 0; - } else { - row_begin = down_pad - up_pad; - } - row_end = row_begin + ((input_height + up_pad + down_pad - filter_height) / - stride_height + - 1); - - int output_height = row_end - row_begin; // col.dims()[0]; + int output_height = col.dims()[0]; int output_width = col.dims()[1]; + PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) / + stride_height + + 1 == + output_height); + PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) / + stride_width + + 1 == + output_width); + int block_dim_x = 0; int block_dim_y = 0; if (filter_height <= 4 && filter_width <= 4) { @@ -289,9 +303,8 @@ class Im2ColFunctor(context) .stream()>>>( im.data(), col.data(), input_channels, input_height, input_width, - filter_height, filter_width, stride_height, stride_width, - padding_height, padding_width, output_height, output_width, row_begin, - row_end); + filter_height, filter_width, stride_height, stride_width, padding_up, + padding_left, output_height, output_width); } }; @@ -300,8 +313,7 @@ __global__ void col2imOCF(T* im_data, const T* col_data, int input_channels, int input_height, int input_width, int filter_height, int filter_width, int stride_height, int stride_width, int padding_height, int padding_width, - int output_height, int output_width, int row_begin, - int row_end) { + int output_height, int output_width) { int swid = blockIdx.x; int shid = blockIdx.y; for (int channelid = threadIdx.z; channelid < input_channels; @@ -309,8 +321,7 @@ __global__ void col2imOCF(T* im_data, const T* col_data, int input_channels, for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) { for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) { int width_offset = idx + swid * stride_width - padding_width; - int height_offset = - idy + (shid + row_begin) * stride_height - padding_height; + int height_offset = idy + shid * stride_height - padding_height; int im_offset = width_offset + height_offset * input_width + channelid * input_height * input_width; @@ -340,7 +351,8 @@ class Col2ImFunctor= down_pad) { - row_begin = 0; - } else { - row_begin = down_pad - up_pad; - } - row_end = row_begin + ((input_height + up_pad + down_pad - filter_height) / - stride_height + - 1); - - int output_height = row_end - row_begin; // col.dims()[0]; + int output_height = col.dims()[0]; int output_width = col.dims()[1]; + PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) / + stride_height + + 1 == + output_height); + PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) / + stride_width + + 1 == + output_width); + int block_dim_x = 0; int block_dim_y = 0; if (filter_height <= 4 && filter_width <= 4) { @@ -388,9 +396,8 @@ class Col2ImFunctor(context) .stream()>>>( im.data(), col.data(), input_channels, input_height, input_width, - filter_height, filter_width, stride_height, stride_width, - padding_height, padding_width, output_height, output_width, row_begin, - row_end); + filter_height, filter_width, stride_height, stride_width, padding_up, + padding_left, output_height, output_width); } }; diff --git a/paddle/operators/math/im2col.h b/paddle/operators/math/im2col.h index 7b717e1603c..c736d4fa523 100644 --- a/paddle/operators/math/im2col.h +++ b/paddle/operators/math/im2col.h @@ -74,8 +74,8 @@ class Im2ColFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& im, framework::Tensor& col, - int stride_height, int stride_width, int padding_height, - int padding_width); + int stride_height, int stride_width, int padding_up, + int padding_down, int padding_left, int padding_right); }; template @@ -83,7 +83,8 @@ class Col2ImFunctor { public: void operator()(const platform::DeviceContext& context, framework::Tensor& im, const framework::Tensor& col, int stride_height, - int stride_width, int padding_height, int padding_width); + int stride_width, int padding_up, int padding_down, + int padding_left, int padding_right); }; } // namespace math diff --git a/paddle/operators/math/im2col_test.cc b/paddle/operators/math/im2col_test.cc index 6406d43a9bc..6dfa61649d9 100644 --- a/paddle/operators/math/im2col_test.cc +++ b/paddle/operators/math/im2col_test.cc @@ -85,10 +85,10 @@ void testIm2col() { paddle::operators::math::ColFormat::kOCF, Place, float> im2col_ocf; - im2col(*context, input, output_cfo, stride, stride, padding, padding); - im2col_ocf(*context, input, output_ocf, /*stride_height*/ stride, - /*stride_width*/ stride, /*up_pad*/ padding, - /*down_pad*/ padding); + im2col(*context, input, output_cfo, stride, stride, padding, padding, padding, + padding); + im2col_ocf(*context, input, output_ocf, stride, stride, padding, padding, + padding, padding); float out_cfo_data[] = {0, 1, 1, 2, 3, 4, 4, 5}; float out_ocf_data[] = {0, 1, 3, 4, 1, 2, 4, 5}; @@ -133,7 +133,8 @@ void testIm2col() { input.CopyFrom(input_tmp, *place, *context); } - col2im(*context, input, output_cfo, stride, stride, padding, padding); + col2im(*context, input, output_cfo, stride, stride, padding, padding, padding, + padding); float* in_ptr; if (paddle::platform::is_cpu_place(*place)) { @@ -154,9 +155,8 @@ void testIm2col() { input.CopyFrom(input_tmp, *place, *context); } - col2im_ocf(*context, input, output_ocf, /*stride_height*/ stride, - /*stride_width*/ stride, /*up_pad*/ padding, - /*down_pad*/ padding); + col2im_ocf(*context, input, output_ocf, stride, stride, padding, padding, + padding, padding); if (paddle::platform::is_cpu_place(*place)) { in_ptr = input.data(); -- GitLab