未验证 提交 52979565 编写于 作者: Y Yuan Shuai 提交者: GitHub

API/OP (clip_by_norm/clip) error message enhancement (#23603)

* error message enhance for clip_by_norm. test=develop

* fix clip_by_norm. test=develop

* fix clip error message. test=develop
上级 b4b6763a
......@@ -67,7 +67,10 @@ class ClipByNormKernel : public framework::OpKernel<T> {
framework::ToTypeName(in_var->Type()));
}
PADDLE_ENFORCE_NOT_NULL(input);
PADDLE_ENFORCE_NOT_NULL(input,
platform::errors::InvalidArgument(
"Input(X) of ClipByNormOp should not be null. "
"Please check if it is created correctly."));
auto x = EigenVector<T>::Flatten(*input);
auto out = EigenVector<T>::Flatten(*output);
......@@ -89,12 +92,19 @@ class ClipByNormOp : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ClipByNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ClipByNormOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(X) of ClipByNormOp should not be null. Please "
"check if it is created correctly."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of ClipByNormOp should not be null. "
"Please check if it is created correctly."));
auto max_norm = ctx->Attrs().Get<float>("max_norm");
PADDLE_ENFORCE_GT(max_norm, 0, "max_norm should be greater than 0.");
PADDLE_ENFORCE_GT(max_norm, 0, platform::errors::InvalidArgument(
"max_norm should be greater than 0. "
"Received max_norm is %f.",
max_norm));
auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
......
......@@ -23,14 +23,21 @@ class ClipOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ClipOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ClipOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(X) of ClipOp should not be null. Please check "
"if it is created correctly."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of ClipOp should not be null. Please "
"check if it is created correctly."));
auto x_dims = ctx->GetInputDim("X");
auto max = ctx->Attrs().Get<float>("max");
auto min = ctx->Attrs().Get<float>("min");
PADDLE_ENFORCE_LT(min, max, "max should be greater than min.");
PADDLE_ENFORCE_LT(min, max, platform::errors::InvalidArgument(
"Max of ClipOp should be greater than min. "
"Received max is %f, received min is %f.",
max, min));
ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
......@@ -68,9 +75,14 @@ class ClipOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::InvalidArgument("Input(X) should not be null. Please "
"check if it is created correctly."));
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
platform::errors::InvalidArgument(
"Input(Out@GRAD) should not be null. Please check if "
"it is created correctly."));
auto x_dims = ctx->GetInputDim("X");
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
......
......@@ -11650,6 +11650,8 @@ def clip_by_norm(x, max_norm, name=None):
"""
helper = LayerHelper("clip_by_norm", **locals())
check_variable_and_dtype(x, 'X', ['float32'], 'clip_by_norm')
check_type(max_norm, 'max_norm', (float), 'clip_by_norm')
if name is None:
name = unique_name.generate_with_ignorable_key(".".join(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册