From 652f182dc02023a04218d1020275dccaf78a92cc Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Fri, 13 Oct 2017 14:05:40 -0700 Subject: [PATCH] deconv --- paddle/operators/deconv2d_op.cc | 147 ++++++++++++++------------------ paddle/operators/deconv2d_op.cu | 23 +++++ paddle/operators/deconv2d_op.h | 52 +++++++++++ 3 files changed, 141 insertions(+), 81 deletions(-) create mode 100644 paddle/operators/deconv2d_op.cu create mode 100644 paddle/operators/deconv2d_op.h diff --git a/paddle/operators/deconv2d_op.cc b/paddle/operators/deconv2d_op.cc index ce95db05e7..6b71a1fea7 100644 --- a/paddle/operators/deconv2d_op.cc +++ b/paddle/operators/deconv2d_op.cc @@ -12,97 +12,82 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/gemm_conv2d_op.h" +#include "paddle/operators/deconv2d_op.h" +#include "paddle/operators/conv2d_op.h" namespace paddle { namespace operators { -class Deconv2DOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Input"), - "Input(Input) of Deconv2DOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Filter"), - "Input(Filter) of Deconv2DOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Output"), - "Output(Output) of Deconv2DOp should not be null."); - - auto in_dims = ctx->GetInputDim("Input"); - auto filter_dims = ctx->GetInputDim("Filter"); - std::vector strides = ctx->Attrs().Get>("strides"); - std::vector paddings = ctx->Attrs().Get>("paddings"); - int groups = ctx->Attrs().Get("groups"); - int input_channels = in_dims[1]; - int output_channels = filter_dims[0]; - - PADDLE_ENFORCE_EQ(in_dims.size(), 4, "Conv2DOp input should be 4-D."); - PADDLE_ENFORCE_EQ(filter_dims.size(), 4, "Conv2DOp filter should be 4-D."); - PADDLE_ENFORCE_EQ(input_channels, filter_dims[1] * groups, - "The number of input channels should be equal to filter " - "channels * groups."); - PADDLE_ENFORCE_EQ( - output_channels % groups, 0, - "The number of output channels should be divided by groups."); - - auto output_height = (in_dims[2] - 1) * strides[0] + filter_dims[2]; - auto output_width = (in_dims[3] - 1) * strides[1] + filter_dims[3]; - ctx->SetOutputDim( - "Output", {in_dims[0], filter_dims[0], output_height, output_width}); - } -}; - -class Deconv2DOpMaker : public framework::OpProtoAndCheckerMaker { - public: - Deconv2DOpMaker(framework::OpProto* proto, - framework::OpAttrChecker* op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput( - "Input", - "The input tensor of deconvolution operator. " - "The format of input tensor is NCHW. Where N is batch size, C is the " - "number of channels, H and W is the height and width of image."); - AddInput( - "Filter", - "The filter tensor of deconvolution operator." - "The format of the filter tensor is MCHW, where M is the number of " - "output image channels, C is the number of input image channels, " - "H and W is height and width of filter. " - "We enforce groups number == 1 and padding == 0 in our deconvolution - Scenario."); - AddOutput("Output", - "The output tensor of deconvolution operator." - "The format of output tensor is also NCHW."); - AddAttr>("strides", "strides of deconvolution operator.") - .SetDefault({1, 1}); - AddAttr>("paddings", "paddings of deconvolution operator.") - .SetDefault({0, 0}); - AddComment(R"DOC( +void Deconv2DOp::InferShape(framework::InferShapeContext* ctx) const { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(Input) of Deconv2DOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Filter"), + "Input(Filter) of Deconv2DOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Output"), + "Output(Output) of Deconv2DOp should not be null."); + + auto in_dims = ctx->GetInputDim("Input"); + auto filter_dims = ctx->GetInputDim("Filter"); + std::vector strides = ctx->Attrs().Get>("strides"); + std::vector paddings = ctx->Attrs().Get>("paddings"); + int groups = ctx->Attrs().Get("groups"); + int input_channels = in_dims[1]; + int output_channels = filter_dims[0]; + + PADDLE_ENFORCE_EQ(in_dims.size(), 4, "Conv2DOp input should be 4-D."); + PADDLE_ENFORCE_EQ(filter_dims.size(), 4, "Conv2DOp filter should be 4-D."); + PADDLE_ENFORCE_EQ(input_channels, filter_dims[1] * groups, + "The number of input channels should be equal to filter " + "channels * groups."); + PADDLE_ENFORCE_EQ( + output_channels % groups, 0, + "The number of output channels should be divided by groups."); + + auto output_height = (in_dims[2] - 1) * strides[0] + filter_dims[2]; + auto output_width = (in_dims[3] - 1) * strides[1] + filter_dims[3]; + ctx->SetOutputDim("Output", + {in_dims[0], filter_dims[0], output_height, output_width}); +} + +Deconv2DOpMaker::Deconv2DOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "Input", + "The input tensor of deconvolution operator. " + "The format of input tensor is NCHW. Where N is batch size, C is the " + "number of channels, H and W is the height and width of image."); + AddInput("Filter", + "The filter tensor of deconvolution operator." + "The format of the filter tensor is MCHW, where M is the number of " + "output image channels, C is the number of input image channels, " + "H and W is height and width of filter. " + "We enforce groups number == 1 and padding == 0 in our " + "deconvolution Scenario."); + AddOutput("Output", + "The output tensor of deconvolution operator." + "The format of output tensor is also NCHW."); + AddAttr>("strides", "strides of deconvolution operator.") + .SetDefault({1, 1}); + AddAttr>("paddings", "paddings of deconvolution operator.") + .SetDefault({0, 0}); + AddComment(R"DOC( The deconvolution operation calculates the output based on the input, filter and strides, paddings, groups parameters. The size of each dimension of the parameters is checked in the infer-shape. )DOC"); - } -}; +} -class Deconv2DOpGrad : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(framework::InferShapeContext* ctx) const override { - auto in_dims = ctx->GetInputDim("Input"); - auto filter_dims = ctx->GetInputDim("Filter"); - if (ctx->HasOutput(framework::GradVarName("Input"))) { - ctx->SetOutputDim(framework::GradVarName("Input"), in_dims); - } - if (ctx->HasOutput(framework::GradVarName("Filter"))) { - ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims); - } +void Deconv2DOpGrad::InferShape(framework::InferShapeContext* ctx) const { + auto in_dims = ctx->GetInputDim("Input"); + auto filter_dims = ctx->GetInputDim("Filter"); + if (ctx->HasOutput(framework::GradVarName("Input"))) { + ctx->SetOutputDim(framework::GradVarName("Input"), in_dims); + } + if (ctx->HasOutput(framework::GradVarName("Filter"))) { + ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims); } -}; +} } // namespace operators } // namespace paddle diff --git a/paddle/operators/deconv2d_op.cu b/paddle/operators/deconv2d_op.cu new file mode 100644 index 0000000000..9286a18153 --- /dev/null +++ b/paddle/operators/deconv2d_op.cu @@ -0,0 +1,23 @@ +/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/conv2d_op.h" +#include "paddle/operators/deconv2d_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_GPU_KERNEL( + deconv2d, ops::GemmConvGrad2DKernel); +REGISTER_OP_GPU_KERNEL( + deconv2d_grad, ops::GemmConv2DKernel); diff --git a/paddle/operators/deconv2d_op.h b/paddle/operators/deconv2d_op.h new file mode 100644 index 0000000000..4f5a0242b1 --- /dev/null +++ b/paddle/operators/deconv2d_op.h @@ -0,0 +1,52 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/im2col.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +// Define Op classes in .h file so that other deconv +// operator implementations can reuse the code. +class Deconv2DOpMaker : public framework::OpProtoAndCheckerMaker { + public: + Deconv2DOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker); +}; + +class Deconv2DOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override; +}; + +class Deconv2DOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override; +}; + +} // namespace operators +} // namespace paddle -- GitLab