未验证 提交 ca174eaf 编写于 作者: X xiaoluomi 提交者: GitHub

fix conv1d_transpose insert quant node bug (#53320)

上级 1d8c82b6
......@@ -684,15 +684,37 @@ class PostTrainingQuantization:
op._set_attr("op_namescope", "skip_quant")
op_type = op.type
# skip quant form simular conv1d_transpose
if op_type == 'conv2d_transpose':
in_name = op.input("Filter")[0]
for _op in self._program.blocks[block_id].ops:
var_name = utils._get_op_output_var_names(_op)
if in_name in var_name:
for name in utils._get_op_input_var_names(_op):
if name not in persistable_var_names:
op._set_attr("op_namescope", "skip_quant")
_op._set_attr("op_namescope", "skip_quant")
if self._is_full_quantize and op_type not in list(
SUPPORT_QUANTIZATION_OP_DICT.keys()
):
_logger.warning(
op_type + " is not supported for quantization."
)
is_conv1d_quant = (op_type == "unsqueeze2") and (
utils._get_op_input_var_names(op)[0]
in persistable_var_names
conv1d_persistable_var_names = []
for opname in persistable_var_names:
if 'conv1d' in opname:
conv1d_persistable_var_names.append(opname)
is_conv1d_quant = (
(op_type == "unsqueeze2")
and (
utils._get_op_input_var_names(op)[0]
in conv1d_persistable_var_names
)
and (
utils._get_op_input_var_names(op)[0]
in conv1d_persistable_var_names
)
)
# For quantized ops, sample inputs and outputs
if (
......
......@@ -2469,6 +2469,7 @@ class QuantizationTransformPassV2(QuantizationTransformPass):
self._act_preprocess_func = act_preprocess_func
self._optimizer = optimizer_func
self._exe = executor
self._conv1dtranspose_flag = False
quant_type = [
'abs_max',
'channel_wise_abs_max',
......@@ -2597,11 +2598,15 @@ class QuantizationTransformPassV2(QuantizationTransformPass):
'trans_y'
)
op_type = op_type + '_trans_y' if trans_y else op_type
quant_axis = (
1
if op_type in utils._channelwise_quant_axis1_ops
else 0
)
if self._conv1dtranspose_flag:
quant_axis = 1
self._conv1dtranspose_flag = False
else:
quant_axis = (
1
if op.name() in utils._channelwise_quant_axis1_ops
else 0
)
insert_quant_pass = InsertQuantizeLinear(
self._place,
self._scope,
......@@ -2657,7 +2662,11 @@ class QuantizationTransformPassV2(QuantizationTransformPass):
for _op in graph.all_op_nodes():
var_names = utils._get_op_output_var_names(_op)
if conv_weight_var_name in var_names and self._has_weight(_op):
weight_scale_node = self._transform_forward(graph, _op)
if op.name() == 'conv2d_transpose':
if not self._is_skip_quant(graph, _op):
weight_scale_node = self._transform_forward(graph, _op)
else:
weight_scale_node = self._transform_forward(graph, _op)
# insert qdq before conv2d
for var_node in op.inputs:
quant_bits = (
......@@ -2677,6 +2686,10 @@ class QuantizationTransformPassV2(QuantizationTransformPass):
quant_axis = (
1 if op.name() in utils._channelwise_quant_axis1_ops else 0
)
if 'unsqueeze2' in utils._channelwise_quant_axis1_ops:
utils._channelwise_quant_axis1_ops.remove('unsqueeze2')
if self._is_skip_quant(graph, op):
return
insert_quant_pass = InsertQuantizeLinear(
self._place,
self._scope,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册