未验证 提交 e1fb77d1 编写于 作者: R ruri 提交者: GitHub

[2.0RC]refine error message in shuffle channel OP (#27505)

* refine err msg in shuffle channel op
上级 42363674
......@@ -21,13 +21,13 @@ class ShuffleChannelOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ShuffleChannelOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ShuffleChannelOp should not be null.");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ShuffleChannelOp");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ShuffleChannelOp");
auto input_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW.");
PADDLE_ENFORCE_EQ(
input_dims.size(), 4,
platform::errors::InvalidArgument("The layout of input is NCHW."));
ctx->SetOutputDim("Out", input_dims);
}
......@@ -53,7 +53,8 @@ class ShuffleChannelOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("group", "the number of groups.")
.SetDefault(1)
.AddCustomChecker([](const int& group) {
PADDLE_ENFORCE_GE(group, 1, "group should be larger than 0.");
PADDLE_ENFORCE_GE(group, 1, platform::errors::InvalidArgument(
"group should be larger than 0."));
});
AddComment(R"DOC(
......@@ -76,7 +77,9 @@ class ShuffleChannelGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
auto input_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW.");
PADDLE_ENFORCE_EQ(
input_dims.size(), 4,
platform::errors::InvalidArgument("The layout of input is NCHW."));
ctx->SetOutputDim(framework::GradVarName("X"), input_dims);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册