未验证 提交 e1fb77d1 编写于 作者: R ruri 提交者: GitHub

[2.0RC]refine error message in shuffle channel OP (#27505)

* refine err msg in shuffle channel op
上级 42363674
...@@ -21,13 +21,13 @@ class ShuffleChannelOp : public framework::OperatorWithKernel { ...@@ -21,13 +21,13 @@ class ShuffleChannelOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ShuffleChannelOp");
"Input(X) of ShuffleChannelOp should not be null."); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ShuffleChannelOp");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ShuffleChannelOp should not be null.");
auto input_dims = ctx->GetInputDim("X"); auto input_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW."); PADDLE_ENFORCE_EQ(
input_dims.size(), 4,
platform::errors::InvalidArgument("The layout of input is NCHW."));
ctx->SetOutputDim("Out", input_dims); ctx->SetOutputDim("Out", input_dims);
} }
...@@ -53,7 +53,8 @@ class ShuffleChannelOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -53,7 +53,8 @@ class ShuffleChannelOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("group", "the number of groups.") AddAttr<int>("group", "the number of groups.")
.SetDefault(1) .SetDefault(1)
.AddCustomChecker([](const int& group) { .AddCustomChecker([](const int& group) {
PADDLE_ENFORCE_GE(group, 1, "group should be larger than 0."); PADDLE_ENFORCE_GE(group, 1, platform::errors::InvalidArgument(
"group should be larger than 0."));
}); });
AddComment(R"DOC( AddComment(R"DOC(
...@@ -76,7 +77,9 @@ class ShuffleChannelGradOp : public framework::OperatorWithKernel { ...@@ -76,7 +77,9 @@ class ShuffleChannelGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
auto input_dims = ctx->GetInputDim(framework::GradVarName("Out")); auto input_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW."); PADDLE_ENFORCE_EQ(
input_dims.size(), 4,
platform::errors::InvalidArgument("The layout of input is NCHW."));
ctx->SetOutputDim(framework::GradVarName("X"), input_dims); ctx->SetOutputDim(framework::GradVarName("X"), input_dims);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册