未验证 提交 9cbf6013 编写于 作者: H HappyAngel 提交者: GitHub

[Cherry-Pick] [2.0-beta] c++ error enhancement (#24189)

上级 53cb2206
...@@ -48,48 +48,33 @@ class AverageAccumulatesOp : public framework::OperatorWithKernel { ...@@ -48,48 +48,33 @@ class AverageAccumulatesOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE( OP_INOUT_CHECK(ctx->HasInput("param"), "Input", "param",
ctx->HasInput("param"), "AverageAccumulates");
"Input (param) of average_accumulates op should not be null."); OP_INOUT_CHECK(ctx->HasInput("in_sum_1"), "Input", "in_sum_1",
PADDLE_ENFORCE( "AverageAccumulates");
ctx->HasInput("in_sum_1"), OP_INOUT_CHECK(ctx->HasInput("in_sum_2"), "Input", "in_sum_2",
"Input (sum_1) of average_accumulates op should not be null."); "AverageAccumulates");
PADDLE_ENFORCE( OP_INOUT_CHECK(ctx->HasInput("in_sum_3"), "Input", "in_sum_3",
ctx->HasInput("in_sum_2"), "AverageAccumulates");
"Input (sum_2) of average_accumulates op should not be null."); OP_INOUT_CHECK(ctx->HasInput("in_num_accumulates"), "Input",
PADDLE_ENFORCE( "in_num_accumulates", "AverageAccumulates");
ctx->HasInput("in_sum_3"), OP_INOUT_CHECK(ctx->HasInput("in_old_num_accumulates"), "Input",
"Input (sum_3) of average_accumulates op should not be null."); "in_old_num_accumulates", "AverageAccumulates");
PADDLE_ENFORCE( OP_INOUT_CHECK(ctx->HasInput("in_num_updates"), "Input", "in_num_updates",
ctx->HasInput("in_num_accumulates"), "AverageAccumulates");
"Input (in_num_accumulates) of average_accumulates op should "
"not be null."); OP_INOUT_CHECK(ctx->HasOutput("out_sum_1"), "Output", "out_sum_1",
PADDLE_ENFORCE(ctx->HasInput("in_old_num_accumulates"), "AverageAccumulates");
"Input (old_num_accumulates) of average_accumulates op " OP_INOUT_CHECK(ctx->HasOutput("out_sum_2"), "Output", "out_sum_2",
"should not be null."); "AverageAccumulates");
PADDLE_ENFORCE( OP_INOUT_CHECK(ctx->HasOutput("out_sum_3"), "Output", "out_sum_3",
ctx->HasInput("in_num_updates"), "AverageAccumulates");
"Input (num_updates) of average_accumulates op should not be null."); OP_INOUT_CHECK(ctx->HasOutput("out_num_accumulates"), "Output",
"out_num_accumulates", "AverageAccumulates");
PADDLE_ENFORCE( OP_INOUT_CHECK(ctx->HasOutput("out_old_num_accumulates"), "Output",
ctx->HasOutput("out_sum_1"), "out_old_num_accumulates", "AverageAccumulates");
"Output (sum_1) of average_accumulates op should not be null."); OP_INOUT_CHECK(ctx->HasOutput("out_num_updates"), "Output",
PADDLE_ENFORCE( "out_num_updates", "AverageAccumulates");
ctx->HasOutput("out_sum_2"),
"Output (sum_2) of average_accumulates op should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("out_sum_3"),
"Output (sum_3) of average_accumulates op should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("out_num_accumulates"),
"Output (num_accumulates) of average_accumulates op should "
"not be null.");
PADDLE_ENFORCE(ctx->HasOutput("out_old_num_accumulates"),
"Output (old_num_accumulates) of average_accumulates op "
"should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("out_num_updates"),
"Output (num_updates) of average_accumulates op should not be null.");
auto in_dim = ctx->GetInputDim("param"); auto in_dim = ctx->GetInputDim("param");
ctx->SetOutputDim("out_sum_1", in_dim); ctx->SetOutputDim("out_sum_1", in_dim);
......
...@@ -23,21 +23,25 @@ class GetTensorFromSelectedRowsOp : public framework::OperatorWithKernel { ...@@ -23,21 +23,25 @@ class GetTensorFromSelectedRowsOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X",
"GetTensorFromSelectedRowsOp must have input X."); "GetTensorFromSelectedRows");
PADDLE_ENFORCE(ctx->HasOutput("Out"), OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
"GetTensorFromSelectedRowsOp must have output Out."); "GetTensorFromSelectedRows");
PADDLE_ENFORCE(
ctx->GetInputsVarType("X").front() == PADDLE_ENFORCE_EQ(
framework::proto::VarType::SELECTED_ROWS, ctx->GetInputsVarType("X").front(),
"The input X's type should be SelectedRows, but the received is %s", framework::proto::VarType::SELECTED_ROWS,
ctx->Inputs("X").front(), ctx->GetInputsVarType("X").front()); platform::errors::InvalidArgument(
PADDLE_ENFORCE( "The input X(%s)'s type should be SelectedRows, "
ctx->GetOutputsVarType("Out").front() == "but the received is %s",
framework::proto::VarType::LOD_TENSOR, ctx->Inputs("X").front(), ctx->GetInputsVarType("X").front()));
"The output Out's type should be LoDTensor, but the received is %s", PADDLE_ENFORCE_EQ(ctx->GetOutputsVarType("Out").front(),
ctx->Outputs("Out").front(), ctx->GetOutputsVarType("Out").front()); framework::proto::VarType::LOD_TENSOR,
platform::errors::InvalidArgument(
"The output Out(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx->Outputs("Out").front(),
ctx->GetOutputsVarType("Out").front()));
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
} }
......
...@@ -96,7 +96,10 @@ class LoDTensorArray2TensorOp : public framework::OperatorBase { ...@@ -96,7 +96,10 @@ class LoDTensorArray2TensorOp : public framework::OperatorBase {
*scope.FindVar(Output("OutIndex"))->GetMutable<framework::LoDTensor>(); *scope.FindVar(Output("OutIndex"))->GetMutable<framework::LoDTensor>();
const size_t n = inx.size(); const size_t n = inx.size();
PADDLE_ENFORCE_GT(n, 0, "Input tensorarray size should > 0."); PADDLE_ENFORCE_GT(n, 0, platform::errors::InvalidArgument(
"Input tensorarray size should > 0,"
"but the received is %d",
n));
std::string base_name = Inputs("X")[0]; std::string base_name = Inputs("X")[0];
std::vector<std::string> names; std::vector<std::string> names;
...@@ -232,7 +235,10 @@ class LoDTensorArray2TensorGradOp : public framework::OperatorBase { ...@@ -232,7 +235,10 @@ class LoDTensorArray2TensorGradOp : public framework::OperatorBase {
auto &inx = scope.FindVar(Input("X"))->Get<framework::LoDTensorArray>(); auto &inx = scope.FindVar(Input("X"))->Get<framework::LoDTensorArray>();
const size_t n = inx.size(); const size_t n = inx.size();
PADDLE_ENFORCE_GT(n, 0, "Input tensorarray size should > 0."); PADDLE_ENFORCE_GT(n, 0, platform::errors::InvalidArgument(
"Input tensorarray size should > 0, "
"but the received is: %d. ",
n));
std::string base_name = Inputs("X")[0]; std::string base_name = Inputs("X")[0];
std::vector<std::string> names; std::vector<std::string> names;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册