diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py index b9f10abcd86fb0a8d2a7f7082b7659e0b7e53997..423fb0fcd52f372bc6016d7138722eb17149af53 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py @@ -32,6 +32,7 @@ from paddle.fluid.io import load_inference_model, save_inference_model from ..quantization_pass import ReplaceFakeQuantDequantPass, QuantWeightPass from paddle.fluid.log_helper import get_logger from .. import quantization_pass +from ..utils import move_persistable_var_to_global_block from . import utils from . import fuse_utils @@ -552,6 +553,8 @@ class ImperativeQuantizeOutputs(object): clip_extra = True + move_persistable_var_to_global_block(infer_program) + save_inference_model(dirname=dirname, feeded_var_names=feed_target_names, target_vars=fetch_targets, diff --git a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py index 4d538cc75b3af77e167ae3c33f9bf75597ab2625..5c0fff9abe4e4abfa704a3acd2224dc6b7e22e97 100644 --- a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py +++ b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py @@ -449,21 +449,7 @@ class PostTrainingQuantization(object): self._collect_dynamic_quantize_op_threshold( self._dynamic_quantize_op_type) - # Move sub blocks persistable var to global block - global_block = self._program.global_block() - for _op in global_block.ops: - if _op.type == "while": - _block_id = _op.attr("sub_block").id - _block = self._program.block(_block_id) - persistables = [] - for _name, _var in _block.vars.items(): - if _var.persistable: - global_block._clone_variable(_var) - persistables.append(_name) - for _name in persistables: - _block._remove_var(_name) - persistables.extend(_op.input('X')) - _op.desc.set_input("X", persistables) + utils.move_persistable_var_to_global_block(self._program) if not self._return_graph: return self._program diff --git a/python/paddle/fluid/contrib/slim/quantization/utils.py b/python/paddle/fluid/contrib/slim/quantization/utils.py index 2af400ec4f7acfe7bfc6e48fa35db539b8bef2d2..158f7e07a0d91a8c856fb386ddced4155f1fadc4 100644 --- a/python/paddle/fluid/contrib/slim/quantization/utils.py +++ b/python/paddle/fluid/contrib/slim/quantization/utils.py @@ -435,6 +435,24 @@ def calculate_quant_cos_error(orig_tensor, qdq_tensor): return cos_sim +def move_persistable_var_to_global_block(program): + # Move sub blocks persistable var to global block + global_block = program.global_block() + for _op in global_block.ops: + if _op.type == "while": + _block_id = _op.attr("sub_block").id + _block = program.block(_block_id) + persistables = [] + for _name, _var in _block.vars.items(): + if _var.persistable: + global_block._clone_variable(_var) + persistables.append(_name) + for _name in persistables: + _block._remove_var(_name) + persistables.extend(_op.input('X')) + _op.desc.set_input("X", persistables) + + def l2_loss(gt, pred): return ((gt - pred)**2).mean()