From 9551e4666733c120a46ed5e59623557bfeffde57 Mon Sep 17 00:00:00 2001 From: "joanna.wozna.intel" Date: Tue, 7 Jun 2022 15:06:41 +0200 Subject: [PATCH] Correct skip_quant condition (#43184) --- .../ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc | 4 ++-- paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc | 2 +- .../contrib/slim/quantization/quant2_int8_mkldnn_pass.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc b/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc index e19426d01d..bd945c139f 100644 --- a/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc @@ -349,9 +349,9 @@ std::unordered_set ComputePropagateScalesMkldnnPass::UpdateScales( waiting_for_scale.insert(input_name); waiting_for_scale.insert(output_name); } else if (in_iter != var_quant_scales->end()) { - out_iter->second = in_iter->second; + (*var_quant_scales)[output_name] = in_iter->second; } else if (out_iter != var_quant_scales->end()) { - in_iter->second = out_iter->second; + (*var_quant_scales)[input_name] = out_iter->second; } } else if (op_name == "scale") { const std::string output_name = op_node->Op()->Output("Out")[0]; diff --git a/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc b/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc index 15100b2340..5f92a4bb7f 100644 --- a/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc @@ -38,7 +38,7 @@ void QuantDequantMkldnnPass::MarkSkipQuantizedOps( for (auto* node_input : op_node->inputs) { for (auto* node_input_input : node_input->inputs) { if (!node_input_input->IsOp()) continue; - if (node_input_input->Name().find("quantize_dequantize") == + if (node_input_input->Name().find("quantize") == std::string::npos) { is_quantized_op = false; break; diff --git a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py index 220016bd65..76feab207e 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py @@ -158,7 +158,7 @@ class Quant2Int8MkldnnPass(object): is_quantized_op = True for var_node in op_node.inputs: for front_op_node in var_node.inputs: - if "quantize_dequantize" not in front_op_node.name(): + if "quantize" not in front_op_node.name(): is_quantized_op = False if not is_quantized_op: op_node.op()._set_attr("skip_quant", True) -- GitLab