未验证 提交 2cf27260 编写于 作者: Z Zhen Wang 提交者: GitHub

OP(fake_quantize) error message enhancement (#23550)

* improve error messages of fake_quantize op. test=develop

* update the bit_length error info. test=develop
上级 1cf64e00
......@@ -180,12 +180,11 @@ class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of FakeQuantizeOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FakeQuantizeOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("OutScale"),
"Output(Scale) of FakeQuantizeOp should not be null.");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FakeQuantizeAbsMax");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
"FakeQuantizeAbsMax");
OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale",
"FakeQuantizeAbsMax");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->SetOutputDim("OutScale", {1});
ctx->ShareLoD("X", /*->*/ "Out");
......@@ -211,8 +210,11 @@ class FakeQuantizeAbsMaxOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("bit_length", "(int, default 8)")
.SetDefault(8)
.AddCustomChecker([](const int& bit_length) {
PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16,
"'bit_length' should be between 1 and 16.");
PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true,
platform::errors::InvalidArgument(
"'bit_length' should be between 1 and 16, but "
"the received is %d",
bit_length));
});
AddComment(R"DOC(
FakeQuantize operator
......@@ -230,14 +232,12 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of FakeChannelWiseQuantizeOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("Out"),
"Output(Out) of FakeChannelWiseQuantizeOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("OutScale"),
"Output(Scale) of FakeChannelWiseQuantizeOp should not be null.");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X",
"FakeChannelWiseQuantizeAbsMax");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
"FakeChannelWiseQuantizeAbsMax");
OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale",
"FakeChannelWiseQuantizeAbsMax");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->SetOutputDim("OutScale", {ctx->GetInputDim("X")[0]});
ctx->ShareLoD("X", /*->*/ "Out");
......@@ -263,8 +263,11 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker
AddAttr<int>("bit_length", "(int, default 8)")
.SetDefault(8)
.AddCustomChecker([](const int& bit_length) {
PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16,
"'bit_length' should be between 1 and 16.");
PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true,
platform::errors::InvalidArgument(
"'bit_length' should be between 1 and 16, but "
"the received is %d",
bit_length));
});
AddComment(R"DOC(
The scale of FakeChannelWiseQuantize operator is a vector.
......@@ -288,14 +291,11 @@ class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of FakeQuantizeRangeAbsMaxOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("Out"),
"Output(Out) of FakeQuantizeRangeAbsMaxOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("OutScale"),
"Output(OutScale) of FakeQuantizeRangeAbsMaxOp should not be null");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FakeQuantizeRangeAbsMax");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
"FakeQuantizeRangeAbsMax");
OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale",
"FakeQuantizeRangeAbsMax");
if (ctx->HasOutput("OutScales")) {
int window_size = ctx->Attrs().Get<int>("window_size");
ctx->SetOutputDim("OutScales", {window_size});
......@@ -329,8 +329,11 @@ class FakeQuantizeRangeAbsMaxOpMaker
AddAttr<int>("bit_length", "(int, default 8), quantization bit number.")
.SetDefault(8)
.AddCustomChecker([](const int& bit_length) {
PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16,
"'bit_length' should be between 1 and 16.");
PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true,
platform::errors::InvalidArgument(
"'bit_length' should be between 1 and 16, but "
"the received is %d",
bit_length));
});
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false "
......@@ -357,16 +360,12 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOp
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of FakeQuantOrWithDequantMovingAverageAbsMaxOp "
"should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FakeQuantOrWithDequantMovingAverageAbsMaxOp "
"should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("OutScale"),
"Output(OutScale) of FakeQuantOrWithDequantMovingAverageAbsMaxOp "
"should not be null");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X",
"FakeQuantOrWithDequantMovingAverageAbsMax");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
"FakeQuantOrWithDequantMovingAverageAbsMax");
OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale",
"FakeQuantOrWithDequantMovingAverageAbsMax");
if (ctx->HasOutput("OutState")) {
ctx->SetOutputDim("OutState", {1});
}
......@@ -404,8 +403,11 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker
AddAttr<int>("bit_length", "(int, default 8), quantization bit number.")
.SetDefault(8)
.AddCustomChecker([](const int& bit_length) {
PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16,
"'bit_length' should be between 1 and 16.");
PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true,
platform::errors::InvalidArgument(
"'bit_length' should be between 1 and 16, but "
"the received is %d",
bit_length));
});
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false "
......@@ -434,15 +436,12 @@ class MovingAverageAbsMaxScaleOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(
ctx->HasInput("X"),
"Input(X) of MovingAverageAbsMaxScaleOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("Out"),
"Output(Out) of MovingAverageAbsMaxScaleOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("OutScale"),
"Output(OutScale) of MovingAverageAbsMaxScaleOp"
"should not be null");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X",
"MovingAverageAbsMaxScale");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
"MovingAverageAbsMaxScale");
OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale",
"MovingAverageAbsMaxScale");
if (ctx->HasOutput("OutState")) {
ctx->SetOutputDim("OutState", {1});
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册