未验证 提交 76c73226 编写于 作者: G Guanghua Yu 提交者: GitHub

fix fetch op rename_input bug in QAT export model (#38012)

上级 2567dfa4
......@@ -484,7 +484,7 @@ class ImperativeQuantizeOutputs(object):
model_filename=model_filename,
params_filename=params_filename))
self._gather_scales(infer_program, scope)
self._gather_scales(infer_program, scope, fetch_targets)
self._set_skip_quant_attr(infer_program)
......@@ -520,10 +520,10 @@ class ImperativeQuantizeOutputs(object):
return flag
def _gather_scales(self, program, scope):
def _gather_scales(self, program, scope, fetch_targets):
"""
Get all scales from fake ops, save them into the corresponding ops
and delete all moving_average_abs_max_scale ops.
and delete all moving_average_abs_max_scale ops.
"""
def _gather_input_scale():
......@@ -580,6 +580,11 @@ class ImperativeQuantizeOutputs(object):
for next_op in next_ops:
next_op._rename_input(out_var_name, in_var_name)
# If next_op is `fetch` and out_var_name in fetch_targets,
# fetch_targets must update to in_var_name when rename input.
for i in range(len(fetch_targets)):
if fetch_targets[i].name == out_var_name:
fetch_targets[i] = block.var(in_var_name)
_gather_input_scale()
_gather_output_scale()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册