From a85dedf978e05bcb1000aa1d30ec65fd2415d6bf Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Thu, 8 Dec 2022 10:04:27 +0800 Subject: [PATCH] Delete duplicate quant nodes in QAT (#48751) --- .../slim/quantization/quantization_pass.py | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index 705b0e5e69..55e1dcacdc 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -2907,13 +2907,31 @@ class ReplaceFakeQuantDequantPass: graph, IrGraph ), 'graph must be the instance of IrGraph.' fake_quant_dequant_ops = [] + remove_fake_quant_ops = [] + observer_out_node_names = [] + for op in graph.all_op_nodes(): + # collect observer node + if op.name() == "moving_average_abs_max_scale": + observer_out_node_names.append(op.output("Out")[0]) for op in graph.all_op_nodes(): if ( op.name() in _fake_quant_dequant_op_list or op.name() == "moving_average_abs_max_scale" ): - fake_quant_dequant_ops.append(op) + var_name = op.input("X")[0] + if var_name in observer_out_node_names: + remove_fake_quant_ops.append(op) + else: + fake_quant_dequant_ops.append(op) + + for _op in remove_fake_quant_ops: + x_node = graph._find_node_by_name(_op.inputs, _op.input("X")[0]) + out_node = graph._find_node_by_name( + _op.outputs, _op.output("Out")[0] + ) + for next_op_node in out_node.outputs: + graph.update_input_link(out_node, x_node, next_op_node) for _op in fake_quant_dequant_ops: self._replace_op(graph, _op) -- GitLab