未验证 提交 2c4b57e9 编写于 作者: G GaoWei8 提交者: GitHub

Op (concat) error message enhancement (#23523)

上级 45880f60
...@@ -30,18 +30,17 @@ class ConcatOp : public framework::OperatorWithKernel { ...@@ -30,18 +30,17 @@ class ConcatOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL, OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X", "Concat");
"Inputs(X) of ConcatOp should not be empty."); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Concat");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of ConcatOp should not be null.");
auto inputs_dims = ctx->GetInputsDim("X"); auto inputs_dims = ctx->GetInputsDim("X");
const size_t inputs_num = inputs_dims.size(); const size_t inputs_num = inputs_dims.size();
PADDLE_ENFORCE_GT(inputs_num, 0, PADDLE_ENFORCE_GT(
"ShapeError: Input tensors count should > 0. But " inputs_num, static_cast<size_t>(0),
"recevied inputs' length is 0."); platform::errors::InvalidArgument(
"The number of input tensors in concat op should > 0. But "
"received inputs' length is 0."));
if (inputs_num == 1) { if (inputs_num == 1) {
VLOG(3) << "Warning: concat op have only one input, may waste memory"; VLOG(3) << "Warning: concat op have only one input, may waste memory";
} }
......
...@@ -49,10 +49,11 @@ static inline framework::DDim ComputeAndCheckShape( ...@@ -49,10 +49,11 @@ static inline framework::DDim ComputeAndCheckShape(
// check all shape in run time // check all shape in run time
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
inputs_dims[0][j], inputs_dims[i][j], inputs_dims[0][j], inputs_dims[i][j],
"ShapeError: Dimension %d in inputs' shapes must be equal. " platform::errors::InvalidArgument(
"But recevied input[0]'s shape = " "The shape of input[%d] must be equal to input[0]. "
"[%s], input[%d]'s shape = [%s].", "But received input[0]'s shape = "
j, inputs_dims[0], i, inputs_dims[i]); "[%s], input[%d]'s shape = [%s].",
i, inputs_dims[0], i, inputs_dims[i]));
} }
} }
} }
...@@ -78,7 +79,9 @@ class ConcatKernel : public framework::OpKernel<T> { ...@@ -78,7 +79,9 @@ class ConcatKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<framework::LoDTensor>("X"); auto ins = ctx.MultiInput<framework::LoDTensor>("X");
framework::LoDTensor* out = ctx.Output<framework::LoDTensor>("Out"); framework::LoDTensor* out = ctx.Output<framework::LoDTensor>("Out");
PADDLE_ENFORCE_EQ(ins[0] != nullptr, true, "The input should not be null."); PADDLE_ENFORCE_NOT_NULL(
ins[0], platform::errors::NotFound(
" The first input of concat should not be null."));
auto axis = ctx.Attr<int>("axis"); auto axis = ctx.Attr<int>("axis");
bool need_resize_out_dims = false; bool need_resize_out_dims = false;
if (ctx.HasInput("AxisTensor")) { if (ctx.HasInput("AxisTensor")) {
...@@ -178,7 +181,9 @@ class ConcatGradKernel : public framework::OpKernel<T> { ...@@ -178,7 +181,9 @@ class ConcatGradKernel : public framework::OpKernel<T> {
} }
} }
} }
PADDLE_ENFORCE_EQ(ins[0] != nullptr, true, "The input should not be null."); PADDLE_ENFORCE_NOT_NULL(
ins[0], platform::errors::NotFound(
"The first input of concat should not be null."));
auto axis = ctx.Attr<int>("axis"); auto axis = ctx.Attr<int>("axis");
if (ctx.HasInput("AxisTensor")) { if (ctx.HasInput("AxisTensor")) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册