未验证 提交 8fc9a19f 编写于 作者: G Guanghua Yu 提交者: GitHub

fix quantization int8 weight save bug (#51500)

上级 c2b24166
......@@ -3150,7 +3150,7 @@ class QuantWeightPass:
self._save_int_weight = save_int_weight
assert self._scope is not None, "scope must not be None."
assert self._place is not None, "place must not be None."
self._quantized_ops = {}
self._quantized_ops = set()
def apply(self, graph):
assert isinstance(
......@@ -3189,6 +3189,7 @@ class QuantWeightPass:
quant_axis = _op.op().attr("quant_axis")
bits_length = _op.op().attr("bit_length")
if x_node.name() not in self._quantized_ops:
self._quantized_ops.add(x_node.name())
quantized_param_v = utils.quant_tensor(
param_v.copy(),
scale_v,
......@@ -3211,30 +3212,10 @@ class QuantWeightPass:
quantized_param_v = quantized_param_v.astype(
save_weight_dtype
)
quant_weight_node = graph.create_persistable_node(
name=self._quantized_var_name(x_node.name()),
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=x_node.shape(),
var_dtype=core.VarDesc.VarType.INT8,
)
_init_var_node(
quant_weight_node,
quantized_param_v,
self._scope,
self._place,
)
self._quantized_ops[x_node.name()] = quant_weight_node
self._restore_var(x_node.name(), quantized_param_v)
for next_op_node in out_node.outputs:
if (
self._quantized_ops[x_node.name()].node
in graph.graph.nodes()
):
graph.update_input_link(
out_node,
self._quantized_ops[x_node.name()],
next_op_node,
)
graph.update_input_link(out_node, x_node, next_op_node)
graph.safe_remove_nodes(_op)
self._remove_unused_var_nodes(graph)
......@@ -3260,11 +3241,9 @@ class QuantWeightPass:
def _load_var(self, name):
return np.array(self._scope.find_var(name).get_tensor())
def _quantized_var_name(self, var_name):
"""
Return quantized variable name for the input `var_name`.
"""
return "%s.quantized" % (var_name)
def _restore_var(self, name, array):
tensor = self._scope.find_var(name).get_tensor()
tensor.set(array, self._place)
class AddQuantDequantForInferencePass:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册