From 6f8243f956ab8fd33185faa46ff1d798c11afe95 Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Thu, 23 Dec 2021 17:01:42 +0800 Subject: [PATCH] fix QAT export bug in while OP (#38122) --- .../fluid/contrib/slim/quantization/imperative/qat.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py index 5d29dc522b3..24caf147954 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py @@ -22,6 +22,7 @@ import warnings import paddle import paddle.nn.quant.quant_layers as quant_layers from paddle.fluid import dygraph, core, framework, unique_name +from paddle.fluid.framework import IrGraph from paddle.fluid.executor import Executor, global_scope from paddle.fluid.param_attr import ParamAttr from paddle.fluid.initializer import Constant @@ -486,6 +487,15 @@ class ImperativeQuantizeOutputs(object): self._gather_scales(infer_program, scope, fetch_targets) + # Remove `moving_average_abs_max_scale` node in sub graphs. + graph = IrGraph(core.Graph(infer_program.desc), for_test=False) + for sub_graph in graph.all_sub_graphs(): + for _op in sub_graph.all_op_nodes(): + if _op.name() == "moving_average_abs_max_scale": + sub_graph.safe_remove_nodes(_op) + sub_graph.resolve_hazard() + infer_program = graph.to_program() + self._set_skip_quant_attr(infer_program) save_inference_model( -- GitLab