From 97e9dd72375258ed69fbbab39f340d23878002f5 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Wed, 8 Nov 2017 14:15:58 +0800 Subject: [PATCH] add dilation for im2col --- paddle/operators/conv_cudnn_op.cc | 2 - paddle/operators/conv_op.cc | 13 +- paddle/operators/conv_op.h | 29 +- paddle/operators/conv_transpose_op.h | 16 +- paddle/operators/math/context_project.h | 10 +- paddle/operators/math/im2col.cc | 281 +++++++++--------- paddle/operators/math/im2col.cu | 366 +++++++++++++----------- paddle/operators/math/im2col.h | 11 +- paddle/operators/math/im2col_test.cc | 18 +- 9 files changed, 395 insertions(+), 351 deletions(-) diff --git a/paddle/operators/conv_cudnn_op.cc b/paddle/operators/conv_cudnn_op.cc index 97f31bf22d7..4c65b60d234 100644 --- a/paddle/operators/conv_cudnn_op.cc +++ b/paddle/operators/conv_cudnn_op.cc @@ -22,8 +22,6 @@ class CudnnConvOpMaker : public Conv2DOpMaker { CudnnConvOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : Conv2DOpMaker(proto, op_checker) { - AddAttr>("dilations", "dilations of convolution operator.") - .SetDefault(std::vector{1, 1}); AddAttr("workspace_size_MB", "workspace size for cudnn, in MB, " "workspace is a section of GPU memory which will be " diff --git a/paddle/operators/conv_op.cc b/paddle/operators/conv_op.cc index a6f65f10165..852ac2ae37c 100644 --- a/paddle/operators/conv_op.cc +++ b/paddle/operators/conv_op.cc @@ -30,6 +30,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { std::vector strides = ctx->Attrs().Get>("strides"); std::vector paddings = ctx->Attrs().Get>("paddings"); int groups = ctx->Attrs().Get("groups"); + std::vector dilations = ctx->Attrs().Get>("dilations"); int input_channels = in_dims[1]; int output_channels = filter_dims[0]; @@ -54,7 +55,8 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { std::vector output_shape({in_dims[0], filter_dims[0]}); for (size_t i = 0; i < paddings.size(); ++i) { output_shape.push_back(OutputSize(in_dims[i + 2], filter_dims[i + 2], - paddings[i], strides[i])); + dilations[i], paddings[i], paddings[i], + strides[i])); } ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); } @@ -90,6 +92,10 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto, "first half of the input channels, while the second half of the filters " "is only connected to the second half of the input channels.") .SetDefault(1); + AddAttr>("dilations", + "(vector default:{1, 1}), the dilations of " + "convolution operator.") + .SetDefault(std::vector{1, 1}); AddComment(R"DOC( Convolution Operator. @@ -151,6 +157,11 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto, "first half of the input channels, while the second half of the filters " "is only connected to the second half of the input channels.") .SetDefault(1); + AddAttr>("dilations", + "(vector default:{1, 1, 1}), the dilations of " + "convolution operator. Currently, conv3d doesn't " + "support dilation.") + .SetDefault(std::vector{1, 1, 1}); AddComment(R"DOC( Convolution3D Operator. diff --git a/paddle/operators/conv_op.h b/paddle/operators/conv_op.h index 7c1729213bf..2459f03a1a9 100644 --- a/paddle/operators/conv_op.h +++ b/paddle/operators/conv_op.h @@ -27,9 +27,12 @@ using Tensor = framework::Tensor; // Base convolution operator definations for other conv // like operators to reuse the implementation. -inline int OutputSize(int input_size, int filter_size, int padding, - int stride) { - int output_size = (input_size - filter_size + 2 * padding) / stride + 1; +inline int OutputSize(int input_size, int filter_size, int dilation, + int padding_up, int padding_down, int stride) { + int output_size = (input_size + padding_up + padding_down - + (dilation * (filter_size - 1) + 1)) / + stride + + 1; return output_size; } @@ -76,6 +79,7 @@ class GemmConvKernel : public framework::OpKernel { std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); int groups = context.Attr("groups"); + std::vector dilations = context.Attr>("dilations"); const int batch_size = static_cast(input->dims()[0]); @@ -139,9 +143,9 @@ class GemmConvKernel : public framework::OpKernel { if (filter_shape_vec.size() == 2) { // im2col math::Im2ColFunctor im2col; - im2col(context.device_context(), in_slice, col, strides[0], - strides[1], paddings[0], paddings[0], paddings[1], - paddings[1]); + im2col(context.device_context(), in_slice, col, dilations[0], + dilations[1], strides[0], strides[1], paddings[0], paddings[0], + paddings[1], paddings[1]); } else if (filter_shape_vec.size() == 3) { // vol2col math::Vol2ColFunctor vol2col; @@ -181,6 +185,7 @@ class GemmConvGradKernel : public framework::OpKernel { std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); int groups = context.Attr("groups"); + std::vector dilations = context.Attr>("dilations"); const int batch_size = static_cast(input->dims()[0]); @@ -263,9 +268,9 @@ class GemmConvGradKernel : public framework::OpKernel { if (filter_shape_vec.size() == 2) { math::Col2ImFunctor col2im; - col2im(context.device_context(), in_grad_slice, col, strides[0], - strides[1], paddings[0], paddings[0], paddings[1], - paddings[1]); + col2im(context.device_context(), in_grad_slice, col, dilations[0], + dilations[1], strides[0], strides[1], paddings[0], + paddings[0], paddings[1], paddings[1]); } else if (filter_shape_vec.size() == 3) { math::Col2VolFunctor col2vol; @@ -295,9 +300,9 @@ class GemmConvGradKernel : public framework::OpKernel { if (filter_shape_vec.size() == 2) { math::Im2ColFunctor im2col; - im2col(context.device_context(), in_slice, col, strides[0], - strides[1], paddings[0], paddings[0], paddings[1], - paddings[1]); + im2col(context.device_context(), in_slice, col, dilations[0], + dilations[1], strides[0], strides[1], paddings[0], + paddings[0], paddings[1], paddings[1]); } else if (filter_shape_vec.size() == 3) { math::Vol2ColFunctor vol2col; vol2col(context.device_context(), in_slice, col, strides[0], diff --git a/paddle/operators/conv_transpose_op.h b/paddle/operators/conv_transpose_op.h index 6c1a6220d78..cbfad88b398 100644 --- a/paddle/operators/conv_transpose_op.h +++ b/paddle/operators/conv_transpose_op.h @@ -69,6 +69,9 @@ class GemmConvTransposeKernel : public framework::OpKernel { // TODO(Zhuoyuan): Paddings can be added in future. // groups will alway be disabled in conv2dtranspose. + int dilation_h = 1; + int dilation_w = 1; + const int batch_size = static_cast(input->dims()[0]); // input_shape_vec: {h, w} or {d, h, w} @@ -140,8 +143,8 @@ class GemmConvTransposeKernel : public framework::OpKernel { // from (c * k_h * k_w, h * w) to (c, o_h, o_w) math::Col2ImFunctor col2im; - col2im(context.device_context(), output_batch, col, strides[0], - strides[1], 0, 0, 0, 0); + col2im(context.device_context(), output_batch, col, dilation_h, + dilation_w, strides[0], strides[1], 0, 0, 0, 0); } else if (filter_shape_vec.size() == 3) { // col2vol: col_matrix -> dy // from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w) @@ -174,6 +177,9 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { // Actually, no paddings and groups allowed in conv transpose. std::vector paddings = context.Attr>("paddings"); + int dilation_h = 1; + int dilation_w = 1; + const int batch_size = static_cast(input->dims()[0]); // input_shape_vec: {h, w} or {d, h, w} @@ -248,9 +254,9 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { // im2col: dy -> col matrix // from (c, o_h, o_w) to (c * k_h * k_w, h * w) math::Im2ColFunctor im2col; - im2col(context.device_context(), output_grad_batch, col, strides[0], - strides[1], paddings[0], paddings[0], paddings[1], - paddings[1]); + im2col(context.device_context(), output_grad_batch, col, dilation_h, + dilation_w, strides[0], strides[1], paddings[0], paddings[0], + paddings[1], paddings[1]); } else if (filter_shape_vec.size() == 3) { // vol2col: dy -> col_matrix // from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w) diff --git a/paddle/operators/math/context_project.h b/paddle/operators/math/context_project.h index e0283360414..c67d84528fd 100644 --- a/paddle/operators/math/context_project.h +++ b/paddle/operators/math/context_project.h @@ -95,6 +95,9 @@ class ContextProjectFunctor { math::Im2ColFunctor im2col_ocf; + int dilation_h = 1; + int dilation_w = 1; + int input_row_begin, input_row_end; int sequence_height, sequence_width; sequence_width = in.dims()[1]; @@ -124,7 +127,7 @@ class ContextProjectFunctor { sequence_width}); // input_channels, input_height, input_width in_t.Resize(framework::make_ddim(input_shape)); - im2col_ocf(context, in_t, out_t, + im2col_ocf(context, in_t, out_t, dilation_h, dilation_w, /*stride_height*/ context_stride, /*stride_width*/ 1, up_pad, down_pad, 0, 0); out_t.Resize({sequence_height, context_length * sequence_width}); @@ -204,6 +207,9 @@ class ContextProjectGradFunctor { math::Col2ImFunctor col2im_ocf; + int dilation_h = 1; + int dilation_w = 1; + int input_row_begin, input_row_end; int sequence_height, sequence_width; sequence_width = in.dims()[1]; @@ -234,7 +240,7 @@ class ContextProjectGradFunctor { sequence_width}); // input_channels, input_height, input_width in_t.Resize(framework::make_ddim(input_shape)); - col2im_ocf(context, in_t, out_t, + col2im_ocf(context, in_t, out_t, dilation_h, dilation_w, /*stride_height*/ context_stride, /*stride_width*/ 1, up_pad, down_pad, 0, 0); out_t.Resize({sequence_height, context_length * sequence_width}); diff --git a/paddle/operators/math/im2col.cc b/paddle/operators/math/im2col.cc index 3b1b0bd71dd..b248863b4e9 100644 --- a/paddle/operators/math/im2col.cc +++ b/paddle/operators/math/im2col.cc @@ -29,35 +29,36 @@ class Im2ColFunctor(); T* col_data = col.data(); @@ -66,19 +67,19 @@ class Im2ColFunctor= input_height || im_col_idx < 0 || - im_col_idx >= input_width) { - col_data[(c * output_height + h) * output_width + w] = T(0); - } else { - im_row_idx += c_im * input_height; - col_data[(c * output_height + h) * output_width + w] = - im_data[im_row_idx * input_width + im_col_idx]; - } + col_data[(c * col_height + h) * col_width + w] = + (im_row_idx < 0 || im_row_idx >= im_height || im_col_idx < 0 || + im_col_idx >= im_width) + ? static_cast(0) + : im_data[(im_row_idx + c_im * im_height) * im_width + + im_col_idx]; } } } @@ -95,35 +96,35 @@ class Col2ImFunctor { public: void operator()(const platform::DeviceContext& context, framework::Tensor& im, - const framework::Tensor& col, int stride_height, - int stride_width, int padding_up, int padding_down, - int padding_left, int padding_right) { + const framework::Tensor& col, int dilation_h, int dilation_w, + int stride_height, int stride_width, int padding_up, + int padding_down, int padding_left, int padding_right) { PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); - int input_channels = im.dims()[0]; - int input_height = im.dims()[1]; - int input_width = im.dims()[2]; + int im_channels = im.dims()[0]; + int im_height = im.dims()[1]; + int im_width = im.dims()[2]; int filter_height = col.dims()[1]; int filter_width = col.dims()[2]; - int output_height = col.dims()[3]; - int output_width = col.dims()[4]; + int col_height = col.dims()[3]; + int col_width = col.dims()[4]; - PADDLE_ENFORCE_EQ( - (input_height + padding_up + padding_down - filter_height) / - stride_height + - 1, - output_height, - "Output_height and padding(padding_up, padding_down) are " - "inconsistent."); - PADDLE_ENFORCE_EQ( - (input_width + padding_left + padding_right - filter_width) / - stride_width + - 1, - output_width, - "output_width and padding(padding_left, padding_right) are " - "inconsistent."); + PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - + ((dilation_h * (filter_height - 1) + 1))) / + stride_height + + 1, + col_height, + "Output_height and padding(padding_up, padding_down) are " + "inconsistent."); + PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - + ((dilation_w * (filter_width - 1) + 1))) / + stride_width + + 1, + col_width, + "col_width and padding(padding_left, padding_right) are " + "inconsistent."); - int channels_col = input_channels * filter_height * filter_width; + int channels_col = im_channels * filter_height * filter_width; T* im_data = im.data(); const T* col_data = col.data(); @@ -132,16 +133,18 @@ class Col2ImFunctor= 0 && (im_row_idx) < input_height && - (im_col_idx) >= 0 && (im_col_idx) < input_width) { - im_row_idx += c_im * input_height; - im_data[im_row_idx * input_width + im_col_idx] += - col_data[(c * output_height + h) * output_width + w]; + if ((im_row_idx) >= 0 && (im_row_idx) < im_height && + (im_col_idx) >= 0 && (im_col_idx) < im_width) { + im_row_idx += c_im * im_height; + im_data[im_row_idx * im_width + im_col_idx] += + col_data[(c * col_height + h) * col_width + w]; } } } @@ -169,39 +172,38 @@ class Im2ColFunctor(); T* col_data = col.data(); - for (int col_row_idx = 0; col_row_idx < output_height; ++col_row_idx) { - for (int col_col_idx = 0; col_col_idx < output_width; ++col_col_idx) { - for (int channel = 0; channel < input_channels; ++channel) { + for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) { + for (int col_col_idx = 0; col_col_idx < col_width; ++col_col_idx) { + for (int channel = 0; channel < im_channels; ++channel) { for (int filter_row_idx = 0; filter_row_idx < filter_height; ++filter_row_idx) { for (int filter_col_idx = 0; filter_col_idx < filter_width; @@ -210,22 +212,21 @@ class Im2ColFunctor= input_height || - im_col_offset < 0 || im_col_offset >= input_width) { - col_data[col_offset] = T(0); - } else { - int im_offset = - (channel * input_height + im_row_offset) * input_width + - im_col_offset; - col_data[col_offset] = im_data[im_offset]; - } + int col_offset = + ((((col_row_idx)*col_width + col_col_idx) * im_channels + + channel) * + filter_height + + filter_row_idx) * + filter_width + + filter_col_idx; + + int im_offset = (channel * im_height + im_row_offset) * im_width + + im_col_offset; + col_data[col_offset] = + (im_row_offset < 0 || im_row_offset >= im_height || + im_col_offset < 0 || im_col_offset >= im_width) + ? static_cast(0) + : im_data[im_offset]; } } } @@ -244,40 +245,38 @@ class Col2ImFunctor { public: void operator()(const platform::DeviceContext& context, framework::Tensor& im, - const framework::Tensor& col, int stride_height, - int stride_width, int padding_up, int padding_down, - int padding_left, int padding_right) { + const framework::Tensor& col, int dilation_h, int dilation_w, + int stride_height, int stride_width, int padding_up, + int padding_down, int padding_left, int padding_right) { PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); - int input_channels = im.dims()[0]; - int input_height = im.dims()[1]; - int input_width = im.dims()[2]; + int im_channels = im.dims()[0]; + int im_height = im.dims()[1]; + int im_width = im.dims()[2]; int filter_height = col.dims()[3]; int filter_width = col.dims()[4]; - int output_height = col.dims()[0]; - int output_width = col.dims()[1]; + int col_height = col.dims()[0]; + int col_width = col.dims()[1]; - PADDLE_ENFORCE_EQ( - (input_height + padding_up + padding_down - filter_height) / - stride_height + - 1, - output_height, - "Output_height and padding(padding_up, padding_down) are " - "inconsistent."); - PADDLE_ENFORCE_EQ( - (input_width + padding_left + padding_right - filter_width) / - stride_width + - 1, - output_width, - "output_width and padding(padding_left, padding_right) are " - "inconsistent."); + PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - filter_height) / + stride_height + + 1, + col_height, + "Output_height and padding(padding_up, padding_down) are " + "inconsistent."); + PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - filter_width) / + stride_width + + 1, + col_width, + "col_width and padding(padding_left, padding_right) are " + "inconsistent."); T* im_data = im.data(); const T* col_data = col.data(); - for (int col_row_idx = 0; col_row_idx < output_height; ++col_row_idx) { - for (int col_col_idx = 0; col_col_idx < output_width; ++col_col_idx) { - for (int channel = 0; channel < input_channels; ++channel) { + for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) { + for (int col_col_idx = 0; col_col_idx < col_width; ++col_col_idx) { + for (int channel = 0; channel < im_channels; ++channel) { for (int filter_row_idx = 0; filter_row_idx < filter_height; ++filter_row_idx) { for (int filter_col_idx = 0; filter_col_idx < filter_width; @@ -286,17 +285,17 @@ class Col2ImFunctor= 0 && im_row_offset < input_height && - im_col_offset >= 0 && im_col_offset < input_width) { + int col_offset = + (((col_row_idx * col_width + col_col_idx) * im_channels + + channel) * + filter_height + + filter_row_idx) * + filter_width + + filter_col_idx; + if (im_row_offset >= 0 && im_row_offset < im_height && + im_col_offset >= 0 && im_col_offset < im_width) { int im_offset = - (channel * input_height + im_row_offset) * input_width + + (channel * im_height + im_row_offset) * im_width + im_col_offset; im_data[im_offset] += col_data[col_offset]; } diff --git a/paddle/operators/math/im2col.cu b/paddle/operators/math/im2col.cu index 7b201fdbf3c..69e2abee03b 100644 --- a/paddle/operators/math/im2col.cu +++ b/paddle/operators/math/im2col.cu @@ -20,36 +20,32 @@ namespace operators { namespace math { template -__global__ void im2col(const T* data_im, int num_outs, int height, int width, +__global__ void im2col(const T* data_im, int num_outs, int im_height, + int im_width, int dilation_h, int dilation_w, int filter_height, int filter_width, int stride_height, int stride_width, int padding_height, int padding_width, - int output_height, int output_width, T* data_col) { - int index = (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; + int col_height, int col_width, T* data_col) { + const int index = + (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; if (index < num_outs) { - int w_out = index % output_width; - index /= output_width; - int h_out = index % output_height; - int channel_in = index / output_height; + int w_out = index % col_width; + int h_out = (index / col_width) % col_height; + int channel_in = index / col_width / col_height; int channel_out = channel_in * filter_height * filter_width; - int h_in = h_out * stride_height; - int w_in = w_out * stride_width; + int h_in = h_out * stride_height - padding_height; + int w_in = w_out * stride_width - padding_width; - data_col += (channel_out * output_height + h_out) * output_width + w_out; + data_col += (channel_out * col_height + h_out) * col_width + w_out; + data_im += (channel_in * im_height + h_in) * im_width + w_in; for (int i = 0; i < filter_height; ++i) { for (int j = 0; j < filter_width; ++j) { - int rIdx = int(h_in + i); - int cIdx = int(w_in + j); - if ((rIdx - (int)padding_height) >= (int)height || - (rIdx - (int)padding_height) < 0 || - (cIdx - (int)padding_width) >= (int)width || - (cIdx - (int)padding_width) < 0) { - *data_col = 0; - } else { - rIdx = rIdx + channel_in * height - padding_height; - cIdx = cIdx - padding_width; - *data_col = data_im[rIdx * width + cIdx]; - } - data_col += output_height * output_width; + int rIdx = h_in + i * dilation_h; + int cIdx = w_in + j * dilation_w; + *data_col = + (rIdx >= im_height || rIdx < 0 || cIdx >= im_width || cIdx < 0) + ? 0 + : data_im[i * dilation_h * im_width + j * dilation_w]; + data_col += col_height * col_width; } } } @@ -66,29 +62,36 @@ class Im2ColFunctor<<(context) .stream()>>>( - im.data(), num_outputs, input_height, input_width, filter_height, - filter_width, stride_height, stride_width, padding_up, padding_left, - output_height, output_width, col.data()); + im.data(), num_outputs, im_height, im_width, dilation_h, dilation_w, + filter_height, filter_width, stride_height, stride_width, padding_up, + padding_left, col_height, col_width, col.data()); } }; template -__global__ void col2im(size_t n, const T* data_col, size_t height, size_t width, - size_t channels, size_t filter_height, - size_t filter_width, size_t stride_height, - size_t stride_width, size_t padding_height, - size_t padding_width, size_t output_height, - size_t output_width, T* data_im) { - size_t index = +__global__ void col2im(int n, const T* data_col, int im_height, int im_width, + int dilation_h, int dilation_w, int filter_height, + int filter_width, int stride_height, int stride_width, + int padding_height, int padding_width, int col_height, + int col_width, T* data_im) { + const int index = (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; + + const int d_filter_height = dilation_h * (filter_height - 1) + 1; + const int d_filter_width = dilation_w * (filter_width - 1) + 1; + if (index < n) { T val = 0; - int w = int(index % width); - int h = int((index / width) % height); - int c = int(index / (width * height)); - if ((w - (int)padding_width) >= 0 && - (w - (int)padding_width) < (width - 2 * padding_width) && - (h - (int)padding_height) >= 0 && - (h - padding_height) < (height - 2 * padding_height)) { - // compute the start and end of the output - int w_col_start = (w < (int)filter_width) - ? 0 - : (w - int(filter_width)) / (int)stride_width + 1; - int w_col_end = - min((int)(w / (int)stride_width + 1), (int)(output_width)); - int h_col_start = (h < (int)filter_height) - ? 0 - : (h - (int)filter_height) / (int)stride_height + 1; - int h_col_end = min(int(h / stride_height + 1), int(output_height)); - for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { - for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { - // the col location: [c * width * height + h_out, w_out] - int c_col = int(c * filter_height * filter_width) + - (h - h_col * (int)stride_height) * (int)filter_width + - (w - w_col * (int)stride_width); - val += - data_col[(c_col * output_height + h_col) * output_width + w_col]; + int w = index % im_width; + int h = (index / im_width) % im_height; + int c = index / (im_width * im_height); + + // compute the start and end of the output + int w_col_start = + (w < d_filter_width) ? 0 : (w - d_filter_width) / stride_width + 1; + int w_col_end = min(w / stride_width + 1, col_width); + int h_col_start = + (h < d_filter_height) ? 0 : (h - d_filter_height) / stride_height + 1; + int h_col_end = min(h / stride_height + 1, col_height); + + for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { + for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { + int h_off = (h - h_col * stride_height); + int w_off = (w - w_col * stride_width); + if (h_off % dilation_h == 0 && w_off % dilation_w == 0) { + h_off /= dilation_h; + w_off /= dilation_w; + int data_col_index = + (((c * filter_height + h_off) * filter_width + w_off) * + col_height + + h_col) * + col_width + + w_col; + val += data_col[data_col_index]; } } - h -= padding_height; - w -= padding_width; - data_im[c * ((width - 2 * padding_width) * - (height - 2 * padding_height)) + - h * (width - 2 * padding_width) + w] += val; } + data_im[index] = val; } } @@ -160,32 +163,36 @@ class Col2ImFunctor { public: void operator()(const platform::DeviceContext& context, framework::Tensor& im, - const framework::Tensor& col, int stride_height, - int stride_width, int padding_up, int padding_down, - int padding_left, int padding_right) { + const framework::Tensor& col, int dilation_h, int dilation_w, + int stride_height, int stride_width, int padding_up, + int padding_down, int padding_left, int padding_right) { PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); - int input_channels = im.dims()[0]; - int input_height = im.dims()[1]; - int input_width = im.dims()[2]; + int im_channels = im.dims()[0]; + int im_height = im.dims()[1]; + int im_width = im.dims()[2]; int filter_height = col.dims()[1]; int filter_width = col.dims()[2]; - int output_height = col.dims()[3]; - int output_width = col.dims()[4]; - - PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) / - stride_height + - 1 == - output_height); - PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) / - stride_width + - 1 == - output_width); - - size_t num_kernels = input_channels * - (input_height + padding_up + padding_down) * - (input_width + padding_left + padding_right); + int col_height = col.dims()[3]; + int col_width = col.dims()[4]; + + PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - + (dilation_h * (filter_height - 1) + 1)) / + stride_height + + 1, + col_height, + "Output_height and padding(padding_up, padding_down) are " + "inconsistent."); + PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - + (dilation_w * (filter_width - 1) + 1)) / + stride_width + + 1, + col_width, + "col_width and padding(padding_left, padding_right) are " + "inconsistent."); + + size_t num_kernels = im_channels * im_height * im_width; size_t blocks = (num_kernels + 1024 - 1) / 1024; size_t block_x = 512; @@ -198,10 +205,9 @@ class Col2ImFunctor<<(context) .stream()>>>( - num_kernels, col.data(), input_height + padding_up + padding_down, - input_width + padding_left + padding_left, input_channels, + num_kernels, col.data(), im_height, im_width, dilation_h, dilation_w, filter_height, filter_width, stride_height, stride_width, padding_up, - padding_left, output_height, output_width, im.data()); + padding_left, col_height, col_width, im.data()); } }; @@ -215,33 +221,32 @@ template class Col2ImFunctor; template -__global__ void im2colOCF(const T* im_data, T* col_data, int input_channels, - int input_height, int input_width, int filter_height, +__global__ void im2colOCF(const T* im_data, T* col_data, int im_channels, + int im_height, int im_width, int filter_height, int filter_width, int stride_height, int stride_width, - int padding_height, int padding_width, - int output_height, int output_width) { + int padding_height, int padding_width, int col_height, + int col_width) { int swid = blockIdx.x; int shid = blockIdx.y; - for (int channelid = threadIdx.z; channelid < input_channels; + for (int channelid = threadIdx.z; channelid < im_channels; channelid += blockDim.z) { for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) { for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) { int width_offset = idx + swid * stride_width - padding_width; int height_offset = idy + shid * stride_height - padding_height; - int im_offset = width_offset + height_offset * input_width + - channelid * input_height * input_width; + int im_offset = width_offset + height_offset * im_width + + channelid * im_height * im_width; int col_offset = idx + idy * filter_width + channelid * filter_height * filter_width + - (shid * output_width + swid) * - (input_channels * filter_height * filter_width); - - if (height_offset >= input_height || height_offset < 0 || - width_offset >= input_width || width_offset < 0) { - col_data[col_offset] = T(0); - } else { - col_data[col_offset] = im_data[im_offset]; - } + (shid * col_width + swid) * + (im_channels * filter_height * filter_width); + + col_data[col_offset] = + (height_offset >= im_height || height_offset < 0 || + width_offset >= im_width || width_offset < 0) + ? T(0) + : im_data[im_offset]; } } } @@ -258,26 +263,33 @@ class Im2ColFunctor<<(context) .stream()>>>( - im.data(), col.data(), input_channels, input_height, input_width, + im.data(), col.data(), im_channels, im_height, im_width, filter_height, filter_width, stride_height, stride_width, padding_up, - padding_left, output_height, output_width); + padding_left, col_height, col_width); } }; template -__global__ void col2imOCF(T* im_data, const T* col_data, int input_channels, - int input_height, int input_width, int filter_height, +__global__ void col2imOCF(T* im_data, const T* col_data, int im_channels, + int im_height, int im_width, int filter_height, int filter_width, int stride_height, int stride_width, - int padding_height, int padding_width, - int output_height, int output_width) { + int padding_height, int padding_width, int col_height, + int col_width) { int swid = blockIdx.x; int shid = blockIdx.y; - for (int channelid = threadIdx.z; channelid < input_channels; + for (int channelid = threadIdx.z; channelid < im_channels; channelid += blockDim.z) { for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) { for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) { int width_offset = idx + swid * stride_width - padding_width; int height_offset = idy + shid * stride_height - padding_height; - int im_offset = width_offset + height_offset * input_width + - channelid * input_height * input_width; + int im_offset = width_offset + height_offset * im_width + + channelid * im_height * im_width; int col_offset = idx + idy * filter_width + channelid * filter_height * filter_width + - (shid * output_width + swid) * - (input_channels * filter_height * filter_width); + (shid * col_width + swid) * + (im_channels * filter_height * filter_width); - if (height_offset >= 0 && height_offset < input_height && - width_offset >= 0 && width_offset < input_width) { + if (height_offset >= 0 && height_offset < im_height && + width_offset >= 0 && width_offset < im_width) { paddle::platform::CudaAtomicAdd(im_data + im_offset, col_data[col_offset]); } @@ -350,27 +361,33 @@ class Col2ImFunctor { public: void operator()(const platform::DeviceContext& context, framework::Tensor& im, - const framework::Tensor& col, int stride_height, - int stride_width, int padding_up, int padding_down, - int padding_left, int padding_right) { + const framework::Tensor& col, int dilation_h, int dilation_w, + int stride_height, int stride_width, int padding_up, + int padding_down, int padding_left, int padding_right) { PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); - int input_channels = im.dims()[0]; - int input_height = im.dims()[1]; - int input_width = im.dims()[2]; + int im_channels = im.dims()[0]; + int im_height = im.dims()[1]; + int im_width = im.dims()[2]; int filter_height = col.dims()[3]; int filter_width = col.dims()[4]; - int output_height = col.dims()[0]; - int output_width = col.dims()[1]; - - PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) / - stride_height + - 1 == - output_height); - PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) / - stride_width + - 1 == - output_width); + int col_height = col.dims()[0]; + int col_width = col.dims()[1]; + + PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - + (dilation_h * (filter_height - 1) + 1)) / + stride_height + + 1, + col_height, + "Output_height and padding(padding_up, padding_down) are " + "inconsistent."); + PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - + (dilation_w * (filter_width - 1) + 1)) / + stride_width + + 1, + col_width, + "col_width and padding(padding_left, padding_right) are " + "inconsistent."); int block_dim_x = 0; int block_dim_y = 0; @@ -389,15 +406,14 @@ class Col2ImFunctor<<(context) .stream()>>>( - im.data(), col.data(), input_channels, input_height, input_width, + im.data(), col.data(), im_channels, im_height, im_width, filter_height, filter_width, stride_height, stride_width, padding_up, - padding_left, output_height, output_width); + padding_left, col_height, col_width); } }; diff --git a/paddle/operators/math/im2col.h b/paddle/operators/math/im2col.h index c736d4fa523..d1c9595a328 100644 --- a/paddle/operators/math/im2col.h +++ b/paddle/operators/math/im2col.h @@ -74,17 +74,18 @@ class Im2ColFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& im, framework::Tensor& col, - int stride_height, int stride_width, int padding_up, - int padding_down, int padding_left, int padding_right); + int dilation_h, int dilation_w, int stride_height, + int stride_width, int padding_up, int padding_down, + int padding_left, int padding_right); }; template class Col2ImFunctor { public: void operator()(const platform::DeviceContext& context, framework::Tensor& im, - const framework::Tensor& col, int stride_height, - int stride_width, int padding_up, int padding_down, - int padding_left, int padding_right); + const framework::Tensor& col, int dilation_h, int dilation_w, + int stride_height, int stride_width, int padding_up, + int padding_down, int padding_left, int padding_right); }; } // namespace math diff --git a/paddle/operators/math/im2col_test.cc b/paddle/operators/math/im2col_test.cc index 5763782c4ed..3385fe8721c 100644 --- a/paddle/operators/math/im2col_test.cc +++ b/paddle/operators/math/im2col_test.cc @@ -47,6 +47,8 @@ void testIm2col() { int filter_size = 2; int stride = 1; int padding = 0; + int dilation_h = 1; + int dilation_w = 1; int output_height = (input_height - filter_size + 2 * padding) / stride + 1; int output_width = (input_width - filter_size + 2 * padding) / stride + 1; float* input_ptr = input_tmp.mutable_data( @@ -85,10 +87,10 @@ void testIm2col() { paddle::operators::math::ColFormat::kOCF, Place, float> im2col_ocf; - im2col(*context, input, output_cfo, stride, stride, padding, padding, padding, - padding); - im2col_ocf(*context, input, output_ocf, stride, stride, padding, padding, - padding, padding); + im2col(*context, input, output_cfo, dilation_h, dilation_w, stride, stride, + padding, padding, padding, padding); + im2col_ocf(*context, input, output_ocf, dilation_h, dilation_w, stride, + stride, padding, padding, padding, padding); float out_cfo_data[] = {0, 1, 1, 2, 3, 4, 4, 5}; float out_ocf_data[] = {0, 1, 3, 4, 1, 2, 4, 5}; @@ -131,8 +133,8 @@ void testIm2col() { input.CopyFrom(input_tmp, *place, *context); } - col2im(*context, input, output_cfo, stride, stride, padding, padding, padding, - padding); + col2im(*context, input, output_cfo, dilation_h, dilation_w, stride, stride, + padding, padding, padding, padding); float* in_ptr; if (paddle::platform::is_cpu_place(*place)) { @@ -153,8 +155,8 @@ void testIm2col() { input.CopyFrom(input_tmp, *place, *context); } - col2im_ocf(*context, input, output_ocf, stride, stride, padding, padding, - padding, padding); + col2im_ocf(*context, input, output_ocf, dilation_h, dilation_w, stride, + stride, padding, padding, padding, padding); if (paddle::platform::is_cpu_place(*place)) { in_ptr = input.data(); -- GitLab