Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
aaf77a05
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
aaf77a05
编写于
1月 07, 2020
作者:
littletomatodonkey
提交者:
GitHub
1月 07, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add gcnet model (#166)
* add gcnet model :
https://arxiv.org/abs/1904.11492
上级
c925d0f6
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
442 addition
and
17 deletion
+442
-17
configs/gcnet/README.md
configs/gcnet/README.md
+34
-0
configs/gcnet/mask_rcnn_r50_vd_fpn_gcb_add_r16_2x.yml
configs/gcnet/mask_rcnn_r50_vd_fpn_gcb_add_r16_2x.yml
+119
-0
configs/gcnet/mask_rcnn_r50_vd_fpn_gcb_mul_r16_2x.yml
configs/gcnet/mask_rcnn_r50_vd_fpn_gcb_mul_r16_2x.yml
+119
-0
docs/MODEL_ZOO.md
docs/MODEL_ZOO.md
+3
-0
docs/MODEL_ZOO_cn.md
docs/MODEL_ZOO_cn.md
+3
-0
ppdet/modeling/backbones/gc_block.py
ppdet/modeling/backbones/gc_block.py
+124
-0
ppdet/modeling/backbones/resnet.py
ppdet/modeling/backbones/resnet.py
+40
-17
未找到文件。
configs/gcnet/README.md
0 → 100644
浏览文件 @
aaf77a05
# GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond
## Introduction
-
GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond
:
[
https://arxiv.org/abs/1904.11492
](
https://arxiv.org/abs/1904.11492
)
```
@article{DBLP:journals/corr/abs-1904-11492,
author = {Yue Cao and
Jiarui Xu and
Stephen Lin and
Fangyun Wei and
Han Hu},
title = {GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond},
journal = {CoRR},
volume = {abs/1904.11492},
year = {2019},
url = {http://arxiv.org/abs/1904.11492},
archivePrefix = {arXiv},
eprint = {1904.11492},
timestamp = {Tue, 09 Jul 2019 16:48:55 +0200},
biburl = {https://dblp.org/rec/bib/journals/corr/abs-1904-11492},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
```
## Model Zoo
| Backbone | Type | Context| Image/gpu | Lr schd | Inf time (fps) | Box AP | Mask AP | Download |
| :---------------------- | :-------------: | :-------------: | :-------: | :-----: | :------------: | :----: | :-----: | :----------------------------------------------------------: |
| ResNet50-vd-FPN | Mask | GC(c3-c5, r16, add) | 2 | 2x | 15.31 | 41.4 | 36.8 |
[
model
](
https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r50_vd_fpn_gcb_add_r16_2x.tar
)
|
| ResNet50-vd-FPN | Mask | GC(c3-c5, r16, mul) | 2 | 2x | 15.35 | 40.7 | 36.1 |
[
model
](
https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r50_vd_fpn_gcb_mul_r16_2x.tar
)
|
configs/gcnet/mask_rcnn_r50_vd_fpn_gcb_add_r16_2x.yml
0 → 100644
浏览文件 @
aaf77a05
architecture
:
MaskRCNN
use_gpu
:
true
max_iters
:
180000
snapshot_iter
:
10000
log_smooth_window
:
20
save_dir
:
output
pretrain_weights
:
https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar
metric
:
COCO
weights
:
output/mask_rcnn_r50_vd_fpn_gcb_add_r16_2x/model_final/
num_classes
:
81
MaskRCNN
:
backbone
:
ResNet
fpn
:
FPN
rpn_head
:
FPNRPNHead
roi_extractor
:
FPNRoIAlign
bbox_head
:
BBoxHead
bbox_assigner
:
BBoxAssigner
ResNet
:
depth
:
50
feature_maps
:
[
2
,
3
,
4
,
5
]
freeze_at
:
2
norm_type
:
bn
variant
:
d
gcb_stages
:
[
3
,
4
,
5
]
gcb_params
:
ratio
:
0.0625
pooling_type
:
att
fusion_types
:
[
channel_add
]
FPN
:
max_level
:
6
min_level
:
2
num_chan
:
256
spatial_scale
:
[
0.03125
,
0.0625
,
0.125
,
0.25
]
FPNRPNHead
:
anchor_generator
:
aspect_ratios
:
[
0.5
,
1.0
,
2.0
]
variance
:
[
1.0
,
1.0
,
1.0
,
1.0
]
anchor_start_size
:
32
max_level
:
6
min_level
:
2
num_chan
:
256
rpn_target_assign
:
rpn_batch_size_per_im
:
256
rpn_fg_fraction
:
0.5
rpn_negative_overlap
:
0.3
rpn_positive_overlap
:
0.7
rpn_straddle_thresh
:
0.0
train_proposal
:
min_size
:
0.0
nms_thresh
:
0.7
pre_nms_top_n
:
2000
post_nms_top_n
:
2000
test_proposal
:
min_size
:
0.0
nms_thresh
:
0.7
pre_nms_top_n
:
1000
post_nms_top_n
:
1000
FPNRoIAlign
:
canconical_level
:
4
canonical_size
:
224
max_level
:
5
min_level
:
2
box_resolution
:
7
sampling_ratio
:
2
mask_resolution
:
14
MaskHead
:
dilation
:
1
conv_dim
:
256
num_convs
:
4
resolution
:
28
BBoxAssigner
:
batch_size_per_im
:
512
bbox_reg_weights
:
[
0.1
,
0.1
,
0.2
,
0.2
]
bg_thresh_hi
:
0.5
bg_thresh_lo
:
0.0
fg_fraction
:
0.25
fg_thresh
:
0.5
MaskAssigner
:
resolution
:
28
BBoxHead
:
head
:
TwoFCHead
nms
:
keep_top_k
:
100
nms_threshold
:
0.5
score_threshold
:
0.05
TwoFCHead
:
mlp_dim
:
1024
LearningRate
:
base_lr
:
0.02
schedulers
:
-
!PiecewiseDecay
gamma
:
0.1
milestones
:
[
120000
,
160000
]
-
!LinearWarmup
start_factor
:
0.1
steps
:
1000
OptimizerBuilder
:
optimizer
:
momentum
:
0.9
type
:
Momentum
regularizer
:
factor
:
0.0001
type
:
L2
_READER_
:
'
../mask_fpn_reader.yml'
TrainReader
:
batch_size
:
2
configs/gcnet/mask_rcnn_r50_vd_fpn_gcb_mul_r16_2x.yml
0 → 100644
浏览文件 @
aaf77a05
architecture
:
MaskRCNN
use_gpu
:
true
max_iters
:
180000
snapshot_iter
:
10000
log_smooth_window
:
20
save_dir
:
output
pretrain_weights
:
https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar
metric
:
COCO
weights
:
output/mask_rcnn_r50_vd_fpn_gcb_mul_r16_2x/model_final/
num_classes
:
81
MaskRCNN
:
backbone
:
ResNet
fpn
:
FPN
rpn_head
:
FPNRPNHead
roi_extractor
:
FPNRoIAlign
bbox_head
:
BBoxHead
bbox_assigner
:
BBoxAssigner
ResNet
:
depth
:
50
feature_maps
:
[
2
,
3
,
4
,
5
]
freeze_at
:
2
norm_type
:
bn
variant
:
d
gcb_stages
:
[
3
,
4
,
5
]
gcb_params
:
ratio
:
0.0625
pooling_type
:
att
fusion_types
:
[
channel_mul
]
FPN
:
max_level
:
6
min_level
:
2
num_chan
:
256
spatial_scale
:
[
0.03125
,
0.0625
,
0.125
,
0.25
]
FPNRPNHead
:
anchor_generator
:
aspect_ratios
:
[
0.5
,
1.0
,
2.0
]
variance
:
[
1.0
,
1.0
,
1.0
,
1.0
]
anchor_start_size
:
32
max_level
:
6
min_level
:
2
num_chan
:
256
rpn_target_assign
:
rpn_batch_size_per_im
:
256
rpn_fg_fraction
:
0.5
rpn_negative_overlap
:
0.3
rpn_positive_overlap
:
0.7
rpn_straddle_thresh
:
0.0
train_proposal
:
min_size
:
0.0
nms_thresh
:
0.7
pre_nms_top_n
:
2000
post_nms_top_n
:
2000
test_proposal
:
min_size
:
0.0
nms_thresh
:
0.7
pre_nms_top_n
:
1000
post_nms_top_n
:
1000
FPNRoIAlign
:
canconical_level
:
4
canonical_size
:
224
max_level
:
5
min_level
:
2
box_resolution
:
7
sampling_ratio
:
2
mask_resolution
:
14
MaskHead
:
dilation
:
1
conv_dim
:
256
num_convs
:
4
resolution
:
28
BBoxAssigner
:
batch_size_per_im
:
512
bbox_reg_weights
:
[
0.1
,
0.1
,
0.2
,
0.2
]
bg_thresh_hi
:
0.5
bg_thresh_lo
:
0.0
fg_fraction
:
0.25
fg_thresh
:
0.5
MaskAssigner
:
resolution
:
28
BBoxHead
:
head
:
TwoFCHead
nms
:
keep_top_k
:
100
nms_threshold
:
0.5
score_threshold
:
0.05
TwoFCHead
:
mlp_dim
:
1024
LearningRate
:
base_lr
:
0.02
schedulers
:
-
!PiecewiseDecay
gamma
:
0.1
milestones
:
[
120000
,
160000
]
-
!LinearWarmup
start_factor
:
0.1
steps
:
1000
OptimizerBuilder
:
optimizer
:
momentum
:
0.9
type
:
Momentum
regularizer
:
factor
:
0.0001
type
:
L2
_READER_
:
'
../mask_fpn_reader.yml'
TrainReader
:
batch_size
:
2
docs/MODEL_ZOO.md
浏览文件 @
aaf77a05
...
...
@@ -102,6 +102,9 @@ The backbone models pretrained on ImageNet are available. All backbone models ar
### IOU loss
*
GIOU loss and DIOU loss are included now. See more details in
[
IOU loss model zoo
](
../configs/iou_loss/README.md
)
.
### GCNet
*
See more details in
[
GCNet model zoo
](
../configs/gcnet/README.md
)
.
### Group Normalization
| Backbone | Type | Image/gpu | Lr schd | Box AP | Mask AP | Download |
...
...
docs/MODEL_ZOO_cn.md
浏览文件 @
aaf77a05
...
...
@@ -99,6 +99,9 @@ Paddle提供基于ImageNet的骨架网络预训练模型。所有预训练模型
### IOU loss
*
目前模型库中包括GIOU loss和DIOU loss,详情加
[
IOU loss模型库
](
../configs/iou_loss/README.md
)
.
### GCNet
*
详情见
[
GCNet模型库
](
../configs/gcnet/README.md
)
.
### Group Normalization
| 骨架网络 | 网络类型 | 每张GPU图片个数 | 学习率策略 | Box AP | Mask AP | 下载 |
...
...
ppdet/modeling/backbones/gc_block.py
0 → 100755
浏览文件 @
aaf77a05
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
unicode_literals
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid
import
ParamAttr
from
paddle.fluid.initializer
import
ConstantInitializer
def
spatial_pool
(
x
,
pooling_type
,
name
):
_
,
channel
,
height
,
width
=
x
.
shape
if
pooling_type
==
'att'
:
input_x
=
x
# [N, 1, C, H * W]
input_x
=
fluid
.
layers
.
reshape
(
input_x
,
shape
=
(
0
,
1
,
channel
,
-
1
))
context_mask
=
fluid
.
layers
.
conv2d
(
input
=
x
,
num_filters
=
1
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
param_attr
=
ParamAttr
(
name
=
name
+
"_weights"
),
bias_attr
=
ParamAttr
(
name
=
name
+
"_bias"
))
# [N, 1, H * W]
context_mask
=
fluid
.
layers
.
reshape
(
context_mask
,
shape
=
(
0
,
0
,
-
1
))
# [N, 1, H * W]
context_mask
=
fluid
.
layers
.
softmax
(
context_mask
,
axis
=
2
)
# [N, 1, H * W, 1]
context_mask
=
fluid
.
layers
.
reshape
(
context_mask
,
shape
=
(
0
,
0
,
-
1
,
1
))
# [N, 1, C, 1]
context
=
fluid
.
layers
.
matmul
(
input_x
,
context_mask
)
# [N, C, 1, 1]
context
=
fluid
.
layers
.
reshape
(
context
,
shape
=
(
0
,
channel
,
1
,
1
))
else
:
# [N, C, 1, 1]
context
=
fluid
.
layers
.
pool2d
(
input
=
x
,
pool_type
=
'avg'
,
global_pooling
=
True
)
return
context
def
channel_conv
(
input
,
inner_ch
,
out_ch
,
name
):
conv
=
fluid
.
layers
.
conv2d
(
input
=
input
,
num_filters
=
inner_ch
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
param_attr
=
ParamAttr
(
name
=
name
+
"_conv1_weights"
),
bias_attr
=
ParamAttr
(
name
=
name
+
"_conv1_bias"
),
name
=
name
+
"_conv1"
,
)
conv
=
fluid
.
layers
.
layer_norm
(
conv
,
begin_norm_axis
=
1
,
param_attr
=
ParamAttr
(
name
=
name
+
"_ln_weights"
),
bias_attr
=
ParamAttr
(
name
=
name
+
"_ln_bias"
),
act
=
"relu"
,
name
=
name
+
"_ln"
)
conv
=
fluid
.
layers
.
conv2d
(
input
=
conv
,
num_filters
=
out_ch
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
param_attr
=
ParamAttr
(
name
=
name
+
"_conv2_weights"
,
initializer
=
ConstantInitializer
(
value
=
0.0
),
),
bias_attr
=
ParamAttr
(
name
=
name
+
"_conv2_bias"
,
initializer
=
ConstantInitializer
(
value
=
0.0
),
),
name
=
name
+
"_conv2"
)
return
conv
def
add_gc_block
(
x
,
ratio
=
1.0
/
16
,
pooling_type
=
'att'
,
fusion_types
=
[
'channel_add'
],
name
=
None
):
'''
GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond, see https://arxiv.org/abs/1904.11492
Args:
ratio (float): channel reduction ratio
pooling_type (str): pooling type, support att and avg
fusion_types (list): fusion types, support channel_add and channel_mul
name (str): prefix name of gc block
'''
assert
pooling_type
in
[
'avg'
,
'att'
]
assert
isinstance
(
fusion_types
,
(
list
,
tuple
))
valid_fusion_types
=
[
'channel_add'
,
'channel_mul'
]
assert
all
([
f
in
valid_fusion_types
for
f
in
fusion_types
])
assert
len
(
fusion_types
)
>
0
,
'at least one fusion should be used'
inner_ch
=
int
(
ratio
*
x
.
shape
[
1
])
out_ch
=
x
.
shape
[
1
]
context
=
spatial_pool
(
x
,
pooling_type
,
name
+
"_spatial_pool"
)
out
=
x
if
'channel_mul'
in
fusion_types
:
inner_out
=
channel_conv
(
context
,
inner_ch
,
out_ch
,
name
+
"_mul"
)
channel_mul_term
=
fluid
.
layers
.
sigmoid
(
inner_out
)
out
=
out
*
channel_mul_term
if
'channel_add'
in
fusion_types
:
channel_add_term
=
channel_conv
(
context
,
inner_ch
,
out_ch
,
name
+
"_add"
)
out
=
out
+
channel_add_term
return
out
ppdet/modeling/backbones/resnet.py
浏览文件 @
aaf77a05
...
...
@@ -28,6 +28,7 @@ from ppdet.core.workspace import register, serializable
from
numbers
import
Integral
from
.nonlocal_helper
import
add_space_nonlocal
from
.gc_block
import
add_gc_block
from
.name_adapter
import
NameAdapter
__all__
=
[
'ResNet'
,
'ResNetC5'
]
...
...
@@ -48,6 +49,10 @@ class ResNet(object):
feature_maps (list): index of stages whose feature maps are returned
dcn_v2_stages (list): index of stages who select deformable conv v2
nonlocal_stages (list): index of stages who select nonlocal networks
gcb_stages (list): index of stages who select gc blocks
gcb_params (dict): gc blocks config, includes ratio(default as 1.0/16),
pooling_type(default as "att") and
fusion_types(default as ['channel_add'])
"""
__shared__
=
[
'norm_type'
,
'freeze_norm'
,
'weight_prefix_name'
]
...
...
@@ -61,7 +66,9 @@ class ResNet(object):
feature_maps
=
[
2
,
3
,
4
,
5
],
dcn_v2_stages
=
[],
weight_prefix_name
=
''
,
nonlocal_stages
=
[]):
nonlocal_stages
=
[],
gcb_stages
=
[],
gcb_params
=
dict
()):
super
(
ResNet
,
self
).
__init__
()
if
isinstance
(
feature_maps
,
Integral
):
...
...
@@ -97,15 +104,18 @@ class ResNet(object):
self
.
_c1_out_chan_num
=
64
self
.
na
=
NameAdapter
(
self
)
self
.
prefix_name
=
weight_prefix_name
self
.
nonlocal_stages
=
nonlocal_stages
self
.
nonlocal_mod_cfg
=
{
50
:
2
,
101
:
5
,
152
:
8
,
200
:
12
,
50
:
2
,
101
:
5
,
152
:
8
,
200
:
12
,
}
self
.
gcb_stages
=
gcb_stages
self
.
gcb_params
=
gcb_params
def
_conv_offset
(
self
,
input
,
filter_size
,
...
...
@@ -257,7 +267,9 @@ class ResNet(object):
stride
,
is_first
,
name
,
dcn_v2
=
False
):
dcn_v2
=
False
,
gcb
=
False
,
gcb_name
=
None
):
if
self
.
variant
==
'a'
:
stride1
,
stride2
=
stride
,
1
else
:
...
...
@@ -309,6 +321,8 @@ class ResNet(object):
if
callable
(
getattr
(
self
,
'_squeeze_excitation'
,
None
)):
residual
=
self
.
_squeeze_excitation
(
input
=
residual
,
num_channels
=
num_filters
,
name
=
'fc'
+
name
)
if
gcb
:
residual
=
add_gc_block
(
residual
,
name
=
gcb_name
,
**
self
.
gcb_params
)
return
fluid
.
layers
.
elementwise_add
(
x
=
short
,
y
=
residual
,
act
=
'relu'
,
name
=
name
+
".add.output.5"
)
...
...
@@ -318,8 +332,11 @@ class ResNet(object):
stride
,
is_first
,
name
,
dcn_v2
=
False
):
dcn_v2
=
False
,
gcb
=
False
,
gcb_name
=
None
):
assert
dcn_v2
is
False
,
"Not implemented yet."
assert
gcb
is
False
,
"Not implemented yet."
conv0
=
self
.
_conv_norm
(
input
=
input
,
num_filters
=
num_filters
,
...
...
@@ -354,11 +371,12 @@ class ResNet(object):
ch_out
=
self
.
stage_filters
[
stage_num
-
2
]
is_first
=
False
if
stage_num
!=
2
else
True
dcn_v2
=
True
if
stage_num
in
self
.
dcn_v2_stages
else
False
nonlocal_mod
=
1000
if
stage_num
in
self
.
nonlocal_stages
:
nonlocal_mod
=
self
.
nonlocal_mod_cfg
[
self
.
depth
]
if
stage_num
==
4
else
2
nonlocal_mod
=
self
.
nonlocal_mod_cfg
[
self
.
depth
]
if
stage_num
==
4
else
2
# Make the layer name and parameter name consistent
# with ImageNet pre-trained model
conv
=
input
...
...
@@ -366,21 +384,26 @@ class ResNet(object):
conv_name
=
self
.
na
.
fix_layer_warp_name
(
stage_num
,
count
,
i
)
if
self
.
depth
<
50
:
is_first
=
True
if
i
==
0
and
stage_num
==
2
else
False
gcb
=
stage_num
in
self
.
gcb_stages
gcb_name
=
"gcb_res{}_b{}"
.
format
(
stage_num
,
i
)
conv
=
block_func
(
input
=
conv
,
num_filters
=
ch_out
,
stride
=
2
if
i
==
0
and
stage_num
!=
2
else
1
,
is_first
=
is_first
,
name
=
conv_name
,
dcn_v2
=
dcn_v2
)
dcn_v2
=
dcn_v2
,
gcb
=
gcb
,
gcb_name
=
gcb_name
)
# add non local model
dim_in
=
conv
.
shape
[
1
]
nonlocal_name
=
"nonlocal_conv{}"
.
format
(
stage_num
)
nonlocal_name
=
"nonlocal_conv{}"
.
format
(
stage_num
)
if
i
%
nonlocal_mod
==
nonlocal_mod
-
1
:
conv
=
add_space_nonlocal
(
conv
,
dim_in
,
dim_in
,
nonlocal_name
+
'_{}'
.
format
(
i
),
int
(
dim_in
/
2
)
)
conv
=
add_space_nonlocal
(
conv
,
dim_in
,
dim_in
,
nonlocal_name
+
'_{}'
.
format
(
i
)
,
int
(
dim_in
/
2
)
)
return
conv
def
c1_stage
(
self
,
input
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录