Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
1adb26ef
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1adb26ef
编写于
6月 05, 2021
作者:
G
Guanghua Yu
提交者:
GitHub
6月 05, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add solov2_r101vd model (#3286)
上级
32abf1a0
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
92 addition
and
23 deletion
+92
-23
configs/solov2/README.md
configs/solov2/README.md
+1
-0
configs/solov2/_base_/solov2_r50_fpn.yml
configs/solov2/_base_/solov2_r50_fpn.yml
+0
-1
configs/solov2/solov2_r101_vd_fpn_3x_coco.yml
configs/solov2/solov2_r101_vd_fpn_3x_coco.yml
+66
-0
ppdet/modeling/heads/solov2_head.py
ppdet/modeling/heads/solov2_head.py
+25
-22
未找到文件。
configs/solov2/README.md
浏览文件 @
1adb26ef
...
...
@@ -21,6 +21,7 @@ SOLOv2 (Segmenting Objects by Locations) is a fast instance segmentation framewo
| SOLOv2 (Paper) | X101-DCN-FPN | True | 3x | 42.4 | 5.9 | V100 | - | - |
| SOLOv2 | R50-FPN | False | 1x | 35.5 | 21.9 | V100 |
[
model
](
https://paddledet.bj.bcebos.com/models/solov2_r50_fpn_1x_coco.pdparams
)
|
[
config
](
https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/solov2/solov2_r50_fpn_1x_coco.yml
)
|
| SOLOv2 | R50-FPN | True | 3x | 38.0 | 21.9 | V100 |
[
model
](
https://paddledet.bj.bcebos.com/models/solov2_r50_fpn_3x_coco.pdparams
)
|
[
config
](
https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/solov2/solov2_r50_fpn_3x_coco.yml
)
|
| SOLOv2 | R101vd-FPN | True | 3x | 42.7 | 12.1 | V100 |
[
model
](
https://paddledet.bj.bcebos.com/models/solov2_r101_vd_fpn_3x_coco.pdparams
)
|
[
config
](
https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/solov2/solov2_r101_vd_fpn_3x_coco.yml
)
|
**Notes:**
...
...
configs/solov2/_base_/solov2_r50_fpn.yml
浏览文件 @
1adb26ef
...
...
@@ -9,7 +9,6 @@ SOLOv2:
ResNet
:
depth
:
50
norm_type
:
bn
freeze_at
:
0
return_idx
:
[
0
,
1
,
2
,
3
]
num_stages
:
4
...
...
configs/solov2/solov2_r101_vd_fpn_3x_coco.yml
0 → 100644
浏览文件 @
1adb26ef
_BASE_
:
[
'
../datasets/coco_instance.yml'
,
'
../runtime.yml'
,
'
_base_/solov2_r50_fpn.yml'
,
'
_base_/optimizer_1x.yml'
,
'
_base_/solov2_reader.yml'
,
]
pretrain_weights
:
https://paddledet.bj.bcebos.com/models/pretrained/ResNet101_vd_pretrained.pdparams
weights
:
output/solov2_r101_vd_fpn_3x_coco/model_final
epoch
:
36
use_ema
:
true
ema_decay
:
0.9998
ResNet
:
depth
:
101
variant
:
d
freeze_at
:
0
return_idx
:
[
0
,
1
,
2
,
3
]
dcn_v2_stages
:
[
1
,
2
,
3
]
num_stages
:
4
SOLOv2Head
:
seg_feat_channels
:
512
stacked_convs
:
4
num_grids
:
[
40
,
36
,
24
,
16
,
12
]
kernel_out_channels
:
256
solov2_loss
:
SOLOv2Loss
mask_nms
:
MaskMatrixNMS
dcn_v2_stages
:
[
0
,
1
,
2
,
3
]
SOLOv2MaskHead
:
mid_channels
:
128
out_channels
:
256
start_level
:
0
end_level
:
3
use_dcn_in_tower
:
True
LearningRate
:
base_lr
:
0.01
schedulers
:
-
!PiecewiseDecay
gamma
:
0.1
milestones
:
[
24
,
33
]
-
!LinearWarmup
start_factor
:
0.
steps
:
2000
TrainReader
:
sample_transforms
:
-
Decode
:
{}
-
Poly2Mask
:
{}
-
RandomResize
:
{
interp
:
1
,
target_size
:
[[
640
,
1333
],
[
672
,
1333
],
[
704
,
1333
],
[
736
,
1333
],
[
768
,
1333
],
[
800
,
1333
]],
keep_ratio
:
True
}
-
RandomFlip
:
{}
-
NormalizeImage
:
{
is_scale
:
true
,
mean
:
[
0.485
,
0.456
,
0.406
],
std
:
[
0.229
,
0.224
,
0.225
]}
-
Permute
:
{}
batch_transforms
:
-
PadBatch
:
{
pad_to_stride
:
32
}
-
Gt2Solov2Target
:
{
num_grids
:
[
40
,
36
,
24
,
16
,
12
],
scale_ranges
:
[[
1
,
96
],
[
48
,
192
],
[
96
,
384
],
[
192
,
768
],
[
384
,
2048
]],
coord_sigma
:
0.2
}
batch_size
:
2
shuffle
:
true
drop_last
:
true
ppdet/modeling/heads/solov2_head.py
浏览文件 @
1adb26ef
...
...
@@ -43,6 +43,7 @@ class SOLOv2MaskHead(nn.Layer):
end_level (int): The position where the input ends.
use_dcn_in_tower (bool): Whether to use dcn in tower or not.
"""
__shared__
=
[
'norm_type'
]
def
__init__
(
self
,
in_channels
=
256
,
...
...
@@ -50,7 +51,8 @@ class SOLOv2MaskHead(nn.Layer):
out_channels
=
256
,
start_level
=
0
,
end_level
=
3
,
use_dcn_in_tower
=
False
):
use_dcn_in_tower
=
False
,
norm_type
=
'gn'
):
super
(
SOLOv2MaskHead
,
self
).
__init__
()
assert
start_level
>=
0
and
end_level
>=
start_level
self
.
in_channels
=
in_channels
...
...
@@ -58,24 +60,22 @@ class SOLOv2MaskHead(nn.Layer):
self
.
mid_channels
=
mid_channels
self
.
use_dcn_in_tower
=
use_dcn_in_tower
self
.
range_level
=
end_level
-
start_level
+
1
# TODO: add DeformConvNorm
conv_type
=
[
ConvNormLayer
]
self
.
conv_func
=
conv_type
[
0
]
if
self
.
use_dcn_in_tower
:
self
.
conv_func
=
conv_type
[
1
]
self
.
use_dcn
=
True
if
self
.
use_dcn_in_tower
else
False
self
.
convs_all_levels
=
[]
self
.
norm_type
=
norm_type
for
i
in
range
(
start_level
,
end_level
+
1
):
conv_feat_name
=
'mask_feat_head.convs_all_levels.{}'
.
format
(
i
)
conv_pre_feat
=
nn
.
Sequential
()
if
i
==
start_level
:
conv_pre_feat
.
add_sublayer
(
conv_feat_name
+
'.conv'
+
str
(
i
),
self
.
conv_func
(
ConvNormLayer
(
ch_in
=
self
.
in_channels
,
ch_out
=
self
.
mid_channels
,
filter_size
=
3
,
stride
=
1
,
norm_type
=
'gn'
))
use_dcn
=
self
.
use_dcn
,
norm_type
=
self
.
norm_type
))
self
.
add_sublayer
(
'conv_pre_feat'
+
str
(
i
),
conv_pre_feat
)
self
.
convs_all_levels
.
append
(
conv_pre_feat
)
else
:
...
...
@@ -87,12 +87,13 @@ class SOLOv2MaskHead(nn.Layer):
ch_in
=
self
.
mid_channels
conv_pre_feat
.
add_sublayer
(
conv_feat_name
+
'.conv'
+
str
(
j
),
self
.
conv_func
(
ConvNormLayer
(
ch_in
=
ch_in
,
ch_out
=
self
.
mid_channels
,
filter_size
=
3
,
stride
=
1
,
norm_type
=
'gn'
))
use_dcn
=
self
.
use_dcn
,
norm_type
=
self
.
norm_type
))
conv_pre_feat
.
add_sublayer
(
conv_feat_name
+
'.conv'
+
str
(
j
)
+
'act'
,
nn
.
ReLU
())
conv_pre_feat
.
add_sublayer
(
...
...
@@ -105,12 +106,13 @@ class SOLOv2MaskHead(nn.Layer):
conv_pred_name
=
'mask_feat_head.conv_pred.0'
self
.
conv_pred
=
self
.
add_sublayer
(
conv_pred_name
,
self
.
conv_func
(
ConvNormLayer
(
ch_in
=
self
.
mid_channels
,
ch_out
=
self
.
out_channels
,
filter_size
=
1
,
stride
=
1
,
norm_type
=
'gn'
))
use_dcn
=
self
.
use_dcn
,
norm_type
=
self
.
norm_type
))
def
forward
(
self
,
inputs
):
"""
...
...
@@ -165,7 +167,7 @@ class SOLOv2Head(nn.Layer):
mask_nms (object): MaskMatrixNMS instance.
"""
__inject__
=
[
'solov2_loss'
,
'mask_nms'
]
__shared__
=
[
'num_classes'
]
__shared__
=
[
'n
orm_type'
,
'n
um_classes'
]
def
__init__
(
self
,
num_classes
=
80
,
...
...
@@ -179,7 +181,8 @@ class SOLOv2Head(nn.Layer):
solov2_loss
=
None
,
score_threshold
=
0.1
,
mask_threshold
=
0.5
,
mask_nms
=
None
):
mask_nms
=
None
,
norm_type
=
'gn'
):
super
(
SOLOv2Head
,
self
).
__init__
()
self
.
num_classes
=
num_classes
self
.
in_channels
=
in_channels
...
...
@@ -194,33 +197,33 @@ class SOLOv2Head(nn.Layer):
self
.
mask_nms
=
mask_nms
self
.
score_threshold
=
score_threshold
self
.
mask_threshold
=
mask_threshold
self
.
norm_type
=
norm_type
conv_type
=
[
ConvNormLayer
]
self
.
conv_func
=
conv_type
[
0
]
self
.
kernel_pred_convs
=
[]
self
.
cate_pred_convs
=
[]
for
i
in
range
(
self
.
stacked_convs
):
if
i
in
self
.
dcn_v2_stages
:
self
.
conv_func
=
conv_type
[
1
]
use_dcn
=
True
if
i
in
self
.
dcn_v2_stages
else
False
ch_in
=
self
.
in_channels
+
2
if
i
==
0
else
self
.
seg_feat_channels
kernel_conv
=
self
.
add_sublayer
(
'bbox_head.kernel_convs.'
+
str
(
i
),
self
.
conv_func
(
ConvNormLayer
(
ch_in
=
ch_in
,
ch_out
=
self
.
seg_feat_channels
,
filter_size
=
3
,
stride
=
1
,
norm_type
=
'gn'
))
use_dcn
=
use_dcn
,
norm_type
=
self
.
norm_type
))
self
.
kernel_pred_convs
.
append
(
kernel_conv
)
ch_in
=
self
.
in_channels
if
i
==
0
else
self
.
seg_feat_channels
cate_conv
=
self
.
add_sublayer
(
'bbox_head.cate_convs.'
+
str
(
i
),
self
.
conv_func
(
ConvNormLayer
(
ch_in
=
ch_in
,
ch_out
=
self
.
seg_feat_channels
,
filter_size
=
3
,
stride
=
1
,
norm_type
=
'gn'
))
use_dcn
=
use_dcn
,
norm_type
=
self
.
norm_type
))
self
.
cate_pred_convs
.
append
(
cate_conv
)
self
.
solo_kernel
=
self
.
add_sublayer
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录