From 31f8168506c50b96b7db1cba28ba1dfdad2516cd Mon Sep 17 00:00:00 2001 From: "joanna.wozna.intel" Date: Wed, 22 Mar 2023 13:33:09 +0100 Subject: [PATCH] Correct lstm qat test (#51499) --- .../compute_propagate_scales_mkldnn_pass.cc | 6 +- .../framework/ir/mkldnn/cpu_quantize_pass.cc | 6 +- .../framework/ir/mkldnn/mkldnn_pass_util.h | 77 ++++++++----------- .../ir/mkldnn/quant_dequant_mkldnn_pass.cc | 4 +- .../inference/api/paddle_pass_builder.cc | 2 +- .../tests/quant2_int8_lstm_model.py | 5 +- 6 files changed, 46 insertions(+), 54 deletions(-) diff --git a/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc b/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc index 19b589ceecb..16dc293b060 100644 --- a/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc @@ -37,7 +37,7 @@ void ComputePropagateScalesMkldnnPass::GetTensorFromVector( void ComputePropagateScalesMkldnnPass::GetQuantInfo( ir::Graph* graph, StringPairMap* var_quant_scales) const { std::unordered_map> info_map{}; - GetInfoFromTheFirstOp(graph, "has_quant_info", "var_quant_scales", &info_map); + GetInfoFromTheTmpOp(graph, "has_quant_info", "var_quant_scales", &info_map); for (auto iter = info_map.begin(); iter != info_map.end(); iter++) { phi::DenseTensor tensor; @@ -510,9 +510,9 @@ void ComputePropagateScalesMkldnnPass::ApplyImpl(ir::Graph* graph) const { UpdateReluOutputScales(graph, &var_quant_scales); PropagateScales(graph, &var_quant_scales, scale_immutable_ops); - // save var_quant_scales in the first op's attr + // save var_quant_scales in the temporary save op's attr // for cpu_quantize_pass - SaveInfoInTheFirstOp( + SaveInfoInTheTmpOp( graph, "has_quant_info", "var_quant_scales", var_quant_scales); } diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc index 33f25739f0f..2c7d93278ab 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc @@ -435,7 +435,7 @@ bool CPUQuantizePass::IsOpQuantized(const Node* node) const { } void CPUQuantizePass::GetQuantInfo(Graph* graph) const { - GetInfoFromTheFirstOp( + GetInfoFromTheTmpOp( graph, "has_quant_info", "var_quant_scales", var_quant_scales_); } @@ -1250,6 +1250,10 @@ void CPUQuantizePass::QuantizeFusionLSTM(Graph* graph) const { bool is_x_unsigned{false}; auto input_x_scale = GetScaleValueForNode(x, &is_x_unsigned); + // In the QAT process scales are prepared for only int8 data type, + // lstm scales should behave as input is int8 to get correct accuracy + is_x_unsigned = false; + double input_x_shift{128.}; if (is_x_unsigned) input_x_shift = 0.; diff --git a/paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h b/paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h index eb13f57c50f..0cf714af3b9 100644 --- a/paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h +++ b/paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h @@ -25,7 +25,7 @@ namespace ir { using StringPairMap = std::unordered_map>; -static void SaveInfoInTheFirstOp( +static void SaveInfoInTheTmpOp( ir::Graph* graph, const std::string& flag, const std::string& key_suffix, @@ -33,48 +33,39 @@ static void SaveInfoInTheFirstOp( VLOG(3) << "save variables in the first op's attr"; const std::string suffix = "_" + key_suffix + "_" + flag; - for (auto* op_node : - ir::TopologyVarientSort(*graph, static_cast(0))) { - if (!op_node->IsOp() || op_node->Op()->Type() == "feed" || - op_node->Op()->Type() == "fetch" || - op_node->Op()->Type() == "fill_constant") - continue; - - op_node->Op()->SetAttr(flag, true); - for (auto iter = info_map.begin(); iter != info_map.end(); ++iter) { - op_node->Op()->SetAttr(iter->first + suffix, iter->second); - } - break; + OpDesc op_desc; + op_desc.SetType("save"); + auto* op_node = graph->CreateOpNode(&op_desc); + + op_node->Op()->SetAttr(flag, true); + for (auto iter = info_map.begin(); iter != info_map.end(); ++iter) { + op_node->Op()->SetAttr(iter->first + suffix, iter->second); } } -static void SaveInfoInTheFirstOp(ir::Graph* graph, - const std::string& flag, - const std::string& key_suffix, - const StringPairMap& info_map) { +static void SaveInfoInTheTmpOp(ir::Graph* graph, + const std::string& flag, + const std::string& key_suffix, + const StringPairMap& info_map) { VLOG(3) << "save variables in the first op's attr"; const std::string suffix = "_" + key_suffix + "_" + flag; - for (auto* op_node : - ir::TopologyVarientSort(*graph, static_cast(0))) { - if (!op_node->IsOp() || op_node->Op()->Type() == "feed" || - op_node->Op()->Type() == "fetch" || - op_node->Op()->Type() == "fill_constant") - continue; - - op_node->Op()->SetAttr(flag, true); - for (auto iter = info_map.begin(); iter != info_map.end(); ++iter) { - auto* data = iter->second.second.data(); - std::vector data_v(data, data + iter->second.second.numel()); - op_node->Op()->SetAttr(iter->first + suffix + "_unsigned", - iter->second.first); - op_node->Op()->SetAttr(iter->first + suffix, data_v); - } - break; + + OpDesc op_desc; + op_desc.SetType("save"); + auto* op_node = graph->CreateOpNode(&op_desc); + + op_node->Op()->SetAttr(flag, true); + for (auto iter = info_map.begin(); iter != info_map.end(); ++iter) { + auto* data = iter->second.second.data(); + std::vector data_v(data, data + iter->second.second.numel()); + op_node->Op()->SetAttr(iter->first + suffix + "_unsigned", + iter->second.first); + op_node->Op()->SetAttr(iter->first + suffix, data_v); } } -static void GetInfoFromTheFirstOp( +static void GetInfoFromTheTmpOp( ir::Graph* graph, const std::string& flag, const std::string& key_suffix, @@ -84,9 +75,7 @@ static void GetInfoFromTheFirstOp( const std::string suffix = "_" + key_suffix + "_" + flag; for (auto* op_node : ir::TopologyVarientSort(*graph, static_cast(0))) { - if (!op_node->IsOp() || op_node->Op()->Type() == "feed" || - op_node->Op()->Type() == "fetch") - continue; + if (!op_node->IsOp() || op_node->Op()->Type() != "save") continue; auto* op_desc = op_node->Op(); if (op_desc->GetAttrIfExists(flag)) { @@ -102,24 +91,23 @@ static void GetInfoFromTheFirstOp( op_desc->RemoveAttr(fake_name); } } + graph->RemoveNode(op_node); break; } } } -static void GetInfoFromTheFirstOp(ir::Graph* graph, - const std::string& flag, - const std::string& key_suffix, - StringPairMap* info_map) { +static void GetInfoFromTheTmpOp(ir::Graph* graph, + const std::string& flag, + const std::string& key_suffix, + StringPairMap* info_map) { VLOG(3) << "get variables from the first op's attr"; const std::string unsigned_flag = "_unsigned"; const std::string suffix = "_" + key_suffix + "_" + flag; const std::string suffix_is_unsigned = suffix + unsigned_flag; for (auto* op_node : ir::TopologyVarientSort(*graph, static_cast(0))) { - if (!op_node->IsOp() || op_node->Op()->Type() == "feed" || - op_node->Op()->Type() == "fetch") - continue; + if (!op_node->IsOp() || op_node->Op()->Type() != "save") continue; auto* op_desc = op_node->Op(); if (op_desc->GetAttrIfExists(flag)) { @@ -150,6 +138,7 @@ static void GetInfoFromTheFirstOp(ir::Graph* graph, op_desc->RemoveAttr(vector_name); } } + graph->RemoveNode(op_node); break; } } diff --git a/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc b/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc index 0d4b0c24f72..9f643199dbe 100644 --- a/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc @@ -754,9 +754,9 @@ void QuantDequantMkldnnPass::ApplyImpl(ir::Graph* graph) const { UpdateActivations(graph); RemoveCtrlVars(graph); - // save var_quant_scales in the first op's attr + // save var_quant_scales in the temporary save op's attr // for compute_propagate_scales_mkldnn_pass - SaveInfoInTheFirstOp( + SaveInfoInTheTmpOp( graph, "has_quant_info", "var_quant_scales", var_quant_scales); } diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 35e0bab83e0..43fa40c4fa7 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -430,7 +430,6 @@ void CpuPassStrategy::EnableMkldnnInt8() { passes_.push_back("simplify_with_basic_ops_pass"); passes_.push_back("quant_dequant_mkldnn_pass"); passes_.push_back("mkldnn_placement_pass"); - passes_.push_back("constant_folding_pass"); passes_.push_back("squeeze2_transpose2_onednn_fuse_pass"); passes_.push_back("layer_norm_fuse_pass"); passes_.push_back("attention_lstm_fuse_pass"); @@ -485,6 +484,7 @@ void CpuPassStrategy::EnableMkldnnInt8() { passes_.push_back("quant_transpose2_dequant_onednn_fuse_pass"); passes_.push_back("int8_scale_calculation_mkldnn_pass"); passes_.push_back("params_quantization_mkldnn_pass"); + passes_.push_back("constant_folding_pass"); } use_mkldnn_int8_ = true; #else diff --git a/python/paddle/static/quantization/tests/quant2_int8_lstm_model.py b/python/paddle/static/quantization/tests/quant2_int8_lstm_model.py index 92ea480362c..01a205c3a99 100644 --- a/python/paddle/static/quantization/tests/quant2_int8_lstm_model.py +++ b/python/paddle/static/quantization/tests/quant2_int8_lstm_model.py @@ -116,10 +116,9 @@ class TestLstmModelPTQ(unittest.TestCase): config.switch_ir_optim(True) config.enable_mkldnn() config.disable_mkldnn_fc_passes() # fc passes caused dnnl error + config.pass_builder().insert_pass(5, "fc_lstm_fuse_pass") config.set_mkldnn_cache_capacity(mkldnn_cache_capacity) if mode == "ptq": - # This pass to work properly, must be added before fc_fuse_pass - config.pass_builder().insert_pass(5, "fc_lstm_fuse_pass") config.enable_quantizer() config.quantizer_config().set_quant_data(warmup_data) config.quantizer_config().set_quant_batch_size(1) @@ -244,7 +243,7 @@ class TestLstmModelPTQ(unittest.TestCase): ) (quant_hx_acc, quant_ctc_acc, quant_fps) = self.run_program( - quant_model + "_int8", + quant_model, infer_data, num_threads, mkldnn_cache_capacity, -- GitLab