提交 e33b4112 编写于 作者: G guosheng

Adapt reduce_op according to up-to-date dev

上级 be58c632
......@@ -24,20 +24,20 @@ class ReduceOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
"Input(X) of ReduceOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
"Output(Out) of ReduceOp should not be null.");
auto x_dims = ctx.Input<Tensor>("X")->dims();
void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ReduceOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ReduceOp should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto x_rank = x_dims.size();
PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported.");
int dim = ctx.Attr<int>("dim");
int dim = ctx->Attrs().Get<int>("dim");
if (dim < 0) dim = x_rank + dim;
PADDLE_ENFORCE_LT(
dim, x_rank,
"The dim should be in the range [-rank(input), rank(input)).");
bool keep_dim = ctx.Attr<bool>("keep_dim");
bool keep_dim = ctx->Attrs().Get<bool>("keep_dim");
auto dims_vector = vectorize(x_dims);
if (keep_dim || x_rank == 1) {
dims_vector[dim] = 1;
......@@ -45,10 +45,10 @@ class ReduceOp : public framework::OperatorWithKernel {
dims_vector.erase(dims_vector.begin() + dim);
}
auto out_dims = framework::make_ddim(dims_vector);
ctx.Output<framework::Tensor>("Out")->Resize(out_dims);
ctx->SetOutputDim("Out", out_dims);
if (dim != 0) {
// Only pass LoD when not reducing on the first dim
ctx.ShareLoD("X", /*->*/ "Out");
// Only pass LoD when not reducing on the first dim.
ctx->ShareLoD("X", /*->*/ "Out");
}
}
};
......@@ -58,21 +58,22 @@ class ReduceGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null.");
auto x_dims = ctx.Input<Tensor>("X")->dims();
void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto x_rank = x_dims.size();
PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported.");
int dim = ctx.Attr<int>("dim");
int dim = ctx->Attrs().Get<int>("dim");
if (dim < 0) dim = x_rank + dim;
PADDLE_ENFORCE_LT(
dim, x_rank,
"The dim should be in the range [-rank(input), rank(input)).");
auto *x_grad =
ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
if (x_grad) x_grad->Resize(x_dims);
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册