diff --git a/paddle/operators/conv2d_op.h b/paddle/operators/conv2d_op.h index f629728f68d65ce81b4910cae7f89ab06d6d94b8..0621389a79eee6b5e75b1eab309b49f8aa4a97ca 100644 --- a/paddle/operators/conv2d_op.h +++ b/paddle/operators/conv2d_op.h @@ -114,7 +114,7 @@ class GemmConv2DKernel : public framework::OpKernel { // im2col Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); im2col(context.device_context(), in_slice, col, strides[0], strides[1], - paddings[0], paddings[1]); + paddings[0], paddings[0], paddings[1], paddings[1]); // gemm Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); @@ -213,7 +213,8 @@ class GemmConvGrad2DKernel : public framework::OpKernel { Tensor in_grad_slice = in_grad_batch.Slice(g * in_step, (g + 1) * in_step); col2im(context.device_context(), in_grad_slice, col, strides[0], - strides[1], paddings[0], paddings[1]); + strides[1], paddings[0], paddings[0], paddings[1], + paddings[1]); } } } @@ -235,7 +236,8 @@ class GemmConvGrad2DKernel : public framework::OpKernel { out_grad_batch.Slice(g * out_step, (g + 1) * out_step); Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); im2col(context.device_context(), in_slice, col, strides[0], - strides[1], paddings[0], paddings[1]); + strides[1], paddings[0], paddings[0], paddings[1], + paddings[1]); // gemm Tensor filter_grad_slice = diff --git a/paddle/operators/math/im2col.cc b/paddle/operators/math/im2col.cc index c08a3380f042886cd400df0d840e61856274619c..3b1b0bd71dd3768b932864e185af8dc839b4653e 100644 --- a/paddle/operators/math/im2col.cc +++ b/paddle/operators/math/im2col.cc @@ -29,8 +29,8 @@ class Im2ColFunctor(); @@ -52,16 +68,14 @@ class Im2ColFunctor= input_height || - (im_col_idx - padding_width) < 0 || - (im_col_idx - padding_width) >= input_width) { + int im_row_idx = h * stride_height + h_offset - padding_up; + int im_col_idx = w * stride_width + w_offset - padding_left; + + if (im_row_idx < 0 || im_row_idx >= input_height || im_col_idx < 0 || + im_col_idx >= input_width) { col_data[(c * output_height + h) * output_width + w] = T(0); } else { - im_row_idx += c_im * input_height - padding_height; - im_col_idx -= padding_width; + im_row_idx += c_im * input_height; col_data[(c * output_height + h) * output_width + w] = im_data[im_row_idx * input_width + im_col_idx]; } @@ -82,7 +96,8 @@ class Col2ImFunctor(); @@ -103,14 +134,12 @@ class Col2ImFunctor= 0 && - (im_row_idx - padding_height) < input_height && - (im_col_idx - padding_width) >= 0 && - (im_col_idx - padding_width) < input_width) { - im_row_idx += c_im * input_height - padding_height; - im_col_idx -= padding_width; + int im_row_idx = h * stride_height + h_offset - padding_up; + int im_col_idx = w * stride_width + w_offset - padding_left; + + if ((im_row_idx) >= 0 && (im_row_idx) < input_height && + (im_col_idx) >= 0 && (im_col_idx) < input_width) { + im_row_idx += c_im * input_height; im_data[im_row_idx * input_width + im_col_idx] += col_data[(c * output_height + h) * output_width + w]; } @@ -140,8 +169,8 @@ class Im2ColFunctor(); T* col_data = col.data(); @@ -163,10 +207,10 @@ class Im2ColFunctor(); const T* col_data = col.data(); @@ -223,9 +283,9 @@ class Col2ImFunctor(context) .stream()>>>( im.data(), num_outputs, input_height, input_width, filter_height, - filter_width, stride_height, stride_width, padding_height, - padding_width, output_height, output_width, col.data()); + filter_width, stride_height, stride_width, padding_up, padding_left, + output_height, output_width, col.data()); } }; @@ -152,7 +161,8 @@ class Col2ImFunctor<<(context) .stream()>>>( - num_kernels, col.data(), input_height + 2 * padding_height, - input_width + 2 * padding_width, input_channels, filter_height, - filter_width, stride_height, stride_width, padding_height, - padding_width, output_height, output_width, im.data()); + num_kernels, col.data(), input_height + padding_up + padding_down, + input_width + padding_left + padding_left, input_channels, + filter_height, filter_width, stride_height, stride_width, padding_up, + padding_left, output_height, output_width, im.data()); } }; @@ -238,8 +258,8 @@ class Im2ColFunctor(context) .stream()>>>( im.data(), col.data(), input_channels, input_height, input_width, - filter_height, filter_width, stride_height, stride_width, - padding_height, padding_width, output_height, output_width); + filter_height, filter_width, stride_height, stride_width, padding_up, + padding_left, output_height, output_width); } }; @@ -322,7 +351,8 @@ class Col2ImFunctor(context) .stream()>>>( im.data(), col.data(), input_channels, input_height, input_width, - filter_height, filter_width, stride_height, stride_width, - padding_height, padding_width, output_height, output_width); + filter_height, filter_width, stride_height, stride_width, padding_up, + padding_left, output_height, output_width); } }; diff --git a/paddle/operators/math/im2col.h b/paddle/operators/math/im2col.h index 7b717e1603c94cd77c74cb0d86f1d23e2692f9d8..c736d4fa523c2af3e3dd7a11114d7f84021bc5c1 100644 --- a/paddle/operators/math/im2col.h +++ b/paddle/operators/math/im2col.h @@ -74,8 +74,8 @@ class Im2ColFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& im, framework::Tensor& col, - int stride_height, int stride_width, int padding_height, - int padding_width); + int stride_height, int stride_width, int padding_up, + int padding_down, int padding_left, int padding_right); }; template @@ -83,7 +83,8 @@ class Col2ImFunctor { public: void operator()(const platform::DeviceContext& context, framework::Tensor& im, const framework::Tensor& col, int stride_height, - int stride_width, int padding_height, int padding_width); + int stride_width, int padding_up, int padding_down, + int padding_left, int padding_right); }; } // namespace math diff --git a/paddle/operators/math/im2col_test.cc b/paddle/operators/math/im2col_test.cc index 443c94b83f0bf24837afe703b19e2ab47a0dd786..5763782c4edec87f44dabef2ccffe3097eeb2421 100644 --- a/paddle/operators/math/im2col_test.cc +++ b/paddle/operators/math/im2col_test.cc @@ -35,6 +35,12 @@ void testIm2col() { * * output_ocf = [0, 1, 3, 4 * 1, 2, 4, 5] + * + * col2im_cfo = [0, 2, 2 + * 3, 4, 5] + * + * col2im_ocf = [0, 2, 2 + * 3, 4, 5] */ int input_height = 2; int input_width = 3; @@ -59,7 +65,7 @@ void testIm2col() { new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace()); #else PADDLE_THROW("no GPU support"); -#endif // PADDLE_ONLY_CPU +#endif // PADDLE_WITH_CUDA } if (paddle::platform::is_cpu_place(*place)) { input = input_tmp; @@ -71,6 +77,7 @@ void testIm2col() { output_ocf.mutable_data( {output_height, output_width, 1, filter_size, filter_size}, *place); + // Im2Col paddle::operators::math::Im2ColFunctor< paddle::operators::math::ColFormat::kCFO, Place, float> im2col; @@ -78,8 +85,13 @@ void testIm2col() { paddle::operators::math::ColFormat::kOCF, Place, float> im2col_ocf; - im2col(*context, input, output_cfo, stride, stride, padding, padding); - im2col_ocf(*context, input, output_ocf, stride, stride, padding, padding); + im2col(*context, input, output_cfo, stride, stride, padding, padding, padding, + padding); + im2col_ocf(*context, input, output_ocf, stride, stride, padding, padding, + padding, padding); + + float out_cfo_data[] = {0, 1, 1, 2, 3, 4, 4, 5}; + float out_ocf_data[] = {0, 1, 3, 4, 1, 2, 4, 5}; float* out_cfo_ptr; if (paddle::platform::is_cpu_place(*place)) { @@ -88,14 +100,9 @@ void testIm2col() { output_tmp.CopyFrom(output_cfo, paddle::platform::CPUPlace(), *context); out_cfo_ptr = output_tmp.data(); } - EXPECT_EQ(out_cfo_ptr[0], 0); - EXPECT_EQ(out_cfo_ptr[1], 1); - EXPECT_EQ(out_cfo_ptr[2], 1); - EXPECT_EQ(out_cfo_ptr[3], 2); - EXPECT_EQ(out_cfo_ptr[4], 3); - EXPECT_EQ(out_cfo_ptr[5], 4); - EXPECT_EQ(out_cfo_ptr[6], 4); - EXPECT_EQ(out_cfo_ptr[7], 5); + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(out_cfo_ptr[i], out_cfo_data[i]); + } float* out_ocf_ptr; if (paddle::platform::is_cpu_place(*place)) { @@ -104,14 +111,60 @@ void testIm2col() { output_tmp.CopyFrom(output_ocf, paddle::platform::CPUPlace(), *context); out_ocf_ptr = output_tmp.data(); } - EXPECT_EQ(out_ocf_ptr[0], 0); - EXPECT_EQ(out_ocf_ptr[1], 1); - EXPECT_EQ(out_ocf_ptr[2], 3); - EXPECT_EQ(out_ocf_ptr[3], 4); - EXPECT_EQ(out_ocf_ptr[4], 1); - EXPECT_EQ(out_ocf_ptr[5], 2); - EXPECT_EQ(out_ocf_ptr[6], 4); - EXPECT_EQ(out_ocf_ptr[7], 5); + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(out_ocf_ptr[i], out_ocf_data[i]); + } + + // Col2Im: kCFO + paddle::operators::math::Col2ImFunctor< + paddle::operators::math::ColFormat::kCFO, Place, float> + col2im; + paddle::operators::math::Col2ImFunctor< + paddle::operators::math::ColFormat::kOCF, Place, float> + col2im_ocf; + float col2im_data[] = {0, 2, 2, 3, 8, 5}; + + memset(input_ptr, 0, 6 * sizeof(float)); + if (paddle::platform::is_cpu_place(*place)) { + input = input_tmp; + } else { + input.CopyFrom(input_tmp, *place, *context); + } + + col2im(*context, input, output_cfo, stride, stride, padding, padding, padding, + padding); + + float* in_ptr; + if (paddle::platform::is_cpu_place(*place)) { + in_ptr = input.data(); + } else { + input_tmp.CopyFrom(input, paddle::platform::CPUPlace(), *context); + in_ptr = input_tmp.data(); + } + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(in_ptr[i], col2im_data[i]); + } + + // Col2Im: kOCF + memset(input_ptr, 0, 6 * sizeof(float)); + if (paddle::platform::is_cpu_place(*place)) { + input = input_tmp; + } else { + input.CopyFrom(input_tmp, *place, *context); + } + + col2im_ocf(*context, input, output_ocf, stride, stride, padding, padding, + padding, padding); + + if (paddle::platform::is_cpu_place(*place)) { + in_ptr = input.data(); + } else { + input_tmp.CopyFrom(input, paddle::platform::CPUPlace(), *context); + in_ptr = input_tmp.data(); + } + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(in_ptr[i], col2im_data[i]); + } } TEST(math, im2col) {