未验证 提交 2cf27260 编写于 作者: Z Zhen Wang 提交者: GitHub

OP(fake_quantize) error message enhancement (#23550)

* improve error messages of fake_quantize op. test=develop

* update the bit_length error info. test=develop
上级 1cf64e00
...@@ -180,12 +180,11 @@ class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel { ...@@ -180,12 +180,11 @@ class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FakeQuantizeAbsMax");
"Input(X) of FakeQuantizeOp should not be null."); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
PADDLE_ENFORCE(ctx->HasOutput("Out"), "FakeQuantizeAbsMax");
"Output(Out) of FakeQuantizeOp should not be null."); OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale",
PADDLE_ENFORCE(ctx->HasOutput("OutScale"), "FakeQuantizeAbsMax");
"Output(Scale) of FakeQuantizeOp should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->SetOutputDim("OutScale", {1}); ctx->SetOutputDim("OutScale", {1});
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
...@@ -211,8 +210,11 @@ class FakeQuantizeAbsMaxOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -211,8 +210,11 @@ class FakeQuantizeAbsMaxOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("bit_length", "(int, default 8)") AddAttr<int>("bit_length", "(int, default 8)")
.SetDefault(8) .SetDefault(8)
.AddCustomChecker([](const int& bit_length) { .AddCustomChecker([](const int& bit_length) {
PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16, PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true,
"'bit_length' should be between 1 and 16."); platform::errors::InvalidArgument(
"'bit_length' should be between 1 and 16, but "
"the received is %d",
bit_length));
}); });
AddComment(R"DOC( AddComment(R"DOC(
FakeQuantize operator FakeQuantize operator
...@@ -230,14 +232,12 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel { ...@@ -230,14 +232,12 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X",
"Input(X) of FakeChannelWiseQuantizeOp should not be null."); "FakeChannelWiseQuantizeAbsMax");
PADDLE_ENFORCE( OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
ctx->HasOutput("Out"), "FakeChannelWiseQuantizeAbsMax");
"Output(Out) of FakeChannelWiseQuantizeOp should not be null."); OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale",
PADDLE_ENFORCE( "FakeChannelWiseQuantizeAbsMax");
ctx->HasOutput("OutScale"),
"Output(Scale) of FakeChannelWiseQuantizeOp should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->SetOutputDim("OutScale", {ctx->GetInputDim("X")[0]}); ctx->SetOutputDim("OutScale", {ctx->GetInputDim("X")[0]});
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
...@@ -263,8 +263,11 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker ...@@ -263,8 +263,11 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker
AddAttr<int>("bit_length", "(int, default 8)") AddAttr<int>("bit_length", "(int, default 8)")
.SetDefault(8) .SetDefault(8)
.AddCustomChecker([](const int& bit_length) { .AddCustomChecker([](const int& bit_length) {
PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16, PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true,
"'bit_length' should be between 1 and 16."); platform::errors::InvalidArgument(
"'bit_length' should be between 1 and 16, but "
"the received is %d",
bit_length));
}); });
AddComment(R"DOC( AddComment(R"DOC(
The scale of FakeChannelWiseQuantize operator is a vector. The scale of FakeChannelWiseQuantize operator is a vector.
...@@ -288,14 +291,11 @@ class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel { ...@@ -288,14 +291,11 @@ class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FakeQuantizeRangeAbsMax");
"Input(X) of FakeQuantizeRangeAbsMaxOp should not be null."); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
PADDLE_ENFORCE( "FakeQuantizeRangeAbsMax");
ctx->HasOutput("Out"), OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale",
"Output(Out) of FakeQuantizeRangeAbsMaxOp should not be null."); "FakeQuantizeRangeAbsMax");
PADDLE_ENFORCE(
ctx->HasOutput("OutScale"),
"Output(OutScale) of FakeQuantizeRangeAbsMaxOp should not be null");
if (ctx->HasOutput("OutScales")) { if (ctx->HasOutput("OutScales")) {
int window_size = ctx->Attrs().Get<int>("window_size"); int window_size = ctx->Attrs().Get<int>("window_size");
ctx->SetOutputDim("OutScales", {window_size}); ctx->SetOutputDim("OutScales", {window_size});
...@@ -329,8 +329,11 @@ class FakeQuantizeRangeAbsMaxOpMaker ...@@ -329,8 +329,11 @@ class FakeQuantizeRangeAbsMaxOpMaker
AddAttr<int>("bit_length", "(int, default 8), quantization bit number.") AddAttr<int>("bit_length", "(int, default 8), quantization bit number.")
.SetDefault(8) .SetDefault(8)
.AddCustomChecker([](const int& bit_length) { .AddCustomChecker([](const int& bit_length) {
PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16, PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true,
"'bit_length' should be between 1 and 16."); platform::errors::InvalidArgument(
"'bit_length' should be between 1 and 16, but "
"the received is %d",
bit_length));
}); });
AddAttr<bool>("is_test", AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false " "(bool, default false) Set to true for inference only, false "
...@@ -357,16 +360,12 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOp ...@@ -357,16 +360,12 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOp
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X",
"Input(X) of FakeQuantOrWithDequantMovingAverageAbsMaxOp " "FakeQuantOrWithDequantMovingAverageAbsMax");
"should not be null."); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
PADDLE_ENFORCE(ctx->HasOutput("Out"), "FakeQuantOrWithDequantMovingAverageAbsMax");
"Output(Out) of FakeQuantOrWithDequantMovingAverageAbsMaxOp " OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale",
"should not be null."); "FakeQuantOrWithDequantMovingAverageAbsMax");
PADDLE_ENFORCE(
ctx->HasOutput("OutScale"),
"Output(OutScale) of FakeQuantOrWithDequantMovingAverageAbsMaxOp "
"should not be null");
if (ctx->HasOutput("OutState")) { if (ctx->HasOutput("OutState")) {
ctx->SetOutputDim("OutState", {1}); ctx->SetOutputDim("OutState", {1});
} }
...@@ -404,8 +403,11 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker ...@@ -404,8 +403,11 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker
AddAttr<int>("bit_length", "(int, default 8), quantization bit number.") AddAttr<int>("bit_length", "(int, default 8), quantization bit number.")
.SetDefault(8) .SetDefault(8)
.AddCustomChecker([](const int& bit_length) { .AddCustomChecker([](const int& bit_length) {
PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16, PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true,
"'bit_length' should be between 1 and 16."); platform::errors::InvalidArgument(
"'bit_length' should be between 1 and 16, but "
"the received is %d",
bit_length));
}); });
AddAttr<bool>("is_test", AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false " "(bool, default false) Set to true for inference only, false "
...@@ -434,15 +436,12 @@ class MovingAverageAbsMaxScaleOp : public framework::OperatorWithKernel { ...@@ -434,15 +436,12 @@ class MovingAverageAbsMaxScaleOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE( OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X",
ctx->HasInput("X"), "MovingAverageAbsMaxScale");
"Input(X) of MovingAverageAbsMaxScaleOp should not be null."); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
PADDLE_ENFORCE( "MovingAverageAbsMaxScale");
ctx->HasOutput("Out"), OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale",
"Output(Out) of MovingAverageAbsMaxScaleOp should not be null."); "MovingAverageAbsMaxScale");
PADDLE_ENFORCE(ctx->HasOutput("OutScale"),
"Output(OutScale) of MovingAverageAbsMaxScaleOp"
"should not be null");
if (ctx->HasOutput("OutState")) { if (ctx->HasOutput("OutState")) {
ctx->SetOutputDim("OutState", {1}); ctx->SetOutputDim("OutState", {1});
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册