diff --git a/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc b/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc index f7ee6a96dce047f3569503b6561a88a9c584270e..99eaab49b7926f5cbc3d2975cc14e98e5326c54a 100644 --- a/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc @@ -390,6 +390,13 @@ std::unordered_set ComputePropagateScalesMkldnnPass::UpdateScales( } else if (out_iter != var_quant_scales->end()) { (*var_quant_scales)[input_name] = out_iter->second; } + } else if (op_name == "concat") { + auto out_iter = var_quant_scales->find(op_node->Op()->Output("Out")[0]); + if (out_iter != var_quant_scales->end()) { + std::vector input_names = op_node->Op()->Input("X"); + for (auto input_name : input_names) + (*var_quant_scales)[input_name] = out_iter->second; + } } else if (op_name == "scale") { const std::string output_name = op_node->Op()->Output("Out")[0]; auto out_iter = var_quant_scales->find(output_name); diff --git a/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc b/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc index 40c6050a3c3f1659a9c86ad3a2eafbb37488cb64..42c54fcb36242f548c16e21c22cee52835faaf90 100644 --- a/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc @@ -55,23 +55,6 @@ void QuantDequantMkldnnPass::MarkSkipQuantizedOps( } } -void QuantDequantMkldnnPass::MarkSkipQuantizedPool2d(ir::Graph* graph) const { - VLOG(3) << "mark avg pool2d as skip quantized op"; - for (auto* op_node : - ir::TopologyVarientSort(*graph, static_cast(0))) { - if (!op_node->IsOp()) continue; - - if (op_node->Name() == "pool2d") { - auto* op_desc = op_node->Op(); - auto pool_type = - BOOST_GET_CONST(std::string, op_desc->GetAttr("pooling_type")); - if (pool_type == "avg") { - op_node->Op()->SetAttr("skip_quant", 1); - } - } - } -} - void QuantDequantMkldnnPass::CollectInfoFromFake( ir::Graph* graph, Scope* scope, @@ -548,7 +531,6 @@ void QuantDequantMkldnnPass::ApplyImpl(ir::Graph* graph) const { auto* scope = param_scope(); MarkSkipQuantizedOps(graph, skip_ops); - MarkSkipQuantizedPool2d(graph); CollectInfoFromFake(graph, scope, fake_dequantize_types, &weight_thresholds); CollectInputScalesFromFake( graph, scope, fake_quantize_types, &var_quant_scales); diff --git a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py index 0d17673a2d522df6be09525cf88c6e117c0450a4..2f155ca0edfc22fc633da26ab98b21c526c3a67e 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py @@ -264,6 +264,14 @@ class Quant2Int8MkldnnPass(object): elif output_name in self._var_quant_scales: self._var_quant_scales[ input_name] = self._var_quant_scales[output_name] + elif op.name() == 'concat': + output_name = op.output("Out")[0] + if output_name in self._var_quant_scales: + input_names = op.input("X") + for input_name in input_names: + self._var_quant_scales[ + input_name] = self._var_quant_scales[ + output_name] elif op.name() in self._scale_ops: input_name = op.input("X")[0] output_name = op.output("Out")[0] @@ -595,13 +603,6 @@ class Quant2Int8MkldnnPass(object): _compute_lstm_weight_scales("WeightX", "WeightH") return graph - def _find_avg_pooling_ids(self, graph): - for op in graph.all_op_nodes(): - if op.name() in self._pool_ops: - if op.op().attr("pooling_type") == "avg": - self._op_ids_to_skip.add(op.id()) - return self._op_ids_to_skip - def _update_relu_output_scales(self, graph): def _set_unsigned_scale(graph, ops, op_out_name, predicate): @@ -651,11 +652,9 @@ class Quant2Int8MkldnnPass(object): 'reshape_transpose_matmul_mkldnn_fuse_pass') graph = self._apply_pass( graph, 'reshape_transpose_matmul_v2_mkldnn_fuse_pass') - graph = self._apply_pass( - graph, 'cpu_quantize_placement_pass', - ['quantize_enabled_op_types', 'quantize_excluded_op_ids'], - [self._ops_to_quantize, - self._find_avg_pooling_ids(graph)]) + graph = self._apply_pass(graph, 'cpu_quantize_placement_pass', + ['quantize_enabled_op_types'], + [self._ops_to_quantize]) graph = self._apply_pass( graph, 'cpu_quantize_pass', ['quant_var_scales', 'data_layout'], [self._var_quant_scales,