From 66b84366f1e09366b28e41dbd0d3521152554115 Mon Sep 17 00:00:00 2001 From: sweetsky0901 Date: Thu, 23 Nov 2017 11:53:30 +0800 Subject: [PATCH] modify for code review by wangyi --- paddle/operators/unpool_op.cc | 26 +++++++++---------- paddle/operators/unpool_op.h | 47 ++++++++++++----------------------- 2 files changed, 28 insertions(+), 45 deletions(-) diff --git a/paddle/operators/unpool_op.cc b/paddle/operators/unpool_op.cc index add8f157368..b5f3d56e960 100644 --- a/paddle/operators/unpool_op.cc +++ b/paddle/operators/unpool_op.cc @@ -16,11 +16,9 @@ namespace paddle { namespace operators { -using framework::Tensor; - class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker { public: - Unpool2dOpMaker(framework::OpProto* proto, \ + Unpool2dOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", @@ -38,26 +36,26 @@ class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker { "the number of channels, H and W is the height and " "width of feature."); AddAttr>("ksize", - "(vector ), the unpooling window size(height, width) " + "(vector), the unpooling window size(height, width) " "of unpooling operator."); AddAttr>("strides", "(vector, default:{1, 1}), " - "strides(height, width) of unpooling operator.") + "strides (height, width) of unpooling operator.") .SetDefault({1, 1}); AddAttr>("paddings", "(vector defalut:{0,0}), " - "paddings(height, width) of unpooling operator.") + "paddings (height, width) of unpooling operator.") .SetDefault({0, 0}); AddAttr("unpoolingtype", "(string), unpooling type, can be \"max\" for max-unpooling ") .InEnum({"max"}); AddComment(R"DOC( - "input: the input Tensor to invert" - "indices: the indices given out by MaxPool2d" - "ksize – Size of the max pooling window." - "stride – Stride of the max pooling window." - "It is set to kernel_size by default." - "padding – Padding that was added to the input" + "input: the input Tensor to invert + indices: the indices given out by MaxPool2d + ksize – Size of the max pooling window. + stride – Stride of the max pooling window. + "It is set to kernel_size by default. + padding – Padding that was added to the input" )DOC"); } }; @@ -80,14 +78,14 @@ class UnpoolOp : public framework::OperatorWithKernel { auto in_x_dims = ctx->GetInputDim("X"); auto in_y_dims = ctx->GetInputDim("Y"); - std::string unpoolingtype = \ + std::string unpoolingtype = ctx->Attrs().Get("unpoolingtype"); std::vector ksize = ctx->Attrs().Get>("ksize"); std::vector strides = ctx->Attrs().Get>("strides"); std::vector paddings = ctx->Attrs().Get>("paddings"); PADDLE_ENFORCE(in_x_dims.size() == 4, - "Unpooling intput should be 4-D."); + "Unpooling intput must be of 4-dimensional."); for (int i = 0; i < 4; ++i) { PADDLE_ENFORCE(in_x_dims[i] == in_y_dims[i], "X size must be eq Y size!"); diff --git a/paddle/operators/unpool_op.h b/paddle/operators/unpool_op.h index e3a45ff9a71..e22171649eb 100644 --- a/paddle/operators/unpool_op.h +++ b/paddle/operators/unpool_op.h @@ -21,15 +21,13 @@ limitations under the License. */ namespace paddle { namespace operators { -using Tensor = framework::Tensor; - template class UnpoolKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - const Tensor* in_x = context.Input("X"); - const Tensor* in_y = context.Input("Y"); - auto * out = context.Output("Out"); + const framework::Tensor* in_x = context.Input("X"); + const framework::Tensor* in_y = context.Input("Y"); + auto * out = context.Output("Out"); std::string unpoolingtype = context.Attr("unpoolingtype"); std::vector ksize = context.Attr>("ksize"); std::vector strides = context.Attr>("strides"); @@ -39,15 +37,8 @@ class UnpoolKernel : public framework::OpKernel { math::SetConstant set_zero; set_zero(context.device_context(), out, static_cast(0)); } - switch (ksize.size()) { - case 2: { - if (unpoolingtype == "max") { - math::Unpool2dMaxFunctor unpool2d_max_forward; - unpool2d_max_forward(context.device_context(), *in_x, *in_y, out); - } - } break; - default: { PADDLE_THROW("Pool op only supports 2D input."); } - } + math::Unpool2dMaxFunctor unpool2d_max_forward; + unpool2d_max_forward(context.device_context(), *in_x, *in_y, out); } }; @@ -55,12 +46,13 @@ template class UnpoolGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - const Tensor* in_x = context.Input("X"); - const Tensor* in_y = context.Input("Y"); - const Tensor* out = context.Input("Out"); - const Tensor* out_grad = - context.Input(framework::GradVarName("Out")); - Tensor* in_x_grad = context.Output(framework::GradVarName("X")); + const framework::Tensor* in_x = context.Input("X"); + const framework::Tensor* in_y = context.Input("Y"); + const framework::Tensor* out = context.Input("Out"); + const framework::Tensor* out_grad = + context.Input(framework::GradVarName("Out")); + framework::Tensor* in_x_grad = + context.Output(framework::GradVarName("X")); std::string unpoolingtype = context.Attr("unpoolingtype"); std::vector ksize = context.Attr>("ksize"); std::vector strides = context.Attr>("strides"); @@ -70,18 +62,11 @@ class UnpoolGradKernel : public framework::OpKernel { math::SetConstant zero; if (in_x_grad) { in_x_grad->mutable_data(context.GetPlace()); - zero(device_ctx, in_x_grad, static_cast(0.0)); - } - switch (ksize.size()) { - case 2: { - if (unpoolingtype == "max") { - math::Unpool2dMaxGradFunctor unpool2d_max_backward; - unpool2d_max_backward(context.device_context(), *in_x, *in_y, in_x_grad, - *out, *out_grad); - } - } break; - default: { PADDLE_THROW("Unpool op only supports 2D input."); } + zero(device_ctx, in_x_grad, static_cast(0)); } + math::Unpool2dMaxGradFunctor unpool2d_max_backward; + unpool2d_max_backward(context.device_context(), *in_x, *in_y, in_x_grad, + *out, *out_grad); } }; -- GitLab