diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py index 125d9fa88d4aedf0e4418c4a109a023c9d5a0e87..5d29dc522b3ef6e008a769b3c899672fe3aa464b 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py @@ -484,7 +484,7 @@ class ImperativeQuantizeOutputs(object): model_filename=model_filename, params_filename=params_filename)) - self._gather_scales(infer_program, scope) + self._gather_scales(infer_program, scope, fetch_targets) self._set_skip_quant_attr(infer_program) @@ -520,10 +520,10 @@ class ImperativeQuantizeOutputs(object): return flag - def _gather_scales(self, program, scope): + def _gather_scales(self, program, scope, fetch_targets): """ Get all scales from fake ops, save them into the corresponding ops - and delete all moving_average_abs_max_scale ops. + and delete all moving_average_abs_max_scale ops. """ def _gather_input_scale(): @@ -580,6 +580,11 @@ class ImperativeQuantizeOutputs(object): for next_op in next_ops: next_op._rename_input(out_var_name, in_var_name) + # If next_op is `fetch` and out_var_name in fetch_targets, + # fetch_targets must update to in_var_name when rename input. + for i in range(len(fetch_targets)): + if fetch_targets[i].name == out_var_name: + fetch_targets[i] = block.var(in_var_name) _gather_input_scale() _gather_output_scale()