diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_mkldnn_pass.py index eda62b4674d017635456389ebddd5b07e652b0e7..bdfd7cdef6f64a57ac31e44aa07c9f729f42d766 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_mkldnn_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_mkldnn_pass.py @@ -321,10 +321,11 @@ class FakeQAT2MkldnnINT8PerfPass(object): graph = self._gather_scales(graph) graph = self._remove_fake_ops(graph) - graph = self._update_pooling_scales(graph) graph = self._dequantize_weights(graph) graph = self._optimize_fp32_graph(graph) graph = self._compute_weight_scales(graph) + graph = self._update_conv_relu_scales(graph) + graph = self._update_pooling_scales(graph) graph = self._quantize_fp32_graph(graph) graph = self._remove_unused_var_nodes(graph) return graph @@ -350,6 +351,8 @@ class FakeQAT2MkldnnINT8PerfPass(object): use_unsigned_int = False self._var_quant_scales[input_name] = (use_unsigned_int, lod_tensor) + self._var_quant_scales[scale_name.replace(".scale", "")] = ( + use_unsigned_int, lod_tensor) if op.name() in self._fake_dequantize_types: input_name = op.input("X")[0] @@ -378,13 +381,13 @@ class FakeQAT2MkldnnINT8PerfPass(object): next_op = op_out.outputs[0] if next_op.name() not in self._mul_ops: self._remove_fake_quantize(graph, op) - else: - quant_op = self._transform_to_quantize_mkldnn(graph, op) - self._transform_to_mul_mkldnn(graph, next_op, quant_op) for op in graph.all_op_nodes(): if op.name() in self._fake_dequantize_types: - self._remove_fake_dequantize(graph, op) + op_in = graph._find_node_by_name(op.inputs, op.input("X")[0]) + prev_op = op_in.inputs[0] + if prev_op.name() not in self._mul_ops: + self._remove_fake_dequantize(graph, op) return graph def _remove_fake_quantize(self, graph, op): @@ -530,7 +533,7 @@ class FakeQAT2MkldnnINT8PerfPass(object): if op.name() in self._pool_ops: if op.op().attr("pooling_type") == "avg": ids.append(op.id()) - return set(ids) + return set(ids) if len(ids) else set([-1]) def _transform_to_quantize_mkldnn(self, graph, op_node): """ @@ -557,13 +560,16 @@ class FakeQAT2MkldnnINT8PerfPass(object): graph.safe_remove_nodes(op_node) return quant_op_node - def _transform_to_mul_mkldnn(self, graph, op_node, quantize_node): - input_name = op_node.input("X")[0] - scale_in = quantize_node.op().attr("Scale") - op_node.set_attr("scale_y", [1.0]) - op_node.set_attr("scale_x", scale_in) - op_node.set_attr("scale_out", 1.0) - op_node.set_attr("force_fp32_output", True) + def _update_conv_relu_scales(self, graph): + for op in graph.all_op_nodes(): + if op.name() in self._conv_ops: + out_name = op.output("Output")[0] + if out_name in self._var_quant_scales and \ + op.op().attr("fuse_activation") == 'relu' and \ + op.op().attr("fuse_residual_connection") == False: + _, tensor = self._var_quant_scales[out_name] + self._var_quant_scales[out_name] = (True, tensor) + return graph def _quantize_fp32_graph(self, graph): ir_pass = self._core.get_pass('cpu_quantize_placement_pass')