未验证 提交 8fc9a19f 编写于 作者: G Guanghua Yu 提交者: GitHub

fix quantization int8 weight save bug (#51500)

上级 c2b24166
...@@ -3150,7 +3150,7 @@ class QuantWeightPass: ...@@ -3150,7 +3150,7 @@ class QuantWeightPass:
self._save_int_weight = save_int_weight self._save_int_weight = save_int_weight
assert self._scope is not None, "scope must not be None." assert self._scope is not None, "scope must not be None."
assert self._place is not None, "place must not be None." assert self._place is not None, "place must not be None."
self._quantized_ops = {} self._quantized_ops = set()
def apply(self, graph): def apply(self, graph):
assert isinstance( assert isinstance(
...@@ -3189,6 +3189,7 @@ class QuantWeightPass: ...@@ -3189,6 +3189,7 @@ class QuantWeightPass:
quant_axis = _op.op().attr("quant_axis") quant_axis = _op.op().attr("quant_axis")
bits_length = _op.op().attr("bit_length") bits_length = _op.op().attr("bit_length")
if x_node.name() not in self._quantized_ops: if x_node.name() not in self._quantized_ops:
self._quantized_ops.add(x_node.name())
quantized_param_v = utils.quant_tensor( quantized_param_v = utils.quant_tensor(
param_v.copy(), param_v.copy(),
scale_v, scale_v,
...@@ -3211,30 +3212,10 @@ class QuantWeightPass: ...@@ -3211,30 +3212,10 @@ class QuantWeightPass:
quantized_param_v = quantized_param_v.astype( quantized_param_v = quantized_param_v.astype(
save_weight_dtype save_weight_dtype
) )
quant_weight_node = graph.create_persistable_node( self._restore_var(x_node.name(), quantized_param_v)
name=self._quantized_var_name(x_node.name()),
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=x_node.shape(),
var_dtype=core.VarDesc.VarType.INT8,
)
_init_var_node(
quant_weight_node,
quantized_param_v,
self._scope,
self._place,
)
self._quantized_ops[x_node.name()] = quant_weight_node
for next_op_node in out_node.outputs: for next_op_node in out_node.outputs:
if ( graph.update_input_link(out_node, x_node, next_op_node)
self._quantized_ops[x_node.name()].node
in graph.graph.nodes()
):
graph.update_input_link(
out_node,
self._quantized_ops[x_node.name()],
next_op_node,
)
graph.safe_remove_nodes(_op) graph.safe_remove_nodes(_op)
self._remove_unused_var_nodes(graph) self._remove_unused_var_nodes(graph)
...@@ -3260,11 +3241,9 @@ class QuantWeightPass: ...@@ -3260,11 +3241,9 @@ class QuantWeightPass:
def _load_var(self, name): def _load_var(self, name):
return np.array(self._scope.find_var(name).get_tensor()) return np.array(self._scope.find_var(name).get_tensor())
def _quantized_var_name(self, var_name): def _restore_var(self, name, array):
""" tensor = self._scope.find_var(name).get_tensor()
Return quantized variable name for the input `var_name`. tensor.set(array, self._place)
"""
return "%s.quantized" % (var_name)
class AddQuantDequantForInferencePass: class AddQuantDequantForInferencePass:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册