From 0d9ba3da9a8db4b9f25d7814fcdc8eec80de9ab5 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Thu, 9 Nov 2017 11:08:39 +0800 Subject: [PATCH] Adapt to new interface. --- paddle/operators/expand_op.cc | 69 +++++++++++++++++++---------------- paddle/operators/expand_op.h | 42 +++++++++------------ 2 files changed, 55 insertions(+), 56 deletions(-) diff --git a/paddle/operators/expand_op.cc b/paddle/operators/expand_op.cc index 3990b3751d..5d83b1d9d2 100644 --- a/paddle/operators/expand_op.cc +++ b/paddle/operators/expand_op.cc @@ -24,26 +24,28 @@ class ExpandOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext& ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "X must be initialized."); - std::vector expand_times = Attr>("expandTimes"); - auto x_dims = ctx.Input("X")->dims(); - - PADDLE_ENFORCE_EQ(x_dims.size(), expand_times.size(), - "The number of expandTimes's value must be equal " - "to the rank of X."); + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must be initialized."); + std::vector expand_times = + ctx->Attrs().Get>("expandTimes"); + auto x_dims = ctx->GetInputDim("X"); + + PADDLE_ENFORCE_EQ(static_cast(x_dims.size()), expand_times.size(), + "The number of Attr(expandTimes)'s value must be equal " + "to the rank of Input(X)."); PADDLE_ENFORCE_LE(x_dims.size(), 6, - "The rank of X must not be greater than 6."); + "The rank of Input(X) must not be greater than 6."); std::vector out_shape(x_dims.size()); for (size_t i = 0; i < expand_times.size(); ++i) { PADDLE_ENFORCE_GE(expand_times[i], 1, - "Each value of expandTimes should not be " + "Each value of Attr(expandTimes) should not be " "less than 1."); out_shape[i] = x_dims[i] * expand_times[i]; } - auto* out = ctx.Output("Out"); - out->Resize(framework::make_ddim(out_shape)); + + ctx->SetOutputDim("Out", framework::make_ddim(out_shape)); + ctx->ShareLoD("X", "Out"); } }; @@ -52,20 +54,21 @@ class ExpandOpMaker : public framework::OpProtoAndCheckerMaker { ExpandOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", - "The input tensor of expand op." - "The rank of X should be between in 1 and 6."); + "(Tensor, default Tensor) A tensor with rank in [1, 6]." + "X is the input tensor to be expanded."); AddOutput("Out", - "Output tensor of expand op." - "The rank of Out is same as X except that each dimension size " - "of Out equals to corresponding dimension size of X multiplying " - "corresponding value of expandTimes."); + "(Tensor, default Tensor) A tensor with rank in [1, 6]." + "The rank of Output(Out) is same as Input(X) except that each " + "dimension size of Output(Out) is equal to corresponding " + "dimension size of Input(X) multiplying corresponding value of " + "Attr(expandTimes)."); AddAttr>("expandTimes", "Expand times number for each dimension."); AddComment(R"DOC( Expand operator tiles the input by given times number. You should set times number for each dimension by providing attribute 'expandTimes'. The rank of X -should be between in 1 and 6. Please notice that size of 'expandTimes' must be -same with X's rank. +should be in [1, 6]. Please notice that size of 'expandTimes' must be same with +X's rank. )DOC"); } }; @@ -75,25 +78,27 @@ class ExpandGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext& ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "X must be initialized."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), - "Input(Out@GRAD) should not be null."); - auto x_dims = ctx.Input("X")->dims(); - std::vector expand_times = Attr>("expandTimes"); - auto out_dims = - ctx.Input(framework::GradVarName("Out"))->dims(); - auto* x_grad = - ctx.Output(framework::GradVarName("X")); + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null."); + auto x_dims = ctx->GetInputDim("X"); + std::vector expand_times = + ctx->Attrs().Get>("expandTimes"); + auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); for (size_t i = 0; i < expand_times.size(); ++i) { PADDLE_ENFORCE_EQ(x_dims[i] * expand_times[i], out_dims[i], "Each dimension size of Input(Out@GRAD) should be " "equal to multiplication of crroresponding dimension " - "size of Input(X) and expandTimes value."); + "size of Input(X) and Attr(expandTimes) value."); } - if (x_grad) x_grad->Resize(x_dims); + auto x_grad_name = framework::GradVarName("X"); + + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } } }; diff --git a/paddle/operators/expand_op.h b/paddle/operators/expand_op.h index f9cd519c70..bd17567c88 100644 --- a/paddle/operators/expand_op.h +++ b/paddle/operators/expand_op.h @@ -45,6 +45,8 @@ namespace paddle { namespace operators { +using Tensor = framework::Tensor; + template using EigenVector = framework::EigenVector; @@ -53,24 +55,24 @@ template ; template -class ExpandKernel : public framework::OpKernel { +class ExpandKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto rank = context.Input("X")->dims().size(); + auto rank = context.Input("X")->dims().size(); switch (rank) { REP_EXPAND_TEMPLATE(6) default: PADDLE_ENFORCE(false, "Only support tensor with rank being between 1 and 6."); - }; + } } protected: template void Expand(const framework::ExecutionContext& context) const { - auto* in0 = context.Input("X"); + auto* in0 = context.Input("X"); auto& expand_times = context.Attr>("expandTimes"); - auto* out0 = context.Output("Out"); + auto* out0 = context.Output("Out"); Eigen::DSizes bcast_dims; auto x_dims = in0->dims(); for (size_t i = 0; i < expand_times.size(); ++i) { @@ -85,10 +87,10 @@ class ExpandKernel : public framework::OpKernel { }; template -class ExpandGradKernel : public framework::OpKernel { +class ExpandGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* in0 = context.Input("X"); + auto* in0 = context.Input("X"); auto& expand_times = context.Attr>("expandTimes"); auto x_dims = in0->dims(); std::vector reshape_dims_vec; @@ -111,23 +113,17 @@ class ExpandGradKernel : public framework::OpKernel { int dims = reshape_dims_vec.size() * 6 + reduce_dims_vec.size() - 7; // no need reduce, just copy if (reduce_dims_vec.size() == 0) { - auto* in0 = - context.Input(framework::GradVarName("Out")); - auto* out0 = - context.Output(framework::GradVarName("X")); + auto* in0 = context.Input(framework::GradVarName("Out")); + auto* out0 = context.Output(framework::GradVarName("X")); out0->mutable_data(context.GetPlace()); - if (platform::is_cpu_place(context.GetPlace())) { - out0->CopyFrom(*in0, platform::CPUPlace()); - } else { - out0->CopyFrom(*in0, platform::GPUPlace()); - } + out0->CopyFrom(*in0, context.GetPlace(), context.device_context()); } else { switch (dims) { REP_EXPAND_GRAD_TEMPLATE(72) default: PADDLE_ENFORCE( false, "Only support tensor with rank being between 1 and 6."); - }; + } } } @@ -144,11 +140,9 @@ class ExpandGradKernel : public framework::OpKernel { PADDLE_ENFORCE_EQ(reduce_size, reduce_dims_vec.size(), "Inconsistent size between template Dims and " "reduce dimensions."); - auto* in0 = - context.Input(framework::GradVarName("Out")); - auto* out0 = - context.Output(framework::GradVarName("X")); - auto x = EigenVector::Flatten(*(context.Input("X"))); + auto* in0 = context.Input(framework::GradVarName("Out")); + auto* out0 = context.Output(framework::GradVarName("X")); + auto x = EigenVector::Flatten(*(context.Input("X"))); out0->mutable_data(context.GetPlace()); auto x_grad = EigenVector::Flatten(*out0); Eigen::DSizes reshape_dims; @@ -165,5 +159,5 @@ class ExpandGradKernel : public framework::OpKernel { } }; -} // operators -} // paddle +} // namespace operators +} // namespace paddle -- GitLab