From e1fb77d123898e3d1d0a6e42e68e9c91dbcda746 Mon Sep 17 00:00:00 2001 From: ruri Date: Thu, 24 Sep 2020 14:25:40 +0800 Subject: [PATCH] [2.0RC]refine error message in shuffle channel OP (#27505) * refine err msg in shuffle channel op --- paddle/fluid/operators/shuffle_channel_op.cc | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/shuffle_channel_op.cc b/paddle/fluid/operators/shuffle_channel_op.cc index c01fed108f3..119d2e72369 100644 --- a/paddle/fluid/operators/shuffle_channel_op.cc +++ b/paddle/fluid/operators/shuffle_channel_op.cc @@ -21,13 +21,13 @@ class ShuffleChannelOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of ShuffleChannelOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of ShuffleChannelOp should not be null."); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ShuffleChannelOp"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ShuffleChannelOp"); auto input_dims = ctx->GetInputDim("X"); - PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW."); + PADDLE_ENFORCE_EQ( + input_dims.size(), 4, + platform::errors::InvalidArgument("The layout of input is NCHW.")); ctx->SetOutputDim("Out", input_dims); } @@ -53,7 +53,8 @@ class ShuffleChannelOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("group", "the number of groups.") .SetDefault(1) .AddCustomChecker([](const int& group) { - PADDLE_ENFORCE_GE(group, 1, "group should be larger than 0."); + PADDLE_ENFORCE_GE(group, 1, platform::errors::InvalidArgument( + "group should be larger than 0.")); }); AddComment(R"DOC( @@ -76,7 +77,9 @@ class ShuffleChannelGradOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { auto input_dims = ctx->GetInputDim(framework::GradVarName("Out")); - PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW."); + PADDLE_ENFORCE_EQ( + input_dims.size(), 4, + platform::errors::InvalidArgument("The layout of input is NCHW.")); ctx->SetOutputDim(framework::GradVarName("X"), input_dims); } -- GitLab