提交 e33b4112 编写于 作者: G guosheng

Adapt reduce_op according to up-to-date dev

上级 be58c632
...@@ -24,20 +24,20 @@ class ReduceOp : public framework::OperatorWithKernel { ...@@ -24,20 +24,20 @@ class ReduceOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ReduceOp should not be null."); "Input(X) of ReduceOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ReduceOp should not be null."); "Output(Out) of ReduceOp should not be null.");
auto x_dims = ctx.Input<Tensor>("X")->dims(); auto x_dims = ctx->GetInputDim("X");
auto x_rank = x_dims.size(); auto x_rank = x_dims.size();
PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported."); PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported.");
int dim = ctx.Attr<int>("dim"); int dim = ctx->Attrs().Get<int>("dim");
if (dim < 0) dim = x_rank + dim; if (dim < 0) dim = x_rank + dim;
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
dim, x_rank, dim, x_rank,
"The dim should be in the range [-rank(input), rank(input))."); "The dim should be in the range [-rank(input), rank(input)).");
bool keep_dim = ctx.Attr<bool>("keep_dim"); bool keep_dim = ctx->Attrs().Get<bool>("keep_dim");
auto dims_vector = vectorize(x_dims); auto dims_vector = vectorize(x_dims);
if (keep_dim || x_rank == 1) { if (keep_dim || x_rank == 1) {
dims_vector[dim] = 1; dims_vector[dim] = 1;
...@@ -45,10 +45,10 @@ class ReduceOp : public framework::OperatorWithKernel { ...@@ -45,10 +45,10 @@ class ReduceOp : public framework::OperatorWithKernel {
dims_vector.erase(dims_vector.begin() + dim); dims_vector.erase(dims_vector.begin() + dim);
} }
auto out_dims = framework::make_ddim(dims_vector); auto out_dims = framework::make_ddim(dims_vector);
ctx.Output<framework::Tensor>("Out")->Resize(out_dims); ctx->SetOutputDim("Out", out_dims);
if (dim != 0) { if (dim != 0) {
// Only pass LoD when not reducing on the first dim // Only pass LoD when not reducing on the first dim.
ctx.ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
} }
}; };
...@@ -58,21 +58,22 @@ class ReduceGradOp : public framework::OperatorWithKernel { ...@@ -58,21 +58,22 @@ class ReduceGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null."); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null."); "Input(Out@GRAD) should not be null.");
auto x_dims = ctx.Input<Tensor>("X")->dims(); auto x_dims = ctx->GetInputDim("X");
auto x_rank = x_dims.size(); auto x_rank = x_dims.size();
PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported."); PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported.");
int dim = ctx.Attr<int>("dim"); int dim = ctx->Attrs().Get<int>("dim");
if (dim < 0) dim = x_rank + dim; if (dim < 0) dim = x_rank + dim;
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
dim, x_rank, dim, x_rank,
"The dim should be in the range [-rank(input), rank(input))."); "The dim should be in the range [-rank(input), rank(input)).");
auto *x_grad = auto x_grad_name = framework::GradVarName("X");
ctx.Output<framework::LoDTensor>(framework::GradVarName("X")); if (ctx->HasOutput(x_grad_name)) {
if (x_grad) x_grad->Resize(x_dims); ctx->SetOutputDim(x_grad_name, x_dims);
}
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册