未验证 提交 9cbf6013 编写于 作者: H HappyAngel 提交者: GitHub

[Cherry-Pick] [2.0-beta] c++ error enhancement (#24189)

上级 53cb2206
......@@ -48,48 +48,33 @@ class AverageAccumulatesOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(
ctx->HasInput("param"),
"Input (param) of average_accumulates op should not be null.");
PADDLE_ENFORCE(
ctx->HasInput("in_sum_1"),
"Input (sum_1) of average_accumulates op should not be null.");
PADDLE_ENFORCE(
ctx->HasInput("in_sum_2"),
"Input (sum_2) of average_accumulates op should not be null.");
PADDLE_ENFORCE(
ctx->HasInput("in_sum_3"),
"Input (sum_3) of average_accumulates op should not be null.");
PADDLE_ENFORCE(
ctx->HasInput("in_num_accumulates"),
"Input (in_num_accumulates) of average_accumulates op should "
"not be null.");
PADDLE_ENFORCE(ctx->HasInput("in_old_num_accumulates"),
"Input (old_num_accumulates) of average_accumulates op "
"should not be null.");
PADDLE_ENFORCE(
ctx->HasInput("in_num_updates"),
"Input (num_updates) of average_accumulates op should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("out_sum_1"),
"Output (sum_1) of average_accumulates op should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("out_sum_2"),
"Output (sum_2) of average_accumulates op should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("out_sum_3"),
"Output (sum_3) of average_accumulates op should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("out_num_accumulates"),
"Output (num_accumulates) of average_accumulates op should "
"not be null.");
PADDLE_ENFORCE(ctx->HasOutput("out_old_num_accumulates"),
"Output (old_num_accumulates) of average_accumulates op "
"should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("out_num_updates"),
"Output (num_updates) of average_accumulates op should not be null.");
OP_INOUT_CHECK(ctx->HasInput("param"), "Input", "param",
"AverageAccumulates");
OP_INOUT_CHECK(ctx->HasInput("in_sum_1"), "Input", "in_sum_1",
"AverageAccumulates");
OP_INOUT_CHECK(ctx->HasInput("in_sum_2"), "Input", "in_sum_2",
"AverageAccumulates");
OP_INOUT_CHECK(ctx->HasInput("in_sum_3"), "Input", "in_sum_3",
"AverageAccumulates");
OP_INOUT_CHECK(ctx->HasInput("in_num_accumulates"), "Input",
"in_num_accumulates", "AverageAccumulates");
OP_INOUT_CHECK(ctx->HasInput("in_old_num_accumulates"), "Input",
"in_old_num_accumulates", "AverageAccumulates");
OP_INOUT_CHECK(ctx->HasInput("in_num_updates"), "Input", "in_num_updates",
"AverageAccumulates");
OP_INOUT_CHECK(ctx->HasOutput("out_sum_1"), "Output", "out_sum_1",
"AverageAccumulates");
OP_INOUT_CHECK(ctx->HasOutput("out_sum_2"), "Output", "out_sum_2",
"AverageAccumulates");
OP_INOUT_CHECK(ctx->HasOutput("out_sum_3"), "Output", "out_sum_3",
"AverageAccumulates");
OP_INOUT_CHECK(ctx->HasOutput("out_num_accumulates"), "Output",
"out_num_accumulates", "AverageAccumulates");
OP_INOUT_CHECK(ctx->HasOutput("out_old_num_accumulates"), "Output",
"out_old_num_accumulates", "AverageAccumulates");
OP_INOUT_CHECK(ctx->HasOutput("out_num_updates"), "Output",
"out_num_updates", "AverageAccumulates");
auto in_dim = ctx->GetInputDim("param");
ctx->SetOutputDim("out_sum_1", in_dim);
......
......@@ -23,21 +23,25 @@ class GetTensorFromSelectedRowsOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"GetTensorFromSelectedRowsOp must have input X.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"GetTensorFromSelectedRowsOp must have output Out.");
PADDLE_ENFORCE(
ctx->GetInputsVarType("X").front() ==
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X",
"GetTensorFromSelectedRows");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
"GetTensorFromSelectedRows");
PADDLE_ENFORCE_EQ(
ctx->GetInputsVarType("X").front(),
framework::proto::VarType::SELECTED_ROWS,
"The input X's type should be SelectedRows, but the received is %s",
ctx->Inputs("X").front(), ctx->GetInputsVarType("X").front());
PADDLE_ENFORCE(
ctx->GetOutputsVarType("Out").front() ==
platform::errors::InvalidArgument(
"The input X(%s)'s type should be SelectedRows, "
"but the received is %s",
ctx->Inputs("X").front(), ctx->GetInputsVarType("X").front()));
PADDLE_ENFORCE_EQ(ctx->GetOutputsVarType("Out").front(),
framework::proto::VarType::LOD_TENSOR,
"The output Out's type should be LoDTensor, but the received is %s",
ctx->Outputs("Out").front(), ctx->GetOutputsVarType("Out").front());
platform::errors::InvalidArgument(
"The output Out(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx->Outputs("Out").front(),
ctx->GetOutputsVarType("Out").front()));
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
}
......
......@@ -96,7 +96,10 @@ class LoDTensorArray2TensorOp : public framework::OperatorBase {
*scope.FindVar(Output("OutIndex"))->GetMutable<framework::LoDTensor>();
const size_t n = inx.size();
PADDLE_ENFORCE_GT(n, 0, "Input tensorarray size should > 0.");
PADDLE_ENFORCE_GT(n, 0, platform::errors::InvalidArgument(
"Input tensorarray size should > 0,"
"but the received is %d",
n));
std::string base_name = Inputs("X")[0];
std::vector<std::string> names;
......@@ -232,7 +235,10 @@ class LoDTensorArray2TensorGradOp : public framework::OperatorBase {
auto &inx = scope.FindVar(Input("X"))->Get<framework::LoDTensorArray>();
const size_t n = inx.size();
PADDLE_ENFORCE_GT(n, 0, "Input tensorarray size should > 0.");
PADDLE_ENFORCE_GT(n, 0, platform::errors::InvalidArgument(
"Input tensorarray size should > 0, "
"but the received is: %d. ",
n));
std::string base_name = Inputs("X")[0];
std::vector<std::string> names;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册