diff --git a/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc b/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc index e19426d01d195c33eb6a58af4578d16c0b679a86..bd945c139f601ac933c361b0252bfee2363e4489 100644 --- a/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc @@ -349,9 +349,9 @@ std::unordered_set ComputePropagateScalesMkldnnPass::UpdateScales( waiting_for_scale.insert(input_name); waiting_for_scale.insert(output_name); } else if (in_iter != var_quant_scales->end()) { - out_iter->second = in_iter->second; + (*var_quant_scales)[output_name] = in_iter->second; } else if (out_iter != var_quant_scales->end()) { - in_iter->second = out_iter->second; + (*var_quant_scales)[input_name] = out_iter->second; } } else if (op_name == "scale") { const std::string output_name = op_node->Op()->Output("Out")[0]; diff --git a/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc b/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc index 15100b23407b018e9c3b6ec2e35de6605941b9d4..5f92a4bb7f15b3c571b7eb0d12d608a7f4375095 100644 --- a/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc @@ -38,7 +38,7 @@ void QuantDequantMkldnnPass::MarkSkipQuantizedOps( for (auto* node_input : op_node->inputs) { for (auto* node_input_input : node_input->inputs) { if (!node_input_input->IsOp()) continue; - if (node_input_input->Name().find("quantize_dequantize") == + if (node_input_input->Name().find("quantize") == std::string::npos) { is_quantized_op = false; break; diff --git a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py index 220016bd653bcaea4f4f922bce76a698672a4725..76feab207eebdb78cf240c1d793aa7232ab08ff5 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py @@ -158,7 +158,7 @@ class Quant2Int8MkldnnPass(object): is_quantized_op = True for var_node in op_node.inputs: for front_op_node in var_node.inputs: - if "quantize_dequantize" not in front_op_node.name(): + if "quantize" not in front_op_node.name(): is_quantized_op = False if not is_quantized_op: op_node.op()._set_attr("skip_quant", True)