From 2995f742e8ab8f17499a857e13861740ccc815a1 Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Tue, 22 Nov 2022 11:08:58 +0800 Subject: [PATCH] fix error of QuantizationTransformPassV2 when has condition block (#48190) * fix error of QuantizationTransformPassV2 when has condition block * fix error --- .../fluid/contrib/slim/quantization/quantization_pass.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index f0caabd6f4..8902b40aa6 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -2481,11 +2481,6 @@ class QuantizationTransformPassV2(QuantizationTransformPass): self.create_var_map = {} self.create_op_map = {} - # marked the variable which has been dequantized. - self.dequantized_vars = collections.OrderedDict() - self.persistable_vars = [] - self.processed_vars = [] - def _quant_preprocess(self, op_node): user_skipped = False if isinstance(self._skip_pattern, list): @@ -2627,6 +2622,10 @@ class QuantizationTransformPassV2(QuantizationTransformPass): ), 'graph must be the instance of IrGraph.' if self._is_test is None: self._is_test = graph.is_test() + # marked the variable which has been dequantized. + self.dequantized_vars = collections.OrderedDict() + self.persistable_vars = [] + self.processed_vars = [] self.persistable_vars = [ p.name() for p in graph.all_persistable_nodes() -- GitLab