未验证 提交 2c4b57e9 编写于 作者: G GaoWei8 提交者: GitHub

Op (concat) error message enhancement (#23523)

上级 45880f60
......@@ -30,18 +30,17 @@ class ConcatOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL,
"Inputs(X) of ConcatOp should not be empty.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of ConcatOp should not be null.");
OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X", "Concat");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Concat");
auto inputs_dims = ctx->GetInputsDim("X");
const size_t inputs_num = inputs_dims.size();
PADDLE_ENFORCE_GT(inputs_num, 0,
"ShapeError: Input tensors count should > 0. But "
"recevied inputs' length is 0.");
PADDLE_ENFORCE_GT(
inputs_num, static_cast<size_t>(0),
platform::errors::InvalidArgument(
"The number of input tensors in concat op should > 0. But "
"received inputs' length is 0."));
if (inputs_num == 1) {
VLOG(3) << "Warning: concat op have only one input, may waste memory";
}
......
......@@ -49,10 +49,11 @@ static inline framework::DDim ComputeAndCheckShape(
// check all shape in run time
PADDLE_ENFORCE_EQ(
inputs_dims[0][j], inputs_dims[i][j],
"ShapeError: Dimension %d in inputs' shapes must be equal. "
"But recevied input[0]'s shape = "
platform::errors::InvalidArgument(
"The shape of input[%d] must be equal to input[0]. "
"But received input[0]'s shape = "
"[%s], input[%d]'s shape = [%s].",
j, inputs_dims[0], i, inputs_dims[i]);
i, inputs_dims[0], i, inputs_dims[i]));
}
}
}
......@@ -78,7 +79,9 @@ class ConcatKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<framework::LoDTensor>("X");
framework::LoDTensor* out = ctx.Output<framework::LoDTensor>("Out");
PADDLE_ENFORCE_EQ(ins[0] != nullptr, true, "The input should not be null.");
PADDLE_ENFORCE_NOT_NULL(
ins[0], platform::errors::NotFound(
" The first input of concat should not be null."));
auto axis = ctx.Attr<int>("axis");
bool need_resize_out_dims = false;
if (ctx.HasInput("AxisTensor")) {
......@@ -178,7 +181,9 @@ class ConcatGradKernel : public framework::OpKernel<T> {
}
}
}
PADDLE_ENFORCE_EQ(ins[0] != nullptr, true, "The input should not be null.");
PADDLE_ENFORCE_NOT_NULL(
ins[0], platform::errors::NotFound(
"The first input of concat should not be null."));
auto axis = ctx.Attr<int>("axis");
if (ctx.HasInput("AxisTensor")) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册