未验证 提交 10edacec 编写于 作者: R ruri 提交者: GitHub

test=release/1.8 , Fix err message (#24507) (#24540)

* fix error message, test=develop
上级 37861815
......@@ -20,17 +20,27 @@ class PixelShuffleOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of PixelShuffleOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of PixelShuffleOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::NotFound(
"Input(X) of PixelShuffleOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of PixelShuffleOp should not be null."));
auto input_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW.");
PADDLE_ENFORCE_EQ(
input_dims.size(), 4,
platform::errors::InvalidArgument(
"Input should be a 4-D tensor of format [N, C, H, W], but got %u.",
input_dims.size()));
auto upscale_factor = ctx->Attrs().Get<int>("upscale_factor");
PADDLE_ENFORCE(input_dims[1] % (upscale_factor * upscale_factor) == 0,
"Upscale_factor should devide the number of channel");
PADDLE_ENFORCE_EQ(input_dims[1] % (upscale_factor * upscale_factor), 0,
platform::errors::InvalidArgument(
"The square of upscale_factor[%u] should divide the "
"number of channel[%u]",
input_dims[1], upscale_factor * upscale_factor));
auto output_dims = input_dims;
output_dims[0] = input_dims[0];
......@@ -57,7 +67,8 @@ class PixelShuffleOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(1)
.AddCustomChecker([](const int& upscale_factor) {
PADDLE_ENFORCE_GE(upscale_factor, 1,
"upscale_factor should be larger than 0.");
platform::errors::InvalidArgument(
"upscale_factor should be larger than 0."));
});
AddComment(R"DOC(
......@@ -95,13 +106,19 @@ class PixelShuffleGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@Grad) should not be null");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Output(X@Grad) should not be null");
PADDLE_ENFORCE_EQ(
ctx->HasInput(framework::GradVarName("Out")), true,
platform::errors::NotFound("Input(Out@Grad) should not be null"));
PADDLE_ENFORCE_EQ(
ctx->HasOutput(framework::GradVarName("X")), true,
platform::errors::NotFound("Output(X@Grad) should not be null"));
auto do_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE(do_dims.size() == 4, "The layout of input is NCHW.");
PADDLE_ENFORCE_EQ(
do_dims.size(), 4,
platform::errors::InvalidArgument(
"Input should be a 4-D tensor of format [N, C, H, W], but got %u.",
do_dims.size()));
auto upscale_factor = ctx->Attrs().Get<int>("upscale_factor");
......
......@@ -14658,6 +14658,7 @@ def pixel_shuffle(x, upscale_factor):
"""
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'pixel_shuffle')
helper = LayerHelper("pixel_shuffle", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册