未验证 提交 7b5065ab 编写于 作者: Z zhouzj 提交者: GitHub

fix the bug of quanting matmul (#52833)

上级 c3055d23
......@@ -701,6 +701,10 @@ class PostTrainingQuantization:
in self.quant_config.activation_quant_operation_types
or is_conv1d_quant
):
trans_y = (op_type == 'matmul_v2') and op.op().attr(
'trans_y'
)
op_type = op_type + '_trans_y' if trans_y else op_type
collect_var_name(
utils._get_op_input_var_names(op),
persistable_var_names,
......
......@@ -86,20 +86,6 @@ def _is_input_all_not_persistable(graph, op_node):
return is_input_all_not_persistable
def _check_grandchild_op_node(op_node, grandchild_op_name):
'''
Check whether the fake_quant node has a grandchild op node named
grandchild_op_name.
'''
for out1_var_node in op_node.outputs:
for out1_op_node in out1_var_node.outputs:
for out2_var_node in out1_op_node.outputs:
for out2_op_node in out2_var_node.outputs:
if out2_op_node.name() == grandchild_op_name:
return True
return False
class QuantizationTransformPass:
"""
Quantize the ops that have weights. Add quant and dequant ops for
......@@ -360,9 +346,14 @@ class QuantizationTransformPass:
if (
quant_type == 'channel_wise_abs_max'
): # Weight quantization
op_type = op.name()
trans_y = (op_type == 'matmul_v2') and op.op().attr(
'trans_y'
)
op_type = op_type + '_trans_y' if trans_y else op_type
quant_axis = (
1
if op.name() in utils._channelwise_quant_axis1_ops
if op_type in utils._channelwise_quant_axis1_ops
else 0
)
(
......@@ -1184,13 +1175,9 @@ class QuantizationFreezePass:
# Quantize weight and restore
if self._round_type == 'round':
param_v = self._load_var(input_arg_name)
if any(
_check_grandchild_op_node(op_node, op)
for op in utils._channelwise_quant_axis1_ops
):
quant_axis = 1
else:
quant_axis = 0
if op_node.op().has_attr('quant_axis'):
quant_axis = op_node.op().attr('quant_axis')
if input_arg_name not in self._quantized_ops:
self._quantized_ops.add(input_arg_name)
quantized_param_v = utils.quant_tensor(
......@@ -2605,9 +2592,14 @@ class QuantizationTransformPassV2(QuantizationTransformPass):
channel_wise = False
if quant_type == 'channel_wise_abs_max': # Weight quantization
channel_wise = True
op_type = op.name()
trans_y = (op_type == 'matmul_v2') and op.op().attr(
'trans_y'
)
op_type = op_type + '_trans_y' if trans_y else op_type
quant_axis = (
1
if op.name() in utils._channelwise_quant_axis1_ops
if op_type in utils._channelwise_quant_axis1_ops
else 0
)
insert_quant_pass = InsertQuantizeLinear(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册