diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py index 9a7d870c8d2ddc3f70e0626506a64d187ec2072b..3a4b7721d55ffd82dc1cc13a44aef7b94e740bec 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py @@ -32,6 +32,7 @@ from paddle.fluid.io import load_inference_model, save_inference_model from ..quantization_pass import ReplaceFakeQuantDequantPass, QuantWeightPass from paddle.fluid.log_helper import get_logger from .. import quantization_pass +from ..utils import move_persistable_var_to_global_block from . import utils from . import fuse_utils @@ -550,6 +551,8 @@ class ImperativeQuantizeOutputs(object): clip_extra = True + move_persistable_var_to_global_block(infer_program) + save_inference_model(dirname=dirname, feeded_var_names=feed_target_names, target_vars=fetch_targets, diff --git a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py index 9d70030847afce57a61848e321154cb5082d3f99..97cb732d5e6cebd0a927caef16058895c45febde 100644 --- a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py +++ b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py @@ -449,21 +449,7 @@ class PostTrainingQuantization(object): self._collect_dynamic_quantize_op_threshold( self._dynamic_quantize_op_type) - # Move sub blocks persistable var to global block - global_block = self._program.global_block() - for _op in global_block.ops: - if _op.type == "while": - _block_id = _op.attr("sub_block").id - _block = self._program.block(_block_id) - persistables = [] - for _name, _var in _block.vars.items(): - if _var.persistable: - global_block._clone_variable(_var) - persistables.append(_name) - for _name in persistables: - _block._remove_var(_name) - persistables.extend(_op.input('X')) - _op.desc.set_input("X", persistables) + utils.move_persistable_var_to_global_block(self._program) if not self._return_graph: return self._program diff --git a/python/paddle/fluid/contrib/slim/quantization/utils.py b/python/paddle/fluid/contrib/slim/quantization/utils.py index b04446b16aa755f2cd96499c6b1527ff98da9340..fe4446939a5546ad1b37b54c1fd2d8d5a086ce69 100644 --- a/python/paddle/fluid/contrib/slim/quantization/utils.py +++ b/python/paddle/fluid/contrib/slim/quantization/utils.py @@ -433,6 +433,24 @@ def calculate_quant_cos_error(orig_tensor, qdq_tensor): return cos_sim +def move_persistable_var_to_global_block(program): + # Move sub blocks persistable var to global block + global_block = program.global_block() + for _op in global_block.ops: + if _op.type == "while": + _block_id = _op.attr("sub_block").id + _block = program.block(_block_id) + persistables = [] + for _name, _var in _block.vars.items(): + if _var.persistable: + global_block._clone_variable(_var) + persistables.append(_name) + for _name in persistables: + _block._remove_var(_name) + persistables.extend(_op.input('X')) + _op.desc.set_input("X", persistables) + + def l2_loss(gt, pred): return ((gt - pred)**2).mean()