Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
f201b465
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f201b465
编写于
10月 16, 2019
作者:
J
juncaipeng
提交者:
GitHub
10月 16, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Move pool2d to add_quant_dequant_pass, test=develop (#20586)
* move pool2d to add_quant_dequant_pass, test=develop
上级
efa10937
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
57 addition
and
35 deletion
+57
-35
python/paddle/fluid/contrib/slim/quantization/quantization_pass.py
...ddle/fluid/contrib/slim/quantization/quantization_pass.py
+16
-11
python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py
...paddle/fluid/contrib/slim/tests/test_quantization_pass.py
+41
-24
未找到文件。
python/paddle/fluid/contrib/slim/quantization/quantization_pass.py
浏览文件 @
f201b465
...
...
@@ -26,7 +26,7 @@ __all__ = [
'AddQuantDequantPass'
]
_quantizable_op_list
=
[
'conv2d'
,
'depthwise_conv2d'
,
'mul'
,
'pool2d'
]
_quantizable_op_list
=
[
'conv2d'
,
'depthwise_conv2d'
,
'mul'
]
_fake_quant_op_list
=
[
'fake_quantize_abs_max'
,
'fake_quantize_range_abs_max'
,
...
...
@@ -161,13 +161,11 @@ class QuantizationTransformPass(object):
persistable_vars
=
[
p
.
name
()
for
p
in
graph
.
all_persistable_nodes
()]
def
_quant_preprocess
(
op_node
):
pool_skipped
=
op_node
.
op
().
has_attr
(
"pooling_type"
)
and
\
op_node
.
op
().
attr
(
"pooling_type"
)
==
'avg'
user_skipped
=
isinstance
(
self
.
_skip_pattern
,
str
)
and
\
op_node
.
op
().
has_attr
(
"op_namescope"
)
and
\
op_node
.
op
().
attr
(
"op_namescope"
).
find
(
self
.
_skip_pattern
)
!=
-
1
if
pool_skipped
or
user_skipped
:
if
user_skipped
:
op_node
.
op
().
_set_attr
(
"skip_quant"
,
True
)
def
_transform_forward
(
graph
,
op
):
...
...
@@ -1163,10 +1161,15 @@ class ScaleForInferencePass(object):
class
AddQuantDequantPass
(
object
):
def
__init__
(
self
,
scope
=
None
,
place
=
None
,
moving_rate
=
0.9
,
quant_bits
=
8
):
def
__init__
(
self
,
scope
=
None
,
place
=
None
,
moving_rate
=
0.9
,
quant_bits
=
8
,
skip_pattern
=
'skip_quant'
):
"""
This pass is used to add quant_dequant op for some ops, such as the
'elementwise_add' and '
average
pool2d' op.
'elementwise_add' and 'pool2d' op.
"""
self
.
_scope
=
scope
self
.
_place
=
place
...
...
@@ -1175,11 +1178,12 @@ class AddQuantDequantPass(object):
self
.
_is_test
=
None
self
.
_target_ops
=
[
"elementwise_add"
,
"pool2d"
]
self
.
_target_grad_ops
=
[
'%s_grad'
%
(
op
)
for
op
in
self
.
_target_ops
]
self
.
_skip_pattern
=
skip_pattern
def
apply
(
self
,
graph
):
"""
Add quant_dequant before some ops, such as the 'elementwise_add'
and '
average
pool2d' op.
and 'pool2d' op.
Args:
graph(IrGraph): the target graph.
"""
...
...
@@ -1191,6 +1195,11 @@ class AddQuantDequantPass(object):
for
op_node
in
ops
:
if
op_node
.
name
()
in
self
.
_target_ops
:
if
isinstance
(
self
.
_skip_pattern
,
str
)
and
\
op_node
.
op
().
has_attr
(
"op_namescope"
)
and
\
op_node
.
op
().
attr
(
"op_namescope"
).
find
(
self
.
_skip_pattern
)
!=
-
1
:
continue
in_nodes_all_not_persistable
=
True
for
input_name
in
op_node
.
input_arg_names
():
in_node
=
graph
.
_find_node_by_name
(
op_node
.
inputs
,
...
...
@@ -1201,10 +1210,6 @@ class AddQuantDequantPass(object):
if
not
in_nodes_all_not_persistable
:
continue
if
op_node
.
op
().
has_attr
(
"pooling_type"
)
and
\
op_node
.
op
().
attr
(
"pooling_type"
)
==
'max'
:
continue
input_names
=
op_node
.
input_arg_names
()
for
input_name
in
input_names
:
in_node
=
graph
.
_find_node_by_name
(
op_node
.
inputs
,
...
...
python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py
浏览文件 @
f201b465
...
...
@@ -42,7 +42,7 @@ def linear_fc(num):
return
loss
def
residual_block
(
num
):
def
residual_block
(
num
,
quant_skip_pattern
=
None
):
def
conv_bn_layer
(
input
,
ch_out
,
filter_size
,
...
...
@@ -67,8 +67,14 @@ def residual_block(num):
conv
=
conv_bn_layer
(
hidden
,
16
,
3
,
1
,
1
,
act
=
None
,
bias_attr
=
True
)
short
=
conv_bn_layer
(
hidden
,
16
,
1
,
1
,
0
,
act
=
None
)
hidden
=
fluid
.
layers
.
elementwise_add
(
x
=
conv
,
y
=
short
,
act
=
'relu'
)
pool
=
fluid
.
layers
.
pool2d
(
input
=
hidden
,
pool_size
=
2
,
pool_type
=
'avg'
,
pool_stride
=
2
)
if
quant_skip_pattern
:
with
fluid
.
name_scope
(
quant_skip_pattern
):
pool
=
fluid
.
layers
.
pool2d
(
input
=
hidden
,
pool_size
=
2
,
pool_type
=
'avg'
,
pool_stride
=
2
)
else
:
pool
=
fluid
.
layers
.
pool2d
(
input
=
hidden
,
pool_size
=
2
,
pool_type
=
'avg'
,
pool_stride
=
2
)
fc
=
fluid
.
layers
.
fc
(
input
=
pool
,
size
=
10
)
loss
=
fluid
.
layers
.
cross_entropy
(
input
=
fc
,
label
=
label
)
loss
=
fluid
.
layers
.
mean
(
loss
)
...
...
@@ -134,7 +140,10 @@ class TestQuantizationTransformPass(unittest.TestCase):
arg_name
.
endswith
(
'.quantized.dequantized'
))
self
.
assertTrue
(
arg_name
in
quantized_ops
)
def
linear_fc_quant
(
self
,
activation_quant_type
,
for_ci
=
True
):
def
linear_fc_quant
(
self
,
activation_quant_type
,
weight_quantize_type
,
for_ci
=
True
):
main
=
fluid
.
Program
()
startup
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main
,
startup
):
...
...
@@ -146,7 +155,8 @@ class TestQuantizationTransformPass(unittest.TestCase):
transform_pass
=
QuantizationTransformPass
(
scope
=
fluid
.
global_scope
(),
place
=
place
,
activation_quantize_type
=
activation_quant_type
)
activation_quantize_type
=
activation_quant_type
,
weight_quantize_type
=
weight_quantize_type
)
transform_pass
.
apply
(
graph
)
if
not
for_ci
:
marked_nodes
=
set
()
...
...
@@ -167,15 +177,19 @@ class TestQuantizationTransformPass(unittest.TestCase):
val_marked_nodes
)
def
test_linear_fc_quant_abs_max
(
self
):
self
.
linear_fc_quant
(
'abs_max'
,
for_ci
=
True
)
self
.
linear_fc_quant
(
'abs_max'
,
'abs_max'
,
for_ci
=
True
)
def
test_linear_fc_quant_range_abs_max
(
self
):
self
.
linear_fc_quant
(
'range_abs_max'
,
for_ci
=
True
)
self
.
linear_fc_quant
(
'range_abs_max'
,
'abs_max'
,
for_ci
=
True
)
def
test_linear_fc_quant_moving_average_abs_max
(
self
):
self
.
linear_fc_quant
(
'moving_average_abs_max'
,
for_ci
=
True
)
self
.
linear_fc_quant
(
'moving_average_abs_max'
,
'channel_wise_abs_max'
,
for_ci
=
True
)
def
residual_block_quant
(
self
,
activation_quant_type
,
for_ci
=
True
):
def
residual_block_quant
(
self
,
activation_quant_type
,
weight_quantize_type
,
for_ci
=
True
):
main
=
fluid
.
Program
()
startup
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main
,
startup
):
...
...
@@ -187,7 +201,8 @@ class TestQuantizationTransformPass(unittest.TestCase):
transform_pass
=
QuantizationTransformPass
(
scope
=
fluid
.
global_scope
(),
place
=
place
,
activation_quantize_type
=
activation_quant_type
)
activation_quantize_type
=
activation_quant_type
,
weight_quantize_type
=
weight_quantize_type
)
transform_pass
.
apply
(
graph
)
if
not
for_ci
:
marked_nodes
=
set
()
...
...
@@ -208,13 +223,14 @@ class TestQuantizationTransformPass(unittest.TestCase):
val_marked_nodes
)
def
test_residual_block_abs_max
(
self
):
self
.
residual_block_quant
(
'abs_max'
,
for_ci
=
True
)
self
.
residual_block_quant
(
'abs_max'
,
'abs_max'
,
for_ci
=
True
)
def
test_residual_block_range_abs_max
(
self
):
self
.
residual_block_quant
(
'range_abs_max'
,
for_ci
=
True
)
self
.
residual_block_quant
(
'range_abs_max'
,
'abs_max'
,
for_ci
=
True
)
def
test_residual_block_moving_average_abs_max
(
self
):
self
.
residual_block_quant
(
'moving_average_abs_max'
,
for_ci
=
True
)
self
.
residual_block_quant
(
'moving_average_abs_max'
,
'channel_wise_abs_max'
,
for_ci
=
True
)
class
TestQuantizationFreezePass
(
unittest
.
TestCase
):
...
...
@@ -494,11 +510,14 @@ class TestAddQuantDequantPass(unittest.TestCase):
self
.
_target_ops
=
{
'elementwise_add'
,
'pool2d'
}
self
.
_target_grad_ops
=
{
'elementwise_add_grad'
,
'pool2d_grad'
}
def
check_graph
(
self
,
graph
):
def
check_graph
(
self
,
graph
,
skip_pattern
=
None
):
ops
=
graph
.
all_op_nodes
()
for
op_node
in
ops
:
if
op_node
.
name
()
in
self
.
_target_ops
:
if
skip_pattern
and
op_node
.
op
().
has_attr
(
"op_namescope"
)
and
\
op_node
.
op
().
attr
(
"op_namescope"
).
find
(
skip_pattern
)
!=
-
1
:
continue
in_nodes_all_not_persistable
=
True
for
input_name
in
op_node
.
input_arg_names
():
in_node
=
graph
.
_find_node_by_name
(
op_node
.
inputs
,
...
...
@@ -508,20 +527,15 @@ class TestAddQuantDequantPass(unittest.TestCase):
not
in_node
.
persistable
())
if
not
in_nodes_all_not_persistable
:
continue
if
op_node
.
op
().
has_attr
(
"pooling_type"
)
and
\
op_node
.
op
().
attr
(
"pooling_type"
)
==
'max'
:
continue
input_names
=
op_node
.
input_arg_names
()
for
input_name
in
input_names
:
self
.
assertTrue
(
input_name
.
endswith
(
'.quant_dequant'
))
def
residual_block_quant
(
self
,
for_ci
=
True
):
def
residual_block_quant
(
self
,
skip_pattern
=
None
,
for_ci
=
True
):
main
=
fluid
.
Program
()
startup
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main
,
startup
):
loss
=
residual_block
(
1
)
loss
=
residual_block
(
2
,
skip_pattern
)
opt
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
0.001
)
opt
.
minimize
(
loss
)
place
=
fluid
.
CPUPlace
()
...
...
@@ -535,7 +549,7 @@ class TestAddQuantDequantPass(unittest.TestCase):
if
op
.
name
().
find
(
'quant'
)
>
-
1
:
marked_nodes
.
add
(
op
)
graph
.
draw
(
'.'
,
'add_quant_dequant_graph'
,
marked_nodes
)
self
.
check_graph
(
graph
)
self
.
check_graph
(
graph
,
skip_pattern
)
program
=
graph
.
to_program
()
val_graph
=
IrGraph
(
core
.
Graph
(
program
.
desc
),
for_test
=
False
)
if
not
for_ci
:
...
...
@@ -546,7 +560,10 @@ class TestAddQuantDequantPass(unittest.TestCase):
val_graph
.
draw
(
'.'
,
'val_add_quant_dequant_graph'
,
val_marked_nodes
)
def
test_residual_block
(
self
):
self
.
residual_block_quant
(
for_ci
=
True
)
self
.
residual_block_quant
(
skip_pattern
=
None
,
for_ci
=
True
)
def
test_residual_block_skip_pattern
(
self
):
self
.
residual_block_quant
(
skip_pattern
=
'skip_quant'
,
for_ci
=
True
)
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录