From 271fc9c1198e90813fee647b7020ee752aae549a Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Fri, 10 Nov 2017 10:25:44 +0800 Subject: [PATCH] Add dilation for vol2col --- paddle/operators/conv_op.h | 15 +-- paddle/operators/conv_transpose_op.h | 13 ++- paddle/operators/math/im2col.cu | 1 + paddle/operators/math/vol2col.cc | 80 ++++++++++++--- paddle/operators/math/vol2col.cu | 139 +++++++++++++++++++------- paddle/operators/math/vol2col.h | 2 + paddle/operators/math/vol2col_test.cc | 9 +- 7 files changed, 189 insertions(+), 70 deletions(-) diff --git a/paddle/operators/conv_op.h b/paddle/operators/conv_op.h index 8e9f3b0b0e..af2c8fb163 100644 --- a/paddle/operators/conv_op.h +++ b/paddle/operators/conv_op.h @@ -165,9 +165,9 @@ class GemmConvKernel : public framework::OpKernel { } else if (filter_shape_vec.size() == 3) { // vol2col math::Vol2ColFunctor vol2col; - vol2col(context.device_context(), in_slice, col, strides[0], - strides[1], strides[2], paddings[0], paddings[1], - paddings[2]); + vol2col(context.device_context(), in_slice, col, dilations[0], + dilations[1], dilations[2], strides[0], strides[1], + strides[2], paddings[0], paddings[1], paddings[2]); } // gemm @@ -314,7 +314,8 @@ class GemmConvGradKernel : public framework::OpKernel { } else if (filter_shape_vec.size() == 3) { math::Col2VolFunctor col2vol; - col2vol(context.device_context(), in_grad_slice, col, strides[0], + col2vol(context.device_context(), in_grad_slice, col, + dilations[0], dilations[1], dilations[2], strides[0], strides[1], strides[2], paddings[0], paddings[1], paddings[2]); } @@ -371,9 +372,9 @@ class GemmConvGradKernel : public framework::OpKernel { paddings[0], paddings[1], paddings[1]); } else if (filter_shape_vec.size() == 3) { math::Vol2ColFunctor vol2col; - vol2col(context.device_context(), in_slice, col, strides[0], - strides[1], strides[2], paddings[0], paddings[1], - paddings[2]); + vol2col(context.device_context(), in_slice, col, dilations[0], + dilations[1], dilations[2], strides[0], strides[1], + strides[2], paddings[0], paddings[1], paddings[2]); } // gemm diff --git a/paddle/operators/conv_transpose_op.h b/paddle/operators/conv_transpose_op.h index cbfad88b39..18ca6b20e0 100644 --- a/paddle/operators/conv_transpose_op.h +++ b/paddle/operators/conv_transpose_op.h @@ -69,6 +69,7 @@ class GemmConvTransposeKernel : public framework::OpKernel { // TODO(Zhuoyuan): Paddings can be added in future. // groups will alway be disabled in conv2dtranspose. + int dilaiton_d = 1; int dilation_h = 1; int dilation_w = 1; @@ -149,8 +150,9 @@ class GemmConvTransposeKernel : public framework::OpKernel { // col2vol: col_matrix -> dy // from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w) math::Col2VolFunctor col2vol; - col2vol(context.device_context(), output_batch, col, strides[0], - strides[1], strides[2], 0, 0, 0); + col2vol(context.device_context(), output_batch, col, dilaiton_d, + dilation_h, dilation_w, strides[0], strides[1], strides[2], 0, + 0, 0); } } } @@ -177,6 +179,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { // Actually, no paddings and groups allowed in conv transpose. std::vector paddings = context.Attr>("paddings"); + int dilaiton_d = 1; int dilation_h = 1; int dilation_w = 1; @@ -261,9 +264,9 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { // vol2col: dy -> col_matrix // from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w) math::Vol2ColFunctor vol2col; - vol2col(context.device_context(), output_grad_batch, col, strides[0], - strides[1], strides[2], paddings[0], paddings[1], - paddings[2]); + vol2col(context.device_context(), output_grad_batch, col, dilaiton_d, + dilation_h, dilation_w, strides[0], strides[1], strides[2], + paddings[0], paddings[1], paddings[2]); } if (input_grad) { diff --git a/paddle/operators/math/im2col.cu b/paddle/operators/math/im2col.cu index 69e2abee03..9da427fdf1 100644 --- a/paddle/operators/math/im2col.cu +++ b/paddle/operators/math/im2col.cu @@ -145,6 +145,7 @@ __global__ void col2im(int n, const T* data_col, int im_height, int im_width, h_col) * col_width + w_col; + val += data_col[data_col_index]; } } diff --git a/paddle/operators/math/vol2col.cc b/paddle/operators/math/vol2col.cc index e9718a0473..d383ee8152 100644 --- a/paddle/operators/math/vol2col.cc +++ b/paddle/operators/math/vol2col.cc @@ -29,6 +29,7 @@ class Vol2ColFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& vol, framework::Tensor& col, + int dilation_d, int dilation_h, int dilation_w, int stride_depth, int stride_height, int stride_width, int padding_depth, int padding_height, int padding_width) const { @@ -48,6 +49,28 @@ class Vol2ColFunctor { int channels_col = input_channels * filter_depth * filter_height * filter_width; + PADDLE_ENFORCE_EQ((input_depth + 2 * padding_depth - + ((dilation_d * (filter_depth - 1) + 1))) / + stride_depth + + 1, + output_depth, + "input_depth and output_depth are " + "Mismatching."); + PADDLE_ENFORCE_EQ((input_height + 2 * padding_height - + ((dilation_h * (filter_height - 1) + 1))) / + stride_height + + 1, + output_height, + "input_height and output_height are " + "Mismatching."); + PADDLE_ENFORCE_EQ((input_width + 2 * padding_width - + ((dilation_w * (filter_width - 1) + 1))) / + stride_width + + 1, + output_width, + "input_width and output_width are " + "Mismatching."); + const T* vol_data = vol.data(); T* col_data = col.data(); @@ -57,24 +80,25 @@ class Vol2ColFunctor { int d_offset = (c / filter_width / filter_height) % filter_depth; int c_in = c / filter_width / filter_height / filter_depth; for (int d = 0; d < output_depth; ++d) { - int d_pad = d * stride_depth - padding_depth + d_offset; + int d_pad = d * stride_depth - padding_depth + d_offset * dilation_d; for (int h = 0; h < output_height; ++h) { - int h_pad = h * stride_height - padding_height + h_offset; + int h_pad = + h * stride_height - padding_height + h_offset * dilation_h; for (int w = 0; w < output_width; ++w) { - int w_pad = w * stride_width - padding_width + w_offset; + int w_pad = + w * stride_width - padding_width + w_offset * dilation_w; int col_idx = ((c * output_depth + d) * output_height + h) * output_width + w; - if (h_pad < 0 || h_pad >= input_height || w_pad < 0 || - w_pad >= input_width || d_pad < 0 || d_pad >= input_depth) { - col_data[col_idx] = static_cast(0); - } else { - int vol_idx = - ((c_in * input_depth + d_pad) * input_height + h_pad) * - input_width + - w_pad; - col_data[col_idx] = vol_data[vol_idx]; - } + int vol_idx = + ((c_in * input_depth + d_pad) * input_height + h_pad) * + input_width + + w_pad; + col_data[col_idx] = + (h_pad < 0 || h_pad >= input_height || w_pad < 0 || + w_pad >= input_width || d_pad < 0 || d_pad >= input_depth) + ? static_cast(0) + : vol_data[vol_idx]; } } } @@ -93,6 +117,7 @@ class Col2VolFunctor { public: void operator()(const platform::DeviceContext& context, framework::Tensor& vol, const framework::Tensor& col, + int dilation_d, int dilation_h, int dilation_w, int stride_depth, int stride_height, int stride_width, int padding_depth, int padding_height, int padding_width) const { @@ -112,6 +137,27 @@ class Col2VolFunctor { int channels_col = input_channels * filter_depth * filter_height * filter_width; + PADDLE_ENFORCE_EQ((input_depth + 2 * padding_depth - + ((dilation_d * (filter_depth - 1) + 1))) / + stride_depth + + 1, + output_depth, + "input_depth and output_depth are " + "Mismatching."); + PADDLE_ENFORCE_EQ((input_height + 2 * padding_height - + ((dilation_h * (filter_height - 1) + 1))) / + stride_height + + 1, + output_height, + "input_height and output_height are " + "Mismatching."); + PADDLE_ENFORCE_EQ((input_width + 2 * padding_width - + ((dilation_w * (filter_width - 1) + 1))) / + stride_width + + 1, + output_width, + "input_width and output_width are " + "Mismatching."); T* vol_data = vol.data(); const T* col_data = col.data(); @@ -121,11 +167,13 @@ class Col2VolFunctor { int d_offset = (c / filter_width / filter_height) % filter_depth; int cIm = c / filter_width / filter_height / filter_depth; for (int d = 0; d < output_depth; ++d) { - int d_pad = d * stride_depth - padding_depth + d_offset; + int d_pad = d * stride_depth - padding_depth + d_offset * dilation_d; for (int h = 0; h < output_height; ++h) { - int h_pad = h * stride_height - padding_height + h_offset; + int h_pad = + h * stride_height - padding_height + h_offset * dilation_h; for (int w = 0; w < output_width; ++w) { - int w_pad = w * stride_width - padding_width + w_offset; + int w_pad = + w * stride_width - padding_width + w_offset * dilation_w; if (h_pad >= 0 && h_pad < input_height && w_pad >= 0 && w_pad < input_width && d_pad >= 0 && d_pad < input_depth) { diff --git a/paddle/operators/math/vol2col.cu b/paddle/operators/math/vol2col.cu index 27b11fb237..080d3e5466 100644 --- a/paddle/operators/math/vol2col.cu +++ b/paddle/operators/math/vol2col.cu @@ -21,11 +21,12 @@ namespace math { template __global__ void vol2col(int num_kernels, const T* data_vol, int depth, - int height, int width, int filter_depth, - int filter_height, int filter_width, int stride_depth, - int stride_height, int stride_width, int padding_depth, - int padding_height, int padding_width, int output_detph, - int output_height, int output_width, T* data_col) { + int height, int width, int dilation_d, int dilation_h, + int dilation_w, int filter_depth, int filter_height, + int filter_width, int stride_depth, int stride_height, + int stride_width, int padding_depth, int padding_height, + int padding_width, int output_detph, int output_height, + int output_width, T* data_col) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels; index += blockDim.x * gridDim.x) { int w_out = index % output_width; @@ -44,12 +45,14 @@ __global__ void vol2col(int num_kernels, const T* data_vol, int depth, for (int k = 0; k < filter_depth; ++k) { for (int i = 0; i < filter_height; ++i) { for (int j = 0; j < filter_width; ++j) { - int d = d_in + k; - int h = h_in + i; - int w = w_in + j; + int d = d_in + k * dilation_d; + int h = h_in + i * dilation_h; + int w = w_in + j * dilation_w; + int col_idx = (k * dilation_d * height + i * dilation_h) * width + + j * dilation_w; *data_col = (d >= 0 && d < depth && h >= 0 && h < height && w >= 0 && w < width) - ? data_vol[(k * height + i) * width + j] + ? data_vol[col_idx] : 0; data_col += output_detph * output_height * output_width; } @@ -69,6 +72,7 @@ class Vol2ColFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& vol, framework::Tensor& col, + int dilation_d, int dilation_h, int dilation_w, int stride_depth, int stride_height, int stride_width, int padding_depth, int padding_height, int padding_width) const { @@ -86,6 +90,28 @@ class Vol2ColFunctor { int output_height = col.dims()[5]; int output_width = col.dims()[6]; + PADDLE_ENFORCE_EQ((input_depth + 2 * padding_depth - + ((dilation_d * (filter_depth - 1) + 1))) / + stride_depth + + 1, + output_depth, + "input_depth and output_depth are " + "Mismatching."); + PADDLE_ENFORCE_EQ((input_height + 2 * padding_height - + ((dilation_h * (filter_height - 1) + 1))) / + stride_height + + 1, + output_height, + "input_height and output_height are " + "Mismatching."); + PADDLE_ENFORCE_EQ((input_width + 2 * padding_width - + ((dilation_w * (filter_width - 1) + 1))) / + stride_width + + 1, + output_width, + "input_width and output_width are " + "Mismatching."); + int num_outputs = input_channels * output_depth * output_height * output_width; @@ -95,19 +121,25 @@ class Vol2ColFunctor { reinterpret_cast(context) .stream()>>>( num_outputs, vol.data(), input_depth, input_height, input_width, - filter_depth, filter_height, filter_width, stride_depth, stride_height, - stride_width, padding_depth, padding_height, padding_width, - output_depth, output_height, output_width, col.data()); + dilation_d, dilation_h, dilation_w, filter_depth, filter_height, + filter_width, stride_depth, stride_height, stride_width, padding_depth, + padding_height, padding_width, output_depth, output_height, + output_width, col.data()); } }; template __global__ void col2vol(int num_kernels, const T* data_col, int depth, - int height, int width, int filter_depth, - int filter_height, int filter_width, int stride_depth, - int stride_height, int stride_width, int padding_depth, - int padding_height, int padding_width, int output_detph, - int output_height, int output_width, T* data_vol) { + int height, int width, int dilation_d, int dilation_h, + int dilation_w, int filter_depth, int filter_height, + int filter_width, int stride_depth, int stride_height, + int stride_width, int padding_depth, int padding_height, + int padding_width, int output_detph, int output_height, + int output_width, T* data_vol) { + const int d_filter_depth = dilation_d * (filter_depth - 1) + 1; + const int d_filter_height = dilation_h * (filter_height - 1) + 1; + const int d_filter_width = dilation_w * (filter_width - 1) + 1; + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels; index += blockDim.x * gridDim.x) { T src_val = 0; @@ -115,35 +147,42 @@ __global__ void col2vol(int num_kernels, const T* data_col, int depth, int h = (index / width) % height + padding_height; int d = (index / width / height) % depth + padding_depth; int c = index / width / height / depth; + // compute the start and end of the output int w_col_start = - (w < filter_width) ? 0 : (w - filter_width) / stride_width + 1; + (w < d_filter_width) ? 0 : (w - d_filter_width) / stride_width + 1; int w_col_end = min(w / stride_width + 1, output_width); int h_col_start = - (h < filter_height) ? 0 : (h - filter_height) / stride_height + 1; + (h < d_filter_height) ? 0 : (h - d_filter_height) / stride_height + 1; int h_col_end = min(h / stride_height + 1, output_height); int d_col_start = - (d < filter_depth) ? 0 : (d - filter_depth) / stride_depth + 1; + (d < d_filter_depth) ? 0 : (d - d_filter_depth) / stride_depth + 1; int d_col_end = min(d / stride_depth + 1, output_detph); - int offset = (c * filter_depth * filter_height * filter_width + - d * filter_width * filter_height + h * filter_width + w) * - output_detph * output_height * output_width; - - int coeff_d_col = - (1 - stride_depth * filter_width * filter_height * output_detph) * - output_height * output_width; - int coeff_h_col = - (1 - stride_height * filter_width * output_detph * output_height) * - output_width; - int coeff_w_col = - (1 - stride_width * output_detph * output_height * output_width); - for (int d_col = d_col_start; d_col < d_col_end; ++d_col) { for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { - src_val += data_col[offset + d_col * coeff_d_col + - h_col * coeff_h_col + w_col * coeff_w_col]; + int d_off = (d - d_col * stride_depth); + int h_off = (h - h_col * stride_height); + int w_off = (w - w_col * stride_width); + if (d_off % dilation_d == 0 && h_off % dilation_h == 0 && + w_off % dilation_w == 0) { + d_off /= dilation_d; + h_off /= dilation_h; + w_off /= dilation_w; + + int data_col_index = + (((((c * filter_depth + d_off) * filter_height + h_off) * + filter_width + + w_off) * + output_detph + + d_col) * + output_height + + h_col) * + output_width + + w_col; + src_val += data_col[data_col_index]; + } } } } @@ -162,6 +201,7 @@ class Col2VolFunctor { public: void operator()(const platform::DeviceContext& context, framework::Tensor& vol, const framework::Tensor& col, + int dilation_d, int dilation_h, int dilation_w, int stride_depth, int stride_height, int stride_width, int padding_depth, int padding_height, int padding_width) const { @@ -179,6 +219,28 @@ class Col2VolFunctor { int output_height = col.dims()[5]; int output_width = col.dims()[6]; + PADDLE_ENFORCE_EQ((input_depth + 2 * padding_depth - + ((dilation_d * (filter_depth - 1) + 1))) / + stride_depth + + 1, + output_depth, + "input_depth and output_depth are " + "Mismatching."); + PADDLE_ENFORCE_EQ((input_height + 2 * padding_height - + ((dilation_h * (filter_height - 1) + 1))) / + stride_height + + 1, + output_height, + "input_height and output_height are " + "Mismatching."); + PADDLE_ENFORCE_EQ((input_width + 2 * padding_width - + ((dilation_w * (filter_width - 1) + 1))) / + stride_width + + 1, + output_width, + "input_width and output_width are " + "Mismatching."); + int num_kernels = input_channels * input_depth * input_height * input_width; const int threads = 1024; @@ -188,9 +250,10 @@ class Col2VolFunctor { reinterpret_cast(context) .stream()>>>( num_kernels, col.data(), input_depth, input_height, input_width, - filter_depth, filter_height, filter_width, stride_depth, stride_height, - stride_width, padding_depth, padding_height, padding_width, - output_depth, output_height, output_width, vol.data()); + dilation_d, dilation_h, dilation_w, filter_depth, filter_height, + filter_width, stride_depth, stride_height, stride_width, padding_depth, + padding_height, padding_width, output_depth, output_height, + output_width, vol.data()); } }; diff --git a/paddle/operators/math/vol2col.h b/paddle/operators/math/vol2col.h index f022365a16..c2d8257c0b 100644 --- a/paddle/operators/math/vol2col.h +++ b/paddle/operators/math/vol2col.h @@ -58,6 +58,7 @@ class Vol2ColFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& vol, framework::Tensor& col, + int dilation_d, int dilation_h, int dilation_w, int stride_depth, int stride_height, int stride_width, int padding_depth, int padding_height, int padding_width) const; @@ -68,6 +69,7 @@ class Col2VolFunctor { public: void operator()(const platform::DeviceContext& context, framework::Tensor& vol, const framework::Tensor& col, + int dilation_d, int dilation_h, int dilation_w, int stride_depth, int stride_height, int stride_width, int padding_depth, int padding_height, int padding_width) const; diff --git a/paddle/operators/math/vol2col_test.cc b/paddle/operators/math/vol2col_test.cc index 74590d17cd..9d673ad36c 100644 --- a/paddle/operators/math/vol2col_test.cc +++ b/paddle/operators/math/vol2col_test.cc @@ -64,6 +64,7 @@ void testVol2col() { int filter_size = 2; int stride = 1; int padding = 0; + int dilation = 1; int output_depth = (input_depth - filter_size + 2 * padding) / stride + 1; int output_height = (input_height - filter_size + 2 * padding) / stride + 1; int output_width = (input_width - filter_size + 2 * padding) / stride + 1; @@ -85,8 +86,8 @@ void testVol2col() { *place); paddle::operators::math::Vol2ColFunctor vol2col; - vol2col(*context, input, output, stride, stride, stride, padding, padding, - padding); + vol2col(*context, input, output, dilation, dilation, dilation, stride, stride, + stride, padding, padding, padding); float vol_2_col[] = {0, 1, 1, 2, 3, 4, 4, 5, 6, 7, 7, 8, 9, 10, 10, 11}; float* out_cfo_ptr; @@ -111,8 +112,8 @@ void testVol2col() { } paddle::operators::math::Col2VolFunctor col2vol; - col2vol(*context, input, output, stride, stride, stride, padding, padding, - padding); + col2vol(*context, input, output, dilation, dilation, dilation, stride, stride, + stride, padding, padding, padding); float* in_ptr; if (paddle::platform::is_cpu_place(*place)) { -- GitLab