diff --git a/paddle/fluid/operators/fake_quantize_op.cc b/paddle/fluid/operators/fake_quantize_op.cc index 292a69e82b9e280ae601ba5eb5582586548c7c5a..8c07e445a6f7a3ff54a0919dd653d6d3615e30fc 100644 --- a/paddle/fluid/operators/fake_quantize_op.cc +++ b/paddle/fluid/operators/fake_quantize_op.cc @@ -180,12 +180,11 @@ class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of FakeQuantizeOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of FakeQuantizeOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("OutScale"), - "Output(Scale) of FakeQuantizeOp should not be null."); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FakeQuantizeAbsMax"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", + "FakeQuantizeAbsMax"); + OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale", + "FakeQuantizeAbsMax"); ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->SetOutputDim("OutScale", {1}); ctx->ShareLoD("X", /*->*/ "Out"); @@ -211,8 +210,11 @@ class FakeQuantizeAbsMaxOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("bit_length", "(int, default 8)") .SetDefault(8) .AddCustomChecker([](const int& bit_length) { - PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16, - "'bit_length' should be between 1 and 16."); + PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true, + platform::errors::InvalidArgument( + "'bit_length' should be between 1 and 16, but " + "the received is %d", + bit_length)); }); AddComment(R"DOC( FakeQuantize operator @@ -230,14 +232,12 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of FakeChannelWiseQuantizeOp should not be null."); - PADDLE_ENFORCE( - ctx->HasOutput("Out"), - "Output(Out) of FakeChannelWiseQuantizeOp should not be null."); - PADDLE_ENFORCE( - ctx->HasOutput("OutScale"), - "Output(Scale) of FakeChannelWiseQuantizeOp should not be null."); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", + "FakeChannelWiseQuantizeAbsMax"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", + "FakeChannelWiseQuantizeAbsMax"); + OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale", + "FakeChannelWiseQuantizeAbsMax"); ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->SetOutputDim("OutScale", {ctx->GetInputDim("X")[0]}); ctx->ShareLoD("X", /*->*/ "Out"); @@ -263,8 +263,11 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker AddAttr("bit_length", "(int, default 8)") .SetDefault(8) .AddCustomChecker([](const int& bit_length) { - PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16, - "'bit_length' should be between 1 and 16."); + PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true, + platform::errors::InvalidArgument( + "'bit_length' should be between 1 and 16, but " + "the received is %d", + bit_length)); }); AddComment(R"DOC( The scale of FakeChannelWiseQuantize operator is a vector. @@ -288,14 +291,11 @@ class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of FakeQuantizeRangeAbsMaxOp should not be null."); - PADDLE_ENFORCE( - ctx->HasOutput("Out"), - "Output(Out) of FakeQuantizeRangeAbsMaxOp should not be null."); - PADDLE_ENFORCE( - ctx->HasOutput("OutScale"), - "Output(OutScale) of FakeQuantizeRangeAbsMaxOp should not be null"); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FakeQuantizeRangeAbsMax"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", + "FakeQuantizeRangeAbsMax"); + OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale", + "FakeQuantizeRangeAbsMax"); if (ctx->HasOutput("OutScales")) { int window_size = ctx->Attrs().Get("window_size"); ctx->SetOutputDim("OutScales", {window_size}); @@ -329,8 +329,11 @@ class FakeQuantizeRangeAbsMaxOpMaker AddAttr("bit_length", "(int, default 8), quantization bit number.") .SetDefault(8) .AddCustomChecker([](const int& bit_length) { - PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16, - "'bit_length' should be between 1 and 16."); + PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true, + platform::errors::InvalidArgument( + "'bit_length' should be between 1 and 16, but " + "the received is %d", + bit_length)); }); AddAttr("is_test", "(bool, default false) Set to true for inference only, false " @@ -357,16 +360,12 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOp : OperatorWithKernel(type, inputs, outputs, attrs) {} void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of FakeQuantOrWithDequantMovingAverageAbsMaxOp " - "should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of FakeQuantOrWithDequantMovingAverageAbsMaxOp " - "should not be null."); - PADDLE_ENFORCE( - ctx->HasOutput("OutScale"), - "Output(OutScale) of FakeQuantOrWithDequantMovingAverageAbsMaxOp " - "should not be null"); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", + "FakeQuantOrWithDequantMovingAverageAbsMax"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", + "FakeQuantOrWithDequantMovingAverageAbsMax"); + OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale", + "FakeQuantOrWithDequantMovingAverageAbsMax"); if (ctx->HasOutput("OutState")) { ctx->SetOutputDim("OutState", {1}); } @@ -404,8 +403,11 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker AddAttr("bit_length", "(int, default 8), quantization bit number.") .SetDefault(8) .AddCustomChecker([](const int& bit_length) { - PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16, - "'bit_length' should be between 1 and 16."); + PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true, + platform::errors::InvalidArgument( + "'bit_length' should be between 1 and 16, but " + "the received is %d", + bit_length)); }); AddAttr("is_test", "(bool, default false) Set to true for inference only, false " @@ -434,15 +436,12 @@ class MovingAverageAbsMaxScaleOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE( - ctx->HasInput("X"), - "Input(X) of MovingAverageAbsMaxScaleOp should not be null."); - PADDLE_ENFORCE( - ctx->HasOutput("Out"), - "Output(Out) of MovingAverageAbsMaxScaleOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("OutScale"), - "Output(OutScale) of MovingAverageAbsMaxScaleOp" - "should not be null"); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", + "MovingAverageAbsMaxScale"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", + "MovingAverageAbsMaxScale"); + OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale", + "MovingAverageAbsMaxScale"); if (ctx->HasOutput("OutState")) { ctx->SetOutputDim("OutState", {1}); }