未验证 提交 638d924d 编写于 作者: Z zhangchunle 提交者: GitHub

Op (FusionSquaredMatSub) error message enhancement. (#23498)

上级 660489ac
...@@ -38,12 +38,6 @@ class FusedEmbeddingSeqPoolOp : public framework::OperatorWithKernel { ...@@ -38,12 +38,6 @@ class FusedEmbeddingSeqPoolOp : public framework::OperatorWithKernel {
"The dim size of the input tensor 'W' should be 2. " "The dim size of the input tensor 'W' should be 2. "
"But received W's size = %d.", "But received W's size = %d.",
table_dims.size())); table_dims.size()));
PADDLE_ENFORCE_GE(
ids_dims.size(), 1,
platform::errors::InvalidArgument(
"The dim size of the input tensor 'Ids' should be greater "
"than or equal to 1. But received Ids's size = %d.",
ids_dims.size()));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ids_dims[ids_dims.size() - 1], 1, ids_dims[ids_dims.size() - 1], 1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
......
...@@ -22,29 +22,34 @@ namespace operators { ...@@ -22,29 +22,34 @@ namespace operators {
void FusionSquaredMatSubOp::InferShape( void FusionSquaredMatSubOp::InferShape(
framework::InferShapeContext* ctx) const { framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"), OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusionSquaredMatSub");
"Input(X) of FusionSquaredMatSubOp should not be null."); OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "FusionSquaredMatSub");
PADDLE_ENFORCE(ctx->HasInput("Y"), OP_INOUT_CHECK(ctx->HasOutput("SquaredX"), "SquaredX", "Out",
"Input(Y) of FusionSquaredMatSubOp should not be null."); "FusionSquaredMatSub");
PADDLE_ENFORCE( OP_INOUT_CHECK(ctx->HasOutput("SquaredY"), "SquaredY", "Out",
ctx->HasOutput("SquaredX"), "FusionSquaredMatSub");
"Output(SquaredX) of FusionSquaredMatSubOp should not be null."); OP_INOUT_CHECK(ctx->HasOutput("SquaredXY"), "SquaredXY", "Out",
PADDLE_ENFORCE( "FusionSquaredMatSub");
ctx->HasOutput("SquaredY"), OP_INOUT_CHECK(ctx->HasOutput("Out"), "Out", "Out", "FusionSquaredMatSub");
"Output(SquaredY) of FusionSquaredMatSubOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("SquaredXY"),
"Output(SquaredXY) of FusionSquaredMatSubOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FusionSquaredMatSubOp should not be null.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y"); auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(x_dims.size(), y_dims.size(), PADDLE_ENFORCE_EQ(
"Input tensors dims size should be equal."); x_dims.size(), y_dims.size(),
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input tensors should be a Matrix."); platform::errors::InvalidArgument("The input tensor X's dims size should "
PADDLE_ENFORCE_EQ(x_dims[1], y_dims[0], "Inputs Matrix should be multiply."); "be equal to Y's. But received X's "
"dims size = %d, Y's dims size = %d.",
x_dims.size(), y_dims.size()));
PADDLE_ENFORCE_EQ(x_dims.size(), 2,
platform::errors::InvalidArgument(
"The input tensor X's dims size should be 2. But "
"received X's dims size = %d.",
x_dims.size()));
PADDLE_ENFORCE_EQ(
x_dims[1], y_dims[0],
platform::errors::InvalidArgument("The input tensor X's dims[1] should "
"be equal to Y's dims[0]. But received "
"X's dims[1] = %d, Y's dims[0] = %d.",
x_dims[1], y_dims[0]));
ctx->SetOutputDim("SquaredX", x_dims); ctx->SetOutputDim("SquaredX", x_dims);
ctx->SetOutputDim("SquaredY", y_dims); ctx->SetOutputDim("SquaredY", y_dims);
ctx->SetOutputDim("SquaredXY", {x_dims[0], y_dims[1]}); ctx->SetOutputDim("SquaredXY", {x_dims[0], y_dims[1]});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册