未验证 提交 66a1df3c 编写于 作者: C Chang Xu 提交者: GitHub

Avoid Quant Weight Repeatedly (#47587)

上级 b3d52d46
...@@ -1119,6 +1119,7 @@ class QuantizationFreezePass(object): ...@@ -1119,6 +1119,7 @@ class QuantizationFreezePass(object):
self._op_input_rename_map = collections.OrderedDict() self._op_input_rename_map = collections.OrderedDict()
self._op_output_rename_map = collections.OrderedDict() self._op_output_rename_map = collections.OrderedDict()
self._quant_var_scale_map = collections.OrderedDict() self._quant_var_scale_map = collections.OrderedDict()
self._quantized_ops = set()
def apply(self, graph): def apply(self, graph):
""" """
...@@ -1173,24 +1174,27 @@ class QuantizationFreezePass(object): ...@@ -1173,24 +1174,27 @@ class QuantizationFreezePass(object):
quant_axis = 1 quant_axis = 1
else: else:
quant_axis = 0 quant_axis = 0
quantized_param_v = utils.quant_tensor( if input_arg_name not in self._quantized_ops:
param_v.copy(), self._quantized_ops.add(input_arg_name)
scale_v, quantized_param_v = utils.quant_tensor(
quant_axis, param_v.copy(),
self._weight_bits,
)
quantized_param_v = np.round(quantized_param_v)
# Weight bias correction
if self._bias_correction == True:
quantized_param_v = utils.bias_correction_w(
param_v,
quantized_param_v,
scale_v, scale_v,
quant_axis, quant_axis,
weight_bits=self._weight_bits, self._weight_bits,
) )
quantized_param_v = np.round(quantized_param_v) quantized_param_v = np.round(quantized_param_v)
self._restore_var(input_arg_name, quantized_param_v) # Weight bias correction
if self._bias_correction == True:
quantized_param_v = utils.bias_correction_w(
param_v,
quantized_param_v,
scale_v,
quant_axis,
weight_bits=self._weight_bits,
)
quantized_param_v = np.round(quantized_param_v)
self._restore_var(input_arg_name, quantized_param_v)
self._remove_fake_quant_and_dequant_op(graph, op_node) self._remove_fake_quant_and_dequant_op(graph, op_node)
# Remove all fake dequant op # Remove all fake dequant op
...@@ -3029,6 +3033,7 @@ class QuantWeightPass(object): ...@@ -3029,6 +3033,7 @@ class QuantWeightPass(object):
self._save_int_weight = save_int_weight self._save_int_weight = save_int_weight
assert self._scope is not None, "scope must not be None." assert self._scope is not None, "scope must not be None."
assert self._place is not None, "place must not be None." assert self._place is not None, "place must not be None."
self._quantized_ops = set()
def apply(self, graph): def apply(self, graph):
assert isinstance( assert isinstance(
...@@ -3066,29 +3071,31 @@ class QuantWeightPass(object): ...@@ -3066,29 +3071,31 @@ class QuantWeightPass(object):
param_v = self._load_var(x_node.name()) param_v = self._load_var(x_node.name())
quant_axis = _op.op().attr("quant_axis") quant_axis = _op.op().attr("quant_axis")
bits_length = _op.op().attr("bit_length") bits_length = _op.op().attr("bit_length")
quantized_param_v = utils.quant_tensor( if x_node.name() not in self._quantized_ops:
param_v.copy(), self._quantized_ops.add(x_node.name())
scale_v, quantized_param_v = utils.quant_tensor(
quant_axis, param_v.copy(),
bits_length,
onnx_format=True,
)
if self._bias_correction == True:
quantized_param_v = utils.bias_correction_w(
param_v,
quantized_param_v,
scale_v, scale_v,
quant_axis, quant_axis,
weight_bits=bits_length, bits_length,
onnx_format=True,
) )
if self._save_int_weight: if self._bias_correction == True:
# cast weight type to int quantized_param_v = utils.bias_correction_w(
if self._quant_bits == 8: param_v,
save_weight_dtype = np.int8 quantized_param_v,
quantized_param_v = quantized_param_v.astype( scale_v,
save_weight_dtype quant_axis,
) weight_bits=bits_length,
self._restore_var(x_node.name(), quantized_param_v) )
if self._save_int_weight:
# cast weight type to int
if self._quant_bits == 8:
save_weight_dtype = np.int8
quantized_param_v = quantized_param_v.astype(
save_weight_dtype
)
self._restore_var(x_node.name(), quantized_param_v)
for next_op_node in out_node.outputs: for next_op_node in out_node.outputs:
graph.update_input_link(out_node, x_node, next_op_node) graph.update_input_link(out_node, x_node, next_op_node)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册