Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
ef536250
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ef536250
编写于
8月 31, 2021
作者:
X
XGZhang
提交者:
GitHub
8月 31, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support fuse layers for ptq (#35015)
上级
561841d2
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
340 addition
and
6 deletion
+340
-6
python/paddle/fluid/contrib/slim/quantization/imperative/fuse_utils.py
.../fluid/contrib/slim/quantization/imperative/fuse_utils.py
+175
-0
python/paddle/fluid/contrib/slim/quantization/imperative/ptq.py
.../paddle/fluid/contrib/slim/quantization/imperative/ptq.py
+13
-4
python/paddle/fluid/contrib/slim/tests/imperative_test_utils.py
.../paddle/fluid/contrib/slim/tests/imperative_test_utils.py
+60
-0
python/paddle/fluid/contrib/slim/tests/test_imperative_ptq.py
...on/paddle/fluid/contrib/slim/tests/test_imperative_ptq.py
+92
-2
未找到文件。
python/paddle/fluid/contrib/slim/quantization/imperative/fuse_utils.py
0 → 100644
浏览文件 @
ef536250
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
copy
import
paddle
import
paddle.nn
as
nn
from
.
import
utils
class
Identity
(
nn
.
Layer
):
'''a layer to replace bn or relu layers'''
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
Identity
,
self
).
__init__
()
def
forward
(
self
,
input
):
return
input
def
fuse_layers
(
model
,
layers_to_fuse
,
inplace
=
False
):
'''
fuse layers in layers_to_fuse
Args:
model(paddle.nn.Layer): The model to be fused.
layers_to_fuse(list): The layers' names to be fused. For
example,"fuse_list = [["conv1", "bn1"], ["conv2", "bn2"]]".
A TypeError would be raised if "fuse" was set as
True but "fuse_list" was None.
Default: None.
inplace(bool): Whether apply fusing to the input model.
Default: False.
Return
fused_model(paddle.nn.Layer): The fused model.
'''
if
inplace
==
False
:
model
=
copy
.
deepcopy
(
model
)
for
layers
in
layers_to_fuse
:
_fuse_layers
(
model
,
layers
)
return
model
def
_fuse_layers
(
model
,
layers_list
):
'''fuse all the layers in layers_list'''
layer_list
=
[]
for
layer_name
in
layers_list
:
parent_layer
,
sub_name
=
utils
.
find_parent_layer_and_sub_name
(
model
,
layer_name
)
layer_list
.
append
(
getattr
(
parent_layer
,
sub_name
))
new_layers
=
_fuse_func
(
layer_list
)
for
i
,
item
in
enumerate
(
layers_list
):
parent_layer
,
sub_name
=
utils
.
find_parent_layer_and_sub_name
(
model
,
item
)
setattr
(
parent_layer
,
sub_name
,
new_layers
[
i
])
def
_fuse_func
(
layer_list
):
'''choose the fuser method and fuse layers'''
types
=
tuple
(
type
(
m
)
for
m
in
layer_list
)
fusion_method
=
types_to_fusion_method
.
get
(
types
,
None
)
new_layers
=
[
None
]
*
len
(
layer_list
)
fused_layer
=
fusion_method
(
*
layer_list
)
for
handle_id
,
pre_hook_fn
in
layer_list
[
0
].
_forward_pre_hooks
.
items
():
fused_layer
.
register_forward_pre_hook
(
pre_hook_fn
)
del
layer_list
[
0
].
_forward_pre_hooks
[
handle_id
]
for
handle_id
,
hook_fn
in
layer_list
[
-
1
].
_forward_post_hooks
.
items
():
fused_layer
.
register_forward_post_hook
(
hook_fn
)
del
layer_list
[
-
1
].
_forward_post_hooks
[
handle_id
]
new_layers
[
0
]
=
fused_layer
for
i
in
range
(
1
,
len
(
layer_list
)):
identity
=
Identity
()
identity
.
training
=
layer_list
[
0
].
training
new_layers
[
i
]
=
identity
return
new_layers
def
_fuse_conv_bn
(
conv
,
bn
):
'''fuse conv and bn for train or eval'''
assert
(
conv
.
training
==
bn
.
training
),
\
"Conv and BN both must be in the same mode (train or eval)."
if
conv
.
training
:
assert
bn
.
_num_features
==
conv
.
_out_channels
,
'Output channel of Conv2d must match num_features of BatchNorm2d'
raise
NotImplementedError
else
:
return
_fuse_conv_bn_eval
(
conv
,
bn
)
def
_fuse_conv_bn_eval
(
conv
,
bn
):
'''fuse conv and bn for eval'''
assert
(
not
(
conv
.
training
or
bn
.
training
)),
"Fusion only for eval!"
fused_conv
=
copy
.
deepcopy
(
conv
)
fused_weight
,
fused_bias
=
_fuse_conv_bn_weights
(
fused_conv
.
weight
,
fused_conv
.
bias
,
bn
.
_mean
,
bn
.
_variance
,
bn
.
_epsilon
,
bn
.
weight
,
bn
.
bias
)
fused_conv
.
weight
.
set_value
(
fused_weight
)
if
fused_conv
.
bias
is
None
:
fused_conv
.
bias
=
paddle
.
create_parameter
(
shape
=
[
fused_conv
.
_out_channels
],
is_bias
=
True
,
dtype
=
bn
.
bias
.
dtype
)
fused_conv
.
bias
.
set_value
(
fused_bias
)
return
fused_conv
def
_fuse_conv_bn_weights
(
conv_w
,
conv_b
,
bn_rm
,
bn_rv
,
bn_eps
,
bn_w
,
bn_b
):
'''fuse weights and bias of conv and bn'''
if
conv_b
is
None
:
conv_b
=
paddle
.
zeros_like
(
bn_rm
)
if
bn_w
is
None
:
bn_w
=
paddle
.
ones_like
(
bn_rm
)
if
bn_b
is
None
:
bn_b
=
paddle
.
zeros_like
(
bn_rm
)
bn_var_rsqrt
=
paddle
.
rsqrt
(
bn_rv
+
bn_eps
)
conv_w
=
conv_w
*
\
(
bn_w
*
bn_var_rsqrt
).
reshape
([
-
1
]
+
[
1
]
*
(
len
(
conv_w
.
shape
)
-
1
))
conv_b
=
(
conv_b
-
bn_rm
)
*
bn_var_rsqrt
*
bn_w
+
bn_b
return
conv_w
,
conv_b
def
_fuse_linear_bn
(
linear
,
bn
):
'''fuse linear and bn'''
assert
(
linear
.
training
==
bn
.
training
),
\
"Linear and BN both must be in the same mode (train or eval)."
if
linear
.
training
:
assert
bn
.
_num_features
==
linear
.
weight
.
shape
[
1
],
'Output channel of Linear must match num_features of BatchNorm'
raise
NotImplementedError
else
:
return
_fuse_linear_bn_eval
(
linear
,
bn
)
def
_fuse_linear_bn_eval
(
linear
,
bn
):
'''fuse linear and bn for eval'''
assert
(
not
(
linear
.
training
or
bn
.
training
)),
"Fusion only for eval!"
fused_linear
=
copy
.
deepcopy
(
linear
)
fused_weight
,
fused_bias
=
_fuse_linear_bn_weights
(
fused_linear
.
weight
,
fused_linear
.
bias
,
bn
.
_mean
,
bn
.
_variance
,
bn
.
_epsilon
,
bn
.
weight
,
bn
.
bias
)
fused_linear
.
weight
.
set_value
(
fused_weight
)
if
fused_linear
.
bias
is
None
:
fused_linear
.
bias
=
paddle
.
create_parameter
(
shape
=
[
fused_linear
.
weight
.
shape
[
1
]],
is_bias
=
True
,
dtype
=
bn
.
bias
.
dtype
)
fused_linear
.
bias
.
set_value
(
fused_bias
)
return
fused_linear
def
_fuse_linear_bn_weights
(
linear_w
,
linear_b
,
bn_rm
,
bn_rv
,
bn_eps
,
bn_w
,
bn_b
):
'''fuse weights and bias of linear and bn'''
if
linear_b
is
None
:
linear_b
=
paddle
.
zeros_like
(
bn_rm
)
bn_scale
=
bn_w
*
paddle
.
rsqrt
(
bn_rv
+
bn_eps
)
fused_w
=
linear_w
*
bn_scale
.
unsqueeze
(
-
1
)
fused_b
=
(
linear_b
-
bn_rm
)
*
bn_scale
+
bn_b
return
fused_w
,
fused_b
types_to_fusion_method
=
{
(
nn
.
Conv2D
,
nn
.
BatchNorm2D
):
_fuse_conv_bn
,
(
nn
.
Linear
,
nn
.
BatchNorm1D
):
_fuse_linear_bn
,
}
python/paddle/fluid/contrib/slim/quantization/imperative/ptq.py
浏览文件 @
ef536250
...
...
@@ -22,6 +22,7 @@ import paddle.nn.quant.quant_layers as quant_layers
from
paddle.fluid.log_helper
import
get_logger
from
paddle.fluid.dygraph.io
import
INFER_MODEL_SUFFIX
,
INFER_PARAMS_SUFFIX
from
.
import
fuse_utils
from
.
import
utils
from
.
import
ptq_hooks
from
.
import
ptq_config
...
...
@@ -55,7 +56,7 @@ class ImperativePTQ(object):
self
.
_quant_config
=
quant_config
def
quantize
(
self
,
model
,
inplace
=
False
):
def
quantize
(
self
,
model
,
inplace
=
False
,
fuse
=
False
,
fuse_list
=
None
):
"""
Add quant config and hook to the target layer.
...
...
@@ -63,15 +64,23 @@ class ImperativePTQ(object):
model(paddle.nn.Layer): The model to be quantized.
inplace(bool): Whether apply quantization to the input model.
Default: False.
Returns:
fuse(bool): Whether to fuse layers.
Default: False.
fuse_list(list): The layers' names to be fused. For example,
"fuse_list = [["conv1", "bn1"], ["conv2", "bn2"]]".
A TypeError would be raised if "fuse" was set as
True but "fuse_list" was None.
Default: None.
Return
quantized_model(paddle.nn.Layer): The quantized model.
"""
assert
isinstance
(
model
,
paddle
.
nn
.
Layer
),
\
"The model must be the instance of paddle.nn.Layer."
if
not
inplace
:
model
=
copy
.
deepcopy
(
model
)
if
fuse
:
model
.
eval
()
model
=
fuse_utils
.
fuse_layers
(
model
,
fuse_list
)
for
name
,
layer
in
model
.
named_sublayers
():
if
PTQRegistry
.
is_supported_layer
(
layer
)
\
and
utils
.
is_leaf_layer
(
layer
)
\
...
...
python/paddle/fluid/contrib/slim/tests/imperative_test_utils.py
浏览文件 @
ef536250
...
...
@@ -20,6 +20,7 @@ from paddle.fluid import core
from
paddle.fluid.dygraph.container
import
Sequential
from
paddle.nn
import
ReLU
,
ReLU6
,
LeakyReLU
,
Sigmoid
,
Softmax
,
PReLU
from
paddle.nn
import
Linear
,
Conv2D
,
Softmax
,
BatchNorm2D
,
MaxPool2D
from
paddle.nn
import
BatchNorm1D
from
paddle.fluid.log_helper
import
get_logger
...
...
@@ -43,6 +44,15 @@ def fix_model_dict(model):
return
model
def
pre_hook
(
layer
,
input
):
input_return
=
(
input
[
0
]
*
2
)
return
input_return
def
post_hook
(
layer
,
input
,
output
):
return
output
*
2
def
train_lenet
(
lenet
,
reader
,
optimizer
):
loss_list
=
[]
lenet
.
train
()
...
...
@@ -224,3 +234,53 @@ class ImperativeLenetWithSkipQuant(fluid.dygraph.Layer):
x
=
self
.
softmax_0
(
x
)
return
x
class
ImperativeLinearBn
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
):
super
(
ImperativeLinearBn
,
self
).
__init__
()
fc_w_attr
=
paddle
.
ParamAttr
(
name
=
"fc_weight"
,
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
0.5
))
fc_b_attr
=
paddle
.
ParamAttr
(
name
=
"fc_bias"
,
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
1.0
))
bn_w_attr
=
paddle
.
ParamAttr
(
name
=
"bn_weight"
,
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
0.5
))
self
.
linear
=
Linear
(
in_features
=
10
,
out_features
=
10
,
weight_attr
=
fc_w_attr
,
bias_attr
=
fc_b_attr
)
self
.
bn
=
BatchNorm1D
(
10
,
weight_attr
=
bn_w_attr
)
def
forward
(
self
,
inputs
):
x
=
self
.
linear
(
inputs
)
x
=
self
.
bn
(
x
)
return
x
class
ImperativeLinearBn_hook
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
):
super
(
ImperativeLinearBn_hook
,
self
).
__init__
()
fc_w_attr
=
paddle
.
ParamAttr
(
name
=
"linear_weight"
,
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
0.5
))
self
.
linear
=
Linear
(
in_features
=
10
,
out_features
=
10
,
weight_attr
=
fc_w_attr
)
self
.
bn
=
BatchNorm1D
(
10
)
forward_pre
=
self
.
linear
.
register_forward_pre_hook
(
pre_hook
)
forward_post
=
self
.
bn
.
register_forward_post_hook
(
post_hook
)
def
forward
(
self
,
inputs
):
x
=
self
.
linear
(
inputs
)
x
=
self
.
bn
(
x
)
return
x
python/paddle/fluid/contrib/slim/tests/test_imperative_ptq.py
浏览文件 @
ef536250
...
...
@@ -23,18 +23,48 @@ import unittest
import
copy
import
logging
import
paddle.nn
as
nn
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid.contrib.slim.quantization
import
*
from
paddle.fluid.log_helper
import
get_logger
from
paddle.dataset.common
import
download
from
imperative_test_utils
import
fix_model_dict
,
ImperativeLenet
from
imperative_test_utils
import
fix_model_dict
,
ImperativeLenet
,
ImperativeLinearBn
from
imperative_test_utils
import
ImperativeLinearBn_hook
_logger
=
get_logger
(
__name__
,
logging
.
INFO
,
fmt
=
'%(asctime)s-%(levelname)s: %(message)s'
)
class
TestFuseLinearBn
(
unittest
.
TestCase
):
"""
Fuse the linear and bn layers, and then quantize the model.
"""
def
test_fuse
(
self
):
model
=
ImperativeLinearBn
()
model_h
=
ImperativeLinearBn_hook
()
inputs
=
paddle
.
randn
((
3
,
10
),
dtype
=
"float32"
)
config
=
PTQConfig
(
AbsmaxQuantizer
(),
AbsmaxQuantizer
())
ptq
=
ImperativePTQ
(
config
)
f_l
=
[[
'linear'
,
'bn'
]]
quant_model
=
ptq
.
quantize
(
model
,
fuse
=
True
,
fuse_list
=
f_l
)
quant_h
=
ptq
.
quantize
(
model_h
,
fuse
=
True
,
fuse_list
=
f_l
)
for
name
,
layer
in
quant_model
.
named_sublayers
():
if
name
in
f_l
:
assert
not
(
isinstance
(
layer
,
nn
.
BatchNorm1D
)
or
isinstance
(
layer
,
nn
.
BatchNorm2D
))
out
=
model
(
inputs
)
out_h
=
model_h
(
inputs
)
out_quant
=
quant_model
(
inputs
)
out_quant_h
=
quant_h
(
inputs
)
cos_sim_func
=
nn
.
CosineSimilarity
(
axis
=
0
)
print
(
'fuse linear+bn'
,
cos_sim_func
(
out
.
flatten
(),
out_quant
.
flatten
()))
print
(
cos_sim_func
(
out_h
.
flatten
(),
out_quant_h
.
flatten
()))
class
TestImperativePTQ
(
unittest
.
TestCase
):
"""
"""
...
...
@@ -177,7 +207,6 @@ class TestImperativePTQ(unittest.TestCase):
model
=
ImperativeLenet
()
model_state_dict
=
paddle
.
load
(
params_path
)
model
.
set_state_dict
(
model_state_dict
)
# Quantize, calibrate and save
quant_model
=
self
.
ptq
.
quantize
(
model
)
before_acc_top1
=
self
.
model_test
(
quant_model
,
self
.
batch_num
,
...
...
@@ -216,6 +245,67 @@ class TestImperativePTQ(unittest.TestCase):
print
(
"total time: %ss
\n
"
%
(
end_time
-
start_time
))
class
TestImperativePTQfuse
(
TestImperativePTQ
):
def
test_ptq
(
self
):
start_time
=
time
.
time
()
self
.
set_vars
()
# Load model
params_path
=
self
.
download_model
(
self
.
lenet_url
,
self
.
lenet_md5
,
"lenet"
)
params_path
+=
"/lenet_pretrained/lenet.pdparams"
model
=
ImperativeLenet
()
model_state_dict
=
paddle
.
load
(
params_path
)
model
.
set_state_dict
(
model_state_dict
)
# Quantize, calibrate and save
f_l
=
[[
'features.0'
,
'features.1'
],
[
'features.4'
,
'features.5'
]]
quant_model
=
self
.
ptq
.
quantize
(
model
,
fuse
=
True
,
fuse_list
=
f_l
)
for
name
,
layer
in
quant_model
.
named_sublayers
():
if
name
in
f_l
:
assert
not
(
isinstance
(
layer
,
nn
.
BatchNorm1D
)
or
isinstance
(
layer
,
nn
.
BatchNorm2D
))
before_acc_top1
=
self
.
model_test
(
quant_model
,
self
.
batch_num
,
self
.
batch_size
)
input_spec
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
1
,
28
,
28
],
dtype
=
'float32'
)
]
self
.
ptq
.
save_quantized_model
(
model
=
quant_model
,
path
=
self
.
save_path
,
input_spec
=
input_spec
)
print
(
'Quantized model saved in {%s}'
%
self
.
save_path
)
after_acc_top1
=
self
.
model_test
(
quant_model
,
self
.
batch_num
,
self
.
batch_size
)
paddle
.
enable_static
()
infer_acc_top1
=
self
.
program_test
(
self
.
save_path
,
self
.
batch_num
,
self
.
batch_size
)
paddle
.
disable_static
()
# Check
print
(
'Before converted acc_top1: %s'
%
before_acc_top1
)
print
(
'After converted acc_top1: %s'
%
after_acc_top1
)
print
(
'Infer acc_top1: %s'
%
infer_acc_top1
)
#Check whether the quant_model is correct after converting.
#The acc of quantized model should be higher than 0.95.
self
.
assertTrue
(
after_acc_top1
>=
self
.
eval_acc_top1
,
msg
=
"The test acc {%f} is less than {%f}."
%
(
after_acc_top1
,
self
.
eval_acc_top1
))
#Check the saved infer_model.The acc of infer model
#should not be lower than the one of dygraph model.
self
.
assertTrue
(
infer_acc_top1
>=
after_acc_top1
,
msg
=
'The acc is lower after converting model.'
)
end_time
=
time
.
time
()
print
(
"total time: %ss
\n
"
%
(
end_time
-
start_time
))
class
TestImperativePTQHist
(
TestImperativePTQ
):
def
set_vars
(
self
):
config
=
PTQConfig
(
HistQuantizer
(),
AbsmaxQuantizer
())
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录