Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
457f649a
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
457f649a
编写于
6月 07, 2021
作者:
X
xiaoting
提交者:
GitHub
6月 07, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add FPN_SSH for blazeface (#3267)
* add fpn for blazeface
上级
6ec36068
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
427 addition
and
20 deletion
+427
-20
configs/face_detection/README.md
configs/face_detection/README.md
+18
-0
configs/face_detection/_base_/blazeface.yml
configs/face_detection/_base_/blazeface.yml
+10
-4
configs/face_detection/_base_/blazeface_fpn.yml
configs/face_detection/_base_/blazeface_fpn.yml
+45
-0
configs/face_detection/blazeface_fpn_ssh_1000e.yml
configs/face_detection/blazeface_fpn_ssh_1000e.yml
+9
-0
deploy/python/infer.py
deploy/python/infer.py
+1
-8
ppdet/engine/trainer.py
ppdet/engine/trainer.py
+0
-1
ppdet/modeling/architectures/__init__.py
ppdet/modeling/architectures/__init__.py
+1
-0
ppdet/modeling/architectures/blazeface.py
ppdet/modeling/architectures/blazeface.py
+91
-0
ppdet/modeling/backbones/blazenet.py
ppdet/modeling/backbones/blazenet.py
+17
-3
ppdet/modeling/heads/face_head.py
ppdet/modeling/heads/face_head.py
+3
-3
ppdet/modeling/necks/__init__.py
ppdet/modeling/necks/__init__.py
+1
-0
ppdet/modeling/necks/blazeface_fpn.py
ppdet/modeling/necks/blazeface_fpn.py
+230
-0
ppdet/utils/checkpoint.py
ppdet/utils/checkpoint.py
+1
-1
未找到文件。
configs/face_detection/README.md
浏览文件 @
457f649a
...
...
@@ -12,6 +12,7 @@
| 网络结构 | 输入尺寸 | 图片个数/GPU | 学习率策略 | Easy/Medium/Hard Set | 预测时延(SD855)| 模型大小(MB) | 下载 | 配置文件 |
|:------------:|:--------:|:----:|:-------:|:-------:|:---------:|:----------:|:---------:|:--------:|
| BlazeFace | 640 | 8 | 1000e | 0.885 / 0.855 / 0.731 | - | 0.472 |
[
下载链接
](
https://paddledet.bj.bcebos.com/models/blazeface_1000e.pdparams
)
|
[
配置文件
](
https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.1/configs/face_detection/blazeface_1000e.yml
)
|
| BlazeFace-FPN-SSH | 640 | 8 | 1000e | 0.907 / 0.883 / 0.793 | - | 0.479 |
[
下载链接
](
https://paddledet.bj.bcebos.com/models/blazeface_fpn_ssh_1000e.pdparams
)
|
[
配置文件
](
https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.1/configs/face_detection/blazeface_fpn_ssh_1000e.yml
)
|
**注意:**
-
我们使用多尺度评估策略得到
`Easy/Medium/Hard Set`
里的mAP。具体细节请参考
[
在WIDER-FACE数据集上评估
](
#在WIDER-FACE数据集上评估
)
。
...
...
@@ -52,6 +53,23 @@
cd dataset/wider_face && ./download_wider_face.sh
```
### 参数配置
基础模型的配置可以参考
`configs/face_detection/_base_/blazeface.yml`
;
改进模型增加FPN和SSH的neck结构,配置文件可以参考
`configs/face_detection/_base_/blazeface_fpn.yml`
,可以根据需求配置FPN和SSH,具体如下:
```
yaml
BlazeNet
:
blaze_filters
:
[[
24
,
24
],
[
24
,
24
],
[
24
,
48
,
2
],
[
48
,
48
],
[
48
,
48
]]
double_blaze_filters
:
[[
48
,
24
,
96
,
2
],
[
96
,
24
,
96
],
[
96
,
24
,
96
],
[
96
,
24
,
96
,
2
],
[
96
,
24
,
96
],
[
96
,
24
,
96
]]
act
:
hard_swish
#配置backbone中BlazeBlock的激活函数,基础模型为relu,增加FPN和SSH时需使用hard_swish
BlazeNeck
:
neck_type
:
fpn_ssh
#可选only_fpn、only_ssh和fpn_ssh
in_channel
:
[
96
,
96
]
```
### 训练与评估
训练流程与评估流程方法与其他算法一致,请参考
[
GETTING_STARTED_cn.md
](
../../docs/tutorials/GETTING_STARTED_cn.md
)
。
**注意:**
人脸检测模型目前不支持边训练边评估。
...
...
configs/face_detection/_base_/blazeface.yml
浏览文件 @
457f649a
architecture
:
SSD
architecture
:
BlazeFace
SSD
:
BlazeFace
:
backbone
:
BlazeNet
ssd_head
:
FaceHead
neck
:
BlazeNeck
blaze_head
:
FaceHead
post_process
:
BBoxPostProcess
BlazeNet
:
blaze_filters
:
[[
24
,
24
],
[
24
,
24
],
[
24
,
48
,
2
],
[
48
,
48
],
[
48
,
48
]]
double_blaze_filters
:
[[
48
,
24
,
96
,
2
],
[
96
,
24
,
96
],
[
96
,
24
,
96
],
[
96
,
24
,
96
,
2
],
[
96
,
24
,
96
],
[
96
,
24
,
96
]]
act
:
relu
BlazeNeck
:
neck_type
:
None
in_channel
:
[
96
,
96
]
FaceHead
:
in_channels
:
[
96
,
96
]
in_channels
:
[
96
,
96
]
anchor_generator
:
AnchorGeneratorSSD
loss
:
SSDLoss
...
...
configs/face_detection/_base_/blazeface_fpn.yml
0 → 100644
浏览文件 @
457f649a
architecture
:
BlazeFace
BlazeFace
:
backbone
:
BlazeNet
neck
:
BlazeNeck
blaze_head
:
FaceHead
post_process
:
BBoxPostProcess
BlazeNet
:
blaze_filters
:
[[
24
,
24
],
[
24
,
24
],
[
24
,
48
,
2
],
[
48
,
48
],
[
48
,
48
]]
double_blaze_filters
:
[[
48
,
24
,
96
,
2
],
[
96
,
24
,
96
],
[
96
,
24
,
96
],
[
96
,
24
,
96
,
2
],
[
96
,
24
,
96
],
[
96
,
24
,
96
]]
act
:
hard_swish
BlazeNeck
:
neck_type
:
fpn_ssh
in_channel
:
[
96
,
96
]
FaceHead
:
in_channels
:
[
48
,
48
]
anchor_generator
:
AnchorGeneratorSSD
loss
:
SSDLoss
SSDLoss
:
overlap_threshold
:
0.35
AnchorGeneratorSSD
:
steps
:
[
8.
,
16.
]
aspect_ratios
:
[[
1.
],
[
1.
]]
min_sizes
:
[[
16.
,
24.
],
[
32.
,
48.
,
64.
,
80.
,
96.
,
128.
]]
max_sizes
:
[[],
[]]
offset
:
0.5
flip
:
False
min_max_aspect_ratios_order
:
false
BBoxPostProcess
:
decode
:
name
:
SSDBox
nms
:
name
:
MultiClassNMS
keep_top_k
:
750
score_threshold
:
0.01
nms_threshold
:
0.3
nms_top_k
:
5000
nms_eta
:
1.0
configs/face_detection/blazeface_fpn_ssh_1000e.yml
0 → 100644
浏览文件 @
457f649a
_BASE_
:
[
'
../datasets/wider_face.yml'
,
'
../runtime.yml'
,
'
_base_/optimizer_1000e.yml'
,
'
_base_/blazeface_fpn.yml'
,
'
_base_/face_reader.yml'
,
]
weights
:
output/blazeface_fpn_ssh_1000e/model_final
multi_scale_eval
:
True
deploy/python/infer.py
浏览文件 @
457f649a
...
...
@@ -36,6 +36,7 @@ SUPPORT_MODELS = {
'YOLO'
,
'RCNN'
,
'SSD'
,
'Face'
,
'FCOS'
,
'SOLOv2'
,
'TTFNet'
,
...
...
@@ -113,14 +114,6 @@ class Detector(object):
threshold
=
0.5
):
# postprocess output of predictor
results
=
{}
if
self
.
pred_config
.
arch
in
[
'Face'
]:
h
,
w
=
inputs
[
'im_shape'
]
scale_y
,
scale_x
=
inputs
[
'scale_factor'
]
w
,
h
=
float
(
h
)
/
scale_y
,
float
(
w
)
/
scale_x
np_boxes
[:,
2
]
*=
h
np_boxes
[:,
3
]
*=
w
np_boxes
[:,
4
]
*=
h
np_boxes
[:,
5
]
*=
w
results
[
'boxes'
]
=
np_boxes
results
[
'boxes_num'
]
=
np_boxes_num
if
np_masks
is
not
None
:
...
...
ppdet/engine/trainer.py
浏览文件 @
457f649a
...
...
@@ -433,7 +433,6 @@ class Trainer(object):
if
'segm'
in
batch_res
else
None
keypoint_res
=
batch_res
[
'keypoint'
][
start
:
end
]
\
if
'keypoint'
in
batch_res
else
None
image
=
visualize_results
(
image
,
bbox_res
,
mask_res
,
segm_res
,
keypoint_res
,
int
(
im_id
),
catid2name
,
draw_threshold
)
...
...
ppdet/modeling/architectures/__init__.py
浏览文件 @
457f649a
...
...
@@ -38,3 +38,4 @@ from .jde import *
from
.deepsort
import
*
from
.fairmot
import
*
from
.centernet
import
*
from
.blazeface
import
*
ppdet/modeling/architectures/blazeface.py
0 → 100644
浏览文件 @
457f649a
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
ppdet.core.workspace
import
register
,
create
from
.meta_arch
import
BaseArch
__all__
=
[
'BlazeFace'
]
@
register
class
BlazeFace
(
BaseArch
):
"""
BlazeFace: Sub-millisecond Neural Face Detection on Mobile GPUs,
see https://arxiv.org/abs/1907.05047
Args:
backbone (nn.Layer): backbone instance
neck (nn.Layer): neck instance
blaze_head (nn.Layer): `blazeHead` instance
post_process (object): `BBoxPostProcess` instance
"""
__category__
=
'architecture'
__inject__
=
[
'post_process'
]
def
__init__
(
self
,
backbone
,
blaze_head
,
neck
,
post_process
):
super
(
BlazeFace
,
self
).
__init__
()
self
.
backbone
=
backbone
self
.
neck
=
neck
self
.
blaze_head
=
blaze_head
self
.
post_process
=
post_process
@
classmethod
def
from_config
(
cls
,
cfg
,
*
args
,
**
kwargs
):
# backbone
backbone
=
create
(
cfg
[
'backbone'
])
# fpn
kwargs
=
{
'input_shape'
:
backbone
.
out_shape
}
neck
=
create
(
cfg
[
'neck'
],
**
kwargs
)
# head
kwargs
=
{
'input_shape'
:
neck
.
out_shape
}
blaze_head
=
create
(
cfg
[
'blaze_head'
],
**
kwargs
)
return
{
'backbone'
:
backbone
,
'neck'
:
neck
,
'blaze_head'
:
blaze_head
,
}
def
_forward
(
self
):
# Backbone
body_feats
=
self
.
backbone
(
self
.
inputs
)
# neck
neck_feats
=
self
.
neck
(
body_feats
)
# blaze Head
if
self
.
training
:
return
self
.
blaze_head
(
neck_feats
,
self
.
inputs
[
'image'
],
self
.
inputs
[
'gt_bbox'
],
self
.
inputs
[
'gt_class'
])
else
:
preds
,
anchors
=
self
.
blaze_head
(
neck_feats
,
self
.
inputs
[
'image'
])
bbox
,
bbox_num
=
self
.
post_process
(
preds
,
anchors
,
self
.
inputs
[
'im_shape'
],
self
.
inputs
[
'scale_factor'
])
return
bbox
,
bbox_num
def
get_loss
(
self
,
):
return
{
"loss"
:
self
.
_forward
()}
def
get_pred
(
self
):
bbox_pred
,
bbox_num
=
self
.
_forward
()
output
=
{
"bbox"
:
bbox_pred
,
"bbox_num"
:
bbox_num
,
}
return
output
ppdet/modeling/backbones/blazenet.py
浏览文件 @
457f649a
...
...
@@ -29,6 +29,10 @@ from ..shape_spec import ShapeSpec
__all__
=
[
'BlazeNet'
]
def
hard_swish
(
x
):
return
x
*
F
.
relu6
(
x
+
3
)
/
6.
class
ConvBNLayer
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
...
...
@@ -80,6 +84,10 @@ class ConvBNLayer(nn.Layer):
x
=
F
.
relu
(
x
)
elif
self
.
act
==
"relu6"
:
x
=
F
.
relu6
(
x
)
elif
self
.
act
==
'leaky'
:
x
=
F
.
leaky_relu
(
x
)
elif
self
.
act
==
'hard_swish'
:
x
=
hard_swish
(
x
)
return
x
...
...
@@ -91,6 +99,7 @@ class BlazeBlock(nn.Layer):
double_channels
=
None
,
stride
=
1
,
use_5x5kernel
=
True
,
act
=
'relu'
,
name
=
None
):
super
(
BlazeBlock
,
self
).
__init__
()
assert
stride
in
[
1
,
2
]
...
...
@@ -132,14 +141,14 @@ class BlazeBlock(nn.Layer):
padding
=
1
,
num_groups
=
out_channels1
,
name
=
name
+
"1_dw_2"
)))
act
=
'relu'
if
self
.
use_double_block
else
None
self
.
act
=
act
if
self
.
use_double_block
else
None
self
.
conv_pw
=
ConvBNLayer
(
in_channels
=
out_channels1
,
out_channels
=
out_channels2
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
act
=
act
,
act
=
self
.
act
,
name
=
name
+
"1_sep"
)
if
self
.
use_double_block
:
self
.
conv_dw2
=
[]
...
...
@@ -237,7 +246,8 @@ class BlazeNet(nn.Layer):
blaze_filters
=
[[
24
,
24
],
[
24
,
24
],
[
24
,
48
,
2
],
[
48
,
48
],
[
48
,
48
]],
double_blaze_filters
=
[[
48
,
24
,
96
,
2
],
[
96
,
24
,
96
],
[
96
,
24
,
96
],
[
96
,
24
,
96
,
2
],
[
96
,
24
,
96
],
[
96
,
24
,
96
]],
use_5x5kernel
=
True
):
use_5x5kernel
=
True
,
act
=
None
):
super
(
BlazeNet
,
self
).
__init__
()
conv1_num_filters
=
blaze_filters
[
0
][
0
]
self
.
conv1
=
ConvBNLayer
(
...
...
@@ -262,6 +272,7 @@ class BlazeNet(nn.Layer):
v
[
0
],
v
[
1
],
use_5x5kernel
=
use_5x5kernel
,
act
=
act
,
name
=
'blaze_{}'
.
format
(
k
))))
elif
len
(
v
)
==
3
:
self
.
blaze_block
.
append
(
...
...
@@ -273,6 +284,7 @@ class BlazeNet(nn.Layer):
v
[
1
],
stride
=
v
[
2
],
use_5x5kernel
=
use_5x5kernel
,
act
=
act
,
name
=
'blaze_{}'
.
format
(
k
))))
in_channels
=
v
[
1
]
...
...
@@ -289,6 +301,7 @@ class BlazeNet(nn.Layer):
v
[
1
],
double_channels
=
v
[
2
],
use_5x5kernel
=
use_5x5kernel
,
act
=
act
,
name
=
'double_blaze_{}'
.
format
(
k
))))
elif
len
(
v
)
==
4
:
self
.
blaze_block
.
append
(
...
...
@@ -301,6 +314,7 @@ class BlazeNet(nn.Layer):
double_channels
=
v
[
2
],
stride
=
v
[
3
],
use_5x5kernel
=
use_5x5kernel
,
act
=
act
,
name
=
'double_blaze_{}'
.
format
(
k
))))
in_channels
=
v
[
2
]
self
.
_out_channels
.
append
(
in_channels
)
...
...
ppdet/modeling/heads/face_head.py
浏览文件 @
457f649a
...
...
@@ -41,7 +41,7 @@ class FaceHead(nn.Layer):
def
__init__
(
self
,
num_classes
=
80
,
in_channels
=
(
96
,
96
)
,
in_channels
=
[
96
,
96
]
,
anchor_generator
=
AnchorGeneratorSSD
().
__dict__
,
kernel_size
=
3
,
padding
=
1
,
...
...
@@ -65,7 +65,7 @@ class FaceHead(nn.Layer):
box_conv
=
self
.
add_sublayer
(
box_conv_name
,
nn
.
Conv2D
(
in_channels
=
in_channels
[
i
],
in_channels
=
self
.
in_channels
[
i
],
out_channels
=
num_prior
*
4
,
kernel_size
=
kernel_size
,
padding
=
padding
))
...
...
@@ -75,7 +75,7 @@ class FaceHead(nn.Layer):
score_conv
=
self
.
add_sublayer
(
score_conv_name
,
nn
.
Conv2D
(
in_channels
=
in_channels
[
i
],
in_channels
=
self
.
in_channels
[
i
],
out_channels
=
num_prior
*
self
.
num_classes
,
kernel_size
=
kernel_size
,
padding
=
padding
))
...
...
ppdet/modeling/necks/__init__.py
浏览文件 @
457f649a
...
...
@@ -23,3 +23,4 @@ from .yolo_fpn import *
from
.hrfpn
import
*
from
.ttf_fpn
import
*
from
.centernet_fpn
import
*
from
.blazeface_fpn
import
*
ppdet/modeling/necks/blazeface_fpn.py
0 → 100644
浏览文件 @
457f649a
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
math
import
paddle
import
paddle.nn.functional
as
F
from
paddle
import
ParamAttr
import
paddle.nn
as
nn
from
paddle.nn.initializer
import
KaimingNormal
from
ppdet.core.workspace
import
register
,
serializable
from
ppdet.modeling.layers
import
ConvNormLayer
from
..shape_spec
import
ShapeSpec
__all__
=
[
'BlazeNeck'
]
def
hard_swish
(
x
):
return
x
*
F
.
relu6
(
x
+
3
)
/
6.
class
ConvBNLayer
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
num_groups
=
1
,
act
=
'relu'
,
conv_lr
=
0.1
,
conv_decay
=
0.
,
norm_decay
=
0.
,
norm_type
=
'bn'
,
name
=
None
):
super
(
ConvBNLayer
,
self
).
__init__
()
self
.
act
=
act
self
.
_conv
=
nn
.
Conv2D
(
in_channels
,
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
groups
=
num_groups
,
weight_attr
=
ParamAttr
(
learning_rate
=
conv_lr
,
initializer
=
KaimingNormal
(),
name
=
name
+
"_weights"
),
bias_attr
=
False
)
param_attr
=
ParamAttr
(
name
=
name
+
"_bn_scale"
)
bias_attr
=
ParamAttr
(
name
=
name
+
"_bn_offset"
)
if
norm_type
==
'sync_bn'
:
self
.
_batch_norm
=
nn
.
SyncBatchNorm
(
out_channels
,
weight_attr
=
param_attr
,
bias_attr
=
bias_attr
)
else
:
self
.
_batch_norm
=
nn
.
BatchNorm
(
out_channels
,
act
=
None
,
param_attr
=
param_attr
,
bias_attr
=
bias_attr
,
use_global_stats
=
False
,
moving_mean_name
=
name
+
'_bn_mean'
,
moving_variance_name
=
name
+
'_bn_variance'
)
def
forward
(
self
,
x
):
x
=
self
.
_conv
(
x
)
x
=
self
.
_batch_norm
(
x
)
if
self
.
act
==
"relu"
:
x
=
F
.
relu
(
x
)
elif
self
.
act
==
"relu6"
:
x
=
F
.
relu6
(
x
)
elif
self
.
act
==
'leaky'
:
x
=
F
.
leaky_relu
(
x
)
elif
self
.
act
==
'hard_swish'
:
x
=
hard_swish
(
x
)
return
x
class
FPN
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
name
=
None
):
super
(
FPN
,
self
).
__init__
()
self
.
conv1_fpn
=
ConvBNLayer
(
in_channels
,
out_channels
//
2
,
kernel_size
=
1
,
padding
=
0
,
stride
=
1
,
act
=
'leaky'
,
name
=
name
+
'_output1'
)
self
.
conv2_fpn
=
ConvBNLayer
(
in_channels
,
out_channels
//
2
,
kernel_size
=
1
,
padding
=
0
,
stride
=
1
,
act
=
'leaky'
,
name
=
name
+
'_output2'
)
self
.
conv3_fpn
=
ConvBNLayer
(
out_channels
//
2
,
out_channels
//
2
,
kernel_size
=
3
,
padding
=
1
,
stride
=
1
,
act
=
'leaky'
,
name
=
name
+
'_merge'
)
def
forward
(
self
,
input
):
output1
=
self
.
conv1_fpn
(
input
[
0
])
output2
=
self
.
conv2_fpn
(
input
[
1
])
up2
=
F
.
upsample
(
output2
,
size
=
paddle
.
shape
(
output1
)[
-
2
:],
mode
=
'nearest'
)
output1
=
paddle
.
add
(
output1
,
up2
)
output1
=
self
.
conv3_fpn
(
output1
)
return
output1
,
output2
class
SSH
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
name
=
None
):
super
(
SSH
,
self
).
__init__
()
assert
out_channels
%
4
==
0
self
.
conv0_ssh
=
ConvBNLayer
(
in_channels
,
out_channels
//
2
,
kernel_size
=
3
,
padding
=
1
,
stride
=
1
,
act
=
None
,
name
=
name
+
'ssh_conv3'
)
self
.
conv1_ssh
=
ConvBNLayer
(
out_channels
//
2
,
out_channels
//
4
,
kernel_size
=
3
,
padding
=
1
,
stride
=
1
,
act
=
'leaky'
,
name
=
name
+
'ssh_conv5_1'
)
self
.
conv2_ssh
=
ConvBNLayer
(
out_channels
//
4
,
out_channels
//
4
,
kernel_size
=
3
,
padding
=
1
,
stride
=
1
,
act
=
None
,
name
=
name
+
'ssh_conv5_2'
)
self
.
conv3_ssh
=
ConvBNLayer
(
out_channels
//
4
,
out_channels
//
4
,
kernel_size
=
3
,
padding
=
1
,
stride
=
1
,
act
=
'leaky'
,
name
=
name
+
'ssh_conv7_1'
)
self
.
conv4_ssh
=
ConvBNLayer
(
out_channels
//
4
,
out_channels
//
4
,
kernel_size
=
3
,
padding
=
1
,
stride
=
1
,
act
=
None
,
name
=
name
+
'ssh_conv7_2'
)
def
forward
(
self
,
x
):
conv0
=
self
.
conv0_ssh
(
x
)
conv1
=
self
.
conv1_ssh
(
conv0
)
conv2
=
self
.
conv2_ssh
(
conv1
)
conv3
=
self
.
conv3_ssh
(
conv2
)
conv4
=
self
.
conv4_ssh
(
conv3
)
concat
=
paddle
.
concat
([
conv0
,
conv2
,
conv4
],
axis
=
1
)
return
F
.
relu
(
concat
)
@
register
@
serializable
class
BlazeNeck
(
nn
.
Layer
):
def
__init__
(
self
,
in_channel
,
neck_type
=
"None"
,
data_format
=
'NCHW'
):
super
(
BlazeNeck
,
self
).
__init__
()
self
.
neck_type
=
neck_type
self
.
reture_input
=
False
self
.
_out_channels
=
in_channel
if
self
.
neck_type
==
'None'
:
self
.
reture_input
=
True
if
"fpn"
in
self
.
neck_type
:
self
.
fpn
=
FPN
(
self
.
_out_channels
[
0
],
self
.
_out_channels
[
1
],
name
=
'fpn'
)
self
.
_out_channels
=
[
self
.
_out_channels
[
0
]
//
2
,
self
.
_out_channels
[
1
]
//
2
]
if
"ssh"
in
self
.
neck_type
:
self
.
ssh1
=
SSH
(
self
.
_out_channels
[
0
],
self
.
_out_channels
[
0
],
name
=
'ssh1'
)
self
.
ssh2
=
SSH
(
self
.
_out_channels
[
1
],
self
.
_out_channels
[
1
],
name
=
'ssh2'
)
self
.
_out_channels
=
[
self
.
_out_channels
[
0
],
self
.
_out_channels
[
1
]]
def
forward
(
self
,
inputs
):
if
self
.
reture_input
:
return
inputs
output1
,
output2
=
None
,
None
if
"fpn"
in
self
.
neck_type
:
backout_4
,
backout_1
=
inputs
output1
,
output2
=
self
.
fpn
([
backout_4
,
backout_1
])
if
self
.
neck_type
==
"only_fpn"
:
return
[
output1
,
output2
]
if
self
.
neck_type
==
"only_ssh"
:
output1
,
output2
=
inputs
feature1
=
self
.
ssh1
(
output1
)
feature2
=
self
.
ssh2
(
output2
)
return
[
feature1
,
feature2
]
@
property
def
out_shape
(
self
):
return
[
ShapeSpec
(
channels
=
c
)
for
c
in
[
self
.
_out_channels
[
0
],
self
.
_out_channels
[
1
]]
]
ppdet/utils/checkpoint.py
浏览文件 @
457f649a
...
...
@@ -162,7 +162,7 @@ def load_pretrain_weight(model, pretrain_weight):
# hack: fit for faster rcnn. Pretrain weights contain prefix of 'backbone'
# while res5 module is located in bbox_head.head. Replace the prefix of
# res5 with 'bbox_head.head' to load pretrain weights correctly.
for
k
in
param_state_dict
.
keys
(
):
for
k
in
list
(
param_state_dict
.
keys
()
):
if
'backbone.res5'
in
k
:
new_k
=
k
.
replace
(
'backbone'
,
'bbox_head.head'
)
if
new_k
in
model_dict
.
keys
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录