diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index 705b0e5e69ee6d4c902b224ddad90157ed5d6e52..55e1dcacdcb628513fde876c5767876237fc6d36 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -2907,13 +2907,31 @@ class ReplaceFakeQuantDequantPass: graph, IrGraph ), 'graph must be the instance of IrGraph.' fake_quant_dequant_ops = [] + remove_fake_quant_ops = [] + observer_out_node_names = [] + for op in graph.all_op_nodes(): + # collect observer node + if op.name() == "moving_average_abs_max_scale": + observer_out_node_names.append(op.output("Out")[0]) for op in graph.all_op_nodes(): if ( op.name() in _fake_quant_dequant_op_list or op.name() == "moving_average_abs_max_scale" ): - fake_quant_dequant_ops.append(op) + var_name = op.input("X")[0] + if var_name in observer_out_node_names: + remove_fake_quant_ops.append(op) + else: + fake_quant_dequant_ops.append(op) + + for _op in remove_fake_quant_ops: + x_node = graph._find_node_by_name(_op.inputs, _op.input("X")[0]) + out_node = graph._find_node_by_name( + _op.outputs, _op.output("Out")[0] + ) + for next_op_node in out_node.outputs: + graph.update_input_link(out_node, x_node, next_op_node) for _op in fake_quant_dequant_ops: self._replace_op(graph, _op)