Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
5c19bfc8
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5c19bfc8
编写于
4月 06, 2023
作者:
C
ceci3
提交者:
GitHub
4月 06, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support hybrid parallel in qat (#52219)
上级
6c01ce8a
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
83 addition
and
43 deletion
+83
-43
python/paddle/static/quantization/quantization_pass.py
python/paddle/static/quantization/quantization_pass.py
+83
-43
未找到文件。
python/paddle/static/quantization/quantization_pass.py
浏览文件 @
5c19bfc8
...
@@ -298,6 +298,7 @@ class QuantizationTransformPass:
...
@@ -298,6 +298,7 @@ class QuantizationTransformPass:
def
_transform_forward
(
graph
,
op
):
def
_transform_forward
(
graph
,
op
):
op
.
op
().
_set_attr
(
"quantization_type"
,
"qat_with_weight"
)
op
.
op
().
_set_attr
(
"quantization_type"
,
"qat_with_weight"
)
op
.
op
().
_set_attr
(
"with_quant_attr"
,
True
)
op
.
op
().
_set_attr
(
"with_quant_attr"
,
True
)
op_role
=
op
.
op
().
attr
(
"op_role"
)
inputs
=
op
.
inputs
inputs
=
op
.
inputs
for
var_node
in
inputs
:
for
var_node
in
inputs
:
if
var_node
.
name
()
not
in
op
.
input_arg_names
():
if
var_node
.
name
()
not
in
op
.
input_arg_names
():
...
@@ -368,7 +369,12 @@ class QuantizationTransformPass:
...
@@ -368,7 +369,12 @@ class QuantizationTransformPass:
quant_var_node
,
quant_var_node
,
scale_var_node
,
scale_var_node
,
)
=
self
.
_insert_channel_quant_op
(
)
=
self
.
_insert_channel_quant_op
(
graph
,
var_node
,
name
,
quant_bits
,
quant_axis
graph
,
var_node
,
name
,
quant_bits
,
quant_axis
,
op_role
,
)
)
dequant_var_node
=
self
.
_insert_channel_dequant_op
(
dequant_var_node
=
self
.
_insert_channel_dequant_op
(
graph
,
graph
,
...
@@ -376,13 +382,23 @@ class QuantizationTransformPass:
...
@@ -376,13 +382,23 @@ class QuantizationTransformPass:
[
scale_var_node
],
[
scale_var_node
],
[
quant_bits
],
[
quant_bits
],
quant_axis
,
quant_axis
,
op_role
,
)
)
else
:
else
:
quant_var_node
,
scale_var_node
=
self
.
_insert_quant_op
(
quant_var_node
,
scale_var_node
=
self
.
_insert_quant_op
(
graph
,
var_node
,
name
,
quant_bits
,
quant_type
graph
,
var_node
,
name
,
quant_bits
,
quant_type
,
op_role
,
)
)
dequant_var_node
=
self
.
_insert_dequant_op
(
dequant_var_node
=
self
.
_insert_dequant_op
(
graph
,
quant_var_node
,
scale_var_node
,
quant_bits
graph
,
quant_var_node
,
scale_var_node
,
quant_bits
,
op_role
,
)
)
dequantized_vars
[
name
]
=
dequant_var_node
dequantized_vars
[
name
]
=
dequant_var_node
graph
.
update_input_link
(
var_node
,
dequant_var_node
,
op
)
graph
.
update_input_link
(
var_node
,
dequant_var_node
,
op
)
...
@@ -476,24 +492,28 @@ class QuantizationTransformPass:
...
@@ -476,24 +492,28 @@ class QuantizationTransformPass:
graph
.
link_to
(
increment_op
,
global_step_out
)
graph
.
link_to
(
increment_op
,
global_step_out
)
self
.
_global_step
=
global_step_out
self
.
_global_step
=
global_step_out
def
_insert_quant_op
(
self
,
graph
,
var_node
,
name
,
quant_bits
,
quant_type
):
def
_insert_quant_op
(
self
,
graph
,
var_node
,
name
,
quant_bits
,
quant_type
,
op_role
):
"""
"""
Insert fake_quantize_op in the graph.
Insert fake_quantize_op in the graph.
"""
"""
if
quant_type
==
'abs_max'
:
if
quant_type
==
'abs_max'
:
return
self
.
_insert_quant_abs_max_op
(
return
self
.
_insert_quant_abs_max_op
(
graph
,
var_node
,
name
,
quant_bits
graph
,
var_node
,
name
,
quant_bits
,
op_role
)
)
elif
quant_type
==
'range_abs_max'
:
elif
quant_type
==
'range_abs_max'
:
return
self
.
_insert_quant_range_abs_max_op
(
return
self
.
_insert_quant_range_abs_max_op
(
graph
,
var_node
,
name
,
quant_bits
graph
,
var_node
,
name
,
quant_bits
,
op_role
)
)
elif
quant_type
==
'moving_average_abs_max'
:
elif
quant_type
==
'moving_average_abs_max'
:
return
self
.
_insert_quant_moving_average_abs_max_op
(
return
self
.
_insert_quant_moving_average_abs_max_op
(
graph
,
var_node
,
name
,
quant_bits
graph
,
var_node
,
name
,
quant_bits
,
op_role
)
)
def
_insert_quant_abs_max_op
(
self
,
graph
,
var_node
,
name
,
quant_bits
):
def
_insert_quant_abs_max_op
(
self
,
graph
,
var_node
,
name
,
quant_bits
,
op_role
):
"""
"""
Insert fake_quantize_abs_max op in the graph.
Insert fake_quantize_abs_max op in the graph.
"""
"""
...
@@ -528,10 +548,7 @@ class QuantizationTransformPass:
...
@@ -528,10 +548,7 @@ class QuantizationTransformPass:
quant_op_node
=
graph
.
create_op_node
(
quant_op_node
=
graph
.
create_op_node
(
op_type
=
'fake_quantize_abs_max'
,
op_type
=
'fake_quantize_abs_max'
,
attrs
=
{
attrs
=
{
'bit_length'
:
quant_bits
,
'op_role'
:
op_role
},
'bit_length'
:
quant_bits
,
'op_role'
:
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
,
},
inputs
=
{
'X'
:
var_node
},
inputs
=
{
'X'
:
var_node
},
outputs
=
{
'Out'
:
quant_var_node
,
'OutScale'
:
scale_var_node
},
outputs
=
{
'Out'
:
quant_var_node
,
'OutScale'
:
scale_var_node
},
)
)
...
@@ -540,7 +557,9 @@ class QuantizationTransformPass:
...
@@ -540,7 +557,9 @@ class QuantizationTransformPass:
graph
.
link_to
(
quant_op_node
,
scale_var_node
)
graph
.
link_to
(
quant_op_node
,
scale_var_node
)
return
quant_var_node
,
scale_var_node
return
quant_var_node
,
scale_var_node
def
_insert_quant_range_abs_max_op
(
self
,
graph
,
var_node
,
name
,
quant_bits
):
def
_insert_quant_range_abs_max_op
(
self
,
graph
,
var_node
,
name
,
quant_bits
,
op_role
):
"""
"""
Insert fake_quantize_range_abs_max on the graph.
Insert fake_quantize_range_abs_max on the graph.
"""
"""
...
@@ -605,7 +624,7 @@ class QuantizationTransformPass:
...
@@ -605,7 +624,7 @@ class QuantizationTransformPass:
'window_size'
:
self
.
_window_size
,
'window_size'
:
self
.
_window_size
,
'bit_length'
:
quant_bits
,
'bit_length'
:
quant_bits
,
'is_test'
:
self
.
_is_test
,
'is_test'
:
self
.
_is_test
,
'op_role'
:
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
,
'op_role'
:
op_role
,
}
}
quant_op_node
=
graph
.
create_op_node
(
quant_op_node
=
graph
.
create_op_node
(
op_type
=
'fake_quantize_range_abs_max'
,
op_type
=
'fake_quantize_range_abs_max'
,
...
@@ -626,7 +645,7 @@ class QuantizationTransformPass:
...
@@ -626,7 +645,7 @@ class QuantizationTransformPass:
return
quant_var_node
,
scale_out_node
return
quant_var_node
,
scale_out_node
def
_insert_quant_moving_average_abs_max_op
(
def
_insert_quant_moving_average_abs_max_op
(
self
,
graph
,
var_node
,
name
,
quant_bits
self
,
graph
,
var_node
,
name
,
quant_bits
,
op_role
):
):
"""Insert fake_quantize_moving_average_abs_max"""
"""Insert fake_quantize_moving_average_abs_max"""
quant_var_node
=
graph
.
create_var_node
(
quant_var_node
=
graph
.
create_var_node
(
...
@@ -706,7 +725,7 @@ class QuantizationTransformPass:
...
@@ -706,7 +725,7 @@ class QuantizationTransformPass:
'bit_length'
:
quant_bits
,
'bit_length'
:
quant_bits
,
'moving_rate'
:
self
.
_moving_rate
,
'moving_rate'
:
self
.
_moving_rate
,
'is_test'
:
self
.
_is_test
,
'is_test'
:
self
.
_is_test
,
'op_role'
:
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
,
'op_role'
:
op_role
,
}
}
quant_op_node
=
graph
.
create_op_node
(
quant_op_node
=
graph
.
create_op_node
(
...
@@ -730,7 +749,7 @@ class QuantizationTransformPass:
...
@@ -730,7 +749,7 @@ class QuantizationTransformPass:
return
quant_var_node
,
scale_out_node
return
quant_var_node
,
scale_out_node
def
_insert_channel_quant_op
(
def
_insert_channel_quant_op
(
self
,
graph
,
var_node
,
name
,
quant_bits
,
quant_axis
self
,
graph
,
var_node
,
name
,
quant_bits
,
quant_axis
,
op_role
):
):
"""
"""
Insert fake_channel_wise_quantize_abs_max op in the graph.
Insert fake_channel_wise_quantize_abs_max op in the graph.
...
@@ -771,7 +790,7 @@ class QuantizationTransformPass:
...
@@ -771,7 +790,7 @@ class QuantizationTransformPass:
'bit_length'
:
quant_bits
,
'bit_length'
:
quant_bits
,
'quant_axis'
:
quant_axis
,
'quant_axis'
:
quant_axis
,
'is_test'
:
self
.
_is_test
,
'is_test'
:
self
.
_is_test
,
'op_role'
:
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
,
'op_role'
:
op_role
,
},
},
inputs
=
{
'X'
:
var_node
},
inputs
=
{
'X'
:
var_node
},
outputs
=
{
'Out'
:
quant_var_node
,
'OutScale'
:
scale_var_node
},
outputs
=
{
'Out'
:
quant_var_node
,
'OutScale'
:
scale_var_node
},
...
@@ -781,7 +800,9 @@ class QuantizationTransformPass:
...
@@ -781,7 +800,9 @@ class QuantizationTransformPass:
graph
.
link_to
(
quant_op_node
,
scale_var_node
)
graph
.
link_to
(
quant_op_node
,
scale_var_node
)
return
quant_var_node
,
scale_var_node
return
quant_var_node
,
scale_var_node
def
_insert_dequant_op
(
self
,
graph
,
var_node
,
scale_var_node
,
quant_bits
):
def
_insert_dequant_op
(
self
,
graph
,
var_node
,
scale_var_node
,
quant_bits
,
op_role
):
"""
"""
Insert fake_dequantize_op in the graph.
Insert fake_dequantize_op in the graph.
"""
"""
...
@@ -796,10 +817,7 @@ class QuantizationTransformPass:
...
@@ -796,10 +817,7 @@ class QuantizationTransformPass:
max_range
=
(
1
<<
(
quant_bits
-
1
))
-
1
max_range
=
(
1
<<
(
quant_bits
-
1
))
-
1
dequant_op_node
=
graph
.
create_op_node
(
dequant_op_node
=
graph
.
create_op_node
(
op_type
=
'fake_dequantize_max_abs'
,
op_type
=
'fake_dequantize_max_abs'
,
attrs
=
{
attrs
=
{
'max_range'
:
float
(
max_range
),
'op_role'
:
op_role
},
'max_range'
:
float
(
max_range
),
'op_role'
:
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
,
},
inputs
=
{
'X'
:
var_node
,
'Scale'
:
scale_var_node
},
inputs
=
{
'X'
:
var_node
,
'Scale'
:
scale_var_node
},
outputs
=
{
'Out'
:
dequant_var_node
},
outputs
=
{
'Out'
:
dequant_var_node
},
)
)
...
@@ -809,7 +827,7 @@ class QuantizationTransformPass:
...
@@ -809,7 +827,7 @@ class QuantizationTransformPass:
return
dequant_var_node
return
dequant_var_node
def
_insert_channel_dequant_op
(
def
_insert_channel_dequant_op
(
self
,
graph
,
var_node
,
scale_var_nodes
,
quant_bits
,
quant_axis
self
,
graph
,
var_node
,
scale_var_nodes
,
quant_bits
,
quant_axis
,
op_role
):
):
"""
"""
Insert fake_channel_wise_dequantize_max_abs in the graph.
Insert fake_channel_wise_dequantize_max_abs in the graph.
...
@@ -827,7 +845,7 @@ class QuantizationTransformPass:
...
@@ -827,7 +845,7 @@ class QuantizationTransformPass:
attrs
=
{
attrs
=
{
'quant_bits'
:
quant_bits
,
'quant_bits'
:
quant_bits
,
'quant_axis'
:
quant_axis
,
'quant_axis'
:
quant_axis
,
'op_role'
:
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
,
'op_role'
:
op_role
,
},
},
inputs
=
{
'X'
:
var_node
,
'Scales'
:
scale_var_nodes
},
inputs
=
{
'X'
:
var_node
,
'Scales'
:
scale_var_nodes
},
outputs
=
{
'Out'
:
dequant_var_node
},
outputs
=
{
'Out'
:
dequant_var_node
},
...
@@ -1628,11 +1646,15 @@ class OutScaleForTrainingPass:
...
@@ -1628,11 +1646,15 @@ class OutScaleForTrainingPass:
in_node
=
graph
.
_find_node_by_name
(
in_node
=
graph
.
_find_node_by_name
(
op
.
outputs
,
output_var_name
op
.
outputs
,
output_var_name
)
)
if
in_node
.
dtype
()
not
in
[
if
(
core
.
VarDesc
.
VarType
.
FP64
,
in_node
.
dtype
()
core
.
VarDesc
.
VarType
.
FP32
,
not
in
[
core
.
VarDesc
.
VarType
.
FP16
,
core
.
VarDesc
.
VarType
.
FP64
,
]:
core
.
VarDesc
.
VarType
.
FP32
,
core
.
VarDesc
.
VarType
.
FP16
,
]
or
'@GRAD'
in
in_node
.
name
()
):
continue
continue
if
in_node
.
dtype
()
==
core
.
VarDesc
.
VarType
.
FP64
:
if
in_node
.
dtype
()
==
core
.
VarDesc
.
VarType
.
FP64
:
...
@@ -1710,7 +1732,7 @@ class OutScaleForTrainingPass:
...
@@ -1710,7 +1732,7 @@ class OutScaleForTrainingPass:
attrs
=
{
attrs
=
{
'moving_rate'
:
self
.
_moving_rate
,
'moving_rate'
:
self
.
_moving_rate
,
'is_test'
:
self
.
_is_test
,
'is_test'
:
self
.
_is_test
,
'op_role'
:
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
,
'op_role'
:
op
.
op
().
attr
(
"op_role"
)
,
}
}
scale_op_node
=
graph
.
create_op_node
(
scale_op_node
=
graph
.
create_op_node
(
op_type
=
'moving_average_abs_max_scale'
,
op_type
=
'moving_average_abs_max_scale'
,
...
@@ -1953,7 +1975,10 @@ class AddQuantDequantPass:
...
@@ -1953,7 +1975,10 @@ class AddQuantDequantPass:
quant_var_node
,
quant_var_node
,
_
,
_
,
)
=
self
.
_inser_quant_dequant_moving_average_abs_max_op
(
)
=
self
.
_inser_quant_dequant_moving_average_abs_max_op
(
graph
,
in_node
,
self
.
_quant_bits
graph
,
in_node
,
self
.
_quant_bits
,
op_node
.
op
().
attr
(
"op_role"
),
)
)
dequantized_vars_map
[
arg_name
]
=
quant_var_node
dequantized_vars_map
[
arg_name
]
=
quant_var_node
graph
.
update_input_link
(
graph
.
update_input_link
(
...
@@ -1978,7 +2003,7 @@ class AddQuantDequantPass:
...
@@ -1978,7 +2003,7 @@ class AddQuantDequantPass:
return
graph
return
graph
def
_inser_quant_dequant_moving_average_abs_max_op
(
def
_inser_quant_dequant_moving_average_abs_max_op
(
self
,
graph
,
var_node
,
quant_bits
self
,
graph
,
var_node
,
quant_bits
,
op_role
):
):
"""Insert fake_quantize_dequantize_moving_average_abs_max op."""
"""Insert fake_quantize_dequantize_moving_average_abs_max op."""
quant_var_node
=
graph
.
create_var_node
(
quant_var_node
=
graph
.
create_var_node
(
...
@@ -2068,7 +2093,7 @@ class AddQuantDequantPass:
...
@@ -2068,7 +2093,7 @@ class AddQuantDequantPass:
'bit_length'
:
quant_bits
,
'bit_length'
:
quant_bits
,
'moving_rate'
:
self
.
_moving_rate
,
'moving_rate'
:
self
.
_moving_rate
,
'is_test'
:
self
.
_is_test
,
'is_test'
:
self
.
_is_test
,
'op_role'
:
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
,
'op_role'
:
op_role
,
}
}
quant_op_node
=
graph
.
create_op_node
(
quant_op_node
=
graph
.
create_op_node
(
...
@@ -2131,7 +2156,12 @@ class InsertQuantizeLinear:
...
@@ -2131,7 +2156,12 @@ class InsertQuantizeLinear:
self
.
_scale_dict
=
scale_dict
self
.
_scale_dict
=
scale_dict
def
insert_quant_op
(
def
insert_quant_op
(
self
,
graph
,
var_node
,
var_name
=
None
,
scale_var_node
=
None
self
,
graph
,
var_node
,
var_name
=
None
,
scale_var_node
=
None
,
op_role
=
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
,
):
):
assert
var_node
.
is_var
(),
f
'
{
var_node
.
name
()
}
is not a var'
assert
var_node
.
is_var
(),
f
'
{
var_node
.
name
()
}
is not a var'
var_name
=
var_node
.
name
()
if
not
var_name
else
var_name
var_name
=
var_node
.
name
()
if
not
var_name
else
var_name
...
@@ -2200,7 +2230,7 @@ class InsertQuantizeLinear:
...
@@ -2200,7 +2230,7 @@ class InsertQuantizeLinear:
inputs
[
"ZeroPoint"
]
=
zero_point_node
inputs
[
"ZeroPoint"
]
=
zero_point_node
attrs
=
{
"quant_axis"
:
self
.
quant_axis
,
"bit_length"
:
self
.
quant_bits
}
attrs
=
{
"quant_axis"
:
self
.
quant_axis
,
"bit_length"
:
self
.
quant_bits
}
attrs
[
"op_role"
]
=
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
attrs
[
"op_role"
]
=
op_role
outputs
=
{
"Y"
:
quant_var_node
}
outputs
=
{
"Y"
:
quant_var_node
}
if
not
self
.
_is_test
:
if
not
self
.
_is_test
:
scale_out_node
=
graph
.
create_var_node_from_desc
(
scale_out_node
=
graph
.
create_var_node_from_desc
(
...
@@ -2271,7 +2301,7 @@ class InsertQuantizeLinear:
...
@@ -2271,7 +2301,7 @@ class InsertQuantizeLinear:
graph
.
link_to
(
quant_op_node
,
scale_out_node
)
graph
.
link_to
(
quant_op_node
,
scale_out_node
)
return
quant_var_node
,
scale_var_node
return
quant_var_node
,
scale_var_node
def
insert_dequant_op
(
self
,
graph
,
var_node
,
scale_var_node
):
def
insert_dequant_op
(
self
,
graph
,
var_node
,
scale_var_node
,
op_role
):
assert
var_node
.
is_var
(),
f
'
{
var_node
.
name
()
}
is not a var'
assert
var_node
.
is_var
(),
f
'
{
var_node
.
name
()
}
is not a var'
dequant_var_node
=
graph
.
create_var_node
(
dequant_var_node
=
graph
.
create_var_node
(
...
@@ -2301,7 +2331,7 @@ class InsertQuantizeLinear:
...
@@ -2301,7 +2331,7 @@ class InsertQuantizeLinear:
inputs
[
"ZeroPoint"
]
=
zero_point_node
inputs
[
"ZeroPoint"
]
=
zero_point_node
attrs
=
{
"quant_axis"
:
self
.
quant_axis
,
"bit_length"
:
self
.
quant_bits
}
attrs
=
{
"quant_axis"
:
self
.
quant_axis
,
"bit_length"
:
self
.
quant_bits
}
attrs
[
"op_role"
]
=
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
attrs
[
"op_role"
]
=
op_role
quant_op_node
=
graph
.
create_op_node
(
quant_op_node
=
graph
.
create_op_node
(
op_type
=
"dequantize_linear"
,
op_type
=
"dequantize_linear"
,
...
@@ -2513,6 +2543,7 @@ class QuantizationTransformPassV2(QuantizationTransformPass):
...
@@ -2513,6 +2543,7 @@ class QuantizationTransformPassV2(QuantizationTransformPass):
def
_transform_forward
(
self
,
graph
,
op
):
def
_transform_forward
(
self
,
graph
,
op
):
op
.
op
().
_set_attr
(
"quantization_type"
,
"qat_with_weight"
)
op
.
op
().
_set_attr
(
"quantization_type"
,
"qat_with_weight"
)
op_role
=
op
.
op
().
attr
(
"op_role"
)
weight_scale_node
=
None
weight_scale_node
=
None
inputs
=
op
.
inputs
inputs
=
op
.
inputs
for
var_node
in
inputs
:
for
var_node
in
inputs
:
...
@@ -2592,10 +2623,10 @@ class QuantizationTransformPassV2(QuantizationTransformPass):
...
@@ -2592,10 +2623,10 @@ class QuantizationTransformPassV2(QuantizationTransformPass):
quant_var_node
,
quant_var_node
,
scale_var_node
,
scale_var_node
,
)
=
insert_quant_pass
.
insert_quant_op
(
)
=
insert_quant_pass
.
insert_quant_op
(
graph
,
var_node
,
var_name
=
name
graph
,
var_node
,
var_name
=
name
,
op_role
=
op_role
)
)
dequant_var_node
=
insert_quant_pass
.
insert_dequant_op
(
dequant_var_node
=
insert_quant_pass
.
insert_dequant_op
(
graph
,
quant_var_node
,
scale_var_node
graph
,
quant_var_node
,
scale_var_node
,
op_role
)
)
self
.
dequantized_vars
[
name
]
=
dequant_var_node
self
.
dequantized_vars
[
name
]
=
dequant_var_node
...
@@ -2676,9 +2707,13 @@ class QuantizationTransformPassV2(QuantizationTransformPass):
...
@@ -2676,9 +2707,13 @@ class QuantizationTransformPassV2(QuantizationTransformPass):
var_node
,
var_node
,
var_name
=
var_node
.
name
(),
var_name
=
var_node
.
name
(),
scale_var_node
=
scale_var_node
,
scale_var_node
=
scale_var_node
,
op_role
=
op
.
op
().
attr
(
"op_role"
),
)
)
dequant_var_node
=
insert_quant_pass
.
insert_dequant_op
(
dequant_var_node
=
insert_quant_pass
.
insert_dequant_op
(
graph
,
quant_var_node
,
scale_var_node
graph
,
quant_var_node
,
scale_var_node
,
op
.
op
().
attr
(
"op_role"
),
)
)
graph
.
update_input_link
(
var_node
,
dequant_var_node
,
op
)
graph
.
update_input_link
(
var_node
,
dequant_var_node
,
op
)
...
@@ -2913,11 +2948,16 @@ class AddQuantDequantPassV2:
...
@@ -2913,11 +2948,16 @@ class AddQuantDequantPassV2:
quant_var_node
,
quant_var_node
,
scale_var_node
,
scale_var_node
,
)
=
insert_quant_pass
.
insert_quant_op
(
)
=
insert_quant_pass
.
insert_quant_op
(
graph
,
in_node
graph
,
in_node
,
op_role
=
op_node
.
op
().
attr
(
"op_role"
),
)
)
dequant_var_node
=
(
dequant_var_node
=
(
insert_quant_pass
.
insert_dequant_op
(
insert_quant_pass
.
insert_dequant_op
(
graph
,
quant_var_node
,
scale_var_node
graph
,
quant_var_node
,
scale_var_node
,
op_node
.
op
().
attr
(
"op_role"
),
)
)
)
)
dequantized_vars_map
[
arg_name
]
=
dequant_var_node
dequantized_vars_map
[
arg_name
]
=
dequant_var_node
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录