diff --git a/paddle/fluid/operators/shuffle_channel_op.cc b/paddle/fluid/operators/shuffle_channel_op.cc index c01fed108f3e2d7b399379520f6b0bec906f5d13..119d2e7236946e7243ef53c791f4bb7f48d21c91 100644 --- a/paddle/fluid/operators/shuffle_channel_op.cc +++ b/paddle/fluid/operators/shuffle_channel_op.cc @@ -21,13 +21,13 @@ class ShuffleChannelOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of ShuffleChannelOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of ShuffleChannelOp should not be null."); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ShuffleChannelOp"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ShuffleChannelOp"); auto input_dims = ctx->GetInputDim("X"); - PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW."); + PADDLE_ENFORCE_EQ( + input_dims.size(), 4, + platform::errors::InvalidArgument("The layout of input is NCHW.")); ctx->SetOutputDim("Out", input_dims); } @@ -53,7 +53,8 @@ class ShuffleChannelOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("group", "the number of groups.") .SetDefault(1) .AddCustomChecker([](const int& group) { - PADDLE_ENFORCE_GE(group, 1, "group should be larger than 0."); + PADDLE_ENFORCE_GE(group, 1, platform::errors::InvalidArgument( + "group should be larger than 0.")); }); AddComment(R"DOC( @@ -76,7 +77,9 @@ class ShuffleChannelGradOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { auto input_dims = ctx->GetInputDim(framework::GradVarName("Out")); - PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW."); + PADDLE_ENFORCE_EQ( + input_dims.size(), 4, + platform::errors::InvalidArgument("The layout of input is NCHW.")); ctx->SetOutputDim(framework::GradVarName("X"), input_dims); }