From 638d924d89f860c386778c2055297b80d9ba6cb9 Mon Sep 17 00:00:00 2001 From: zhangchunle Date: Tue, 7 Apr 2020 11:30:22 +0800 Subject: [PATCH] Op (FusionSquaredMatSub) error message enhancement. (#23498) --- .../fused/fused_embedding_seq_pool_op.cc | 6 --- .../fused/fusion_squared_mat_sub_op.cc | 47 ++++++++++--------- 2 files changed, 26 insertions(+), 27 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc index 8ebeb9cd26d..2db2ed09728 100644 --- a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc +++ b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc @@ -38,12 +38,6 @@ class FusedEmbeddingSeqPoolOp : public framework::OperatorWithKernel { "The dim size of the input tensor 'W' should be 2. " "But received W's size = %d.", table_dims.size())); - PADDLE_ENFORCE_GE( - ids_dims.size(), 1, - platform::errors::InvalidArgument( - "The dim size of the input tensor 'Ids' should be greater " - "than or equal to 1. But received Ids's size = %d.", - ids_dims.size())); PADDLE_ENFORCE_EQ( ids_dims[ids_dims.size() - 1], 1, platform::errors::InvalidArgument( diff --git a/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc b/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc index 2d4a3977980..870f72b8c7f 100644 --- a/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc +++ b/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc @@ -22,29 +22,34 @@ namespace operators { void FusionSquaredMatSubOp::InferShape( framework::InferShapeContext* ctx) const { - PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of FusionSquaredMatSubOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Y"), - "Input(Y) of FusionSquaredMatSubOp should not be null."); - PADDLE_ENFORCE( - ctx->HasOutput("SquaredX"), - "Output(SquaredX) of FusionSquaredMatSubOp should not be null."); - PADDLE_ENFORCE( - ctx->HasOutput("SquaredY"), - "Output(SquaredY) of FusionSquaredMatSubOp should not be null."); - PADDLE_ENFORCE( - ctx->HasOutput("SquaredXY"), - "Output(SquaredXY) of FusionSquaredMatSubOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of FusionSquaredMatSubOp should not be null."); - + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusionSquaredMatSub"); + OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "FusionSquaredMatSub"); + OP_INOUT_CHECK(ctx->HasOutput("SquaredX"), "SquaredX", "Out", + "FusionSquaredMatSub"); + OP_INOUT_CHECK(ctx->HasOutput("SquaredY"), "SquaredY", "Out", + "FusionSquaredMatSub"); + OP_INOUT_CHECK(ctx->HasOutput("SquaredXY"), "SquaredXY", "Out", + "FusionSquaredMatSub"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Out", "Out", "FusionSquaredMatSub"); auto x_dims = ctx->GetInputDim("X"); auto y_dims = ctx->GetInputDim("Y"); - PADDLE_ENFORCE_EQ(x_dims.size(), y_dims.size(), - "Input tensors dims size should be equal."); - PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input tensors should be a Matrix."); - PADDLE_ENFORCE_EQ(x_dims[1], y_dims[0], "Inputs Matrix should be multiply."); - + PADDLE_ENFORCE_EQ( + x_dims.size(), y_dims.size(), + platform::errors::InvalidArgument("The input tensor X's dims size should " + "be equal to Y's. But received X's " + "dims size = %d, Y's dims size = %d.", + x_dims.size(), y_dims.size())); + PADDLE_ENFORCE_EQ(x_dims.size(), 2, + platform::errors::InvalidArgument( + "The input tensor X's dims size should be 2. But " + "received X's dims size = %d.", + x_dims.size())); + PADDLE_ENFORCE_EQ( + x_dims[1], y_dims[0], + platform::errors::InvalidArgument("The input tensor X's dims[1] should " + "be equal to Y's dims[0]. But received " + "X's dims[1] = %d, Y's dims[0] = %d.", + x_dims[1], y_dims[0])); ctx->SetOutputDim("SquaredX", x_dims); ctx->SetOutputDim("SquaredY", y_dims); ctx->SetOutputDim("SquaredXY", {x_dims[0], y_dims[1]}); -- GitLab