未验证 提交 66a1df3c 编写于 作者: C Chang Xu 提交者: GitHub

Avoid Quant Weight Repeatedly (#47587)

上级 b3d52d46
......@@ -1119,6 +1119,7 @@ class QuantizationFreezePass(object):
self._op_input_rename_map = collections.OrderedDict()
self._op_output_rename_map = collections.OrderedDict()
self._quant_var_scale_map = collections.OrderedDict()
self._quantized_ops = set()
def apply(self, graph):
"""
......@@ -1173,24 +1174,27 @@ class QuantizationFreezePass(object):
quant_axis = 1
else:
quant_axis = 0
quantized_param_v = utils.quant_tensor(
param_v.copy(),
scale_v,
quant_axis,
self._weight_bits,
)
quantized_param_v = np.round(quantized_param_v)
# Weight bias correction
if self._bias_correction == True:
quantized_param_v = utils.bias_correction_w(
param_v,
quantized_param_v,
if input_arg_name not in self._quantized_ops:
self._quantized_ops.add(input_arg_name)
quantized_param_v = utils.quant_tensor(
param_v.copy(),
scale_v,
quant_axis,
weight_bits=self._weight_bits,
self._weight_bits,
)
quantized_param_v = np.round(quantized_param_v)
self._restore_var(input_arg_name, quantized_param_v)
# Weight bias correction
if self._bias_correction == True:
quantized_param_v = utils.bias_correction_w(
param_v,
quantized_param_v,
scale_v,
quant_axis,
weight_bits=self._weight_bits,
)
quantized_param_v = np.round(quantized_param_v)
self._restore_var(input_arg_name, quantized_param_v)
self._remove_fake_quant_and_dequant_op(graph, op_node)
# Remove all fake dequant op
......@@ -3029,6 +3033,7 @@ class QuantWeightPass(object):
self._save_int_weight = save_int_weight
assert self._scope is not None, "scope must not be None."
assert self._place is not None, "place must not be None."
self._quantized_ops = set()
def apply(self, graph):
assert isinstance(
......@@ -3066,29 +3071,31 @@ class QuantWeightPass(object):
param_v = self._load_var(x_node.name())
quant_axis = _op.op().attr("quant_axis")
bits_length = _op.op().attr("bit_length")
quantized_param_v = utils.quant_tensor(
param_v.copy(),
scale_v,
quant_axis,
bits_length,
onnx_format=True,
)
if self._bias_correction == True:
quantized_param_v = utils.bias_correction_w(
param_v,
quantized_param_v,
if x_node.name() not in self._quantized_ops:
self._quantized_ops.add(x_node.name())
quantized_param_v = utils.quant_tensor(
param_v.copy(),
scale_v,
quant_axis,
weight_bits=bits_length,
bits_length,
onnx_format=True,
)
if self._save_int_weight:
# cast weight type to int
if self._quant_bits == 8:
save_weight_dtype = np.int8
quantized_param_v = quantized_param_v.astype(
save_weight_dtype
)
self._restore_var(x_node.name(), quantized_param_v)
if self._bias_correction == True:
quantized_param_v = utils.bias_correction_w(
param_v,
quantized_param_v,
scale_v,
quant_axis,
weight_bits=bits_length,
)
if self._save_int_weight:
# cast weight type to int
if self._quant_bits == 8:
save_weight_dtype = np.int8
quantized_param_v = quantized_param_v.astype(
save_weight_dtype
)
self._restore_var(x_node.name(), quantized_param_v)
for next_op_node in out_node.outputs:
graph.update_input_link(out_node, x_node, next_op_node)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册