From aa770198c72c115310e6075ebd403878154fbf0f Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Fri, 8 Dec 2017 17:36:46 +0800 Subject: [PATCH] add dilation in c++ code --- paddle/operators/conv_transpose_op.cc | 7 ++++++- paddle/operators/conv_transpose_op.h | 14 ++++++-------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/paddle/operators/conv_transpose_op.cc b/paddle/operators/conv_transpose_op.cc index e900ad452ea..c31a2e4a708 100644 --- a/paddle/operators/conv_transpose_op.cc +++ b/paddle/operators/conv_transpose_op.cc @@ -29,6 +29,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const { auto filter_dims = ctx->GetInputDim("Filter"); std::vector strides = ctx->Attrs().Get>("strides"); std::vector paddings = ctx->Attrs().Get>("paddings"); + std::vector dilations = ctx->Attrs().Get>("dilations"); PADDLE_ENFORCE(in_dims.size() == 4 || in_dims.size() == 5, "ConvTransposeOp intput should be 4-D or 5-D tensor."); @@ -41,14 +42,18 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const { PADDLE_ENFORCE_EQ(paddings.size(), strides.size(), "ConvTransposeOp paddings dimension and strides " "dimension should be the same."); + PADDLE_ENFORCE_EQ(paddings.size(), dilations.size(), + "ConvTransposeOp paddings dimension and dilations " + "dimension should be the same."); PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[0], "In ConvTransposeOp, The input channel should be the same " "as the number of filters."); std::vector output_shape({in_dims[0], filter_dims[1]}); for (size_t i = 0; i < strides.size(); ++i) { + auto filter_extent = dilations[i] * (filter_dims[i + 2] - 1) + 1; output_shape.push_back((in_dims[i + 2] - 1) * strides[i] - 2 * paddings[i] + - filter_dims[i + 2]); + filter_extent); } ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); } diff --git a/paddle/operators/conv_transpose_op.h b/paddle/operators/conv_transpose_op.h index 1cacb770e6a..65a0076d9ca 100644 --- a/paddle/operators/conv_transpose_op.h +++ b/paddle/operators/conv_transpose_op.h @@ -63,6 +63,7 @@ class GemmConvTransposeKernel : public framework::OpKernel { std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); + std::vector dilations = context.Attr>("dilations"); // groups will alway be disabled in conv2dtranspose. const int batch_size = static_cast(input->dims()[0]); @@ -114,7 +115,6 @@ class GemmConvTransposeKernel : public framework::OpKernel { math::Col2ImFunctor col2im; math::Col2VolFunctor col2vol; - std::vector dilations({1, 1, 1}); // convolution transpose: gemm + col2im or col2vol (similar to conv-backward // on input) @@ -134,8 +134,7 @@ class GemmConvTransposeKernel : public framework::OpKernel { if (data_dim == 2U) { // col2im: col_matrix -> dy // from (c * k_h * k_w, h * w) to (c, o_h, o_w) - col2im(context.device_context(), col, - std::vector{dilations[0], dilations[1]}, strides, + col2im(context.device_context(), col, dilations, strides, std::vector{paddings[0], paddings[1], paddings[0], paddings[1]}, &output_batch); @@ -168,6 +167,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); + std::vector dilations = context.Attr>("dilations"); const int batch_size = static_cast(input->dims()[0]); @@ -221,7 +221,6 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { math::Im2ColFunctor im2col; math::Vol2ColFunctor vol2col; - std::vector dilations({1, 1, 1}); if (input_grad) { input_grad->mutable_data(context.GetPlace()); @@ -242,10 +241,9 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { if (data_dim == 2U) { // im2col: dy -> col matrix // from (c, o_h, o_w) to (c * k_h * k_w, h * w) - im2col(context.device_context(), output_grad_batch, - std::vector{dilations[0], dilations[1]}, strides, - std::vector{paddings[0], paddings[1], paddings[0], - paddings[1]}, + im2col(context.device_context(), output_grad_batch, dilations, + strides, std::vector{paddings[0], paddings[1], + paddings[0], paddings[1]}, &col); } else if (data_dim == 3U) { // vol2col: dy -> col_matrix -- GitLab