未验证 提交 cbab0184 编写于 作者: C ceci3 提交者: GitHub

[Cherry pick] fix quant scale name (#44903)

* fix quant scale name (#44116)

* fix acc diff problem caused by pr #44116 (#44311)
Co-authored-by: Nhandiz <35895648+ZhangHandi@users.noreply.github.com>
上级 26762817
...@@ -962,10 +962,10 @@ class PostTrainingQuantization(object): ...@@ -962,10 +962,10 @@ class PostTrainingQuantization(object):
else: else:
scale_dict = self._quantized_threshold scale_dict = self._quantized_threshold
for key, val in scale_dict.items(): for key, val in scale_dict.items():
utils.set_variable_data(self._scope, self._place, key + ".scale", utils.set_variable_data(self._scope, self._place, key + "@scale",
np.array([val], dtype=np.float32)) np.array([val], dtype=np.float32))
utils.set_variable_data(self._scope, self._place, utils.set_variable_data(self._scope, self._place,
key + ".quant_dequant.scale", key + ".quant_dequant@scale",
np.array([val], dtype=np.float32)) np.array([val], dtype=np.float32))
if not self._onnx_format: if not self._onnx_format:
......
...@@ -906,7 +906,7 @@ class QuantizationTransformPass(object): ...@@ -906,7 +906,7 @@ class QuantizationTransformPass(object):
""" """
Return the scale name of quantized variable for the input `var_name`. Return the scale name of quantized variable for the input `var_name`.
""" """
return "%s.scale" % (var_name) return "%s@scale" % (var_name)
def _is_skip_quant(self, graph, op_node): def _is_skip_quant(self, graph, op_node):
""" """
...@@ -1246,8 +1246,8 @@ class QuantizationFreezePass(object): ...@@ -1246,8 +1246,8 @@ class QuantizationFreezePass(object):
return var_name[:-len('.quantized')] return var_name[:-len('.quantized')]
if var_name.endswith('.dequantized'): if var_name.endswith('.dequantized'):
return var_name[:-len('.dequantized')] return var_name[:-len('.dequantized')]
if var_name.endswith('.scale'): if var_name.endswith('@scale'):
return var_name[:-len('.scale')] return var_name[:-len('@scale')]
else: else:
return var_name return var_name
...@@ -1440,11 +1440,18 @@ class OutScaleForTrainingPass(object): ...@@ -1440,11 +1440,18 @@ class OutScaleForTrainingPass(object):
[core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32]: [core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32]:
continue continue
scale_node = graph.create_persistable_node( try:
name=self._scale_name(in_node.name()), graph._find_node_by_name(
var_type=core.VarDesc.VarType.LOD_TENSOR, graph.all_var_nodes(),
shape=[1], self._scale_name(in_node.name()))
var_dtype=in_node.dtype()) continue
except:
scale_node = graph.create_persistable_node(
name=self._scale_name(in_node.name()),
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
var_dtype=in_node.dtype())
data_type = 'float64' if in_node.dtype() \ data_type = 'float64' if in_node.dtype() \
== core.VarDesc.VarType.FP64 else 'float32' == core.VarDesc.VarType.FP64 else 'float32'
_init_var_node(scale_node, np.ones([1], dtype=data_type), _init_var_node(scale_node, np.ones([1], dtype=data_type),
...@@ -1705,7 +1712,7 @@ class AddQuantDequantPass(object): ...@@ -1705,7 +1712,7 @@ class AddQuantDequantPass(object):
shape=var_node.shape(), shape=var_node.shape(),
var_dtype=var_node.dtype()) var_dtype=var_node.dtype())
scale_in_node = graph.create_persistable_node( scale_in_node = graph.create_persistable_node(
name="{}.quant_dequant.scale".format(var_node.name()), name="{}.quant_dequant@scale".format(var_node.name()),
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1], shape=[1],
var_dtype=var_node.dtype()) var_dtype=var_node.dtype())
...@@ -1954,7 +1961,7 @@ class InsertQuantizeLinear(object): ...@@ -1954,7 +1961,7 @@ class InsertQuantizeLinear(object):
""" """
Return the scale name of quantized variable for the input `var_name`. Return the scale name of quantized variable for the input `var_name`.
""" """
return "%s.scale" % (var_name) return "%s@scale" % (var_name)
def _zero_point_name(self, var_name): def _zero_point_name(self, var_name):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册