提交 02bdfdba 编写于 作者: X xiemoyuan

Optimize the error message of GatherTreeOP.

上级 562945c9
......@@ -22,19 +22,16 @@ class GatherTreeOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("Ids"), true,
"Input(Ids) of GatherTreeOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Parents"), true,
"Input(Parents) of GatherTreeOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of GatherTreeOp should not be null.");
OP_INOUT_CHECK(ctx->HasInput("Ids"), "Input", "Ids", "GatherTree");
OP_INOUT_CHECK(ctx->HasInput("Parents"), "Input", "Parents", "GatherTree");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "GatherTree");
auto ids_dims = ctx->GetInputDim("Ids");
auto parents_dims = ctx->GetInputDim("Parents");
PADDLE_ENFORCE_EQ(
ids_dims == parents_dims, true,
"The shape of Input(Parents) must be same with the shape of "
"Input(Ids).");
PADDLE_ENFORCE_EQ(ids_dims == parents_dims, true,
platform::errors::InvalidArgument(
"The shape of Input(Parents) must be same with the "
"shape of Input(Ids)."));
ctx->SetOutputDim("Out", ids_dims);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册