未验证 提交 10edacec 编写于 作者: R ruri 提交者: GitHub

test=release/1.8 , Fix err message (#24507) (#24540)

* fix error message, test=develop
上级 37861815
...@@ -20,17 +20,27 @@ class PixelShuffleOp : public framework::OperatorWithKernel { ...@@ -20,17 +20,27 @@ class PixelShuffleOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"Input(X) of PixelShuffleOp should not be null."); platform::errors::NotFound(
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Input(X) of PixelShuffleOp should not be null."));
"Output(Out) of PixelShuffleOp should not be null."); PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of PixelShuffleOp should not be null."));
auto input_dims = ctx->GetInputDim("X"); auto input_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW."); PADDLE_ENFORCE_EQ(
input_dims.size(), 4,
platform::errors::InvalidArgument(
"Input should be a 4-D tensor of format [N, C, H, W], but got %u.",
input_dims.size()));
auto upscale_factor = ctx->Attrs().Get<int>("upscale_factor"); auto upscale_factor = ctx->Attrs().Get<int>("upscale_factor");
PADDLE_ENFORCE(input_dims[1] % (upscale_factor * upscale_factor) == 0, PADDLE_ENFORCE_EQ(input_dims[1] % (upscale_factor * upscale_factor), 0,
"Upscale_factor should devide the number of channel"); platform::errors::InvalidArgument(
"The square of upscale_factor[%u] should divide the "
"number of channel[%u]",
input_dims[1], upscale_factor * upscale_factor));
auto output_dims = input_dims; auto output_dims = input_dims;
output_dims[0] = input_dims[0]; output_dims[0] = input_dims[0];
...@@ -57,7 +67,8 @@ class PixelShuffleOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -57,7 +67,8 @@ class PixelShuffleOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(1) .SetDefault(1)
.AddCustomChecker([](const int& upscale_factor) { .AddCustomChecker([](const int& upscale_factor) {
PADDLE_ENFORCE_GE(upscale_factor, 1, PADDLE_ENFORCE_GE(upscale_factor, 1,
"upscale_factor should be larger than 0."); platform::errors::InvalidArgument(
"upscale_factor should be larger than 0."));
}); });
AddComment(R"DOC( AddComment(R"DOC(
...@@ -95,13 +106,19 @@ class PixelShuffleGradOp : public framework::OperatorWithKernel { ...@@ -95,13 +106,19 @@ class PixelShuffleGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), PADDLE_ENFORCE_EQ(
"Input(Out@Grad) should not be null"); ctx->HasInput(framework::GradVarName("Out")), true,
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), platform::errors::NotFound("Input(Out@Grad) should not be null"));
"Output(X@Grad) should not be null"); PADDLE_ENFORCE_EQ(
ctx->HasOutput(framework::GradVarName("X")), true,
platform::errors::NotFound("Output(X@Grad) should not be null"));
auto do_dims = ctx->GetInputDim(framework::GradVarName("Out")); auto do_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE(do_dims.size() == 4, "The layout of input is NCHW."); PADDLE_ENFORCE_EQ(
do_dims.size(), 4,
platform::errors::InvalidArgument(
"Input should be a 4-D tensor of format [N, C, H, W], but got %u.",
do_dims.size()));
auto upscale_factor = ctx->Attrs().Get<int>("upscale_factor"); auto upscale_factor = ctx->Attrs().Get<int>("upscale_factor");
......
...@@ -14658,6 +14658,7 @@ def pixel_shuffle(x, upscale_factor): ...@@ -14658,6 +14658,7 @@ def pixel_shuffle(x, upscale_factor):
""" """
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'pixel_shuffle')
helper = LayerHelper("pixel_shuffle", **locals()) helper = LayerHelper("pixel_shuffle", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册