未验证 提交 8b86aad9 编写于 作者: G Guanghua Yu 提交者: GitHub

fix fetch op rename_input bug in QAT export model (#38025)

上级 a4c0c71c
...@@ -484,7 +484,7 @@ class ImperativeQuantizeOutputs(object): ...@@ -484,7 +484,7 @@ class ImperativeQuantizeOutputs(object):
model_filename=model_filename, model_filename=model_filename,
params_filename=params_filename)) params_filename=params_filename))
self._gather_scales(infer_program, scope) self._gather_scales(infer_program, scope, fetch_targets)
self._set_skip_quant_attr(infer_program) self._set_skip_quant_attr(infer_program)
...@@ -520,10 +520,10 @@ class ImperativeQuantizeOutputs(object): ...@@ -520,10 +520,10 @@ class ImperativeQuantizeOutputs(object):
return flag return flag
def _gather_scales(self, program, scope): def _gather_scales(self, program, scope, fetch_targets):
""" """
Get all scales from fake ops, save them into the corresponding ops Get all scales from fake ops, save them into the corresponding ops
and delete all moving_average_abs_max_scale ops. and delete all moving_average_abs_max_scale ops.
""" """
def _gather_input_scale(): def _gather_input_scale():
...@@ -580,6 +580,11 @@ class ImperativeQuantizeOutputs(object): ...@@ -580,6 +580,11 @@ class ImperativeQuantizeOutputs(object):
for next_op in next_ops: for next_op in next_ops:
next_op._rename_input(out_var_name, in_var_name) next_op._rename_input(out_var_name, in_var_name)
# If next_op is `fetch` and out_var_name in fetch_targets,
# fetch_targets must update to in_var_name when rename input.
for i in range(len(fetch_targets)):
if fetch_targets[i].name == out_var_name:
fetch_targets[i] = block.var(in_var_name)
_gather_input_scale() _gather_input_scale()
_gather_output_scale() _gather_output_scale()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册