未验证 提交 52979565 编写于 作者: Y Yuan Shuai 提交者: GitHub

API/OP (clip_by_norm/clip) error message enhancement (#23603)

* error message enhance for clip_by_norm. test=develop

* fix clip_by_norm. test=develop

* fix clip error message. test=develop
上级 b4b6763a
......@@ -67,7 +67,10 @@ class ClipByNormKernel : public framework::OpKernel<T> {
framework::ToTypeName(in_var->Type()));
}
PADDLE_ENFORCE_NOT_NULL(input);
PADDLE_ENFORCE_NOT_NULL(input,
platform::errors::InvalidArgument(
"Input(X) of ClipByNormOp should not be null. "
"Please check if it is created correctly."));
auto x = EigenVector<T>::Flatten(*input);
auto out = EigenVector<T>::Flatten(*output);
......@@ -89,12 +92,19 @@ class ClipByNormOp : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ClipByNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ClipByNormOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(X) of ClipByNormOp should not be null. Please "
"check if it is created correctly."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of ClipByNormOp should not be null. "
"Please check if it is created correctly."));
auto max_norm = ctx->Attrs().Get<float>("max_norm");
PADDLE_ENFORCE_GT(max_norm, 0, "max_norm should be greater than 0.");
PADDLE_ENFORCE_GT(max_norm, 0, platform::errors::InvalidArgument(
"max_norm should be greater than 0. "
"Received max_norm is %f.",
max_norm));
auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
......
......@@ -23,14 +23,21 @@ class ClipOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ClipOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ClipOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(X) of ClipOp should not be null. Please check "
"if it is created correctly."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of ClipOp should not be null. Please "
"check if it is created correctly."));
auto x_dims = ctx->GetInputDim("X");
auto max = ctx->Attrs().Get<float>("max");
auto min = ctx->Attrs().Get<float>("min");
PADDLE_ENFORCE_LT(min, max, "max should be greater than min.");
PADDLE_ENFORCE_LT(min, max, platform::errors::InvalidArgument(
"Max of ClipOp should be greater than min. "
"Received max is %f, received min is %f.",
max, min));
ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
......@@ -52,7 +59,7 @@ class ClipOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(
Clip Operator.
The clip operator limits the value of given input within an interval [min, max],
The clip operator limits the value of given input within an interval [min, max],
just as the following equation,
$$
......@@ -68,9 +75,14 @@ class ClipOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::InvalidArgument("Input(X) should not be null. Please "
"check if it is created correctly."));
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
platform::errors::InvalidArgument(
"Input(Out@GRAD) should not be null. Please check if "
"it is created correctly."));
auto x_dims = ctx->GetInputDim("X");
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册