未验证 提交 66a1df3c 编写于 作者: C Chang Xu 提交者: GitHub

Avoid Quant Weight Repeatedly (#47587)

上级 b3d52d46
......@@ -1119,6 +1119,7 @@ class QuantizationFreezePass(object):
self._op_input_rename_map = collections.OrderedDict()
self._op_output_rename_map = collections.OrderedDict()
self._quant_var_scale_map = collections.OrderedDict()
self._quantized_ops = set()
def apply(self, graph):
"""
......@@ -1173,6 +1174,8 @@ class QuantizationFreezePass(object):
quant_axis = 1
else:
quant_axis = 0
if input_arg_name not in self._quantized_ops:
self._quantized_ops.add(input_arg_name)
quantized_param_v = utils.quant_tensor(
param_v.copy(),
scale_v,
......@@ -1191,6 +1194,7 @@ class QuantizationFreezePass(object):
)
quantized_param_v = np.round(quantized_param_v)
self._restore_var(input_arg_name, quantized_param_v)
self._remove_fake_quant_and_dequant_op(graph, op_node)
# Remove all fake dequant op
......@@ -3029,6 +3033,7 @@ class QuantWeightPass(object):
self._save_int_weight = save_int_weight
assert self._scope is not None, "scope must not be None."
assert self._place is not None, "place must not be None."
self._quantized_ops = set()
def apply(self, graph):
assert isinstance(
......@@ -3066,6 +3071,8 @@ class QuantWeightPass(object):
param_v = self._load_var(x_node.name())
quant_axis = _op.op().attr("quant_axis")
bits_length = _op.op().attr("bit_length")
if x_node.name() not in self._quantized_ops:
self._quantized_ops.add(x_node.name())
quantized_param_v = utils.quant_tensor(
param_v.copy(),
scale_v,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册