未验证 提交 81fd2fff 编写于 作者: H handiz 提交者: GitHub

fix quant scale name (#44116)

上级 d7f4599d
......@@ -963,10 +963,10 @@ class PostTrainingQuantization(object):
else:
scale_dict = self._quantized_threshold
for key, val in scale_dict.items():
utils.set_variable_data(self._scope, self._place, key + ".scale",
utils.set_variable_data(self._scope, self._place, key + "@scale",
np.array([val], dtype=np.float32))
utils.set_variable_data(self._scope, self._place,
key + ".quant_dequant.scale",
key + ".quant_dequant@scale",
np.array([val], dtype=np.float32))
if not self._onnx_format:
......
......@@ -906,7 +906,7 @@ class QuantizationTransformPass(object):
"""
Return the scale name of quantized variable for the input `var_name`.
"""
return "%s.scale" % (var_name)
return "%s@scale" % (var_name)
def _is_skip_quant(self, graph, op_node):
"""
......@@ -1246,8 +1246,8 @@ class QuantizationFreezePass(object):
return var_name[:-len('.quantized')]
if var_name.endswith('.dequantized'):
return var_name[:-len('.dequantized')]
if var_name.endswith('.scale'):
return var_name[:-len('.scale')]
if var_name.endswith('@scale'):
return var_name[:-len('@scale')]
else:
return var_name
......@@ -1705,7 +1705,7 @@ class AddQuantDequantPass(object):
shape=var_node.shape(),
var_dtype=var_node.dtype())
scale_in_node = graph.create_persistable_node(
name="{}.quant_dequant.scale".format(var_node.name()),
name="{}.quant_dequant@scale".format(var_node.name()),
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
var_dtype=var_node.dtype())
......@@ -1922,7 +1922,7 @@ class InsertQuantizeLinear(object):
"""
Return the scale name of quantized variable for the input `var_name`.
"""
return "%s.scale" % (var_name)
return "%s@scale" % (var_name)
def _zero_point_name(self, var_name):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册