From 10881b6ee800c2442b4d24df3572527eba171035 Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Thu, 20 Oct 2022 14:46:32 +0800 Subject: [PATCH] fix problem of persistable var saving in QAT (#47178) --- .../slim/quantization/imperative/qat.py | 3 +++ .../quantization/post_training_quantization.py | 16 +--------------- .../fluid/contrib/slim/quantization/utils.py | 18 ++++++++++++++++++ 3 files changed, 22 insertions(+), 15 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py index b9f10abcd8..423fb0fcd5 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py @@ -32,6 +32,7 @@ from paddle.fluid.io import load_inference_model, save_inference_model from ..quantization_pass import ReplaceFakeQuantDequantPass, QuantWeightPass from paddle.fluid.log_helper import get_logger from .. import quantization_pass +from ..utils import move_persistable_var_to_global_block from . import utils from . import fuse_utils @@ -552,6 +553,8 @@ class ImperativeQuantizeOutputs(object): clip_extra = True + move_persistable_var_to_global_block(infer_program) + save_inference_model(dirname=dirname, feeded_var_names=feed_target_names, target_vars=fetch_targets, diff --git a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py index 4d538cc75b..5c0fff9abe 100644 --- a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py +++ b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py @@ -449,21 +449,7 @@ class PostTrainingQuantization(object): self._collect_dynamic_quantize_op_threshold( self._dynamic_quantize_op_type) - # Move sub blocks persistable var to global block - global_block = self._program.global_block() - for _op in global_block.ops: - if _op.type == "while": - _block_id = _op.attr("sub_block").id - _block = self._program.block(_block_id) - persistables = [] - for _name, _var in _block.vars.items(): - if _var.persistable: - global_block._clone_variable(_var) - persistables.append(_name) - for _name in persistables: - _block._remove_var(_name) - persistables.extend(_op.input('X')) - _op.desc.set_input("X", persistables) + utils.move_persistable_var_to_global_block(self._program) if not self._return_graph: return self._program diff --git a/python/paddle/fluid/contrib/slim/quantization/utils.py b/python/paddle/fluid/contrib/slim/quantization/utils.py index 2af400ec4f..158f7e07a0 100644 --- a/python/paddle/fluid/contrib/slim/quantization/utils.py +++ b/python/paddle/fluid/contrib/slim/quantization/utils.py @@ -435,6 +435,24 @@ def calculate_quant_cos_error(orig_tensor, qdq_tensor): return cos_sim +def move_persistable_var_to_global_block(program): + # Move sub blocks persistable var to global block + global_block = program.global_block() + for _op in global_block.ops: + if _op.type == "while": + _block_id = _op.attr("sub_block").id + _block = program.block(_block_id) + persistables = [] + for _name, _var in _block.vars.items(): + if _var.persistable: + global_block._clone_variable(_var) + persistables.append(_name) + for _name in persistables: + _block._remove_var(_name) + persistables.extend(_op.input('X')) + _op.desc.set_input("X", persistables) + + def l2_loss(gt, pred): return ((gt - pred)**2).mean() -- GitLab