From 64c5ecbedba5bfb5eea3a5fbed63ed628968a042 Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Fri, 20 Oct 2017 14:46:30 -0700 Subject: [PATCH] deconv --- paddle/operators/deconv2d_op.cc | 52 +++++++------- paddle/operators/deconv2d_op.cu | 7 +- paddle/operators/deconv2d_op.h | 118 ++++++++++++++++---------------- 3 files changed, 92 insertions(+), 85 deletions(-) diff --git a/paddle/operators/deconv2d_op.cc b/paddle/operators/deconv2d_op.cc index 8481aefdc1a..98a47f02b41 100644 --- a/paddle/operators/deconv2d_op.cc +++ b/paddle/operators/deconv2d_op.cc @@ -18,13 +18,13 @@ namespace paddle { namespace operators { -void Deconv2DOp::InferShape(framework::InferShapeContext* ctx) const { +void Conv2DTransposeOp::InferShape(framework::InferShapeContext* ctx) const { PADDLE_ENFORCE(ctx->HasInput("Input"), - "Input(Input) of Deconv2DOp should not be null."); + "Input(Input) of Conv2DTransposeOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Filter"), - "Input(Filter) of Deconv2DOp should not be null."); + "Input(Filter) of Conv2DTransposeOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Output"), - "Output(Output) of Deconv2DOp should not be null."); + "Output(Output) of Conv2DTransposeOp should not be null."); auto in_dims = ctx->GetInputDim("Input"); auto filter_dims = ctx->GetInputDim("Filter"); @@ -32,13 +32,14 @@ void Deconv2DOp::InferShape(framework::InferShapeContext* ctx) const { std::vector paddings = ctx->Attrs().Get>("paddings"); for (size_t i = 0; i < paddings.size(); ++i) { - PADDLE_ENFORCE_EQ(paddings[i], 0, "No Padding allowed in deconv op."); + PADDLE_ENFORCE_EQ(paddings[i], 0, + "No Padding allowed in conv transpose op."); } PADDLE_ENFORCE_EQ(in_dims.size(), 4, - "Deconv2DOp input should be 4-D tensor."); + "Conv2DTransposeOp input should be 4-D tensor."); PADDLE_ENFORCE_EQ(filter_dims.size(), 4, - "Deconv2DOp filter should be 4-D tensor."); + "Conv2DTransposeOp filter should be 4-D tensor."); PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[0], "input and kernel input dimension should be equal."); @@ -48,36 +49,39 @@ void Deconv2DOp::InferShape(framework::InferShapeContext* ctx) const { {in_dims[0], filter_dims[1], output_height, output_width}); } -Deconv2DOpMaker::Deconv2DOpMaker(framework::OpProto* proto, - framework::OpAttrChecker* op_checker) +Conv2DTransposeOpMaker::Conv2DTransposeOpMaker( + framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput( "Input", - "The input tensor of deconvolution operator. " + "The input tensor of convolution transpose operator. " "The format of input tensor is NCHW. Where N is batch size, C is the " "number of input channels, H and W is the height and width of image."); AddInput("Filter", - "The filter tensor of deconvolution operator." - "The format of the filter tensor is MCHW, where C is the number of " + "The filter tensor of convolution transpose operator." + "The format of the filter tensor is CMHW, where C is the number of " "output image channels, M is the number of input image channels, " "H and W is height and width of filter. " "We enforce groups number == 1 and padding == 0 in " - "deconvolution Scenario."); + "convolution transpose Scenario."); AddOutput("Output", - "The output tensor of deconvolution operator." + "The output tensor of convolution transpose operator." "The format of output tensor is also NCHW."); - AddAttr>("strides", "strides of deconvolution operator.") + AddAttr>("strides", + "strides of convolution transpose operator.") .SetDefault({1, 1}); - AddAttr>("paddings", "paddings of deconvolution operator.") + AddAttr>("paddings", + "paddings of convolution transpose operator.") .SetDefault({0, 0}); AddComment(R"DOC( -The deconvolution operation calculates the output based on the input, filter +The convolution transpose operation calculates the output based on the input, filter and strides, paddings, groups parameters. The size of each dimension of the parameters is checked in the infer-shape. )DOC"); } -void Deconv2DOpGrad::InferShape(framework::InferShapeContext* ctx) const { +void Conv2DTransposeOpGrad::InferShape( + framework::InferShapeContext* ctx) const { auto in_dims = ctx->GetInputDim("Input"); auto filter_dims = ctx->GetInputDim("Filter"); if (ctx->HasOutput(framework::GradVarName("Input"))) { @@ -92,11 +96,13 @@ void Deconv2DOpGrad::InferShape(framework::InferShapeContext* ctx) const { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(deconv2d, ops::Deconv2DOp, ops::Deconv2DOpMaker, deconv2d_grad, - ops::Deconv2DOpGrad); +REGISTER_OP(conv2dtranspose, ops::Conv2DTransposeOp, + ops::Conv2DTransposeOpMaker, conv2dtranspose_grad, + ops::Conv2DTransposeOpGrad); REGISTER_OP_CPU_KERNEL( - deconv2d, ops::GemmDeconv2DKernel); + conv2dtranspose, + ops::GemmConv2DTransposeKernel); REGISTER_OP_CPU_KERNEL( - deconv2d_grad, - ops::GemmDeconvGrad2DKernel); + conv2dtranspose_grad, + ops::GemmConv2DTransposeGradKernel); diff --git a/paddle/operators/deconv2d_op.cu b/paddle/operators/deconv2d_op.cu index b117e7eeef8..660ec32e353 100644 --- a/paddle/operators/deconv2d_op.cu +++ b/paddle/operators/deconv2d_op.cu @@ -17,7 +17,8 @@ namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL( - deconv2d, ops::GemmDeconv2DKernel); + conv2dtranspose, + ops::GemmConv2DTransposeKernel); REGISTER_OP_GPU_KERNEL( - deconv2d_grad, - ops::GemmDeconvGrad2DKernel); + conv2dtranspose_grad, + ops::GemmConv2DTransposeGradKernel); diff --git a/paddle/operators/deconv2d_op.h b/paddle/operators/deconv2d_op.h index 973190efab4..91bf6193b21 100644 --- a/paddle/operators/deconv2d_op.h +++ b/paddle/operators/deconv2d_op.h @@ -26,15 +26,15 @@ namespace operators { using Tensor = framework::Tensor; using DDim = framework::DDim; -// Define Op classes in .h file so that other deconv +// Define Op classes in .h file so that other conv transpose // operator implementations can reuse the code. -class Deconv2DOpMaker : public framework::OpProtoAndCheckerMaker { +class Conv2DTransposeOpMaker : public framework::OpProtoAndCheckerMaker { public: - Deconv2DOpMaker(framework::OpProto* proto, - framework::OpAttrChecker* op_checker); + Conv2DTransposeOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker); }; -class Deconv2DOp : public framework::OperatorWithKernel { +class Conv2DTransposeOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -42,7 +42,7 @@ class Deconv2DOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; }; -class Deconv2DOpGrad : public framework::OperatorWithKernel { +class Conv2DTransposeOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -51,7 +51,7 @@ class Deconv2DOpGrad : public framework::OperatorWithKernel { }; template -class GemmDeconv2DKernel : public framework::OpKernel { +class GemmConv2DTransposeKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { const Tensor* input = context.Input("Input"); @@ -64,27 +64,27 @@ class GemmDeconv2DKernel : public framework::OpKernel { // no paddings and groups allowed in deconv - int N = input->dims()[0]; - int M = input->dims()[1]; - int H = input->dims()[2]; - int W = input->dims()[3]; + const int batch_size = input->dims()[0]; + const int m = input->dims()[1]; + const int h = input->dims()[2]; + const int w = input->dims()[3]; - int K_H = filter.dims()[2]; - int K_W = filter.dims()[3]; + const int k_h = filter.dims()[2]; + const int k_w = filter.dims()[3]; - int C = output->dims()[1]; // output channels - int O_H = output->dims()[2]; - int O_W = output->dims()[3]; + const int c = output->dims()[1]; // output channels + const int o_h = output->dims()[2]; + const int o_w = output->dims()[3]; paddle::operators::math::Col2ImFunctor< paddle::operators::math::ColFormat::kCFO, Place, T> col2im; // use col_shape in the im2col and col2im calculation - DDim col_shape = {C, K_H, K_W, H, W}; + DDim col_shape = {c, k_h, k_w, h, w}; // use col_matrix_shape in the gemm calculation - DDim col_matrix_shape = {C * K_H * K_W, H * W}; + DDim col_matrix_shape = {c * k_h * k_w, h * w}; Tensor col; col.mutable_data(col_shape, context.GetPlace()); @@ -94,10 +94,10 @@ class GemmDeconv2DKernel : public framework::OpKernel { Tensor col_matrix = col; col_matrix.Resize(col_matrix_shape); - DDim output_shape = {C, O_H, O_W}; - DDim input_matrix_shape = {M, H * W}; + DDim output_shape = {c, o_h, o_w}; + DDim input_matrix_shape = {m, h * w}; - DDim filter_matrix_shape = {M, C * K_H * K_W}; + DDim filter_matrix_shape = {m, c * k_h * k_w}; filter.Resize(filter_matrix_shape); // deconvolution: gemm + col2im (similar to conv-backward on input) @@ -106,16 +106,16 @@ class GemmDeconv2DKernel : public framework::OpKernel { auto t = framework::EigenVector::Flatten(*output); t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); - for (int i = 0; i < N; i++) { - // batch with size (M, H * W) - Tensor input_batch = input->Slice(i, i + 1).Resize(input_matrix_shape); - // filter size: (M, C * K_H * K_W) + for (int i = 0; i < batch_size; i++) { + // batch with size (M, h * w) + Tensor input_batch = input->Slice(i, i + 1).Resize(input_matrix_shape); + // filter size: (M, c * k_h * k_w) - // output size: (C, O_H, O_W) - Tensor output_batch = output->Slice(i, i + 1).Resize(output_shape); + // output size: (c, o_h, o_w) + Tensor output_batch = output->Slice(i, i + 1).Resize(output_shape); // col_matrix = filter * input_batch - // of shape (C * K_H * K_W, H * W) + // of shape (c * k_h * k_w, h * w) math::matmul(context.device_context(), filter, true, input_batch, false, T(1.0), &col_matrix, T(0.0)); col2im(context.device_context(), output_batch, col, strides[0], @@ -125,7 +125,7 @@ class GemmDeconv2DKernel : public framework::OpKernel { }; template -class GemmDeconvGrad2DKernel : public framework::OpKernel { +class GemmConv2DTransposeGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { const Tensor* input = context.Input("Input"); @@ -145,17 +145,17 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel { // Actually, no paddings and groups allowed in deconv. std::vector paddings = context.Attr>("paddings"); - int N = input->dims()[0]; - int M = input->dims()[1]; - int H = input->dims()[2]; - int W = input->dims()[3]; + const int batch_size = input->dims()[0]; + const int m = input->dims()[1]; + const int h = input->dims()[2]; + const int w = input->dims()[3]; - int K_H = filter.dims()[2]; - int K_W = filter.dims()[3]; + const int k_h = filter.dims()[2]; + const int k_w = filter.dims()[3]; - int C = output_grad->dims()[1]; // output channels - int O_H = output_grad->dims()[2]; - int O_W = output_grad->dims()[3]; + const int c = output_grad->dims()[1]; // output channels + const int o_h = output_grad->dims()[2]; + const int o_w = output_grad->dims()[3]; // Only im2col functor required for bp to get to the right shape paddle::operators::math::Im2ColFunctor< @@ -163,10 +163,10 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel { im2col; // use col_shape in the im2col and col2im calculation - DDim col_shape = {C, K_H, K_W, H, W}; + DDim col_shape = {c, k_h, k_w, h, w}; // use col_matrix_shape in the gemm calculation - DDim col_matrix_shape_f = {C * H * W, K_H * K_W}; + DDim col_matrix_shape_f = {c * h * w, k_h * k_w}; Tensor col; col.mutable_data(col_shape, context.GetPlace()); @@ -174,10 +174,10 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel { // but will be reshaped into a two-dimensional matrix shape // to call the matrix multiplication interface. - DDim output_shape = {C, O_H, O_W}; - DDim input_matrix_shape = {M, H * W}; + DDim output_shape = {c, o_h, o_w}; + DDim input_matrix_shape = {m, h * w}; - DDim filter_matrix_shape = {M, C * K_H * K_W}; + DDim filter_matrix_shape = {m, c * k_h * k_w}; filter.Resize(filter_matrix_shape); // deconvolution grad on input: @@ -185,29 +185,29 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel { // input need to compute gradient if (input_grad) { Tensor col_matrix = col; - DDim col_matrix_shape = {C * K_H * K_W, H * W}; + DDim col_matrix_shape = {c * k_h * k_w, h * w}; col_matrix.Resize(col_matrix_shape); input_grad->mutable_data(context.GetPlace()); auto t = framework::EigenVector::Flatten(*input_grad); t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); - for (int i = 0; i < N; i++) { - // batch with size (C, O_H * O_W) + for (int i = 0; i < batch_size; i++) { + // batch with size (c, o_h * o_w) Tensor output_grad_batch = - output_grad->Slice(i, i + 1).Resize(output_shape); - // filter of size (M, C * K_H * K_W) + output_grad->Slice(i, i + 1).Resize(output_shape); + // filter of size (m, c * k_h * k_w) - // batch with size (M, H, W) + // batch with size (m, h, w) Tensor input_grad_batch = - input_grad->Slice(i, i + 1).Resize(input_matrix_shape); + input_grad->Slice(i, i + 1).Resize(input_matrix_shape); - // im2col: dy from (C, O_H, O_W) -> (C * K_H * K_W, H * W) + // im2col: dy from (c, o_h, o_w) -> (c * k_h * k_w, h * w) im2col(context.device_context(), output_grad_batch, col, strides[0], strides[1], paddings[0], paddings[1]); // gemm: dx = filter * dy - // (M, C * K_H * K_W) * (C * K_H * K_W, H * W) -> (M, C, H) + // (m, c * k_h * k_w) * (c * k_h * k_w, h * w) -> (m, c, h) math::matmul(context.device_context(), filter, false, col_matrix, false, T(1.0), &input_grad_batch, T(0.0)); @@ -217,7 +217,7 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel { // filter gradient required if (filter_grad) { Tensor col_matrix_f = col; - DDim col_matrix_shape_f = {C * H * W, K_H * K_W}; + DDim col_matrix_shape_f = {c * h * w, k_h * k_w}; col_matrix_f.Resize(col_matrix_shape_f); filter_grad->mutable_data(context.GetPlace()); @@ -226,19 +226,19 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel { auto t = framework::EigenVector::Flatten(filter_grad_); t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); - for (int i = 0; i < N; ++i) { - // batch with size (C, O_H, O_W) + for (int i = 0; i < batch_size; ++i) { + // batch with size (c, o_h, o_w) Tensor output_grad_batch = - output_grad->Slice(i, i + 1).Resize(output_shape); + output_grad->Slice(i, i + 1).Resize(output_shape); // input batch - Tensor in_batch = input->Slice(i, i + 1).Resize(input_matrix_shape); + Tensor in_batch = input->Slice(i, i + 1).Resize(input_matrix_shape); - // im2col: (C * H * W, K_H * K_W) + // im2col: (c * h * w, k_h * k_w) im2col(context.device_context(), output_grad_batch, col, strides[0], strides[1], paddings[0], paddings[1]); // gemm: d_filter = x * y_grad^T - // (M, C * H * W) * (K_H * K_W, C * H * W) -> (M, C, H) + // (m, c * h * w) * (k_h * k_w, c * h * w) -> (m, c, h) math::matmul(context.device_context(), in_batch, false, col_matrix_f, true, T(1.0), &filter_grad_, T(1.0)); -- GitLab