Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
a972c39e
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a972c39e
编写于
5月 13, 2022
作者:
G
Guanghua Yu
提交者:
GitHub
5月 13, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support fuse conv bn when export model (#5977)
上级
67742521
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
186 addition
and
0 deletion
+186
-0
configs/runtime.yml
configs/runtime.yml
+1
-0
ppdet/engine/trainer.py
ppdet/engine/trainer.py
+6
-0
ppdet/utils/fuse_utils.py
ppdet/utils/fuse_utils.py
+179
-0
未找到文件。
configs/runtime.yml
浏览文件 @
a972c39e
...
@@ -10,3 +10,4 @@ export:
...
@@ -10,3 +10,4 @@ export:
post_process
:
True
# Whether post-processing is included in the network when export model.
post_process
:
True
# Whether post-processing is included in the network when export model.
nms
:
True
# Whether NMS is included in the network when export model.
nms
:
True
# Whether NMS is included in the network when export model.
benchmark
:
False
# It is used to testing model performance, if set `True`, post-process and NMS will not be exported.
benchmark
:
False
# It is used to testing model performance, if set `True`, post-process and NMS will not be exported.
fuse_conv_bn
:
False
ppdet/engine/trainer.py
浏览文件 @
a972c39e
...
@@ -44,6 +44,7 @@ from ppdet.metrics import RBoxMetric, JDEDetMetric, SNIPERCOCOMetric
...
@@ -44,6 +44,7 @@ from ppdet.metrics import RBoxMetric, JDEDetMetric, SNIPERCOCOMetric
from
ppdet.data.source.sniper_coco
import
SniperCOCODataSet
from
ppdet.data.source.sniper_coco
import
SniperCOCODataSet
from
ppdet.data.source.category
import
get_categories
from
ppdet.data.source.category
import
get_categories
import
ppdet.utils.stats
as
stats
import
ppdet.utils.stats
as
stats
from
ppdet.utils.fuse_utils
import
fuse_conv_bn
from
ppdet.utils
import
profiler
from
ppdet.utils
import
profiler
from
.callbacks
import
Callback
,
ComposeCallback
,
LogPrinter
,
Checkpointer
,
WiferFaceEval
,
VisualDLWriter
,
SniperProposalsGenerator
,
WandbCallback
from
.callbacks
import
Callback
,
ComposeCallback
,
LogPrinter
,
Checkpointer
,
WiferFaceEval
,
VisualDLWriter
,
SniperProposalsGenerator
,
WandbCallback
...
@@ -770,6 +771,11 @@ class Trainer(object):
...
@@ -770,6 +771,11 @@ class Trainer(object):
def
export
(
self
,
output_dir
=
'output_inference'
):
def
export
(
self
,
output_dir
=
'output_inference'
):
self
.
model
.
eval
()
self
.
model
.
eval
()
if
hasattr
(
self
.
cfg
,
'export'
)
and
'fuse_conv_bn'
in
self
.
cfg
[
'export'
]
and
self
.
cfg
[
'export'
][
'fuse_conv_bn'
]:
self
.
model
=
fuse_conv_bn
(
self
.
model
)
model_name
=
os
.
path
.
splitext
(
os
.
path
.
split
(
self
.
cfg
.
filename
)[
-
1
])[
0
]
model_name
=
os
.
path
.
splitext
(
os
.
path
.
split
(
self
.
cfg
.
filename
)[
-
1
])[
0
]
save_dir
=
os
.
path
.
join
(
output_dir
,
model_name
)
save_dir
=
os
.
path
.
join
(
output_dir
,
model_name
)
if
not
os
.
path
.
exists
(
save_dir
):
if
not
os
.
path
.
exists
(
save_dir
):
...
...
ppdet/utils/fuse_utils.py
0 → 100644
浏览文件 @
a972c39e
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
copy
import
paddle
import
paddle.nn
as
nn
__all__
=
[
'fuse_conv_bn'
]
def
fuse_conv_bn
(
model
):
is_train
=
False
if
model
.
training
:
model
.
eval
()
is_train
=
True
fuse_list
=
[]
tmp_pair
=
[
None
,
None
]
for
name
,
layer
in
model
.
named_sublayers
():
if
isinstance
(
layer
,
nn
.
Conv2D
):
tmp_pair
[
0
]
=
name
if
isinstance
(
layer
,
nn
.
BatchNorm2D
):
tmp_pair
[
1
]
=
name
if
tmp_pair
[
0
]
and
tmp_pair
[
1
]
and
len
(
tmp_pair
)
==
2
:
fuse_list
.
append
(
tmp_pair
)
tmp_pair
=
[
None
,
None
]
model
=
fuse_layers
(
model
,
fuse_list
)
if
is_train
:
model
.
train
()
return
model
def
find_parent_layer_and_sub_name
(
model
,
name
):
"""
Given the model and the name of a layer, find the parent layer and
the sub_name of the layer.
For example, if name is 'block_1/convbn_1/conv_1', the parent layer is
'block_1/convbn_1' and the sub_name is `conv_1`.
Args:
model(paddle.nn.Layer): the model to be quantized.
name(string): the name of a layer
Returns:
parent_layer, subname
"""
assert
isinstance
(
model
,
nn
.
Layer
),
\
"The model must be the instance of paddle.nn.Layer."
assert
len
(
name
)
>
0
,
"The input (name) should not be empty."
last_idx
=
0
idx
=
0
parent_layer
=
model
while
idx
<
len
(
name
):
if
name
[
idx
]
==
'.'
:
sub_name
=
name
[
last_idx
:
idx
]
if
hasattr
(
parent_layer
,
sub_name
):
parent_layer
=
getattr
(
parent_layer
,
sub_name
)
last_idx
=
idx
+
1
idx
+=
1
sub_name
=
name
[
last_idx
:
idx
]
return
parent_layer
,
sub_name
class
Identity
(
nn
.
Layer
):
'''a layer to replace bn or relu layers'''
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
Identity
,
self
).
__init__
()
def
forward
(
self
,
input
):
return
input
def
fuse_layers
(
model
,
layers_to_fuse
,
inplace
=
False
):
'''
fuse layers in layers_to_fuse
Args:
model(nn.Layer): The model to be fused.
layers_to_fuse(list): The layers' names to be fused. For
example,"fuse_list = [["conv1", "bn1"], ["conv2", "bn2"]]".
A TypeError would be raised if "fuse" was set as
True but "fuse_list" was None.
Default: None.
inplace(bool): Whether apply fusing to the input model.
Default: False.
Return
fused_model(paddle.nn.Layer): The fused model.
'''
if
not
inplace
:
model
=
copy
.
deepcopy
(
model
)
for
layers_list
in
layers_to_fuse
:
layer_list
=
[]
for
layer_name
in
layers_list
:
parent_layer
,
sub_name
=
find_parent_layer_and_sub_name
(
model
,
layer_name
)
layer_list
.
append
(
getattr
(
parent_layer
,
sub_name
))
new_layers
=
_fuse_func
(
layer_list
)
for
i
,
item
in
enumerate
(
layers_list
):
parent_layer
,
sub_name
=
find_parent_layer_and_sub_name
(
model
,
item
)
setattr
(
parent_layer
,
sub_name
,
new_layers
[
i
])
return
model
def
_fuse_func
(
layer_list
):
'''choose the fuser method and fuse layers'''
types
=
tuple
(
type
(
m
)
for
m
in
layer_list
)
fusion_method
=
types_to_fusion_method
.
get
(
types
,
None
)
new_layers
=
[
None
]
*
len
(
layer_list
)
fused_layer
=
fusion_method
(
*
layer_list
)
for
handle_id
,
pre_hook_fn
in
layer_list
[
0
].
_forward_pre_hooks
.
items
():
fused_layer
.
register_forward_pre_hook
(
pre_hook_fn
)
del
layer_list
[
0
].
_forward_pre_hooks
[
handle_id
]
for
handle_id
,
hook_fn
in
layer_list
[
-
1
].
_forward_post_hooks
.
items
():
fused_layer
.
register_forward_post_hook
(
hook_fn
)
del
layer_list
[
-
1
].
_forward_post_hooks
[
handle_id
]
new_layers
[
0
]
=
fused_layer
for
i
in
range
(
1
,
len
(
layer_list
)):
identity
=
Identity
()
identity
.
training
=
layer_list
[
0
].
training
new_layers
[
i
]
=
identity
return
new_layers
def
_fuse_conv_bn
(
conv
,
bn
):
'''fuse conv and bn for train or eval'''
assert
(
conv
.
training
==
bn
.
training
),
\
"Conv and BN both must be in the same mode (train or eval)."
if
conv
.
training
:
assert
bn
.
_num_features
==
conv
.
_out_channels
,
'Output channel of Conv2d must match num_features of BatchNorm2d'
raise
NotImplementedError
else
:
return
_fuse_conv_bn_eval
(
conv
,
bn
)
def
_fuse_conv_bn_eval
(
conv
,
bn
):
'''fuse conv and bn for eval'''
assert
(
not
(
conv
.
training
or
bn
.
training
)),
"Fusion only for eval!"
fused_conv
=
copy
.
deepcopy
(
conv
)
fused_weight
,
fused_bias
=
_fuse_conv_bn_weights
(
fused_conv
.
weight
,
fused_conv
.
bias
,
bn
.
_mean
,
bn
.
_variance
,
bn
.
_epsilon
,
bn
.
weight
,
bn
.
bias
)
fused_conv
.
weight
.
set_value
(
fused_weight
)
if
fused_conv
.
bias
is
None
:
fused_conv
.
bias
=
paddle
.
create_parameter
(
shape
=
[
fused_conv
.
_out_channels
],
is_bias
=
True
,
dtype
=
bn
.
bias
.
dtype
)
fused_conv
.
bias
.
set_value
(
fused_bias
)
return
fused_conv
def
_fuse_conv_bn_weights
(
conv_w
,
conv_b
,
bn_rm
,
bn_rv
,
bn_eps
,
bn_w
,
bn_b
):
'''fuse weights and bias of conv and bn'''
if
conv_b
is
None
:
conv_b
=
paddle
.
zeros_like
(
bn_rm
)
if
bn_w
is
None
:
bn_w
=
paddle
.
ones_like
(
bn_rm
)
if
bn_b
is
None
:
bn_b
=
paddle
.
zeros_like
(
bn_rm
)
bn_var_rsqrt
=
paddle
.
rsqrt
(
bn_rv
+
bn_eps
)
conv_w
=
conv_w
*
\
(
bn_w
*
bn_var_rsqrt
).
reshape
([
-
1
]
+
[
1
]
*
(
len
(
conv_w
.
shape
)
-
1
))
conv_b
=
(
conv_b
-
bn_rm
)
*
bn_var_rsqrt
*
bn_w
+
bn_b
return
conv_w
,
conv_b
types_to_fusion_method
=
{(
nn
.
Conv2D
,
nn
.
BatchNorm2D
):
_fuse_conv_bn
,
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录