From ff70a2694e264e1b5774ba6d91dc60f93098d206 Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Mon, 27 Jun 2022 16:14:58 +0800 Subject: [PATCH] [cherry-pick]Update quantization round and clip calculation methods (#43829) * update quantization clip and round * fix quantization clip and round Attribute * fix typo --- .../ir/delete_quant_dequant_filter_op_pass.cc | 28 +- .../ir/delete_quant_dequant_linear_op_pass.cc | 11 +- .../delete_weight_dequant_linear_op_pass.cc | 26 +- .../ir/quant_conv2d_dequant_fuse_pass.cc | 83 +- paddle/fluid/operators/fake_quantize_op.cc | 539 ++++++++---- paddle/fluid/operators/fake_quantize_op.cu.h | 476 +++++++---- paddle/fluid/operators/fake_quantize_op.h | 362 +++++--- paddle/fluid/operators/quantize_linear_op.cc | 69 +- paddle/fluid/operators/quantize_linear_op.h | 44 +- .../contrib/slim/quantization/adaround.py | 128 +-- .../post_training_quantization.py | 233 +++--- .../slim/quantization/quantization_pass.py | 430 +++++----- .../fluid/contrib/slim/quantization/utils.py | 27 +- .../fluid/contrib/slim/tests/CMakeLists.txt | 776 +++++++++++------- .../contrib/slim/tests/test_imperative_ptq.py | 74 +- ...t_post_training_quantization_lstm_model.py | 96 +-- .../test_post_training_quantization_mnist.py | 177 ++-- ..._post_training_quantization_mobilenetv1.py | 136 ++- ...est_post_training_quantization_resnet50.py | 25 +- .../tests/unittests/test_fake_quantize_op.py | 204 +++-- 20 files changed, 2406 insertions(+), 1538 deletions(-) diff --git a/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc b/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc index 2fc133edb7a..86639e4ff42 100644 --- a/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc @@ -45,6 +45,10 @@ DeleteQuantDequantFilterOpPass::DeleteQuantDequantFilterOpPass() { .End() .AddAttr("bit_length") .IsIntIn({8, 16}) + .End() + .AddAttr("round_type") + .IsOptional() + .IsIntIn({0, 1}) .End(); AddOpCompat(OpCompat("fake_channel_wise_quantize_dequantize_abs_max")) .AddInput("X") @@ -61,6 +65,10 @@ DeleteQuantDequantFilterOpPass::DeleteQuantDequantFilterOpPass() { .End() .AddAttr("quant_axis") .IsIntIn({0, 1}) + .End() + .AddAttr("round_type") + .IsOptional() + .IsIntIn({0, 1}) .End(); } // Delete quant_dequant_op, then quantize and dequantize weight @@ -96,15 +104,18 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const { auto var_map = any_op2_desc->Inputs(); std::string arg_name = ""; for (auto& name_m : var_map) { - if (std::find(name_m.second.begin(), name_m.second.end(), + if (std::find(name_m.second.begin(), + name_m.second.end(), quant_dequant_op_out_name) != name_m.second.end()) { arg_name = name_m.first; break; } } - PADDLE_ENFORCE_GT(arg_name.size(), 0, platform::errors::InvalidArgument( - "can not find the input %s.", - quant_dequant_op_out_name)); + PADDLE_ENFORCE_GT( + arg_name.size(), + 0, + platform::errors::InvalidArgument("can not find the input %s.", + quant_dequant_op_out_name)); // any_op2_desc->SetAttr("enable_int8", true); any_op2_desc->SetAttr("bit_length", bit_length); @@ -123,7 +134,8 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const { if (dequant_type == "fake_channel_wise_quantize_dequantize_abs_max") { int quant_axis = BOOST_GET_CONST(int, quant_dequant_op->Op()->GetAttr("quant_axis")); - PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1, true, + PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1, + true, platform::errors::InvalidArgument( "'quant_axis' should be 0 or 1, but " "the received is %d", @@ -176,7 +188,8 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const { } } for (int i = 0; i < channel; i++) { - PADDLE_ENFORCE_NE(weight_scale[i], 0, + PADDLE_ENFORCE_NE(weight_scale[i], + 0, platform::errors::InvalidArgument( "Weight scale should be nonzero, but get zero.")); weight_scale[i] = weight_scale[i] / range; @@ -188,7 +201,8 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const { abs_max_weight = std::max(abs_max_weight, std::abs(quantized_weight_data[j])); } - PADDLE_ENFORCE_NE(abs_max_weight, 0, + PADDLE_ENFORCE_NE(abs_max_weight, + 0, platform::errors::InvalidArgument( "Weight scale should be nonzero, but get zero")); weight_scale.push_back(abs_max_weight / range); diff --git a/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc b/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc index 8f2b58ed51b..08e8aa3b360 100644 --- a/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc @@ -54,6 +54,10 @@ DeleteQuantDequantLinearOpPass::DeleteQuantDequantLinearOpPass() { .End() .AddAttr("quant_axis") .IsType() + .End() + .AddAttr("round_type") + .IsOptional() + .IsType() .End(); AddOpCompat(OpCompat("dequantize_linear")) .AddInput("X") @@ -74,6 +78,10 @@ DeleteQuantDequantLinearOpPass::DeleteQuantDequantLinearOpPass() { .End() .AddAttr("quant_axis") .IsType() + .End() + .AddAttr("round_type") + .IsOptional() + .IsType() .End(); } // Delete quantize_linear_op dequantize_linear_op, then add input_scales @@ -112,7 +120,8 @@ void DeleteQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { const LoDTensor& input_scale_tensor = scope->GetVar(quantize_linear_op_scale->Name())->Get(); PADDLE_ENFORCE_EQ( - paddle::platform::is_cpu_place(input_scale_tensor.place()), true, + paddle::platform::is_cpu_place(input_scale_tensor.place()), + true, platform::errors::InvalidArgument( "Input scale tensor's place should be CPU.")); const float* input_scale_data = input_scale_tensor.data(); diff --git a/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc b/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc index 8ebea231e7a..b37d7978a8e 100644 --- a/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc @@ -52,6 +52,10 @@ DeleteWeightQuantDequantLinearOpPass::DeleteWeightQuantDequantLinearOpPass() { .End() .AddAttr("quant_axis") .IsType() + .End() + .AddAttr("round_type") + .IsOptional() + .IsType() .End(); AddOpCompat(OpCompat("dequantize_linear")) .AddInput("X") @@ -72,6 +76,10 @@ DeleteWeightQuantDequantLinearOpPass::DeleteWeightQuantDequantLinearOpPass() { .End() .AddAttr("quant_axis") .IsType() + .End() + .AddAttr("round_type") + .IsOptional() + .IsType() .End(); AddOpCompat(OpCompat("conv2d")) .AddInput("Input") @@ -322,7 +330,8 @@ void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { int quant_axis = BOOST_GET_CONST( int, weight_dequantize_linear_op->Op()->GetAttr("quant_axis")); if (quant_axis == -1) { // per_layer quant_dequant: all OP - PADDLE_ENFORCE_EQ(weight_scale_nums, 1, + PADDLE_ENFORCE_EQ(weight_scale_nums, + 1, platform::errors::InvalidArgument( "When quant_axis == -1 means use per_layer " "quant_dequant, weight_scale'number should be 1.")); @@ -335,11 +344,13 @@ void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { } else if (quant_axis == 0) { // per_channel quant_dequant: conv2d, // depthwise_conv2d, conv2d_fusion PADDLE_ENFORCE_EQ( - weight_scale_nums, w_dims[quant_axis], + weight_scale_nums, + w_dims[quant_axis], platform::errors::InvalidArgument( "When quant_axis == 0 means use per_channel quant_dequant, " "weight_scale'numbers should be equal channels.")); - PADDLE_ENFORCE_EQ(w_dims.size(), 4, + PADDLE_ENFORCE_EQ(w_dims.size(), + 4, platform::errors::InvalidArgument( "When quant_axis == 0 means use per_channel " "quant_dequant, (conv2d, depthwise_conv2d, " @@ -352,7 +363,8 @@ void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { } } else if (quant_axis == 1) { PADDLE_ENFORCE_EQ( - weight_scale_nums, w_dims[quant_axis], + weight_scale_nums, + w_dims[quant_axis], platform::errors::InvalidArgument( "When quant_axis == 1 means use per_channel quant_dequant, " "weight_scale'numbers should be equal channels.")); @@ -360,7 +372,8 @@ void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { if (w_dims.size() == 4) { // conv2d_transpose std::string quantized_op_type = any_op2->Op()->Type(); PADDLE_ENFORCE_EQ( - quantized_op_type, "conv2d_transpose", + quantized_op_type, + "conv2d_transpose", platform::errors::InvalidArgument( "When quant_axis == 1 means use per_channel quant_dequant, " "only conv2d_transpose weight dims equal 4.")); @@ -388,7 +401,8 @@ void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { weight_tensor->Resize(phi::make_ddim(phi::vectorize(w_dims))); float* new_quantized_weight_data = weight_tensor->mutable_data(platform::CPUPlace()); - memcpy(new_quantized_weight_data, weight_data_tmp.data(), + memcpy(new_quantized_weight_data, + weight_data_tmp.data(), weight_tensor->numel() * sizeof(float)); nodes2rm.insert(weight_dequantize_linear_op_scale); diff --git a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc index 281e0b99106..058f69aa58a 100644 --- a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc +++ b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc @@ -49,6 +49,10 @@ QuantDequantFusePass::QuantDequantFusePass() { .End() .AddAttr("bit_length") .IsIntIn({8, 16}) + .End() + .AddAttr("round_type") + .IsOptional() + .IsIntIn({0, 1}) .End(); AddOpCompat(OpCompat("fake_quantize_moving_average_abs_max")) .AddInput("X") @@ -85,6 +89,10 @@ QuantDequantFusePass::QuantDequantFusePass() { .End() .AddAttr("bit_length") .IsIntIn({8, 16}) + .End() + .AddAttr("round_type") + .IsOptional() + .IsIntIn({0, 1}) .End(); AddOpCompat(OpCompat("fake_dequantize_max_abs")) .AddInput("X") @@ -309,7 +317,8 @@ QuantDequantFusePass::QuantDequantFusePass() { } // Delete quant op before quantized ops, and set input scale in the attr of // quantized ops -void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope, +void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, + Scope* scope, const std::string& quant_type) const { const std::string pattern_name = "delete_quant_fuse"; GraphPatternDetector gpd; @@ -331,7 +340,8 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope, return; } PADDLE_ENFORCE_EQ( - subgraph.count(input_act_node), true, + subgraph.count(input_act_node), + true, platform::errors::NotFound( "Input act node(%s) not found in QuantDequantFuse pass.", input_act_node->name())); @@ -345,12 +355,14 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope, // Get input scale from tensor std::string input_scale_var_name = quant->Op()->Input("InScale").front(); PADDLE_ENFORCE_NOT_NULL( - scope, platform::errors::InvalidArgument( - "Scope in QuantDequantFuse pass should not be null.")); + scope, + platform::errors::InvalidArgument( + "Scope in QuantDequantFuse pass should not be null.")); const LoDTensor& input_scale_tensor = scope->FindVar(input_scale_var_name)->Get(); PADDLE_ENFORCE_EQ( - paddle::platform::is_cpu_place(input_scale_tensor.place()), true, + paddle::platform::is_cpu_place(input_scale_tensor.place()), + true, platform::errors::InvalidArgument( "Input scale tensor's place should be CPU.")); const float* input_scale_data = input_scale_tensor.data(); @@ -382,8 +394,8 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope, IR_NODE_LINK_TO(input_act, quantized_node); } // Delete nodes and edges - std::unordered_set nodes2rm = {input_scale, quant, - output_scale, output_act}; + std::unordered_set nodes2rm = { + input_scale, quant, output_scale, output_act}; GraphSafeRemoveNodes(graph, nodes2rm); }; gpd(graph, handler); @@ -391,7 +403,8 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope, // Delete dequant op after quantized ops, and convert weight from fp32 range to // int8 range -void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, +void QuantDequantFusePass::FuseDequant(ir::Graph* graph, + Scope* scope, const std::string& quantized_op_type, const std::string& dequant_type) const { std::string weight_name = ""; @@ -436,7 +449,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, return; } PADDLE_ENFORCE_EQ( - subgraph.count(quantized_op_input), true, + subgraph.count(quantized_op_input), + true, platform::errors::NotFound("Quantized op input node(%s) did not find " "in QuantDequantFuse pass.", quantized_op_input->name())); @@ -464,14 +478,16 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, subgraph.at(pattern.GetPDNode("dequant_channel_scale")); auto scales_name = dequant_op_node->Op()->Input("Scales"); PADDLE_ENFORCE_EQ( - scales_name.size(), 2, + scales_name.size(), + 2, platform::errors::InvalidArgument( "Scales size in channel-wise dequantize op should be 2, got %d.", scales_name.size())); const LoDTensor& channel_scale_tensor = scope->FindVar(scales_name[0])->Get(); PADDLE_ENFORCE_EQ( - paddle::platform::is_cpu_place(channel_scale_tensor.place()), true, + paddle::platform::is_cpu_place(channel_scale_tensor.place()), + true, platform::errors::InvalidArgument( "Channel scale tensor's place should be CPU.")); const float* channel_scale_data = channel_scale_tensor.data(); @@ -497,7 +513,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, if (quantized_op_type == "mul" || quantized_op_type == "matmul" || quantized_op_type == "matmul_v2" || quantized_op_type == "fc") { if (dequant_type == "fake_dequantize_max_abs") { - PADDLE_ENFORCE_EQ(weight_scale.size(), 1, + PADDLE_ENFORCE_EQ(weight_scale.size(), + 1, platform::errors::InvalidArgument( "mul/matmul/matmul_v2 op weight dequantized by " "[fake_dequantize_max_abs] " @@ -511,7 +528,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, if (quant_axis == 0) { } else { PADDLE_ENFORCE_EQ( - quant_axis == 1, true, + quant_axis == 1, + true, platform::errors::InvalidArgument( "'quant_axis' of mul/matmul/fc/matmul_v2 op weight " "dequantized by " @@ -520,14 +538,16 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, quant_axis)); } PADDLE_ENFORCE_EQ( - weight_scale.size(), static_cast(w_dims[1]), + weight_scale.size(), + static_cast(w_dims[1]), platform::errors::InvalidArgument( "mul/matmul/matmul_v2 op weight dequantized by " "[fake_channel_wise_dequantize_max_abs] requires weight scale " "size = 2nd dim of mul/matmul/matmul_v2's weight, which is %d, " "but got " "%d.", - static_cast(w_dims[1]), weight_scale.size())); + static_cast(w_dims[1]), + weight_scale.size())); for (int j = 0; j < weight_tensor->numel(); j++) { quantized_weight_data[j] *= weight_scale[j % w_dims[1]]; } @@ -535,7 +555,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, } else if (quantized_op_type == "conv2d" || quantized_op_type == "depthwise_conv2d") { PADDLE_ENFORCE_EQ( - dequant_type, "fake_channel_wise_dequantize_max_abs", + dequant_type, + "fake_channel_wise_dequantize_max_abs", platform::errors::InvalidArgument( "conv2d op must be dequantized by " "[fake_channel_wise_dequantize_max_abs], but got %s. " @@ -546,7 +567,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, if (quant_axis == 0) { } else { PADDLE_ENFORCE_EQ( - quant_axis == 0, true, + quant_axis == 0, + true, platform::errors::InvalidArgument( "'quant_axis' of conv2d/depthwise_conv2d op weight dequantized " "by [fake_channel_wise_dequantize_max_abs]should be 0, but " @@ -554,18 +576,21 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, quant_axis)); } PADDLE_ENFORCE_EQ( - weight_scale.size(), static_cast(w_dims[0]), + weight_scale.size(), + static_cast(w_dims[0]), platform::errors::InvalidArgument( "conv2d op requires weight scale size = channel size of the " "weight, which is %d, but got %d.", - static_cast(w_dims[0]), weight_scale.size())); + static_cast(w_dims[0]), + weight_scale.size())); for (int j = 0; j < weight_tensor->numel(); j++) { int inner_size = w_dims[1] * w_dims[2] * w_dims[3]; quantized_weight_data[j] *= weight_scale[j / inner_size]; } } else if (quantized_op_type == "conv2d_transpose") { PADDLE_ENFORCE_EQ( - dequant_type, "fake_channel_wise_dequantize_max_abs", + dequant_type, + "fake_channel_wise_dequantize_max_abs", platform::errors::InvalidArgument( "conv2d_transpose must be dequantized by " "[fake_channel_wise_dequantize_max_abs], but got %s", @@ -573,7 +598,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, if (quant_axis == 0) { } else { PADDLE_ENFORCE_EQ( - quant_axis == 1, true, + quant_axis == 1, + true, platform::errors::InvalidArgument( "'quant_axis' of conv2d_transpose op weight dequantized by " "[fake_channel_wise_dequantize_max_abs]should be 1, but " @@ -581,11 +607,13 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, quant_axis)); } PADDLE_ENFORCE_EQ( - weight_scale.size(), static_cast(w_dims[1]), + weight_scale.size(), + static_cast(w_dims[1]), platform::errors::InvalidArgument( "conv2d_transpose op requires weight scale size = channel size " "of the weight, which is %d, but got %d.", - static_cast(w_dims[1]), weight_scale.size())); + static_cast(w_dims[1]), + weight_scale.size())); for (int j = 0; j < weight_tensor->numel(); j++) { int inner_size = w_dims[2] * w_dims[3]; quantized_weight_data[j] *= weight_scale[(j / inner_size) % w_dims[1]]; @@ -639,8 +667,13 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const { std::unordered_set quant_types = { "fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"}; std::unordered_set quantized_op_types = { - "conv2d", "mul", "matmul", "depthwise_conv2d", - "conv2d_transpose", "fc", "matmul_v2", + "conv2d", + "mul", + "matmul", + "depthwise_conv2d", + "conv2d_transpose", + "fc", + "matmul_v2", }; auto* scope = param_scope(); diff --git a/paddle/fluid/operators/fake_quantize_op.cc b/paddle/fluid/operators/fake_quantize_op.cc index ac72f23d46e..61ee9d49ebe 100644 --- a/paddle/fluid/operators/fake_quantize_op.cc +++ b/paddle/fluid/operators/fake_quantize_op.cc @@ -13,8 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/fake_quantize_op.h" + #include #include + #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/platform/transform.h" @@ -31,8 +33,10 @@ struct Compare { template struct FindAbsMaxFunctor { - void operator()(const platform::CPUDeviceContext& ctx, const T* in, - const int num, T* out) { + void operator()(const platform::CPUDeviceContext &ctx, + const T *in, + const int num, + T *out) { *out = std::abs(*(std::max_element(in + 0, in + num, Compare()))); } }; @@ -41,24 +45,26 @@ template struct FindAbsMaxFunctor; template struct FindChannelAbsMaxFunctor { - void operator()(const platform::CPUDeviceContext& ctx, - const framework::Tensor& in_tensor, const int quant_axis, - T* out_abs_max) { + void operator()(const platform::CPUDeviceContext &ctx, + const framework::Tensor &in_tensor, + const int quant_axis, + T *out_abs_max) { // At present, channelwise quantization supports conv2d, depthwise_conv2d // conv2d_transpose and mul PADDLE_ENFORCE_EQ( - quant_axis == 0 || quant_axis == 1, true, + quant_axis == 0 || quant_axis == 1, + true, platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " "the received is %d", quant_axis)); - auto* in_data = in_tensor.data(); + auto *in_data = in_tensor.data(); auto in_dims = in_tensor.dims(); const int64_t channel = in_dims[quant_axis]; if (quant_axis == 0) { const int64_t channel_size = in_tensor.numel() / channel; for (int64_t i = 0; i < channel; i++) { - auto* start = in_data + i * channel_size; - auto* end = in_data + (i + 1) * channel_size; + auto *start = in_data + i * channel_size; + auto *end = in_data + (i + 1) * channel_size; out_abs_max[i] = std::abs(*(std::max_element(start, end, Compare()))); } @@ -70,8 +76,8 @@ struct FindChannelAbsMaxFunctor { const int64_t step_j = in_tensor.numel() / (in_dims[0] * in_dims[1]); for (int64_t i = 0; i < in_dims[0]; i++) { for (int64_t j = 0; j < in_dims[1]; j++) { - auto* start = in_data + i * step_i + j * step_j; - auto* end = in_data + i * step_i + (j + 1) * step_j; + auto *start = in_data + i * step_i + j * step_j; + auto *end = in_data + i * step_i + (j + 1) * step_j; T abs_max = std::abs(*(std::max_element(start, end, Compare()))); out_abs_max[j] = std::max(out_abs_max[j], abs_max); } @@ -84,16 +90,30 @@ template struct FindChannelAbsMaxFunctor; template struct ClipAndFakeQuantFunctor { - void operator()(const platform::CPUDeviceContext& ctx, - const framework::Tensor& in, const framework::Tensor& scale, - const int bin_cnt, framework::Tensor* out) { + void operator()(const platform::CPUDeviceContext &ctx, + const framework::Tensor &in, + const framework::Tensor &scale, + const int bin_cnt, + const int round_type, + framework::Tensor *out) { T s = scale.data()[0]; T inv_s = inverse(s); platform::Transform trans; - trans(ctx, in.data(), in.data() + in.numel(), - out->mutable_data(ctx.GetPlace()), phi::ClipFunctor(-s, s)); - auto out_e = framework::EigenVector::Flatten(*out); - out_e.device(*ctx.eigen_device()) = (bin_cnt * inv_s * out_e).round(); + if (round_type == 0) { + trans(ctx, + in.data(), + in.data() + in.numel(), + out->mutable_data(ctx.GetPlace()), + QuantTensorFunctor(static_cast(bin_cnt), inv_s)); + } else { + trans(ctx, + in.data(), + in.data() + in.numel(), + out->mutable_data(ctx.GetPlace()), + phi::ClipFunctor(-s, s)); + auto out_e = framework::EigenVector::Flatten(*out); + out_e.device(*ctx.eigen_device()) = (bin_cnt * inv_s * out_e).round(); + } } }; @@ -101,18 +121,34 @@ template struct ClipAndFakeQuantFunctor; template struct ClipAndFakeQuantDequantFunctor { - void operator()(const platform::CPUDeviceContext& ctx, - const framework::Tensor& in, const framework::Tensor& scale, - const int bin_cnt, framework::Tensor* out) { + void operator()(const platform::CPUDeviceContext &ctx, + const framework::Tensor &in, + const framework::Tensor &scale, + const int bin_cnt, + const int round_type, + framework::Tensor *out) { T s = scale.data()[0]; T inv_s = inverse(s); platform::Transform trans; - trans(ctx, in.data(), in.data() + in.numel(), - out->mutable_data(ctx.GetPlace()), phi::ClipFunctor(-s, s)); - auto out_e = framework::EigenVector::Flatten(*out); - out_e.device(*ctx.eigen_device()) = - (bin_cnt * inv_s * out_e).round() * s / static_cast(bin_cnt); + if (round_type == 0) { + trans(ctx, + in.data(), + in.data() + in.numel(), + out->mutable_data(ctx.GetPlace()), + QuantTensorFunctor(static_cast(bin_cnt), inv_s)); + auto out_e = framework::EigenVector::Flatten(*out); + out_e.device(*ctx.eigen_device()) = out_e * s / static_cast(bin_cnt); + } else { + trans(ctx, + in.data(), + in.data() + in.numel(), + out->mutable_data(ctx.GetPlace()), + phi::ClipFunctor(-s, s)); + auto out_e = framework::EigenVector::Flatten(*out); + out_e.device(*ctx.eigen_device()) = + (bin_cnt * inv_s * out_e).round() * s / static_cast(bin_cnt); + } } }; template struct ClipAndFakeQuantDequantFunctor struct ChannelClipAndFakeQuantFunctor { - void operator()(const platform::CPUDeviceContext& ctx, - const framework::Tensor& in, const framework::Tensor& scale, - const int bin_cnt, const int quant_axis, - framework::Tensor* out) { + void operator()(const platform::CPUDeviceContext &ctx, + const framework::Tensor &in, + const framework::Tensor &scale, + const int bin_cnt, + const int round_type, + const int quant_axis, + framework::Tensor *out) { // At present, channelwise quantization supports conv2d, depthwise_conv2d // conv2d_transpose and mul PADDLE_ENFORCE_EQ( - quant_axis == 0 || quant_axis == 1, true, + quant_axis == 0 || quant_axis == 1, + true, platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " "the received is %d", quant_axis)); - auto* scale_data = scale.data(); - auto* in_data = in.data(); - auto* out_data = out->mutable_data(ctx.GetPlace()); + auto *scale_data = scale.data(); + auto *in_data = in.data(); + auto *out_data = out->mutable_data(ctx.GetPlace()); auto in_dims = in.dims(); const int64_t channel = in_dims[quant_axis]; platform::Transform trans; @@ -141,17 +181,31 @@ struct ChannelClipAndFakeQuantFunctor { const int64_t channel_size = in.numel() / channel; for (int64_t i = 0; i < channel; i++) { T s = scale_data[i]; - auto* start = in_data + i * channel_size; - auto* end = in_data + (i + 1) * channel_size; - trans(ctx, start, end, out_data + i * channel_size, - phi::ClipFunctor(-s, s)); - } - for (int64_t i = 0; i < channel; i++) { - T s = scale_data[i]; + auto *start = in_data + i * channel_size; + auto *end = in_data + (i + 1) * channel_size; T inv_s = inverse(s); - framework::Tensor one_channel_out = out->Slice(i, i + 1); - auto out_e = framework::EigenVector::Flatten(one_channel_out); - out_e.device(*ctx.eigen_device()) = (bin_cnt * inv_s * out_e).round(); + if (round_type == 0) { + trans(ctx, + start, + end, + out_data + i * channel_size, + QuantTensorFunctor(static_cast(bin_cnt), inv_s)); + } else { + trans(ctx, + start, + end, + out_data + i * channel_size, + phi::ClipFunctor(-s, s)); + } + } + if (round_type == 1) { + for (int64_t i = 0; i < channel; i++) { + T s = scale_data[i]; + T inv_s = inverse(s); + framework::Tensor one_channel_out = out->Slice(i, i + 1); + auto out_e = framework::EigenVector::Flatten(one_channel_out); + out_e.device(*ctx.eigen_device()) = (bin_cnt * inv_s * out_e).round(); + } } } else if (quant_axis == 1) { const int64_t step_i = in.numel() / in_dims[0]; @@ -160,12 +214,20 @@ struct ChannelClipAndFakeQuantFunctor { for (int j = 0; j < in_dims[1]; j++) { T s = scale_data[j]; T inv_s = inverse(s); - auto* start = in_data + i * step_i + j * step_j; - auto* end = in_data + i * step_i + (j + 1) * step_j; - auto* cur_out_data = out_data + i * step_i + j * step_j; - trans(ctx, start, end, cur_out_data, phi::ClipFunctor(-s, s)); - for (int k = 0; k < step_j; k++) { - cur_out_data[k] = std::round(bin_cnt * inv_s * cur_out_data[k]); + auto *start = in_data + i * step_i + j * step_j; + auto *end = in_data + i * step_i + (j + 1) * step_j; + auto *cur_out_data = out_data + i * step_i + j * step_j; + if (round_type == 0) { + trans(ctx, + start, + end, + cur_out_data, + QuantTensorFunctor(static_cast(bin_cnt), inv_s)); + } else { + trans(ctx, start, end, cur_out_data, phi::ClipFunctor(-s, s)); + for (int k = 0; k < step_j; k++) { + cur_out_data[k] = std::round(bin_cnt * inv_s * cur_out_data[k]); + } } } } @@ -177,19 +239,23 @@ template struct ChannelClipAndFakeQuantFunctor; template struct ChannelClipFakeQuantDequantFunctor { - void operator()(const platform::CPUDeviceContext& ctx, - const framework::Tensor& in, const framework::Tensor& scale, - const int bin_cnt, const int quant_axis, - framework::Tensor* out) { + void operator()(const platform::CPUDeviceContext &ctx, + const framework::Tensor &in, + const framework::Tensor &scale, + const int bin_cnt, + const int round_type, + const int quant_axis, + framework::Tensor *out) { PADDLE_ENFORCE_EQ( - quant_axis == 0 || quant_axis == 1, true, + quant_axis == 0 || quant_axis == 1, + true, platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " "the received is %d", quant_axis)); - auto* scale_data = scale.data(); - auto* in_data = in.data(); - auto* out_data = out->mutable_data(ctx.GetPlace()); + auto *scale_data = scale.data(); + auto *in_data = in.data(); + auto *out_data = out->mutable_data(ctx.GetPlace()); auto in_dims = in.dims(); const int64_t channel = in_dims[quant_axis]; platform::Transform trans; @@ -197,18 +263,35 @@ struct ChannelClipFakeQuantDequantFunctor { const int64_t channel_size = in.numel() / channel; for (int i = 0; i < channel; i++) { T s = scale_data[i]; - auto* start = in_data + i * channel_size; - auto* end = in_data + (i + 1) * channel_size; - trans(ctx, start, end, out_data + i * channel_size, - phi::ClipFunctor(-s, s)); + auto *start = in_data + i * channel_size; + auto *end = in_data + (i + 1) * channel_size; + if (round_type == 0) { + T inv_s = inverse(s); + trans(ctx, + start, + end, + out_data + i * channel_size, + QuantTensorFunctor(static_cast(bin_cnt), inv_s)); + } else { + trans(ctx, + start, + end, + out_data + i * channel_size, + phi::ClipFunctor(-s, s)); + } } for (int i = 0; i < channel; i++) { T s = scale_data[i]; - T inv_s = inverse(s); framework::Tensor one_channel_out = out->Slice(i, i + 1); auto out_e = framework::EigenVector::Flatten(one_channel_out); - out_e.device(*ctx.eigen_device()) = - (bin_cnt * inv_s * out_e).round() * s / static_cast(bin_cnt); + if (round_type == 0) { + out_e.device(*ctx.eigen_device()) = + out_e * s / static_cast(bin_cnt); + } else { + T inv_s = inverse(s); + out_e.device(*ctx.eigen_device()) = + (bin_cnt * inv_s * out_e).round() * s / static_cast(bin_cnt); + } } } else if (quant_axis == 1) { const int64_t step_i = in.numel() / in_dims[0]; @@ -217,13 +300,25 @@ struct ChannelClipFakeQuantDequantFunctor { for (int j = 0; j < in_dims[1]; j++) { T s = scale_data[j]; T inv_s = inverse(s); - auto* start = in_data + i * step_i + j * step_j; - auto* end = in_data + i * step_i + (j + 1) * step_j; - auto* cur_out_data = out_data + i * step_i + j * step_j; - trans(ctx, start, end, cur_out_data, phi::ClipFunctor(-s, s)); + auto *start = in_data + i * step_i + j * step_j; + auto *end = in_data + i * step_i + (j + 1) * step_j; + auto *cur_out_data = out_data + i * step_i + j * step_j; + if (round_type == 0) { + trans(ctx, + start, + end, + cur_out_data, + QuantTensorFunctor(static_cast(bin_cnt), inv_s)); + } else { + trans(ctx, start, end, cur_out_data, phi::ClipFunctor(-s, s)); + } for (int k = 0; k < step_j; k++) { - cur_out_data[k] = std::round(bin_cnt * inv_s * cur_out_data[k]) * - s / static_cast(bin_cnt); + if (round_type == 0) { + cur_out_data[k] = cur_out_data[k] * s / static_cast(bin_cnt); + } else { + cur_out_data[k] = std::round(bin_cnt * inv_s * cur_out_data[k]) * + s / static_cast(bin_cnt); + } } } } @@ -235,12 +330,14 @@ template struct ChannelClipFakeQuantDequantFunctor; template struct FindRangeAbsMaxFunctor { - void operator()(const platform::CPUDeviceContext& ctx, - const framework::Tensor& cur_scale, - const framework::Tensor& last_scale, - const framework::Tensor& iter, const int window_size, - framework::Tensor* scales_arr, framework::Tensor* out_scale) { - T* scale_arr = scales_arr->mutable_data(ctx.GetPlace()); + void operator()(const platform::CPUDeviceContext &ctx, + const framework::Tensor &cur_scale, + const framework::Tensor &last_scale, + const framework::Tensor &iter, + const int window_size, + framework::Tensor *scales_arr, + framework::Tensor *out_scale) { + T *scale_arr = scales_arr->mutable_data(ctx.GetPlace()); int64_t it = iter.data()[0]; int idx = it % window_size; T removed = scale_arr[idx]; @@ -252,8 +349,8 @@ struct FindRangeAbsMaxFunctor { max = cur; } else if (fabs(removed - max) < 1e-6) { int size = (it > window_size) ? window_size : it; - FindAbsMaxFunctor()(ctx, scale_arr, size, - &max); + FindAbsMaxFunctor()( + ctx, scale_arr, size, &max); } out_scale->mutable_data(ctx.GetPlace())[0] = max; } @@ -263,11 +360,14 @@ template struct FindRangeAbsMaxFunctor; template struct FindMovingAverageAbsMaxFunctor { - void operator()(const platform::CPUDeviceContext& ctx, - const framework::Tensor& in_accum, - const framework::Tensor& in_state, const T* cur_scale, - const float rate, framework::Tensor* out_state, - framework::Tensor* out_accum, framework::Tensor* out_scale) { + void operator()(const platform::CPUDeviceContext &ctx, + const framework::Tensor &in_accum, + const framework::Tensor &in_state, + const T *cur_scale, + const float rate, + framework::Tensor *out_state, + framework::Tensor *out_accum, + framework::Tensor *out_scale) { T accum = in_accum.data()[0]; T state = in_state.data()[0]; T scale = cur_scale[0]; @@ -287,18 +387,22 @@ template struct FindMovingAverageAbsMaxFunctorHasInput("X"), "Input", "X", + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK( + ctx->HasInput("X"), "Input", "X", "FakeQuantOrWithDequantAbsMaxOp"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), + "Output", + "Out", "FakeQuantOrWithDequantAbsMaxOp"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", - "FakeQuantOrWithDequantAbsMaxOp"); - OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale", + OP_INOUT_CHECK(ctx->HasOutput("OutScale"), + "Output", + "OutScale", "FakeQuantOrWithDequantAbsMaxOp"); ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->SetOutputDim("OutScale", {1}); @@ -307,7 +411,7 @@ class FakeQuantOrWithDequantAbsMaxOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { + const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context()); @@ -325,13 +429,32 @@ class FakeQuantOrWithDequantAbsMaxOpMaker AddOutput("OutScale", "(Tensor) Current scale"); AddAttr("bit_length", "(int, default 8)") .SetDefault(8) - .AddCustomChecker([](const int& bit_length) { - PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true, + .AddCustomChecker([](const int &bit_length) { + PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, + true, platform::errors::InvalidArgument( "'bit_length' should be between 1 and 16, but " "the received is %d", bit_length)); }); + AddAttr( + "round_type", + "(int, default 1) The round type of fp32 to int." + "0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2" + "1: rounding to nearest ties away from zero. Eg: round(1.5)=2, " + "round(2.5)=3") + .SetDefault(1) + .AddCustomChecker([](const int &round_type) { + PADDLE_ENFORCE_EQ( + round_type == 0 || round_type == 1, + true, + platform::errors::InvalidArgument( + "'round_type' should be 0 or 1, 0 rounding to " + "nearest ties to even and 1 is rounding to nearest " + "ties away from zero.but the received is %d", + round_type)); + }) + .AsExtra(); AddComment(R"DOC( This is a Base Op which supports FakeQuantAbsMaxOpMaker and FakeQuantDequantAbsMaxOpMaker. FakeQuantAbsMaxOp operator is used in the dynamic quantization. @@ -354,12 +477,16 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", - "FakeChannelWiseQuantizeAbsMax"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK( + ctx->HasInput("X"), "Input", "X", "FakeChannelWiseQuantizeAbsMax"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), + "Output", + "Out", "FakeChannelWiseQuantizeAbsMax"); - OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale", + OP_INOUT_CHECK(ctx->HasOutput("OutScale"), + "Output", + "OutScale", "FakeChannelWiseQuantizeAbsMax"); int quant_axis = ctx->Attrs().Get("quant_axis"); ctx->SetOutputDim("Out", ctx->GetInputDim("X")); @@ -369,7 +496,7 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { + const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } @@ -389,8 +516,9 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker "For conv2d, depthwise_conv2d, conv2d_transpose " "and mul, the quant_axis is equal to the cout axis.") .SetDefault(0) - .AddCustomChecker([](const int& quant_axis) { - PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1, true, + .AddCustomChecker([](const int &quant_axis) { + PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1, + true, platform::errors::InvalidArgument( "'quant_axis' should be 0 or 1, but " "the received is %d", @@ -398,13 +526,32 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker }); AddAttr("bit_length", "(int, default 8)") .SetDefault(8) - .AddCustomChecker([](const int& bit_length) { - PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true, + .AddCustomChecker([](const int &bit_length) { + PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, + true, platform::errors::InvalidArgument( "'bit_length' should be between 1 and 16, but " "the received is %d", bit_length)); }); + AddAttr( + "round_type", + "(int, default 1) The round type of fp32 to int." + "0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2" + "1: rounding to nearest ties away from zero. Eg: round(1.5)=2, " + "round(2.5)=3") + .SetDefault(1) + .AddCustomChecker([](const int &round_type) { + PADDLE_ENFORCE_EQ( + round_type == 0 || round_type == 1, + true, + platform::errors::InvalidArgument( + "'round_type' should be 0 or 1, 0 rounding to " + "nearest ties to even and 1 is rounding to nearest " + "ties away from zero.but the received is %d", + round_type)); + }) + .AsExtra(); AddAttr("is_test", "(bool, default false) Set to true for inference only, false " "for training. Some layers may run faster when this is true.") @@ -427,12 +574,18 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxOp public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), + "Input", + "X", "FakeChannelWiseQuantizeDequantizeAbsMax"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", + OP_INOUT_CHECK(ctx->HasOutput("Out"), + "Output", + "Out", "FakeChannelWiseQuantizeDequantizeAbsMax"); - OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale", + OP_INOUT_CHECK(ctx->HasOutput("OutScale"), + "Output", + "OutScale", "FakeChannelWiseQuantizeDequantizeAbsMax"); int quant_axis = ctx->Attrs().Get("quant_axis"); ctx->SetOutputDim("Out", ctx->GetInputDim("X")); @@ -442,7 +595,7 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxOp protected: framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { + const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } @@ -462,8 +615,9 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxOpMaker "For conv2d, depthwise_conv2d, conv2d_transpose " "and mul, the quant_axis is equal to the cout axis.") .SetDefault(0) - .AddCustomChecker([](const int& quant_axis) { - PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1, true, + .AddCustomChecker([](const int &quant_axis) { + PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1, + true, platform::errors::InvalidArgument( "'quant_axis' should be 0 or 1, but " "the received is %d", @@ -471,13 +625,32 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxOpMaker }); AddAttr("bit_length", "(int, default 8)") .SetDefault(8) - .AddCustomChecker([](const int& bit_length) { - PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true, + .AddCustomChecker([](const int &bit_length) { + PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, + true, platform::errors::InvalidArgument( "'bit_length' should be between 1 and 16, but " "the received is %d", bit_length)); }); + AddAttr( + "round_type", + "(int, default 1) The round type of fp32 to int." + "0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2" + "1: rounding to nearest ties away from zero. Eg: round(1.5)=2, " + "round(2.5)=3") + .SetDefault(1) + .AddCustomChecker([](const int &round_type) { + PADDLE_ENFORCE_EQ( + round_type == 0 || round_type == 1, + true, + platform::errors::InvalidArgument( + "'round_type' should be 0 or 1, 0 rounding to " + "nearest ties to even and 1 is rounding to nearest " + "ties away from zero.but the received is %d", + round_type)); + }) + .AsExtra(); AddComment(R"DOC( The scale of FakeChannelWiseQuantize operator is a vector. In detail, each channel of the input X has a scale value. @@ -493,17 +666,19 @@ $$0 \leq c \lt \ the\ channel\ number\ of\ X$$ class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel { public: - FakeQuantizeRangeAbsMaxOp(const std::string& type, - const framework::VariableNameMap& inputs, - const framework::VariableNameMap& outputs, - const framework::AttributeMap& attrs) + FakeQuantizeRangeAbsMaxOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) : OperatorWithKernel(type, inputs, outputs, attrs) {} - void InferShape(framework::InferShapeContext* ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FakeQuantizeRangeAbsMax"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", - "FakeQuantizeRangeAbsMax"); - OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale", + OP_INOUT_CHECK( + ctx->HasOutput("Out"), "Output", "Out", "FakeQuantizeRangeAbsMax"); + OP_INOUT_CHECK(ctx->HasOutput("OutScale"), + "Output", + "OutScale", "FakeQuantizeRangeAbsMax"); if (ctx->HasOutput("OutScales")) { int window_size = ctx->Attrs().Get("window_size"); @@ -516,7 +691,7 @@ class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { + const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context()); @@ -537,13 +712,32 @@ class FakeQuantizeRangeAbsMaxOpMaker .SetDefault(10000); AddAttr("bit_length", "(int, default 8), quantization bit number.") .SetDefault(8) - .AddCustomChecker([](const int& bit_length) { - PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true, + .AddCustomChecker([](const int &bit_length) { + PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, + true, platform::errors::InvalidArgument( "'bit_length' should be between 1 and 16, but " "the received is %d", bit_length)); }); + AddAttr( + "round_type", + "(int, default 1) The round type of fp32 to int." + "0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2" + "1: rounding to nearest ties away from zero. Eg: round(1.5)=2, " + "round(2.5)=3") + .SetDefault(1) + .AddCustomChecker([](const int &round_type) { + PADDLE_ENFORCE_EQ( + round_type == 0 || round_type == 1, + true, + platform::errors::InvalidArgument( + "'round_type' should be 0 or 1, 0 rounding to " + "nearest ties to even and 1 is rounding to nearest " + "ties away from zero.but the received is %d", + round_type)); + }) + .AsExtra(); AddAttr("is_test", "(bool, default false) Set to true for inference only, false " "for training. Some layers may run faster when this is true.") @@ -563,17 +757,24 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOp : public framework::OperatorWithKernel { public: FakeQuantOrWithDequantMovingAverageAbsMaxOp( - const std::string& type, const framework::VariableNameMap& inputs, - const framework::VariableNameMap& outputs, - const framework::AttributeMap& attrs) + const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) : OperatorWithKernel(type, inputs, outputs, attrs) {} - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), + "Input", + "X", "FakeQuantOrWithDequantMovingAverageAbsMax"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", + OP_INOUT_CHECK(ctx->HasOutput("Out"), + "Output", + "Out", "FakeQuantOrWithDequantMovingAverageAbsMax"); - OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale", + OP_INOUT_CHECK(ctx->HasOutput("OutScale"), + "Output", + "OutScale", "FakeQuantOrWithDequantMovingAverageAbsMax"); if (ctx->HasOutput("OutState")) { ctx->SetOutputDim("OutState", {1}); @@ -588,7 +789,7 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOp protected: framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { + const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context()); @@ -611,13 +812,32 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker .SetDefault(0.9); AddAttr("bit_length", "(int, default 8), quantization bit number.") .SetDefault(8) - .AddCustomChecker([](const int& bit_length) { - PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true, + .AddCustomChecker([](const int &bit_length) { + PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, + true, platform::errors::InvalidArgument( "'bit_length' should be between 1 and 16, but " "the received is %d", bit_length)); }); + AddAttr( + "round_type", + "(int, default 1) The round type of fp32 to int." + "0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2" + "1: rounding to nearest ties away from zero. Eg: round(1.5)=2, " + "round(2.5)=3") + .SetDefault(1) + .AddCustomChecker([](const int &round_type) { + PADDLE_ENFORCE_EQ( + round_type == 0 || round_type == 1, + true, + platform::errors::InvalidArgument( + "'round_type' should be 0 or 1, 0 rounding to " + "nearest ties to even and 1 is rounding to nearest " + "ties away from zero.but the received is %d", + round_type)); + }) + .AsExtra(); AddAttr("is_test", "(bool, default false) Set to true for inference only, false " "for training. Some layers may run faster when this is true.") @@ -644,10 +864,12 @@ class MovingAverageAbsMaxScaleOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", - "MovingAverageAbsMaxScale"); - OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale", + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK( + ctx->HasInput("X"), "Input", "X", "MovingAverageAbsMaxScale"); + OP_INOUT_CHECK(ctx->HasOutput("OutScale"), + "Output", + "OutScale", "MovingAverageAbsMaxScale"); if (ctx->HasOutput("OutState")) { @@ -665,7 +887,7 @@ class MovingAverageAbsMaxScaleOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { + const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } @@ -705,19 +927,23 @@ class StrightThroughEstimatorGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { auto out_grad_name = framework::GradVarName("Out"); auto x_grad_name = framework::GradVarName("X"); - OP_INOUT_CHECK(ctx->HasInput(out_grad_name), "Input", out_grad_name, + OP_INOUT_CHECK(ctx->HasInput(out_grad_name), + "Input", + out_grad_name, "StrightThroughEstimatorGradOp"); - OP_INOUT_CHECK(ctx->HasOutput(x_grad_name), "Output", x_grad_name, + OP_INOUT_CHECK(ctx->HasOutput(x_grad_name), + "Output", + x_grad_name, "StrightThroughEstimatorGradOp"); ctx->SetOutputDim(x_grad_name, ctx->GetInputDim(out_grad_name)); } framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { + const framework::ExecutionContext &ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); return framework::OpKernelType(input_data_type, ctx.GetPlace()); @@ -745,7 +971,8 @@ namespace ops = paddle::operators; using CPU = paddle::platform::CPUDeviceContext; REGISTER_OPERATOR( - fake_quantize_abs_max, ops::FakeQuantOrWithDequantAbsMaxOp, + fake_quantize_abs_max, + ops::FakeQuantOrWithDequantAbsMaxOp, ops::FakeQuantOrWithDequantAbsMaxOpMaker, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker); @@ -753,7 +980,8 @@ REGISTER_OP_CPU_KERNEL(fake_quantize_abs_max, ops::FakeQuantizeAbsMaxKernel); REGISTER_OPERATOR( - fake_quantize_dequantize_abs_max, ops::FakeQuantOrWithDequantAbsMaxOp, + fake_quantize_dequantize_abs_max, + ops::FakeQuantOrWithDequantAbsMaxOp, ops::FakeQuantOrWithDequantAbsMaxOpMaker, ops::StrightThroughEstimatorMaker, ops::StrightThroughEstimatorMaker); @@ -761,7 +989,8 @@ REGISTER_OP_CPU_KERNEL(fake_quantize_dequantize_abs_max, ops::FakeQuantizeDequantizeAbsMaxKernel); REGISTER_OPERATOR( - fake_quantize_range_abs_max, ops::FakeQuantizeRangeAbsMaxOp, + fake_quantize_range_abs_max, + ops::FakeQuantizeRangeAbsMaxOp, ops::FakeQuantizeRangeAbsMaxOpMaker, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker); @@ -788,7 +1017,8 @@ REGISTER_OP_CPU_KERNEL( ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel); REGISTER_OPERATOR( - fake_channel_wise_quantize_abs_max, ops::FakeChannelWiseQuantizeAbsMaxOp, + fake_channel_wise_quantize_abs_max, + ops::FakeChannelWiseQuantizeAbsMaxOp, ops::FakeChannelWiseQuantizeAbsMaxOpMaker, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker); @@ -796,7 +1026,8 @@ REGISTER_OP_CPU_KERNEL(fake_channel_wise_quantize_abs_max, ops::FakeChannelWiseQuantizeAbsMaxKernel); REGISTER_OPERATOR( - moving_average_abs_max_scale, ops::MovingAverageAbsMaxScaleOp, + moving_average_abs_max_scale, + ops::MovingAverageAbsMaxScaleOp, ops::MovingAverageAbsMaxScaleOpMaker, ops::StrightThroughEstimatorMaker, ops::StrightThroughEstimatorMaker); @@ -832,7 +1063,7 @@ REGISTER_OP_VERSION(moving_average_abs_max_scale) "Delete output in order to make the inference model not " "save moving_average_abs_max_scale operator. This will " "make the quantitative model be correctly applied in inference.")) - .AddCheckpoint( - R"ROC(Incompatible upgrade of output [Out])ROC", - paddle::framework::compatible::OpVersionDesc().NewOutput( - "Out", "In order to support dygraph qat, add output again.")); + .AddCheckpoint(R"ROC(Incompatible upgrade of output [Out])ROC", + paddle::framework::compatible::OpVersionDesc().NewOutput( + "Out", + "In order to support dygraph qat, add output again.")); diff --git a/paddle/fluid/operators/fake_quantize_op.cu.h b/paddle/fluid/operators/fake_quantize_op.cu.h index e809c553684..3b1877f2bc8 100644 --- a/paddle/fluid/operators/fake_quantize_op.cu.h +++ b/paddle/fluid/operators/fake_quantize_op.cu.h @@ -36,12 +36,12 @@ struct QuantizeDataType { }; template -__global__ void FindAbsMaxKernel(const T* in, const int n, T* out) { +__global__ void FindAbsMaxKernel(const T *in, const int n, T *out) { int bid = threadIdx.x + blockIdx.x * blockDim.x; int tid = threadIdx.x; - extern __shared__ char* shared_max_data_tmp[]; - auto shared_max_data = reinterpret_cast(shared_max_data_tmp); + extern __shared__ char *shared_max_data_tmp[]; + auto shared_max_data = reinterpret_cast(shared_max_data_tmp); if (gridDim.x > 1) { T local_max_data = T(0); for (int i = bid; i < n; i += blockDim.x * gridDim.x) { @@ -73,18 +73,20 @@ __global__ void FindAbsMaxKernel(const T* in, const int n, T* out) { template struct FindAbsMaxFunctor { - void operator()(const platform::CUDADeviceContext& ctx, const T* in, - const int num, T* out) { + void operator()(const platform::CUDADeviceContext &ctx, + const T *in, + const int num, + T *out) { int block = 1024; int grid = (block - 1 + num) / block; grid = (grid > block) ? block : grid; framework::Tensor max; - T* max_data = max.mutable_data(phi::make_ddim({grid}), ctx.GetPlace()); - FindAbsMaxKernel<<>>( - in, num, max_data); - FindAbsMaxKernel<<<1, block, 1024 * sizeof(T), ctx.stream()>>>( - max_data, grid, out); + T *max_data = max.mutable_data(phi::make_ddim({grid}), ctx.GetPlace()); + FindAbsMaxKernel + <<>>(in, num, max_data); + FindAbsMaxKernel + <<<1, block, 1024 * sizeof(T), ctx.stream()>>>(max_data, grid, out); } }; @@ -93,13 +95,15 @@ template struct FindAbsMaxFunctor; template -__global__ void FindChannelAbsMaxKernelQuantAxis0(const T* in, const int n, - const int c, T* out) { +__global__ void FindChannelAbsMaxKernelQuantAxis0(const T *in, + const int n, + const int c, + T *out) { int tid = threadIdx.x; int channel_size = n / c; - const T* in_c = in + blockIdx.x * channel_size; - extern __shared__ char* shared_max_data_tmp[]; - auto shared_max_data = reinterpret_cast(shared_max_data_tmp); + const T *in_c = in + blockIdx.x * channel_size; + extern __shared__ char *shared_max_data_tmp[]; + auto shared_max_data = reinterpret_cast(shared_max_data_tmp); T local_max_data = T(0); for (int i = tid; i < channel_size; i += blockDim.x) { T tmp = static_cast( @@ -122,17 +126,16 @@ __global__ void FindChannelAbsMaxKernelQuantAxis0(const T* in, const int n, } template -__global__ void FindChannelAbsMaxKernelQuantAxis1(const T* in, const int n, - const int cin, const int cout, - T* out) { - extern __shared__ char* shared_max_data_tmp[]; - auto shared_max_data = reinterpret_cast(shared_max_data_tmp); +__global__ void FindChannelAbsMaxKernelQuantAxis1( + const T *in, const int n, const int cin, const int cout, T *out) { + extern __shared__ char *shared_max_data_tmp[]; + auto shared_max_data = reinterpret_cast(shared_max_data_tmp); int cout_wh_size = n / cin; int wh_size = n / (cin * cout); int tid = threadIdx.x; int bid = blockIdx.x; - const T* in_current = in + tid * cout_wh_size + bid * wh_size; + const T *in_current = in + tid * cout_wh_size + bid * wh_size; T local_max_data = T(0); for (int i = 0; i < wh_size; i++) { T tmp = static_cast( @@ -162,24 +165,26 @@ __global__ void FindChannelAbsMaxKernelQuantAxis1(const T* in, const int n, template struct FindChannelAbsMaxFunctor { - void operator()(const platform::CUDADeviceContext& ctx, - const framework::Tensor& in_tensor, const int quant_axis, - T* out_abs_max) { + void operator()(const platform::CUDADeviceContext &ctx, + const framework::Tensor &in_tensor, + const int quant_axis, + T *out_abs_max) { PADDLE_ENFORCE_EQ( - quant_axis == 0 || quant_axis == 1, true, + quant_axis == 0 || quant_axis == 1, + true, platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " "the received is %d", quant_axis)); const int num = in_tensor.numel(); auto in_dims = in_tensor.dims(); - const T* in_data = in_tensor.data(); + const T *in_data = in_tensor.data(); if (quant_axis == 0) { int cout = in_dims[0]; int grid = cout; int block = 1024; - FindChannelAbsMaxKernelQuantAxis0< - T><<>>( - in_data, num, cout, out_abs_max); + FindChannelAbsMaxKernelQuantAxis0 + <<>>( + in_data, num, cout, out_abs_max); } else if (quant_axis == 1) { int cin = in_dims[0]; int cout = in_dims[1]; @@ -194,17 +199,17 @@ struct FindChannelAbsMaxFunctor { for (int i = 0; i < cin / max_threads; i++) { int block = max_threads; - FindChannelAbsMaxKernelQuantAxis1< - T><<>>( - in_data, num, cin, cout, out_abs_max); + FindChannelAbsMaxKernelQuantAxis1 + <<>>( + in_data, num, cin, cout, out_abs_max); in_data += num / cin; } int block = cin % max_threads; if (block > 0) { - FindChannelAbsMaxKernelQuantAxis1< - T><<>>( - in_data, num, in_dims[0], in_dims[1], out_abs_max); + FindChannelAbsMaxKernelQuantAxis1 + <<>>( + in_data, num, in_dims[0], in_dims[1], out_abs_max); } } } @@ -213,8 +218,12 @@ struct FindChannelAbsMaxFunctor { template struct FindChannelAbsMaxFunctor; template -__global__ void ClipAndQuantKernel(const T* in, const T* scale, - const int bin_cnt, const int n, T* out) { +__global__ void ClipAndQuantKernel(const T *in, + const T *scale, + const int bin_cnt, + const int round_type, + const int n, + T *out) { int bid = threadIdx.x + blockIdx.x * blockDim.x; int tid = threadIdx.x; @@ -226,17 +235,30 @@ __global__ void ClipAndQuantKernel(const T* in, const T* scale, for (int i = bid; i < n; i += blockDim.x * gridDim.x) { ComputeDataType x = static_cast(in[i]); - ComputeDataType v = x > s ? s : x; - v = v < -s ? -s : v; - v = bin_cnt_t * inv_s * v; - out[i] = static_cast(round(v)); + if (round_type == 0) { + x = bin_cnt_t * inv_s * x; + x = roundWithTiesToEven(x); + ComputeDataType max_bound = bin_cnt_t; + ComputeDataType min_bound = -bin_cnt_t - static_cast(1); + x = x > max_bound ? max_bound : x; + x = x < min_bound ? min_bound : x; + out[i] = static_cast(x); + } else { + ComputeDataType v = x > s ? s : x; + v = v < -s ? -s : v; + v = bin_cnt_t * inv_s * v; + out[i] = static_cast(round(v)); + } } } template -__global__ void ClipAndQuantDequantKernel(const T* in, const T* scale, - const int bin_cnt, const int n, - T* out) { +__global__ void ClipAndQuantDequantKernel(const T *in, + const T *scale, + const int bin_cnt, + const int round_type, + const int n, + T *out) { int bid = threadIdx.x + blockIdx.x * blockDim.x; int tid = threadIdx.x; @@ -248,29 +270,42 @@ __global__ void ClipAndQuantDequantKernel(const T* in, const T* scale, for (int i = bid; i < n; i += blockDim.x * gridDim.x) { ComputeDataType x = static_cast(in[i]); - x = x > s ? s : x; - x = x < -s ? -s : x; - x = bin_cnt_t * inv_s * x; - x = round(x); - out[i] = static_cast((x * s) / bin_cnt_t); + if (round_type == 0) { + x = bin_cnt_t * inv_s * x; + x = roundWithTiesToEven(x); + ComputeDataType max_bound = bin_cnt_t; + ComputeDataType min_bound = -bin_cnt_t - static_cast(1); + x = x > max_bound ? max_bound : x; + x = x < min_bound ? min_bound : x; + out[i] = static_cast((x * s) / bin_cnt_t); + } else { + x = x > s ? s : x; + x = x < -s ? -s : x; + x = bin_cnt_t * inv_s * x; + x = round(x); + out[i] = static_cast((x * s) / bin_cnt_t); + } } } template struct ClipAndFakeQuantFunctor { - void operator()(const platform::CUDADeviceContext& ctx, - const framework::Tensor& in, const framework::Tensor& scale, - const int bin_cnt, framework::Tensor* out) { + void operator()(const platform::CUDADeviceContext &ctx, + const framework::Tensor &in, + const framework::Tensor &scale, + const int bin_cnt, + const int round_type, + framework::Tensor *out) { int num = in.numel(); int block = 1024; int grid = (block - 1 + num) / block; - const T* in_data = in.data(); - const T* scale_data = scale.data(); - T* out_data = out->mutable_data(ctx.GetPlace()); + const T *in_data = in.data(); + const T *scale_data = scale.data(); + T *out_data = out->mutable_data(ctx.GetPlace()); ClipAndQuantKernel<<>>( - in_data, scale_data, bin_cnt, num, out_data); + in_data, scale_data, bin_cnt, round_type, num, out_data); } }; @@ -278,33 +313,39 @@ template struct ClipAndFakeQuantFunctor; template struct ClipAndFakeQuantDequantFunctor { - void operator()(const platform::CUDADeviceContext& ctx, - const framework::Tensor& in, const framework::Tensor& scale, - const int bin_cnt, framework::Tensor* out) { + void operator()(const platform::CUDADeviceContext &ctx, + const framework::Tensor &in, + const framework::Tensor &scale, + const int bin_cnt, + const int round_type, + framework::Tensor *out) { int num = in.numel(); int block = 1024; int grid = (block - 1 + num) / block; - const T* in_data = in.data(); - const T* scale_data = scale.data(); - T* out_data = out->mutable_data(ctx.GetPlace()); + const T *in_data = in.data(); + const T *scale_data = scale.data(); + T *out_data = out->mutable_data(ctx.GetPlace()); ClipAndQuantDequantKernel<<>>( - in_data, scale_data, bin_cnt, num, out_data); + in_data, scale_data, bin_cnt, round_type, num, out_data); } }; // ChannelClipAndQuantKernel for quant_axis is 0 template -__global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale, +__global__ void ChannelClipAndQuantKernelQuantAxis0(const T *in, + const T *scale, const int bin_cnt, + const int round_type, const int64_t n, - const int c, T* out) { + const int c, + T *out) { int tid = threadIdx.x; int64_t channel_size = n / c; - const T* in_c = in + blockIdx.x * channel_size; - T* out_c = out + blockIdx.x * channel_size; + const T *in_c = in + blockIdx.x * channel_size; + T *out_c = out + blockIdx.x * channel_size; using ComputeDataType = typename QuantizeDataType::type; @@ -314,18 +355,33 @@ __global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale, for (int64_t i = tid; i < channel_size; i += blockDim.x) { ComputeDataType x = static_cast(in_c[i]); - ComputeDataType v = x > s ? s : x; - v = v < -s ? -s : v; - v = bin_cnt_t * inv_s * v; - out_c[i] = static_cast(round(v)); + if (round_type == 0) { + x = bin_cnt_t * inv_s * x; + x = roundWithTiesToEven(x); + ComputeDataType max_bound = bin_cnt_t; + ComputeDataType min_bound = -bin_cnt_t - static_cast(1); + x = x > max_bound ? max_bound : x; + x = x < min_bound ? min_bound : x; + out_c[i] = static_cast(x); + } else { + ComputeDataType v = x > s ? s : x; + v = v < -s ? -s : v; + v = bin_cnt_t * inv_s * v; + out_c[i] = static_cast(round(v)); + } } } // ChannelClipAndQuantKernel for quant_axis is N template -__global__ void ChannelClipAndQuantKernelQuantAxisN( - const T* in, const T* scale, const int bin_cnt, const int64_t n, - const int nScale, const int quant_stride, T* out) { +__global__ void ChannelClipAndQuantKernelQuantAxisN(const T *in, + const T *scale, + const int bin_cnt, + const int round_type, + const int64_t n, + const int nScale, + const int quant_stride, + T *out) { int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; using ComputeDataType = typename QuantizeDataType::type; ComputeDataType bin_cnt_t = static_cast(bin_cnt); @@ -334,36 +390,50 @@ __global__ void ChannelClipAndQuantKernelQuantAxisN( static_cast(scale[(i / quant_stride) % nScale]); ComputeDataType inv_s = inverse(s); ComputeDataType x = static_cast(in[i]); - ComputeDataType v = x > s ? s : x; - v = v < -s ? -s : v; - v = bin_cnt_t * inv_s * v; - out[i] = static_cast(round(v)); + if (round_type == 0) { + x = bin_cnt_t * inv_s * x; + x = roundWithTiesToEven(x); + ComputeDataType max_bound = bin_cnt_t; + ComputeDataType min_bound = -bin_cnt_t - static_cast(1); + x = x > max_bound ? max_bound : x; + x = x < min_bound ? min_bound : x; + out[i] = static_cast(x); + } else { + ComputeDataType v = x > s ? s : x; + v = v < -s ? -s : v; + v = bin_cnt_t * inv_s * v; + out[i] = static_cast(round(v)); + } } } template struct ChannelClipAndFakeQuantFunctor { - void operator()(const platform::CUDADeviceContext& ctx, - const framework::Tensor& in, const framework::Tensor& scale, - const int bin_cnt, const int quant_axis, - framework::Tensor* out) { + void operator()(const platform::CUDADeviceContext &ctx, + const framework::Tensor &in, + const framework::Tensor &scale, + const int bin_cnt, + const int round_type, + const int quant_axis, + framework::Tensor *out) { PADDLE_ENFORCE_EQ( - quant_axis == 0 || quant_axis == 1, true, + quant_axis == 0 || quant_axis == 1, + true, platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " "the received is %d", quant_axis)); int64_t num = in.numel(); auto in_dims = in.dims(); - const T* in_data = in.data(); - const T* scale_data = scale.data(); - T* out_data = out->mutable_data(ctx.GetPlace()); + const T *in_data = in.data(); + const T *scale_data = scale.data(); + T *out_data = out->mutable_data(ctx.GetPlace()); if (quant_axis == 0) { int grid = in_dims[0]; int block = 1024; ChannelClipAndQuantKernelQuantAxis0<<>>( - in_data, scale_data, bin_cnt, num, in_dims[0], out_data); + in_data, scale_data, bin_cnt, round_type, num, in_dims[0], out_data); } else { int quant_stride = 1; for (int i = quant_axis + 1; i < in_dims.size(); i++) { @@ -379,9 +449,15 @@ struct ChannelClipAndFakeQuantFunctor { const int64_t grid_size = std::min(max_blocks, (num + block_size - 1) / block_size); - ChannelClipAndQuantKernelQuantAxisN<<>>( - in_data, scale_data, bin_cnt, num, in_dims[quant_axis], quant_stride, - out_data); + ChannelClipAndQuantKernelQuantAxisN + <<>>(in_data, + scale_data, + bin_cnt, + round_type, + num, + in_dims[quant_axis], + quant_stride, + out_data); } } }; @@ -390,12 +466,14 @@ template struct ChannelClipAndFakeQuantFunctor; template -__global__ void FindRangeAbsMaxAndFillArray(const T* cur_scale, - const T* last_scale, - const int64_t* iter, - const int window_size, T* scale_arr, - T* out_scale, int* need_find_max, - int* out_size) { +__global__ void FindRangeAbsMaxAndFillArray(const T *cur_scale, + const T *last_scale, + const int64_t *iter, + const int window_size, + T *scale_arr, + T *out_scale, + int *need_find_max, + int *out_size) { int it = iter[0]; int idx = it % window_size; T removed = scale_arr[idx]; @@ -414,45 +492,63 @@ __global__ void FindRangeAbsMaxAndFillArray(const T* cur_scale, template struct FindRangeAbsMaxFunctor { - void operator()(const platform::CUDADeviceContext& ctx, - const framework::Tensor& cur_scale, - const framework::Tensor& last_scale, - const framework::Tensor& iter, const int window_size, - framework::Tensor* scales_arr, framework::Tensor* out_scale) { + void operator()(const platform::CUDADeviceContext &ctx, + const framework::Tensor &cur_scale, + const framework::Tensor &last_scale, + const framework::Tensor &iter, + const int window_size, + framework::Tensor *scales_arr, + framework::Tensor *out_scale) { const auto gpu_place = ctx.GetPlace(); - T* scale_arr = scales_arr->mutable_data(gpu_place); - T* out_scale_data = out_scale->mutable_data(gpu_place); + T *scale_arr = scales_arr->mutable_data(gpu_place); + T *out_scale_data = out_scale->mutable_data(gpu_place); framework::Tensor need_find_max, out_size; - int* find_max = need_find_max.mutable_data({1}, gpu_place); - int* out_size_data = out_size.mutable_data({1}, gpu_place); - - FindRangeAbsMaxAndFillArray<<<1, 1, 0, ctx.stream()>>>( - cur_scale.data(), last_scale.data(), iter.data(), - window_size, scale_arr, out_scale_data, find_max, out_size_data); + int *find_max = need_find_max.mutable_data({1}, gpu_place); + int *out_size_data = out_size.mutable_data({1}, gpu_place); + + FindRangeAbsMaxAndFillArray + <<<1, 1, 0, ctx.stream()>>>(cur_scale.data(), + last_scale.data(), + iter.data(), + window_size, + scale_arr, + out_scale_data, + find_max, + out_size_data); int g_find_max; - memory::Copy(platform::CPUPlace(), &g_find_max, gpu_place, find_max, - sizeof(int), ctx.stream()); + memory::Copy(platform::CPUPlace(), + &g_find_max, + gpu_place, + find_max, + sizeof(int), + ctx.stream()); ctx.Wait(); if (g_find_max) { int len; - memory::Copy(platform::CPUPlace(), &len, gpu_place, out_size_data, - sizeof(int), ctx.stream()); + memory::Copy(platform::CPUPlace(), + &len, + gpu_place, + out_size_data, + sizeof(int), + ctx.stream()); ctx.Wait(); - FindAbsMaxFunctor()(ctx, scale_arr, len, - out_scale_data); + FindAbsMaxFunctor()( + ctx, scale_arr, len, out_scale_data); } } }; template -__global__ void FindMovingAverageAbsMaxKernel(const T* in_state, - const T* in_accum, - const T* cur_scale, const T rate, - T* out_state, T* out_accum, - T* out_scale) { +__global__ void FindMovingAverageAbsMaxKernel(const T *in_state, + const T *in_accum, + const T *cur_scale, + const T rate, + T *out_state, + T *out_accum, + T *out_scale) { T state = rate * (*in_state) + T(1.0f); T accum = rate * (*in_accum) + (*cur_scale); *out_state = state; @@ -464,78 +560,119 @@ template struct FindRangeAbsMaxFunctor; template struct FindMovingAverageAbsMaxFunctor { - void operator()(const platform::CUDADeviceContext& ctx, - const framework::Tensor& in_accum, - const framework::Tensor& in_state, const T* cur_scale, - const float rate, framework::Tensor* out_state, - framework::Tensor* out_accum, framework::Tensor* out_scale) { + void operator()(const platform::CUDADeviceContext &ctx, + const framework::Tensor &in_accum, + const framework::Tensor &in_state, + const T *cur_scale, + const float rate, + framework::Tensor *out_state, + framework::Tensor *out_accum, + framework::Tensor *out_scale) { const auto gpu_place = ctx.GetPlace(); T rate_t = static_cast(rate); - T* out_state_data = out_state->mutable_data(gpu_place); - T* out_accum_data = out_accum->mutable_data(gpu_place); - T* out_scale_data = out_scale->mutable_data(gpu_place); - - FindMovingAverageAbsMaxKernel<<<1, 1, 0, ctx.stream()>>>( - in_state.data(), in_accum.data(), cur_scale, rate_t, - out_state_data, out_accum_data, out_scale_data); + T *out_state_data = out_state->mutable_data(gpu_place); + T *out_accum_data = out_accum->mutable_data(gpu_place); + T *out_scale_data = out_scale->mutable_data(gpu_place); + + FindMovingAverageAbsMaxKernel + <<<1, 1, 0, ctx.stream()>>>(in_state.data(), + in_accum.data(), + cur_scale, + rate_t, + out_state_data, + out_accum_data, + out_scale_data); } }; // ChannelClipAndQuantDequantKernel for quant_axis is 0 template -__global__ void ChannelClipAndQuantDequantKernelQuantAxis0( - const T* in, const T* scale, const int bin_cnt, const int n, const int c, - T* out) { +__global__ void ChannelClipAndQuantDequantKernelQuantAxis0(const T *in, + const T *scale, + const int bin_cnt, + const int round_type, + const int n, + const int c, + T *out) { int tid = threadIdx.x; int channel_size = n / c; - const T* in_c = in + blockIdx.x * channel_size; - T* out_c = out + blockIdx.x * channel_size; + const T *in_c = in + blockIdx.x * channel_size; + T *out_c = out + blockIdx.x * channel_size; T s = scale[blockIdx.x]; T inv_s = inverse(s); for (int i = tid; i < channel_size; i += blockDim.x) { T x = in_c[i]; - T v = x > s ? s : x; - v = v < -s ? -s : v; - v = bin_cnt * inv_s * v; - out_c[i] = round(v) * s / bin_cnt; + if (round_type == 0) { + x = bin_cnt * inv_s * x; + x = roundWithTiesToEven(x); + T max_bound = bin_cnt; + T min_bound = -bin_cnt - static_cast(1); + x = x > max_bound ? max_bound : x; + x = x < min_bound ? min_bound : x; + out_c[i] = (x * s) / bin_cnt; + } else { + T v = x > s ? s : x; + v = v < -s ? -s : v; + v = bin_cnt * inv_s * v; + out_c[i] = round(v) * s / bin_cnt; + } } } // ChannelClipAndQuantDequantKernel for quant_axis is 1 template -__global__ void ChannelClipAndQuantDequantKernelQuantAxis1( - const T* in, const T* scale, const int bin_cnt, const int n, const int cin, - const int cout, T* out) { +__global__ void ChannelClipAndQuantDequantKernelQuantAxis1(const T *in, + const T *scale, + const int bin_cnt, + const int round_type, + const int n, + const int cin, + const int cout, + T *out) { T s = scale[blockIdx.x % cout]; T inv_s = inverse(s); int wh_size = n / (cin * cout); - const T* in_c = in + blockIdx.x * wh_size; - T* out_c = out + blockIdx.x * wh_size; + const T *in_c = in + blockIdx.x * wh_size; + T *out_c = out + blockIdx.x * wh_size; for (int i = threadIdx.x; i < wh_size; i += blockDim.x) { T x = in_c[i]; - T v = x > s ? s : x; - v = v < -s ? -s : v; - v = bin_cnt * inv_s * v; - out_c[i] = round(v) * s / bin_cnt; + if (round_type == 0) { + x = bin_cnt * inv_s * x; + x = roundWithTiesToEven(x); + T max_bound = bin_cnt; + T min_bound = -bin_cnt - static_cast(1); + x = x > max_bound ? max_bound : x; + x = x < min_bound ? min_bound : x; + out_c[i] = (x * s) / bin_cnt; + } else { + T v = x > s ? s : x; + v = v < -s ? -s : v; + v = bin_cnt * inv_s * v; + out_c[i] = round(v) * s / bin_cnt; + } } } template struct ChannelClipFakeQuantDequantFunctor { - void operator()(const platform::CUDADeviceContext& ctx, - const framework::Tensor& in, const framework::Tensor& scale, - const int bin_cnt, const int quant_axis, - framework::Tensor* out) { + void operator()(const platform::CUDADeviceContext &ctx, + const framework::Tensor &in, + const framework::Tensor &scale, + const int bin_cnt, + const int round_type, + const int quant_axis, + framework::Tensor *out) { // At present, channelwise quantization supports conv2d, depthwise_conv2d // conv2d_transpose and mul PADDLE_ENFORCE_EQ( - quant_axis == 0 || quant_axis == 1, true, + quant_axis == 0 || quant_axis == 1, + true, platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " "the received is %d", quant_axis)); @@ -543,23 +680,34 @@ struct ChannelClipFakeQuantDequantFunctor { int num = in.numel(); auto in_dims = in.dims(); - const T* in_data = in.data(); - const T* scale_data = scale.data(); - T* out_data = out->mutable_data(ctx.GetPlace()); + const T *in_data = in.data(); + const T *scale_data = scale.data(); + T *out_data = out->mutable_data(ctx.GetPlace()); if (quant_axis == 0) { int grid = in_dims[0]; int block = 1024; - ChannelClipAndQuantDequantKernelQuantAxis0< - T><<>>(in_data, scale_data, bin_cnt, - num, in_dims[0], out_data); + ChannelClipAndQuantDequantKernelQuantAxis0 + <<>>(in_data, + scale_data, + bin_cnt, + round_type, + num, + in_dims[0], + out_data); } else if (quant_axis == 1) { int grid = in_dims[0] * in_dims[1]; int block = 1024; - ChannelClipAndQuantDequantKernelQuantAxis1< - T><<>>( - in_data, scale_data, bin_cnt, num, in_dims[0], in_dims[1], out_data); + ChannelClipAndQuantDequantKernelQuantAxis1 + <<>>(in_data, + scale_data, + bin_cnt, + round_type, + num, + in_dims[0], + in_dims[1], + out_data); } } }; diff --git a/paddle/fluid/operators/fake_quantize_op.h b/paddle/fluid/operators/fake_quantize_op.h index dc3f081cc9e..6931ac4325b 100644 --- a/paddle/fluid/operators/fake_quantize_op.h +++ b/paddle/fluid/operators/fake_quantize_op.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include + #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor_util.h" @@ -33,97 +34,158 @@ inline HOSTDEVICE T inverse(T s) { return s <= static_cast(1e-30) ? one / (s + eps) : one / s; } +template +inline HOSTDEVICE T roundWithTiesToEven(T x) { + T xLower = floor(x); + T xUpper = ceil(x); + // x is in interval [xl,xu]. Choose closest of two bounds, breaking ties to + // even. + T dLower = x - xLower; + T dUpper = xUpper - x; + return static_cast( + (dLower == dUpper ? fmod(xLower, 2.0F) == 0.0F : dLower < dUpper) + ? xLower + : xUpper); +} + +template +class QuantTensorFunctor { + public: + explicit QuantTensorFunctor(const T bin_cnt, const T inv_s) + : bin_cnt_(bin_cnt), inv_s_(inv_s) {} + HOSTDEVICE T operator()(const T x) const { + T out = bin_cnt_ * inv_s_ * x; + out = roundWithTiesToEven(out); + T max_bound = bin_cnt_; + T min_bound = -bin_cnt_ - static_cast(1); + out = out > max_bound ? max_bound : out; + out = out < min_bound ? min_bound : out; + return out; + } + + private: + T bin_cnt_; + T inv_s_; +}; + template struct FindAbsMaxFunctor { - void operator()(const DeviceContext& ctx, const T* in, const int num, T* out); + void operator()(const DeviceContext &ctx, const T *in, const int num, T *out); }; template struct ClipAndFakeQuantFunctor { - void operator()(const DeviceContext& ctx, const framework::Tensor& in, - const framework::Tensor& scale, const int bin_cnt, - framework::Tensor* out); + void operator()(const DeviceContext &ctx, + const framework::Tensor &in, + const framework::Tensor &scale, + const int bin_cnt, + const int round_type, + framework::Tensor *out); }; template struct ClipAndFakeQuantDequantFunctor { - void operator()(const DeviceContext& ctx, const framework::Tensor& in, - const framework::Tensor& scale, const int bin_cnt, - framework::Tensor* out); + void operator()(const DeviceContext &ctx, + const framework::Tensor &in, + const framework::Tensor &scale, + const int bin_cnt, + int round_type, + framework::Tensor *out); }; template struct FindRangeAbsMaxFunctor { - void operator()(const DeviceContext& ctx, const framework::Tensor& cur_scale, - const framework::Tensor& last_scale, - const framework::Tensor& iter, const int window_size, - framework::Tensor* scales_arr, framework::Tensor* out_scale); + void operator()(const DeviceContext &ctx, + const framework::Tensor &cur_scale, + const framework::Tensor &last_scale, + const framework::Tensor &iter, + const int window_size, + framework::Tensor *scales_arr, + framework::Tensor *out_scale); }; template struct FindChannelAbsMaxFunctor { - void operator()(const DeviceContext& ctx, const framework::Tensor& in_tensor, - const int quant_axis, T* out_abs_max); + void operator()(const DeviceContext &ctx, + const framework::Tensor &in_tensor, + const int quant_axis, + T *out_abs_max); }; template struct ChannelClipAndFakeQuantFunctor { - void operator()(const DeviceContext& ctx, const framework::Tensor& in, - const framework::Tensor& scale, const int bin_cnt, - const int quant_axis, framework::Tensor* out); + void operator()(const DeviceContext &ctx, + const framework::Tensor &in, + const framework::Tensor &scale, + const int bin_cnt, + const int round_type, + const int quant_axis, + framework::Tensor *out); }; template struct ChannelClipFakeQuantDequantFunctor { - void operator()(const DeviceContext& ctx, const framework::Tensor& in, - const framework::Tensor& scale, const int bin_cnt, - const int quant_axis, framework::Tensor* out); + void operator()(const DeviceContext &ctx, + const framework::Tensor &in, + const framework::Tensor &scale, + const int bin_cnt, + int round_type, + const int quant_axis, + framework::Tensor *out); }; template struct FindMovingAverageAbsMaxFunctor { - void operator()(const DeviceContext& ctx, const framework::Tensor& in_accum, - const framework::Tensor& in_state, - const framework::Tensor& cur_scale, - framework::Tensor* out_state, framework::Tensor* out_accum, - framework::Tensor* out_scale); + void operator()(const DeviceContext &ctx, + const framework::Tensor &in_accum, + const framework::Tensor &in_state, + const framework::Tensor &cur_scale, + framework::Tensor *out_state, + framework::Tensor *out_accum, + framework::Tensor *out_scale); }; template class FakeAbsMaxKernelBase : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& context) const override { - auto* in = context.Input("X"); - auto* out = context.Output("Out"); - auto* out_scale = context.Output("OutScale"); - T* out_s = out_scale->mutable_data(context.GetPlace()); + void Compute(const framework::ExecutionContext &context) const override { + auto *in = context.Input("X"); + auto *out = context.Output("Out"); + auto *out_scale = context.Output("OutScale"); + T *out_s = out_scale->mutable_data(context.GetPlace()); int bit_length = context.Attr("bit_length"); + int round_type = context.Attr("round_type"); int bin_cnt = std::pow(2, bit_length - 1) - 1; - auto& dev_ctx = context.template device_context(); - const T* in_data = in->data(); + auto &dev_ctx = context.template device_context(); + const T *in_data = in->data(); FindAbsMaxFunctor()(dev_ctx, in_data, in->numel(), out_s); - RunClipFunctor(dev_ctx, *in, *out_scale, bin_cnt, out); + RunClipFunctor(dev_ctx, *in, *out_scale, bin_cnt, round_type, out); } virtual ~FakeAbsMaxKernelBase() = default; protected: - virtual void RunClipFunctor(const DeviceContext& dev_ctx, - const framework::Tensor& in, - const framework::Tensor& scale, int bin_cnt, - framework::Tensor* out) const = 0; + virtual void RunClipFunctor(const DeviceContext &dev_ctx, + const framework::Tensor &in, + const framework::Tensor &scale, + int bin_cnt, + int round_type, + framework::Tensor *out) const = 0; }; template class FakeQuantizeAbsMaxKernel : public FakeAbsMaxKernelBase { protected: - void RunClipFunctor(const DeviceContext& dev_ctx, const framework::Tensor& in, - const framework::Tensor& scale, int bin_cnt, - framework::Tensor* out) const override { - ClipAndFakeQuantFunctor()(dev_ctx, in, scale, bin_cnt, - out); + void RunClipFunctor(const DeviceContext &dev_ctx, + const framework::Tensor &in, + const framework::Tensor &scale, + int bin_cnt, + int round_type, + framework::Tensor *out) const override { + ClipAndFakeQuantFunctor()( + dev_ctx, in, scale, bin_cnt, round_type, out); } }; @@ -131,37 +193,41 @@ template class FakeQuantizeDequantizeAbsMaxKernel : public FakeAbsMaxKernelBase { protected: - void RunClipFunctor(const DeviceContext& dev_ctx, const framework::Tensor& in, - const framework::Tensor& scale, int bin_cnt, - framework::Tensor* out) const override { - ClipAndFakeQuantDequantFunctor()(dev_ctx, in, scale, - bin_cnt, out); + void RunClipFunctor(const DeviceContext &dev_ctx, + const framework::Tensor &in, + const framework::Tensor &scale, + int bin_cnt, + int round_type, + framework::Tensor *out) const override { + ClipAndFakeQuantDequantFunctor()( + dev_ctx, in, scale, bin_cnt, round_type, out); } }; template class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& context) const override { - auto* in = context.Input("X"); + void Compute(const framework::ExecutionContext &context) const override { + auto *in = context.Input("X"); - auto* out = context.Output("Out"); - auto* out_scale = context.Output("OutScale"); + auto *out = context.Output("Out"); + auto *out_scale = context.Output("OutScale"); out->mutable_data(context.GetPlace()); int bit_length = context.Attr("bit_length"); + int round_type = context.Attr("round_type"); int bin_cnt = std::pow(2, bit_length - 1) - 1; int quant_axis = context.Attr("quant_axis"); bool is_test = context.Attr("is_test"); - auto& dev_ctx = context.template device_context(); + auto &dev_ctx = context.template device_context(); if (!is_test) { - T* out_scale_data = out_scale->mutable_data(context.GetPlace()); - FindChannelAbsMaxFunctor()(dev_ctx, *in, quant_axis, - out_scale_data); + T *out_scale_data = out_scale->mutable_data(context.GetPlace()); + FindChannelAbsMaxFunctor()( + dev_ctx, *in, quant_axis, out_scale_data); } ChannelClipAndFakeQuantFunctor()( - dev_ctx, *in, *out_scale, bin_cnt, quant_axis, out); + dev_ctx, *in, *out_scale, bin_cnt, round_type, quant_axis, out); } }; @@ -169,130 +235,147 @@ template class FakeChannelWiseQuantizeDequantizeAbsMaxKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& context) const override { - auto* in = context.Input("X"); - auto* out = context.Output("Out"); - auto* out_scale = context.Output("OutScale"); - T* out_scale_data = out_scale->mutable_data(context.GetPlace()); - auto& dev_ctx = context.template device_context(); + void Compute(const framework::ExecutionContext &context) const override { + auto *in = context.Input("X"); + auto *out = context.Output("Out"); + auto *out_scale = context.Output("OutScale"); + T *out_scale_data = out_scale->mutable_data(context.GetPlace()); + auto &dev_ctx = context.template device_context(); out->mutable_data(dev_ctx.GetPlace()); int bit_length = context.Attr("bit_length"); + int round_type = context.Attr("round_type"); int bin_cnt = std::pow(2, bit_length - 1) - 1; int quant_axis = context.Attr("quant_axis"); - FindChannelAbsMaxFunctor()(dev_ctx, *in, quant_axis, - out_scale_data); + FindChannelAbsMaxFunctor()( + dev_ctx, *in, quant_axis, out_scale_data); ChannelClipFakeQuantDequantFunctor()( - dev_ctx, *in, *out_scale, bin_cnt, quant_axis, out); + dev_ctx, *in, *out_scale, bin_cnt, round_type, quant_axis, out); } }; template class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& context) const override { - auto* in = context.Input("X"); - auto* in_scale = context.Input("InScale"); + void Compute(const framework::ExecutionContext &context) const override { + auto *in = context.Input("X"); + auto *in_scale = context.Input("InScale"); - auto* out = context.Output("Out"); + auto *out = context.Output("Out"); out->mutable_data(context.GetPlace()); bool is_test = context.Attr("is_test"); int bit_length = context.Attr("bit_length"); + int round_type = context.Attr("round_type"); int bin_cnt = std::pow(2, bit_length - 1) - 1; - auto& dev_ctx = context.template device_context(); + auto &dev_ctx = context.template device_context(); // testing if (is_test) { - ClipAndFakeQuantFunctor()(dev_ctx, *in, *in_scale, - bin_cnt, out); + ClipAndFakeQuantFunctor()( + dev_ctx, *in, *in_scale, bin_cnt, round_type, out); return; } // training - auto* out_scale = context.Output("OutScale"); - auto* out_scales = context.Output("OutScales"); - auto* iter = context.Input("Iter"); + auto *out_scale = context.Output("OutScale"); + auto *out_scales = context.Output("OutScales"); + auto *iter = context.Input("Iter"); int window_size = context.Attr("window_size"); out_scale->mutable_data(context.GetPlace()); framework::Tensor cur_scale; - T* cur_scale_data = cur_scale.mutable_data({1}, context.GetPlace()); - FindAbsMaxFunctor()(dev_ctx, in->data(), in->numel(), - cur_scale_data); - FindRangeAbsMaxFunctor()(dev_ctx, cur_scale, *in_scale, - *iter, window_size, out_scales, + T *cur_scale_data = cur_scale.mutable_data({1}, context.GetPlace()); + FindAbsMaxFunctor()( + dev_ctx, in->data(), in->numel(), cur_scale_data); + FindRangeAbsMaxFunctor()(dev_ctx, + cur_scale, + *in_scale, + *iter, + window_size, + out_scales, out_scale); - ClipAndFakeQuantFunctor()(dev_ctx, *in, *out_scale, - bin_cnt, out); + ClipAndFakeQuantFunctor()( + dev_ctx, *in, *out_scale, bin_cnt, round_type, out); } }; template class FakeMovingAverageAbsMaxKernelBase : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& context) const override { - auto* in = context.Input("X"); - auto* in_scale = context.Input("InScale"); - auto* out = context.Output("Out"); + void Compute(const framework::ExecutionContext &context) const override { + auto *in = context.Input("X"); + auto *in_scale = context.Input("InScale"); + auto *out = context.Output("Out"); out->mutable_data(context.GetPlace()); bool is_test = context.Attr("is_test"); int bit_length = context.Attr("bit_length"); + int round_type = context.Attr("round_type"); int bin_cnt = std::pow(2, bit_length - 1) - 1; - auto& dev_ctx = context.template device_context(); + auto &dev_ctx = context.template device_context(); // testing if (is_test) { - RunClipFunctor(dev_ctx, *in, *in_scale, bin_cnt, out); + RunClipFunctor(dev_ctx, *in, *in_scale, bin_cnt, round_type, out); return; } // training - auto* in_accum = context.Input("InAccum"); - auto* in_state = context.Input("InState"); + auto *in_accum = context.Input("InAccum"); + auto *in_state = context.Input("InState"); auto cur_scale = memory::Alloc(dev_ctx, sizeof(T)); - T* cur_scale_data = static_cast(cur_scale->ptr()); + T *cur_scale_data = static_cast(cur_scale->ptr()); - FindAbsMaxFunctor()(dev_ctx, in->data(), in->numel(), - cur_scale_data); + FindAbsMaxFunctor()( + dev_ctx, in->data(), in->numel(), cur_scale_data); - auto* out_state = context.Output("OutState"); - auto* out_accum = context.Output("OutAccum"); - auto* out_scale = context.Output("OutScale"); + auto *out_state = context.Output("OutState"); + auto *out_accum = context.Output("OutAccum"); + auto *out_scale = context.Output("OutScale"); out_state->mutable_data(context.GetPlace()); out_accum->mutable_data(context.GetPlace()); out_scale->mutable_data(context.GetPlace()); float moving_rate = context.Attr("moving_rate"); - FindMovingAverageAbsMaxFunctor()( - dev_ctx, *in_accum, *in_state, cur_scale_data, moving_rate, out_state, - out_accum, out_scale); + FindMovingAverageAbsMaxFunctor()(dev_ctx, + *in_accum, + *in_state, + cur_scale_data, + moving_rate, + out_state, + out_accum, + out_scale); - RunClipFunctor(dev_ctx, *in, *out_scale, bin_cnt, out); + RunClipFunctor(dev_ctx, *in, *out_scale, bin_cnt, round_type, out); } virtual ~FakeMovingAverageAbsMaxKernelBase() = default; protected: - virtual void RunClipFunctor(const DeviceContext& dev_ctx, - const framework::Tensor& in, - const framework::Tensor& in_scale, int bin_cnt, - framework::Tensor* out) const = 0; + virtual void RunClipFunctor(const DeviceContext &dev_ctx, + const framework::Tensor &in, + const framework::Tensor &in_scale, + int bin_cnt, + int round_type, + framework::Tensor *out) const = 0; }; template class FakeQuantizeMovingAverageAbsMaxKernel : public FakeMovingAverageAbsMaxKernelBase { protected: - void RunClipFunctor(const DeviceContext& dev_ctx, const framework::Tensor& in, - const framework::Tensor& in_scale, int bin_cnt, - framework::Tensor* out) const override { - ClipAndFakeQuantFunctor()(dev_ctx, in, in_scale, bin_cnt, - out); + void RunClipFunctor(const DeviceContext &dev_ctx, + const framework::Tensor &in, + const framework::Tensor &in_scale, + int bin_cnt, + int round_type, + framework::Tensor *out) const override { + ClipAndFakeQuantFunctor()( + dev_ctx, in, in_scale, bin_cnt, round_type, out); } }; @@ -300,23 +383,26 @@ template class FakeQuantizeDequantizeMovingAverageAbsMaxKernel : public FakeMovingAverageAbsMaxKernelBase { protected: - void RunClipFunctor(const DeviceContext& dev_ctx, const framework::Tensor& in, - const framework::Tensor& in_scale, int bin_cnt, - framework::Tensor* out) const override { - ClipAndFakeQuantDequantFunctor()(dev_ctx, in, in_scale, - bin_cnt, out); + void RunClipFunctor(const DeviceContext &dev_ctx, + const framework::Tensor &in, + const framework::Tensor &in_scale, + int bin_cnt, + int round_type, + framework::Tensor *out) const override { + ClipAndFakeQuantDequantFunctor()( + dev_ctx, in, in_scale, bin_cnt, round_type, out); } }; template class MovingAverageAbsMaxScaleKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& context) const override { - auto* in = context.Input("X"); - auto& dev_ctx = context.template device_context(); + void Compute(const framework::ExecutionContext &context) const override { + auto *in = context.Input("X"); + auto &dev_ctx = context.template device_context(); if (context.HasOutput("Out")) { - auto* out = context.Output("Out"); + auto *out = context.Output("Out"); out->mutable_data(context.GetPlace()); framework::TensorCopy(*in, context.GetPlace(), dev_ctx, out); } @@ -328,40 +414,46 @@ class MovingAverageAbsMaxScaleKernel : public framework::OpKernel { } // training - auto* in_accum = context.Input("InAccum"); - auto* in_state = context.Input("InState"); + auto *in_accum = context.Input("InAccum"); + auto *in_state = context.Input("InState"); auto cur_scale = memory::Alloc(dev_ctx, sizeof(T)); - T* cur_scale_data = static_cast(cur_scale->ptr()); + T *cur_scale_data = static_cast(cur_scale->ptr()); - FindAbsMaxFunctor()(dev_ctx, in->data(), in->numel(), - cur_scale_data); + FindAbsMaxFunctor()( + dev_ctx, in->data(), in->numel(), cur_scale_data); - auto* out_state = context.Output("OutState"); - auto* out_accum = context.Output("OutAccum"); - auto* out_scale = context.Output("OutScale"); + auto *out_state = context.Output("OutState"); + auto *out_accum = context.Output("OutAccum"); + auto *out_scale = context.Output("OutScale"); out_state->mutable_data(context.GetPlace()); out_accum->mutable_data(context.GetPlace()); out_scale->mutable_data(context.GetPlace()); float moving_rate = context.Attr("moving_rate"); - FindMovingAverageAbsMaxFunctor()( - dev_ctx, *in_accum, *in_state, cur_scale_data, moving_rate, out_state, - out_accum, out_scale); + FindMovingAverageAbsMaxFunctor()(dev_ctx, + *in_accum, + *in_state, + cur_scale_data, + moving_rate, + out_state, + out_accum, + out_scale); } }; template class StrightThroughEstimatorGradKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& context) const override { - auto* d_out = + void Compute(const framework::ExecutionContext &context) const override { + auto *d_out = context.Input(framework::GradVarName("Out")); auto x_grad_name = framework::GradVarName("X"); - auto* d_x = context.Output(x_grad_name); - PADDLE_ENFORCE_NOT_NULL(d_x, platform::errors::PreconditionNotMet( - "StrightThroughEstimatorGradKernel " - "doesn't have the output named %s.", - x_grad_name)); + auto *d_x = context.Output(x_grad_name); + PADDLE_ENFORCE_NOT_NULL(d_x, + platform::errors::PreconditionNotMet( + "StrightThroughEstimatorGradKernel " + "doesn't have the output named %s.", + x_grad_name)); // Initialize dx as same as d_out d_x->mutable_data(context.GetPlace()); diff --git a/paddle/fluid/operators/quantize_linear_op.cc b/paddle/fluid/operators/quantize_linear_op.cc index 4039f0e9d07..4580acbe3fc 100644 --- a/paddle/fluid/operators/quantize_linear_op.cc +++ b/paddle/fluid/operators/quantize_linear_op.cc @@ -10,9 +10,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/quantize_linear_op.h" + #include #include #include + #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/platform/transform.h" @@ -24,14 +26,17 @@ namespace operators { template struct ChannelDequantizeFunctorV2 { - void operator()(const platform::CPUDeviceContext& dev_ctx, - const framework::Tensor* in, const framework::Tensor* scale, - T max_range, const int quant_axis, framework::Tensor* out) { + void operator()(const platform::CPUDeviceContext &dev_ctx, + const framework::Tensor *in, + const framework::Tensor *scale, + T max_range, + const int quant_axis, + framework::Tensor *out) { // Dequant op is before quantized op // Dequantize the weight of quantized op auto in_dims = in->dims(); const int64_t channel = in_dims[quant_axis]; - const T* scale_factor = scale->data(); + const T *scale_factor = scale->data(); if (quant_axis == 0) { for (int64_t i = 0; i < channel; i++) { T s = scale_factor[i]; @@ -39,7 +44,7 @@ struct ChannelDequantizeFunctorV2 { framework::Tensor one_channel_out = out->Slice(i, i + 1); auto in_e = framework::EigenVector::Flatten(one_channel_in); auto out_e = framework::EigenVector::Flatten(one_channel_out); - auto& dev = *dev_ctx.eigen_device(); + auto &dev = *dev_ctx.eigen_device(); out_e.device(dev) = in_e * s / max_range; } } else if (quant_axis == 1) { @@ -49,12 +54,12 @@ struct ChannelDequantizeFunctorV2 { } int64_t step_i = in->numel() / out_iter; int64_t step_j = in->numel() / (out_iter * channel); - auto* in_data = in->data(); - auto* out_data = out->mutable_data(dev_ctx.GetPlace()); + auto *in_data = in->data(); + auto *out_data = out->mutable_data(dev_ctx.GetPlace()); for (int64_t i = 0; i < out_iter; i++) { for (int64_t j = 0; j < channel; j++) { - auto* cur_in = in_data + i * step_i + j * step_j; - auto* cur_out = out_data + i * step_i + j * step_j; + auto *cur_in = in_data + i * step_i + j * step_j; + auto *cur_out = out_data + i * step_i + j * step_j; T s = scale_factor[j]; for (int64_t k = 0; k < step_j; k++) { *cur_out = (*cur_in) * s / max_range; @@ -67,19 +72,17 @@ struct ChannelDequantizeFunctorV2 { } }; -template struct DequantizeFunctor; -template struct DequantizeFunctor; template struct ChannelDequantizeFunctorV2; template struct ChannelDequantizeFunctorV2; class QuantizeLinearOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "QuantizeLinear"); OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale", "QuantizeLinear"); - OP_INOUT_CHECK(ctx->HasInput("ZeroPoint"), "Input", "ZeroPoint", - "QuantizeLinear"); + OP_INOUT_CHECK( + ctx->HasInput("ZeroPoint"), "Input", "ZeroPoint", "QuantizeLinear"); OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "QuantizeLinear"); ctx->SetOutputDim("Y", ctx->GetInputDim("X")); int quant_axis = ctx->Attrs().Get("quant_axis"); @@ -95,7 +98,7 @@ class QuantizeLinearOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { + const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } @@ -116,9 +119,10 @@ class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker { "For conv2d, depthwise_conv2d, conv2d_transpose " "and mul, the quant_axis is equal to the cout axis.") .SetDefault(0) - .AddCustomChecker([](const int& quant_axis) { + .AddCustomChecker([](const int &quant_axis) { PADDLE_ENFORCE_EQ( - quant_axis == 0 || quant_axis == 1 || quant_axis == -1, true, + quant_axis == 0 || quant_axis == 1 || quant_axis == -1, + true, platform::errors::InvalidArgument( "'quant_axis' should be 0 or 1, but " "the received is %d", @@ -126,13 +130,32 @@ class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker { }); AddAttr("bit_length", "(int, default 8)") .SetDefault(8) - .AddCustomChecker([](const int& bit_length) { - PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true, + .AddCustomChecker([](const int &bit_length) { + PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, + true, platform::errors::InvalidArgument( "'bit_length' should be between 1 and 16, but " "the received is %d", bit_length)); }); + AddAttr( + "round_type", + "(int, default 0) The round type of fp32 to int." + "0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2" + "1: rounding to nearest ties away from zero. Eg: round(1.5)=2, " + "round(2.5)=3") + .SetDefault(0) + .AddCustomChecker([](const int &round_type) { + PADDLE_ENFORCE_EQ( + round_type == 0 || round_type == 1, + true, + platform::errors::InvalidArgument( + "'round_type' should be 0 or 1, 0 rounding to " + "nearest ties to even and 1 is rounding to nearest " + "ties away from zero.but the received is %d", + round_type)); + }) + .AsExtra(); AddAttr("is_test", "(bool, default false) Set to true for inference only, false " "for training. Some layers may run faster when this is true.") @@ -156,14 +179,18 @@ namespace ops = paddle::operators; using CPU = paddle::platform::CPUDeviceContext; REGISTER_OPERATOR( - quantize_linear, ops::QuantizeLinearOp, ops::QuantizeLinearOpMaker, + quantize_linear, + ops::QuantizeLinearOp, + ops::QuantizeLinearOpMaker, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL(quantize_linear, ops::QuantizeLinearKernel); REGISTER_OPERATOR( - dequantize_linear, ops::QuantizeLinearOp, ops::QuantizeLinearOpMaker, + dequantize_linear, + ops::QuantizeLinearOp, + ops::QuantizeLinearOpMaker, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker); diff --git a/paddle/fluid/operators/quantize_linear_op.h b/paddle/fluid/operators/quantize_linear_op.h index e20b99e85f0..47e65784b6b 100644 --- a/paddle/fluid/operators/quantize_linear_op.h +++ b/paddle/fluid/operators/quantize_linear_op.h @@ -29,9 +29,13 @@ namespace operators { template struct ChannelDequantizeFunctorV2 { - void operator()(const DeviceContext& dev_ctx, const framework::Tensor* in, - const framework::Tensor** scales, const int scale_num, - T max_range, const int quant_axis, framework::Tensor* out); + void operator()(const DeviceContext& dev_ctx, + const framework::Tensor* in, + const framework::Tensor** scales, + const int scale_num, + T max_range, + const int quant_axis, + framework::Tensor* out); }; template @@ -44,6 +48,7 @@ class QuantizeLinearKernel : public framework::OpKernel { auto* out = context.Output("Y"); out->mutable_data(context.GetPlace()); int bit_length = context.Attr("bit_length"); + int round_type = context.Attr("round_type"); int bin_cnt = std::pow(2, bit_length - 1) - 1; int quant_axis = context.Attr("quant_axis"); bool is_test = context.Attr("is_test"); @@ -53,25 +58,25 @@ class QuantizeLinearKernel : public framework::OpKernel { if (!is_test) { auto* out_scale = context.Output("OutScale"); T* out_s = out_scale->mutable_data(context.GetPlace()); - FindAbsMaxFunctor()(dev_ctx, in->data(), - in->numel(), out_s); - ClipAndFakeQuantFunctor()(dev_ctx, *in, *out_scale, - bin_cnt, out); + FindAbsMaxFunctor()( + dev_ctx, in->data(), in->numel(), out_s); + ClipAndFakeQuantFunctor()( + dev_ctx, *in, *out_scale, bin_cnt, round_type, out); } else { - ClipAndFakeQuantFunctor()(dev_ctx, *in, *in_scale, - bin_cnt, out); + ClipAndFakeQuantFunctor()( + dev_ctx, *in, *in_scale, bin_cnt, round_type, out); } } else { if (!is_test) { auto* out_scale = context.Output("OutScale"); T* out_scale_data = out_scale->mutable_data(context.GetPlace()); - FindChannelAbsMaxFunctor()(dev_ctx, *in, quant_axis, - out_scale_data); + FindChannelAbsMaxFunctor()( + dev_ctx, *in, quant_axis, out_scale_data); ChannelClipAndFakeQuantFunctor()( - dev_ctx, *in, *out_scale, bin_cnt, quant_axis, out); + dev_ctx, *in, *out_scale, bin_cnt, round_type, quant_axis, out); } else { ChannelClipAndFakeQuantFunctor()( - dev_ctx, *in, *in_scale, bin_cnt, quant_axis, out); + dev_ctx, *in, *in_scale, bin_cnt, round_type, quant_axis, out); } } } @@ -87,7 +92,8 @@ class DeQuantizeLinearKernel : public framework::OpKernel { auto in_tmp = phi::Cast( static_cast::TYPE&>(dev_ctx), - *in, experimental::CppTypeToDataType::Type()); + *in, + experimental::CppTypeToDataType::Type()); auto* scale = context.Input("Scale"); auto* out = context.Output("Y"); @@ -97,16 +103,18 @@ class DeQuantizeLinearKernel : public framework::OpKernel { if (quant_axis < 0) { float max_range = (std::pow(2, bit_length - 1) - 1); - DequantizeFunctor()(dev_ctx, &in_tmp, scale, - static_cast(max_range), out); + DequantizeFunctor()( + dev_ctx, &in_tmp, scale, static_cast(max_range), out); } else { PADDLE_ENFORCE_EQ( - scale->numel(), in_tmp.dims()[quant_axis], + scale->numel(), + in_tmp.dims()[quant_axis], platform::errors::PreconditionNotMet( "The number of first scale values must be the same with " "quant_axis dimension value of Input(X) when the `scale` has " "only one element, but %ld != %ld here.", - scale->numel(), in_tmp.dims()[quant_axis])); + scale->numel(), + in_tmp.dims()[quant_axis])); int max_range = (std::pow(2, bit_length - 1) - 1); ChannelDequantizeFunctorV2()( diff --git a/python/paddle/fluid/contrib/slim/quantization/adaround.py b/python/paddle/fluid/contrib/slim/quantization/adaround.py index f6908d7e836..04d894b055d 100644 --- a/python/paddle/fluid/contrib/slim/quantization/adaround.py +++ b/python/paddle/fluid/contrib/slim/quantization/adaround.py @@ -20,26 +20,31 @@ import logging import paddle.fluid as fluid from ....log_helper import get_logger -from .utils import load_variable_data, set_variable_data, stable_sigmoid, quant_tensor, dequant_tensor, _channelwise_quant_axis1_ops, calculate_quant_cos_error +from .utils import load_variable_data, set_variable_data, stable_sigmoid, quant_tensor, dequant_tensor, _channelwise_quant_axis1_ops, calculate_quant_cos_error, bias_correction_w -_logger = get_logger( - __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') +_logger = get_logger(__name__, + logging.INFO, + fmt='%(asctime)s-%(levelname)s: %(message)s') GAMMA = -0.1 ZETA = 1.1 def compute_soft_rounding(alpha_v): - return fluid.layers.clip( - fluid.layers.sigmoid(alpha_v) * (ZETA - GAMMA) + GAMMA, min=0, max=1) + return fluid.layers.clip(fluid.layers.sigmoid(alpha_v) * (ZETA - GAMMA) + + GAMMA, + min=0, + max=1) def compute_soft_rounding_np(alpha_v): - return np.clip( - stable_sigmoid(alpha_v) * (ZETA - GAMMA) + GAMMA, a_min=0, a_max=1) + return np.clip(stable_sigmoid(alpha_v) * (ZETA - GAMMA) + GAMMA, + a_min=0, + a_max=1) class AdaRoundLoss(object): + def __init__(self, reg_param=0.01, default_beta_range=(20, 2)): self.default_reg_param = reg_param self.default_beta_range = default_beta_range @@ -48,26 +53,29 @@ class AdaRoundLoss(object): square_cost = fluid.layers.square_error_cost(ada_quantized_output, orig_output) recon_loss = fluid.layers.reduce_mean( - fluid.layers.reduce_sum( - square_cost, dim=-1)) + fluid.layers.reduce_sum(square_cost, dim=-1)) return recon_loss def compute_round_loss(self, alpha_v, warm_start, beta): + def round_loss_fn(): # compute rectified sigmoid of parameter 'alpha' which maps it between zero and one h_v = compute_soft_rounding(alpha_v) # calculate regularization term - which ensures parameter to converge to exactly zeros and ones # at the end of optimization - reg_term = fluid.layers.reduce_sum(-fluid.layers.pow( - fluid.layers.abs(2 * h_v - 1), factor=beta) + 1) + reg_term = fluid.layers.reduce_sum( + -fluid.layers.pow(fluid.layers.abs(2 * h_v - 1), factor=beta) + + 1) # calculate the rounding loss round_loss = self.default_reg_param * reg_term return round_loss - round_loss = fluid.layers.cond(warm_start, lambda: fluid.layers.fill_constant(shape=[1], dtype='float32', value=0.0), round_loss_fn) + round_loss = fluid.layers.cond( + warm_start, lambda: fluid.layers.fill_constant( + shape=[1], dtype='float32', value=0.0), round_loss_fn) return round_loss @@ -80,15 +88,16 @@ class AdaRoundLoss(object): warm_start_end_iter = warm_start * max_iter # compute relative iteration of current iteration - rel_iter = (cur_iter - warm_start_end_iter) / ( - max_iter - warm_start_end_iter) - beta = end_beta + 0.5 * (start_beta - end_beta) * (1 + np.cos(rel_iter * - np.pi)) + rel_iter = (cur_iter - warm_start_end_iter) / (max_iter - + warm_start_end_iter) + beta = end_beta + 0.5 * (start_beta - + end_beta) * (1 + np.cos(rel_iter * np.pi)) return beta class AdaRound(object): + def __init__(self, scale, weight_tensor, @@ -145,10 +154,9 @@ class AdaRound(object): h_alpha = compute_soft_rounding_np(np_alpha) # Scale the tensor - tensor_scale = quant_tensor( - self.ori_weight_tensor.copy(), - self.scale, - quant_axis=self.quant_axis) + tensor_scale = quant_tensor(self.ori_weight_tensor.copy(), + self.scale, + quant_axis=self.quant_axis) weight_tensor = np.floor(tensor_scale) @@ -160,10 +168,10 @@ class AdaRound(object): weight_tensor_quant = self._calculate_quant_weight() # Dequantize the tensor - weight_tensor_dequant = dequant_tensor( - weight_tensor_quant + self.offset, - self.scale, - quant_axis=self.quant_axis) + weight_tensor_dequant = dequant_tensor(weight_tensor_quant + + self.offset, + self.scale, + quant_axis=self.quant_axis) return weight_tensor_dequant def update_final_weights(self): @@ -171,10 +179,10 @@ class AdaRound(object): return weight_tensor_quant def get_loss(self, beta, warm_start, adaround_out_tensor, orig_out_tensor): - round_loss = self.adaround_loss.compute_round_loss(self.alpha_v, - warm_start, beta) - recon_loss = self.adaround_loss.compute_recon_loss(adaround_out_tensor, - orig_out_tensor) + round_loss = self.adaround_loss.compute_round_loss( + self.alpha_v, warm_start, beta) + recon_loss = self.adaround_loss.compute_recon_loss( + adaround_out_tensor, orig_out_tensor) loss = round_loss + recon_loss losses = { 'loss': loss, @@ -201,6 +209,7 @@ def run_adaround(data_loader, scale_dict, num_iterations=1000, lr=0.001, + bias_correction=False, fast_mode=True): fetch_op_name = fetch_list[0].name final_weight_tensor_quant_dict = {} @@ -226,29 +235,29 @@ def run_adaround(data_loader, with fluid.program_guard(train_program, startup_program): with fluid.unique_name.guard(): # initialize adaround - adaround = AdaRound( - scale, - weight_var_tensor, - scope=scope, - weight_var_name=weight_var_name, - weight_op_type=weight_op_type, - num_iterations=num_iterations) - orig_out_tensor = fluid.data( - name='orig_out_tensor', - shape=fp32_fetch_list.shape, - dtype='float32') - adaround_out_tensor = fluid.data( - name='adaround_out_tensor', - shape=fp32_fetch_list.shape, - dtype='float32') - beta_tensor = fluid.data( - name='beta', shape=[1], dtype='float32') - warm_start_tensor = fluid.data( - name='warm_start', shape=[1], dtype='bool') - - train_fetches_loss = adaround.get_loss( - beta_tensor, warm_start_tensor, adaround_out_tensor, - orig_out_tensor) + adaround = AdaRound(scale, + weight_var_tensor, + scope=scope, + weight_var_name=weight_var_name, + weight_op_type=weight_op_type, + num_iterations=num_iterations) + orig_out_tensor = fluid.data(name='orig_out_tensor', + shape=fp32_fetch_list.shape, + dtype='float32') + adaround_out_tensor = fluid.data(name='adaround_out_tensor', + shape=fp32_fetch_list.shape, + dtype='float32') + beta_tensor = fluid.data(name='beta', + shape=[1], + dtype='float32') + warm_start_tensor = fluid.data(name='warm_start', + shape=[1], + dtype='bool') + + train_fetches_loss = adaround.get_loss(beta_tensor, + warm_start_tensor, + adaround_out_tensor, + orig_out_tensor) optimizer = fluid.optimizer.Adam(learning_rate=lr) loss = train_fetches_loss['loss'] optimizer.minimize(loss) @@ -291,16 +300,23 @@ def run_adaround(data_loader, fetch_list=[v.name for v in train_fetches_loss.values()], return_numpy=True) _logger.info( - "Iter {:d}, lr {:.5f}, loss {:.5f}, loss_round {:.5f}, loss_recon {:.5f}, time {:.5f}s". - format(i, lr, - np.mean(out[0]), - np.mean(out[1]), - np.mean(out[2]), start_time - prev_start_time)) + "Iter {:d}, lr {:.5f}, loss {:.5f}, loss_round {:.5f}, loss_recon {:.5f}, time {:.5f}s" + .format(i, lr, np.mean(out[0]), np.mean(out[1]), + np.mean(out[2]), start_time - prev_start_time)) sys.stdout.flush() if i == num_iterations: break final_weight_tensor_quant_dict[ weight_var_name] = adaround.update_final_weights() + + if bias_correction: + final_weight_tensor_quant_dict[weight_var_name] = bias_correction_w( + weight_var_tensor, + final_weight_tensor_quant_dict[weight_var_name], + scale, + adaround.quant_axis, + weight_bits=adaround.weight_bits) + del adaround # update adarounded calibrated weights diff --git a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py index bd4ecfb7b11..f1da3990a36 100644 --- a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py +++ b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py @@ -36,8 +36,9 @@ from . import utils __all__ = ['PostTrainingQuantization', 'WeightQuantization'] -_logger = get_logger( - __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') +_logger = get_logger(__name__, + logging.INFO, + fmt='%(asctime)s-%(levelname)s: %(message)s') def _all_persistable_var_names(program): @@ -88,7 +89,8 @@ def _apply_pass(scope, cpp_graph.set_not_owned('__param_scope__', scope) if attrs: assert attr_values and len(attrs) == len( - attr_values), "Different number of pass attributes and their values." + attr_values + ), "Different number of pass attributes and their values." for attr, value in zip(attrs, attr_values): ir_pass.set(attr, value) ir_pass.apply(cpp_graph) @@ -180,7 +182,8 @@ class PostTrainingQuantization(object): "mul"]. round_type(str, optional): The method of converting the quantized weights value float->int. Currently supports ['round', 'adaround'] methods. - Default is `round`, which is rounding nearest to the nearest whole number. + Default is `round`, which is rounding nearest to the integer. + 'adaround' is refer to https://arxiv.org/abs/2004.10568. learning_rate(float, optional): The learning rate of adaround method. is_full_quantized(bool, optional): If set is_full_quantized as True, apply quantization to all supported quantizable op type. If set @@ -364,7 +367,8 @@ class PostTrainingQuantization(object): batch_id = 0 with tqdm( total=self._batch_nums, - bar_format='Preparation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}', + bar_format= + 'Preparation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}', ncols=80) as t: for data in self._data_loader(): self._executor.run(program=self._program, @@ -380,10 +384,10 @@ class PostTrainingQuantization(object): self._init_sampling_act_histogram() batch_id = 0 - with tqdm( - total=self._batch_nums, - bar_format='Sampling stage, Run batch:|{bar}| {n_fmt}/{total_fmt}', - ncols=80) as t: + with tqdm(total=self._batch_nums, + bar_format= + 'Sampling stage, Run batch:|{bar}| {n_fmt}/{total_fmt}', + ncols=80) as t: for data in self._data_loader(): self._executor.run(program=self._program, feed=data, @@ -446,18 +450,18 @@ class PostTrainingQuantization(object): scale_dict = self._quantized_var_threshold else: scale_dict = self._quantized_threshold - run_adaround( - self._data_loader, - self._program, - self._fetch_list, - self._executor, - self._scope, - self._place, - self._quantized_op_pairs, - self._weight_op_pairs, - scale_dict, - num_iterations=self._batch_nums, - lr=self._learning_rate) + run_adaround(self._data_loader, + self._program, + self._fetch_list, + self._executor, + self._scope, + self._place, + self._quantized_op_pairs, + self._weight_op_pairs, + scale_dict, + num_iterations=self._batch_nums, + bias_correction=self._bias_correction, + lr=self._learning_rate) def save_quantized_model(self, save_model_path, @@ -478,15 +482,14 @@ class PostTrainingQuantization(object): None ''' clip_extra = True if self._onnx_format else False - io.save_inference_model( - dirname=save_model_path, - model_filename=model_filename, - params_filename=params_filename, - feeded_var_names=self._feed_list, - target_vars=self._fetch_list, - executor=self._executor, - main_program=self._program, - clip_extra=clip_extra) + io.save_inference_model(dirname=save_model_path, + model_filename=model_filename, + params_filename=params_filename, + feeded_var_names=self._feed_list, + target_vars=self._fetch_list, + executor=self._executor, + main_program=self._program, + clip_extra=clip_extra) _logger.info("The quantized model is saved in " + save_model_path) def _load_model_data(self): @@ -508,17 +511,18 @@ class PostTrainingQuantization(object): if self._data_loader is not None: return - self._data_loader = io.DataLoader.from_generator( - feed_list=feed_vars, capacity=3 * self._batch_size, iterable=True) + self._data_loader = io.DataLoader.from_generator(feed_list=feed_vars, + capacity=3 * + self._batch_size, + iterable=True) if self._sample_generator is not None: - self._data_loader.set_sample_generator( - self._sample_generator, - batch_size=self._batch_size, - drop_last=True, - places=self._place) + self._data_loader.set_sample_generator(self._sample_generator, + batch_size=self._batch_size, + drop_last=True, + places=self._place) elif self._batch_generator is not None: - self._data_loader.set_batch_generator( - self._batch_generator, places=self._place) + self._data_loader.set_batch_generator(self._batch_generator, + places=self._place) def _optimize_fp32_model(self): ''' @@ -569,12 +573,10 @@ class PostTrainingQuantization(object): " is not supported for quantization.") # For quantized ops, sample inputs and outputs if op_type in self._quantizable_op_type: - collect_var_name( - utils._get_op_input_var_names(op), - persistable_var_names, op_type) - collect_var_name( - utils._get_op_output_var_names(op), - persistable_var_names, op_type) + collect_var_name(utils._get_op_input_var_names(op), + persistable_var_names, op_type) + collect_var_name(utils._get_op_output_var_names(op), + persistable_var_names, op_type) # collect quanted op output var name for out_var_name in utils._get_op_output_var_names(op): for in_var_name in utils._get_op_input_var_names(op): @@ -583,9 +585,8 @@ class PostTrainingQuantization(object): in_var_name] = out_var_name # For other op, only sample output scale elif op_type in self._out_scale_op_list: - collect_var_name( - utils._get_op_output_var_names(op), - persistable_var_names, op_type) + collect_var_name(utils._get_op_output_var_names(op), + persistable_var_names, op_type) def _set_activation_persistable(self): ''' @@ -655,9 +656,14 @@ class PostTrainingQuantization(object): scale = s * abs_max_value s += 0.02 bins = 2**(self._activation_bits - 1) - 1 - quant_dequant_var = np.round( - np.clip(var_tensor, 0.0, scale) / scale * - bins) / bins * scale + if self._onnx_format: + quant_var = np.clip(np.round(var_tensor / scale * bins), + -bins - 1, bins) + quant_dequant_var = quant_var / bins * scale + else: + quant_dequant_var = np.round( + np.clip(var_tensor, 0.0, scale) / scale * + bins) / bins * scale mse_loss = ((var_tensor - quant_dequant_var)**2).mean() if mse_loss <= self._best_calibration_loss[var_name]: self._best_calibration_loss[var_name] = mse_loss @@ -694,9 +700,14 @@ class PostTrainingQuantization(object): scale = s * abs_max_value s += 0.02 bins = 2**(self._activation_bits - 1) - 1 - quant_dequant_var = np.round( - np.clip(var_tensor, 0.0, scale) / scale * - bins) / bins * scale + if self._onnx_format: + quant_var = np.clip(np.round(var_tensor / scale * bins), + -bins - 1, bins) + quant_dequant_var = quant_var / bins * scale + else: + quant_dequant_var = np.round( + np.clip(var_tensor, 0.0, scale) / scale * + bins) / bins * scale emd_loss = np.abs( np.mean(var_tensor) - np.mean(quant_dequant_var)) + np.abs( np.std(var_tensor) - np.std(quant_dequant_var)) @@ -846,8 +857,9 @@ class PostTrainingQuantization(object): if var_name not in self._sampling_act_histogram: min_val = self._sampling_act_abs_min_max[var_name][0] max_val = self._sampling_act_abs_min_max[var_name][1] - hist, hist_edeges = np.histogram( - [], bins=self._histogram_bins, range=(min_val, max_val)) + hist, hist_edeges = np.histogram([], + bins=self._histogram_bins, + range=(min_val, max_val)) self._sampling_act_histogram[var_name] = [hist, hist_edeges] def _calculate_kl_hist_threshold(self): @@ -951,18 +963,11 @@ class PostTrainingQuantization(object): else: scale_dict = self._quantized_threshold for key, val in scale_dict.items(): - utils.set_variable_data( - self._scope, - self._place, - key + ".scale", - np.array( - [val], dtype=np.float32)) - utils.set_variable_data( - self._scope, - self._place, - key + ".quant_dequant.scale", - np.array( - [val], dtype=np.float32)) + utils.set_variable_data(self._scope, self._place, key + ".scale", + np.array([val], dtype=np.float32)) + utils.set_variable_data(self._scope, self._place, + key + ".quant_dequant.scale", + np.array([val], dtype=np.float32)) if not self._onnx_format: # apply QuantizationFreezePass, and obtain the final quant model @@ -1038,8 +1043,8 @@ class PostTrainingQuantization(object): for block_id in range(len(self._program.blocks)): for op in self._program.blocks[block_id].ops: - if op.type in ( - self._quantizable_op_type + self._out_scale_op_list): + if op.type in (self._quantizable_op_type + + self._out_scale_op_list): out_var_names = utils._get_op_output_var_names(op) for var_name in out_var_names: analysis_and_save_info(op, var_name) @@ -1175,10 +1180,11 @@ class WeightQuantization(object): if generate_test_model: test_model_dir = os.path.join(save_model_dir, "test_model") - self._quantize_weight_to_int( - test_model_dir, save_model_filename, save_params_filename, - quantizable_op_type, weight_bits, weight_quantize_type, True, - threshold_rate) + self._quantize_weight_to_int(test_model_dir, save_model_filename, + save_params_filename, + quantizable_op_type, weight_bits, + weight_quantize_type, True, + threshold_rate) def convert_weight_to_fp16(self, save_model_dir): """ @@ -1216,16 +1222,17 @@ class WeightQuantization(object): if self._params_filename is not None: save_var_map[new_var.name] = new_var else: - save_file_path = os.path.join( - os.path.normpath(save_model_dir), new_var.name) - save_block.append_op( - type='save', - inputs={'X': [new_var]}, - outputs={}, - attrs={ - 'file_path': os.path.normpath(save_file_path), - 'save_as_fp16': True - }) + save_file_path = os.path.join(os.path.normpath(save_model_dir), + new_var.name) + save_block.append_op(type='save', + inputs={'X': [new_var]}, + outputs={}, + attrs={ + 'file_path': + os.path.normpath(save_file_path), + 'save_as_fp16': + True + }) if self._params_filename is not None: save_var_list = [] @@ -1237,14 +1244,15 @@ class WeightQuantization(object): name=unique_name.generate("saved_params")) saved_params_var.desc.set_persistable(True) - save_path = os.path.join( - os.path.normpath(save_model_dir), self._params_filename) - save_block.append_op( - type='save_combine', - inputs={'X': save_var_list}, - outputs={'Y': saved_params_var}, - attrs={'file_path': save_path, - 'save_as_fp16': True}) + save_path = os.path.join(os.path.normpath(save_model_dir), + self._params_filename) + save_block.append_op(type='save_combine', + inputs={'X': save_var_list}, + outputs={'Y': saved_params_var}, + attrs={ + 'file_path': save_path, + 'save_as_fp16': True + }) save_program._sync_with_cpp() exe.run(save_program) @@ -1293,14 +1301,13 @@ class WeightQuantization(object): self._weight_channel_wise_abs_max_quantization( scope, place, weight_bits, op, var_name, for_test) - io.save_inference_model( - dirname=save_model_dir, - feeded_var_names=feed_list, - target_vars=fetch_list, - executor=exe, - main_program=program, - model_filename=save_model_filename, - params_filename=save_params_filename) + io.save_inference_model(dirname=save_model_dir, + feeded_var_names=feed_list, + target_vars=fetch_list, + executor=exe, + main_program=program, + model_filename=save_model_filename, + params_filename=save_params_filename) def _weight_abs_max_quantization(self, scope, place, weight_bits, threshold_rate, op, var_name, for_test): @@ -1339,8 +1346,9 @@ class WeightQuantization(object): op._set_attr(var_name + "_quant_scale", [scale]) # Save as list op._set_attr("with_quant_attr", True) - def _weight_channel_wise_abs_max_quantization( - self, scope, place, weight_bits, op, var_name, for_test): + def _weight_channel_wise_abs_max_quantization(self, scope, place, + weight_bits, op, var_name, + for_test): ''' Use channel_wise_abs_max method to quantize weight. ''' @@ -1390,8 +1398,8 @@ class WeightQuantization(object): and quantize the weights. ''' scales = [] - quantized_weight_data = np.zeros_like( - weight_data, dtype=save_weight_dtype) + quantized_weight_data = np.zeros_like(weight_data, + dtype=save_weight_dtype) channel_num = weight_data.shape[0] for i in range(channel_num): scale = np.max(np.abs(weight_data[i])) / quantize_range @@ -1404,8 +1412,8 @@ class WeightQuantization(object): ''' For conv2d and depthwise_conv2d, dequantize the weights to fp32. ''' - dequantized_weight_data = np.zeros_like( - quantized_weight_data, dtype=np.float32) + dequantized_weight_data = np.zeros_like(quantized_weight_data, + dtype=np.float32) for i in range(len(scales)): dequantized_weight_data[i] = \ (quantized_weight_data[i] * scales[i]).astype(np.float32) @@ -1418,8 +1426,8 @@ class WeightQuantization(object): and quantize the weights. ''' scales = [] - quantized_weight_data = np.zeros_like( - weight_data, dtype=save_weight_dtype) + quantized_weight_data = np.zeros_like(weight_data, + dtype=save_weight_dtype) channel_num = weight_data.shape[-1] for i in range(channel_num): scale = np.max(np.abs(weight_data[:, i])) / quantize_range @@ -1432,8 +1440,8 @@ class WeightQuantization(object): ''' For mul, dequantize the weights to fp32. ''' - dequantized_weight_data = np.zeros_like( - quantized_weight_data, dtype=np.float32) + dequantized_weight_data = np.zeros_like(quantized_weight_data, + dtype=np.float32) for i in range(len(scales)): dequantized_weight_data[:, i] = \ (quantized_weight_data[:, i] * scales[i]).astype(np.float32) @@ -1441,8 +1449,9 @@ class WeightQuantization(object): def _calculate_threshold(self, input, threshold_rate, histogram_bins=5000): input_abs = np.abs(input) - hist, hist_edeges = np.histogram( - input_abs, bins=histogram_bins, range=(0, np.max(input_abs))) + hist, hist_edeges = np.histogram(input_abs, + bins=histogram_bins, + range=(0, np.max(input_abs))) hist = hist / float(sum(hist)) hist_sum = 0 hist_index = 0 diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index f298929a6e7..3a316e9192e 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -307,8 +307,9 @@ class QuantizationTransformPass(object): var_node = self._insert_func( graph, self._weight_preprocess_func, var_node, op) elif not is_weight and self._act_preprocess_func is not None: - var_node = self._insert_func( - graph, self._act_preprocess_func, var_node, op) + var_node = self._insert_func(graph, + self._act_preprocess_func, + var_node, op) # if var node is weight and weight_quantize_func is not None, # will insert weight quantize func to quantize and dequantize weight @@ -376,10 +377,10 @@ class QuantizationTransformPass(object): graph.out_node_mapping_table = dict() # The process of _transform_forward and _transform_backward is needed in two for loops. # The loop for transforming the forward graph: - with tqdm( - total=len(ops), - bar_format='Adding quant op for weight:|{bar}| {n_fmt}/{total_fmt}', - ncols=80) as t: + with tqdm(total=len(ops), + bar_format= + 'Adding quant op with weight:|{bar}| {n_fmt}/{total_fmt}', + ncols=80) as t: for op in ops: if op.name() in self._quantizable_ops: if not self._is_skip_quant(graph, op) and _has_weight(op): @@ -405,12 +406,8 @@ class QuantizationTransformPass(object): var_type=core.VarDesc.VarType.LOD_TENSOR, shape=[1], var_dtype=core.VarDesc.VarType.INT64) - _init_var_node( - global_step_in, - np.zeros( - [1], dtype='int64'), - self._scope, - self._place) + _init_var_node(global_step_in, np.zeros([1], dtype='int64'), + self._scope, self._place) global_step_out = graph.create_var_node_from_desc( global_step_in.var()) # The attribute of `op_role` is needed by ParallelExecutor. @@ -459,12 +456,9 @@ class QuantizationTransformPass(object): var_dtype=var_node.dtype()) data_type = 'float64' if var_node.dtype( ) == core.VarDesc.VarType.FP64 else 'float32' - _init_var_node( - scale_var_node, - np.zeros( - scale_var_node.shape(), dtype=data_type), - self._scope, - self._place) + _init_var_node(scale_var_node, + np.zeros(scale_var_node.shape(), dtype=data_type), + self._scope, self._place) quant_op_node = graph.create_op_node( op_type='fake_quantize_abs_max', attrs={ @@ -472,8 +466,10 @@ class QuantizationTransformPass(object): 'op_role': core.op_proto_and_checker_maker.OpRole.Forward }, inputs={'X': var_node}, - outputs={'Out': quant_var_node, - 'OutScale': scale_var_node}) + outputs={ + 'Out': quant_var_node, + 'OutScale': scale_var_node + }) graph.link_to(var_node, quant_op_node) graph.link_to(quant_op_node, quant_var_node) graph.link_to(quant_op_node, scale_var_node) @@ -498,12 +494,9 @@ class QuantizationTransformPass(object): var_dtype=var_node.dtype()) data_type = 'float64' if var_node.dtype( ) == core.VarDesc.VarType.FP64 else 'float32' - _init_var_node( - scale_in_node, - np.array( - [_SCALE_DEFAULT_VALUE], dtype=data_type), - self._scope, - self._place) + _init_var_node(scale_in_node, + np.array([_SCALE_DEFAULT_VALUE], dtype=data_type), + self._scope, self._place) scale_out_node = graph.create_var_node_from_desc(scale_in_node.var()) inputs = {'X': var_node, 'InScale': scale_in_node} @@ -518,12 +511,9 @@ class QuantizationTransformPass(object): var_dtype=var_node.dtype()) data_type = 'float64' if var_node.dtype( ) == core.VarDesc.VarType.FP64 else 'float32' - _init_var_node( - scales_node, - np.zeros( - [self._window_size], dtype=data_type), - self._scope, - self._place) + _init_var_node(scales_node, + np.zeros([self._window_size], dtype=data_type), + self._scope, self._place) inputs['Iter'] = self._global_step outputs['OutScales'] = scales_node @@ -566,12 +556,9 @@ class QuantizationTransformPass(object): var_dtype=var_node.dtype()) data_type = 'float64' if var_node.dtype( ) == core.VarDesc.VarType.FP64 else 'float32' - _init_var_node( - scale_in_node, - np.array( - [_SCALE_DEFAULT_VALUE], dtype=data_type), - self._scope, - self._place) + _init_var_node(scale_in_node, + np.array([_SCALE_DEFAULT_VALUE], dtype=data_type), + self._scope, self._place) scale_out_node = graph.create_var_node_from_desc(scale_in_node.var()) ins = {'X': var_node, 'InScale': scale_in_node} @@ -584,27 +571,19 @@ class QuantizationTransformPass(object): shape=[1]) data_type = 'float64' if var_node.dtype( ) == core.VarDesc.VarType.FP64 else 'float32' - _init_var_node( - state_in_node, - np.ones( - [1], dtype=data_type), - self._scope, - self._place) + _init_var_node(state_in_node, np.ones([1], dtype=data_type), + self._scope, self._place) accum_in_node = graph.create_persistable_node( name=unique_name.generate('accum'), var_type=core.VarDesc.VarType.LOD_TENSOR, var_dtype=var_node.dtype(), shape=[1]) - _init_var_node( - accum_in_node, - np.ones( - [1], dtype=data_type), - self._scope, - self._place) - state_out_node = graph.create_var_node_from_desc(state_in_node.var( - )) - accum_out_node = graph.create_var_node_from_desc(accum_in_node.var( - )) + _init_var_node(accum_in_node, np.ones([1], dtype=data_type), + self._scope, self._place) + state_out_node = graph.create_var_node_from_desc( + state_in_node.var()) + accum_out_node = graph.create_var_node_from_desc( + accum_in_node.var()) ins['InState'] = state_in_node ins['InAccum'] = accum_in_node @@ -656,12 +635,9 @@ class QuantizationTransformPass(object): var_dtype=var_node.dtype()) data_type = 'float64' if var_node.dtype( ) == core.VarDesc.VarType.FP64 else 'float32' - _init_var_node( - scale_var_node, - np.zeros( - scale_var_node.shape(), dtype=data_type), - self._scope, - self._place) + _init_var_node(scale_var_node, + np.zeros(scale_var_node.shape(), dtype=data_type), + self._scope, self._place) quant_op_node = graph.create_op_node( op_type='fake_channel_wise_quantize_abs_max', attrs={ @@ -671,8 +647,10 @@ class QuantizationTransformPass(object): 'op_role': core.op_proto_and_checker_maker.OpRole.Forward }, inputs={'X': var_node}, - outputs={'Out': quant_var_node, - 'OutScale': scale_var_node}) + outputs={ + 'Out': quant_var_node, + 'OutScale': scale_var_node + }) graph.link_to(var_node, quant_op_node) graph.link_to(quant_op_node, quant_var_node) graph.link_to(quant_op_node, scale_var_node) @@ -696,8 +674,10 @@ class QuantizationTransformPass(object): 'max_range': float(max_range), 'op_role': core.op_proto_and_checker_maker.OpRole.Forward }, - inputs={'X': var_node, - 'Scale': scale_var_node}, + inputs={ + 'X': var_node, + 'Scale': scale_var_node + }, outputs={'Out': dequant_var_node}) graph.link_to(var_node, dequant_op_node) graph.link_to(scale_var_node, dequant_op_node) @@ -723,8 +703,10 @@ class QuantizationTransformPass(object): 'quant_axis': quant_axis, 'op_role': core.op_proto_and_checker_maker.OpRole.Forward }, - inputs={'X': var_node, - 'Scales': scale_var_nodes}, + inputs={ + 'X': var_node, + 'Scales': scale_var_nodes + }, outputs={'Out': dequant_var_node}) graph.link_to(var_node, dequant_op_node) for scale_n in scale_var_nodes: @@ -812,10 +794,9 @@ class QuantizationTransformPass(object): startup_program = Program() with program_guard(tmp_program, startup_program): with unique_name.guard(var_node.name() + "_"): - in_node = data( - var_node.name() + '_tmp_input', - shape=var_node.shape(), - dtype='float32') + in_node = data(var_node.name() + '_tmp_input', + shape=var_node.shape(), + dtype='float32') out_node = func(in_node) graph.out_node_mapping_table[out_node.name] = var_node.name() # loss shape must be 1 when minimize @@ -828,8 +809,8 @@ class QuantizationTransformPass(object): with scope_guard(self._scope): self._exe.run(startup_program) - tmp_graph = IrGraph( - core.Graph(tmp_program.desc), for_test=graph._for_test) + tmp_graph = IrGraph(core.Graph(tmp_program.desc), + for_test=graph._for_test) in_node = tmp_graph._find_node_by_name(tmp_graph.all_var_nodes(), in_node.name) out_node = tmp_graph._find_node_by_name(tmp_graph.all_var_nodes(), @@ -870,9 +851,11 @@ class QuantizationTransformPass(object): # find op's gradient op, such as conv2d_grad op_grad = op_out_grad.outputs[0] target_out_grad_node = graph._find_node_by_name( - graph.all_var_nodes(), target_out_node.name() + "@GRAD") + graph.all_var_nodes(), + target_out_node.name() + "@GRAD") in_node_grad = graph._find_node_by_name( - graph.all_var_nodes(), target_in_node.name() + "@GRAD") + graph.all_var_nodes(), + target_in_node.name() + "@GRAD") in_node_grad_op = in_node_grad.inputs # update op_grad's input graph.update_input_link(var_node, target_out_node, op_grad) @@ -945,6 +928,7 @@ class QuantizationTransformPass(object): class QuantizationFreezePass(object): + def __init__(self, scope, place, @@ -970,8 +954,9 @@ class QuantizationFreezePass(object): weight_bits(int): quantization bit number for weights. activation_bits(int): quantization bit number for activation. round_type(str, optional): The method of converting the quantized weights - value from float to int. Currently supports ['round', 'adaround'] methods. - Default is `round`, which is rounding nearest to the nearest whole number. + value float->int. Currently supports ['round', 'adaround'] methods. + Default is `round`, which is rounding nearest to the integer. + 'adaround' is refer to https://arxiv.org/abs/2004.10568. weight_quantize_type(str): quantization type for weights, support 'abs_max' and 'channel_wise_abs_max'. The 'range_abs_max' usually is not used for weight, since weights are fixed once the model is well trained. @@ -1017,7 +1002,8 @@ class QuantizationFreezePass(object): input_arg_name] if input_arg_name not in persistable_vars: scale_v = graph._find_node_by_name( - op_node.outputs, op_node.output('OutScale')[0]) + op_node.outputs, + op_node.output('OutScale')[0]) self._quant_var_scale_map[input_arg_name] = scale_v else: # Obtain scale from OutScale var node @@ -1033,8 +1019,8 @@ class QuantizationFreezePass(object): scale_v = scale_v.tolist() self._quant_var_scale_map[input_arg_name] = scale_v # Quantize weight and restore - param_v = self._load_var(input_arg_name) if self._round_type == 'round': + param_v = self._load_var(input_arg_name) if any( _check_grandchild_op_node(op_node, op) for op in utils._channelwise_quant_axis1_ops): @@ -1045,6 +1031,7 @@ class QuantizationFreezePass(object): param_v.copy(), scale_v, quant_axis, self._weight_bits) quantized_param_v = np.round(quantized_param_v) + # Weight bias correction if self._bias_correction == True: quantized_param_v = utils.bias_correction_w( param_v, @@ -1072,8 +1059,8 @@ class QuantizationFreezePass(object): if self._weight_quantize_type == 'channel_wise_abs_max': quant_axis = 1 if op_node.name() in \ utils._channelwise_quant_axis1_ops else 0 - self._insert_post_channel_dequant_op(graph, op_node, - quant_axis) + self._insert_post_channel_dequant_op( + graph, op_node, quant_axis) else: self._insert_post_dequant_op(graph, op_node) @@ -1128,7 +1115,8 @@ class QuantizationFreezePass(object): " more than one output." % (op_node.name())) output_var_node = graph._find_node_by_name( - op_node.outputs, op_node.output_arg_names()[0]) + op_node.outputs, + op_node.output_arg_names()[0]) weight_scale_node = graph.create_persistable_node( name=unique_name.generate('channel_scale'), var_type=core.VarDesc.VarType.LOD_TENSOR, @@ -1136,9 +1124,8 @@ class QuantizationFreezePass(object): var_dtype=output_var_node.dtype()) data_type = 'float64' if output_var_node.dtype( ) == core.VarDesc.VarType.FP64 else 'float32' - _init_var_node(weight_scale_node, - channel_scale.astype(data_type), self._scope, - self._place) + _init_var_node(weight_scale_node, channel_scale.astype(data_type), + self._scope, self._place) dequant_var_node = graph.create_var_node( name=self._dequantized_var_name(output_var_node.name()), var_type=output_var_node.type(), @@ -1201,7 +1188,8 @@ class QuantizationFreezePass(object): " more than one output." % (op_node.name())) output_var_node = graph._find_node_by_name( - op_node.outputs, op_node.output_arg_names()[0]) + op_node.outputs, + op_node.output_arg_names()[0]) dequant_var_node = graph.create_var_node( name=self._dequantized_var_name(output_var_node.name()), var_type=output_var_node.type(), @@ -1213,8 +1201,10 @@ class QuantizationFreezePass(object): 'max_range': float(max_range), 'op_role': core.op_proto_and_checker_maker.OpRole.Forward }, - inputs={'X': output_var_node, - 'Scale': scale_var_node}, + inputs={ + 'X': output_var_node, + 'Scale': scale_var_node + }, outputs={'Out': dequant_var_node}) graph.link_to(output_var_node, dequant_op_node) graph.link_to(scale_var_node, dequant_op_node) @@ -1273,6 +1263,7 @@ class QuantizationFreezePass(object): class ConvertToInt8Pass(object): + def __init__(self, scope, place, quantizable_op_type=None): """ Convert the weights into int8_t type. @@ -1312,8 +1303,8 @@ class ConvertToInt8Pass(object): name = var_node.name() if name in persistable_vars: if name not in input_map: - int8_var_node = self._convert_to_int8(graph, - var_node) + int8_var_node = self._convert_to_int8( + graph, var_node) input_map[name] = int8_var_node graph.update_input_link(var_node, input_map[name], op_node) @@ -1361,6 +1352,7 @@ class ConvertToInt8Pass(object): class TransformForMobilePass(object): + def __init__(self): """ This pass is used to convert the frozen graph for paddle-mobile execution. @@ -1403,6 +1395,7 @@ class TransformForMobilePass(object): class OutScaleForTrainingPass(object): + def __init__(self, scope=None, place=None, moving_rate=0.9): """ This pass is used for calculating output scales of some operators. @@ -1436,10 +1429,9 @@ class OutScaleForTrainingPass(object): for op in graph.all_op_nodes(): if op.name() in self._teller_set: target_ops.append(op) - with tqdm( - total=len(target_ops), - bar_format='Adding OutScale op:|{bar}| {n_fmt}/{total_fmt}', - ncols=80) as t: + with tqdm(total=len(target_ops), + bar_format='Adding OutScale op:|{bar}| {n_fmt}/{total_fmt}', + ncols=80) as t: for op in target_ops: for output_var_name in utils._get_op_output_var_names(op): in_node = graph._find_node_by_name(op.outputs, @@ -1455,12 +1447,8 @@ class OutScaleForTrainingPass(object): var_dtype=in_node.dtype()) data_type = 'float64' if in_node.dtype() \ == core.VarDesc.VarType.FP64 else 'float32' - _init_var_node( - scale_node, - np.ones( - [1], dtype=data_type), - self._scope, - self._place) + _init_var_node(scale_node, np.ones([1], dtype=data_type), + self._scope, self._place) ins = {'X': in_node} outs = {'OutScale': scale_node} if not self._is_test: @@ -1469,23 +1457,17 @@ class OutScaleForTrainingPass(object): var_type=core.VarDesc.VarType.LOD_TENSOR, var_dtype=in_node.dtype(), shape=[1]) - _init_var_node( - state_in_node, - np.ones( - [1], dtype=data_type), - self._scope, - self._place) + _init_var_node(state_in_node, + np.ones([1], dtype=data_type), + self._scope, self._place) accum_in_node = graph.create_persistable_node( name=unique_name.generate('scale_accum@'), var_type=core.VarDesc.VarType.LOD_TENSOR, var_dtype=in_node.dtype(), shape=[1]) - _init_var_node( - accum_in_node, - np.ones( - [1], dtype=data_type), - self._scope, - self._place) + _init_var_node(accum_in_node, + np.ones([1], dtype=data_type), + self._scope, self._place) state_out_node = graph.create_var_node_from_desc( state_in_node.var()) accum_out_node = graph.create_var_node_from_desc( @@ -1525,6 +1507,7 @@ class OutScaleForTrainingPass(object): class OutScaleForInferencePass(object): + def __init__(self, scope=None): """ This pass is used for setting output scales of some operators. @@ -1566,8 +1549,8 @@ class OutScaleForInferencePass(object): # For compatibility, we save output threshold by two methods. op_node.op()._set_attr("out_threshold", float(scale_value)) - argname_index = utils._get_output_name_index(op_node, - var_name) + argname_index = utils._get_output_name_index( + op_node, var_name) assert argname_index is not None, \ var_name + " is not the output of the op" op_node.op()._set_attr(argname_index[0] + str(argname_index[1]) \ @@ -1660,10 +1643,10 @@ class AddQuantDequantPass(object): # Forward stage, insert quant_dequant op all_op_nodes = graph.all_op_nodes() - with tqdm( - total=len(all_op_nodes), - bar_format='Adding quant activation op:|{bar}| {n_fmt}/{total_fmt}', - ncols=80) as t: + with tqdm(total=len(all_op_nodes), + bar_format= + 'Adding quant activation op:|{bar}| {n_fmt}/{total_fmt}', + ncols=80) as t: for op_node in all_op_nodes: if op_node.name() in self._quantizable_op_type: is_skip = False @@ -1685,8 +1668,8 @@ class AddQuantDequantPass(object): op_node.op()._set_attr("with_quant_attr", True) arg_names = utils._get_op_input_var_names(op_node) for arg_name in arg_names: - in_node = graph._find_node_by_name(op_node.inputs, - arg_name) + in_node = graph._find_node_by_name( + op_node.inputs, arg_name) if arg_name in dequantized_vars_map: quant_var_node = dequantized_vars_map[arg_name] else: @@ -1703,8 +1686,8 @@ class AddQuantDequantPass(object): if op_node.name() in self._quantizable_grad_op_type: for input_name in op_node.input_arg_names(): if input_name in dequantized_vars_map: - in_node = graph._find_node_by_name(op_node.inputs, - input_name) + in_node = graph._find_node_by_name( + op_node.inputs, input_name) dequant_var_node = dequantized_vars_map[input_name] graph.update_input_link(in_node, dequant_var_node, op_node) @@ -1716,11 +1699,11 @@ class AddQuantDequantPass(object): quant_bits): """Insert fake_quantize_dequantize_moving_average_abs_max op. """ - quant_var_node = graph.create_var_node( - name="{}.quant_dequant".format(var_node.name()), - var_type=var_node.type(), - shape=var_node.shape(), - var_dtype=var_node.dtype()) + quant_var_node = graph.create_var_node(name="{}.quant_dequant".format( + var_node.name()), + var_type=var_node.type(), + shape=var_node.shape(), + var_dtype=var_node.dtype()) scale_in_node = graph.create_persistable_node( name="{}.quant_dequant.scale".format(var_node.name()), var_type=core.VarDesc.VarType.LOD_TENSOR, @@ -1728,12 +1711,9 @@ class AddQuantDequantPass(object): var_dtype=var_node.dtype()) data_type = 'float64' if var_node.dtype( ) == core.VarDesc.VarType.FP64 else 'float32' - _init_var_node( - scale_in_node, - np.array( - [_SCALE_DEFAULT_VALUE], dtype=data_type), - self._scope, - self._place) + _init_var_node(scale_in_node, + np.array([_SCALE_DEFAULT_VALUE], dtype=data_type), + self._scope, self._place) scale_out_node = graph.create_var_node_from_desc(scale_in_node.var()) ins = {'X': var_node, 'InScale': scale_in_node} @@ -1746,27 +1726,19 @@ class AddQuantDequantPass(object): shape=[1]) data_type = 'float64' if var_node.dtype( ) == core.VarDesc.VarType.FP64 else 'float32' - _init_var_node( - state_in_node, - np.ones( - [1], dtype=data_type), - self._scope, - self._place) + _init_var_node(state_in_node, np.ones([1], dtype=data_type), + self._scope, self._place) accum_in_node = graph.create_persistable_node( name=unique_name.generate('quant_dequant.accum'), var_type=core.VarDesc.VarType.LOD_TENSOR, var_dtype=var_node.dtype(), shape=[1]) - _init_var_node( - accum_in_node, - np.ones( - [1], dtype=data_type), - self._scope, - self._place) - state_out_node = graph.create_var_node_from_desc(state_in_node.var( - )) - accum_out_node = graph.create_var_node_from_desc(accum_in_node.var( - )) + _init_var_node(accum_in_node, np.ones([1], dtype=data_type), + self._scope, self._place) + state_out_node = graph.create_var_node_from_desc( + state_in_node.var()) + accum_out_node = graph.create_var_node_from_desc( + accum_in_node.var()) ins['InState'] = state_in_node ins['InAccum'] = accum_in_node @@ -1833,11 +1805,11 @@ class InsertQuantizeLinear(object): def insert_quant_op(self, graph, var_node): assert var_node.is_var(), '{} is not a var'.format(var_node.name()) - quant_var_node = graph.create_var_node( - name=self._quantized_var_name(var_node.name()), - var_type=var_node.type(), - shape=var_node.shape(), - var_dtype=var_node.dtype()) + quant_var_node = graph.create_var_node(name=self._quantized_var_name( + var_node.name()), + var_type=var_node.type(), + shape=var_node.shape(), + var_dtype=var_node.dtype()) data_type = 'float64' if var_node.dtype( ) == core.VarDesc.VarType.FP64 else 'float32' if self.channel_wise: @@ -1863,12 +1835,9 @@ class InsertQuantizeLinear(object): var_type=core.VarDesc.VarType.LOD_TENSOR, shape=scale_var_node.shape(), var_dtype=core.VarDesc.VarType.INT32) - _init_var_node( - zero_point_node, - np.zeros( - scale_var_node.shape(), dtype="int32"), - self._scope, - self._place) + _init_var_node(zero_point_node, + np.zeros(scale_var_node.shape(), dtype="int32"), + self._scope, self._place) inputs = {"X": var_node, "Scale": scale_var_node} if zero_point_node is not None: @@ -1879,15 +1848,14 @@ class InsertQuantizeLinear(object): if not self._is_test: attrs["is_test"] = self._is_test attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward - scale_out_node = graph.create_var_node_from_desc(scale_var_node.var( - )) + scale_out_node = graph.create_var_node_from_desc( + scale_var_node.var()) outputs["OutScale"] = scale_out_node - quant_op_node = graph.create_op_node( - op_type="quantize_linear", - attrs=attrs, - inputs=inputs, - outputs=outputs) + quant_op_node = graph.create_op_node(op_type="quantize_linear", + attrs=attrs, + inputs=inputs, + outputs=outputs) graph.link_to(var_node, quant_op_node) graph.link_to(scale_var_node, quant_op_node) @@ -1914,12 +1882,9 @@ class InsertQuantizeLinear(object): var_type=core.VarDesc.VarType.LOD_TENSOR, shape=scale_var_node.shape(), var_dtype=core.VarDesc.VarType.INT32) - _init_var_node( - zero_point_node, - np.zeros( - scale_var_node.shape(), dtype="int32"), - self._scope, - self._place) + _init_var_node(zero_point_node, + np.zeros(scale_var_node.shape(), dtype="int32"), + self._scope, self._place) inputs = {"X": var_node, "Scale": scale_var_node} if zero_point_node is not None: @@ -1929,11 +1894,10 @@ class InsertQuantizeLinear(object): if not self._is_test: attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward - quant_op_node = graph.create_op_node( - op_type="dequantize_linear", - attrs=attrs, - inputs=inputs, - outputs={"Y": dequant_var_node}) + quant_op_node = graph.create_op_node(op_type="dequantize_linear", + attrs=attrs, + inputs=inputs, + outputs={"Y": dequant_var_node}) graph.link_to(var_node, quant_op_node) graph.link_to(scale_var_node, quant_op_node) @@ -2151,11 +2115,13 @@ class QuantizationTransformPassV2(object): # will insert activation preprocess func # to preorocess activation before quantization if is_weight and self._weight_preprocess_func is not None: - var_node = self._insert_func( - graph, self._weight_preprocess_func, var_node, op) + var_node = self._insert_func(graph, + self._weight_preprocess_func, + var_node, op) elif not is_weight and self._act_preprocess_func is not None: - var_node = self._insert_func( - graph, self._act_preprocess_func, var_node, op) + var_node = self._insert_func(graph, + self._act_preprocess_func, + var_node, op) # if var node is weight and weight_quantize_func is not None, # will insert weight quantize func to quantize and dequantize weight @@ -2167,8 +2133,9 @@ class QuantizationTransformPassV2(object): processed_vars.append(name) continue elif not is_weight and self._act_quantize_func is not None: - target_out_node = self._insert_func( - graph, self._act_quantize_func, var_node, op) + target_out_node = self._insert_func(graph, + self._act_quantize_func, + var_node, op) processed_vars.append(name) continue @@ -2263,10 +2230,10 @@ class QuantizationTransformPassV2(object): graph.out_node_mapping_table = dict() # The process of _transform_forward and _transform_backward is needed in two for loops. # The loop for transforming the forward graph: - with tqdm( - total=len(ops), - bar_format='Adding quant op for weight:|{bar}| {n_fmt}/{total_fmt}', - ncols=80) as t: + with tqdm(total=len(ops), + bar_format= + 'Adding quant op with weight:|{bar}| {n_fmt}/{total_fmt}', + ncols=80) as t: for op in ops: if op.name() in self._quantizable_ops: if not self._is_skip_quant(graph, @@ -2375,10 +2342,10 @@ class AddQuantDequantPassV2(object): # Forward stage, insert quant_dequant op all_op_nodes = graph.all_op_nodes() - with tqdm( - total=len(all_op_nodes), - bar_format='Adding quant activation op:|{bar}| {n_fmt}/{total_fmt}', - ncols=80) as t: + with tqdm(total=len(all_op_nodes), + bar_format= + 'Adding quant activation op:|{bar}| {n_fmt}/{total_fmt}', + ncols=80) as t: for op_node in all_op_nodes: if op_node.name() in self._quantizable_op_type: is_skip = False @@ -2397,8 +2364,8 @@ class AddQuantDequantPassV2(object): "qat_without_weight") arg_names = utils._get_op_input_var_names(op_node) for arg_name in arg_names: - in_node = graph._find_node_by_name(op_node.inputs, - arg_name) + in_node = graph._find_node_by_name( + op_node.inputs, arg_name) if in_node.persistable(): continue if arg_name in dequantized_vars_map: @@ -2425,8 +2392,8 @@ class AddQuantDequantPassV2(object): if op_node.name() in self._quantizable_grad_op_type: for input_name in op_node.input_arg_names(): if input_name in dequantized_vars_map: - in_node = graph._find_node_by_name(op_node.inputs, - input_name) + in_node = graph._find_node_by_name( + op_node.inputs, input_name) dequant_var_node = dequantized_vars_map[input_name] graph.update_input_link(in_node, dequant_var_node, op_node) @@ -2502,43 +2469,42 @@ class ReplaceFakeQuantDequantPass(object): var_type=core.VarDesc.VarType.LOD_TENSOR, shape=scale_node.shape(), var_dtype=core.VarDesc.VarType.INT32) - _init_var_node( - zero_point_node, - np.zeros( - scale_node.shape(), dtype="int32"), - self._scope, - self._place) - - quant_var_node = graph.create_var_node( - name=self._quantized_var_name(x_node.name()), - var_type=x_node.type(), - shape=x_node.shape(), - var_dtype=x_node.dtype()) - quant_op_node = graph.create_op_node( - op_type="quantize_linear", - attrs={"quant_axis": quant_axis, - "bit_length": bit_length}, - inputs={ - "X": x_node, - "Scale": scale_node, - "ZeroPoint": zero_point_node - }, - outputs={"Y": quant_var_node}) + _init_var_node(zero_point_node, + np.zeros(scale_node.shape(), dtype="int32"), + self._scope, self._place) + + quant_var_node = graph.create_var_node(name=self._quantized_var_name( + x_node.name()), + var_type=x_node.type(), + shape=x_node.shape(), + var_dtype=x_node.dtype()) + quant_op_node = graph.create_op_node(op_type="quantize_linear", + attrs={ + "quant_axis": quant_axis, + "bit_length": bit_length + }, + inputs={ + "X": x_node, + "Scale": scale_node, + "ZeroPoint": zero_point_node + }, + outputs={"Y": quant_var_node}) graph.link_to(x_node, quant_op_node) graph.link_to(scale_node, quant_op_node) if zero_point_node is not None: graph.link_to(zero_point_node, quant_op_node) graph.link_to(quant_op_node, quant_var_node) - dequant_op_node = graph.create_op_node( - op_type="dequantize_linear", - attrs={"quant_axis": quant_axis, - "bit_length": bit_length}, - inputs={ - "X": quant_var_node, - "Scale": scale_node, - "ZeroPoint": zero_point_node - }, - outputs={"Y": out_node}) + dequant_op_node = graph.create_op_node(op_type="dequantize_linear", + attrs={ + "quant_axis": quant_axis, + "bit_length": bit_length + }, + inputs={ + "X": quant_var_node, + "Scale": scale_node, + "ZeroPoint": zero_point_node + }, + outputs={"Y": out_node}) graph.link_to(quant_var_node, dequant_op_node) graph.link_to(scale_node, dequant_op_node) if zero_point_node is not None: @@ -2617,7 +2583,8 @@ class QuantWeightPass(object): scale_node = graph._find_node_by_name(_op.inputs, _op.input("Scale")[0]) zero_point_node = graph._find_node_by_name( - _op.inputs, _op.input("ZeroPoint")[0]) + _op.inputs, + _op.input("ZeroPoint")[0]) out_node = graph._find_node_by_name(_op.outputs, _op.output("Y")[0]) @@ -2633,8 +2600,11 @@ class QuantWeightPass(object): param_v = self._load_var(x_node.name()) quant_axis = _op.op().attr("quant_axis") bits_length = _op.op().attr("bit_length") - quantized_param_v = utils.quant_tensor(param_v.copy(), scale_v, - quant_axis, bits_length) + quantized_param_v = utils.quant_tensor(param_v.copy(), + scale_v, + quant_axis, + bits_length, + onnx_format=True) if self._bias_correction == True: quantized_param_v = utils.bias_correction_w( param_v, diff --git a/python/paddle/fluid/contrib/slim/quantization/utils.py b/python/paddle/fluid/contrib/slim/quantization/utils.py index f8867a346fb..28efcd2d591 100644 --- a/python/paddle/fluid/contrib/slim/quantization/utils.py +++ b/python/paddle/fluid/contrib/slim/quantization/utils.py @@ -321,7 +321,7 @@ def set_variable_data(scope, place, var_name, np_value): tensor.set(np_value, place) -def quant_tensor(x, scale, quant_axis=0, weight_bits=8): +def quant_tensor(x, scale, quant_axis=0, weight_bits=8, onnx_format=False): # symmetry quant def _clip(x, scale): x[x > scale] = scale @@ -335,15 +335,27 @@ def quant_tensor(x, scale, quant_axis=0, weight_bits=8): if s == 0.0: s = 1e-8 if quant_axis == 0: - x[i] = _clip(x[i], s) - x[i] = x[i] / s * bnt + if onnx_format: + x[i] = np.round(x[i] / s * bnt) + x[i] = np.clip(x[i], -bnt - 1, bnt) + else: + x[i] = _clip(x[i], s) + x[i] = x[i] / s * bnt else: - x[:, i] = _clip(x[:, i], s) - x[:, i] = x[:, i] / s * bnt + if onnx_format: + x[:, i] = np.round(x[:, i] / s * bnt) + x[:, i] = np.clip(x[:, i], -bnt - 1, bnt) + else: + x[:, i] = _clip(x[:, i], s) + x[:, i] = x[:, i] / s * bnt else: scale = 1e-8 if scale == 0.0 else scale - x = _clip(x, scale) - x = x / scale * bnt + if onnx_format: + x = np.round(x / scale * bnt) + x = np.clip(x, -bnt - 1, bnt) + else: + x = _clip(x, scale) + x = x / scale * bnt return x @@ -416,6 +428,7 @@ def calculate_quant_cos_error(orig_tensor, qdq_tensor): class tqdm(object): + def __init__(self, total, bar_format='Loading|{bar}', ncols=80): self.total = total self.bar_format = bar_format diff --git a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt index 0140283b915..88dc33f581a 100644 --- a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt +++ b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt @@ -1,352 +1,523 @@ -file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") +file( + GLOB TEST_OPS + RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" + "test_*.py") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") -function(_inference_analysis_python_api_int8_test target model_dir data_path filename use_mkldnn) - py_test(${target} SRCS ${filename} - ENVS CPU_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} - FLAGS_use_mkldnn=${use_mkldnn} - ARGS --infer_model ${model_dir}/model - --infer_data ${data_path} - --int8_model_save_path int8_models/${target} - --warmup_batch_size ${WARMUP_BATCH_SIZE} - --batch_size 50) +function(_inference_analysis_python_api_int8_test target model_dir data_path + filename use_mkldnn) + py_test( + ${target} + SRCS ${filename} + ENVS + CPU_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} + FLAGS_use_mkldnn=${use_mkldnn} + ARGS + --infer_model + ${model_dir}/model + --infer_data + ${data_path} + --int8_model_save_path + int8_models/${target} + --warmup_batch_size + ${WARMUP_BATCH_SIZE} + --batch_size + 50) endfunction() -function(inference_analysis_python_api_int8_test target model_dir data_path filename) - _inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_path} ${filename} False) +function(inference_analysis_python_api_int8_test target model_dir data_path + filename) + _inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_path} + ${filename} False) endfunction() -function(inference_analysis_python_api_int8_test_custom_warmup_batch_size target model_dir data_dir filename warmup_batch_size) - set(WARMUP_BATCH_SIZE ${warmup_batch_size}) - inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_dir} ${filename}) +function(inference_analysis_python_api_int8_test_custom_warmup_batch_size + target model_dir data_dir filename warmup_batch_size) + set(WARMUP_BATCH_SIZE ${warmup_batch_size}) + inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_dir} + ${filename}) endfunction() -function(inference_analysis_python_api_int8_test_mkldnn target model_dir data_path filename) - _inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_path} ${filename} True) +function(inference_analysis_python_api_int8_test_mkldnn target model_dir + data_path filename) + _inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_path} + ${filename} True) endfunction() function(download_data install_dir url data_file check_sum) - if (NOT EXISTS ${install_dir}/${data_file}) - inference_download_and_uncompress(${install_dir} ${url} ${data_file} ${check_sum}) - endif() + if(NOT EXISTS ${install_dir}/${data_file}) + inference_download_and_uncompress(${install_dir} ${url} ${data_file} + ${check_sum}) + endif() endfunction() function(download_quant_data install_dir data_file check_sum) - if (NOT EXISTS ${install_dir}/${data_file}) - inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8 ${data_file} ${check_sum}) - endif() + if(NOT EXISTS ${install_dir}/${data_file}) + inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8 + ${data_file} ${check_sum}) + endif() endfunction() function(download_quant_model install_dir data_file check_sum) - if (NOT EXISTS ${install_dir}/${data_file}) - inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8/QAT_models ${data_file} ${check_sum}) - endif() + if(NOT EXISTS ${install_dir}/${data_file}) + inference_download_and_uncompress( + ${install_dir} ${INFERENCE_URL}/int8/QAT_models ${data_file} ${check_sum}) + endif() endfunction() function(download_quant_fp32_model install_dir data_file check_sum) - if (NOT EXISTS ${install_dir}/${data_file}) - inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8/QAT_models/fp32 ${data_file} ${check_sum}) - endif() + if(NOT EXISTS ${install_dir}/${data_file}) + inference_download_and_uncompress( + ${install_dir} ${INFERENCE_URL}/int8/QAT_models/fp32 ${data_file} + ${check_sum}) + endif() endfunction() function(download_lstm_model install_dir data_file check_sum) - if (NOT EXISTS ${install_dir}/${data_file}) - inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/lstm ${data_file} ${check_sum}) - endif() + if(NOT EXISTS ${install_dir}/${data_file}) + inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/lstm + ${data_file} ${check_sum}) + endif() endfunction() -function(inference_quant_int8_image_classification_test target quant_model_dir dataset_path) - py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/quant_int8_image_classification_comparison.py" - ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} - OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} - FLAGS_use_mkldnn=true - ARGS --quant_model ${quant_model_dir} - --infer_data ${dataset_path} - --batch_size 25 - --batch_num 2 - --acc_diff_threshold 0.1) +function(inference_quant_int8_image_classification_test target quant_model_dir + dataset_path) + py_test( + ${target} + SRCS "${CMAKE_CURRENT_SOURCE_DIR}/quant_int8_image_classification_comparison.py" + ENVS + FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} + OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} + FLAGS_use_mkldnn=true + ARGS + --quant_model + ${quant_model_dir} + --infer_data + ${dataset_path} + --batch_size + 25 + --batch_num + 2 + --acc_diff_threshold + 0.1) endfunction() - -# set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 25 -function(inference_quant2_int8_image_classification_test target quant_model_dir fp32_model_dir dataset_path) - py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/quant2_int8_image_classification_comparison.py" - ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} - OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} - FLAGS_use_mkldnn=true - ARGS --quant_model ${quant_model_dir} - --fp32_model ${fp32_model_dir} - --infer_data ${dataset_path} - --batch_size 50 - --batch_num 2 - --acc_diff_threshold 0.1) +# set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 25 +function(inference_quant2_int8_image_classification_test target quant_model_dir + fp32_model_dir dataset_path) + py_test( + ${target} + SRCS "${CMAKE_CURRENT_SOURCE_DIR}/quant2_int8_image_classification_comparison.py" + ENVS + FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} + OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} + FLAGS_use_mkldnn=true + ARGS + --quant_model + ${quant_model_dir} + --fp32_model + ${fp32_model_dir} + --infer_data + ${dataset_path} + --batch_size + 50 + --batch_num + 2 + --acc_diff_threshold + 0.1) endfunction() -# set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 20 -function(inference_quant2_int8_nlp_test target quant_model_dir fp32_model_dir dataset_path labels_path ops_to_quantize) - py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/quant2_int8_nlp_comparison.py" - ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} - OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} - FLAGS_use_mkldnn=true - ARGS --quant_model ${quant_model_dir} - --fp32_model ${fp32_model_dir} - --infer_data ${dataset_path} - --labels ${labels_path} - --batch_size 10 - --batch_num 2 - --acc_diff_threshold 0.1 - --ops_to_quantize ${ops_to_quantize}) +# set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 20 +function( + inference_quant2_int8_nlp_test + target + quant_model_dir + fp32_model_dir + dataset_path + labels_path + ops_to_quantize) + py_test( + ${target} + SRCS "${CMAKE_CURRENT_SOURCE_DIR}/quant2_int8_nlp_comparison.py" + ENVS + FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} + OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} + FLAGS_use_mkldnn=true + ARGS + --quant_model + ${quant_model_dir} + --fp32_model + ${fp32_model_dir} + --infer_data + ${dataset_path} + --labels + ${labels_path} + --batch_size + 10 + --batch_num + 2 + --acc_diff_threshold + 0.1 + --ops_to_quantize + ${ops_to_quantize}) endfunction() -function(inference_quant2_int8_lstm_model_test target fp32_model quant_model dataset_path) - py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/quant2_int8_lstm_model.py" - ARGS --fp32_model ${fp32_model} - --quant_model ${quant_model} - --infer_data ${dataset_path} - --num_threads 1 - --mkldnn_cache_capacity 100 - --warmup_iter 100 - --acc_diff_threshold 0.11) +function(inference_quant2_int8_lstm_model_test target fp32_model quant_model + dataset_path) + py_test( + ${target} + SRCS "${CMAKE_CURRENT_SOURCE_DIR}/quant2_int8_lstm_model.py" + ARGS + --fp32_model + ${fp32_model} + --quant_model + ${quant_model} + --infer_data + ${dataset_path} + --num_threads + 1 + --mkldnn_cache_capacity + 100 + --warmup_iter + 100 + --acc_diff_threshold + 0.11) endfunction() function(download_quant_data install_dir data_file check_sum) - if (NOT EXISTS ${install_dir}/${data_file}) - inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8 ${data_file} ${check_sum}) - endif() + if(NOT EXISTS ${install_dir}/${data_file}) + inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8 + ${data_file} ${check_sum}) + endif() endfunction() function(download_quant_model install_dir data_file check_sum) - if (NOT EXISTS ${install_dir}/${data_file}) - inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8/QAT_models ${data_file} ${check_sum}) - endif() + if(NOT EXISTS ${install_dir}/${data_file}) + inference_download_and_uncompress( + ${install_dir} ${INFERENCE_URL}/int8/QAT_models ${data_file} ${check_sum}) + endif() endfunction() function(save_quant_ic_model_test target quant_model_dir int8_model_save_path) - py_test(${target} SRCS ${CMAKE_CURRENT_SOURCE_DIR}/save_quant_model.py - ARGS --quant_model_path ${quant_model_dir} - --int8_model_save_path ${int8_model_save_path} - --debug) + py_test( + ${target} + SRCS ${CMAKE_CURRENT_SOURCE_DIR}/save_quant_model.py + ARGS + --quant_model_path + ${quant_model_dir} + --int8_model_save_path + ${int8_model_save_path} + --debug) endfunction() -function(save_quant_nlp_model_test target quant_model_dir int8_model_save_path ops_to_quantize) - py_test(${target} SRCS ${CMAKE_CURRENT_SOURCE_DIR}/save_quant_model.py - ARGS --quant_model_path ${quant_model_dir} - --int8_model_save_path ${int8_model_save_path} - --ops_to_quantize ${ops_to_quantize}) +function(save_quant_nlp_model_test target quant_model_dir int8_model_save_path + ops_to_quantize) + py_test( + ${target} + SRCS ${CMAKE_CURRENT_SOURCE_DIR}/save_quant_model.py + ARGS + --quant_model_path + ${quant_model_dir} + --int8_model_save_path + ${int8_model_save_path} + --ops_to_quantize + ${ops_to_quantize}) endfunction() -function(convert_model2dot_test target model_path save_graph_dir save_graph_name) - py_test(${target} SRCS ${CMAKE_CURRENT_SOURCE_DIR}/convert_model2dot.py - ARGS --model_path ${model_path} - --save_graph_dir ${save_graph_dir} - --save_graph_name ${save_graph_name}) +function(convert_model2dot_test target model_path save_graph_dir + save_graph_name) + py_test( + ${target} + SRCS ${CMAKE_CURRENT_SOURCE_DIR}/convert_model2dot.py + ARGS + --model_path + ${model_path} + --save_graph_dir + ${save_graph_dir} + --save_graph_name + ${save_graph_name}) endfunction() if(WIN32) - list(REMOVE_ITEM TEST_OPS test_light_nas) - list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mnist) - list(REMOVE_ITEM TEST_OPS test_post_training_quantization_while) - list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mobilenetv1) - list(REMOVE_ITEM TEST_OPS test_post_training_quantization_resnet50) - list(REMOVE_ITEM TEST_OPS test_post_training_quantization_lstm_model) - list(REMOVE_ITEM TEST_OPS test_imperative_ptq) - list(REMOVE_ITEM TEST_OPS test_weight_quantization_mobilenetv1) - list(REMOVE_ITEM TEST_OPS test_quantize_transpiler_v2) - list(REMOVE_ITEM TEST_OPS test_imperative_qat_amp) + list(REMOVE_ITEM TEST_OPS test_light_nas) + list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mnist) + list(REMOVE_ITEM TEST_OPS test_post_training_quantization_while) + list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mobilenetv1) + list(REMOVE_ITEM TEST_OPS test_post_training_quantization_resnet50) + list(REMOVE_ITEM TEST_OPS test_post_training_quantization_lstm_model) + list(REMOVE_ITEM TEST_OPS test_imperative_ptq) + list(REMOVE_ITEM TEST_OPS test_weight_quantization_mobilenetv1) + list(REMOVE_ITEM TEST_OPS test_quantize_transpiler_v2) + list(REMOVE_ITEM TEST_OPS test_imperative_qat_amp) endif() if(LINUX AND WITH_MKLDNN) - #### Image classification dataset: ImageNet (small) - # The dataset should already be downloaded for INT8v2 unit tests - set(IMAGENET_DATA_PATH "${INFERENCE_DEMO_INSTALL_DIR}/imagenet/data.bin") - - #### INT8 image classification python api test - # Models should be already downloaded for INT8v2 unit tests - - set(INT8_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2") - - #### QUANT & INT8 comparison python api tests - - set(QUANT_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/quant") - - ### Quant1 for image classification - - # Quant ResNet50 - set(QUANT_RESNET50_MODEL_DIR "${QUANT_INSTALL_DIR}/ResNet50_quant") - set(QUANT_RESNET50_MODEL_ARCHIVE "ResNet50_qat_model.tar.gz") - download_quant_model(${QUANT_RESNET50_MODEL_DIR} ${QUANT_RESNET50_MODEL_ARCHIVE} ff89b934ab961c3a4a844193ece2e8a7) - inference_quant_int8_image_classification_test(test_quant_int8_resnet50_mkldnn ${QUANT_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH}) - - # Quant ResNet101 - set(QUANT_RESNET101_MODEL_DIR "${QUANT_INSTALL_DIR}/ResNet101_quant") - set(QUANT_RESNET101_MODEL_ARCHIVE "ResNet101_qat_model.tar.gz") - download_quant_model(${QUANT_RESNET101_MODEL_DIR} ${QUANT_RESNET101_MODEL_ARCHIVE} 95c6d01e3aeba31c13efb2ba8057d558) - # inference_quant_int8_image_classification_test(test_quant_int8_resnet101_mkldnn ${QUANT_RESNET101_MODEL_DIR}/model ${IMAGENET_DATA_PATH}) - - # Quant GoogleNet - set(QUANT_GOOGLENET_MODEL_DIR "${QUANT_INSTALL_DIR}/GoogleNet_quant") - set(QUANT_GOOGLENET_MODEL_ARCHIVE "GoogleNet_qat_model.tar.gz") - download_quant_model(${QUANT_GOOGLENET_MODEL_DIR} ${QUANT_GOOGLENET_MODEL_ARCHIVE} 1d4a7383baa63e7d1c423e8db2b791d5) - inference_quant_int8_image_classification_test(test_quant_int8_googlenet_mkldnn ${QUANT_GOOGLENET_MODEL_DIR}/model ${IMAGENET_DATA_PATH}) - - # Quant MobileNetV1 - set(QUANT_MOBILENETV1_MODEL_DIR "${QUANT_INSTALL_DIR}/MobileNetV1_quant") - set(QUANT_MOBILENETV1_MODEL_ARCHIVE "MobileNetV1_qat_model.tar.gz") - download_quant_model(${QUANT_MOBILENETV1_MODEL_DIR} ${QUANT_MOBILENETV1_MODEL_ARCHIVE} 3b774d94a9fcbb604d09bdb731fc1162) - inference_quant_int8_image_classification_test(test_quant_int8_mobilenetv1_mkldnn ${QUANT_MOBILENETV1_MODEL_DIR}/model ${IMAGENET_DATA_PATH}) - - # Quant MobileNetV2 - set(QUANT_MOBILENETV2_MODEL_DIR "${QUANT_INSTALL_DIR}/MobileNetV2_quant") - set(QUANT_MOBILENETV2_MODEL_ARCHIVE "MobileNetV2_qat_model.tar.gz") - download_quant_model(${QUANT_MOBILENETV2_MODEL_DIR} ${QUANT_MOBILENETV2_MODEL_ARCHIVE} 758a99d9225d8b73e1a8765883f96cdd) - inference_quant_int8_image_classification_test(test_quant_int8_mobilenetv2_mkldnn ${QUANT_MOBILENETV2_MODEL_DIR}/model ${IMAGENET_DATA_PATH}) - - # Quant VGG16 - set(QUANT_VGG16_MODEL_DIR "${QUANT_INSTALL_DIR}/VGG16_quant") - set(QUANT_VGG16_MODEL_ARCHIVE "VGG16_qat_model.tar.gz") - download_quant_model(${QUANT_VGG16_MODEL_DIR} ${QUANT_VGG16_MODEL_ARCHIVE} c37e63ca82a102f47be266f8068b0b55) - # inference_quant_int8_image_classification_test(test_quant_int8_vgg16_mkldnn ${QUANT_VGG16_MODEL_DIR}/model ${IMAGENET_DATA_PATH}) - - # Quant VGG19 - set(QUANT_VGG19_MODEL_DIR "${QUANT_INSTALL_DIR}/VGG19_quant") - set(QUANT_VGG19_MODEL_ARCHIVE "VGG19_qat_model.tar.gz") - download_quant_model(${QUANT_VGG19_MODEL_DIR} ${QUANT_VGG19_MODEL_ARCHIVE} 62bcd4b6c3ca2af67e8251d1c96ea18f) - # inference_quant_int8_image_classification_test(test_quant_int8_vgg19_mkldnn ${QUANT_VGG19_MODEL_DIR}/model ${IMAGENET_DATA_PATH}) - - ### Quant2 for image classification - - # Quant2 ResNet50 with input/output scales in `fake_quantize_moving_average_abs_max` operators, - # with weight scales in `fake_dequantize_max_abs` operators - set(QUANT2_RESNET50_MODEL_DIR "${QUANT_INSTALL_DIR}/ResNet50_quant2") - set(QUANT2_RESNET50_MODEL_ARCHIVE "ResNet50_qat_perf.tar.gz") - download_quant_model(${QUANT2_RESNET50_MODEL_DIR} ${QUANT2_RESNET50_MODEL_ARCHIVE} e87309457e8c462a579340607f064d66) - set(FP32_RESNET50_MODEL_DIR "${INT8_INSTALL_DIR}/resnet50") - inference_quant2_int8_image_classification_test(test_quant2_int8_resnet50_mkldnn ${QUANT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH}) - - # Quant2 ResNet50 with input/output scales in `fake_quantize_range_abs_max` operators and the `out_threshold` attributes, - # with weight scales in `fake_dequantize_max_abs` operators - set(QUANT2_RESNET50_RANGE_MODEL_DIR "${QUANT_INSTALL_DIR}/ResNet50_quant2_range") - set(QUANT2_RESNET50_RANGE_MODEL_ARCHIVE "ResNet50_qat_range.tar.gz") - download_quant_model(${QUANT2_RESNET50_RANGE_MODEL_DIR} ${QUANT2_RESNET50_RANGE_MODEL_ARCHIVE} 2fdc8a139f041c0d270abec826b2d304) - inference_quant2_int8_image_classification_test(test_quant2_int8_resnet50_range_mkldnn ${QUANT2_RESNET50_RANGE_MODEL_DIR}/ResNet50_qat_range ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH}) - - # Quant2 ResNet50 with input/output scales in `fake_quantize_range_abs_max` operators and the `out_threshold` attributes, - # with weight scales in `fake_channel_wise_dequantize_max_abs` operators - set(QUANT2_RESNET50_CHANNELWISE_MODEL_DIR "${QUANT_INSTALL_DIR}/ResNet50_quant2_channelwise") - set(QUANT2_RESNET50_CHANNELWISE_MODEL_ARCHIVE "ResNet50_qat_channelwise.tar.gz") - download_quant_model(${QUANT2_RESNET50_CHANNELWISE_MODEL_DIR} ${QUANT2_RESNET50_CHANNELWISE_MODEL_ARCHIVE} 887a1b1b0e9a4efd10f263a43764db26) - inference_quant2_int8_image_classification_test(test_quant2_int8_resnet50_channelwise_mkldnn ${QUANT2_RESNET50_CHANNELWISE_MODEL_DIR}/ResNet50_qat_channelwise ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH}) - - # Quant2 MobileNetV1 - set(QUANT2_MOBILENETV1_MODEL_DIR "${QUANT_INSTALL_DIR}/MobileNetV1_quant2") - set(QUANT2_MOBILENETV1_MODEL_ARCHIVE "MobileNet_qat_perf.tar.gz") - download_quant_model(${QUANT2_MOBILENETV1_MODEL_DIR} ${QUANT2_MOBILENETV1_MODEL_ARCHIVE} 7f626e453db2d56fed6c2538621ffacf) - set(FP32_MOBILENETV1_MODEL_DIR "${INT8_INSTALL_DIR}/mobilenetv1") - inference_quant2_int8_image_classification_test(test_quant2_int8_mobilenetv1_mkldnn ${QUANT2_MOBILENETV1_MODEL_DIR}/MobileNet_qat_perf/float ${FP32_MOBILENETV1_MODEL_DIR}/model ${IMAGENET_DATA_PATH}) - - ### Quant2 for NLP - - set(NLP_DATA_ARCHIVE "Ernie_dataset.tar.gz") - set(NLP_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/Ernie_dataset") - set(NLP_DATA_PATH "${NLP_DATA_DIR}/Ernie_dataset/1.8w.bs1") - set(NLP_LABLES_PATH "${NLP_DATA_DIR}/Ernie_dataset/label.xnli.dev") - download_quant_data(${NLP_DATA_DIR} ${NLP_DATA_ARCHIVE} e650ce0cbc1fadbed5cc2c01d4e734dc) - - # Quant2 Ernie - set(QUANT2_ERNIE_MODEL_ARCHIVE "ernie_qat.tar.gz") - set(QUANT2_ERNIE_MODEL_DIR "${QUANT_INSTALL_DIR}/Ernie_quant2") - download_quant_model(${QUANT2_ERNIE_MODEL_DIR} ${QUANT2_ERNIE_MODEL_ARCHIVE} f7cdf4720755ecf66efbc8044e9922d9) - set(FP32_ERNIE_MODEL_ARCHIVE "ernie_fp32_model.tar.gz") - set(FP32_ERNIE_MODEL_DIR "${QUANT_INSTALL_DIR}/Ernie_float") - download_quant_fp32_model(${FP32_ERNIE_MODEL_DIR} ${FP32_ERNIE_MODEL_ARCHIVE} 114f38804a3ef8c45e7259e68bbd838b) - set(QUANT2_ERNIE_OPS_TO_QUANTIZE "fc,reshape2,transpose2,matmul,elementwise_add,slice") - inference_quant2_int8_nlp_test(test_quant2_int8_ernie_mkldnn ${QUANT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${FP32_ERNIE_MODEL_DIR}/ernie_fp32_model ${NLP_DATA_PATH} ${NLP_LABLES_PATH} ${QUANT2_ERNIE_OPS_TO_QUANTIZE}) - - # Quant2 GRU - set(QUANT2_GRU_MODEL_ARCHIVE "GRU_quant_acc.tar.gz") - set(QUANT2_GRU_MODEL_DIR "${QUANT_INSTALL_DIR}/GRU_quant2") - download_quant_model(${QUANT2_GRU_MODEL_DIR} ${QUANT2_GRU_MODEL_ARCHIVE} cf207f8076dcfb8b74d8b6bdddf9090c) - set(QUANT2_GRU_OPS_TO_QUANTIZE "multi_gru") - - # Quant2 LSTM - set(QUANT2_LSTM_MODEL_ARCHIVE "lstm_quant.tar.gz") - set(QUANT2_LSTM_MODEL_DIR "${QUANT_INSTALL_DIR}/lstm_quant_test") - download_quant_model(${QUANT2_LSTM_MODEL_DIR} ${QUANT2_LSTM_MODEL_ARCHIVE} 40a693803b12ee9e251258f32559abcb) - set(QUANT2_LSTM_OPS_TO_QUANTIZE "fusion_lstm") - - ### Save FP32 model or INT8 model from Quant model - - set(QUANT2_INT8_RESNET50_SAVE_PATH "${QUANT_INSTALL_DIR}/ResNet50_quant2_int8") - save_quant_ic_model_test(save_quant2_model_resnet50 ${QUANT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${QUANT2_INT8_RESNET50_SAVE_PATH}) - - set(QUANT2_INT8_ERNIE_SAVE_PATH "${QUANT_INSTALL_DIR}/Ernie_quant2_int8") - save_quant_nlp_model_test(save_quant2_model_ernie ${QUANT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${QUANT2_INT8_ERNIE_SAVE_PATH} ${QUANT2_ERNIE_OPS_TO_QUANTIZE}) - - set(QUANT2_INT8_GRU_SAVE_PATH "${QUANT_INSTALL_DIR}/GRU_quant2_int8") - save_quant_nlp_model_test(save_quant2_model_gru ${QUANT2_GRU_MODEL_DIR}/GRU_quant_acc ${QUANT2_INT8_GRU_SAVE_PATH} ${QUANT2_GRU_OPS_TO_QUANTIZE}) - - set(QUANT2_INT8_LSTM_SAVE_PATH "${QUANT_INSTALL_DIR}/lstm_quant2_int8") - save_quant_nlp_model_test(save_quant2_model_lstm ${QUANT2_LSTM_MODEL_DIR}/lstm_quant ${QUANT2_INT8_LSTM_SAVE_PATH} ${QUANT2_LSTM_OPS_TO_QUANTIZE}) - - # Convert Quant2 model to dot and pdf files - set(QUANT2_INT8_ERNIE_DOT_SAVE_PATH "${QUANT_INSTALL_DIR}/Ernie_quant2_int8_dot_file") - convert_model2dot_test(convert_model2dot_ernie ${QUANT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${QUANT2_INT8_ERNIE_DOT_SAVE_PATH} "Ernie_quant2_int8") - - ### PTQ INT8 - - # PTQ int8 lstm model - set(LSTM_DATA_FILE "quant_lstm_input_data.tar.gz") - set(LSTM_URL "${INFERENCE_URL}/int8/unittest_model_data") - download_data(${QUANT2_INT8_LSTM_SAVE_PATH} ${LSTM_URL} ${LSTM_DATA_FILE} add84c754e9b792fea1fbd728d134ab7) - set(QUANT2_FP32_LSTM_MODEL_ARCHIVE "lstm_fp32_model.tar.gz") - download_lstm_model(${QUANT2_INT8_LSTM_SAVE_PATH} ${QUANT2_FP32_LSTM_MODEL_ARCHIVE} eecd9f44d69a84acc1cf2235c4b8b743) - inference_quant2_int8_lstm_model_test(test_quant2_int8_lstm_mkldnn ${QUANT2_INT8_LSTM_SAVE_PATH}/lstm_fp32_model ${QUANT2_LSTM_MODEL_DIR}/lstm_quant ${QUANT2_INT8_LSTM_SAVE_PATH}/quant_lstm_input_data) + #### Image classification dataset: ImageNet (small) + # The dataset should already be downloaded for INT8v2 unit tests + set(IMAGENET_DATA_PATH "${INFERENCE_DEMO_INSTALL_DIR}/imagenet/data.bin") + + #### INT8 image classification python api test + # Models should be already downloaded for INT8v2 unit tests + + set(INT8_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2") + + #### QUANT & INT8 comparison python api tests + + set(QUANT_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/quant") + + ### Quant1 for image classification + + # Quant ResNet50 + set(QUANT_RESNET50_MODEL_DIR "${QUANT_INSTALL_DIR}/ResNet50_quant") + set(QUANT_RESNET50_MODEL_ARCHIVE "ResNet50_qat_model.tar.gz") + download_quant_model( + ${QUANT_RESNET50_MODEL_DIR} ${QUANT_RESNET50_MODEL_ARCHIVE} + ff89b934ab961c3a4a844193ece2e8a7) + inference_quant_int8_image_classification_test( + test_quant_int8_resnet50_mkldnn ${QUANT_RESNET50_MODEL_DIR}/model + ${IMAGENET_DATA_PATH}) + + # Quant ResNet101 + set(QUANT_RESNET101_MODEL_DIR "${QUANT_INSTALL_DIR}/ResNet101_quant") + set(QUANT_RESNET101_MODEL_ARCHIVE "ResNet101_qat_model.tar.gz") + download_quant_model( + ${QUANT_RESNET101_MODEL_DIR} ${QUANT_RESNET101_MODEL_ARCHIVE} + 95c6d01e3aeba31c13efb2ba8057d558) + # inference_quant_int8_image_classification_test(test_quant_int8_resnet101_mkldnn ${QUANT_RESNET101_MODEL_DIR}/model ${IMAGENET_DATA_PATH}) + + # Quant GoogleNet + set(QUANT_GOOGLENET_MODEL_DIR "${QUANT_INSTALL_DIR}/GoogleNet_quant") + set(QUANT_GOOGLENET_MODEL_ARCHIVE "GoogleNet_qat_model.tar.gz") + download_quant_model( + ${QUANT_GOOGLENET_MODEL_DIR} ${QUANT_GOOGLENET_MODEL_ARCHIVE} + 1d4a7383baa63e7d1c423e8db2b791d5) + inference_quant_int8_image_classification_test( + test_quant_int8_googlenet_mkldnn ${QUANT_GOOGLENET_MODEL_DIR}/model + ${IMAGENET_DATA_PATH}) + + # Quant MobileNetV1 + set(QUANT_MOBILENETV1_MODEL_DIR "${QUANT_INSTALL_DIR}/MobileNetV1_quant") + set(QUANT_MOBILENETV1_MODEL_ARCHIVE "MobileNetV1_qat_model.tar.gz") + download_quant_model( + ${QUANT_MOBILENETV1_MODEL_DIR} ${QUANT_MOBILENETV1_MODEL_ARCHIVE} + 3b774d94a9fcbb604d09bdb731fc1162) + inference_quant_int8_image_classification_test( + test_quant_int8_mobilenetv1_mkldnn ${QUANT_MOBILENETV1_MODEL_DIR}/model + ${IMAGENET_DATA_PATH}) + + # Quant MobileNetV2 + set(QUANT_MOBILENETV2_MODEL_DIR "${QUANT_INSTALL_DIR}/MobileNetV2_quant") + set(QUANT_MOBILENETV2_MODEL_ARCHIVE "MobileNetV2_qat_model.tar.gz") + download_quant_model( + ${QUANT_MOBILENETV2_MODEL_DIR} ${QUANT_MOBILENETV2_MODEL_ARCHIVE} + 758a99d9225d8b73e1a8765883f96cdd) + inference_quant_int8_image_classification_test( + test_quant_int8_mobilenetv2_mkldnn ${QUANT_MOBILENETV2_MODEL_DIR}/model + ${IMAGENET_DATA_PATH}) + + # Quant VGG16 + set(QUANT_VGG16_MODEL_DIR "${QUANT_INSTALL_DIR}/VGG16_quant") + set(QUANT_VGG16_MODEL_ARCHIVE "VGG16_qat_model.tar.gz") + download_quant_model(${QUANT_VGG16_MODEL_DIR} ${QUANT_VGG16_MODEL_ARCHIVE} + c37e63ca82a102f47be266f8068b0b55) + # inference_quant_int8_image_classification_test(test_quant_int8_vgg16_mkldnn ${QUANT_VGG16_MODEL_DIR}/model ${IMAGENET_DATA_PATH}) + + # Quant VGG19 + set(QUANT_VGG19_MODEL_DIR "${QUANT_INSTALL_DIR}/VGG19_quant") + set(QUANT_VGG19_MODEL_ARCHIVE "VGG19_qat_model.tar.gz") + download_quant_model(${QUANT_VGG19_MODEL_DIR} ${QUANT_VGG19_MODEL_ARCHIVE} + 62bcd4b6c3ca2af67e8251d1c96ea18f) + # inference_quant_int8_image_classification_test(test_quant_int8_vgg19_mkldnn ${QUANT_VGG19_MODEL_DIR}/model ${IMAGENET_DATA_PATH}) + + ### Quant2 for image classification + + # Quant2 ResNet50 with input/output scales in `fake_quantize_moving_average_abs_max` operators, + # with weight scales in `fake_dequantize_max_abs` operators + set(QUANT2_RESNET50_MODEL_DIR "${QUANT_INSTALL_DIR}/ResNet50_quant2") + set(QUANT2_RESNET50_MODEL_ARCHIVE "ResNet50_qat_perf.tar.gz") + download_quant_model( + ${QUANT2_RESNET50_MODEL_DIR} ${QUANT2_RESNET50_MODEL_ARCHIVE} + e87309457e8c462a579340607f064d66) + set(FP32_RESNET50_MODEL_DIR "${INT8_INSTALL_DIR}/resnet50") + inference_quant2_int8_image_classification_test( + test_quant2_int8_resnet50_mkldnn + ${QUANT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float + ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH}) + + # Quant2 ResNet50 with input/output scales in `fake_quantize_range_abs_max` operators and the `out_threshold` attributes, + # with weight scales in `fake_dequantize_max_abs` operators + set(QUANT2_RESNET50_RANGE_MODEL_DIR + "${QUANT_INSTALL_DIR}/ResNet50_quant2_range") + set(QUANT2_RESNET50_RANGE_MODEL_ARCHIVE "ResNet50_qat_range.tar.gz") + download_quant_model( + ${QUANT2_RESNET50_RANGE_MODEL_DIR} ${QUANT2_RESNET50_RANGE_MODEL_ARCHIVE} + 2fdc8a139f041c0d270abec826b2d304) + inference_quant2_int8_image_classification_test( + test_quant2_int8_resnet50_range_mkldnn + ${QUANT2_RESNET50_RANGE_MODEL_DIR}/ResNet50_qat_range + ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH}) + + # Quant2 ResNet50 with input/output scales in `fake_quantize_range_abs_max` operators and the `out_threshold` attributes, + # with weight scales in `fake_channel_wise_dequantize_max_abs` operators + set(QUANT2_RESNET50_CHANNELWISE_MODEL_DIR + "${QUANT_INSTALL_DIR}/ResNet50_quant2_channelwise") + set(QUANT2_RESNET50_CHANNELWISE_MODEL_ARCHIVE + "ResNet50_qat_channelwise.tar.gz") + download_quant_model( + ${QUANT2_RESNET50_CHANNELWISE_MODEL_DIR} + ${QUANT2_RESNET50_CHANNELWISE_MODEL_ARCHIVE} + 887a1b1b0e9a4efd10f263a43764db26) + inference_quant2_int8_image_classification_test( + test_quant2_int8_resnet50_channelwise_mkldnn + ${QUANT2_RESNET50_CHANNELWISE_MODEL_DIR}/ResNet50_qat_channelwise + ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH}) + + # Quant2 MobileNetV1 + set(QUANT2_MOBILENETV1_MODEL_DIR "${QUANT_INSTALL_DIR}/MobileNetV1_quant2") + set(QUANT2_MOBILENETV1_MODEL_ARCHIVE "MobileNet_qat_perf.tar.gz") + download_quant_model( + ${QUANT2_MOBILENETV1_MODEL_DIR} ${QUANT2_MOBILENETV1_MODEL_ARCHIVE} + 7f626e453db2d56fed6c2538621ffacf) + set(FP32_MOBILENETV1_MODEL_DIR "${INT8_INSTALL_DIR}/mobilenetv1") + inference_quant2_int8_image_classification_test( + test_quant2_int8_mobilenetv1_mkldnn + ${QUANT2_MOBILENETV1_MODEL_DIR}/MobileNet_qat_perf/float + ${FP32_MOBILENETV1_MODEL_DIR}/model ${IMAGENET_DATA_PATH}) + + ### Quant2 for NLP + + set(NLP_DATA_ARCHIVE "Ernie_dataset.tar.gz") + set(NLP_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/Ernie_dataset") + set(NLP_DATA_PATH "${NLP_DATA_DIR}/Ernie_dataset/1.8w.bs1") + set(NLP_LABLES_PATH "${NLP_DATA_DIR}/Ernie_dataset/label.xnli.dev") + download_quant_data(${NLP_DATA_DIR} ${NLP_DATA_ARCHIVE} + e650ce0cbc1fadbed5cc2c01d4e734dc) + + # Quant2 Ernie + set(QUANT2_ERNIE_MODEL_ARCHIVE "ernie_qat.tar.gz") + set(QUANT2_ERNIE_MODEL_DIR "${QUANT_INSTALL_DIR}/Ernie_quant2") + download_quant_model(${QUANT2_ERNIE_MODEL_DIR} ${QUANT2_ERNIE_MODEL_ARCHIVE} + f7cdf4720755ecf66efbc8044e9922d9) + set(FP32_ERNIE_MODEL_ARCHIVE "ernie_fp32_model.tar.gz") + set(FP32_ERNIE_MODEL_DIR "${QUANT_INSTALL_DIR}/Ernie_float") + download_quant_fp32_model(${FP32_ERNIE_MODEL_DIR} ${FP32_ERNIE_MODEL_ARCHIVE} + 114f38804a3ef8c45e7259e68bbd838b) + set(QUANT2_ERNIE_OPS_TO_QUANTIZE + "fc,reshape2,transpose2,matmul,elementwise_add,slice") + inference_quant2_int8_nlp_test( + test_quant2_int8_ernie_mkldnn ${QUANT2_ERNIE_MODEL_DIR}/Ernie_qat/float + ${FP32_ERNIE_MODEL_DIR}/ernie_fp32_model ${NLP_DATA_PATH} + ${NLP_LABLES_PATH} ${QUANT2_ERNIE_OPS_TO_QUANTIZE}) + + # Quant2 GRU + set(QUANT2_GRU_MODEL_ARCHIVE "GRU_quant_acc.tar.gz") + set(QUANT2_GRU_MODEL_DIR "${QUANT_INSTALL_DIR}/GRU_quant2") + download_quant_model(${QUANT2_GRU_MODEL_DIR} ${QUANT2_GRU_MODEL_ARCHIVE} + cf207f8076dcfb8b74d8b6bdddf9090c) + set(QUANT2_GRU_OPS_TO_QUANTIZE "multi_gru") + + # Quant2 LSTM + set(QUANT2_LSTM_MODEL_ARCHIVE "lstm_quant.tar.gz") + set(QUANT2_LSTM_MODEL_DIR "${QUANT_INSTALL_DIR}/lstm_quant_test") + download_quant_model(${QUANT2_LSTM_MODEL_DIR} ${QUANT2_LSTM_MODEL_ARCHIVE} + 40a693803b12ee9e251258f32559abcb) + set(QUANT2_LSTM_OPS_TO_QUANTIZE "fusion_lstm") + + ### Save FP32 model or INT8 model from Quant model + + set(QUANT2_INT8_RESNET50_SAVE_PATH + "${QUANT_INSTALL_DIR}/ResNet50_quant2_int8") + save_quant_ic_model_test( + save_quant2_model_resnet50 + ${QUANT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float + ${QUANT2_INT8_RESNET50_SAVE_PATH}) + + set(QUANT2_INT8_ERNIE_SAVE_PATH "${QUANT_INSTALL_DIR}/Ernie_quant2_int8") + save_quant_nlp_model_test( + save_quant2_model_ernie ${QUANT2_ERNIE_MODEL_DIR}/Ernie_qat/float + ${QUANT2_INT8_ERNIE_SAVE_PATH} ${QUANT2_ERNIE_OPS_TO_QUANTIZE}) + + set(QUANT2_INT8_GRU_SAVE_PATH "${QUANT_INSTALL_DIR}/GRU_quant2_int8") + save_quant_nlp_model_test( + save_quant2_model_gru ${QUANT2_GRU_MODEL_DIR}/GRU_quant_acc + ${QUANT2_INT8_GRU_SAVE_PATH} ${QUANT2_GRU_OPS_TO_QUANTIZE}) + + set(QUANT2_INT8_LSTM_SAVE_PATH "${QUANT_INSTALL_DIR}/lstm_quant2_int8") + save_quant_nlp_model_test( + save_quant2_model_lstm ${QUANT2_LSTM_MODEL_DIR}/lstm_quant + ${QUANT2_INT8_LSTM_SAVE_PATH} ${QUANT2_LSTM_OPS_TO_QUANTIZE}) + + # Convert Quant2 model to dot and pdf files + set(QUANT2_INT8_ERNIE_DOT_SAVE_PATH + "${QUANT_INSTALL_DIR}/Ernie_quant2_int8_dot_file") + convert_model2dot_test( + convert_model2dot_ernie ${QUANT2_ERNIE_MODEL_DIR}/Ernie_qat/float + ${QUANT2_INT8_ERNIE_DOT_SAVE_PATH} "Ernie_quant2_int8") + + ### PTQ INT8 + + # PTQ int8 lstm model + set(LSTM_DATA_FILE "quant_lstm_input_data.tar.gz") + set(LSTM_URL "${INFERENCE_URL}/int8/unittest_model_data") + download_data(${QUANT2_INT8_LSTM_SAVE_PATH} ${LSTM_URL} ${LSTM_DATA_FILE} + add84c754e9b792fea1fbd728d134ab7) + set(QUANT2_FP32_LSTM_MODEL_ARCHIVE "lstm_fp32_model.tar.gz") + download_lstm_model( + ${QUANT2_INT8_LSTM_SAVE_PATH} ${QUANT2_FP32_LSTM_MODEL_ARCHIVE} + eecd9f44d69a84acc1cf2235c4b8b743) + inference_quant2_int8_lstm_model_test( + test_quant2_int8_lstm_mkldnn ${QUANT2_INT8_LSTM_SAVE_PATH}/lstm_fp32_model + ${QUANT2_LSTM_MODEL_DIR}/lstm_quant + ${QUANT2_INT8_LSTM_SAVE_PATH}/quant_lstm_input_data) endif() -# Since the tests for Quant & INT8 comparison support only testing on Linux +# Since the tests for Quant & INT8 comparison support only testing on Linux # with MKL-DNN, we remove it here to not test it on other systems. -list(REMOVE_ITEM TEST_OPS - test_mkldnn_int8_quantization_strategy - quant_int8_image_classification_comparison - quant_int8_nlp_comparison) +list(REMOVE_ITEM TEST_OPS test_mkldnn_int8_quantization_strategy + quant_int8_image_classification_comparison quant_int8_nlp_comparison) #TODO(wanghaoshuang): Fix this unitest failed on GCC8. -LIST(REMOVE_ITEM TEST_OPS test_auto_pruning) -LIST(REMOVE_ITEM TEST_OPS test_filter_pruning) - +list(REMOVE_ITEM TEST_OPS test_auto_pruning) +list(REMOVE_ITEM TEST_OPS test_filter_pruning) + # fix if(WIN32) - SET(SINGLE_CARD_TEST_OPS - test_user_defined_quantization - test_quantization_scale_pass - test_quantization_pass - test_moving_average_abs_max_scale_op - test_imperative_qat_channelwise - test_imperative_qat - test_imperative_out_scale - test_graph) - LIST(REMOVE_ITEM TEST_OPS ${SINGLE_CARD_TEST_OPS}) - foreach(src ${SINGLE_CARD_TEST_OPS}) - py_test(${src} SRCS ${src}.py ENVS CUDA_VISIBLE_DEVICES=0) - endforeach() + set(SINGLE_CARD_TEST_OPS + test_user_defined_quantization + test_quantization_scale_pass + test_quantization_pass + test_moving_average_abs_max_scale_op + test_imperative_qat_channelwise + test_imperative_qat + test_imperative_out_scale + test_graph) + list(REMOVE_ITEM TEST_OPS ${SINGLE_CARD_TEST_OPS}) + foreach(src ${SINGLE_CARD_TEST_OPS}) + py_test(${src} SRCS ${src}.py ENVS CUDA_VISIBLE_DEVICES=0) + endforeach() endif() - foreach(src ${TEST_OPS}) - py_test(${src} SRCS ${src}.py) + py_test(${src} SRCS ${src}.py) endforeach() # setting timeout value for old unittests if(NOT WIN32) - set_tests_properties(test_post_training_quantization_lstm_model PROPERTIES TIMEOUT 120) - set_tests_properties(test_post_training_quantization_mobilenetv1 PROPERTIES TIMEOUT 600 LABELS "RUN_TYPE=NIGHTLY") - set_tests_properties(test_post_training_quantization_resnet50 PROPERTIES TIMEOUT 600 LABELS "RUN_TYPE=NIGHTLY") - set_tests_properties(test_post_training_quantization_mnist PROPERTIES TIMEOUT 120) - set_tests_properties(test_post_training_quantization_while PROPERTIES TIMEOUT 120) - set_tests_properties(test_imperative_ptq PROPERTIES TIMEOUT 120) - set_tests_properties(test_weight_quantization_mobilenetv1 PROPERTIES TIMEOUT 120) + set_tests_properties(test_post_training_quantization_lstm_model + PROPERTIES TIMEOUT 120) + set_tests_properties(test_post_training_quantization_mobilenetv1 + PROPERTIES TIMEOUT 600 LABELS "RUN_TYPE=NIGHTLY") + set_tests_properties(test_post_training_quantization_resnet50 + PROPERTIES TIMEOUT 600 LABELS "RUN_TYPE=NIGHTLY") + set_tests_properties(test_post_training_quantization_mnist PROPERTIES TIMEOUT + 120) + set_tests_properties(test_post_training_quantization_while PROPERTIES TIMEOUT + 120) + set_tests_properties(test_imperative_ptq PROPERTIES TIMEOUT 120) + set_tests_properties(test_weight_quantization_mobilenetv1 PROPERTIES TIMEOUT + 120) endif() set_tests_properties(test_graph PROPERTIES TIMEOUT 120) @@ -359,23 +530,30 @@ set_tests_properties(test_imperative_out_scale PROPERTIES TIMEOUT 200) set_tests_properties(test_imperative_qat_user_defined PROPERTIES TIMEOUT 200) if(LINUX AND WITH_MKLDNN) - set_tests_properties(test_quant2_int8_mobilenetv1_mkldnn PROPERTIES TIMEOUT 120) - set_tests_properties(convert_model2dot_ernie PROPERTIES TIMEOUT 120) - set_tests_properties(test_quant2_int8_resnet50_channelwise_mkldnn PROPERTIES TIMEOUT 120) - set_tests_properties(test_quant_int8_mobilenetv2_mkldnn PROPERTIES TIMEOUT 120) - set_tests_properties(test_quant2_int8_resnet50_range_mkldnn PROPERTIES TIMEOUT 120) - set_tests_properties(save_quant2_model_resnet50 PROPERTIES TIMEOUT 120) - set_tests_properties(test_quant_int8_resnet50_mkldnn PROPERTIES TIMEOUT 120) - set_tests_properties(test_quant_int8_mobilenetv1_mkldnn PROPERTIES TIMEOUT 120) - set_tests_properties(test_quant2_int8_ernie_mkldnn PROPERTIES TIMEOUT 120) - set_tests_properties(test_quant_int8_googlenet_mkldnn PROPERTIES TIMEOUT 120) - set_tests_properties(test_quant2_int8_resnet50_mkldnn PROPERTIES TIMEOUT 120) - set_tests_properties(test_quant2_int8_lstm_mkldnn PROPERTIES TIMEOUT 120) + set_tests_properties(test_quant2_int8_mobilenetv1_mkldnn PROPERTIES TIMEOUT + 120) + set_tests_properties(convert_model2dot_ernie PROPERTIES TIMEOUT 120) + set_tests_properties(test_quant2_int8_resnet50_channelwise_mkldnn + PROPERTIES TIMEOUT 120) + set_tests_properties(test_quant_int8_mobilenetv2_mkldnn PROPERTIES TIMEOUT + 120) + set_tests_properties(test_quant2_int8_resnet50_range_mkldnn PROPERTIES TIMEOUT + 120) + set_tests_properties(save_quant2_model_resnet50 PROPERTIES TIMEOUT 120) + set_tests_properties(test_quant_int8_resnet50_mkldnn PROPERTIES TIMEOUT 120) + set_tests_properties(test_quant_int8_mobilenetv1_mkldnn PROPERTIES TIMEOUT + 120) + set_tests_properties(test_quant2_int8_ernie_mkldnn PROPERTIES TIMEOUT 120) + set_tests_properties(test_quant_int8_googlenet_mkldnn PROPERTIES TIMEOUT 120) + set_tests_properties(test_quant2_int8_resnet50_mkldnn PROPERTIES TIMEOUT 120) + set_tests_properties(test_quant2_int8_lstm_mkldnn PROPERTIES TIMEOUT 120) endif() if(APPLE) - set_tests_properties(test_post_training_quantization_mnist PROPERTIES TIMEOUT 300) - set_tests_properties(test_post_training_quantization_while PROPERTIES TIMEOUT 300) - set_tests_properties(test_imperative_ptq PROPERTIES TIMEOUT 300) - set_tests_properties(test_imperative_skip_op PROPERTIES TIMEOUT 300) + set_tests_properties(test_post_training_quantization_mnist PROPERTIES TIMEOUT + 300) + set_tests_properties(test_post_training_quantization_while PROPERTIES TIMEOUT + 300) + set_tests_properties(test_imperative_ptq PROPERTIES TIMEOUT 300) + set_tests_properties(test_imperative_skip_op PROPERTIES TIMEOUT 300) endif() diff --git a/python/paddle/fluid/contrib/slim/tests/test_imperative_ptq.py b/python/paddle/fluid/contrib/slim/tests/test_imperative_ptq.py index 0715fcf2a8b..2c56f9ad53d 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_imperative_ptq.py +++ b/python/paddle/fluid/contrib/slim/tests/test_imperative_ptq.py @@ -35,8 +35,9 @@ from paddle.fluid.framework import _test_eager_guard from imperative_test_utils import fix_model_dict, ImperativeLenet, ImperativeLinearBn from imperative_test_utils import ImperativeLinearBn_hook -_logger = get_logger( - __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') +_logger = get_logger(__name__, + logging.INFO, + fmt='%(asctime)s-%(levelname)s: %(message)s') class TestFuseLinearBn(unittest.TestCase): @@ -55,15 +56,15 @@ class TestFuseLinearBn(unittest.TestCase): quant_h = ptq.quantize(model_h, fuse=True, fuse_list=f_l) for name, layer in quant_model.named_sublayers(): if name in f_l: - assert not (isinstance(layer, nn.BatchNorm1D) or - isinstance(layer, nn.BatchNorm2D)) + assert not (isinstance(layer, nn.BatchNorm1D) + or isinstance(layer, nn.BatchNorm2D)) out = model(inputs) out_h = model_h(inputs) out_quant = quant_model(inputs) out_quant_h = quant_h(inputs) cos_sim_func = nn.CosineSimilarity(axis=0) - print('fuse linear+bn', - cos_sim_func(out.flatten(), out_quant.flatten())) + print('fuse linear+bn', cos_sim_func(out.flatten(), + out_quant.flatten())) print(cos_sim_func(out_h.flatten(), out_quant_h.flatten())) @@ -87,8 +88,8 @@ class TestImperativePTQ(unittest.TestCase): def cache_unzipping(self, target_folder, zip_path): if not os.path.exists(target_folder): - cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(target_folder, - zip_path) + cmd = 'mkdir {0} && tar xf {1} -C {0}'.format( + target_folder, zip_path) os.system(cmd) def download_model(self, data_url, data_md5, folder_name): @@ -123,8 +124,8 @@ class TestImperativePTQ(unittest.TestCase): def model_test(self, model, batch_num=-1, batch_size=8): model.eval() - test_reader = paddle.batch( - paddle.dataset.mnist.test(), batch_size=batch_size) + test_reader = paddle.batch(paddle.dataset.mnist.test(), + batch_size=batch_size) eval_acc_top1_list = [] for batch_id, data in enumerate(test_reader()): @@ -157,8 +158,8 @@ class TestImperativePTQ(unittest.TestCase): [inference_program, feed_target_names, fetch_targets ] = (paddle.static.load_inference_model(program_path, exe)) - test_reader = paddle.batch( - paddle.dataset.mnist.test(), batch_size=batch_size) + test_reader = paddle.batch(paddle.dataset.mnist.test(), + batch_size=batch_size) top1_correct_num = 0. total_num = 0. @@ -203,13 +204,13 @@ class TestImperativePTQ(unittest.TestCase): self.batch_size) input_spec = [ - paddle.static.InputSpec( - shape=[None, 1, 28, 28], dtype='float32') + paddle.static.InputSpec(shape=[None, 1, 28, 28], dtype='float32') ] with tempfile.TemporaryDirectory(prefix="imperative_ptq_") as tmpdir: save_path = os.path.join(tmpdir, "model") - self.ptq.save_quantized_model( - model=quant_model, path=save_path, input_spec=input_spec) + self.ptq.save_quantized_model(model=quant_model, + path=save_path, + input_spec=input_spec) print('Quantized model saved in {%s}' % save_path) after_acc_top1 = self.model_test(quant_model, self.batch_num, @@ -225,13 +226,11 @@ class TestImperativePTQ(unittest.TestCase): print('After converted acc_top1: %s' % after_acc_top1) print('Infer acc_top1: %s' % infer_acc_top1) - self.assertTrue( - after_acc_top1 >= self.eval_acc_top1, - msg="The test acc {%f} is less than {%f}." % - (after_acc_top1, self.eval_acc_top1)) - self.assertTrue( - infer_acc_top1 >= after_acc_top1, - msg='The acc is lower after converting model.') + self.assertTrue(after_acc_top1 >= self.eval_acc_top1, + msg="The test acc {%f} is less than {%f}." % + (after_acc_top1, self.eval_acc_top1)) + self.assertTrue(infer_acc_top1 >= after_acc_top1, + msg='The acc is lower after converting model.') end_time = time.time() print("total time: %ss \n" % (end_time - start_time)) @@ -243,6 +242,7 @@ class TestImperativePTQ(unittest.TestCase): class TestImperativePTQfuse(TestImperativePTQ): + def func_ptq(self): start_time = time.time() @@ -261,19 +261,19 @@ class TestImperativePTQfuse(TestImperativePTQ): quant_model = self.ptq.quantize(model, fuse=True, fuse_list=f_l) for name, layer in quant_model.named_sublayers(): if name in f_l: - assert not (isinstance(layer, nn.BatchNorm1D) or - isinstance(layer, nn.BatchNorm2D)) + assert not (isinstance(layer, nn.BatchNorm1D) + or isinstance(layer, nn.BatchNorm2D)) before_acc_top1 = self.model_test(quant_model, self.batch_num, self.batch_size) input_spec = [ - paddle.static.InputSpec( - shape=[None, 1, 28, 28], dtype='float32') + paddle.static.InputSpec(shape=[None, 1, 28, 28], dtype='float32') ] with tempfile.TemporaryDirectory(prefix="imperative_ptq_") as tmpdir: save_path = os.path.join(tmpdir, "model") - self.ptq.save_quantized_model( - model=quant_model, path=save_path, input_spec=input_spec) + self.ptq.save_quantized_model(model=quant_model, + path=save_path, + input_spec=input_spec) print('Quantized model saved in {%s}' % save_path) after_acc_top1 = self.model_test(quant_model, self.batch_num, @@ -291,15 +291,13 @@ class TestImperativePTQfuse(TestImperativePTQ): #Check whether the quant_model is correct after converting. #The acc of quantized model should be higher than 0.95. - self.assertTrue( - after_acc_top1 >= self.eval_acc_top1, - msg="The test acc {%f} is less than {%f}." % - (after_acc_top1, self.eval_acc_top1)) + self.assertTrue(after_acc_top1 >= self.eval_acc_top1, + msg="The test acc {%f} is less than {%f}." % + (after_acc_top1, self.eval_acc_top1)) #Check the saved infer_model.The acc of infer model #should not be lower than the one of dygraph model. - self.assertTrue( - infer_acc_top1 >= after_acc_top1, - msg='The acc is lower after converting model.') + self.assertTrue(infer_acc_top1 >= after_acc_top1, + msg='The acc is lower after converting model.') end_time = time.time() print("total time: %ss \n" % (end_time - start_time)) @@ -311,6 +309,7 @@ class TestImperativePTQfuse(TestImperativePTQ): class TestImperativePTQHist(TestImperativePTQ): + def set_vars(self): config = PTQConfig(HistQuantizer(), AbsmaxQuantizer()) self.ptq = ImperativePTQ(config) @@ -332,13 +331,14 @@ class TestImperativePTQHist(TestImperativePTQ): class TestImperativePTQKL(TestImperativePTQ): + def set_vars(self): config = PTQConfig(KLQuantizer(), PerChannelAbsmaxQuantizer()) self.ptq = ImperativePTQ(config) self.batch_num = 10 self.batch_size = 10 - self.eval_acc_top1 = 1.0 + self.eval_acc_top1 = 0.98 conv2d_1_wt_thresholds = [ 0.18116560578346252, 0.17079241573810577, 0.1702047884464264, diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_lstm_model.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_lstm_model.py index 4ea51233e40..6100ed4f82a 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_lstm_model.py +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_lstm_model.py @@ -34,6 +34,7 @@ np.random.seed(0) class TestPostTrainingQuantization(unittest.TestCase): + def setUp(self): self.download_path = 'int8/download' self.cache_folder = os.path.expanduser('~/.cache/paddle/dataset/' + @@ -44,8 +45,8 @@ class TestPostTrainingQuantization(unittest.TestCase): try: os.system("mkdir -p " + self.int8_model_path) except Exception as e: - print("Failed to create {} due to {}".format(self.int8_model_path, - str(e))) + print("Failed to create {} due to {}".format( + self.int8_model_path, str(e))) sys.exit(-1) def tearDown(self): @@ -53,8 +54,8 @@ class TestPostTrainingQuantization(unittest.TestCase): def cache_unzipping(self, target_folder, zip_path): if not os.path.exists(target_folder): - cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(target_folder, - zip_path) + cmd = 'mkdir {0} && tar xf {1} -C {0}'.format( + target_folder, zip_path) os.system(cmd) def download_model(self, data_url, data_md5, folder_name): @@ -68,6 +69,7 @@ class TestPostTrainingQuantization(unittest.TestCase): return data_cache_folder def get_batch_reader(self, data_path, place): + def reader(): with open(data_path, 'rb') as in_file: while True: @@ -80,15 +82,14 @@ class TestPostTrainingQuantization(unittest.TestCase): seq_len = (alllen >> 16) & 0xFFFF label = in_file.read(4 * label_len) - label = np.frombuffer( - label, dtype=np.int32).reshape([len(label) // 4]) + label = np.frombuffer(label, dtype=np.int32).reshape( + [len(label) // 4]) if label.shape[0] != 1 or label[0] > 6350: continue feat = in_file.read(4 * seq_len * 8) - feat = np.frombuffer( - feat, - dtype=np.float32).reshape([len(feat) // 4 // 8, 8]) + feat = np.frombuffer(feat, dtype=np.float32).reshape( + [len(feat) // 4 // 8, 8]) lod_feat = [feat.shape[0]] minputs = fluid.create_lod_tensor(feat, [lod_feat], place) @@ -97,6 +98,7 @@ class TestPostTrainingQuantization(unittest.TestCase): return reader def get_simple_reader(self, data_path, place): + def reader(): with open(data_path, 'rb') as in_file: while True: @@ -109,15 +111,14 @@ class TestPostTrainingQuantization(unittest.TestCase): seq_len = (alllen >> 16) & 0xFFFF label = in_file.read(4 * label_len) - label = np.frombuffer( - label, dtype=np.int32).reshape([len(label) // 4]) + label = np.frombuffer(label, dtype=np.int32).reshape( + [len(label) // 4]) if label.shape[0] != 1 or label[0] > 6350: continue feat = in_file.read(4 * seq_len * 8) - feat = np.frombuffer( - feat, - dtype=np.float32).reshape([len(feat) // 4 // 8, 8]) + feat = np.frombuffer(feat, dtype=np.float32).reshape( + [len(feat) // 4 // 8, 8]) lod_feat = [feat.shape[0]] minputs = fluid.create_lod_tensor(feat, [lod_feat], place) @@ -178,18 +179,17 @@ class TestPostTrainingQuantization(unittest.TestCase): scope = fluid.global_scope() batch_generator = self.get_batch_reader(data_path, place) - ptq = PostTrainingQuantization( - executor=exe, - model_dir=model_path, - batch_generator=batch_generator, - batch_nums=batch_nums, - algo=algo, - quantizable_op_type=quantizable_op_type, - round_type=round_type, - is_full_quantize=is_full_quantize, - optimize_model=is_optimize_model, - onnx_format=onnx_format, - is_use_cache_file=is_use_cache_file) + ptq = PostTrainingQuantization(executor=exe, + model_dir=model_path, + batch_generator=batch_generator, + batch_nums=batch_nums, + algo=algo, + quantizable_op_type=quantizable_op_type, + round_type=round_type, + is_full_quantize=is_full_quantize, + optimize_model=is_optimize_model, + onnx_format=onnx_format, + is_use_cache_file=is_use_cache_file) ptq.quantize() ptq.save_quantized_model(self.int8_model_path) @@ -223,10 +223,11 @@ class TestPostTrainingQuantization(unittest.TestCase): print("Start post training quantization for {0} on {1} samples ...". format(model_name, quant_iterations)) - self.generate_quantized_model( - fp32_model_path, data_path, algo, round_type, quantizable_op_type, - is_full_quantize, is_use_cache_file, is_optimize_model, - quant_iterations, onnx_format) + self.generate_quantized_model(fp32_model_path, data_path, algo, + round_type, quantizable_op_type, + is_full_quantize, is_use_cache_file, + is_optimize_model, quant_iterations, + onnx_format) print("Start INT8 inference for {0} on {1} samples ...".format( model_name, infer_iterations)) @@ -245,6 +246,7 @@ class TestPostTrainingQuantization(unittest.TestCase): class TestPostTrainingAvgForLSTM(TestPostTrainingQuantization): + def test_post_training_avg(self): model_name = "nlp_lstm_fp32_model" model_url = "https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/nlp_lstm_fp32_model.tar.gz" @@ -268,6 +270,7 @@ class TestPostTrainingAvgForLSTM(TestPostTrainingQuantization): class TestPostTrainingAvgForLSTMONNXFormat(TestPostTrainingQuantization): + def test_post_training_avg_onnx_format(self): model_name = "nlp_lstm_fp32_model" model_url = "https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/nlp_lstm_fp32_model.tar.gz" @@ -285,23 +288,22 @@ class TestPostTrainingAvgForLSTMONNXFormat(TestPostTrainingQuantization): infer_iterations = 100 quant_iterations = 10 onnx_format = True - self.run_test( - model_name, - model_url, - model_md5, - data_name, - data_url, - data_md5, - algo, - round_type, - quantizable_op_type, - is_full_quantize, - is_use_cache_file, - is_optimize_model, - diff_threshold, - infer_iterations, - quant_iterations, - onnx_format=onnx_format) + self.run_test(model_name, + model_url, + model_md5, + data_name, + data_url, + data_md5, + algo, + round_type, + quantizable_op_type, + is_full_quantize, + is_use_cache_file, + is_optimize_model, + diff_threshold, + infer_iterations, + quant_iterations, + onnx_format=onnx_format) if __name__ == '__main__': diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mnist.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mnist.py index ec1272a0480..807bdbf8a9a 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mnist.py +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mnist.py @@ -33,6 +33,7 @@ np.random.seed(0) class TestPostTrainingQuantization(unittest.TestCase): + def setUp(self): self.root_path = tempfile.TemporaryDirectory() self.int8_model_path = os.path.join(self.root_path.name, @@ -43,8 +44,8 @@ class TestPostTrainingQuantization(unittest.TestCase): try: os.system("mkdir -p " + self.int8_model_path) except Exception as e: - print("Failed to create {} due to {}".format(self.int8_model_path, - str(e))) + print("Failed to create {} due to {}".format( + self.int8_model_path, str(e))) sys.exit(-1) def tearDown(self): @@ -52,8 +53,8 @@ class TestPostTrainingQuantization(unittest.TestCase): def cache_unzipping(self, target_folder, zip_path): if not os.path.exists(target_folder): - cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(target_folder, - zip_path) + cmd = 'mkdir {0} && tar xf {1} -C {0}'.format( + target_folder, zip_path) os.system(cmd) def download_model(self, data_url, data_md5, folder_name): @@ -115,26 +116,27 @@ class TestPostTrainingQuantization(unittest.TestCase): batch_size=10, batch_nums=10, onnx_format=False, - skip_tensor_list=None): + skip_tensor_list=None, + bias_correction=False): place = fluid.CPUPlace() exe = fluid.Executor(place) val_reader = paddle.dataset.mnist.train() - ptq = PostTrainingQuantization( - executor=exe, - model_dir=model_path, - sample_generator=val_reader, - batch_size=batch_size, - batch_nums=batch_nums, - algo=algo, - quantizable_op_type=quantizable_op_type, - round_type=round_type, - is_full_quantize=is_full_quantize, - optimize_model=is_optimize_model, - onnx_format=onnx_format, - skip_tensor_list=skip_tensor_list, - is_use_cache_file=is_use_cache_file) + ptq = PostTrainingQuantization(executor=exe, + model_dir=model_path, + sample_generator=val_reader, + batch_size=batch_size, + batch_nums=batch_nums, + algo=algo, + quantizable_op_type=quantizable_op_type, + round_type=round_type, + is_full_quantize=is_full_quantize, + optimize_model=is_optimize_model, + bias_correction=bias_correction, + onnx_format=onnx_format, + skip_tensor_list=skip_tensor_list, + is_use_cache_file=is_use_cache_file) ptq.quantize() ptq.save_quantized_model(self.int8_model_path) @@ -152,6 +154,7 @@ class TestPostTrainingQuantization(unittest.TestCase): batch_size=10, infer_iterations=10, quant_iterations=5, + bias_correction=False, onnx_format=False, skip_tensor_list=None): @@ -160,20 +163,23 @@ class TestPostTrainingQuantization(unittest.TestCase): print("Start FP32 inference for {0} on {1} images ...".format( model_name, infer_iterations * batch_size)) - (fp32_throughput, fp32_latency, fp32_acc1) = self.run_program( - origin_model_path, batch_size, infer_iterations) + (fp32_throughput, fp32_latency, + fp32_acc1) = self.run_program(origin_model_path, batch_size, + infer_iterations) print("Start INT8 post training quantization for {0} on {1} images ...". format(model_name, quant_iterations * batch_size)) - self.generate_quantized_model( - origin_model_path, algo, round_type, quantizable_op_type, - is_full_quantize, is_use_cache_file, is_optimize_model, batch_size, - quant_iterations, onnx_format, skip_tensor_list) + self.generate_quantized_model(origin_model_path, algo, round_type, + quantizable_op_type, is_full_quantize, + is_use_cache_file, is_optimize_model, + batch_size, quant_iterations, onnx_format, + skip_tensor_list, bias_correction) print("Start INT8 inference for {0} on {1} images ...".format( model_name, infer_iterations * batch_size)) - (int8_throughput, int8_latency, int8_acc1) = self.run_program( - self.int8_model_path, batch_size, infer_iterations) + (int8_throughput, int8_latency, + int8_acc1) = self.run_program(self.int8_model_path, batch_size, + infer_iterations) print("---Post training quantization of {} method---".format(algo)) print( @@ -191,6 +197,7 @@ class TestPostTrainingQuantization(unittest.TestCase): class TestPostTrainingKLForMnist(TestPostTrainingQuantization): + def test_post_training_kl(self): model_name = "mnist_model" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" @@ -212,6 +219,7 @@ class TestPostTrainingKLForMnist(TestPostTrainingQuantization): class TestPostTraininghistForMnist(TestPostTrainingQuantization): + def test_post_training_hist(self): model_name = "mnist_model" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" @@ -233,6 +241,7 @@ class TestPostTraininghistForMnist(TestPostTrainingQuantization): class TestPostTrainingmseForMnist(TestPostTrainingQuantization): + def test_post_training_mse(self): model_name = "mnist_model" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" @@ -254,6 +263,7 @@ class TestPostTrainingmseForMnist(TestPostTrainingQuantization): class TestPostTrainingemdForMnist(TestPostTrainingQuantization): + def test_post_training_mse(self): model_name = "mnist_model" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" @@ -275,6 +285,7 @@ class TestPostTrainingemdForMnist(TestPostTrainingQuantization): class TestPostTrainingavgForMnist(TestPostTrainingQuantization): + def test_post_training_avg(self): model_name = "mnist_model" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" @@ -296,6 +307,7 @@ class TestPostTrainingavgForMnist(TestPostTrainingQuantization): class TestPostTrainingAbsMaxForMnist(TestPostTrainingQuantization): + def test_post_training_abs_max(self): model_name = "mnist_model" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" @@ -317,6 +329,7 @@ class TestPostTrainingAbsMaxForMnist(TestPostTrainingQuantization): class TestPostTrainingmseAdaroundForMnist(TestPostTrainingQuantization): + def test_post_training_mse(self): model_name = "mnist_model" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" @@ -331,13 +344,25 @@ class TestPostTrainingmseAdaroundForMnist(TestPostTrainingQuantization): batch_size = 10 infer_iterations = 50 quant_iterations = 5 - self.run_test(model_name, data_url, data_md5, algo, round_type, - quantizable_op_type, is_full_quantize, is_use_cache_file, - is_optimize_model, diff_threshold, batch_size, - infer_iterations, quant_iterations) + bias_correction = True + self.run_test(model_name, + data_url, + data_md5, + algo, + round_type, + quantizable_op_type, + is_full_quantize, + is_use_cache_file, + is_optimize_model, + diff_threshold, + batch_size, + infer_iterations, + quant_iterations, + bias_correction=bias_correction) class TestPostTrainingKLAdaroundForMnist(TestPostTrainingQuantization): + def test_post_training_kl(self): model_name = "mnist_model" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" @@ -359,6 +384,7 @@ class TestPostTrainingKLAdaroundForMnist(TestPostTrainingQuantization): class TestPostTrainingmseForMnistONNXFormat(TestPostTrainingQuantization): + def test_post_training_mse_onnx_format(self): model_name = "mnist_model" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" @@ -374,25 +400,25 @@ class TestPostTrainingmseForMnistONNXFormat(TestPostTrainingQuantization): batch_size = 10 infer_iterations = 50 quant_iterations = 5 - self.run_test( - model_name, - data_url, - data_md5, - algo, - round_type, - quantizable_op_type, - is_full_quantize, - is_use_cache_file, - is_optimize_model, - diff_threshold, - batch_size, - infer_iterations, - quant_iterations, - onnx_format=onnx_format) + self.run_test(model_name, + data_url, + data_md5, + algo, + round_type, + quantizable_op_type, + is_full_quantize, + is_use_cache_file, + is_optimize_model, + diff_threshold, + batch_size, + infer_iterations, + quant_iterations, + onnx_format=onnx_format) class TestPostTrainingmseForMnistONNXFormatFullQuant( TestPostTrainingQuantization): + def test_post_training_mse_onnx_format_full_quant(self): model_name = "mnist_model" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" @@ -408,24 +434,24 @@ class TestPostTrainingmseForMnistONNXFormatFullQuant( batch_size = 10 infer_iterations = 50 quant_iterations = 5 - self.run_test( - model_name, - data_url, - data_md5, - algo, - round_type, - quantizable_op_type, - is_full_quantize, - is_use_cache_file, - is_optimize_model, - diff_threshold, - batch_size, - infer_iterations, - quant_iterations, - onnx_format=onnx_format) + self.run_test(model_name, + data_url, + data_md5, + algo, + round_type, + quantizable_op_type, + is_full_quantize, + is_use_cache_file, + is_optimize_model, + diff_threshold, + batch_size, + infer_iterations, + quant_iterations, + onnx_format=onnx_format) class TestPostTrainingavgForMnistSkipOP(TestPostTrainingQuantization): + def test_post_training_avg_skip_op(self): model_name = "mnist_model" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" @@ -441,21 +467,20 @@ class TestPostTrainingavgForMnistSkipOP(TestPostTrainingQuantization): infer_iterations = 50 quant_iterations = 5 skip_tensor_list = ["fc_0.w_0"] - self.run_test( - model_name, - data_url, - data_md5, - algo, - round_type, - quantizable_op_type, - is_full_quantize, - is_use_cache_file, - is_optimize_model, - diff_threshold, - batch_size, - infer_iterations, - quant_iterations, - skip_tensor_list=skip_tensor_list) + self.run_test(model_name, + data_url, + data_md5, + algo, + round_type, + quantizable_op_type, + is_full_quantize, + is_use_cache_file, + is_optimize_model, + diff_threshold, + batch_size, + infer_iterations, + quant_iterations, + skip_tensor_list=skip_tensor_list) if __name__ == '__main__': diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py index 8d94c49e469..25707d0c8c9 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py @@ -83,6 +83,7 @@ def _reader_creator(file_list, color_jitter=False, rotate=False, data_dir=DATA_DIR): + def reader(): with open(file_list) as flist: full_lines = [line.strip() for line in flist] @@ -97,8 +98,10 @@ def _reader_creator(file_list, continue yield img_path, int(label) - mapper = functools.partial( - process_image, mode=mode, color_jitter=color_jitter, rotate=rotate) + mapper = functools.partial(process_image, + mode=mode, + color_jitter=color_jitter, + rotate=rotate) return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE) @@ -109,6 +112,7 @@ def val(data_dir=DATA_DIR): class TestPostTrainingQuantization(unittest.TestCase): + def setUp(self): self.int8_download = 'int8/download' self.cache_folder = os.path.expanduser('~/.cache/paddle/dataset/' + @@ -156,8 +160,8 @@ class TestPostTrainingQuantization(unittest.TestCase): def cache_unzipping(self, target_folder, zip_path): if not os.path.exists(target_folder): - cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(target_folder, - zip_path) + cmd = 'mkdir {0} && tar xf {1} -C {0}'.format( + target_folder, zip_path) os.system(cmd) def download_data(self, data_urls, data_md5s, folder_name, is_model=True): @@ -210,11 +214,12 @@ class TestPostTrainingQuantization(unittest.TestCase): label = label.reshape([-1, 1]) t1 = time.time() - _, acc1, _ = exe.run( - infer_program, - feed={feed_dict[0]: image, - feed_dict[1]: label}, - fetch_list=fetch_targets) + _, acc1, _ = exe.run(infer_program, + feed={ + feed_dict[0]: image, + feed_dict[1]: label + }, + fetch_list=fetch_targets) t2 = time.time() period = t2 - t1 periods.append(period) @@ -241,13 +246,12 @@ class TestPostTrainingQuantization(unittest.TestCase): is_full_quantize=False, is_use_cache_file=False, is_optimize_model=False, - onnx_format=False, - skip_tensor_list=None): + onnx_format=False): try: os.system("mkdir " + self.int8_model) except Exception as e: - print("Failed to create {} due to {}".format(self.int8_model, - str(e))) + print("Failed to create {} due to {}".format( + self.int8_model, str(e))) sys.exit(-1) place = fluid.CPUPlace() @@ -255,18 +259,16 @@ class TestPostTrainingQuantization(unittest.TestCase): scope = fluid.global_scope() val_reader = val() - ptq = PostTrainingQuantization( - executor=exe, - sample_generator=val_reader, - model_dir=model_path, - algo=algo, - quantizable_op_type=quantizable_op_type, - round_type=round_type, - is_full_quantize=is_full_quantize, - optimize_model=is_optimize_model, - onnx_format=onnx_format, - skip_tensor_list=skip_tensor_list, - is_use_cache_file=is_use_cache_file) + ptq = PostTrainingQuantization(executor=exe, + sample_generator=val_reader, + model_dir=model_path, + algo=algo, + quantizable_op_type=quantizable_op_type, + round_type=round_type, + is_full_quantize=is_full_quantize, + optimize_model=is_optimize_model, + onnx_format=onnx_format, + is_use_cache_file=is_use_cache_file) ptq.quantize() ptq.save_quantized_model(self.int8_model) @@ -281,8 +283,7 @@ class TestPostTrainingQuantization(unittest.TestCase): is_use_cache_file, is_optimize_model, diff_threshold, - onnx_format=False, - skip_tensor_list=None): + onnx_format=False): infer_iterations = self.infer_iterations batch_size = self.batch_size sample_iterations = self.sample_iterations @@ -291,20 +292,22 @@ class TestPostTrainingQuantization(unittest.TestCase): print("Start FP32 inference for {0} on {1} images ...".format( model, infer_iterations * batch_size)) - (fp32_throughput, fp32_latency, fp32_acc1) = self.run_program( - model_cache_folder + "/model", batch_size, infer_iterations) + (fp32_throughput, fp32_latency, + fp32_acc1) = self.run_program(model_cache_folder + "/model", + batch_size, infer_iterations) print("Start INT8 post training quantization for {0} on {1} images ...". format(model, sample_iterations * batch_size)) - self.generate_quantized_model( - model_cache_folder + "/model", quantizable_op_type, algo, - round_type, is_full_quantize, is_use_cache_file, is_optimize_model, - onnx_format, skip_tensor_list) + self.generate_quantized_model(model_cache_folder + "/model", + quantizable_op_type, algo, round_type, + is_full_quantize, is_use_cache_file, + is_optimize_model, onnx_format) print("Start INT8 inference for {0} on {1} images ...".format( model, infer_iterations * batch_size)) - (int8_throughput, int8_latency, int8_acc1) = self.run_program( - self.int8_model, batch_size, infer_iterations) + (int8_throughput, int8_latency, + int8_acc1) = self.run_program(self.int8_model, batch_size, + infer_iterations) print("---Post training quantization of {} method---".format(algo)) print( @@ -322,6 +325,7 @@ class TestPostTrainingQuantization(unittest.TestCase): class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization): + def test_post_training_kl_mobilenetv1(self): model = "MobileNet-V1" algo = "KL" @@ -346,6 +350,7 @@ class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization): class TestPostTrainingavgForMobilenetv1(TestPostTrainingQuantization): + def test_post_training_avg_mobilenetv1(self): model = "MobileNet-V1" algo = "avg" @@ -369,6 +374,7 @@ class TestPostTrainingavgForMobilenetv1(TestPostTrainingQuantization): class TestPostTraininghistForMobilenetv1(TestPostTrainingQuantization): + def test_post_training_hist_mobilenetv1(self): model = "MobileNet-V1" algo = "hist" @@ -392,6 +398,7 @@ class TestPostTraininghistForMobilenetv1(TestPostTrainingQuantization): class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization): + def test_post_training_abs_max_mobilenetv1(self): model = "MobileNet-V1" algo = "abs_max" @@ -415,9 +422,10 @@ class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization): class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization): + def test_post_training_onnx_format_mobilenetv1(self): model = "MobileNet-V1" - algo = "avg" + algo = "emd" round_type = "round" data_urls = [ 'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz' @@ -433,51 +441,17 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization): is_optimize_model = True onnx_format = True diff_threshold = 0.05 - self.run_test( - model, - algo, - round_type, - data_urls, - data_md5s, - quantizable_op_type, - is_full_quantize, - is_use_cache_file, - is_optimize_model, - diff_threshold, - onnx_format=onnx_format) - - -class TestPostTrainingForMobilenetv1SkipOP(TestPostTrainingQuantization): - def test_post_training_mobilenetv1_skip(self): - model = "MobileNet-V1" - algo = "avg" - round_type = "round" - data_urls = [ - 'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz' - ] - data_md5s = ['13892b0716d26443a8cdea15b3c6438b'] - quantizable_op_type = [ - "conv2d", - "depthwise_conv2d", - "mul", - ] - is_full_quantize = False - is_use_cache_file = False - is_optimize_model = True - diff_threshold = 0.025 - skip_tensor_list = ["fc_0.w_0"] - self.run_test( - model, - algo, - round_type, - data_urls, - data_md5s, - quantizable_op_type, - is_full_quantize, - is_use_cache_file, - is_optimize_model, - diff_threshold, - skip_tensor_list=skip_tensor_list) + self.run_test(model, + algo, + round_type, + data_urls, + data_md5s, + quantizable_op_type, + is_full_quantize, + is_use_cache_file, + is_optimize_model, + diff_threshold, + onnx_format=onnx_format) if __name__ == '__main__': diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_resnet50.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_resnet50.py index dc12026a21a..c79499100ce 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_resnet50.py +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_resnet50.py @@ -21,6 +21,7 @@ paddle.enable_static() class TestPostTrainingForResnet50(TestPostTrainingQuantization): + def test_post_training_resnet50(self): model = "ResNet-50" algo = "min_max" @@ -40,6 +41,7 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization): class TestPostTrainingForResnet50ONNXFormat(TestPostTrainingQuantization): + def test_post_training_resnet50(self): model = "ResNet-50" algo = "min_max" @@ -54,18 +56,17 @@ class TestPostTrainingForResnet50ONNXFormat(TestPostTrainingQuantization): is_optimize_model = False diff_threshold = 0.025 onnx_format = True - self.run_test( - model, - algo, - round_type, - data_urls, - data_md5s, - quantizable_op_type, - is_full_quantize, - is_use_cache_file, - is_optimize_model, - diff_threshold, - onnx_format=onnx_format) + self.run_test(model, + algo, + round_type, + data_urls, + data_md5s, + quantizable_op_type, + is_full_quantize, + is_use_cache_file, + is_optimize_model, + diff_threshold, + onnx_format=onnx_format) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py b/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py index df458f97d59..02fff35fec7 100644 --- a/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py +++ b/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py @@ -21,8 +21,6 @@ import math from op_test import OpTest -# numpy.round has different behavior in comparision to c++ round function -# so we use round_c instead of numpy.round to align the output data def round_c_single_element(val): dtype = type(val) if val >= 0: @@ -30,6 +28,7 @@ def round_c_single_element(val): return dtype(np.ceil(val - 0.5)) +# rounding to nearest ties away from zero round_c = np.vectorize(round_c_single_element) @@ -41,17 +40,30 @@ def get_compute_type(dtype): class TestFakeQuantizeAbsMaxOp(OpTest): + def setUp(self): self.op_type = 'fake_quantize_abs_max' self.attrs = {'bit_length': 8} - def _fake_quantize_abs_max(self, dtype, input_shape, distribution): + def _fake_quantize_abs_max(self, + dtype, + input_shape, + distribution, + round_type='TiesAwayFromZero'): input_data = distribution(input_shape).astype(dtype) compute_type = get_compute_type(dtype) scale = np.max(np.abs(input_data)) bnt = (1 << (self.attrs['bit_length'] - 1)) - 1 inv_scale = 1.0 / (scale + 1e-6) if scale < 1e-30 else 1.0 / scale - output_data = round_c(input_data.astype(compute_type) * inv_scale * bnt) + if round_type == 'TiesToEven': + round_out = np.round( + input_data.astype(compute_type) * inv_scale * bnt) + output_data = np.clip(round_out, -bnt - 1, bnt) + self.attrs['round_type'] = 0 + else: + output_data = round_c( + input_data.astype(compute_type) * inv_scale * bnt) + self.attrs['round_type'] = 1 self.inputs = {'X': input_data} self.outputs = {'Out': output_data, 'OutScale': scale} self.dtype = dtype @@ -60,6 +72,11 @@ class TestFakeQuantizeAbsMaxOp(OpTest): def test_fake_quantize_abs_max(self): self._fake_quantize_abs_max(np.float32, (124, 240), np.random.random) + def test_fake_quantize_abs_max_round1(self): + self._fake_quantize_abs_max(np.float32, (124, 240), + np.random.random, + round_type='TiesToEven') + def test_fake_quantize_abs_max_float16(self): self._fake_quantize_abs_max(np.float16, (124, 240), np.random.random) @@ -72,21 +89,33 @@ class TestFakeQuantizeAbsMaxOp(OpTest): class TestFakeChannelWiseQuantizeAbsMaxOp(OpTest): + def setUp(self): self.op_type = 'fake_channel_wise_quantize_abs_max' self.attrs = {'bit_length': 8} - def _fake_channel_wise_quantize_abs_max(self, dtype, input_shape, - quant_axis, distribution): + def _fake_channel_wise_quantize_abs_max(self, + dtype, + input_shape, + quant_axis, + distribution, + round_type='TiesToEven'): assert quant_axis in [0, 1], 'quant_axis should be 0 or 1.' input_data = distribution(input_shape).astype(dtype) compute_type = get_compute_type(dtype) bnt = (1 << (self.attrs['bit_length'] - 1)) - 1 - compute_axis = tuple( - i for i in range(len(input_shape)) if i != quant_axis) + compute_axis = tuple(i for i in range(len(input_shape)) + if i != quant_axis) scale_broadcast = np.amax(input_data, axis=compute_axis, keepdims=True) - output_data = round_c(bnt * input_data.astype(compute_type) / - scale_broadcast) + if round_type == 'TiesToEven': + round_out = np.round( + input_data.astype(compute_type) / scale_broadcast * bnt) + output_data = np.clip(round_out, -bnt - 1, bnt) + self.attrs['round_type'] = 0 + else: + output_data = round_c(bnt * input_data.astype(compute_type) / + scale_broadcast) + self.attrs['round_type'] = 1 if quant_axis == 1: scale_broadcast = np.transpose(scale_broadcast, (1, ) + compute_axis) @@ -100,19 +129,24 @@ class TestFakeChannelWiseQuantizeAbsMaxOp(OpTest): def test_fake_channel_wise_quantize_abs_max(self): dtype_options = [np.float32, np.float16] input_shape_quant_axis_options = [[(20, 15, 6, 6), 0], - [(15, 20, 5, 5), 1], [(30, 15), 0], - [(30, 15), 1]] - for dtype, input_shape_quant_axis in itertools.product( - dtype_options, input_shape_quant_axis_options): + [(20, 15, 6, 6), 1], [(30, 30), 0], + [(30, 30), 1]] + round_type_options = ['TiesToEven', 'TiesAwayFromZero'] + for dtype, input_shape_quant_axis, round_type in itertools.product( + dtype_options, input_shape_quant_axis_options, + round_type_options): input_shape, quant_axis = input_shape_quant_axis - with self.subTest( - dtype=dtype, input_shape=input_shape, - quant_axis=quant_axis): + with self.subTest(dtype=dtype, + input_shape=input_shape, + quant_axis=quant_axis, + round_type=round_type): self._fake_channel_wise_quantize_abs_max( - dtype, input_shape, quant_axis, np.random.random) + dtype, input_shape, quant_axis, np.random.random, + round_type) class TestFakeQuantizeRangeAbsMaxOp(OpTest): + def setUp(self): self.op_type = 'fake_quantize_range_abs_max' self.attrs = {'bit_length': 5, 'window_size': 1} @@ -121,7 +155,8 @@ class TestFakeQuantizeRangeAbsMaxOp(OpTest): dtype, input_shape, distribution, - is_test=False): + is_test=False, + round_type='TiesToEven'): input_data = distribution(input_shape).astype(dtype) compute_type = get_compute_type(dtype) bnt = (1 << (self.attrs['bit_length'] - 1)) - 1 @@ -130,11 +165,19 @@ class TestFakeQuantizeRangeAbsMaxOp(OpTest): out_scale[0] = np.max(np.abs(input_data)) if is_test: out_scale[0] = in_scale[0] = out_scale[0] - 1.0 - clip_data = np.clip(input_data, -in_scale, in_scale) + if round_type == 'TiesToEven': + round_out = np.round( + input_data.astype(compute_type) / out_scale[0] * bnt) + self.attrs['round_type'] = 0 + output_data = np.clip(round_out, -bnt - 1, bnt) else: - clip_data = input_data - output_data = round_c( - clip_data.astype(compute_type) / out_scale[0] * bnt) + if is_test: + clip_data = np.clip(input_data, -in_scale, in_scale) + else: + clip_data = input_data + output_data = round_c( + clip_data.astype(compute_type) / out_scale[0] * bnt) + self.attrs['round_type'] = 1 self.inputs = { 'X': input_data, 'Iter': np.zeros(1).astype(np.int64), @@ -150,18 +193,24 @@ class TestFakeQuantizeRangeAbsMaxOp(OpTest): self.check_output() def test_fake_quantize_range_abs_max(self): - dtype_options = [np.float32, np.float16] + dtype_options = [np.float16, np.float32] is_test_options = [False, True] - for dtype, is_test in itertools.product(dtype_options, is_test_options): + round_type_options = ['TiesToEven', 'TiesAwayFromZero'] + for dtype, is_test, round_type in itertools.product( + dtype_options, is_test_options, round_type_options): self.attrs['bit_length'] = 8 if is_test else 5 - with self.subTest(dtype=dtype, is_test=is_test): + with self.subTest(dtype=dtype, + is_test=is_test, + round_type=round_type): self._fake_quantize_range_abs_max( - dtype, (8, 16, 7, 7), - lambda shape: (np.random.random(shape) - 0.5) * 10, - is_test=is_test) + dtype, (8, 16, 6, 6), + lambda shape: (np.random.random(shape) - 0.4) * 10, + is_test=is_test, + round_type=round_type) class TestMovingAverageAbsMaxScaleOp(OpTest): + def setUp(self): self.op_type = 'moving_average_abs_max_scale' self.attrs = {'moving_rate': float(0.9), 'is_test': False} @@ -194,6 +243,7 @@ class TestMovingAverageAbsMaxScaleOp(OpTest): class TestFakeQuantizeMovingAverageAbsMaxOp(OpTest): + def setUp(self): self.op_type = 'fake_quantize_moving_average_abs_max' self.attrs = {'bit_length': 5, 'moving_rate': 0.9, 'is_test': False} @@ -203,7 +253,8 @@ class TestFakeQuantizeMovingAverageAbsMaxOp(OpTest): input_shape, distribution, dequantize=False, - with_gradient=False): + with_gradient=False, + round_type='TiesAwayFromZero'): input_data = distribution(input_shape).astype(dtype) compute_type = get_compute_type(dtype) bnt = (1 << (self.attrs['bit_length'] - 1)) - 1 @@ -217,12 +268,20 @@ class TestFakeQuantizeMovingAverageAbsMaxOp(OpTest): np.abs(input_data)) out_state[0] = self.attrs['moving_rate'] * in_state[0] + 1.0 out_scale = out_accum / out_state - round_data = round_c(input_data.astype(compute_type) / out_scale * bnt) + if round_type == 'TiesToEven': + round_out = np.round( + input_data.astype(compute_type) / out_scale * bnt) + quant_data = np.clip(round_out, -bnt - 1, bnt) + self.attrs['round_type'] = 0 + else: + quant_data = round_c( + input_data.astype(compute_type) / out_scale * bnt) + self.attrs['round_type'] = 1 if dequantize: - output_data = (round_data * out_scale / bnt).astype(dtype) + output_data = (quant_data * out_scale / bnt).astype(dtype) self.op_type = 'fake_quantize_dequantize_moving_average_abs_max' else: - output_data = round_data.astype(dtype) + output_data = quant_data.astype(dtype) self.inputs = { 'X': input_data, 'InScale': in_scale, @@ -251,25 +310,39 @@ class TestFakeQuantizeMovingAverageAbsMaxOp(OpTest): self._fake_quantize_moving_average_abs_max(np.float16, (8, 16, 7, 7), np.random.random) + def test_fake_quantize_moving_average_abs_max_round1(self): + self._fake_quantize_moving_average_abs_max(np.float32, (8, 16, 7, 7), + np.random.random, + round_type='TiesToEven') + def test_fake_quantize_dequantize_moving_average_abs_max(self): - self._fake_quantize_moving_average_abs_max( - np.float32, (8, 16, 7, 7), - np.random.random, - dequantize=True, - with_gradient=True) + self._fake_quantize_moving_average_abs_max(np.float32, (8, 16, 7, 7), + np.random.random, + dequantize=True, + with_gradient=True) class TestFakeQuantizeDequantizeAbsMaxOp(OpTest): + def setUp(self): self.op_type = 'fake_quantize_dequantize_abs_max' self.attrs = {'bit_length': 8} - def _fake_quantize_dequantize_abs_max(self, dtype, input_shape, - distribution): + def _fake_quantize_dequantize_abs_max(self, + dtype, + input_shape, + distribution, + round_type='TiesAwayFromZero'): input_data = distribution(input_shape).astype(dtype) scale = np.max(np.abs(input_data)).astype(dtype) bnt = (1 << (self.attrs['bit_length'] - 1)) - 1 - output_data = round_c(input_data / scale * bnt) * scale / bnt + if round_type == 'TiesToEven': + round_out = np.round(input_data / scale * bnt) + output_data = np.clip(round_out, -bnt - 1, bnt) * scale / bnt + self.attrs['round_type'] = 0 + else: + output_data = round_c(input_data / scale * bnt) * scale / bnt + self.attrs['round_type'] = 1 self.inputs = {'X': input_data} self.outputs = { 'Out': output_data, @@ -284,24 +357,41 @@ class TestFakeQuantizeDequantizeAbsMaxOp(OpTest): self._fake_quantize_dequantize_abs_max(np.float32, (124, 240), np.random.random) + def test_fake_quantize_dequantize_abs_max_round1(self): + self._fake_quantize_dequantize_abs_max(np.float32, (124, 240), + np.random.random, + round_type='TiesToEven') + class TestChannelWiseFakeQuantizeDequantizeAbsMaxOp(OpTest): + def setUp(self): self.op_type = 'fake_channel_wise_quantize_dequantize_abs_max' self.attrs = {'bit_length': 8} - def _fake_channel_wise_quantize_dequantize_abs_max( - self, dtype, input_shape, quant_axis, distribution): + def _fake_channel_wise_quantize_dequantize_abs_max(self, + dtype, + input_shape, + quant_axis, + distribution, + round_type='TiesToEven'): assert quant_axis in [0, 1], 'quant_axis should be 0 or 1.' input_data = distribution(input_shape).astype(dtype) compute_type = get_compute_type(dtype) bnt = (1 << (self.attrs['bit_length'] - 1)) - 1 output_data = input_data.copy().astype(compute_type) - compute_axis = tuple( - i for i in range(len(input_shape)) if i != quant_axis) + compute_axis = tuple(i for i in range(len(input_shape)) + if i != quant_axis) scale_broadcast = np.amax(input_data, axis=compute_axis, keepdims=True) - output_data = round_c(bnt * output_data / - scale_broadcast) * scale_broadcast / bnt + if round_type == 'TiesToEven': + round_out = np.round(bnt * output_data / scale_broadcast) + output_data = np.clip(round_out, -bnt - 1, + bnt) * scale_broadcast / bnt + self.attrs['round_type'] = 0 + else: + output_data = round_c( + bnt * output_data / scale_broadcast) * scale_broadcast / bnt + self.attrs['round_type'] = 1 if quant_axis == 1: scale_broadcast = np.transpose(scale_broadcast, (1, ) + compute_axis) @@ -318,10 +408,19 @@ class TestChannelWiseFakeQuantizeDequantizeAbsMaxOp(OpTest): input_shape_quant_axis_options = [[(3, 4, 64, 64), 0], [(15, 20, 5, 5), 1], [(30, 15), 0], [(30, 15), 1]] - for input_shape, quant_axis in input_shape_quant_axis_options: - with self.subTest(input_shape=input_shape, quant_axis=quant_axis): + round_type_options = ['TiesToEven', 'TiesAwayFromZero'] + for input_shape_quant_axis, round_type in itertools.product( + input_shape_quant_axis_options, round_type_options): + input_shape, quant_axis = input_shape_quant_axis + with self.subTest(input_shape=input_shape, + quant_axis=quant_axis, + round_type=round_type): self._fake_channel_wise_quantize_dequantize_abs_max( - np.float32, input_shape, quant_axis, np.random.random) + np.float32, + input_shape, + quant_axis, + np.random.random, + round_type=round_type) def quantize_max_abs(x, max_range): @@ -349,6 +448,7 @@ def channel_wise_quantize_max_abs(x, quant_bit=8, quant_axis=0): class TestChannelWiseQuantizeOp(OpTest): + def set_args(self): self.bit_length = 8 self.data_type = "float32" @@ -375,6 +475,7 @@ class TestChannelWiseQuantizeOp(OpTest): class TestChannelWiseQuantizeOp1(TestChannelWiseQuantizeOp): + def set_args(self): self.bit_length = 8 self.data_type = "float32" @@ -382,6 +483,7 @@ class TestChannelWiseQuantizeOp1(TestChannelWiseQuantizeOp): class TestChannelWiseQuantizeOpTrain(OpTest): + def set_args(self): self.bit_length = 8 self.data_type = "float32" @@ -410,6 +512,7 @@ class TestChannelWiseQuantizeOpTrain(OpTest): class TestquantizeOp(OpTest): + def set_args(self): self.bit_length = 8 self.quant_axis = -1 @@ -436,6 +539,7 @@ class TestquantizeOp(OpTest): class TestquantizeOpTrain(TestquantizeOp): + def set_args(self): self.bit_length = 8 self.quant_axis = -1 -- GitLab