From cbab018413688067025f820727e5a87bfacd9750 Mon Sep 17 00:00:00 2001 From: ceci3 Date: Wed, 10 Aug 2022 16:16:33 +0800 Subject: [PATCH] [Cherry pick] fix quant scale name (#44903) * fix quant scale name (#44116) * fix acc diff problem caused by pr #44116 (#44311) Co-authored-by: handiz <35895648+ZhangHandi@users.noreply.github.com> --- .../post_training_quantization.py | 4 +-- .../slim/quantization/quantization_pass.py | 27 ++++++++++++------- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py index bfd76a44b4d..3d60e808951 100644 --- a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py +++ b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py @@ -962,10 +962,10 @@ class PostTrainingQuantization(object): else: scale_dict = self._quantized_threshold for key, val in scale_dict.items(): - utils.set_variable_data(self._scope, self._place, key + ".scale", + utils.set_variable_data(self._scope, self._place, key + "@scale", np.array([val], dtype=np.float32)) utils.set_variable_data(self._scope, self._place, - key + ".quant_dequant.scale", + key + ".quant_dequant@scale", np.array([val], dtype=np.float32)) if not self._onnx_format: diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index 5abb1d382b3..8213b779a6a 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -906,7 +906,7 @@ class QuantizationTransformPass(object): """ Return the scale name of quantized variable for the input `var_name`. """ - return "%s.scale" % (var_name) + return "%s@scale" % (var_name) def _is_skip_quant(self, graph, op_node): """ @@ -1246,8 +1246,8 @@ class QuantizationFreezePass(object): return var_name[:-len('.quantized')] if var_name.endswith('.dequantized'): return var_name[:-len('.dequantized')] - if var_name.endswith('.scale'): - return var_name[:-len('.scale')] + if var_name.endswith('@scale'): + return var_name[:-len('@scale')] else: return var_name @@ -1440,11 +1440,18 @@ class OutScaleForTrainingPass(object): [core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32]: continue - scale_node = graph.create_persistable_node( - name=self._scale_name(in_node.name()), - var_type=core.VarDesc.VarType.LOD_TENSOR, - shape=[1], - var_dtype=in_node.dtype()) + try: + graph._find_node_by_name( + graph.all_var_nodes(), + self._scale_name(in_node.name())) + continue + except: + scale_node = graph.create_persistable_node( + name=self._scale_name(in_node.name()), + var_type=core.VarDesc.VarType.LOD_TENSOR, + shape=[1], + var_dtype=in_node.dtype()) + data_type = 'float64' if in_node.dtype() \ == core.VarDesc.VarType.FP64 else 'float32' _init_var_node(scale_node, np.ones([1], dtype=data_type), @@ -1705,7 +1712,7 @@ class AddQuantDequantPass(object): shape=var_node.shape(), var_dtype=var_node.dtype()) scale_in_node = graph.create_persistable_node( - name="{}.quant_dequant.scale".format(var_node.name()), + name="{}.quant_dequant@scale".format(var_node.name()), var_type=core.VarDesc.VarType.LOD_TENSOR, shape=[1], var_dtype=var_node.dtype()) @@ -1954,7 +1961,7 @@ class InsertQuantizeLinear(object): """ Return the scale name of quantized variable for the input `var_name`. """ - return "%s.scale" % (var_name) + return "%s@scale" % (var_name) def _zero_point_name(self, var_name): """ -- GitLab