From 532f38d3336d295792f161b223c8c25bae46b492 Mon Sep 17 00:00:00 2001 From: Zhuoyuan Date: Wed, 11 Oct 2017 17:34:01 -0700 Subject: [PATCH] deconv op --- paddle/operators/deconv2d_op.cc | 118 ++++++++++++++++++++++++++++++++ 1 file changed, 118 insertions(+) create mode 100644 paddle/operators/deconv2d_op.cc diff --git a/paddle/operators/deconv2d_op.cc b/paddle/operators/deconv2d_op.cc new file mode 100644 index 00000000000..408e1f04521 --- /dev/null +++ b/paddle/operators/deconv2d_op.cc @@ -0,0 +1,118 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/gemm_conv2d_op.h" + +namespace paddle { +namespace operators { + + +class Deconv2DOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(Input) of Deconv2DOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Filter"), + "Input(Filter) of Deconv2DOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Output"), + "Output(Output) of Deconv2DOp should not be null."); + + auto in_dims = ctx->GetInputDim("Input"); + auto filter_dims = ctx->GetInputDim("Filter"); + std::vector strides = ctx->Attrs().Get>("strides"); + std::vector paddings = ctx->Attrs().Get>("paddings"); + int groups = ctx->Attrs().Get("groups"); + int input_channels = in_dims[1]; + int output_channels = filter_dims[0]; + + PADDLE_ENFORCE_EQ(in_dims.size(), 4, "Conv2DOp input should be 4-D."); + PADDLE_ENFORCE_EQ(filter_dims.size(), 4, "Conv2DOp filter should be 4-D."); + PADDLE_ENFORCE_EQ(input_channels, filter_dims[1] * groups, + "The number of input channels should be equal to filter " + "channels * groups."); + PADDLE_ENFORCE_EQ( + output_channels % groups, 0, + "The number of output channels should be divided by groups."); + + auto output_height = (in_dims[2] - 1) * strides[0] + filter_dims[2]; + auto output_width = (in_dims[3] - 1) * strides[1] + filter_dims[3]; + ctx->SetOutputDim( + "Output", {in_dims[0], filter_dims[0], output_height, output_width}); + } +}; + +class Deconv2DOpMaker : public framework::OpProtoAndCheckerMaker { + public: + Deconv2DOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "Input", + "The input tensor of deconvolution operator. " + "The format of input tensor is NCHW. Where N is batch size, C is the " + "number of channels, H and W is the height and width of image."); + AddInput( + "Filter", + "The filter tensor of deconvolution operator." + "The format of the filter tensor is MCHW, where M is the number of " + "output image channels, C is the number of input image channels, " + "H and W is height and width of filter. " + "We enforce groups number == 1 and padding == 0 in our deconvolution + Scenario."); + AddOutput("Output", + "The output tensor of deconvolution operator." + "The format of output tensor is also NCHW."); + AddAttr>("strides", "strides of deconvolution operator.") + .SetDefault({1, 1}); + AddAttr>("paddings", "paddings of deconvolution operator.") + .SetDefault({0, 0}); + AddComment(R"DOC( +The deconvolution operation calculates the output based on the input, filter +and strides, paddings, groups parameters. The size of each dimension of the +parameters is checked in the infer-shape. +)DOC"); + } +}; + +class Deconv2DOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + auto in_dims = ctx->GetInputDim("Input"); + auto filter_dims = ctx->GetInputDim("Filter"); + if (ctx->HasOutput(framework::GradVarName("Input"))) { + ctx->SetOutputDim(framework::GradVarName("Input"), in_dims); + } + if (ctx->HasOutput(framework::GradVarName("Filter"))) { + ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(deconv2d, ops::Deconv2DOp, ops::Deconv2DOpMaker, deconv2d_grad, + ops::Deconv2DOpGrad); + +REGISTER_OP_CPU_KERNEL( + deconv2d, ops::GemmConvGrad2DKernel); +REGISTER_OP_CPU_KERNEL( + deconv2d_grad, ops::GemmConv2DKernel); -- GitLab