未验证 提交 638d924d 编写于 作者: Z zhangchunle 提交者: GitHub

Op (FusionSquaredMatSub) error message enhancement. (#23498)

上级 660489ac
......@@ -38,12 +38,6 @@ class FusedEmbeddingSeqPoolOp : public framework::OperatorWithKernel {
"The dim size of the input tensor 'W' should be 2. "
"But received W's size = %d.",
table_dims.size()));
PADDLE_ENFORCE_GE(
ids_dims.size(), 1,
platform::errors::InvalidArgument(
"The dim size of the input tensor 'Ids' should be greater "
"than or equal to 1. But received Ids's size = %d.",
ids_dims.size()));
PADDLE_ENFORCE_EQ(
ids_dims[ids_dims.size() - 1], 1,
platform::errors::InvalidArgument(
......
......@@ -22,29 +22,34 @@ namespace operators {
void FusionSquaredMatSubOp::InferShape(
framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of FusionSquaredMatSubOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"),
"Input(Y) of FusionSquaredMatSubOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("SquaredX"),
"Output(SquaredX) of FusionSquaredMatSubOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("SquaredY"),
"Output(SquaredY) of FusionSquaredMatSubOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("SquaredXY"),
"Output(SquaredXY) of FusionSquaredMatSubOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FusionSquaredMatSubOp should not be null.");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusionSquaredMatSub");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "FusionSquaredMatSub");
OP_INOUT_CHECK(ctx->HasOutput("SquaredX"), "SquaredX", "Out",
"FusionSquaredMatSub");
OP_INOUT_CHECK(ctx->HasOutput("SquaredY"), "SquaredY", "Out",
"FusionSquaredMatSub");
OP_INOUT_CHECK(ctx->HasOutput("SquaredXY"), "SquaredXY", "Out",
"FusionSquaredMatSub");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Out", "Out", "FusionSquaredMatSub");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(x_dims.size(), y_dims.size(),
"Input tensors dims size should be equal.");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input tensors should be a Matrix.");
PADDLE_ENFORCE_EQ(x_dims[1], y_dims[0], "Inputs Matrix should be multiply.");
PADDLE_ENFORCE_EQ(
x_dims.size(), y_dims.size(),
platform::errors::InvalidArgument("The input tensor X's dims size should "
"be equal to Y's. But received X's "
"dims size = %d, Y's dims size = %d.",
x_dims.size(), y_dims.size()));
PADDLE_ENFORCE_EQ(x_dims.size(), 2,
platform::errors::InvalidArgument(
"The input tensor X's dims size should be 2. But "
"received X's dims size = %d.",
x_dims.size()));
PADDLE_ENFORCE_EQ(
x_dims[1], y_dims[0],
platform::errors::InvalidArgument("The input tensor X's dims[1] should "
"be equal to Y's dims[0]. But received "
"X's dims[1] = %d, Y's dims[0] = %d.",
x_dims[1], y_dims[0]));
ctx->SetOutputDim("SquaredX", x_dims);
ctx->SetOutputDim("SquaredY", y_dims);
ctx->SetOutputDim("SquaredXY", {x_dims[0], y_dims[1]});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册