Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
b0ceed6f
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b0ceed6f
编写于
9月 23, 2019
作者:
J
juncaipeng
提交者:
GitHub
9月 23, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add fake_quant_dequant_op for average pool2d, test=develop (#19880)
* add fake_quant_dequant_op for average pool2d * add test
上级
cb8f3c03
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
94 addition
and
8 deletion
+94
-8
python/paddle/fluid/contrib/slim/quantization/quantization_pass.py
...ddle/fluid/contrib/slim/quantization/quantization_pass.py
+30
-7
python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py
...paddle/fluid/contrib/slim/tests/test_quantization_pass.py
+64
-1
未找到文件。
python/paddle/fluid/contrib/slim/quantization/quantization_pass.py
浏览文件 @
b0ceed6f
...
@@ -90,6 +90,9 @@ class QuantizationTransformPass(object):
...
@@ -90,6 +90,9 @@ class QuantizationTransformPass(object):
usually is not used for weight, since weights are fixed once the
usually is not used for weight, since weights are fixed once the
model is well trained.
model is well trained.
window_size (int): the window size for 'range_abs_max' quantization.
window_size (int): the window size for 'range_abs_max' quantization.
skip_pattern(str): The user-defined quantization skip pattern, which
will be presented in the name scope of an op. When the skip pattern is
detected in an op's name scope, the corresponding op will not be quantized.
Examples:
Examples:
.. code-block:: python
.. code-block:: python
...
@@ -1163,29 +1166,31 @@ class AddQuantDequantPass(object):
...
@@ -1163,29 +1166,31 @@ class AddQuantDequantPass(object):
def
__init__
(
self
,
scope
=
None
,
place
=
None
,
moving_rate
=
0.9
,
quant_bits
=
8
):
def
__init__
(
self
,
scope
=
None
,
place
=
None
,
moving_rate
=
0.9
,
quant_bits
=
8
):
"""
"""
This pass is used to add quant_dequant op for some ops, such as the
This pass is used to add quant_dequant op for some ops, such as the
`elementwise_add`
op.
'elementwise_add' and 'average pool2d'
op.
"""
"""
self
.
_scope
=
scope
self
.
_scope
=
scope
self
.
_place
=
place
self
.
_place
=
place
self
.
_moving_rate
=
moving_rate
self
.
_moving_rate
=
moving_rate
self
.
_quant_bits
=
quant_bits
self
.
_quant_bits
=
quant_bits
self
.
_is_test
=
None
self
.
_is_test
=
None
self
.
_target_ops
=
[
"elementwise_add"
]
self
.
_target_ops
=
[
"elementwise_add"
,
"pool2d"
]
self
.
_target_grad_ops
=
[
'%s_grad'
%
(
op
)
for
op
in
self
.
_target_ops
]
def
apply
(
self
,
graph
):
def
apply
(
self
,
graph
):
"""
"""
Add quant_dequant before some ops, such as the
`elementwise_add` op. This
Add quant_dequant before some ops, such as the
'elementwise_add'
is required by TensorRT
.
and 'average pool2d' op
.
Args:
Args:
graph(IrGraph): the target graph.
graph(IrGraph): the target graph.
"""
"""
assert
isinstance
(
graph
,
assert
isinstance
(
graph
,
IrGraph
),
'graph must be the instance of IrGraph.'
IrGraph
),
'graph must be the instance of IrGraph.'
self
.
_is_test
=
graph
.
is_test
()
self
.
_is_test
=
graph
.
is_test
()
dequantized_vars_map
=
collections
.
OrderedDict
()
ops
=
graph
.
all_op_nodes
()
ops
=
graph
.
all_op_nodes
()
for
op_node
in
ops
:
for
op_node
in
ops
:
name
=
op_node
.
name
()
if
op_node
.
name
()
in
self
.
_target_ops
:
if
name
in
self
.
_target_ops
:
in_nodes_all_not_persistable
=
True
in_nodes_all_not_persistable
=
True
for
input_name
in
op_node
.
input_arg_names
():
for
input_name
in
op_node
.
input_arg_names
():
in_node
=
graph
.
_find_node_by_name
(
op_node
.
inputs
,
in_node
=
graph
.
_find_node_by_name
(
op_node
.
inputs
,
...
@@ -1195,13 +1200,31 @@ class AddQuantDequantPass(object):
...
@@ -1195,13 +1200,31 @@ class AddQuantDequantPass(object):
not
in_node
.
persistable
())
not
in_node
.
persistable
())
if
not
in_nodes_all_not_persistable
:
if
not
in_nodes_all_not_persistable
:
continue
continue
if
op_node
.
op
().
has_attr
(
"pooling_type"
)
and
\
op_node
.
op
().
attr
(
"pooling_type"
)
==
'max'
:
continue
input_names
=
op_node
.
input_arg_names
()
input_names
=
op_node
.
input_arg_names
()
for
input_name
in
input_names
:
for
input_name
in
input_names
:
in_node
=
graph
.
_find_node_by_name
(
op_node
.
inputs
,
in_node
=
graph
.
_find_node_by_name
(
op_node
.
inputs
,
input_name
)
input_name
)
quant_var_node
,
scale_var_node
=
self
.
_inser_quant_dequant_moving_average_abs_max_op
(
quant_var_node
,
scale_var_node
=
\
self
.
_inser_quant_dequant_moving_average_abs_max_op
(
graph
,
in_node
,
self
.
_quant_bits
)
graph
,
in_node
,
self
.
_quant_bits
)
dequantized_vars_map
[
input_name
]
=
quant_var_node
graph
.
update_input_link
(
in_node
,
quant_var_node
,
op_node
)
graph
.
update_input_link
(
in_node
,
quant_var_node
,
op_node
)
for
op_node
in
ops
:
if
op_node
.
name
()
in
self
.
_target_grad_ops
:
for
input_name
in
op_node
.
input_arg_names
():
if
input_name
in
dequantized_vars_map
:
in_node
=
graph
.
_find_node_by_name
(
op_node
.
inputs
,
input_name
)
dequant_var_node
=
dequantized_vars_map
[
input_name
]
graph
.
update_input_link
(
in_node
,
dequant_var_node
,
op_node
)
graph
.
resolve_hazard
()
graph
.
resolve_hazard
()
return
graph
return
graph
...
...
python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py
浏览文件 @
b0ceed6f
...
@@ -24,6 +24,7 @@ from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
...
@@ -24,6 +24,7 @@ from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from
paddle.fluid.contrib.slim.quantization
import
QuantizationFreezePass
from
paddle.fluid.contrib.slim.quantization
import
QuantizationFreezePass
from
paddle.fluid.contrib.slim.quantization
import
ConvertToInt8Pass
from
paddle.fluid.contrib.slim.quantization
import
ConvertToInt8Pass
from
paddle.fluid.contrib.slim.quantization
import
TransformForMobilePass
from
paddle.fluid.contrib.slim.quantization
import
TransformForMobilePass
from
paddle.fluid.contrib.slim.quantization
import
AddQuantDequantPass
from
paddle.fluid
import
core
from
paddle.fluid
import
core
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"0"
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"0"
...
@@ -66,7 +67,9 @@ def residual_block(num):
...
@@ -66,7 +67,9 @@ def residual_block(num):
conv
=
conv_bn_layer
(
hidden
,
16
,
3
,
1
,
1
,
act
=
None
,
bias_attr
=
True
)
conv
=
conv_bn_layer
(
hidden
,
16
,
3
,
1
,
1
,
act
=
None
,
bias_attr
=
True
)
short
=
conv_bn_layer
(
hidden
,
16
,
1
,
1
,
0
,
act
=
None
)
short
=
conv_bn_layer
(
hidden
,
16
,
1
,
1
,
0
,
act
=
None
)
hidden
=
fluid
.
layers
.
elementwise_add
(
x
=
conv
,
y
=
short
,
act
=
'relu'
)
hidden
=
fluid
.
layers
.
elementwise_add
(
x
=
conv
,
y
=
short
,
act
=
'relu'
)
fc
=
fluid
.
layers
.
fc
(
input
=
hidden
,
size
=
10
)
pool
=
fluid
.
layers
.
pool2d
(
input
=
hidden
,
pool_size
=
2
,
pool_type
=
'avg'
,
pool_stride
=
2
)
fc
=
fluid
.
layers
.
fc
(
input
=
pool
,
size
=
10
)
loss
=
fluid
.
layers
.
cross_entropy
(
input
=
fc
,
label
=
label
)
loss
=
fluid
.
layers
.
cross_entropy
(
input
=
fc
,
label
=
label
)
loss
=
fluid
.
layers
.
mean
(
loss
)
loss
=
fluid
.
layers
.
mean
(
loss
)
return
loss
return
loss
...
@@ -486,5 +489,65 @@ class TestQuantizationFreezePass(unittest.TestCase):
...
@@ -486,5 +489,65 @@ class TestQuantizationFreezePass(unittest.TestCase):
for_ci
=
True
)
for_ci
=
True
)
class
TestAddQuantDequantPass
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
_target_ops
=
{
'elementwise_add'
,
'pool2d'
}
self
.
_target_grad_ops
=
{
'elementwise_add_grad'
,
'pool2d_grad'
}
def
check_graph
(
self
,
graph
):
ops
=
graph
.
all_op_nodes
()
for
op_node
in
ops
:
if
op_node
.
name
()
in
self
.
_target_ops
:
in_nodes_all_not_persistable
=
True
for
input_name
in
op_node
.
input_arg_names
():
in_node
=
graph
.
_find_node_by_name
(
op_node
.
inputs
,
input_name
)
in_nodes_all_not_persistable
=
(
in_nodes_all_not_persistable
and
not
in_node
.
persistable
())
if
not
in_nodes_all_not_persistable
:
continue
if
op_node
.
op
().
has_attr
(
"pooling_type"
)
and
\
op_node
.
op
().
attr
(
"pooling_type"
)
==
'max'
:
continue
input_names
=
op_node
.
input_arg_names
()
for
input_name
in
input_names
:
self
.
assertTrue
(
input_name
.
endswith
(
'.quant_dequant'
))
def
residual_block_quant
(
self
,
for_ci
=
True
):
main
=
fluid
.
Program
()
startup
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main
,
startup
):
loss
=
residual_block
(
1
)
opt
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
0.001
)
opt
.
minimize
(
loss
)
place
=
fluid
.
CPUPlace
()
graph
=
IrGraph
(
core
.
Graph
(
main
.
desc
),
for_test
=
False
)
add_quant_dequant_pass
=
AddQuantDequantPass
(
scope
=
fluid
.
global_scope
(),
place
=
place
)
add_quant_dequant_pass
.
apply
(
graph
)
if
not
for_ci
:
marked_nodes
=
set
()
for
op
in
graph
.
all_op_nodes
():
if
op
.
name
().
find
(
'quant'
)
>
-
1
:
marked_nodes
.
add
(
op
)
graph
.
draw
(
'.'
,
'add_quant_dequant_graph'
,
marked_nodes
)
self
.
check_graph
(
graph
)
program
=
graph
.
to_program
()
val_graph
=
IrGraph
(
core
.
Graph
(
program
.
desc
),
for_test
=
False
)
if
not
for_ci
:
val_marked_nodes
=
set
()
for
op
in
val_graph
.
all_op_nodes
():
if
op
.
name
().
find
(
'quant'
)
>
-
1
:
val_marked_nodes
.
add
(
op
)
val_graph
.
draw
(
'.'
,
'val_add_quant_dequant_graph'
,
val_marked_nodes
)
def
test_residual_block
(
self
):
self
.
residual_block_quant
(
for_ci
=
True
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录