Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
4e88bec5
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
接近 2 年 前同步成功
通知
707
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4e88bec5
编写于
5月 13, 2021
作者:
S
shangliang Xu
提交者:
GitHub
5月 13, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add res2net (#2992)
上级
50410757
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
543 addition
and
0 deletion
+543
-0
configs/res2net/README.md
configs/res2net/README.md
+37
-0
configs/res2net/faster_rcnn_res2net50_vb_26w_4s_fpn_1x_coco.yml
...s/res2net/faster_rcnn_res2net50_vb_26w_4s_fpn_1x_coco.yml
+33
-0
configs/res2net/mask_rcnn_res2net50_vb_26w_4s_fpn_2x_coco.yml
...igs/res2net/mask_rcnn_res2net50_vb_26w_4s_fpn_2x_coco.yml
+47
-0
configs/res2net/mask_rcnn_res2net50_vd_26w_4s_fpn_2x_coco.yml
...igs/res2net/mask_rcnn_res2net50_vd_26w_4s_fpn_2x_coco.yml
+47
-0
docs/MODEL_ZOO_cn.md
docs/MODEL_ZOO_cn.md
+20
-0
ppdet/modeling/backbones/__init__.py
ppdet/modeling/backbones/__init__.py
+2
-0
ppdet/modeling/backbones/res2net.py
ppdet/modeling/backbones/res2net.py
+357
-0
未找到文件。
configs/res2net/README.md
0 → 100644
浏览文件 @
4e88bec5
# Res2Net
## Introduction
-
Res2Net: A New Multi-scale Backbone Architecture:
[
https://arxiv.org/abs/1904.01169
](
https://arxiv.org/abs/1904.01169
)
```
@article{DBLP:journals/corr/abs-1904-01169,
author = {Shanghua Gao and
Ming{-}Ming Cheng and
Kai Zhao and
Xinyu Zhang and
Ming{-}Hsuan Yang and
Philip H. S. Torr},
title = {Res2Net: {A} New Multi-scale Backbone Architecture},
journal = {CoRR},
volume = {abs/1904.01169},
year = {2019},
url = {http://arxiv.org/abs/1904.01169},
archivePrefix = {arXiv},
eprint = {1904.01169},
timestamp = {Thu, 25 Apr 2019 10:24:54 +0200},
biburl = {https://dblp.org/rec/bib/journals/corr/abs-1904-01169},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
```
## Model Zoo
| Backbone | Type | Image/gpu | Lr schd | Inf time (fps) | Box AP | Mask AP | Download | Configs |
| :---------------------- | :------------- | :-------: | :-----: | :------------: | :----: | :-----: | :----------------------------------------------------------: | :-----: |
| Res2Net50-FPN | Faster | 2 | 1x | - | 40.6 | - |
[
model
](
https://paddledet.bj.bcebos.com/models/faster_rcnn_res2net50_vb_26w_4s_fpn_1x_coco.pdparams
)
|
[
config
](
https://github.com/PaddlePaddle/PaddleDetection/develop/configs/res2net/faster_rcnn_res2net50_vb_26w_4s_fpn_1x.yml
)
|
| Res2Net50-FPN | Mask | 2 | 2x | - | 42.4 | 38.1 |
[
model
](
https://paddledet.bj.bcebos.com/models/mask_rcnn_res2net50_vb_26w_4s_fpn_2x_coco.pdparams
)
|
[
config
](
https://github.com/PaddlePaddle/PaddleDetection/develop/configs/res2net/mask_rcnn_res2net50_vb_26w_4s_fpn_2x_coco.yml
)
|
| Res2Net50-vd-FPN | Mask | 2 | 2x | - | 42.6 | 38.1 |
[
model
](
https://paddledet.bj.bcebos.com/models/mask_rcnn_res2net50_vd_26w_4s_fpn_2x_coco.pdparams
)
|
[
config
](
https://github.com/PaddlePaddle/PaddleDetection/develop/configs/res2net/mask_rcnn_res2net50_vd_26w_4s_fpn_2x_coco.yml
)
|
Note: all the above models are trained with 8 gpus.
configs/res2net/faster_rcnn_res2net50_vb_26w_4s_fpn_1x_coco.yml
0 → 100644
浏览文件 @
4e88bec5
_BASE_
:
[
'
../datasets/coco_detection.yml'
,
'
../runtime.yml'
,
'
../faster_rcnn/_base_/optimizer_1x.yml'
,
'
../faster_rcnn/_base_/faster_rcnn_r50_fpn.yml'
,
'
../faster_rcnn/_base_/faster_fpn_reader.yml'
,
]
pretrain_weights
:
https://paddledet.bj.bcebos.com/models/pretrained/Res2Net50_26w_4s_pretrained.pdparams
weights
:
output/faster_rcnn_res2net50_vb_26w_4s_fpn_1x_coco/model_final
FasterRCNN
:
backbone
:
Res2Net
neck
:
FPN
rpn_head
:
RPNHead
bbox_head
:
BBoxHead
# post process
bbox_post_process
:
BBoxPostProcess
Res2Net
:
# index 0 stands for res2
depth
:
50
width
:
26
scales
:
4
norm_type
:
bn
freeze_at
:
0
return_idx
:
[
0
,
1
,
2
,
3
]
num_stages
:
4
variant
:
b
TrainReader
:
batch_size
:
2
configs/res2net/mask_rcnn_res2net50_vb_26w_4s_fpn_2x_coco.yml
0 → 100644
浏览文件 @
4e88bec5
_BASE_
:
[
'
../datasets/coco_instance.yml'
,
'
../runtime.yml'
,
'
../mask_rcnn/_base_/optimizer_1x.yml'
,
'
../mask_rcnn/_base_/mask_rcnn_r50_fpn.yml'
,
'
../mask_rcnn/_base_/mask_fpn_reader.yml'
,
]
pretrain_weights
:
https://paddledet.bj.bcebos.com/models/pretrained/Res2Net50_26w_4s_pretrained.pdparams
weights
:
output/mask_rcnn_res2net50_vb_26w_4s_fpn_2x_coco/model_final
MaskRCNN
:
backbone
:
Res2Net
neck
:
FPN
rpn_head
:
RPNHead
bbox_head
:
BBoxHead
mask_head
:
MaskHead
# post process
bbox_post_process
:
BBoxPostProcess
mask_post_process
:
MaskPostProcess
Res2Net
:
# index 0 stands for res2
depth
:
50
width
:
26
scales
:
4
norm_type
:
bn
freeze_at
:
0
return_idx
:
[
0
,
1
,
2
,
3
]
num_stages
:
4
variant
:
b
epoch
:
24
LearningRate
:
base_lr
:
0.01
schedulers
:
-
!PiecewiseDecay
gamma
:
0.1
milestones
:
[
16
,
22
]
-
!LinearWarmup
start_factor
:
0.3333333333333333
steps
:
500
TrainReader
:
batch_size
:
2
configs/res2net/mask_rcnn_res2net50_vd_26w_4s_fpn_2x_coco.yml
0 → 100644
浏览文件 @
4e88bec5
_BASE_
:
[
'
../datasets/coco_instance.yml'
,
'
../runtime.yml'
,
'
../mask_rcnn/_base_/optimizer_1x.yml'
,
'
../mask_rcnn/_base_/mask_rcnn_r50_fpn.yml'
,
'
../mask_rcnn/_base_/mask_fpn_reader.yml'
,
]
pretrain_weights
:
https://paddledet.bj.bcebos.com/models/pretrained/Res2Net50_vd_26w_4s_pretrained.pdparams
weights
:
output/mask_rcnn_res2net50_vd_26w_4s_fpn_2x_coco/model_final
MaskRCNN
:
backbone
:
Res2Net
neck
:
FPN
rpn_head
:
RPNHead
bbox_head
:
BBoxHead
mask_head
:
MaskHead
# post process
bbox_post_process
:
BBoxPostProcess
mask_post_process
:
MaskPostProcess
Res2Net
:
# index 0 stands for res2
depth
:
50
width
:
26
scales
:
4
norm_type
:
bn
freeze_at
:
0
return_idx
:
[
0
,
1
,
2
,
3
]
num_stages
:
4
variant
:
d
epoch
:
24
LearningRate
:
base_lr
:
0.01
schedulers
:
-
!PiecewiseDecay
gamma
:
0.1
milestones
:
[
16
,
22
]
-
!LinearWarmup
start_factor
:
0.3333333333333333
steps
:
500
TrainReader
:
batch_size
:
2
docs/MODEL_ZOO_cn.md
浏览文件 @
4e88bec5
...
...
@@ -63,3 +63,23 @@ Paddle提供基于ImageNet的骨架网络预训练模型。所有预训练模型
### TTFNet
请参考
[
TTFNet
](
https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/ttfnet/
)
### Group Normalization
请参考
[
Group Normalization
](
https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/gn/
)
### Deformable ConvNets v2
请参考
[
Deformable ConvNets v2
](
https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/dcn/
)
### HRNets
请参考
[
HRNets
](
https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/hrnet/
)
### S2ANet
请参考
[
S2ANet
](
https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/dota/
)
### Res2Net
请参考
[
Res2Net
](
https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/res2net/
)
ppdet/modeling/backbones/__init__.py
浏览文件 @
4e88bec5
...
...
@@ -21,6 +21,7 @@ from . import hrnet
from
.
import
blazenet
from
.
import
ghostnet
from
.
import
senet
from
.
import
res2net
from
.vgg
import
*
from
.resnet
import
*
...
...
@@ -31,3 +32,4 @@ from .hrnet import *
from
.blazenet
import
*
from
.ghostnet
import
*
from
.senet
import
*
from
.res2net
import
*
ppdet/modeling/backbones/res2net.py
0 → 100644
浏览文件 @
4e88bec5
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
numbers
import
Integral
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
ppdet.core.workspace
import
register
,
serializable
from
..shape_spec
import
ShapeSpec
from
.resnet
import
ConvNormLayer
__all__
=
[
'Res2Net'
,
'Res2NetC5'
]
Res2Net_cfg
=
{
50
:
[
3
,
4
,
6
,
3
],
101
:
[
3
,
4
,
23
,
3
],
152
:
[
3
,
8
,
36
,
3
],
200
:
[
3
,
12
,
48
,
3
]
}
class
BottleNeck
(
nn
.
Layer
):
def
__init__
(
self
,
ch_in
,
ch_out
,
stride
,
shortcut
,
width
,
scales
=
4
,
variant
=
'b'
,
groups
=
1
,
lr
=
1.0
,
norm_type
=
'bn'
,
norm_decay
=
0.
,
freeze_norm
=
True
,
dcn_v2
=
False
):
super
(
BottleNeck
,
self
).
__init__
()
self
.
shortcut
=
shortcut
self
.
scales
=
scales
self
.
stride
=
stride
if
not
shortcut
:
if
variant
==
'd'
and
stride
==
2
:
self
.
branch1
=
nn
.
Sequential
()
self
.
branch1
.
add_sublayer
(
'pool'
,
nn
.
AvgPool2D
(
kernel_size
=
2
,
stride
=
2
,
padding
=
0
,
ceil_mode
=
True
))
self
.
branch1
.
add_sublayer
(
'conv'
,
ConvNormLayer
(
ch_in
=
ch_in
,
ch_out
=
ch_out
,
filter_size
=
1
,
stride
=
1
,
norm_type
=
norm_type
,
norm_decay
=
norm_decay
,
freeze_norm
=
freeze_norm
,
lr
=
lr
))
else
:
self
.
branch1
=
ConvNormLayer
(
ch_in
=
ch_in
,
ch_out
=
ch_out
,
filter_size
=
1
,
stride
=
stride
,
norm_type
=
norm_type
,
norm_decay
=
norm_decay
,
freeze_norm
=
freeze_norm
,
lr
=
lr
)
self
.
branch2a
=
ConvNormLayer
(
ch_in
=
ch_in
,
ch_out
=
width
*
scales
,
filter_size
=
1
,
stride
=
stride
if
variant
==
'a'
else
1
,
groups
=
1
,
act
=
'relu'
,
norm_type
=
norm_type
,
norm_decay
=
norm_decay
,
freeze_norm
=
freeze_norm
,
lr
=
lr
)
self
.
branch2b
=
nn
.
LayerList
([
ConvNormLayer
(
ch_in
=
width
,
ch_out
=
width
,
filter_size
=
3
,
stride
=
1
if
variant
==
'a'
else
stride
,
groups
=
groups
,
act
=
'relu'
,
norm_type
=
norm_type
,
norm_decay
=
norm_decay
,
freeze_norm
=
freeze_norm
,
lr
=
lr
,
dcn_v2
=
dcn_v2
)
for
_
in
range
(
self
.
scales
-
1
)
])
self
.
branch2c
=
ConvNormLayer
(
ch_in
=
width
*
scales
,
ch_out
=
ch_out
,
filter_size
=
1
,
stride
=
1
,
groups
=
1
,
norm_type
=
norm_type
,
norm_decay
=
norm_decay
,
freeze_norm
=
freeze_norm
,
lr
=
lr
)
def
forward
(
self
,
inputs
):
out
=
self
.
branch2a
(
inputs
)
feature_split
=
paddle
.
split
(
out
,
self
.
scales
,
1
)
out_split
=
[]
for
i
in
range
(
self
.
scales
-
1
):
if
i
==
0
or
self
.
stride
==
2
:
out_split
.
append
(
self
.
branch2b
[
i
](
feature_split
[
i
]))
else
:
out_split
.
append
(
self
.
branch2b
[
i
](
paddle
.
add
(
feature_split
[
i
],
out_split
[
-
1
])))
if
self
.
stride
==
1
:
out_split
.
append
(
feature_split
[
-
1
])
else
:
out_split
.
append
(
F
.
avg_pool2d
(
feature_split
[
-
1
],
3
,
self
.
stride
,
1
))
out
=
self
.
branch2c
(
paddle
.
concat
(
out_split
,
1
))
if
self
.
shortcut
:
short
=
inputs
else
:
short
=
self
.
branch1
(
inputs
)
out
=
paddle
.
add
(
out
,
short
)
out
=
F
.
relu
(
out
)
return
out
class
Blocks
(
nn
.
Layer
):
def
__init__
(
self
,
ch_in
,
ch_out
,
count
,
stage_num
,
width
,
scales
=
4
,
variant
=
'b'
,
groups
=
1
,
lr
=
1.0
,
norm_type
=
'bn'
,
norm_decay
=
0.
,
freeze_norm
=
True
,
dcn_v2
=
False
):
super
(
Blocks
,
self
).
__init__
()
self
.
blocks
=
nn
.
Sequential
()
for
i
in
range
(
count
):
self
.
blocks
.
add_sublayer
(
str
(
i
),
BottleNeck
(
ch_in
=
ch_in
if
i
==
0
else
ch_out
,
ch_out
=
ch_out
,
stride
=
2
if
i
==
0
and
stage_num
!=
2
else
1
,
shortcut
=
False
if
i
==
0
else
True
,
width
=
width
*
(
2
**
(
stage_num
-
2
)),
scales
=
scales
,
variant
=
variant
,
groups
=
groups
,
lr
=
lr
,
norm_type
=
norm_type
,
norm_decay
=
norm_decay
,
freeze_norm
=
freeze_norm
,
dcn_v2
=
dcn_v2
))
def
forward
(
self
,
inputs
):
return
self
.
blocks
(
inputs
)
@
register
@
serializable
class
Res2Net
(
nn
.
Layer
):
"""
Res2Net, see https://arxiv.org/abs/1904.01169
Args:
depth (int): Res2Net depth, should be 50, 101, 152, 200.
width (int): Res2Net width
scales (int): Res2Net scale
variant (str): Res2Net variant, supports 'a', 'b', 'c', 'd' currently
lr_mult_list (list): learning rate ratio of different resnet stages(2,3,4,5),
lower learning rate ratio is need for pretrained model
got using distillation(default as [1.0, 1.0, 1.0, 1.0]).
groups (int): The groups number of the Conv Layer.
norm_type (str): normalization type, 'bn' or 'sync_bn'
norm_decay (float): weight decay for normalization layer weights
freeze_norm (bool): freeze normalization layers
freeze_at (int): freeze the backbone at which stage
return_idx (list): index of stages whose feature maps are returned,
index 0 stands for res2
dcn_v2_stages (list): index of stages who select deformable conv v2
num_stages (int): number of stages created
"""
__shared__
=
[
'norm_type'
]
def
__init__
(
self
,
depth
=
50
,
width
=
26
,
scales
=
4
,
variant
=
'b'
,
lr_mult_list
=
[
1.0
,
1.0
,
1.0
,
1.0
],
groups
=
1
,
norm_type
=
'bn'
,
norm_decay
=
0.
,
freeze_norm
=
True
,
freeze_at
=
0
,
return_idx
=
[
0
,
1
,
2
,
3
],
dcn_v2_stages
=
[
-
1
],
num_stages
=
4
):
super
(
Res2Net
,
self
).
__init__
()
self
.
_model_type
=
'Res2Net'
if
groups
==
1
else
'Res2NeXt'
assert
depth
in
[
50
,
101
,
152
,
200
],
\
"depth {} not in [50, 101, 152, 200]"
assert
variant
in
[
'a'
,
'b'
,
'c'
,
'd'
],
"invalid Res2Net variant"
assert
num_stages
>=
1
and
num_stages
<=
4
self
.
depth
=
depth
self
.
variant
=
variant
self
.
norm_type
=
norm_type
self
.
norm_decay
=
norm_decay
self
.
freeze_norm
=
freeze_norm
self
.
freeze_at
=
freeze_at
if
isinstance
(
return_idx
,
Integral
):
return_idx
=
[
return_idx
]
assert
max
(
return_idx
)
<
num_stages
,
\
'the maximum return index must smaller than num_stages, '
\
'but received maximum return index is {} and num_stages '
\
'is {}'
.
format
(
max
(
return_idx
),
num_stages
)
self
.
return_idx
=
return_idx
self
.
num_stages
=
num_stages
assert
len
(
lr_mult_list
)
==
4
,
\
"lr_mult_list length must be 4 but got {}"
.
format
(
len
(
lr_mult_list
))
if
isinstance
(
dcn_v2_stages
,
Integral
):
dcn_v2_stages
=
[
dcn_v2_stages
]
assert
max
(
dcn_v2_stages
)
<
num_stages
self
.
dcn_v2_stages
=
dcn_v2_stages
block_nums
=
Res2Net_cfg
[
depth
]
# C1 stage
if
self
.
variant
in
[
'c'
,
'd'
]:
conv_def
=
[
[
3
,
32
,
3
,
2
,
"conv1_1"
],
[
32
,
32
,
3
,
1
,
"conv1_2"
],
[
32
,
64
,
3
,
1
,
"conv1_3"
],
]
else
:
conv_def
=
[[
3
,
64
,
7
,
2
,
"conv1"
]]
self
.
res1
=
nn
.
Sequential
()
for
(
c_in
,
c_out
,
k
,
s
,
_name
)
in
conv_def
:
self
.
res1
.
add_sublayer
(
_name
,
ConvNormLayer
(
ch_in
=
c_in
,
ch_out
=
c_out
,
filter_size
=
k
,
stride
=
s
,
groups
=
1
,
act
=
'relu'
,
norm_type
=
norm_type
,
norm_decay
=
norm_decay
,
freeze_norm
=
freeze_norm
,
lr
=
1.0
))
self
.
_in_channels
=
[
64
,
256
,
512
,
1024
]
self
.
_out_channels
=
[
256
,
512
,
1024
,
2048
]
self
.
_out_strides
=
[
4
,
8
,
16
,
32
]
# C2-C5 stages
self
.
res_layers
=
[]
for
i
in
range
(
num_stages
):
lr_mult
=
lr_mult_list
[
i
]
stage_num
=
i
+
2
self
.
res_layers
.
append
(
self
.
add_sublayer
(
"res{}"
.
format
(
stage_num
),
Blocks
(
self
.
_in_channels
[
i
],
self
.
_out_channels
[
i
],
count
=
block_nums
[
i
],
stage_num
=
stage_num
,
width
=
width
,
scales
=
scales
,
groups
=
groups
,
lr
=
lr_mult
,
norm_type
=
norm_type
,
norm_decay
=
norm_decay
,
freeze_norm
=
freeze_norm
,
dcn_v2
=
(
i
in
self
.
dcn_v2_stages
))))
@
property
def
out_shape
(
self
):
return
[
ShapeSpec
(
channels
=
self
.
_out_channels
[
i
],
stride
=
self
.
_out_strides
[
i
])
for
i
in
self
.
return_idx
]
def
forward
(
self
,
inputs
):
x
=
inputs
[
'image'
]
res1
=
self
.
res1
(
x
)
x
=
F
.
max_pool2d
(
res1
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
outs
=
[]
for
idx
,
stage
in
enumerate
(
self
.
res_layers
):
x
=
stage
(
x
)
if
idx
==
self
.
freeze_at
:
x
.
stop_gradient
=
True
if
idx
in
self
.
return_idx
:
outs
.
append
(
x
)
return
outs
@
register
class
Res2NetC5
(
nn
.
Layer
):
def
__init__
(
self
,
depth
=
50
,
width
=
26
,
scales
=
4
,
variant
=
'b'
):
super
(
Res2NetC5
,
self
).
__init__
()
feat_in
,
feat_out
=
[
1024
,
2048
]
self
.
res5
=
Blocks
(
feat_in
,
feat_out
,
count
=
3
,
stage_num
=
5
,
width
=
width
,
scales
=
scales
,
variant
=
variant
)
self
.
feat_out
=
feat_out
@
property
def
out_shape
(
self
):
return
[
ShapeSpec
(
channels
=
self
.
feat_out
,
stride
=
32
,
)]
def
forward
(
self
,
roi_feat
,
stage
=
0
):
y
=
self
.
res5
(
roi_feat
)
return
y
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录