From 75144f134876e78cb0284b334a431bdfb83a7007 Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Tue, 21 Jun 2022 19:12:05 +0800 Subject: [PATCH] Update quantization round and clip calculation rules (#42695) --- .../ir/delete_quant_dequant_filter_op_pass.cc | 8 + .../ir/delete_quant_dequant_linear_op_pass.cc | 8 + .../delete_weight_dequant_linear_op_pass.cc | 8 + .../ir/quant_conv2d_dequant_fuse_pass.cc | 8 + paddle/fluid/operators/fake_quantize_op.cc | 127 ++++++++++---- paddle/fluid/operators/fake_quantize_op.cu.h | 136 ++++++++++----- paddle/fluid/operators/fake_quantize_op.h | 92 +++++++--- paddle/fluid/operators/quantize_linear_op.cc | 16 +- paddle/fluid/operators/quantize_linear_op.h | 9 +- .../contrib/slim/quantization/adaround.py | 12 +- .../post_training_quantization.py | 46 +++-- .../slim/quantization/quantization_pass.py | 93 ++++++++-- .../fluid/contrib/slim/quantization/utils.py | 34 ++-- .../fluid/contrib/slim/tests/CMakeLists.txt | 2 +- .../contrib/slim/tests/test_imperative_ptq.py | 2 +- ...t_post_training_quantization_lstm_model.py | 16 +- .../test_post_training_quantization_mnist.py | 83 +++++---- ..._post_training_quantization_mobilenetv1.py | 33 ++-- ...est_post_training_quantization_resnet50.py | 8 +- .../tests/unittests/test_fake_quantize_op.py | 165 +++++++++++++----- 20 files changed, 653 insertions(+), 253 deletions(-) diff --git a/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc b/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc index a02efc0a7c..e168266c9a 100644 --- a/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc @@ -45,6 +45,10 @@ DeleteQuantDequantFilterOpPass::DeleteQuantDequantFilterOpPass() { .End() .AddAttr("bit_length") .IsIntIn({8, 16}) + .End() + .AddAttr("round_type") + .IsOptional() + .IsIntIn({0, 1}) .End(); AddOpCompat(OpCompat("fake_channel_wise_quantize_dequantize_abs_max")) .AddInput("X") @@ -61,6 +65,10 @@ DeleteQuantDequantFilterOpPass::DeleteQuantDequantFilterOpPass() { .End() .AddAttr("quant_axis") .IsIntIn({0, 1}) + .End() + .AddAttr("round_type") + .IsOptional() + .IsIntIn({0, 1}) .End(); } // Delete quant_dequant_op, then quantize and dequantize weight diff --git a/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc b/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc index 8f2b58ed51..aa265a43ba 100644 --- a/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc @@ -54,6 +54,10 @@ DeleteQuantDequantLinearOpPass::DeleteQuantDequantLinearOpPass() { .End() .AddAttr("quant_axis") .IsType() + .End() + .AddAttr("round_type") + .IsOptional() + .IsType() .End(); AddOpCompat(OpCompat("dequantize_linear")) .AddInput("X") @@ -74,6 +78,10 @@ DeleteQuantDequantLinearOpPass::DeleteQuantDequantLinearOpPass() { .End() .AddAttr("quant_axis") .IsType() + .End() + .AddAttr("round_type") + .IsOptional() + .IsType() .End(); } // Delete quantize_linear_op dequantize_linear_op, then add input_scales diff --git a/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc b/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc index 8ebea231e7..47ad986fe8 100644 --- a/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc @@ -52,6 +52,10 @@ DeleteWeightQuantDequantLinearOpPass::DeleteWeightQuantDequantLinearOpPass() { .End() .AddAttr("quant_axis") .IsType() + .End() + .AddAttr("round_type") + .IsOptional() + .IsType() .End(); AddOpCompat(OpCompat("dequantize_linear")) .AddInput("X") @@ -72,6 +76,10 @@ DeleteWeightQuantDequantLinearOpPass::DeleteWeightQuantDequantLinearOpPass() { .End() .AddAttr("quant_axis") .IsType() + .End() + .AddAttr("round_type") + .IsOptional() + .IsType() .End(); AddOpCompat(OpCompat("conv2d")) .AddInput("Input") diff --git a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc index e436bee035..af8043e155 100644 --- a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc +++ b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc @@ -49,6 +49,10 @@ QuantDequantFusePass::QuantDequantFusePass() { .End() .AddAttr("bit_length") .IsIntIn({8, 16}) + .End() + .AddAttr("round_type") + .IsOptional() + .IsIntIn({0, 1}) .End(); AddOpCompat(OpCompat("fake_quantize_moving_average_abs_max")) .AddInput("X") @@ -85,6 +89,10 @@ QuantDequantFusePass::QuantDequantFusePass() { .End() .AddAttr("bit_length") .IsIntIn({8, 16}) + .End() + .AddAttr("round_type") + .IsOptional() + .IsIntIn({0, 1}) .End(); AddOpCompat(OpCompat("fake_dequantize_max_abs")) .AddInput("X") diff --git a/paddle/fluid/operators/fake_quantize_op.cc b/paddle/fluid/operators/fake_quantize_op.cc index 855c78d299..94badfb1c2 100644 --- a/paddle/fluid/operators/fake_quantize_op.cc +++ b/paddle/fluid/operators/fake_quantize_op.cc @@ -88,14 +88,14 @@ template struct ClipAndFakeQuantFunctor { void operator()(const platform::CPUDeviceContext& ctx, const framework::Tensor& in, const framework::Tensor& scale, - const int bin_cnt, framework::Tensor* out) { + const int bin_cnt, const int round_type, + framework::Tensor* out) { T s = scale.data()[0]; T inv_s = inverse(s); platform::Transform trans; trans(ctx, in.data(), in.data() + in.numel(), - out->mutable_data(ctx.GetPlace()), phi::ClipFunctor(-s, s)); - auto out_e = framework::EigenVector::Flatten(*out); - out_e.device(*ctx.eigen_device()) = (bin_cnt * inv_s * out_e).round(); + out->mutable_data(ctx.GetPlace()), + QuantTensorFunctor(static_cast(bin_cnt), round_type, inv_s)); } }; @@ -105,16 +105,17 @@ template struct ClipAndFakeQuantDequantFunctor { void operator()(const platform::CPUDeviceContext& ctx, const framework::Tensor& in, const framework::Tensor& scale, - const int bin_cnt, framework::Tensor* out) { + const int bin_cnt, const int round_type, + framework::Tensor* out) { T s = scale.data()[0]; T inv_s = inverse(s); platform::Transform trans; trans(ctx, in.data(), in.data() + in.numel(), - out->mutable_data(ctx.GetPlace()), phi::ClipFunctor(-s, s)); + out->mutable_data(ctx.GetPlace()), + QuantTensorFunctor(static_cast(bin_cnt), round_type, inv_s)); auto out_e = framework::EigenVector::Flatten(*out); - out_e.device(*ctx.eigen_device()) = - (bin_cnt * inv_s * out_e).round() * s / static_cast(bin_cnt); + out_e.device(*ctx.eigen_device()) = out_e * s / static_cast(bin_cnt); } }; template struct ClipAndFakeQuantDequantFunctor struct ChannelClipAndFakeQuantFunctor { void operator()(const platform::CPUDeviceContext& ctx, const framework::Tensor& in, const framework::Tensor& scale, - const int bin_cnt, const int quant_axis, + const int bin_cnt, const int round_type, const int quant_axis, framework::Tensor* out) { // At present, channelwise quantization supports conv2d, depthwise_conv2d // conv2d_transpose and mul @@ -145,15 +146,10 @@ struct ChannelClipAndFakeQuantFunctor { T s = scale_data[i]; auto* start = in_data + i * channel_size; auto* end = in_data + (i + 1) * channel_size; - trans(ctx, start, end, out_data + i * channel_size, - phi::ClipFunctor(-s, s)); - } - for (int64_t i = 0; i < channel; i++) { - T s = scale_data[i]; T inv_s = inverse(s); - framework::Tensor one_channel_out = out->Slice(i, i + 1); - auto out_e = framework::EigenVector::Flatten(one_channel_out); - out_e.device(*ctx.eigen_device()) = (bin_cnt * inv_s * out_e).round(); + trans( + ctx, start, end, out_data + i * channel_size, + QuantTensorFunctor(static_cast(bin_cnt), round_type, inv_s)); } } else if (quant_axis == 1) { const int64_t step_i = in.numel() / in_dims[0]; @@ -165,10 +161,9 @@ struct ChannelClipAndFakeQuantFunctor { auto* start = in_data + i * step_i + j * step_j; auto* end = in_data + i * step_i + (j + 1) * step_j; auto* cur_out_data = out_data + i * step_i + j * step_j; - trans(ctx, start, end, cur_out_data, phi::ClipFunctor(-s, s)); - for (int k = 0; k < step_j; k++) { - cur_out_data[k] = std::round(bin_cnt * inv_s * cur_out_data[k]); - } + trans(ctx, start, end, cur_out_data, + QuantTensorFunctor(static_cast(bin_cnt), round_type, + inv_s)); } } } @@ -181,7 +176,7 @@ template struct ChannelClipFakeQuantDequantFunctor { void operator()(const platform::CPUDeviceContext& ctx, const framework::Tensor& in, const framework::Tensor& scale, - const int bin_cnt, const int quant_axis, + const int bin_cnt, const int round_type, const int quant_axis, framework::Tensor* out) { PADDLE_ENFORCE_EQ( quant_axis == 0 || quant_axis == 1, true, @@ -201,16 +196,13 @@ struct ChannelClipFakeQuantDequantFunctor { T s = scale_data[i]; auto* start = in_data + i * channel_size; auto* end = in_data + (i + 1) * channel_size; - trans(ctx, start, end, out_data + i * channel_size, - phi::ClipFunctor(-s, s)); - } - for (int i = 0; i < channel; i++) { - T s = scale_data[i]; T inv_s = inverse(s); + trans( + ctx, start, end, out_data + i * channel_size, + QuantTensorFunctor(static_cast(bin_cnt), round_type, inv_s)); framework::Tensor one_channel_out = out->Slice(i, i + 1); auto out_e = framework::EigenVector::Flatten(one_channel_out); - out_e.device(*ctx.eigen_device()) = - (bin_cnt * inv_s * out_e).round() * s / static_cast(bin_cnt); + out_e.device(*ctx.eigen_device()) = out_e * s / static_cast(bin_cnt); } } else if (quant_axis == 1) { const int64_t step_i = in.numel() / in_dims[0]; @@ -222,10 +214,11 @@ struct ChannelClipFakeQuantDequantFunctor { auto* start = in_data + i * step_i + j * step_j; auto* end = in_data + i * step_i + (j + 1) * step_j; auto* cur_out_data = out_data + i * step_i + j * step_j; - trans(ctx, start, end, cur_out_data, phi::ClipFunctor(-s, s)); + trans(ctx, start, end, cur_out_data, + QuantTensorFunctor(static_cast(bin_cnt), round_type, + inv_s)); for (int k = 0; k < step_j; k++) { - cur_out_data[k] = std::round(bin_cnt * inv_s * cur_out_data[k]) * - s / static_cast(bin_cnt); + cur_out_data[k] = cur_out_data[k] * s / static_cast(bin_cnt); } } } @@ -334,6 +327,20 @@ class FakeQuantOrWithDequantAbsMaxOpMaker "the received is %d", bit_length)); }); + AddAttr( + "round_type", + "(int, default 0) The round type of fp32 to int." + "0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2" + "1: rounding to nearest ties away from zero. Eg: round(1.5)=2, " + "round(2.5)=3") + .SetDefault(0) + .AddCustomChecker([](const int& round_type) { + PADDLE_ENFORCE_EQ(round_type >= 0 && round_type <= 1, true, + platform::errors::InvalidArgument( + "'round_type' should be between 0 and 1, but " + "the received is %d", + round_type)); + }); AddComment(R"DOC( This is a Base Op which supports FakeQuantAbsMaxOpMaker and FakeQuantDequantAbsMaxOpMaker. FakeQuantAbsMaxOp operator is used in the dynamic quantization. @@ -407,6 +414,20 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker "the received is %d", bit_length)); }); + AddAttr( + "round_type", + "(int, default 0) The round type of fp32 to int." + "0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2" + "1: rounding to nearest ties away from zero. Eg: round(1.5)=2, " + "round(2.5)=3") + .SetDefault(0) + .AddCustomChecker([](const int& round_type) { + PADDLE_ENFORCE_EQ(round_type >= 0 && round_type <= 1, true, + platform::errors::InvalidArgument( + "'round_type' should be between 0 and 1, but " + "the received is %d", + round_type)); + }); AddAttr("is_test", "(bool, default false) Set to true for inference only, false " "for training. Some layers may run faster when this is true.") @@ -480,6 +501,20 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxOpMaker "the received is %d", bit_length)); }); + AddAttr( + "round_type", + "(int, default 0) The round type of fp32 to int." + "0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2" + "1: rounding to nearest ties away from zero. Eg: round(1.5)=2, " + "round(2.5)=3") + .SetDefault(0) + .AddCustomChecker([](const int& round_type) { + PADDLE_ENFORCE_EQ(round_type >= 0 && round_type <= 1, true, + platform::errors::InvalidArgument( + "'round_type' should be between 0 and 1, but " + "the received is %d", + round_type)); + }); AddComment(R"DOC( The scale of FakeChannelWiseQuantize operator is a vector. In detail, each channel of the input X has a scale value. @@ -546,6 +581,20 @@ class FakeQuantizeRangeAbsMaxOpMaker "the received is %d", bit_length)); }); + AddAttr( + "round_type", + "(int, default 0) The round type of fp32 to int." + "0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2" + "1: rounding to nearest ties away from zero. Eg: round(1.5)=2, " + "round(2.5)=3") + .SetDefault(0) + .AddCustomChecker([](const int& round_type) { + PADDLE_ENFORCE_EQ(round_type >= 0 && round_type <= 1, true, + platform::errors::InvalidArgument( + "'round_type' should be between 0 and 1, but " + "the received is %d", + round_type)); + }); AddAttr("is_test", "(bool, default false) Set to true for inference only, false " "for training. Some layers may run faster when this is true.") @@ -620,6 +669,20 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker "the received is %d", bit_length)); }); + AddAttr( + "round_type", + "(int, default 0) The round type of fp32 to int." + "0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2" + "1: rounding to nearest ties away from zero. Eg: round(1.5)=2, " + "round(2.5)=3") + .SetDefault(0) + .AddCustomChecker([](const int& round_type) { + PADDLE_ENFORCE_EQ(round_type >= 0 && round_type <= 1, true, + platform::errors::InvalidArgument( + "'round_type' should be between 0 and 1, but " + "the received is %d", + round_type)); + }); AddAttr("is_test", "(bool, default false) Set to true for inference only, false " "for training. Some layers may run faster when this is true.") diff --git a/paddle/fluid/operators/fake_quantize_op.cu.h b/paddle/fluid/operators/fake_quantize_op.cu.h index 580521183c..46aa3fbfe3 100644 --- a/paddle/fluid/operators/fake_quantize_op.cu.h +++ b/paddle/fluid/operators/fake_quantize_op.cu.h @@ -214,7 +214,8 @@ template struct FindChannelAbsMaxFunctor; template __global__ void ClipAndQuantKernel(const T* in, const T* scale, - const int bin_cnt, const int n, T* out) { + const int bin_cnt, const int round_type, + const int n, T* out) { int bid = threadIdx.x + blockIdx.x * blockDim.x; int tid = threadIdx.x; @@ -226,16 +227,24 @@ __global__ void ClipAndQuantKernel(const T* in, const T* scale, for (int i = bid; i < n; i += blockDim.x * gridDim.x) { ComputeDataType x = static_cast(in[i]); - ComputeDataType v = x > s ? s : x; - v = v < -s ? -s : v; - v = bin_cnt_t * inv_s * v; - out[i] = static_cast(round(v)); + x = bin_cnt_t * inv_s * x; + if (round_type == 0) { + x = roundWithTiesToEven(x); + } else { + x = round(x); + } + ComputeDataType max_bound = bin_cnt_t; + ComputeDataType min_bound = -bin_cnt_t - static_cast(1); + x = x > max_bound ? max_bound : x; + x = x < min_bound ? min_bound : x; + out[i] = static_cast(x); } } template __global__ void ClipAndQuantDequantKernel(const T* in, const T* scale, - const int bin_cnt, const int n, + const int bin_cnt, + const int round_type, const int n, T* out) { int bid = threadIdx.x + blockIdx.x * blockDim.x; int tid = threadIdx.x; @@ -248,10 +257,16 @@ __global__ void ClipAndQuantDequantKernel(const T* in, const T* scale, for (int i = bid; i < n; i += blockDim.x * gridDim.x) { ComputeDataType x = static_cast(in[i]); - x = x > s ? s : x; - x = x < -s ? -s : x; x = bin_cnt_t * inv_s * x; - x = round(x); + if (round_type == 0) { + x = roundWithTiesToEven(x); + } else { + x = round(x); + } + ComputeDataType max_bound = bin_cnt_t; + ComputeDataType min_bound = -bin_cnt_t - static_cast(1); + x = x > max_bound ? max_bound : x; + x = x < min_bound ? min_bound : x; out[i] = static_cast((x * s) / bin_cnt_t); } } @@ -260,7 +275,8 @@ template struct ClipAndFakeQuantFunctor { void operator()(const platform::CUDADeviceContext& ctx, const framework::Tensor& in, const framework::Tensor& scale, - const int bin_cnt, framework::Tensor* out) { + const int bin_cnt, const int round_type, + framework::Tensor* out) { int num = in.numel(); int block = 1024; int grid = (block - 1 + num) / block; @@ -270,7 +286,7 @@ struct ClipAndFakeQuantFunctor { T* out_data = out->mutable_data(ctx.GetPlace()); ClipAndQuantKernel<<>>( - in_data, scale_data, bin_cnt, num, out_data); + in_data, scale_data, bin_cnt, round_type, num, out_data); } }; @@ -280,7 +296,8 @@ template struct ClipAndFakeQuantDequantFunctor { void operator()(const platform::CUDADeviceContext& ctx, const framework::Tensor& in, const framework::Tensor& scale, - const int bin_cnt, framework::Tensor* out) { + const int bin_cnt, const int round_type, + framework::Tensor* out) { int num = in.numel(); int block = 1024; int grid = (block - 1 + num) / block; @@ -290,7 +307,7 @@ struct ClipAndFakeQuantDequantFunctor { T* out_data = out->mutable_data(ctx.GetPlace()); ClipAndQuantDequantKernel<<>>( - in_data, scale_data, bin_cnt, num, out_data); + in_data, scale_data, bin_cnt, round_type, num, out_data); } }; @@ -298,6 +315,7 @@ struct ClipAndFakeQuantDequantFunctor { template __global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale, const int bin_cnt, + const int round_type, const int64_t n, const int c, T* out) { int tid = threadIdx.x; @@ -314,18 +332,25 @@ __global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale, for (int64_t i = tid; i < channel_size; i += blockDim.x) { ComputeDataType x = static_cast(in_c[i]); - ComputeDataType v = x > s ? s : x; - v = v < -s ? -s : v; - v = bin_cnt_t * inv_s * v; - out_c[i] = static_cast(round(v)); + x = bin_cnt_t * inv_s * x; + if (round_type == 0) { + x = roundWithTiesToEven(x); + } else { + x = round(x); + } + ComputeDataType max_bound = bin_cnt_t; + ComputeDataType min_bound = -bin_cnt_t - static_cast(1); + x = x > max_bound ? max_bound : x; + x = x < min_bound ? min_bound : x; + out_c[i] = static_cast(x); } } // ChannelClipAndQuantKernel for quant_axis is N template __global__ void ChannelClipAndQuantKernelQuantAxisN( - const T* in, const T* scale, const int bin_cnt, const int64_t n, - const int nScale, const int quant_stride, T* out) { + const T* in, const T* scale, const int bin_cnt, const int round_type, + const int64_t n, const int nScale, const int quant_stride, T* out) { int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; using ComputeDataType = typename QuantizeDataType::type; ComputeDataType bin_cnt_t = static_cast(bin_cnt); @@ -334,10 +359,17 @@ __global__ void ChannelClipAndQuantKernelQuantAxisN( static_cast(scale[(i / quant_stride) % nScale]); ComputeDataType inv_s = inverse(s); ComputeDataType x = static_cast(in[i]); - ComputeDataType v = x > s ? s : x; - v = v < -s ? -s : v; - v = bin_cnt_t * inv_s * v; - out[i] = static_cast(round(v)); + x = bin_cnt_t * inv_s * x; + if (round_type == 0) { + x = roundWithTiesToEven(x); + } else { + x = round(x); + } + ComputeDataType max_bound = bin_cnt_t; + ComputeDataType min_bound = -bin_cnt_t - static_cast(1); + x = x > max_bound ? max_bound : x; + x = x < min_bound ? min_bound : x; + out[i] = static_cast(x); } } @@ -345,7 +377,7 @@ template struct ChannelClipAndFakeQuantFunctor { void operator()(const platform::CUDADeviceContext& ctx, const framework::Tensor& in, const framework::Tensor& scale, - const int bin_cnt, const int quant_axis, + const int bin_cnt, const int round_type, const int quant_axis, framework::Tensor* out) { PADDLE_ENFORCE_EQ( quant_axis == 0 || quant_axis == 1, true, @@ -363,7 +395,7 @@ struct ChannelClipAndFakeQuantFunctor { int grid = in_dims[0]; int block = 1024; ChannelClipAndQuantKernelQuantAxis0<<>>( - in_data, scale_data, bin_cnt, num, in_dims[0], out_data); + in_data, scale_data, bin_cnt, round_type, num, in_dims[0], out_data); } else { int quant_stride = 1; for (int i = quant_axis + 1; i < in_dims.size(); i++) { @@ -380,8 +412,8 @@ struct ChannelClipAndFakeQuantFunctor { std::min(max_blocks, (num + block_size - 1) / block_size); ChannelClipAndQuantKernelQuantAxisN<<>>( - in_data, scale_data, bin_cnt, num, in_dims[quant_axis], quant_stride, - out_data); + in_data, scale_data, bin_cnt, round_type, num, in_dims[quant_axis], + quant_stride, out_data); } } }; @@ -485,8 +517,8 @@ struct FindMovingAverageAbsMaxFunctor { // ChannelClipAndQuantDequantKernel for quant_axis is 0 template __global__ void ChannelClipAndQuantDequantKernelQuantAxis0( - const T* in, const T* scale, const int bin_cnt, const int n, const int c, - T* out) { + const T* in, const T* scale, const int bin_cnt, const int round_type, + const int n, const int c, T* out) { int tid = threadIdx.x; int channel_size = n / c; @@ -498,18 +530,25 @@ __global__ void ChannelClipAndQuantDequantKernelQuantAxis0( for (int i = tid; i < channel_size; i += blockDim.x) { T x = in_c[i]; - T v = x > s ? s : x; - v = v < -s ? -s : v; - v = bin_cnt * inv_s * v; - out_c[i] = round(v) * s / bin_cnt; + x = bin_cnt * inv_s * x; + if (round_type == 0) { + x = roundWithTiesToEven(x); + } else { + x = round(x); + } + T max_bound = bin_cnt; + T min_bound = -bin_cnt - static_cast(1); + x = x > max_bound ? max_bound : x; + x = x < min_bound ? min_bound : x; + out_c[i] = (x * s) / bin_cnt; } } // ChannelClipAndQuantDequantKernel for quant_axis is 1 template __global__ void ChannelClipAndQuantDequantKernelQuantAxis1( - const T* in, const T* scale, const int bin_cnt, const int n, const int cin, - const int cout, T* out) { + const T* in, const T* scale, const int bin_cnt, const int round_type, + const int n, const int cin, const int cout, T* out) { T s = scale[blockIdx.x % cout]; T inv_s = inverse(s); @@ -519,10 +558,17 @@ __global__ void ChannelClipAndQuantDequantKernelQuantAxis1( for (int i = threadIdx.x; i < wh_size; i += blockDim.x) { T x = in_c[i]; - T v = x > s ? s : x; - v = v < -s ? -s : v; - v = bin_cnt * inv_s * v; - out_c[i] = round(v) * s / bin_cnt; + x = bin_cnt * inv_s * x; + if (round_type == 0) { + x = roundWithTiesToEven(x); + } else { + x = round(x); + } + T max_bound = bin_cnt; + T min_bound = -bin_cnt - static_cast(1); + x = x > max_bound ? max_bound : x; + x = x < min_bound ? min_bound : x; + out_c[i] = (x * s) / bin_cnt; } } @@ -530,7 +576,7 @@ template struct ChannelClipFakeQuantDequantFunctor { void operator()(const platform::CUDADeviceContext& ctx, const framework::Tensor& in, const framework::Tensor& scale, - const int bin_cnt, const int quant_axis, + const int bin_cnt, const int round_type, const int quant_axis, framework::Tensor* out) { // At present, channelwise quantization supports conv2d, depthwise_conv2d // conv2d_transpose and mul @@ -551,15 +597,17 @@ struct ChannelClipFakeQuantDequantFunctor { int grid = in_dims[0]; int block = 1024; ChannelClipAndQuantDequantKernelQuantAxis0 - <<>>(in_data, scale_data, bin_cnt, num, - in_dims[0], out_data); + <<>>(in_data, scale_data, bin_cnt, + round_type, num, in_dims[0], + out_data); } else if (quant_axis == 1) { int grid = in_dims[0] * in_dims[1]; int block = 1024; ChannelClipAndQuantDequantKernelQuantAxis1 - <<>>(in_data, scale_data, bin_cnt, num, - in_dims[0], in_dims[1], out_data); + <<>>(in_data, scale_data, bin_cnt, + round_type, num, in_dims[0], + in_dims[1], out_data); } } }; diff --git a/paddle/fluid/operators/fake_quantize_op.h b/paddle/fluid/operators/fake_quantize_op.h index 182db11ed8..2956478f44 100644 --- a/paddle/fluid/operators/fake_quantize_op.h +++ b/paddle/fluid/operators/fake_quantize_op.h @@ -34,6 +34,46 @@ inline HOSTDEVICE T inverse(T s) { return s <= static_cast(1e-30) ? one / (s + eps) : one / s; } +template +inline HOSTDEVICE T roundWithTiesToEven(T x) { + T xLower = floor(x); + T xUpper = ceil(x); + // x is in interval [xl,xu]. Choose closest of two bounds, breaking ties to + // even. + T dLower = x - xLower; + T dUpper = xUpper - x; + return static_cast( + (dLower == dUpper ? fmod(xLower, 2.0F) == 0.0F : dLower < dUpper) + ? xLower + : xUpper); +} + +template +class QuantTensorFunctor { + public: + explicit QuantTensorFunctor(const T bin_cnt, const int round_type, + const T inv_s) + : bin_cnt_(bin_cnt), round_type_(round_type), inv_s_(inv_s) {} + HOSTDEVICE T operator()(const T x) const { + T out = bin_cnt_ * inv_s_ * x; + if (round_type_ == 0) { + out = roundWithTiesToEven(out); + } else if (round_type_ == 1) { + out = std::round(out); + } + T max_bound = bin_cnt_; + T min_bound = -bin_cnt_ - static_cast(1); + out = out > max_bound ? max_bound : out; + out = out < min_bound ? min_bound : out; + return out; + } + + private: + T bin_cnt_; + int round_type_; + T inv_s_; +}; + template struct FindAbsMaxFunctor { void operator()(const DeviceContext& ctx, const T* in, const int num, T* out); @@ -43,14 +83,14 @@ template struct ClipAndFakeQuantFunctor { void operator()(const DeviceContext& ctx, const framework::Tensor& in, const framework::Tensor& scale, const int bin_cnt, - framework::Tensor* out); + const int round_type, framework::Tensor* out); }; template struct ClipAndFakeQuantDequantFunctor { void operator()(const DeviceContext& ctx, const framework::Tensor& in, const framework::Tensor& scale, const int bin_cnt, - framework::Tensor* out); + int round_type, framework::Tensor* out); }; template @@ -71,14 +111,15 @@ template struct ChannelClipAndFakeQuantFunctor { void operator()(const DeviceContext& ctx, const framework::Tensor& in, const framework::Tensor& scale, const int bin_cnt, - const int quant_axis, framework::Tensor* out); + const int round_type, const int quant_axis, + framework::Tensor* out); }; template struct ChannelClipFakeQuantDequantFunctor { void operator()(const DeviceContext& ctx, const framework::Tensor& in, const framework::Tensor& scale, const int bin_cnt, - const int quant_axis, framework::Tensor* out); + int round_type, const int quant_axis, framework::Tensor* out); }; template @@ -100,12 +141,13 @@ class FakeAbsMaxKernelBase : public framework::OpKernel { T* out_s = out_scale->mutable_data(context.GetPlace()); int bit_length = context.Attr("bit_length"); + int round_type = context.Attr("round_type"); int bin_cnt = std::pow(2, bit_length - 1) - 1; auto& dev_ctx = context.template device_context(); const T* in_data = in->data(); FindAbsMaxFunctor()(dev_ctx, in_data, in->numel(), out_s); - RunClipFunctor(dev_ctx, *in, *out_scale, bin_cnt, out); + RunClipFunctor(dev_ctx, *in, *out_scale, bin_cnt, round_type, out); } virtual ~FakeAbsMaxKernelBase() = default; @@ -114,7 +156,7 @@ class FakeAbsMaxKernelBase : public framework::OpKernel { virtual void RunClipFunctor(const DeviceContext& dev_ctx, const framework::Tensor& in, const framework::Tensor& scale, int bin_cnt, - framework::Tensor* out) const = 0; + int round_type, framework::Tensor* out) const = 0; }; template @@ -122,9 +164,9 @@ class FakeQuantizeAbsMaxKernel : public FakeAbsMaxKernelBase { protected: void RunClipFunctor(const DeviceContext& dev_ctx, const framework::Tensor& in, const framework::Tensor& scale, int bin_cnt, - framework::Tensor* out) const override { + int round_type, framework::Tensor* out) const override { ClipAndFakeQuantFunctor()(dev_ctx, in, scale, bin_cnt, - out); + round_type, out); } }; @@ -134,9 +176,9 @@ class FakeQuantizeDequantizeAbsMaxKernel protected: void RunClipFunctor(const DeviceContext& dev_ctx, const framework::Tensor& in, const framework::Tensor& scale, int bin_cnt, - framework::Tensor* out) const override { - ClipAndFakeQuantDequantFunctor()(dev_ctx, in, scale, - bin_cnt, out); + int round_type, framework::Tensor* out) const override { + ClipAndFakeQuantDequantFunctor()( + dev_ctx, in, scale, bin_cnt, round_type, out); } }; @@ -151,6 +193,7 @@ class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel { out->mutable_data(context.GetPlace()); int bit_length = context.Attr("bit_length"); + int round_type = context.Attr("round_type"); int bin_cnt = std::pow(2, bit_length - 1) - 1; int quant_axis = context.Attr("quant_axis"); bool is_test = context.Attr("is_test"); @@ -162,7 +205,7 @@ class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel { out_scale_data); } ChannelClipAndFakeQuantFunctor()( - dev_ctx, *in, *out_scale, bin_cnt, quant_axis, out); + dev_ctx, *in, *out_scale, bin_cnt, round_type, quant_axis, out); } }; @@ -179,6 +222,7 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxKernel out->mutable_data(dev_ctx.GetPlace()); int bit_length = context.Attr("bit_length"); + int round_type = context.Attr("round_type"); int bin_cnt = std::pow(2, bit_length - 1) - 1; int quant_axis = context.Attr("quant_axis"); @@ -186,7 +230,7 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxKernel out_scale_data); ChannelClipFakeQuantDequantFunctor()( - dev_ctx, *in, *out_scale, bin_cnt, quant_axis, out); + dev_ctx, *in, *out_scale, bin_cnt, round_type, quant_axis, out); } }; @@ -202,13 +246,14 @@ class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel { bool is_test = context.Attr("is_test"); int bit_length = context.Attr("bit_length"); + int round_type = context.Attr("round_type"); int bin_cnt = std::pow(2, bit_length - 1) - 1; auto& dev_ctx = context.template device_context(); // testing if (is_test) { ClipAndFakeQuantFunctor()(dev_ctx, *in, *in_scale, - bin_cnt, out); + bin_cnt, round_type, out); return; } @@ -228,7 +273,7 @@ class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel { *iter, window_size, out_scales, out_scale); ClipAndFakeQuantFunctor()(dev_ctx, *in, *out_scale, - bin_cnt, out); + bin_cnt, round_type, out); } }; @@ -243,12 +288,13 @@ class FakeMovingAverageAbsMaxKernelBase : public framework::OpKernel { bool is_test = context.Attr("is_test"); int bit_length = context.Attr("bit_length"); + int round_type = context.Attr("round_type"); int bin_cnt = std::pow(2, bit_length - 1) - 1; auto& dev_ctx = context.template device_context(); // testing if (is_test) { - RunClipFunctor(dev_ctx, *in, *in_scale, bin_cnt, out); + RunClipFunctor(dev_ctx, *in, *in_scale, bin_cnt, round_type, out); return; } @@ -273,7 +319,7 @@ class FakeMovingAverageAbsMaxKernelBase : public framework::OpKernel { dev_ctx, *in_accum, *in_state, cur_scale_data, moving_rate, out_state, out_accum, out_scale); - RunClipFunctor(dev_ctx, *in, *out_scale, bin_cnt, out); + RunClipFunctor(dev_ctx, *in, *out_scale, bin_cnt, round_type, out); } virtual ~FakeMovingAverageAbsMaxKernelBase() = default; @@ -282,7 +328,7 @@ class FakeMovingAverageAbsMaxKernelBase : public framework::OpKernel { virtual void RunClipFunctor(const DeviceContext& dev_ctx, const framework::Tensor& in, const framework::Tensor& in_scale, int bin_cnt, - framework::Tensor* out) const = 0; + int round_type, framework::Tensor* out) const = 0; }; template @@ -291,9 +337,9 @@ class FakeQuantizeMovingAverageAbsMaxKernel protected: void RunClipFunctor(const DeviceContext& dev_ctx, const framework::Tensor& in, const framework::Tensor& in_scale, int bin_cnt, - framework::Tensor* out) const override { + int round_type, framework::Tensor* out) const override { ClipAndFakeQuantFunctor()(dev_ctx, in, in_scale, bin_cnt, - out); + round_type, out); } }; @@ -303,9 +349,9 @@ class FakeQuantizeDequantizeMovingAverageAbsMaxKernel protected: void RunClipFunctor(const DeviceContext& dev_ctx, const framework::Tensor& in, const framework::Tensor& in_scale, int bin_cnt, - framework::Tensor* out) const override { - ClipAndFakeQuantDequantFunctor()(dev_ctx, in, in_scale, - bin_cnt, out); + int round_type, framework::Tensor* out) const override { + ClipAndFakeQuantDequantFunctor()( + dev_ctx, in, in_scale, bin_cnt, round_type, out); } }; diff --git a/paddle/fluid/operators/quantize_linear_op.cc b/paddle/fluid/operators/quantize_linear_op.cc index edd2a06a50..7aaebb8f92 100644 --- a/paddle/fluid/operators/quantize_linear_op.cc +++ b/paddle/fluid/operators/quantize_linear_op.cc @@ -69,8 +69,6 @@ struct ChannelDequantizeFunctorV2 { } }; -template struct DequantizeFunctor; -template struct DequantizeFunctor; template struct ChannelDequantizeFunctorV2; template struct ChannelDequantizeFunctorV2; @@ -135,6 +133,20 @@ class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker { "the received is %d", bit_length)); }); + AddAttr( + "round_type", + "(int, default 0) The round type of fp32 to int." + "0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2" + "1: rounding to nearest ties away from zero. Eg: round(1.5)=2, " + "round(2.5)=3") + .SetDefault(0) + .AddCustomChecker([](const int& round_type) { + PADDLE_ENFORCE_EQ(round_type >= 0 && round_type <= 1, true, + platform::errors::InvalidArgument( + "'round_type' should be between 0 and 1, but " + "the received is %d", + round_type)); + }); AddAttr("is_test", "(bool, default false) Set to true for inference only, false " "for training. Some layers may run faster when this is true.") diff --git a/paddle/fluid/operators/quantize_linear_op.h b/paddle/fluid/operators/quantize_linear_op.h index df1a93ba63..b56ac61c1a 100644 --- a/paddle/fluid/operators/quantize_linear_op.h +++ b/paddle/fluid/operators/quantize_linear_op.h @@ -45,6 +45,7 @@ class QuantizeLinearKernel : public framework::OpKernel { auto* out = context.Output("Y"); out->mutable_data(context.GetPlace()); int bit_length = context.Attr("bit_length"); + int round_type = context.Attr("round_type"); int bin_cnt = std::pow(2, bit_length - 1) - 1; int quant_axis = context.Attr("quant_axis"); bool is_test = context.Attr("is_test"); @@ -57,10 +58,10 @@ class QuantizeLinearKernel : public framework::OpKernel { FindAbsMaxFunctor()(dev_ctx, in->data(), in->numel(), out_s); ClipAndFakeQuantFunctor()(dev_ctx, *in, *out_scale, - bin_cnt, out); + bin_cnt, round_type, out); } else { ClipAndFakeQuantFunctor()(dev_ctx, *in, *in_scale, - bin_cnt, out); + bin_cnt, round_type, out); } } else { if (!is_test) { @@ -69,10 +70,10 @@ class QuantizeLinearKernel : public framework::OpKernel { FindChannelAbsMaxFunctor()(dev_ctx, *in, quant_axis, out_scale_data); ChannelClipAndFakeQuantFunctor()( - dev_ctx, *in, *out_scale, bin_cnt, quant_axis, out); + dev_ctx, *in, *out_scale, bin_cnt, round_type, quant_axis, out); } else { ChannelClipAndFakeQuantFunctor()( - dev_ctx, *in, *in_scale, bin_cnt, quant_axis, out); + dev_ctx, *in, *in_scale, bin_cnt, round_type, quant_axis, out); } } } diff --git a/python/paddle/fluid/contrib/slim/quantization/adaround.py b/python/paddle/fluid/contrib/slim/quantization/adaround.py index be3201044f..04d894b055 100644 --- a/python/paddle/fluid/contrib/slim/quantization/adaround.py +++ b/python/paddle/fluid/contrib/slim/quantization/adaround.py @@ -20,7 +20,7 @@ import logging import paddle.fluid as fluid from ....log_helper import get_logger -from .utils import load_variable_data, set_variable_data, stable_sigmoid, quant_tensor, dequant_tensor, _channelwise_quant_axis1_ops, calculate_quant_cos_error +from .utils import load_variable_data, set_variable_data, stable_sigmoid, quant_tensor, dequant_tensor, _channelwise_quant_axis1_ops, calculate_quant_cos_error, bias_correction_w _logger = get_logger(__name__, logging.INFO, @@ -209,6 +209,7 @@ def run_adaround(data_loader, scale_dict, num_iterations=1000, lr=0.001, + bias_correction=False, fast_mode=True): fetch_op_name = fetch_list[0].name final_weight_tensor_quant_dict = {} @@ -307,6 +308,15 @@ def run_adaround(data_loader, break final_weight_tensor_quant_dict[ weight_var_name] = adaround.update_final_weights() + + if bias_correction: + final_weight_tensor_quant_dict[weight_var_name] = bias_correction_w( + weight_var_tensor, + final_weight_tensor_quant_dict[weight_var_name], + scale, + adaround.quant_axis, + weight_bits=adaround.weight_bits) + del adaround # update adarounded calibrated weights diff --git a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py index a4888e6f90..9bcf3af134 100644 --- a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py +++ b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py @@ -121,7 +121,8 @@ class PostTrainingQuantization(object): algo="KL", hist_percent=0.99999, quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], - round_type='round', + weight_round_algo='round', + round_type='TiesToEven', learning_rate=0.001, is_full_quantize=False, bias_correction=False, @@ -180,9 +181,14 @@ class PostTrainingQuantization(object): quantizable_op_type(list[str], optional): List the type of ops that will be quantized. Default is ["conv2d", "depthwise_conv2d", "mul"]. - round_type(str, optional): The method of converting the quantized weights + weight_round_algo(str, optional): The method of converting the quantized weights value float->int. Currently supports ['round', 'adaround'] methods. - Default is `round`, which is rounding nearest to the nearest whole number. + Default is `round`, which is rounding nearest to the integer. + 'adaround' is refer to https://arxiv.org/abs/2004.10568. + round_type(str, optional): The method of converting the tensor value float->int. + Currently supports ['TiesToEven', 'TiesAwayFromZero'] methods. + Default is `TiesToEven`, which is rounding to nearest ties to even. + 'TiesAwayFromZero' is rounding to nearest ties away from zero. learning_rate(float, optional): The learning rate of adaround method. is_full_quantized(bool, optional): If set is_full_quantized as True, apply quantization to all supported quantizable op type. If set @@ -263,8 +269,10 @@ class PostTrainingQuantization(object): self._support_algo_type = [ 'KL', 'hist', 'avg', 'mse', 'emd', 'abs_max', 'min_max' ] - assert round_type in ['adaround', 'round'] + assert round_type in ['TiesToEven', 'TiesAwayFromZero'] self._round_type = round_type + assert weight_round_algo in ['adaround', 'round'] + self._weight_round_algo = weight_round_algo self._learning_rate = learning_rate self._dynamic_quantize_op_type = ['lstm'] self._support_quantize_op_type = \ @@ -406,7 +414,7 @@ class PostTrainingQuantization(object): if self._algo in ["KL", "hist"]: self._calculate_kl_hist_threshold() - if self._round_type == 'adaround': + if self._weight_round_algo == 'adaround': self._adaround_apply() self._reset_activation_persistable() @@ -459,6 +467,7 @@ class PostTrainingQuantization(object): self._weight_op_pairs, scale_dict, num_iterations=self._batch_nums, + bias_correction=self._bias_correction, lr=self._learning_rate) def save_quantized_model(self, @@ -642,6 +651,7 @@ class PostTrainingQuantization(object): float(np.max(np.abs(var_tensor[i])))) self._quantized_threshold[var_name] = abs_max_value _logger.info("MSE searching stage ...") + distribution = np.round if self._round_type == 'TiesToEven' else utils.round_c for var_name in self._quantized_act_var_name: var_tensor = utils.load_variable_data(self._scope, var_name) var_tensor = var_tensor.flatten() @@ -654,9 +664,9 @@ class PostTrainingQuantization(object): scale = s * abs_max_value s += 0.02 bins = 2**(self._activation_bits - 1) - 1 - quant_dequant_var = np.round( - np.clip(var_tensor, 0.0, scale) / scale * - bins) / bins * scale + quant_var = np.clip(distribution(var_tensor / scale * bins), + -bins - 1, bins) + quant_dequant_var = quant_var / bins * scale mse_loss = ((var_tensor - quant_dequant_var)**2).mean() if mse_loss <= self._best_calibration_loss[var_name]: self._best_calibration_loss[var_name] = mse_loss @@ -681,6 +691,7 @@ class PostTrainingQuantization(object): float(np.max(np.abs(var_tensor[i])))) self._quantized_threshold[var_name] = abs_max_value _logger.info("EMD searching stage ...") + distribution = np.round if self._round_type == 'TiesToEven' else utils.round_c for var_name in self._quantized_act_var_name: var_tensor = utils.load_variable_data(self._scope, var_name) var_tensor = var_tensor.flatten() @@ -693,9 +704,9 @@ class PostTrainingQuantization(object): scale = s * abs_max_value s += 0.02 bins = 2**(self._activation_bits - 1) - 1 - quant_dequant_var = np.round( - np.clip(var_tensor, 0.0, scale) / scale * - bins) / bins * scale + quant_var = np.clip(distribution(var_tensor / scale * bins), + -bins - 1, bins) + quant_dequant_var = quant_var / bins * scale emd_loss = np.abs( np.mean(var_tensor) - np.mean(quant_dequant_var)) + np.abs( np.std(var_tensor) - np.std(quant_dequant_var)) @@ -907,7 +918,8 @@ class PostTrainingQuantization(object): activation_bits=self._activation_bits, activation_quantize_type=self._activation_quantize_type, weight_quantize_type=self._weight_quantize_type, - quantizable_op_type=major_quantizable_op_types) + quantizable_op_type=major_quantizable_op_types, + round_type=self._round_type) else: transform_pass = QuantizationTransformPassV2( scope=self._scope, @@ -916,7 +928,8 @@ class PostTrainingQuantization(object): activation_bits=self._activation_bits, activation_quantize_type=self._activation_quantize_type, weight_quantize_type=self._weight_quantize_type, - quantizable_op_type=major_quantizable_op_types) + quantizable_op_type=major_quantizable_op_types, + round_type=self._round_type) for sub_graph in graph.all_sub_graphs(): # Insert fake_quant/fake_dequantize op must in test graph, so @@ -933,13 +946,15 @@ class PostTrainingQuantization(object): add_quant_dequant_pass = AddQuantDequantPass( scope=self._scope, place=self._place, - quantizable_op_type=minor_quantizable_op_types) + quantizable_op_type=minor_quantizable_op_types, + round_type=self._round_type) else: add_quant_dequant_pass = AddQuantDequantPassV2( scope=self._scope, place=self._place, quantizable_op_type=minor_quantizable_op_types, - is_full_quantized=self._is_full_quantize) + is_full_quantized=self._is_full_quantize, + round_type=self._round_type) for sub_graph in graph.all_sub_graphs(): sub_graph._for_test = True @@ -964,6 +979,7 @@ class PostTrainingQuantization(object): place=self._place, bias_correction=self._bias_correction, weight_bits=self._weight_bits, + weight_round_algo=self._weight_round_algo, round_type=self._round_type, activation_bits=self._activation_bits, weight_quantize_type=self._weight_quantize_type, diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index 0dd79992eb..08d507284e 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -119,6 +119,7 @@ class QuantizationTransformPass(object): moving_rate=0.9, skip_pattern=['skip_quant'], quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul'], + round_type='TiesToEven', weight_quantize_func=None, act_quantize_func=None, weight_preprocess_func=None, @@ -156,6 +157,10 @@ class QuantizationTransformPass(object): quantizable_op_type(list[str]): List the type of ops that will be quantized. Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in QuantizationFreezePass and ConvertToInt8Pass must be the same as this. + round_type(str, optional): The method of converting the tensor value float->int. + Currently supports ['TiesToEven', 'TiesAwayFromZero'] methods. + Default is `TiesToEven`, which is rounding to nearest ties to even. + 'TiesAwayFromZero' is rounding to nearest ties away from zero. weight_quantize_func(function): Function that defines how to quantize weight. Using this can quickly test if user's quantization method works or not. In this function, user should both define quantization function and @@ -206,6 +211,7 @@ class QuantizationTransformPass(object): self._weight_bits = weight_bits self._activation_bits = activation_bits self._skip_pattern = skip_pattern + self._round_type = round_type self._weight_quantize_func = weight_quantize_func self._act_quantize_func = act_quantize_func self._weight_preprocess_func = weight_preprocess_func @@ -459,10 +465,12 @@ class QuantizationTransformPass(object): _init_var_node(scale_var_node, np.zeros(scale_var_node.shape(), dtype=data_type), self._scope, self._place) + round_type = 0 if self._round_type == 'TiesToEven' else 1 quant_op_node = graph.create_op_node( op_type='fake_quantize_abs_max', attrs={ 'bit_length': quant_bits, + 'round_type': round_type, 'op_role': core.op_proto_and_checker_maker.OpRole.Forward }, inputs={'X': var_node}, @@ -517,9 +525,11 @@ class QuantizationTransformPass(object): inputs['Iter'] = self._global_step outputs['OutScales'] = scales_node + round_type = 0 if self._round_type == 'TiesToEven' else 1 attrs = { 'window_size': self._window_size, 'bit_length': quant_bits, + 'round_type': round_type, 'is_test': self._is_test, 'op_role': core.op_proto_and_checker_maker.OpRole.Forward } @@ -590,8 +600,10 @@ class QuantizationTransformPass(object): outs['OutState'] = state_out_node outs['OutAccum'] = accum_out_node + round_type = 0 if self._round_type == 'TiesToEven' else 1 attrs = { 'bit_length': quant_bits, + 'round_type': round_type, 'moving_rate': self._moving_rate, 'is_test': self._is_test, 'op_role': core.op_proto_and_checker_maker.OpRole.Forward @@ -638,10 +650,12 @@ class QuantizationTransformPass(object): _init_var_node(scale_var_node, np.zeros(scale_var_node.shape(), dtype=data_type), self._scope, self._place) + round_type = 0 if self._round_type == 'TiesToEven' else 1 quant_op_node = graph.create_op_node( op_type='fake_channel_wise_quantize_abs_max', attrs={ 'bit_length': quant_bits, + 'round_type': round_type, 'quant_axis': quant_axis, 'is_test': self._is_test, 'op_role': core.op_proto_and_checker_maker.OpRole.Forward @@ -935,7 +949,8 @@ class QuantizationFreezePass(object): bias_correction=False, weight_bits=8, activation_bits=8, - round_type='round', + weight_round_algo='round', + round_type='TiesToEven', weight_quantize_type='abs_max', quantizable_op_type=None): """ @@ -953,9 +968,14 @@ class QuantizationFreezePass(object): https://arxiv.org/abs/1810.05723. weight_bits(int): quantization bit number for weights. activation_bits(int): quantization bit number for activation. - round_type(str, optional): The method of converting the quantized weights - value from float to int. Currently supports ['round', 'adaround'] methods. - Default is `round`, which is rounding nearest to the nearest whole number. + weight_round_algo(str, optional): The method of converting the quantized weights + value float->int. Currently supports ['round', 'adaround'] methods. + Default is `round`, which is rounding nearest to the integer. + 'adaround' is refer to https://arxiv.org/abs/2004.10568. + round_type(str, optional): The method of converting the tensor value float->int. + Currently supports ['TiesToEven', 'TiesAwayFromZero'] methods. + Default is `TiesToEven`, which is rounding to nearest ties to even. + 'TiesAwayFromZero' is rounding to nearest ties away from zero. weight_quantize_type(str): quantization type for weights, support 'abs_max' and 'channel_wise_abs_max'. The 'range_abs_max' usually is not used for weight, since weights are fixed once the model is well trained. @@ -971,6 +991,7 @@ class QuantizationFreezePass(object): self._place = _get_paddle_place(place) self._weight_bits = weight_bits self._activation_bits = activation_bits + self._weight_round_algo = weight_round_algo self._round_type = round_type self._weight_quantize_type = weight_quantize_type self._fake_quant_op_names = _fake_quant_op_list @@ -1018,8 +1039,8 @@ class QuantizationFreezePass(object): scale_v = scale_v.tolist() self._quant_var_scale_map[input_arg_name] = scale_v # Quantize weight and restore - param_v = self._load_var(input_arg_name) - if self._round_type == 'round': + if self._weight_round_algo == 'round': + param_v = self._load_var(input_arg_name) if any( _check_grandchild_op_node(op_node, op) for op in utils._channelwise_quant_axis1_ops): @@ -1028,8 +1049,8 @@ class QuantizationFreezePass(object): quant_axis = 0 quantized_param_v = utils.quant_tensor( param_v.copy(), scale_v, quant_axis, - self._weight_bits) - quantized_param_v = np.round(quantized_param_v) + self._weight_bits, self._round_type) + # Weight bias correction if self._bias_correction == True: quantized_param_v = utils.bias_correction_w( param_v, @@ -1037,7 +1058,6 @@ class QuantizationFreezePass(object): scale_v, quant_axis, weight_bits=self._weight_bits) - quantized_param_v = np.round(quantized_param_v) self._restore_var(input_arg_name, quantized_param_v) self._remove_fake_quant_and_dequant_op(graph, op_node) @@ -1580,7 +1600,8 @@ class AddQuantDequantPass(object): quant_bits=8, skip_pattern=["skip_quant"], quantizable_op_type=["elementwise_add", "pool2d"], - is_full_quantized=False): + is_full_quantized=False, + round_type='TiesToEven'): """ Constructor. @@ -1602,6 +1623,10 @@ class AddQuantDequantPass(object): quantization to all supported quantizable op type. If set is_full_quantized as False, only apply quantization to the op type according to the input quantizable_op_type. + round_type(str, optional): The method of converting the tensor value float->int. + Currently supports ['TiesToEven', 'TiesAwayFromZero'] methods. + Default is `TiesToEven`, which is rounding to nearest ties to even. + 'TiesAwayFromZero' is rounding to nearest ties away from zero. """ self._scope = scope self._place = _get_paddle_place(place) @@ -1609,6 +1634,7 @@ class AddQuantDequantPass(object): self._quant_bits = quant_bits self._is_test = None self._skip_pattern = skip_pattern + self._round_type = round_type if is_full_quantized: self._quantizable_op_type = utils._act_supported_quantizable_op_type @@ -1743,8 +1769,10 @@ class AddQuantDequantPass(object): outs['OutState'] = state_out_node outs['OutAccum'] = accum_out_node + round_type = 0 if self._round_type == 'TiesToEven' else 1 attrs = { 'bit_length': quant_bits, + 'round_type': round_type, 'moving_rate': self._moving_rate, 'is_test': self._is_test, 'op_role': core.op_proto_and_checker_maker.OpRole.Forward @@ -1784,6 +1812,10 @@ class InsertQuantizeLinear(object): Default is -1. channel_wise(bool, optional): Whether quantization with per channel or not. Default is False. is_test(bool, optional): Whether quantization with training or not. Default is True. + round_type(str, optional): The method of converting the tensor value float->int. + Currently supports ['TiesToEven', 'TiesAwayFromZero'] methods. + Default is `TiesToEven`, which is rounding to nearest ties to even. + 'TiesAwayFromZero' is rounding to nearest ties away from zero. """ def __init__(self, @@ -1792,13 +1824,15 @@ class InsertQuantizeLinear(object): quant_bits=8, quant_axis=-1, channel_wise=False, - is_test=True): + is_test=True, + round_type='TiesToEven'): self._place = place self._scope = scope self.quant_bits = quant_bits self.quant_axis = quant_axis self.channel_wise = channel_wise self._is_test = is_test + self._round_type = round_type def insert_quant_op(self, graph, var_node): assert var_node.is_var(), '{} is not a var'.format(var_node.name()) @@ -1841,7 +1875,12 @@ class InsertQuantizeLinear(object): if zero_point_node is not None: inputs["ZeroPoint"] = zero_point_node - attrs = {"quant_axis": self.quant_axis, "bit_length": self.quant_bits} + round_type = 0 if self._round_type == 'TiesToEven' else 1 + attrs = { + "quant_axis": self.quant_axis, + "bit_length": self.quant_bits, + "round_type": round_type + } outputs = {"Y": quant_var_node} if not self._is_test: attrs["is_test"] = self._is_test @@ -1946,6 +1985,7 @@ class QuantizationTransformPassV2(object): moving_rate=0.9, skip_pattern=['skip_quant'], quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul'], + round_type='TiesToEven', weight_quantize_func=None, act_quantize_func=None, weight_preprocess_func=None, @@ -1981,6 +2021,10 @@ class QuantizationTransformPassV2(object): quantizable_op_type(list[str]): List the type of ops that will be quantized. Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in QuantizationFreezePass and ConvertToInt8Pass must be the same as this. + round_type(str, optional): The method of converting the tensor value float->int. + Currently supports ['TiesToEven', 'TiesAwayFromZero'] methods. + Default is `TiesToEven`, which is rounding to nearest ties to even. + 'TiesAwayFromZero' is rounding to nearest ties away from zero. weight_quantize_func(function): Function that defines how to quantize weight. Using this can quickly test if user's quantization method works or not. In this function, user should both define quantization function and @@ -2030,6 +2074,7 @@ class QuantizationTransformPassV2(object): self._weight_bits = weight_bits self._activation_bits = activation_bits self._skip_pattern = skip_pattern + self._round_type = round_type self._weight_quantize_func = weight_quantize_func self._act_quantize_func = act_quantize_func self._weight_preprocess_func = weight_preprocess_func @@ -2153,7 +2198,8 @@ class QuantizationTransformPassV2(object): quant_bits=quant_bits, quant_axis=quant_axis, channel_wise=channel_wise, - is_test=self._is_test) + is_test=self._is_test, + round_type=self._round_type) quant_var_node, scale_var_node = insert_quant_pass.insert_quant_op( graph, var_node) dequant_var_node = insert_quant_pass.insert_dequant_op( @@ -2261,7 +2307,8 @@ class AddQuantDequantPassV2(object): quant_bits=8, skip_pattern=["skip_quant"], quantizable_op_type=["elementwise_add", "pool2d"], - is_full_quantized=False): + is_full_quantized=False, + round_type='TiesToEven'): """ Args: scope(paddle.Scope): The scope is used to initialize these new parameters. @@ -2281,6 +2328,10 @@ class AddQuantDequantPassV2(object): quantization to all supported quantizable op type. If set is_full_quantized as False, only apply quantization to the op type according to the input quantizable_op_type. + round_type(str, optional): The method of converting the tensor value float->int. + Currently supports ['TiesToEven', 'TiesAwayFromZero'] methods. + Default is `TiesToEven`, which is rounding to nearest ties to even. + 'TiesAwayFromZero' is rounding to nearest ties away from zero. Examples: .. code-block:: python @@ -2303,6 +2354,7 @@ class AddQuantDequantPassV2(object): self._quant_bits = quant_bits self._is_test = None self._skip_pattern = skip_pattern + self._round_type = round_type if is_full_quantized: self._quantizable_op_type = utils._act_supported_quantizable_op_type @@ -2375,7 +2427,8 @@ class AddQuantDequantPassV2(object): quant_bits=self._quant_bits, quant_axis=-1, channel_wise=False, - is_test=self._is_test) + is_test=self._is_test, + round_type=self._round_type) quant_var_node, scale_var_node = insert_quant_pass.insert_quant_op( graph, in_node) dequant_var_node = insert_quant_pass.insert_dequant_op( @@ -2458,6 +2511,8 @@ class ReplaceFakeQuantDequantPass(object): "quant_axis") else -1 bit_length = op.op().attr("bit_length") if op.op().has_attr( "bit_length") else 8 + round_type = op.op().attr("round_type") if op.op().has_attr( + "round_type") else 0 zero_point_node = None quanted_node = x_node @@ -2479,7 +2534,8 @@ class ReplaceFakeQuantDequantPass(object): quant_op_node = graph.create_op_node(op_type="quantize_linear", attrs={ "quant_axis": quant_axis, - "bit_length": bit_length + "bit_length": bit_length, + "round_type": round_type }, inputs={ "X": x_node, @@ -2598,8 +2654,11 @@ class QuantWeightPass(object): param_v = self._load_var(x_node.name()) quant_axis = _op.op().attr("quant_axis") bits_length = _op.op().attr("bit_length") + round_type = _op.op().attr("round_type") if _op.op().has_attr( + "round_type") else 0 quantized_param_v = utils.quant_tensor(param_v.copy(), scale_v, - quant_axis, bits_length) + quant_axis, bits_length, + round_type) if self._bias_correction == True: quantized_param_v = utils.bias_correction_w( param_v, diff --git a/python/paddle/fluid/contrib/slim/quantization/utils.py b/python/paddle/fluid/contrib/slim/quantization/utils.py index b9c304df5b..e396ce9dee 100644 --- a/python/paddle/fluid/contrib/slim/quantization/utils.py +++ b/python/paddle/fluid/contrib/slim/quantization/utils.py @@ -321,29 +321,39 @@ def set_variable_data(scope, place, var_name, np_value): tensor.set(np_value, place) -def quant_tensor(x, scale, quant_axis=0, weight_bits=8): - # symmetry quant - def _clip(x, scale): - x[x > scale] = scale - x[x < -scale] = -scale - return x +def round_c_single_element(val): + dtype = type(val) + if val >= 0: + return dtype(np.floor(val + 0.5)) + return dtype(np.ceil(val - 0.5)) + +# rounding to nearest ties away from zero +round_c = np.vectorize(round_c_single_element) + + +def quant_tensor(x, + scale, + quant_axis=0, + weight_bits=8, + round_type='TiesToEven'): assert quant_axis in [0, 1], 'quant_axis should be 0 or 1 for now.' + distribution = np.round if round_type == 'TiesToEven' else round_c bnt = (1 << (weight_bits - 1)) - 1 if isinstance(scale, list): for i, s in enumerate(scale): if s == 0.0: s = 1e-8 if quant_axis == 0: - x[i] = _clip(x[i], s) - x[i] = x[i] / s * bnt + x[i] = distribution(x[i] / s * bnt) + x[i] = np.clip(x[i], -bnt - 1, bnt) else: - x[:, i] = _clip(x[:, i], s) - x[:, i] = x[:, i] / s * bnt + x[:, i] = distribution(x[:, i] / s * bnt) + x[:, i] = np.clip(x[:, i], -bnt - 1, bnt) else: scale = 1e-8 if scale == 0.0 else scale - x = _clip(x, scale) - x = x / scale * bnt + x = distribution(x / scale * bnt) + x = np.clip(x, -bnt - 1, bnt) return x diff --git a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt index 4a90ab2753..de373716d8 100644 --- a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt +++ b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt @@ -558,7 +558,7 @@ if(LINUX AND WITH_MKLDNN) 120) set_tests_properties(test_quant2_int8_ernie_mkldnn PROPERTIES TIMEOUT 120) set_tests_properties(test_quant_int8_googlenet_mkldnn PROPERTIES TIMEOUT 120) - set_tests_properties(test_quant2_int8_resnet50_mkldnn PROPERTIES TIMEOUT 120) + set_tests_properties(test_quant2_int8_resnet50_mkldnn PROPERTIES TIMEOUT 200) set_tests_properties(test_quant2_int8_lstm_mkldnn PROPERTIES TIMEOUT 120) endif() diff --git a/python/paddle/fluid/contrib/slim/tests/test_imperative_ptq.py b/python/paddle/fluid/contrib/slim/tests/test_imperative_ptq.py index cde739b2c9..2c56f9ad53 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_imperative_ptq.py +++ b/python/paddle/fluid/contrib/slim/tests/test_imperative_ptq.py @@ -338,7 +338,7 @@ class TestImperativePTQKL(TestImperativePTQ): self.batch_num = 10 self.batch_size = 10 - self.eval_acc_top1 = 1.0 + self.eval_acc_top1 = 0.98 conv2d_1_wt_thresholds = [ 0.18116560578346252, 0.17079241573810577, 0.1702047884464264, diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_lstm_model.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_lstm_model.py index 6100ed4f82..befc76c027 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_lstm_model.py +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_lstm_model.py @@ -165,7 +165,7 @@ class TestPostTrainingQuantization(unittest.TestCase): model_path, data_path, algo="KL", - round_type="round", + weight_round_algo="round", quantizable_op_type=["conv2d"], is_full_quantize=False, is_use_cache_file=False, @@ -185,7 +185,7 @@ class TestPostTrainingQuantization(unittest.TestCase): batch_nums=batch_nums, algo=algo, quantizable_op_type=quantizable_op_type, - round_type=round_type, + weight_round_algo=weight_round_algo, is_full_quantize=is_full_quantize, optimize_model=is_optimize_model, onnx_format=onnx_format, @@ -201,7 +201,7 @@ class TestPostTrainingQuantization(unittest.TestCase): data_url, data_md5, algo, - round_type, + weight_round_algo, quantizable_op_type, is_full_quantize, is_use_cache_file, @@ -224,7 +224,7 @@ class TestPostTrainingQuantization(unittest.TestCase): print("Start post training quantization for {0} on {1} samples ...". format(model_name, quant_iterations)) self.generate_quantized_model(fp32_model_path, data_path, algo, - round_type, quantizable_op_type, + weight_round_algo, quantizable_op_type, is_full_quantize, is_use_cache_file, is_optimize_model, quant_iterations, onnx_format) @@ -255,7 +255,7 @@ class TestPostTrainingAvgForLSTM(TestPostTrainingQuantization): data_url = "https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/quant_lstm_input_data.tar.gz" data_md5 = "add84c754e9b792fea1fbd728d134ab7" algo = "avg" - round_type = "round" + weight_round_algo = "round" quantizable_op_type = ["mul", "lstm"] is_full_quantize = False is_use_cache_file = False @@ -264,7 +264,7 @@ class TestPostTrainingAvgForLSTM(TestPostTrainingQuantization): infer_iterations = 100 quant_iterations = 10 self.run_test(model_name, model_url, model_md5, data_name, data_url, - data_md5, algo, round_type, quantizable_op_type, + data_md5, algo, weight_round_algo, quantizable_op_type, is_full_quantize, is_use_cache_file, is_optimize_model, diff_threshold, infer_iterations, quant_iterations) @@ -279,7 +279,7 @@ class TestPostTrainingAvgForLSTMONNXFormat(TestPostTrainingQuantization): data_url = "https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/quant_lstm_input_data.tar.gz" data_md5 = "add84c754e9b792fea1fbd728d134ab7" algo = "avg" - round_type = "round" + weight_round_algo = "round" quantizable_op_type = ["mul", "lstm"] is_full_quantize = False is_use_cache_file = False @@ -295,7 +295,7 @@ class TestPostTrainingAvgForLSTMONNXFormat(TestPostTrainingQuantization): data_url, data_md5, algo, - round_type, + weight_round_algo, quantizable_op_type, is_full_quantize, is_use_cache_file, diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mnist.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mnist.py index ca2bf80765..71e974f898 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mnist.py +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mnist.py @@ -108,7 +108,7 @@ class TestPostTrainingQuantization(unittest.TestCase): def generate_quantized_model(self, model_path, algo="KL", - round_type="round", + weight_round_algo="round", quantizable_op_type=["conv2d"], is_full_quantize=False, is_use_cache_file=False, @@ -116,7 +116,8 @@ class TestPostTrainingQuantization(unittest.TestCase): batch_size=10, batch_nums=10, onnx_format=False, - skip_tensor_list=None): + skip_tensor_list=None, + bias_correction=False): place = fluid.CPUPlace() exe = fluid.Executor(place) @@ -129,9 +130,10 @@ class TestPostTrainingQuantization(unittest.TestCase): batch_nums=batch_nums, algo=algo, quantizable_op_type=quantizable_op_type, - round_type=round_type, + weight_round_algo=weight_round_algo, is_full_quantize=is_full_quantize, optimize_model=is_optimize_model, + bias_correction=bias_correction, onnx_format=onnx_format, skip_tensor_list=skip_tensor_list, is_use_cache_file=is_use_cache_file) @@ -143,7 +145,7 @@ class TestPostTrainingQuantization(unittest.TestCase): data_url, data_md5, algo, - round_type, + weight_round_algo, quantizable_op_type, is_full_quantize, is_use_cache_file, @@ -152,6 +154,7 @@ class TestPostTrainingQuantization(unittest.TestCase): batch_size=10, infer_iterations=10, quant_iterations=5, + bias_correction=False, onnx_format=False, skip_tensor_list=None): @@ -166,11 +169,12 @@ class TestPostTrainingQuantization(unittest.TestCase): print("Start INT8 post training quantization for {0} on {1} images ...". format(model_name, quant_iterations * batch_size)) - self.generate_quantized_model(origin_model_path, algo, round_type, - quantizable_op_type, is_full_quantize, - is_use_cache_file, is_optimize_model, - batch_size, quant_iterations, onnx_format, - skip_tensor_list) + self.generate_quantized_model(origin_model_path, algo, + weight_round_algo, quantizable_op_type, + is_full_quantize, is_use_cache_file, + is_optimize_model, batch_size, + quant_iterations, onnx_format, + skip_tensor_list, bias_correction) print("Start INT8 inference for {0} on {1} images ...".format( model_name, infer_iterations * batch_size)) @@ -200,7 +204,7 @@ class TestPostTrainingKLForMnist(TestPostTrainingQuantization): data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_md5 = "be71d3997ec35ac2a65ae8a145e2887c" algo = "KL" - round_type = "round" + weight_round_algo = "round" quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"] is_full_quantize = False is_use_cache_file = False @@ -209,7 +213,7 @@ class TestPostTrainingKLForMnist(TestPostTrainingQuantization): batch_size = 10 infer_iterations = 50 quant_iterations = 5 - self.run_test(model_name, data_url, data_md5, algo, round_type, + self.run_test(model_name, data_url, data_md5, algo, weight_round_algo, quantizable_op_type, is_full_quantize, is_use_cache_file, is_optimize_model, diff_threshold, batch_size, infer_iterations, quant_iterations) @@ -222,7 +226,7 @@ class TestPostTraininghistForMnist(TestPostTrainingQuantization): data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_md5 = "be71d3997ec35ac2a65ae8a145e2887c" algo = "hist" - round_type = "round" + weight_round_algo = "round" quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"] is_full_quantize = False is_use_cache_file = False @@ -231,7 +235,7 @@ class TestPostTraininghistForMnist(TestPostTrainingQuantization): batch_size = 10 infer_iterations = 50 quant_iterations = 5 - self.run_test(model_name, data_url, data_md5, algo, round_type, + self.run_test(model_name, data_url, data_md5, algo, weight_round_algo, quantizable_op_type, is_full_quantize, is_use_cache_file, is_optimize_model, diff_threshold, batch_size, infer_iterations, quant_iterations) @@ -244,7 +248,7 @@ class TestPostTrainingmseForMnist(TestPostTrainingQuantization): data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_md5 = "be71d3997ec35ac2a65ae8a145e2887c" algo = "mse" - round_type = "round" + weight_round_algo = "round" quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"] is_full_quantize = False is_use_cache_file = False @@ -253,7 +257,7 @@ class TestPostTrainingmseForMnist(TestPostTrainingQuantization): batch_size = 10 infer_iterations = 50 quant_iterations = 5 - self.run_test(model_name, data_url, data_md5, algo, round_type, + self.run_test(model_name, data_url, data_md5, algo, weight_round_algo, quantizable_op_type, is_full_quantize, is_use_cache_file, is_optimize_model, diff_threshold, batch_size, infer_iterations, quant_iterations) @@ -266,7 +270,7 @@ class TestPostTrainingemdForMnist(TestPostTrainingQuantization): data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_md5 = "be71d3997ec35ac2a65ae8a145e2887c" algo = "emd" - round_type = "round" + weight_round_algo = "round" quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"] is_full_quantize = False is_use_cache_file = False @@ -275,7 +279,7 @@ class TestPostTrainingemdForMnist(TestPostTrainingQuantization): batch_size = 10 infer_iterations = 50 quant_iterations = 5 - self.run_test(model_name, data_url, data_md5, algo, round_type, + self.run_test(model_name, data_url, data_md5, algo, weight_round_algo, quantizable_op_type, is_full_quantize, is_use_cache_file, is_optimize_model, diff_threshold, batch_size, infer_iterations, quant_iterations) @@ -288,7 +292,7 @@ class TestPostTrainingavgForMnist(TestPostTrainingQuantization): data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_md5 = "be71d3997ec35ac2a65ae8a145e2887c" algo = "avg" - round_type = "round" + weight_round_algo = "round" quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"] is_full_quantize = False is_use_cache_file = False @@ -297,7 +301,7 @@ class TestPostTrainingavgForMnist(TestPostTrainingQuantization): batch_size = 10 infer_iterations = 50 quant_iterations = 5 - self.run_test(model_name, data_url, data_md5, algo, round_type, + self.run_test(model_name, data_url, data_md5, algo, weight_round_algo, quantizable_op_type, is_full_quantize, is_use_cache_file, is_optimize_model, diff_threshold, batch_size, infer_iterations, quant_iterations) @@ -310,7 +314,7 @@ class TestPostTrainingAbsMaxForMnist(TestPostTrainingQuantization): data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_md5 = "be71d3997ec35ac2a65ae8a145e2887c" algo = "abs_max" - round_type = "round" + weight_round_algo = "round" quantizable_op_type = ["conv2d", "mul"] is_full_quantize = True is_use_cache_file = False @@ -319,7 +323,7 @@ class TestPostTrainingAbsMaxForMnist(TestPostTrainingQuantization): batch_size = 10 infer_iterations = 50 quant_iterations = 10 - self.run_test(model_name, data_url, data_md5, algo, round_type, + self.run_test(model_name, data_url, data_md5, algo, weight_round_algo, quantizable_op_type, is_full_quantize, is_use_cache_file, is_optimize_model, diff_threshold, batch_size, infer_iterations, quant_iterations) @@ -332,7 +336,7 @@ class TestPostTrainingmseAdaroundForMnist(TestPostTrainingQuantization): data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_md5 = "be71d3997ec35ac2a65ae8a145e2887c" algo = "mse" - round_type = "adaround" + weight_round_algo = "adaround" quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"] is_full_quantize = False is_use_cache_file = False @@ -341,10 +345,21 @@ class TestPostTrainingmseAdaroundForMnist(TestPostTrainingQuantization): batch_size = 10 infer_iterations = 50 quant_iterations = 5 - self.run_test(model_name, data_url, data_md5, algo, round_type, - quantizable_op_type, is_full_quantize, is_use_cache_file, - is_optimize_model, diff_threshold, batch_size, - infer_iterations, quant_iterations) + bias_correction = True + self.run_test(model_name, + data_url, + data_md5, + algo, + weight_round_algo, + quantizable_op_type, + is_full_quantize, + is_use_cache_file, + is_optimize_model, + diff_threshold, + batch_size, + infer_iterations, + quant_iterations, + bias_correction=bias_correction) class TestPostTrainingKLAdaroundForMnist(TestPostTrainingQuantization): @@ -354,7 +369,7 @@ class TestPostTrainingKLAdaroundForMnist(TestPostTrainingQuantization): data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_md5 = "be71d3997ec35ac2a65ae8a145e2887c" algo = "KL" - round_type = "adaround" + weight_round_algo = "adaround" quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"] is_full_quantize = False is_use_cache_file = False @@ -363,7 +378,7 @@ class TestPostTrainingKLAdaroundForMnist(TestPostTrainingQuantization): batch_size = 10 infer_iterations = 50 quant_iterations = 5 - self.run_test(model_name, data_url, data_md5, algo, round_type, + self.run_test(model_name, data_url, data_md5, algo, weight_round_algo, quantizable_op_type, is_full_quantize, is_use_cache_file, is_optimize_model, diff_threshold, batch_size, infer_iterations, quant_iterations) @@ -376,7 +391,7 @@ class TestPostTrainingmseForMnistONNXFormat(TestPostTrainingQuantization): data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_md5 = "be71d3997ec35ac2a65ae8a145e2887c" algo = "mse" - round_type = "round" + weight_round_algo = "round" quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"] is_full_quantize = False is_use_cache_file = False @@ -390,7 +405,7 @@ class TestPostTrainingmseForMnistONNXFormat(TestPostTrainingQuantization): data_url, data_md5, algo, - round_type, + weight_round_algo, quantizable_op_type, is_full_quantize, is_use_cache_file, @@ -410,7 +425,7 @@ class TestPostTrainingmseForMnistONNXFormatFullQuant( data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_md5 = "be71d3997ec35ac2a65ae8a145e2887c" algo = "mse" - round_type = "round" + weight_round_algo = "round" quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"] is_full_quantize = True is_use_cache_file = False @@ -424,7 +439,7 @@ class TestPostTrainingmseForMnistONNXFormatFullQuant( data_url, data_md5, algo, - round_type, + weight_round_algo, quantizable_op_type, is_full_quantize, is_use_cache_file, @@ -443,7 +458,7 @@ class TestPostTrainingavgForMnistSkipOP(TestPostTrainingQuantization): data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_md5 = "be71d3997ec35ac2a65ae8a145e2887c" algo = "avg" - round_type = "round" + weight_round_algo = "round" quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"] is_full_quantize = False is_use_cache_file = False @@ -457,7 +472,7 @@ class TestPostTrainingavgForMnistSkipOP(TestPostTrainingQuantization): data_url, data_md5, algo, - round_type, + weight_round_algo, quantizable_op_type, is_full_quantize, is_use_cache_file, diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py index 9c076d85fd..fac0dcc341 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py @@ -242,7 +242,7 @@ class TestPostTrainingQuantization(unittest.TestCase): model_path, quantizable_op_type, algo="KL", - round_type="round", + weight_round_algo="round", is_full_quantize=False, is_use_cache_file=False, is_optimize_model=False, @@ -264,7 +264,7 @@ class TestPostTrainingQuantization(unittest.TestCase): model_dir=model_path, algo=algo, quantizable_op_type=quantizable_op_type, - round_type=round_type, + weight_round_algo=weight_round_algo, is_full_quantize=is_full_quantize, optimize_model=is_optimize_model, onnx_format=onnx_format, @@ -275,7 +275,7 @@ class TestPostTrainingQuantization(unittest.TestCase): def run_test(self, model, algo, - round_type, + weight_round_algo, data_urls, data_md5s, quantizable_op_type, @@ -299,9 +299,10 @@ class TestPostTrainingQuantization(unittest.TestCase): print("Start INT8 post training quantization for {0} on {1} images ...". format(model, sample_iterations * batch_size)) self.generate_quantized_model(model_cache_folder + "/model", - quantizable_op_type, algo, round_type, - is_full_quantize, is_use_cache_file, - is_optimize_model, onnx_format) + quantizable_op_type, algo, + weight_round_algo, is_full_quantize, + is_use_cache_file, is_optimize_model, + onnx_format) print("Start INT8 inference for {0} on {1} images ...".format( model, infer_iterations * batch_size)) @@ -329,7 +330,7 @@ class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization): def test_post_training_kl_mobilenetv1(self): model = "MobileNet-V1" algo = "KL" - round_type = "round" + weight_round_algo = "round" data_urls = [ 'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz' ] @@ -344,7 +345,7 @@ class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization): is_use_cache_file = False is_optimize_model = True diff_threshold = 0.025 - self.run_test(model, algo, round_type, data_urls, data_md5s, + self.run_test(model, algo, weight_round_algo, data_urls, data_md5s, quantizable_op_type, is_full_quantize, is_use_cache_file, is_optimize_model, diff_threshold) @@ -354,7 +355,7 @@ class TestPostTrainingavgForMobilenetv1(TestPostTrainingQuantization): def test_post_training_avg_mobilenetv1(self): model = "MobileNet-V1" algo = "avg" - round_type = "round" + weight_round_algo = "round" data_urls = [ 'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz' ] @@ -368,7 +369,7 @@ class TestPostTrainingavgForMobilenetv1(TestPostTrainingQuantization): is_use_cache_file = False is_optimize_model = True diff_threshold = 0.025 - self.run_test(model, algo, round_type, data_urls, data_md5s, + self.run_test(model, algo, weight_round_algo, data_urls, data_md5s, quantizable_op_type, is_full_quantize, is_use_cache_file, is_optimize_model, diff_threshold) @@ -378,7 +379,7 @@ class TestPostTraininghistForMobilenetv1(TestPostTrainingQuantization): def test_post_training_hist_mobilenetv1(self): model = "MobileNet-V1" algo = "hist" - round_type = "round" + weight_round_algo = "round" data_urls = [ 'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz' ] @@ -392,7 +393,7 @@ class TestPostTraininghistForMobilenetv1(TestPostTrainingQuantization): is_use_cache_file = False is_optimize_model = True diff_threshold = 0.03 - self.run_test(model, algo, round_type, data_urls, data_md5s, + self.run_test(model, algo, weight_round_algo, data_urls, data_md5s, quantizable_op_type, is_full_quantize, is_use_cache_file, is_optimize_model, diff_threshold) @@ -402,7 +403,7 @@ class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization): def test_post_training_abs_max_mobilenetv1(self): model = "MobileNet-V1" algo = "abs_max" - round_type = "round" + weight_round_algo = "round" data_urls = [ 'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz' ] @@ -416,7 +417,7 @@ class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization): is_optimize_model = False # The accuracy diff of post-training quantization (abs_max) maybe bigger diff_threshold = 0.05 - self.run_test(model, algo, round_type, data_urls, data_md5s, + self.run_test(model, algo, weight_round_algo, data_urls, data_md5s, quantizable_op_type, is_full_quantize, is_use_cache_file, is_optimize_model, diff_threshold) @@ -426,7 +427,7 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization): def test_post_training_onnx_format_mobilenetv1(self): model = "MobileNet-V1" algo = "avg" - round_type = "round" + weight_round_algo = "round" data_urls = [ 'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz' ] @@ -443,7 +444,7 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization): diff_threshold = 0.05 self.run_test(model, algo, - round_type, + weight_round_algo, data_urls, data_md5s, quantizable_op_type, diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_resnet50.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_resnet50.py index c79499100c..78c5153b74 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_resnet50.py +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_resnet50.py @@ -25,7 +25,7 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization): def test_post_training_resnet50(self): model = "ResNet-50" algo = "min_max" - round_type = "round" + weight_round_algo = "round" data_urls = [ 'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz' ] @@ -35,7 +35,7 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization): is_use_cache_file = False is_optimize_model = False diff_threshold = 0.025 - self.run_test(model, algo, round_type, data_urls, data_md5s, + self.run_test(model, algo, weight_round_algo, data_urls, data_md5s, quantizable_op_type, is_full_quantize, is_use_cache_file, is_optimize_model, diff_threshold) @@ -45,7 +45,7 @@ class TestPostTrainingForResnet50ONNXFormat(TestPostTrainingQuantization): def test_post_training_resnet50(self): model = "ResNet-50" algo = "min_max" - round_type = "round" + weight_round_algo = "round" data_urls = [ 'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz' ] @@ -58,7 +58,7 @@ class TestPostTrainingForResnet50ONNXFormat(TestPostTrainingQuantization): onnx_format = True self.run_test(model, algo, - round_type, + weight_round_algo, data_urls, data_md5s, quantizable_op_type, diff --git a/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py b/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py index 3693ba615d..e76d5c49d9 100644 --- a/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py +++ b/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py @@ -21,8 +21,6 @@ import math from op_test import OpTest -# numpy.round has different behavior in comparision to c++ round function -# so we use round_c instead of numpy.round to align the output data def round_c_single_element(val): dtype = type(val) if val >= 0: @@ -30,6 +28,7 @@ def round_c_single_element(val): return dtype(np.ceil(val - 0.5)) +# rounding to nearest ties away from zero round_c = np.vectorize(round_c_single_element) @@ -46,13 +45,25 @@ class TestFakeQuantizeAbsMaxOp(OpTest): self.op_type = 'fake_quantize_abs_max' self.attrs = {'bit_length': 8} - def _fake_quantize_abs_max(self, dtype, input_shape, distribution): + def _fake_quantize_abs_max(self, + dtype, + input_shape, + distribution, + round_type='TiesToEven'): input_data = distribution(input_shape).astype(dtype) compute_type = get_compute_type(dtype) scale = np.max(np.abs(input_data)) bnt = (1 << (self.attrs['bit_length'] - 1)) - 1 inv_scale = 1.0 / (scale + 1e-6) if scale < 1e-30 else 1.0 / scale - output_data = round_c(input_data.astype(compute_type) * inv_scale * bnt) + if round_type == 'TiesToEven': + round_out = np.round( + input_data.astype(compute_type) * inv_scale * bnt) + self.attrs['round_type'] = 0 + else: + round_out = round_c( + input_data.astype(compute_type) * inv_scale * bnt) + self.attrs['round_type'] = 1 + output_data = np.clip(round_out, -bnt - 1, bnt) self.inputs = {'X': input_data} self.outputs = {'Out': output_data, 'OutScale': scale} self.dtype = dtype @@ -61,6 +72,11 @@ class TestFakeQuantizeAbsMaxOp(OpTest): def test_fake_quantize_abs_max(self): self._fake_quantize_abs_max(np.float32, (124, 240), np.random.random) + def test_fake_quantize_abs_max_round1(self): + self._fake_quantize_abs_max(np.float32, (124, 240), + np.random.random, + round_type='TiesAwayFromZero') + def test_fake_quantize_abs_max_float16(self): self._fake_quantize_abs_max(np.float16, (124, 240), np.random.random) @@ -78,8 +94,12 @@ class TestFakeChannelWiseQuantizeAbsMaxOp(OpTest): self.op_type = 'fake_channel_wise_quantize_abs_max' self.attrs = {'bit_length': 8} - def _fake_channel_wise_quantize_abs_max(self, dtype, input_shape, - quant_axis, distribution): + def _fake_channel_wise_quantize_abs_max(self, + dtype, + input_shape, + quant_axis, + distribution, + round_type='TiesToEven'): assert quant_axis in [0, 1], 'quant_axis should be 0 or 1.' input_data = distribution(input_shape).astype(dtype) compute_type = get_compute_type(dtype) @@ -87,8 +107,15 @@ class TestFakeChannelWiseQuantizeAbsMaxOp(OpTest): compute_axis = tuple(i for i in range(len(input_shape)) if i != quant_axis) scale_broadcast = np.amax(input_data, axis=compute_axis, keepdims=True) - output_data = round_c(bnt * input_data.astype(compute_type) / - scale_broadcast) + if round_type == 'TiesToEven': + round_out = np.round( + input_data.astype(compute_type) / scale_broadcast * bnt) + self.attrs['round_type'] = 0 + else: + round_out = round_c( + input_data.astype(compute_type) / scale_broadcast * bnt) + self.attrs['round_type'] = 1 + output_data = np.clip(round_out, -bnt - 1, bnt) if quant_axis == 1: scale_broadcast = np.transpose(scale_broadcast, (1, ) + compute_axis) @@ -102,16 +129,20 @@ class TestFakeChannelWiseQuantizeAbsMaxOp(OpTest): def test_fake_channel_wise_quantize_abs_max(self): dtype_options = [np.float32, np.float16] input_shape_quant_axis_options = [[(20, 15, 6, 6), 0], - [(15, 20, 5, 5), 1], [(30, 15), 0], - [(30, 15), 1]] - for dtype, input_shape_quant_axis in itertools.product( - dtype_options, input_shape_quant_axis_options): + [(20, 15, 6, 6), 1], [(30, 30), 0], + [(30, 30), 1]] + round_type_options = ['TiesToEven', 'TiesAwayFromZero'] + for dtype, input_shape_quant_axis, round_type in itertools.product( + dtype_options, input_shape_quant_axis_options, + round_type_options): input_shape, quant_axis = input_shape_quant_axis with self.subTest(dtype=dtype, input_shape=input_shape, - quant_axis=quant_axis): + quant_axis=quant_axis, + round_type=round_type): self._fake_channel_wise_quantize_abs_max( - dtype, input_shape, quant_axis, np.random.random) + dtype, input_shape, quant_axis, np.random.random, + round_type) class TestFakeQuantizeRangeAbsMaxOp(OpTest): @@ -124,7 +155,8 @@ class TestFakeQuantizeRangeAbsMaxOp(OpTest): dtype, input_shape, distribution, - is_test=False): + is_test=False, + round_type='TiesToEven'): input_data = distribution(input_shape).astype(dtype) compute_type = get_compute_type(dtype) bnt = (1 << (self.attrs['bit_length'] - 1)) - 1 @@ -133,11 +165,15 @@ class TestFakeQuantizeRangeAbsMaxOp(OpTest): out_scale[0] = np.max(np.abs(input_data)) if is_test: out_scale[0] = in_scale[0] = out_scale[0] - 1.0 - clip_data = np.clip(input_data, -in_scale, in_scale) + if round_type == 'TiesToEven': + round_out = np.round( + input_data.astype(compute_type) / out_scale[0] * bnt) + self.attrs['round_type'] = 0 else: - clip_data = input_data - output_data = round_c( - clip_data.astype(compute_type) / out_scale[0] * bnt) + round_out = round_c( + input_data.astype(compute_type) / out_scale[0] * bnt) + self.attrs['round_type'] = 1 + output_data = np.clip(round_out, -bnt - 1, bnt) self.inputs = { 'X': input_data, 'Iter': np.zeros(1).astype(np.int64), @@ -153,15 +189,20 @@ class TestFakeQuantizeRangeAbsMaxOp(OpTest): self.check_output() def test_fake_quantize_range_abs_max(self): - dtype_options = [np.float32, np.float16] + dtype_options = [np.float16, np.float32] is_test_options = [False, True] - for dtype, is_test in itertools.product(dtype_options, is_test_options): + round_type_options = ['TiesToEven', 'TiesAwayFromZero'] + for dtype, is_test, round_type in itertools.product( + dtype_options, is_test_options, round_type_options): self.attrs['bit_length'] = 8 if is_test else 5 - with self.subTest(dtype=dtype, is_test=is_test): + with self.subTest(dtype=dtype, + is_test=is_test, + round_type=round_type): self._fake_quantize_range_abs_max( - dtype, (8, 16, 7, 7), - lambda shape: (np.random.random(shape) - 0.5) * 10, - is_test=is_test) + dtype, (8, 16, 6, 6), + lambda shape: (np.random.random(shape) - 0.4) * 10, + is_test=is_test, + round_type=round_type) class TestMovingAverageAbsMaxScaleOp(OpTest): @@ -208,7 +249,8 @@ class TestFakeQuantizeMovingAverageAbsMaxOp(OpTest): input_shape, distribution, dequantize=False, - with_gradient=False): + with_gradient=False, + round_type='TiesToEven'): input_data = distribution(input_shape).astype(dtype) compute_type = get_compute_type(dtype) bnt = (1 << (self.attrs['bit_length'] - 1)) - 1 @@ -222,12 +264,20 @@ class TestFakeQuantizeMovingAverageAbsMaxOp(OpTest): np.abs(input_data)) out_state[0] = self.attrs['moving_rate'] * in_state[0] + 1.0 out_scale = out_accum / out_state - round_data = round_c(input_data.astype(compute_type) / out_scale * bnt) + if round_type == 'TiesToEven': + round_out = np.round( + input_data.astype(compute_type) / out_scale * bnt) + self.attrs['round_type'] = 0 + else: + round_out = round_c( + input_data.astype(compute_type) / out_scale * bnt) + self.attrs['round_type'] = 1 + quant_data = np.clip(round_out, -bnt - 1, bnt) if dequantize: - output_data = (round_data * out_scale / bnt).astype(dtype) + output_data = (quant_data * out_scale / bnt).astype(dtype) self.op_type = 'fake_quantize_dequantize_moving_average_abs_max' else: - output_data = round_data.astype(dtype) + output_data = quant_data.astype(dtype) self.inputs = { 'X': input_data, 'InScale': in_scale, @@ -256,6 +306,12 @@ class TestFakeQuantizeMovingAverageAbsMaxOp(OpTest): self._fake_quantize_moving_average_abs_max(np.float16, (8, 16, 7, 7), np.random.random) + def test_fake_quantize_moving_average_abs_max_round1(self): + self._fake_quantize_moving_average_abs_max( + np.float32, (8, 16, 7, 7), + np.random.random, + round_type='TiesAwayFromZero') + def test_fake_quantize_dequantize_moving_average_abs_max(self): self._fake_quantize_moving_average_abs_max(np.float32, (8, 16, 7, 7), np.random.random, @@ -269,12 +325,21 @@ class TestFakeQuantizeDequantizeAbsMaxOp(OpTest): self.op_type = 'fake_quantize_dequantize_abs_max' self.attrs = {'bit_length': 8} - def _fake_quantize_dequantize_abs_max(self, dtype, input_shape, - distribution): + def _fake_quantize_dequantize_abs_max(self, + dtype, + input_shape, + distribution, + round_type='TiesToEven'): input_data = distribution(input_shape).astype(dtype) scale = np.max(np.abs(input_data)).astype(dtype) bnt = (1 << (self.attrs['bit_length'] - 1)) - 1 - output_data = round_c(input_data / scale * bnt) * scale / bnt + if round_type == 'TiesToEven': + round_out = np.round(input_data / scale * bnt) + self.attrs['round_type'] = 0 + else: + round_out = round_c(input_data / scale * bnt) + self.attrs['round_type'] = 1 + output_data = np.clip(round_out, -bnt - 1, bnt) * scale / bnt self.inputs = {'X': input_data} self.outputs = { 'Out': output_data, @@ -289,6 +354,11 @@ class TestFakeQuantizeDequantizeAbsMaxOp(OpTest): self._fake_quantize_dequantize_abs_max(np.float32, (124, 240), np.random.random) + def test_fake_quantize_dequantize_abs_max_round1(self): + self._fake_quantize_dequantize_abs_max(np.float32, (124, 240), + np.random.random, + round_type='TiesAwayFromZero') + class TestChannelWiseFakeQuantizeDequantizeAbsMaxOp(OpTest): @@ -296,9 +366,12 @@ class TestChannelWiseFakeQuantizeDequantizeAbsMaxOp(OpTest): self.op_type = 'fake_channel_wise_quantize_dequantize_abs_max' self.attrs = {'bit_length': 8} - def _fake_channel_wise_quantize_dequantize_abs_max(self, dtype, input_shape, + def _fake_channel_wise_quantize_dequantize_abs_max(self, + dtype, + input_shape, quant_axis, - distribution): + distribution, + round_type='TiesToEven'): assert quant_axis in [0, 1], 'quant_axis should be 0 or 1.' input_data = distribution(input_shape).astype(dtype) compute_type = get_compute_type(dtype) @@ -307,8 +380,13 @@ class TestChannelWiseFakeQuantizeDequantizeAbsMaxOp(OpTest): compute_axis = tuple(i for i in range(len(input_shape)) if i != quant_axis) scale_broadcast = np.amax(input_data, axis=compute_axis, keepdims=True) - output_data = round_c( - bnt * output_data / scale_broadcast) * scale_broadcast / bnt + if round_type == 'TiesToEven': + round_out = np.round(bnt * output_data / scale_broadcast) + self.attrs['round_type'] = 0 + else: + round_out = round_c(bnt * output_data / scale_broadcast) + self.attrs['round_type'] = 1 + output_data = np.clip(round_out, -bnt - 1, bnt) * scale_broadcast / bnt if quant_axis == 1: scale_broadcast = np.transpose(scale_broadcast, (1, ) + compute_axis) @@ -325,10 +403,19 @@ class TestChannelWiseFakeQuantizeDequantizeAbsMaxOp(OpTest): input_shape_quant_axis_options = [[(3, 4, 64, 64), 0], [(15, 20, 5, 5), 1], [(30, 15), 0], [(30, 15), 1]] - for input_shape, quant_axis in input_shape_quant_axis_options: - with self.subTest(input_shape=input_shape, quant_axis=quant_axis): + round_type_options = ['TiesToEven', 'TiesAwayFromZero'] + for input_shape_quant_axis, round_type in itertools.product( + input_shape_quant_axis_options, round_type_options): + input_shape, quant_axis = input_shape_quant_axis + with self.subTest(input_shape=input_shape, + quant_axis=quant_axis, + round_type=round_type): self._fake_channel_wise_quantize_dequantize_abs_max( - np.float32, input_shape, quant_axis, np.random.random) + np.float32, + input_shape, + quant_axis, + np.random.random, + round_type=round_type) def quantize_max_abs(x, max_range): -- GitLab