From 8b86aad9cd2c99adf556513b1672ad10ace314d3 Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Fri, 10 Dec 2021 13:56:21 +0800 Subject: [PATCH] fix fetch op rename_input bug in QAT export model (#38025) --- .../fluid/contrib/slim/quantization/imperative/qat.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py index 125d9fa88d4..5d29dc522b3 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py @@ -484,7 +484,7 @@ class ImperativeQuantizeOutputs(object): model_filename=model_filename, params_filename=params_filename)) - self._gather_scales(infer_program, scope) + self._gather_scales(infer_program, scope, fetch_targets) self._set_skip_quant_attr(infer_program) @@ -520,10 +520,10 @@ class ImperativeQuantizeOutputs(object): return flag - def _gather_scales(self, program, scope): + def _gather_scales(self, program, scope, fetch_targets): """ Get all scales from fake ops, save them into the corresponding ops - and delete all moving_average_abs_max_scale ops. + and delete all moving_average_abs_max_scale ops. """ def _gather_input_scale(): @@ -580,6 +580,11 @@ class ImperativeQuantizeOutputs(object): for next_op in next_ops: next_op._rename_input(out_var_name, in_var_name) + # If next_op is `fetch` and out_var_name in fetch_targets, + # fetch_targets must update to in_var_name when rename input. + for i in range(len(fetch_targets)): + if fetch_targets[i].name == out_var_name: + fetch_targets[i] = block.var(in_var_name) _gather_input_scale() _gather_output_scale() -- GitLab