Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
07e6a942
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
07e6a942
编写于
11月 26, 2019
作者:
I
itminner
提交者:
whs
11月 26, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
paddleslim quantization skip pattern support list of string (#21141)
上级
d8e7d252
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
47 addition
and
14 deletion
+47
-14
python/paddle/fluid/contrib/slim/quantization/quantization_pass.py
...ddle/fluid/contrib/slim/quantization/quantization_pass.py
+20
-10
python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py
...paddle/fluid/contrib/slim/tests/test_quantization_pass.py
+27
-4
未找到文件。
python/paddle/fluid/contrib/slim/quantization/quantization_pass.py
浏览文件 @
07e6a942
...
...
@@ -99,7 +99,7 @@ class QuantizationTransformPass(object):
weight_quantize_type
=
'abs_max'
,
window_size
=
10000
,
moving_rate
=
0.9
,
skip_pattern
=
'skip_quant'
,
skip_pattern
=
[
'skip_quant'
]
,
quantizable_op_type
=
[
'conv2d'
,
'depthwise_conv2d'
,
'mul'
]):
"""
Convert and rewrite the IrGraph according to weight and
...
...
@@ -126,9 +126,9 @@ class QuantizationTransformPass(object):
model is well trained.
window_size(int): the window size for 'range_abs_max' quantization.
moving_rate(float): the param for 'moving_average_abs_max' quantization.
skip_pattern(str): The user-defined quantization skip pattern, which
skip_pattern(str
or str list
): The user-defined quantization skip pattern, which
will be presented in the name scope of an op. When the skip pattern is
detected in an op's name scope, the corresponding op will not be quantized.
detected in an op's name scope, the corresponding op will not be quantized.
quantizable_op_type(list[str]): List the type of ops that will be quantized.
Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in
QuantizationFreezePass and ConvertToInt8Pass must be the same as this.
...
...
@@ -206,9 +206,13 @@ class QuantizationTransformPass(object):
persistable_vars
=
[
p
.
name
()
for
p
in
graph
.
all_persistable_nodes
()]
def
_quant_preprocess
(
op_node
):
user_skipped
=
isinstance
(
self
.
_skip_pattern
,
str
)
and
\
op_node
.
op
().
has_attr
(
"op_namescope"
)
and
\
op_node
.
op
().
attr
(
"op_namescope"
).
find
(
self
.
_skip_pattern
)
!=
-
1
user_skipped
=
False
if
isinstance
(
self
.
_skip_pattern
,
list
):
user_skipped
=
op_node
.
op
().
has_attr
(
"op_namescope"
)
and
\
any
(
pattern
in
op_node
.
op
().
attr
(
"op_namescope"
)
for
pattern
in
self
.
_skip_pattern
)
elif
isinstance
(
self
.
_skip_pattern
,
str
):
user_skipped
=
op_node
.
op
().
has_attr
(
"op_namescope"
)
and
\
op_node
.
op
().
attr
(
"op_namescope"
).
find
(
self
.
_skip_pattern
)
!=
-
1
if
user_skipped
:
op_node
.
op
().
_set_attr
(
"skip_quant"
,
True
)
...
...
@@ -1245,7 +1249,7 @@ class AddQuantDequantPass(object):
place
=
None
,
moving_rate
=
0.9
,
quant_bits
=
8
,
skip_pattern
=
'skip_quant'
,
skip_pattern
=
[
"skip_quant"
]
,
quantizable_op_type
=
[
"elementwise_add"
,
"pool2d"
,
"concat"
],
is_full_quantized
=
False
):
"""
...
...
@@ -1313,9 +1317,15 @@ class AddQuantDequantPass(object):
all_op_nodes
=
graph
.
all_op_nodes
()
for
op_node
in
all_op_nodes
:
if
op_node
.
name
()
in
self
.
_quantizable_op_type
:
if
isinstance
(
self
.
_skip_pattern
,
str
)
and
\
op_node
.
op
().
has_attr
(
"op_namescope"
)
and
\
op_node
.
op
().
attr
(
"op_namescope"
).
find
(
self
.
_skip_pattern
)
!=
-
1
:
user_skipped
=
False
if
isinstance
(
self
.
_skip_pattern
,
list
):
user_skipped
=
op_node
.
op
().
has_attr
(
"op_namescope"
)
and
\
any
(
pattern
in
op_node
.
op
().
attr
(
"op_namescope"
)
for
pattern
in
self
.
_skip_pattern
)
elif
isinstance
(
self
.
_skip_pattern
,
str
):
user_skipped
=
op_node
.
op
().
has_attr
(
"op_namescope"
)
and
\
op_node
.
op
().
attr
(
"op_namescope"
).
find
(
self
.
_skip_pattern
)
!=
-
1
if
user_skipped
:
continue
if
not
self
.
_is_input_all_not_persistable
(
graph
,
op_node
):
...
...
python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py
浏览文件 @
07e6a942
...
...
@@ -531,7 +531,7 @@ def quant_dequant_residual_block(num, quant_skip_pattern=None):
short
=
conv_bn_layer
(
hidden
,
16
,
1
,
1
,
0
,
act
=
None
)
hidden
=
fluid
.
layers
.
elementwise_add
(
x
=
conv
,
y
=
short
,
act
=
'relu'
)
if
quant_skip_pattern
:
if
isinstance
(
quant_skip_pattern
,
str
)
:
with
fluid
.
name_scope
(
quant_skip_pattern
):
pool1
=
fluid
.
layers
.
pool2d
(
input
=
hidden
,
pool_size
=
2
,
pool_type
=
'avg'
,
pool_stride
=
2
)
...
...
@@ -539,6 +539,18 @@ def quant_dequant_residual_block(num, quant_skip_pattern=None):
input
=
hidden
,
pool_size
=
2
,
pool_type
=
'max'
,
pool_stride
=
2
)
pool_add
=
fluid
.
layers
.
elementwise_add
(
x
=
pool1
,
y
=
pool2
,
act
=
'relu'
)
elif
isinstance
(
quant_skip_pattern
,
list
):
assert
len
(
quant_skip_pattern
)
>
1
,
'test config error: the len of quant_skip_pattern list should be greater than 1.'
with
fluid
.
name_scope
(
quant_skip_pattern
[
0
]):
pool1
=
fluid
.
layers
.
pool2d
(
input
=
hidden
,
pool_size
=
2
,
pool_type
=
'avg'
,
pool_stride
=
2
)
pool2
=
fluid
.
layers
.
pool2d
(
input
=
hidden
,
pool_size
=
2
,
pool_type
=
'max'
,
pool_stride
=
2
)
with
fluid
.
name_scope
(
quant_skip_pattern
[
1
]):
pool_add
=
fluid
.
layers
.
elementwise_add
(
x
=
pool1
,
y
=
pool2
,
act
=
'relu'
)
else
:
pool1
=
fluid
.
layers
.
pool2d
(
input
=
hidden
,
pool_size
=
2
,
pool_type
=
'avg'
,
pool_stride
=
2
)
...
...
@@ -560,8 +572,15 @@ class TestAddQuantDequantPass(unittest.TestCase):
ops
=
graph
.
all_op_nodes
()
for
op_node
in
ops
:
if
op_node
.
name
()
in
self
.
_target_ops
:
if
skip_pattern
and
op_node
.
op
().
has_attr
(
"op_namescope"
)
and
\
op_node
.
op
().
attr
(
"op_namescope"
).
find
(
skip_pattern
)
!=
-
1
:
user_skipped
=
False
if
isinstance
(
skip_pattern
,
list
):
user_skipped
=
op_node
.
op
().
has_attr
(
"op_namescope"
)
and
\
any
(
pattern
in
op_node
.
op
().
attr
(
"op_namescope"
)
for
pattern
in
skip_pattern
)
elif
isinstance
(
skip_pattern
,
str
):
user_skipped
=
op_node
.
op
().
has_attr
(
"op_namescope"
)
and
\
op_node
.
op
().
attr
(
"op_namescope"
).
find
(
skip_pattern
)
!=
-
1
if
user_skipped
:
continue
in_nodes_all_not_persistable
=
True
...
...
@@ -587,7 +606,7 @@ class TestAddQuantDequantPass(unittest.TestCase):
place
=
fluid
.
CPUPlace
()
graph
=
IrGraph
(
core
.
Graph
(
main
.
desc
),
for_test
=
False
)
add_quant_dequant_pass
=
AddQuantDequantPass
(
scope
=
fluid
.
global_scope
(),
place
=
place
)
scope
=
fluid
.
global_scope
(),
place
=
place
,
skip_pattern
=
skip_pattern
)
add_quant_dequant_pass
.
apply
(
graph
)
if
not
for_ci
:
marked_nodes
=
set
()
...
...
@@ -611,6 +630,10 @@ class TestAddQuantDequantPass(unittest.TestCase):
def
test_residual_block_skip_pattern
(
self
):
self
.
residual_block_quant
(
skip_pattern
=
'skip_quant'
,
for_ci
=
True
)
def
test_residual_block_skip_pattern
(
self
):
self
.
residual_block_quant
(
skip_pattern
=
[
'skip_quant1'
,
'skip_quant2'
],
for_ci
=
True
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录