diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py index 5d29dc522b3ef6e008a769b3c899672fe3aa464b..24caf1479543e3e1870c8e8e9283233d94dfb3bc 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py @@ -22,6 +22,7 @@ import warnings import paddle import paddle.nn.quant.quant_layers as quant_layers from paddle.fluid import dygraph, core, framework, unique_name +from paddle.fluid.framework import IrGraph from paddle.fluid.executor import Executor, global_scope from paddle.fluid.param_attr import ParamAttr from paddle.fluid.initializer import Constant @@ -486,6 +487,15 @@ class ImperativeQuantizeOutputs(object): self._gather_scales(infer_program, scope, fetch_targets) + # Remove `moving_average_abs_max_scale` node in sub graphs. + graph = IrGraph(core.Graph(infer_program.desc), for_test=False) + for sub_graph in graph.all_sub_graphs(): + for _op in sub_graph.all_op_nodes(): + if _op.name() == "moving_average_abs_max_scale": + sub_graph.safe_remove_nodes(_op) + sub_graph.resolve_hazard() + infer_program = graph.to_program() + self._set_skip_quant_attr(infer_program) save_inference_model(