Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
aa731e63
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
aa731e63
编写于
3月 24, 2021
作者:
W
Wojciech Uss
提交者:
GitHub
3月 24, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update scale collection and propagation algorithm (#31783) (#31810)
上级
f3b0f8db
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
25 addition
and
25 deletion
+25
-25
python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py
...luid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py
+25
-25
未找到文件。
python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py
浏览文件 @
aa731e63
...
...
@@ -62,9 +62,8 @@ class Quant2Int8MkldnnPass(object):
self
.
_ops_to_quantize
=
_ops_to_quantize
self
.
_op_ids_to_skip
=
_op_ids_to_skip
if
_op_ids_to_skip
is
not
None
else
set
(
[
-
1
])
self
.
_scale_immutable_ops
=
[
'transpose2'
,
'reshape2'
,
'pool2d'
,
'scale'
]
self
.
_scale_immutable_ops
=
[
'transpose2'
,
'reshape2'
,
'pool2d'
]
self
.
_scale_ops
=
[
'scale'
]
self
.
_conv_ops
=
[
'conv2d'
,
'depthwise_conv2d'
]
self
.
_pool_ops
=
[
'pool2d'
]
self
.
_mul_ops
=
[
'mul'
]
...
...
@@ -87,8 +86,8 @@ class Quant2Int8MkldnnPass(object):
self
.
_reset_pass_idx_and_group
(
'int8'
)
graph
=
self
.
_label_skip_quantized_op
(
graph
)
graph
=
self
.
_gather_weight_thresholds_from_fake
(
graph
)
graph
=
self
.
_gather_output_scales_from_attr
(
graph
)
graph
=
self
.
_gather_input_scales_from_fake
(
graph
)
graph
=
self
.
_gather_output_scales_from_attr
(
graph
)
graph
=
self
.
_remove_fake_ops
(
graph
)
graph
=
self
.
_dequantize_weights
(
graph
)
graph
=
self
.
_optimize_fp32_graph
(
graph
)
...
...
@@ -160,12 +159,16 @@ class Quant2Int8MkldnnPass(object):
op_node
.
op
().
_set_attr
(
"skip_quant"
,
True
)
return
graph
def
_gather_input_scales_from_fake
(
self
,
graph
):
def
_add_scale_for_vars
(
var_names
,
use_unsigned_int
,
lod_tensor
):
scales
=
self
.
_var_quant_scales
for
var_name
in
var_names
:
def
_add_scale_for_vars
(
self
,
var_names
,
use_unsigned_int
,
lod_tensor
):
"""
Save quantization scales for variables. Do not overwrite.
"""
scales
=
self
.
_var_quant_scales
for
var_name
in
var_names
:
if
var_name
not
in
scales
:
scales
[
var_name
]
=
(
use_unsigned_int
,
lod_tensor
)
def
_gather_input_scales_from_fake
(
self
,
graph
):
# fake_quantize_dequantize_abs_max doesn't have scale value
fake_ops
=
[
'fake_quantize_dequantize_moving_average_abs_max'
]
fake_ops
.
extend
(
self
.
_fake_quantize_types
)
...
...
@@ -185,8 +188,8 @@ class Quant2Int8MkldnnPass(object):
scale
[
scale
==
np
.
Inf
]
=
0.0
lod_tensor
=
self
.
_convert_scale2tensor
(
scale
)
use_unsigned_int
=
False
_add_scale_for_vars
([
input_name
,
output_name
],
use_unsigned_int
,
lod_tensor
)
self
.
_add_scale_for_vars
([
input_name
,
output_name
]
,
use_unsigned_int
,
lod_tensor
)
return
graph
...
...
@@ -219,8 +222,8 @@ class Quant2Int8MkldnnPass(object):
use_unsigned_int
=
False
for
output_name
in
op
.
op
().
outputs
():
for
out_var_name
in
op
.
op
().
output
(
output_name
):
self
.
_
var_quant_scales
[
out_var_name
]
=
(
use_unsigned_int
,
scale_lod_tensor
)
self
.
_
add_scale_for_vars
(
[
out_var_name
],
use_unsigned_int
,
scale_lod_tensor
)
return
graph
...
...
@@ -239,24 +242,21 @@ class Quant2Int8MkldnnPass(object):
output_name
=
op
.
output
(
"Out"
)[
0
]
tensor_names
=
[
input_name
,
output_name
]
# Scale is not quantized, so if it doesn't have any scales
# to propagate, its tensors won't be added to the waiting list.
if
all
(
name
not
in
self
.
_var_quant_scales
for
name
in
tensor_names
)
\
and
op
.
name
()
!=
'scale'
:
if
all
(
name
not
in
self
.
_var_quant_scales
for
name
in
tensor_names
):
waiting_for_scale
.
update
(
tensor_names
)
continue
if
input_name
in
self
.
_var_quant_scales
:
elif
input_name
in
self
.
_var_quant_scales
:
self
.
_var_quant_scales
[
output_name
]
=
self
.
_var_quant_scales
[
input_name
]
elif
output_name
in
self
.
_var_quant_scales
:
if
op
.
name
()
==
'scale'
:
_update_scale_op_in_scale
(
op
,
input_name
,
output_name
)
else
:
self
.
_var_quant_scales
[
input_name
]
=
self
.
_var_quant_scales
[
output_name
]
self
.
_var_quant_scales
[
input_name
]
=
self
.
_var_quant_scales
[
output_name
]
elif
op
.
name
()
in
self
.
_scale_ops
:
input_name
=
op
.
input
(
"X"
)[
0
]
output_name
=
op
.
output
(
"Out"
)[
0
]
if
output_name
in
self
.
_var_quant_scales
:
_update_scale_op_in_scale
(
op
,
input_name
,
output_name
)
return
waiting_for_scale
waiting_for_scale
=
_update_scales
(
graph
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录