From 61dbf4b926a3f911b55b7d23e35d7ddc4f39a414 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Tue, 24 Oct 2017 13:45:39 +0800 Subject: [PATCH] fix im2col interface --- paddle/operators/conv2dtranspose_op.h | 6 +++--- paddle/operators/math/im2col.h | 5 ++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/paddle/operators/conv2dtranspose_op.h b/paddle/operators/conv2dtranspose_op.h index 293b7ce9ba..8c70b3dcec 100644 --- a/paddle/operators/conv2dtranspose_op.h +++ b/paddle/operators/conv2dtranspose_op.h @@ -120,7 +120,7 @@ class GemmConv2DTransposeKernel : public framework::OpKernel { math::matmul(context.device_context(), filter, true, input_batch, false, T(1.0), &col_matrix, T(0.0)); col2im(context.device_context(), output_batch, col, strides[0], - strides[1], 0, 0); + strides[1], 0, 0, 0, 0); } } }; @@ -206,7 +206,7 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel { // im2col: dy from (c, o_h, o_w) -> (c * k_h * k_w, h * w) im2col(context.device_context(), output_grad_batch, col, strides[0], - strides[1], paddings[0], paddings[1]); + strides[1], paddings[0], paddings[0], paddings[1], paddings[1]); // gemm: dx = filter * dy // (m, c * k_h * k_w) * (c * k_h * k_w, h * w) -> (m, c, h) @@ -238,7 +238,7 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel { // im2col: (c * h * w, k_h * k_w) im2col(context.device_context(), output_grad_batch, col, strides[0], - strides[1], paddings[0], paddings[1]); + strides[1], paddings[0], paddings[0], paddings[1], paddings[1]); // gemm: d_filter = x * y_grad^T // (m, c * h * w) * (k_h * k_w, c * h * w) -> (m, c, h) diff --git a/paddle/operators/math/im2col.h b/paddle/operators/math/im2col.h index a1cb956c51..c736d4fa52 100644 --- a/paddle/operators/math/im2col.h +++ b/paddle/operators/math/im2col.h @@ -75,8 +75,7 @@ class Im2ColFunctor { void operator()(const platform::DeviceContext& context, const framework::Tensor& im, framework::Tensor& col, int stride_height, int stride_width, int padding_up, - int padding_down, int padding_left = 0, - int padding_right = 0); + int padding_down, int padding_left, int padding_right); }; template @@ -85,7 +84,7 @@ class Col2ImFunctor { void operator()(const platform::DeviceContext& context, framework::Tensor& im, const framework::Tensor& col, int stride_height, int stride_width, int padding_up, int padding_down, - int padding_left = 0, int padding_right = 0); + int padding_left, int padding_right); }; } // namespace math -- GitLab