Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
ef536250
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ef536250
编写于
8月 31, 2021
作者:
X
XGZhang
提交者:
GitHub
8月 31, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support fuse layers for ptq (#35015)
上级
561841d2
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
340 addition
and
6 deletion
+340
-6
python/paddle/fluid/contrib/slim/quantization/imperative/fuse_utils.py
.../fluid/contrib/slim/quantization/imperative/fuse_utils.py
+175
-0
python/paddle/fluid/contrib/slim/quantization/imperative/ptq.py
.../paddle/fluid/contrib/slim/quantization/imperative/ptq.py
+13
-4
python/paddle/fluid/contrib/slim/tests/imperative_test_utils.py
.../paddle/fluid/contrib/slim/tests/imperative_test_utils.py
+60
-0
python/paddle/fluid/contrib/slim/tests/test_imperative_ptq.py
...on/paddle/fluid/contrib/slim/tests/test_imperative_ptq.py
+92
-2
未找到文件。
python/paddle/fluid/contrib/slim/quantization/imperative/fuse_utils.py
0 → 100644
浏览文件 @
ef536250
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
copy
import
paddle
import
paddle.nn
as
nn
from
.
import
utils
class
Identity
(
nn
.
Layer
):
'''a layer to replace bn or relu layers'''
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
Identity
,
self
).
__init__
()
def
forward
(
self
,
input
):
return
input
def
fuse_layers
(
model
,
layers_to_fuse
,
inplace
=
False
):
'''
fuse layers in layers_to_fuse
Args:
model(paddle.nn.Layer): The model to be fused.
layers_to_fuse(list): The layers' names to be fused. For
example,"fuse_list = [["conv1", "bn1"], ["conv2", "bn2"]]".
A TypeError would be raised if "fuse" was set as
True but "fuse_list" was None.
Default: None.
inplace(bool): Whether apply fusing to the input model.
Default: False.
Return
fused_model(paddle.nn.Layer): The fused model.
'''
if
inplace
==
False
:
model
=
copy
.
deepcopy
(
model
)
for
layers
in
layers_to_fuse
:
_fuse_layers
(
model
,
layers
)
return
model
def
_fuse_layers
(
model
,
layers_list
):
'''fuse all the layers in layers_list'''
layer_list
=
[]
for
layer_name
in
layers_list
:
parent_layer
,
sub_name
=
utils
.
find_parent_layer_and_sub_name
(
model
,
layer_name
)
layer_list
.
append
(
getattr
(
parent_layer
,
sub_name
))
new_layers
=
_fuse_func
(
layer_list
)
for
i
,
item
in
enumerate
(
layers_list
):
parent_layer
,
sub_name
=
utils
.
find_parent_layer_and_sub_name
(
model
,
item
)
setattr
(
parent_layer
,
sub_name
,
new_layers
[
i
])
def
_fuse_func
(
layer_list
):
'''choose the fuser method and fuse layers'''
types
=
tuple
(
type
(
m
)
for
m
in
layer_list
)
fusion_method
=
types_to_fusion_method
.
get
(
types
,
None
)
new_layers
=
[
None
]
*
len
(
layer_list
)
fused_layer
=
fusion_method
(
*
layer_list
)
for
handle_id
,
pre_hook_fn
in
layer_list
[
0
].
_forward_pre_hooks
.
items
():
fused_layer
.
register_forward_pre_hook
(
pre_hook_fn
)
del
layer_list
[
0
].
_forward_pre_hooks
[
handle_id
]
for
handle_id
,
hook_fn
in
layer_list
[
-
1
].
_forward_post_hooks
.
items
():
fused_layer
.
register_forward_post_hook
(
hook_fn
)
del
layer_list
[
-
1
].
_forward_post_hooks
[
handle_id
]
new_layers
[
0
]
=
fused_layer
for
i
in
range
(
1
,
len
(
layer_list
)):
identity
=
Identity
()
identity
.
training
=
layer_list
[
0
].
training
new_layers
[
i
]
=
identity
return
new_layers
def
_fuse_conv_bn
(
conv
,
bn
):
'''fuse conv and bn for train or eval'''
assert
(
conv
.
training
==
bn
.
training
),
\
"Conv and BN both must be in the same mode (train or eval)."
if
conv
.
training
:
assert
bn
.
_num_features
==
conv
.
_out_channels
,
'Output channel of Conv2d must match num_features of BatchNorm2d'
raise
NotImplementedError
else
:
return
_fuse_conv_bn_eval
(
conv
,
bn
)
def
_fuse_conv_bn_eval
(
conv
,
bn
):
'''fuse conv and bn for eval'''
assert
(
not
(
conv
.
training
or
bn
.
training
)),
"Fusion only for eval!"
fused_conv
=
copy
.
deepcopy
(
conv
)
fused_weight
,
fused_bias
=
_fuse_conv_bn_weights
(
fused_conv
.
weight
,
fused_conv
.
bias
,
bn
.
_mean
,
bn
.
_variance
,
bn
.
_epsilon
,
bn
.
weight
,
bn
.
bias
)
fused_conv
.
weight
.
set_value
(
fused_weight
)
if
fused_conv
.
bias
is
None
:
fused_conv
.
bias
=
paddle
.
create_parameter
(
shape
=
[
fused_conv
.
_out_channels
],
is_bias
=
True
,
dtype
=
bn
.
bias
.
dtype
)
fused_conv
.
bias
.
set_value
(
fused_bias
)
return
fused_conv
def
_fuse_conv_bn_weights
(
conv_w
,
conv_b
,
bn_rm
,
bn_rv
,
bn_eps
,
bn_w
,
bn_b
):
'''fuse weights and bias of conv and bn'''
if
conv_b
is
None
:
conv_b
=
paddle
.
zeros_like
(
bn_rm
)
if
bn_w
is
None
:
bn_w
=
paddle
.
ones_like
(
bn_rm
)
if
bn_b
is
None
:
bn_b
=
paddle
.
zeros_like
(
bn_rm
)
bn_var_rsqrt
=
paddle
.
rsqrt
(
bn_rv
+
bn_eps
)
conv_w
=
conv_w
*
\
(
bn_w
*
bn_var_rsqrt
).
reshape
([
-
1
]
+
[
1
]
*
(
len
(
conv_w
.
shape
)
-
1
))
conv_b
=
(
conv_b
-
bn_rm
)
*
bn_var_rsqrt
*
bn_w
+
bn_b
return
conv_w
,
conv_b
def
_fuse_linear_bn
(
linear
,
bn
):
'''fuse linear and bn'''
assert
(
linear
.
training
==
bn
.
training
),
\
"Linear and BN both must be in the same mode (train or eval)."
if
linear
.
training
:
assert
bn
.
_num_features
==
linear
.
weight
.
shape
[
1
],
'Output channel of Linear must match num_features of BatchNorm'
raise
NotImplementedError
else
:
return
_fuse_linear_bn_eval
(
linear
,
bn
)
def
_fuse_linear_bn_eval
(
linear
,
bn
):
'''fuse linear and bn for eval'''
assert
(
not
(
linear
.
training
or
bn
.
training
)),
"Fusion only for eval!"
fused_linear
=
copy
.
deepcopy
(
linear
)
fused_weight
,
fused_bias
=
_fuse_linear_bn_weights
(
fused_linear
.
weight
,
fused_linear
.
bias
,
bn
.
_mean
,
bn
.
_variance
,
bn
.
_epsilon
,
bn
.
weight
,
bn
.
bias
)
fused_linear
.
weight
.
set_value
(
fused_weight
)
if
fused_linear
.
bias
is
None
:
fused_linear
.
bias
=
paddle
.
create_parameter
(
shape
=
[
fused_linear
.
weight
.
shape
[
1
]],
is_bias
=
True
,
dtype
=
bn
.
bias
.
dtype
)
fused_linear
.
bias
.
set_value
(
fused_bias
)
return
fused_linear
def
_fuse_linear_bn_weights
(
linear_w
,
linear_b
,
bn_rm
,
bn_rv
,
bn_eps
,
bn_w
,
bn_b
):
'''fuse weights and bias of linear and bn'''
if
linear_b
is
None
:
linear_b
=
paddle
.
zeros_like
(
bn_rm
)
bn_scale
=
bn_w
*
paddle
.
rsqrt
(
bn_rv
+
bn_eps
)
fused_w
=
linear_w
*
bn_scale
.
unsqueeze
(
-
1
)
fused_b
=
(
linear_b
-
bn_rm
)
*
bn_scale
+
bn_b
return
fused_w
,
fused_b
types_to_fusion_method
=
{
(
nn
.
Conv2D
,
nn
.
BatchNorm2D
):
_fuse_conv_bn
,
(
nn
.
Linear
,
nn
.
BatchNorm1D
):
_fuse_linear_bn
,
}
python/paddle/fluid/contrib/slim/quantization/imperative/ptq.py
浏览文件 @
ef536250
...
...
@@ -22,6 +22,7 @@ import paddle.nn.quant.quant_layers as quant_layers
from
paddle.fluid.log_helper
import
get_logger
from
paddle.fluid.dygraph.io
import
INFER_MODEL_SUFFIX
,
INFER_PARAMS_SUFFIX
from
.
import
fuse_utils
from
.
import
utils
from
.
import
ptq_hooks
from
.
import
ptq_config
...
...
@@ -55,7 +56,7 @@ class ImperativePTQ(object):
self
.
_quant_config
=
quant_config
def
quantize
(
self
,
model
,
inplace
=
False
):
def
quantize
(
self
,
model
,
inplace
=
False
,
fuse
=
False
,
fuse_list
=
None
):
"""
Add quant config and hook to the target layer.
...
...
@@ -63,15 +64,23 @@ class ImperativePTQ(object):
model(paddle.nn.Layer): The model to be quantized.
inplace(bool): Whether apply quantization to the input model.
Default: False.
Returns:
fuse(bool): Whether to fuse layers.
Default: False.
fuse_list(list): The layers' names to be fused. For example,
"fuse_list = [["conv1", "bn1"], ["conv2", "bn2"]]".
A TypeError would be raised if "fuse" was set as
True but "fuse_list" was None.
Default: None.
Return
quantized_model(paddle.nn.Layer): The quantized model.
"""
assert
isinstance
(
model
,
paddle
.
nn
.
Layer
),
\
"The model must be the instance of paddle.nn.Layer."
if
not
inplace
:
model
=
copy
.
deepcopy
(
model
)
if
fuse
:
model
.
eval
()
model
=
fuse_utils
.
fuse_layers
(
model
,
fuse_list
)
for
name
,
layer
in
model
.
named_sublayers
():
if
PTQRegistry
.
is_supported_layer
(
layer
)
\
and
utils
.
is_leaf_layer
(
layer
)
\
...
...
python/paddle/fluid/contrib/slim/tests/imperative_test_utils.py
浏览文件 @
ef536250
...
...
@@ -20,6 +20,7 @@ from paddle.fluid import core
from
paddle.fluid.dygraph.container
import
Sequential
from
paddle.nn
import
ReLU
,
ReLU6
,
LeakyReLU
,
Sigmoid
,
Softmax
,
PReLU
from
paddle.nn
import
Linear
,
Conv2D
,
Softmax
,
BatchNorm2D
,
MaxPool2D
from
paddle.nn
import
BatchNorm1D
from
paddle.fluid.log_helper
import
get_logger
...
...
@@ -43,6 +44,15 @@ def fix_model_dict(model):
return
model
def
pre_hook
(
layer
,
input
):
input_return
=
(
input
[
0
]
*
2
)
return
input_return
def
post_hook
(
layer
,
input
,
output
):
return
output
*
2
def
train_lenet
(
lenet
,
reader
,
optimizer
):
loss_list
=
[]
lenet
.
train
()
...
...
@@ -224,3 +234,53 @@ class ImperativeLenetWithSkipQuant(fluid.dygraph.Layer):
x
=
self
.
softmax_0
(
x
)
return
x
class
ImperativeLinearBn
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
):
super
(
ImperativeLinearBn
,
self
).
__init__
()
fc_w_attr
=
paddle
.
ParamAttr
(
name
=
"fc_weight"
,
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
0.5
))
fc_b_attr
=
paddle
.
ParamAttr
(
name
=
"fc_bias"
,
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
1.0
))
bn_w_attr
=
paddle
.
ParamAttr
(
name
=
"bn_weight"
,
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
0.5
))
self
.
linear
=
Linear
(
in_features
=
10
,
out_features
=
10
,
weight_attr
=
fc_w_attr
,
bias_attr
=
fc_b_attr
)
self
.
bn
=
BatchNorm1D
(
10
,
weight_attr
=
bn_w_attr
)
def
forward
(
self
,
inputs
):
x
=
self
.
linear
(
inputs
)
x
=
self
.
bn
(
x
)
return
x
class
ImperativeLinearBn_hook
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
):
super
(
ImperativeLinearBn_hook
,
self
).
__init__
()
fc_w_attr
=
paddle
.
ParamAttr
(
name
=
"linear_weight"
,
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
0.5
))
self
.
linear
=
Linear
(
in_features
=
10
,
out_features
=
10
,
weight_attr
=
fc_w_attr
)
self
.
bn
=
BatchNorm1D
(
10
)
forward_pre
=
self
.
linear
.
register_forward_pre_hook
(
pre_hook
)
forward_post
=
self
.
bn
.
register_forward_post_hook
(
post_hook
)
def
forward
(
self
,
inputs
):
x
=
self
.
linear
(
inputs
)
x
=
self
.
bn
(
x
)
return
x
python/paddle/fluid/contrib/slim/tests/test_imperative_ptq.py
浏览文件 @
ef536250
...
...
@@ -23,18 +23,48 @@ import unittest
import
copy
import
logging
import
paddle.nn
as
nn
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid.contrib.slim.quantization
import
*
from
paddle.fluid.log_helper
import
get_logger
from
paddle.dataset.common
import
download
from
imperative_test_utils
import
fix_model_dict
,
ImperativeLenet
from
imperative_test_utils
import
fix_model_dict
,
ImperativeLenet
,
ImperativeLinearBn
from
imperative_test_utils
import
ImperativeLinearBn_hook
_logger
=
get_logger
(
__name__
,
logging
.
INFO
,
fmt
=
'%(asctime)s-%(levelname)s: %(message)s'
)
class
TestFuseLinearBn
(
unittest
.
TestCase
):
"""
Fuse the linear and bn layers, and then quantize the model.
"""
def
test_fuse
(
self
):
model
=
ImperativeLinearBn
()
model_h
=
ImperativeLinearBn_hook
()
inputs
=
paddle
.
randn
((
3
,
10
),
dtype
=
"float32"
)
config
=
PTQConfig
(
AbsmaxQuantizer
(),
AbsmaxQuantizer
())
ptq
=
ImperativePTQ
(
config
)
f_l
=
[[
'linear'
,
'bn'
]]
quant_model
=
ptq
.
quantize
(
model
,
fuse
=
True
,
fuse_list
=
f_l
)
quant_h
=
ptq
.
quantize
(
model_h
,
fuse
=
True
,
fuse_list
=
f_l
)
for
name
,
layer
in
quant_model
.
named_sublayers
():
if
name
in
f_l
:
assert
not
(
isinstance
(
layer
,
nn
.
BatchNorm1D
)
or
isinstance
(
layer
,
nn
.
BatchNorm2D
))
out
=
model
(
inputs
)
out_h
=
model_h
(
inputs
)
out_quant
=
quant_model
(
inputs
)
out_quant_h
=
quant_h
(
inputs
)
cos_sim_func
=
nn
.
CosineSimilarity
(
axis
=
0
)
print
(
'fuse linear+bn'
,
cos_sim_func
(
out
.
flatten
(),
out_quant
.
flatten
()))
print
(
cos_sim_func
(
out_h
.
flatten
(),
out_quant_h
.
flatten
()))
class
TestImperativePTQ
(
unittest
.
TestCase
):
"""
"""
...
...
@@ -177,7 +207,6 @@ class TestImperativePTQ(unittest.TestCase):
model
=
ImperativeLenet
()
model_state_dict
=
paddle
.
load
(
params_path
)
model
.
set_state_dict
(
model_state_dict
)
# Quantize, calibrate and save
quant_model
=
self
.
ptq
.
quantize
(
model
)
before_acc_top1
=
self
.
model_test
(
quant_model
,
self
.
batch_num
,
...
...
@@ -216,6 +245,67 @@ class TestImperativePTQ(unittest.TestCase):
print
(
"total time: %ss
\n
"
%
(
end_time
-
start_time
))
class
TestImperativePTQfuse
(
TestImperativePTQ
):
def
test_ptq
(
self
):
start_time
=
time
.
time
()
self
.
set_vars
()
# Load model
params_path
=
self
.
download_model
(
self
.
lenet_url
,
self
.
lenet_md5
,
"lenet"
)
params_path
+=
"/lenet_pretrained/lenet.pdparams"
model
=
ImperativeLenet
()
model_state_dict
=
paddle
.
load
(
params_path
)
model
.
set_state_dict
(
model_state_dict
)
# Quantize, calibrate and save
f_l
=
[[
'features.0'
,
'features.1'
],
[
'features.4'
,
'features.5'
]]
quant_model
=
self
.
ptq
.
quantize
(
model
,
fuse
=
True
,
fuse_list
=
f_l
)
for
name
,
layer
in
quant_model
.
named_sublayers
():
if
name
in
f_l
:
assert
not
(
isinstance
(
layer
,
nn
.
BatchNorm1D
)
or
isinstance
(
layer
,
nn
.
BatchNorm2D
))
before_acc_top1
=
self
.
model_test
(
quant_model
,
self
.
batch_num
,
self
.
batch_size
)
input_spec
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
1
,
28
,
28
],
dtype
=
'float32'
)
]
self
.
ptq
.
save_quantized_model
(
model
=
quant_model
,
path
=
self
.
save_path
,
input_spec
=
input_spec
)
print
(
'Quantized model saved in {%s}'
%
self
.
save_path
)
after_acc_top1
=
self
.
model_test
(
quant_model
,
self
.
batch_num
,
self
.
batch_size
)
paddle
.
enable_static
()
infer_acc_top1
=
self
.
program_test
(
self
.
save_path
,
self
.
batch_num
,
self
.
batch_size
)
paddle
.
disable_static
()
# Check
print
(
'Before converted acc_top1: %s'
%
before_acc_top1
)
print
(
'After converted acc_top1: %s'
%
after_acc_top1
)
print
(
'Infer acc_top1: %s'
%
infer_acc_top1
)
#Check whether the quant_model is correct after converting.
#The acc of quantized model should be higher than 0.95.
self
.
assertTrue
(
after_acc_top1
>=
self
.
eval_acc_top1
,
msg
=
"The test acc {%f} is less than {%f}."
%
(
after_acc_top1
,
self
.
eval_acc_top1
))
#Check the saved infer_model.The acc of infer model
#should not be lower than the one of dygraph model.
self
.
assertTrue
(
infer_acc_top1
>=
after_acc_top1
,
msg
=
'The acc is lower after converting model.'
)
end_time
=
time
.
time
()
print
(
"total time: %ss
\n
"
%
(
end_time
-
start_time
))
class
TestImperativePTQHist
(
TestImperativePTQ
):
def
set_vars
(
self
):
config
=
PTQConfig
(
HistQuantizer
(),
AbsmaxQuantizer
())
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录