From 81fd2fff617e258e777dd520a6c826039b0eb222 Mon Sep 17 00:00:00 2001 From: handiz <35895648+ZhangHandi@users.noreply.github.com> Date: Wed, 6 Jul 2022 16:38:48 +0800 Subject: [PATCH] fix quant scale name (#44116) --- .../slim/quantization/post_training_quantization.py | 4 ++-- .../contrib/slim/quantization/quantization_pass.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py index f1da3990a3..a46a0d12fd 100644 --- a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py +++ b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py @@ -963,10 +963,10 @@ class PostTrainingQuantization(object): else: scale_dict = self._quantized_threshold for key, val in scale_dict.items(): - utils.set_variable_data(self._scope, self._place, key + ".scale", + utils.set_variable_data(self._scope, self._place, key + "@scale", np.array([val], dtype=np.float32)) utils.set_variable_data(self._scope, self._place, - key + ".quant_dequant.scale", + key + ".quant_dequant@scale", np.array([val], dtype=np.float32)) if not self._onnx_format: diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index 3a316e9192..d3ce543320 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -906,7 +906,7 @@ class QuantizationTransformPass(object): """ Return the scale name of quantized variable for the input `var_name`. """ - return "%s.scale" % (var_name) + return "%s@scale" % (var_name) def _is_skip_quant(self, graph, op_node): """ @@ -1246,8 +1246,8 @@ class QuantizationFreezePass(object): return var_name[:-len('.quantized')] if var_name.endswith('.dequantized'): return var_name[:-len('.dequantized')] - if var_name.endswith('.scale'): - return var_name[:-len('.scale')] + if var_name.endswith('@scale'): + return var_name[:-len('@scale')] else: return var_name @@ -1705,7 +1705,7 @@ class AddQuantDequantPass(object): shape=var_node.shape(), var_dtype=var_node.dtype()) scale_in_node = graph.create_persistable_node( - name="{}.quant_dequant.scale".format(var_node.name()), + name="{}.quant_dequant@scale".format(var_node.name()), var_type=core.VarDesc.VarType.LOD_TENSOR, shape=[1], var_dtype=var_node.dtype()) @@ -1922,7 +1922,7 @@ class InsertQuantizeLinear(object): """ Return the scale name of quantized variable for the input `var_name`. """ - return "%s.scale" % (var_name) + return "%s@scale" % (var_name) def _zero_point_name(self, var_name): """ -- GitLab