未验证 提交 6f8243f9 编写于 作者: G Guanghua Yu 提交者: GitHub

fix QAT export bug in while OP (#38122)

上级 29e540af
......@@ -22,6 +22,7 @@ import warnings
import paddle
import paddle.nn.quant.quant_layers as quant_layers
from paddle.fluid import dygraph, core, framework, unique_name
from paddle.fluid.framework import IrGraph
from paddle.fluid.executor import Executor, global_scope
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Constant
......@@ -486,6 +487,15 @@ class ImperativeQuantizeOutputs(object):
self._gather_scales(infer_program, scope, fetch_targets)
# Remove `moving_average_abs_max_scale` node in sub graphs.
graph = IrGraph(core.Graph(infer_program.desc), for_test=False)
for sub_graph in graph.all_sub_graphs():
for _op in sub_graph.all_op_nodes():
if _op.name() == "moving_average_abs_max_scale":
sub_graph.safe_remove_nodes(_op)
sub_graph.resolve_hazard()
infer_program = graph.to_program()
self._set_skip_quant_attr(infer_program)
save_inference_model(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册