diff --git a/paddle/operators/reduce_op.cc b/paddle/operators/reduce_op.cc index 61b33d4bbd59e259fcb89f0c22bdc992eea15947..3ef443d1c7f475cbd578078db02fb5e0d500d060 100644 --- a/paddle/operators/reduce_op.cc +++ b/paddle/operators/reduce_op.cc @@ -24,20 +24,20 @@ class ReduceOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input(X) of ReduceOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(Out) of ReduceOp should not be null."); - auto x_dims = ctx.Input("X")->dims(); + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of ReduceOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of ReduceOp should not be null."); + auto x_dims = ctx->GetInputDim("X"); auto x_rank = x_dims.size(); PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported."); - int dim = ctx.Attr("dim"); + int dim = ctx->Attrs().Get("dim"); if (dim < 0) dim = x_rank + dim; PADDLE_ENFORCE_LT( dim, x_rank, "The dim should be in the range [-rank(input), rank(input))."); - bool keep_dim = ctx.Attr("keep_dim"); + bool keep_dim = ctx->Attrs().Get("keep_dim"); auto dims_vector = vectorize(x_dims); if (keep_dim || x_rank == 1) { dims_vector[dim] = 1; @@ -45,10 +45,10 @@ class ReduceOp : public framework::OperatorWithKernel { dims_vector.erase(dims_vector.begin() + dim); } auto out_dims = framework::make_ddim(dims_vector); - ctx.Output("Out")->Resize(out_dims); + ctx->SetOutputDim("Out", out_dims); if (dim != 0) { - // Only pass LoD when not reducing on the first dim - ctx.ShareLoD("X", /*->*/ "Out"); + // Only pass LoD when not reducing on the first dim. + ctx->ShareLoD("X", /*->*/ "Out"); } } }; @@ -58,21 +58,22 @@ class ReduceGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), - "Input(Out@GRAD) should not be null."); - auto x_dims = ctx.Input("X")->dims(); + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null."); + auto x_dims = ctx->GetInputDim("X"); auto x_rank = x_dims.size(); PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported."); - int dim = ctx.Attr("dim"); + int dim = ctx->Attrs().Get("dim"); if (dim < 0) dim = x_rank + dim; PADDLE_ENFORCE_LT( dim, x_rank, "The dim should be in the range [-rank(input), rank(input))."); - auto *x_grad = - ctx.Output(framework::GradVarName("X")); - if (x_grad) x_grad->Resize(x_dims); + auto x_grad_name = framework::GradVarName("X"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } } };