提交 02bdfdba 编写于 作者: X xiemoyuan

Optimize the error message of GatherTreeOP.

上级 562945c9
...@@ -22,19 +22,16 @@ class GatherTreeOp : public framework::OperatorWithKernel { ...@@ -22,19 +22,16 @@ class GatherTreeOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("Ids"), true, OP_INOUT_CHECK(ctx->HasInput("Ids"), "Input", "Ids", "GatherTree");
"Input(Ids) of GatherTreeOp should not be null."); OP_INOUT_CHECK(ctx->HasInput("Parents"), "Input", "Parents", "GatherTree");
PADDLE_ENFORCE_EQ(ctx->HasInput("Parents"), true, OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "GatherTree");
"Input(Parents) of GatherTreeOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of GatherTreeOp should not be null.");
auto ids_dims = ctx->GetInputDim("Ids"); auto ids_dims = ctx->GetInputDim("Ids");
auto parents_dims = ctx->GetInputDim("Parents"); auto parents_dims = ctx->GetInputDim("Parents");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(ids_dims == parents_dims, true,
ids_dims == parents_dims, true, platform::errors::InvalidArgument(
"The shape of Input(Parents) must be same with the shape of " "The shape of Input(Parents) must be same with the "
"Input(Ids)."); "shape of Input(Ids)."));
ctx->SetOutputDim("Out", ids_dims); ctx->SetOutputDim("Out", ids_dims);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册