Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
866332ed
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 2 年 前同步成功
通知
708
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
866332ed
编写于
4月 15, 2021
作者:
W
wangguanzhong
提交者:
GitHub
4月 15, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add ttfnet enhance (#2609)
* add ttfnet enhance * add doc * fix pafnet training
上级
ac701833
变更
21
隐藏空白更改
内联
并排
Showing
21 changed file
with
1142 addition
and
92 deletion
+1142
-92
configs/ttfnet/README.md
configs/ttfnet/README.md
+38
-1
configs/ttfnet/_base_/optimizer_10x.yml
configs/ttfnet/_base_/optimizer_10x.yml
+19
-0
configs/ttfnet/_base_/optimizer_20x.yml
configs/ttfnet/_base_/optimizer_20x.yml
+20
-0
configs/ttfnet/_base_/pafnet.yml
configs/ttfnet/_base_/pafnet.yml
+41
-0
configs/ttfnet/_base_/pafnet_lite.yml
configs/ttfnet/_base_/pafnet_lite.yml
+44
-0
configs/ttfnet/_base_/pafnet_lite_reader.yml
configs/ttfnet/_base_/pafnet_lite_reader.yml
+40
-0
configs/ttfnet/_base_/pafnet_reader.yml
configs/ttfnet/_base_/pafnet_reader.yml
+40
-0
configs/ttfnet/_base_/ttfnet_darknet53.yml
configs/ttfnet/_base_/ttfnet_darknet53.yml
+3
-2
configs/ttfnet/pafnet_10x_coco.yml
configs/ttfnet/pafnet_10x_coco.yml
+8
-0
configs/ttfnet/pafnet_lite_mobilenet_v3_20x_coco.yml
configs/ttfnet/pafnet_lite_mobilenet_v3_20x_coco.yml
+8
-0
ppdet/data/source/dataset.py
ppdet/data/source/dataset.py
+3
-0
ppdet/data/transform/batch_operators.py
ppdet/data/transform/batch_operators.py
+2
-0
ppdet/data/transform/gridmask_utils.py
ppdet/data/transform/gridmask_utils.py
+2
-2
ppdet/data/transform/operators.py
ppdet/data/transform/operators.py
+27
-16
ppdet/modeling/backbones/mobilenet_v3.py
ppdet/modeling/backbones/mobilenet_v3.py
+6
-6
ppdet/modeling/heads/ttf_head.py
ppdet/modeling/heads/ttf_head.py
+104
-29
ppdet/modeling/layers.py
ppdet/modeling/layers.py
+69
-3
ppdet/modeling/necks/ttf_fpn.py
ppdet/modeling/necks/ttf_fpn.py
+140
-27
static/configs/anchor_free/pafnet_10x_coco.yml
static/configs/anchor_free/pafnet_10x_coco.yml
+170
-0
static/configs/anchor_free/pafnet_lite_mobilenet_v3_20x_coco.yml
...configs/anchor_free/pafnet_lite_mobilenet_v3_20x_coco.yml
+171
-0
static/ppdet/modeling/anchor_heads/ttf_head.py
static/ppdet/modeling/anchor_heads/ttf_head.py
+187
-6
未找到文件。
configs/ttfnet/README.md
浏览文件 @
866332ed
# TTFNet
#
1.
TTFNet
## 简介
...
...
@@ -15,6 +15,43 @@ TTFNet是一种用于实时目标检测且对训练时间友好的网络,对Ce
| :-------------- | :------------- | :-----: | :-----: | :------------: | :-----: | :-----------------------------------------------------: | :-----: |
| DarkNet53 | TTFNet | 12 | 1x | ---- | 33.5 |
[
下载链接
](
https://paddledet.bj.bcebos.com/models/ttfnet_darknet53_1x_coco.pdparams
)
|
[
配置文件
](
https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/ttfnet/ttfnet_darknet53_1x_coco.yml
)
|
# 2. PAFNet
## 简介
PAFNet(Paddle Anchor Free)是PaddleDetection基于TTFNet的优化模型,精度达到anchor free领域SOTA水平,同时产出移动端轻量级模型PAFNet-Lite
PAFNet系列模型从如下方面优化TTFNet模型:
-
[
CutMix
](
https://arxiv.org/abs/1905.04899
)
-
更优的骨干网络: ResNet50vd-DCN
-
更大的训练batch size: 8 GPUs,每GPU batch_size=18
-
Synchronized Batch Normalization
-
[
Deformable Convolution
](
https://arxiv.org/abs/1703.06211
)
-
[
Exponential Moving Average
](
https://www.investopedia.com/terms/e/ema.asp
)
-
更优的预训练模型
## 模型库
| 骨架网络 | 网络类型 | 每张GPU图片个数 | 学习率策略 |推理时间(fps) | Box AP | 下载 | 配置文件 |
| :-------------- | :------------- | :-----: | :-----: | :------------: | :-----: | :-----------------------------------------------------: | :-----: |
| ResNet50vd | PAFNet | 18 | 10x | ---- | 42.2 |
[
下载链接
](
https://paddledet.bj.bcebos.com/models/pafnet_10x_coco.pdparams
)
|
[
配置文件
](
https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/ttfnet/pafnet_10x_coco.yml
)
|
### PAFNet-Lite
| 骨架网络 | 网络类型 | 每张GPU图片个数 | 学习率策略 | Box AP | 麒麟990延时(ms) | 体积(M) | 下载 | 配置文件 |
| :-------------- | :------------- | :-----: | :-----: | :-----: | :------------: | :-----: | :-----------------------------------------------------: | :-----: |
| MobileNetv3 | PAFNet-Lite | 12 | 20x | 23.9 | 26.00 | 14 |
[
下载链接
](
https://paddledet.bj.bcebos.com/models/pafnet_lite_mobilenet_v3_20x_coco.pdparams
)
|
[
配置文件
](
https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/ttfnet/pafnet_lite_mobilenet_v3_20x_coco.yml
)
|
## Citations
```
@article{liu2019training,
...
...
configs/ttfnet/_base_/optimizer_10x.yml
0 → 100644
浏览文件 @
866332ed
epoch
:
120
LearningRate
:
base_lr
:
0.015
schedulers
:
-
!PiecewiseDecay
gamma
:
0.1
milestones
:
[
80
,
110
]
-
!LinearWarmup
start_factor
:
0.2
steps
:
500
OptimizerBuilder
:
optimizer
:
momentum
:
0.9
type
:
Momentum
regularizer
:
factor
:
0.0004
type
:
L2
configs/ttfnet/_base_/optimizer_20x.yml
0 → 100644
浏览文件 @
866332ed
epoch
:
240
LearningRate
:
base_lr
:
0.015
schedulers
:
-
!PiecewiseDecay
gamma
:
0.1
milestones
:
[
160
,
220
]
-
!LinearWarmup
start_factor
:
0.2
steps
:
1000
OptimizerBuilder
:
clip_grad_by_norm
:
35
optimizer
:
momentum
:
0.9
type
:
Momentum
regularizer
:
factor
:
0.0004
type
:
L2
configs/ttfnet/_base_/pafnet.yml
0 → 100644
浏览文件 @
866332ed
architecture
:
TTFNet
pretrain_weights
:
https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_vd_ssld_pretrained.pdparams
norm_type
:
sync_bn
use_ema
:
true
ema_decay
:
0.9998
TTFNet
:
backbone
:
ResNet
neck
:
TTFFPN
ttf_head
:
TTFHead
post_process
:
BBoxPostProcess
ResNet
:
depth
:
50
variant
:
d
return_idx
:
[
0
,
1
,
2
,
3
]
freeze_at
:
-1
norm_decay
:
0.
variant
:
d
dcn_v2_stages
:
[
1
,
2
,
3
]
TTFFPN
:
planes
:
[
256
,
128
,
64
]
shortcut_num
:
[
3
,
2
,
1
]
TTFHead
:
dcn_head
:
true
hm_loss
:
name
:
CTFocalLoss
loss_weight
:
1.
wh_loss
:
name
:
GIoULoss
loss_weight
:
5.
reduction
:
sum
BBoxPostProcess
:
decode
:
name
:
TTFBox
max_per_img
:
100
score_thresh
:
0.01
down_ratio
:
4
configs/ttfnet/_base_/pafnet_lite.yml
0 → 100644
浏览文件 @
866332ed
architecture
:
TTFNet
pretrain_weights
:
https://paddledet.bj.bcebos.com/models/pretrained/MobileNetV3_large_x1_0_ssld_pretrained.pdparams
norm_type
:
sync_bn
TTFNet
:
backbone
:
MobileNetV3
neck
:
TTFFPN
ttf_head
:
TTFHead
post_process
:
BBoxPostProcess
MobileNetV3
:
scale
:
1.0
model_name
:
large
feature_maps
:
[
5
,
8
,
14
,
17
]
with_extra_blocks
:
true
lr_mult_list
:
[
0.25
,
0.25
,
0.5
,
0.5
,
0.75
]
conv_decay
:
0.00001
norm_decay
:
0.0
extra_block_filters
:
[]
TTFFPN
:
planes
:
[
96
,
48
,
24
]
shortcut_num
:
[
2
,
2
,
1
]
lite_neck
:
true
fusion_method
:
concat
TTFHead
:
hm_head_planes
:
48
wh_head_planes
:
24
lite_head
:
true
hm_loss
:
name
:
CTFocalLoss
loss_weight
:
1.
wh_loss
:
name
:
GIoULoss
loss_weight
:
5.
reduction
:
sum
BBoxPostProcess
:
decode
:
name
:
TTFBox
max_per_img
:
100
score_thresh
:
0.01
down_ratio
:
4
configs/ttfnet/_base_/pafnet_lite_reader.yml
0 → 100644
浏览文件 @
866332ed
worker_num
:
2
TrainReader
:
sample_transforms
:
-
Decode
:
{}
-
Cutmix
:
{
alpha
:
1.5
,
beta
:
1.5
}
-
RandomDistort
:
{}
-
RandomExpand
:
{
fill_value
:
[
123.675
,
116.28
,
103.53
]}
-
RandomCrop
:
{
aspect_ratio
:
NULL
,
cover_all_box
:
True
}
-
RandomFlip
:
{}
-
GridMask
:
{
upper_iter
:
300000
}
batch_transforms
:
-
BatchRandomResize
:
{
target_size
:
[
320
,
352
,
384
,
416
,
448
,
480
,
512
],
random_interp
:
True
,
keep_ratio
:
False
}
-
NormalizeImage
:
{
mean
:
[
123.675
,
116.28
,
103.53
],
std
:
[
58.395
,
57.12
,
57.375
],
is_scale
:
false
}
-
Permute
:
{}
-
Gt2TTFTarget
:
{
down_ratio
:
4
}
-
PadBatch
:
{
pad_to_stride
:
32
}
batch_size
:
12
shuffle
:
true
drop_last
:
true
use_shared_memory
:
true
EvalReader
:
sample_transforms
:
-
Decode
:
{}
-
Resize
:
{
interp
:
1
,
target_size
:
[
320
,
320
],
keep_ratio
:
False
}
-
NormalizeImage
:
{
is_scale
:
false
,
mean
:
[
123.675
,
116.28
,
103.53
],
std
:
[
58.395
,
57.12
,
57.375
]}
-
Permute
:
{}
batch_size
:
1
drop_last
:
false
drop_empty
:
false
TestReader
:
sample_transforms
:
-
Decode
:
{}
-
Resize
:
{
interp
:
1
,
target_size
:
[
320
,
320
],
keep_ratio
:
False
}
-
NormalizeImage
:
{
is_scale
:
false
,
mean
:
[
123.675
,
116.28
,
103.53
],
std
:
[
58.395
,
57.12
,
57.375
]}
-
Permute
:
{}
batch_size
:
1
drop_last
:
false
drop_empty
:
false
configs/ttfnet/_base_/pafnet_reader.yml
0 → 100644
浏览文件 @
866332ed
worker_num
:
2
TrainReader
:
sample_transforms
:
-
Decode
:
{}
-
Cutmix
:
{
alpha
:
1.5
,
beta
:
1.5
}
-
RandomDistort
:
{
random_apply
:
false
,
random_channel
:
true
}
-
RandomExpand
:
{
fill_value
:
[
123.675
,
116.28
,
103.53
]}
-
RandomCrop
:
{
aspect_ratio
:
NULL
,
cover_all_box
:
True
}
-
RandomFlip
:
{
prob
:
0.5
}
batch_transforms
:
-
BatchRandomResize
:
{
target_size
:
[
416
,
448
,
480
,
512
,
544
,
576
,
608
,
640
,
672
],
keep_ratio
:
false
}
-
NormalizeImage
:
{
mean
:
[
123.675
,
116.28
,
103.53
],
std
:
[
58.395
,
57.12
,
57.375
],
is_scale
:
false
}
-
Permute
:
{}
-
Gt2TTFTarget
:
{
down_ratio
:
4
}
-
PadBatch
:
{
pad_to_stride
:
32
}
batch_size
:
18
shuffle
:
true
drop_last
:
true
use_shared_memory
:
true
mixup_epoch
:
100
EvalReader
:
sample_transforms
:
-
Decode
:
{}
-
Resize
:
{
interp
:
1
,
target_size
:
[
512
,
512
],
keep_ratio
:
False
}
-
NormalizeImage
:
{
is_scale
:
false
,
mean
:
[
123.675
,
116.28
,
103.53
],
std
:
[
58.395
,
57.12
,
57.375
]}
-
Permute
:
{}
batch_size
:
1
drop_last
:
false
drop_empty
:
false
TestReader
:
sample_transforms
:
-
Decode
:
{}
-
Resize
:
{
interp
:
1
,
target_size
:
[
512
,
512
],
keep_ratio
:
False
}
-
NormalizeImage
:
{
is_scale
:
false
,
mean
:
[
123.675
,
116.28
,
103.53
],
std
:
[
58.395
,
57.12
,
57.375
]}
-
Permute
:
{}
batch_size
:
1
drop_last
:
false
drop_empty
:
false
configs/ttfnet/_base_/ttfnet_darknet53.yml
浏览文件 @
866332ed
...
...
@@ -14,8 +14,9 @@ DarkNet:
norm_type
:
bn
norm_decay
:
0.0004
# use default config
# TTFFPN:
TTFFPN
:
planes
:
[
256
,
128
,
64
]
shortcut_num
:
[
3
,
2
,
1
]
TTFHead
:
hm_loss
:
...
...
configs/ttfnet/pafnet_10x_coco.yml
0 → 100644
浏览文件 @
866332ed
_BASE_
:
[
'
../datasets/coco_detection.yml'
,
'
../runtime.yml'
,
'
_base_/optimizer_10x.yml'
,
'
_base_/pafnet.yml'
,
'
_base_/pafnet_reader.yml'
,
]
weights
:
output/pafnet_10x_coco/model_final
configs/ttfnet/pafnet_lite_mobilenet_v3_20x_coco.yml
0 → 100644
浏览文件 @
866332ed
_BASE_
:
[
'
../datasets/coco_detection.yml'
,
'
../runtime.yml'
,
'
_base_/optimizer_20x.yml'
,
'
_base_/pafnet_lite.yml'
,
'
_base_/pafnet_lite_reader.yml'
,
]
weights
:
output/pafnet_lite_mobilenet_v3_10x_coco/model_final
ppdet/data/source/dataset.py
浏览文件 @
866332ed
...
...
@@ -55,6 +55,7 @@ class DetDataset(Dataset):
self
.
sample_num
=
sample_num
self
.
use_default_label
=
use_default_label
self
.
_epoch
=
0
self
.
_curr_iter
=
0
def
__len__
(
self
,
):
return
len
(
self
.
roidbs
)
...
...
@@ -76,6 +77,8 @@ class DetDataset(Dataset):
copy
.
deepcopy
(
self
.
roidbs
[
np
.
random
.
randint
(
n
)])
for
_
in
range
(
3
)
]
roidb
[
'curr_iter'
]
=
self
.
_curr_iter
self
.
_curr_iter
+=
1
return
self
.
transform
(
roidb
)
...
...
ppdet/data/transform/batch_operators.py
浏览文件 @
866332ed
...
...
@@ -533,6 +533,8 @@ class Gt2TTFTarget(BaseOperator):
sample
.
pop
(
'is_crowd'
)
sample
.
pop
(
'gt_class'
)
sample
.
pop
(
'gt_bbox'
)
if
'gt_score'
in
sample
:
sample
.
pop
(
'gt_score'
)
return
samples
def
draw_truncate_gaussian
(
self
,
heatmap
,
center
,
h_radius
,
w_radius
):
...
...
ppdet/data/transform/gridmask_utils.py
浏览文件 @
866332ed
...
...
@@ -20,7 +20,7 @@ import numpy as np
from
PIL
import
Image
class
Grid
M
ask
(
object
):
class
Grid
m
ask
(
object
):
def
__init__
(
self
,
use_h
=
True
,
use_w
=
True
,
...
...
@@ -30,7 +30,7 @@ class GridMask(object):
mode
=
1
,
prob
=
0.7
,
upper_iter
=
360000
):
super
(
Grid
M
ask
,
self
).
__init__
()
super
(
Grid
m
ask
,
self
).
__init__
()
self
.
use_h
=
use_h
self
.
use_w
=
use_w
self
.
rotate
=
rotate
...
...
ppdet/data/transform/operators.py
浏览文件 @
866332ed
...
...
@@ -308,8 +308,8 @@ class GridMask(BaseOperator):
self
.
prob
=
prob
self
.
upper_iter
=
upper_iter
from
.gridmask_utils
import
Grid
M
ask
self
.
gridmask_op
=
Grid
M
ask
(
from
.gridmask_utils
import
Grid
m
ask
self
.
gridmask_op
=
Grid
m
ask
(
use_h
,
use_w
,
rotate
=
rotate
,
...
...
@@ -1516,14 +1516,14 @@ class Cutmix(BaseOperator):
bbx2
=
np
.
clip
(
cx
+
cut_w
//
2
,
0
,
w
-
1
)
bby2
=
np
.
clip
(
cy
+
cut_h
//
2
,
0
,
h
-
1
)
img_1
=
np
.
zeros
((
h
,
w
,
img1
.
shape
[
2
]),
'float32'
)
img_1
[:
img1
.
shape
[
0
],
:
img1
.
shape
[
1
],
:]
=
\
img_1
_pad
=
np
.
zeros
((
h
,
w
,
img1
.
shape
[
2
]),
'float32'
)
img_1
_pad
[:
img1
.
shape
[
0
],
:
img1
.
shape
[
1
],
:]
=
\
img1
.
astype
(
'float32'
)
img_2
=
np
.
zeros
((
h
,
w
,
img2
.
shape
[
2
]),
'float32'
)
img_2
[:
img2
.
shape
[
0
],
:
img2
.
shape
[
1
],
:]
=
\
img_2
_pad
=
np
.
zeros
((
h
,
w
,
img2
.
shape
[
2
]),
'float32'
)
img_2
_pad
[:
img2
.
shape
[
0
],
:
img2
.
shape
[
1
],
:]
=
\
img2
.
astype
(
'float32'
)
img_1
[
bby1
:
bby2
,
bbx1
:
bbx2
,
:]
=
img2
[
bby1
:
bby2
,
bbx1
:
bbx2
,
:]
return
img_1
img_1
_pad
[
bby1
:
bby2
,
bbx1
:
bbx2
,
:]
=
img_2_pad
[
bby1
:
bby2
,
bbx1
:
bbx2
,
:]
return
img_1
_pad
def
__call__
(
self
,
sample
,
context
=
None
):
if
not
isinstance
(
sample
,
Sequence
):
...
...
@@ -1546,16 +1546,27 @@ class Cutmix(BaseOperator):
gt_class1
=
sample
[
0
][
'gt_class'
]
gt_class2
=
sample
[
1
][
'gt_class'
]
gt_class
=
np
.
concatenate
((
gt_class1
,
gt_class2
),
axis
=
0
)
gt_score1
=
sample
[
0
][
'gt_score'
]
gt_score2
=
sample
[
1
][
'gt_score'
]
gt_score1
=
np
.
ones_like
(
sample
[
0
][
'gt_class'
])
gt_score2
=
np
.
ones_like
(
sample
[
1
][
'gt_class'
])
gt_score
=
np
.
concatenate
(
(
gt_score1
*
factor
,
gt_score2
*
(
1.
-
factor
)),
axis
=
0
)
sample
=
sample
[
0
]
sample
[
'image'
]
=
img
sample
[
'gt_bbox'
]
=
gt_bbox
sample
[
'gt_score'
]
=
gt_score
sample
[
'gt_class'
]
=
gt_class
return
sample
result
=
copy
.
deepcopy
(
sample
[
0
])
result
[
'image'
]
=
img
result
[
'gt_bbox'
]
=
gt_bbox
result
[
'gt_score'
]
=
gt_score
result
[
'gt_class'
]
=
gt_class
if
'is_crowd'
in
sample
[
0
]:
is_crowd1
=
sample
[
0
][
'is_crowd'
]
is_crowd2
=
sample
[
1
][
'is_crowd'
]
is_crowd
=
np
.
concatenate
((
is_crowd1
,
is_crowd2
),
axis
=
0
)
result
[
'is_crowd'
]
=
is_crowd
if
'difficult'
in
sample
[
0
]:
is_difficult1
=
sample
[
0
][
'difficult'
]
is_difficult2
=
sample
[
1
][
'difficult'
]
is_difficult
=
np
.
concatenate
(
(
is_difficult1
,
is_difficult2
),
axis
=
0
)
result
[
'difficult'
]
=
is_difficult
return
result
@
register_op
...
...
ppdet/modeling/backbones/mobilenet_v3.py
浏览文件 @
866332ed
...
...
@@ -330,16 +330,16 @@ class MobileNetV3(nn.Layer):
[
3
,
16
,
16
,
False
,
"relu"
,
1
],
[
3
,
64
,
24
,
False
,
"relu"
,
2
],
[
3
,
72
,
24
,
False
,
"relu"
,
1
],
[
5
,
72
,
40
,
True
,
"relu"
,
2
],
[
5
,
72
,
40
,
True
,
"relu"
,
2
],
# RCNN output
[
5
,
120
,
40
,
True
,
"relu"
,
1
],
[
5
,
120
,
40
,
True
,
"relu"
,
1
],
# YOLOv3 output
[
3
,
240
,
80
,
False
,
"hard_swish"
,
2
],
[
3
,
240
,
80
,
False
,
"hard_swish"
,
2
],
# RCNN output
[
3
,
200
,
80
,
False
,
"hard_swish"
,
1
],
[
3
,
184
,
80
,
False
,
"hard_swish"
,
1
],
[
3
,
184
,
80
,
False
,
"hard_swish"
,
1
],
[
3
,
480
,
112
,
True
,
"hard_swish"
,
1
],
[
3
,
672
,
112
,
True
,
"hard_swish"
,
1
],
# YOLOv3 output
[
5
,
672
,
160
,
True
,
"hard_swish"
,
2
],
# SSD/SSDLite output
[
5
,
672
,
160
,
True
,
"hard_swish"
,
2
],
# SSD/SSDLite
/RCNN
output
[
5
,
960
,
160
,
True
,
"hard_swish"
,
1
],
[
5
,
960
,
160
,
True
,
"hard_swish"
,
1
],
# YOLOv3 output
]
...
...
@@ -347,14 +347,14 @@ class MobileNetV3(nn.Layer):
self
.
cfg
=
[
# k, exp, c, se, nl, s,
[
3
,
16
,
16
,
True
,
"relu"
,
2
],
[
3
,
72
,
24
,
False
,
"relu"
,
2
],
[
3
,
72
,
24
,
False
,
"relu"
,
2
],
# RCNN output
[
3
,
88
,
24
,
False
,
"relu"
,
1
],
# YOLOv3 output
[
5
,
96
,
40
,
True
,
"hard_swish"
,
2
],
[
5
,
96
,
40
,
True
,
"hard_swish"
,
2
],
# RCNN output
[
5
,
240
,
40
,
True
,
"hard_swish"
,
1
],
[
5
,
240
,
40
,
True
,
"hard_swish"
,
1
],
[
5
,
120
,
48
,
True
,
"hard_swish"
,
1
],
[
5
,
144
,
48
,
True
,
"hard_swish"
,
1
],
# YOLOv3 output
[
5
,
288
,
96
,
True
,
"hard_swish"
,
2
],
# SSD/SSDLite output
[
5
,
288
,
96
,
True
,
"hard_swish"
,
2
],
# SSD/SSDLite
/RCNN
output
[
5
,
576
,
96
,
True
,
"hard_swish"
,
1
],
[
5
,
576
,
96
,
True
,
"hard_swish"
,
1
],
# YOLOv3 output
]
...
...
ppdet/modeling/heads/ttf_head.py
浏览文件 @
866332ed
...
...
@@ -19,6 +19,7 @@ from paddle import ParamAttr
from
paddle.nn.initializer
import
Constant
,
Uniform
,
Normal
from
paddle.regularizer
import
L2Decay
from
ppdet.core.workspace
import
register
from
ppdet.modeling.layers
import
DeformableConvV2
,
LiteConv
import
numpy
as
np
...
...
@@ -30,27 +31,61 @@ class HMHead(nn.Layer):
ch_out (int): The channel number of output Tensor.
num_classes (int): Number of classes.
conv_num (int): The convolution number of hm_feat.
dcn_head(bool): whether use dcn in head. False by default.
lite_head(bool): whether use lite version. False by default.
norm_type (string): norm type, 'sync_bn', 'bn', 'gn' are optional.
bn by default
Return:
Heatmap head output
"""
__shared__
=
[
'num_classes'
]
__shared__
=
[
'num_classes'
,
'norm_type'
]
def
__init__
(
self
,
ch_in
,
ch_out
=
128
,
num_classes
=
80
,
conv_num
=
2
):
def
__init__
(
self
,
ch_in
,
ch_out
=
128
,
num_classes
=
80
,
conv_num
=
2
,
dcn_head
=
False
,
lite_head
=
False
,
norm_type
=
'bn'
,
):
super
(
HMHead
,
self
).
__init__
()
head_conv
=
nn
.
Sequential
()
for
i
in
range
(
conv_num
):
name
=
'conv.{}'
.
format
(
i
)
head_conv
.
add_sublayer
(
name
,
nn
.
Conv2D
(
in_channels
=
ch_in
if
i
==
0
else
ch_out
,
out_channels
=
ch_out
,
kernel_size
=
3
,
padding
=
1
,
weight_attr
=
ParamAttr
(
initializer
=
Normal
(
0
,
0.01
)),
bias_attr
=
ParamAttr
(
learning_rate
=
2.
,
regularizer
=
L2Decay
(
0.
))))
head_conv
.
add_sublayer
(
name
+
'.act'
,
nn
.
ReLU
())
if
lite_head
:
lite_name
=
'hm.'
+
name
head_conv
.
add_sublayer
(
lite_name
,
LiteConv
(
in_channels
=
ch_in
if
i
==
0
else
ch_out
,
out_channels
=
ch_out
,
norm_type
=
norm_type
,
name
=
lite_name
))
head_conv
.
add_sublayer
(
lite_name
+
'.act'
,
nn
.
ReLU6
())
else
:
if
dcn_head
:
head_conv
.
add_sublayer
(
name
,
DeformableConvV2
(
in_channels
=
ch_in
if
i
==
0
else
ch_out
,
out_channels
=
ch_out
,
kernel_size
=
3
,
weight_attr
=
ParamAttr
(
initializer
=
Normal
(
0
,
0.01
)),
name
=
'hm.'
+
name
))
else
:
head_conv
.
add_sublayer
(
name
,
nn
.
Conv2D
(
in_channels
=
ch_in
if
i
==
0
else
ch_out
,
out_channels
=
ch_out
,
kernel_size
=
3
,
padding
=
1
,
weight_attr
=
ParamAttr
(
initializer
=
Normal
(
0
,
0.01
)),
bias_attr
=
ParamAttr
(
learning_rate
=
2.
,
regularizer
=
L2Decay
(
0.
))))
head_conv
.
add_sublayer
(
name
+
'.act'
,
nn
.
ReLU
())
self
.
feat
=
self
.
add_sublayer
(
'hm_feat'
,
head_conv
)
bias_init
=
float
(
-
np
.
log
((
1
-
0.01
)
/
0.01
))
self
.
head
=
self
.
add_sublayer
(
...
...
@@ -78,26 +113,59 @@ class WHHead(nn.Layer):
ch_in (int): The channel number of input Tensor.
ch_out (int): The channel number of output Tensor.
conv_num (int): The convolution number of wh_feat.
dcn_head(bool): whether use dcn in head. False by default.
lite_head(bool): whether use lite version. False by default.
norm_type (string): norm type, 'sync_bn', 'bn', 'gn' are optional.
bn by default
Return:
Width & Height head output
"""
__shared__
=
[
'norm_type'
]
def
__init__
(
self
,
ch_in
,
ch_out
=
64
,
conv_num
=
2
):
def
__init__
(
self
,
ch_in
,
ch_out
=
64
,
conv_num
=
2
,
dcn_head
=
False
,
lite_head
=
False
,
norm_type
=
'bn'
):
super
(
WHHead
,
self
).
__init__
()
head_conv
=
nn
.
Sequential
()
for
i
in
range
(
conv_num
):
name
=
'conv.{}'
.
format
(
i
)
head_conv
.
add_sublayer
(
name
,
nn
.
Conv2D
(
in_channels
=
ch_in
if
i
==
0
else
ch_out
,
out_channels
=
ch_out
,
kernel_size
=
3
,
padding
=
1
,
weight_attr
=
ParamAttr
(
initializer
=
Normal
(
0
,
0.001
)),
bias_attr
=
ParamAttr
(
learning_rate
=
2.
,
regularizer
=
L2Decay
(
0.
))))
head_conv
.
add_sublayer
(
name
+
'.act'
,
nn
.
ReLU
())
if
lite_head
:
lite_name
=
'wh.'
+
name
head_conv
.
add_sublayer
(
lite_name
,
LiteConv
(
in_channels
=
ch_in
if
i
==
0
else
ch_out
,
out_channels
=
ch_out
,
norm_type
=
norm_type
,
name
=
lite_name
))
head_conv
.
add_sublayer
(
lite_name
+
'.act'
,
nn
.
ReLU6
())
else
:
if
dcn_head
:
head_conv
.
add_sublayer
(
name
,
DeformableConvV2
(
in_channels
=
ch_in
if
i
==
0
else
ch_out
,
out_channels
=
ch_out
,
kernel_size
=
3
,
weight_attr
=
ParamAttr
(
initializer
=
Normal
(
0
,
0.01
)),
name
=
'wh.'
+
name
))
else
:
head_conv
.
add_sublayer
(
name
,
nn
.
Conv2D
(
in_channels
=
ch_in
if
i
==
0
else
ch_out
,
out_channels
=
ch_out
,
kernel_size
=
3
,
padding
=
1
,
weight_attr
=
ParamAttr
(
initializer
=
Normal
(
0
,
0.01
)),
bias_attr
=
ParamAttr
(
learning_rate
=
2.
,
regularizer
=
L2Decay
(
0.
))))
head_conv
.
add_sublayer
(
name
+
'.act'
,
nn
.
ReLU
())
self
.
feat
=
self
.
add_sublayer
(
'wh_feat'
,
head_conv
)
self
.
head
=
self
.
add_sublayer
(
'wh_head'
,
...
...
@@ -137,9 +205,12 @@ class TTFHead(nn.Layer):
16.0 by default.
down_ratio (int): the actual down_ratio is calculated by base_down_ratio
(default 16) and the number of upsample layers.
lite_head(bool): whether use lite version. False by default.
norm_type (string): norm type, 'sync_bn', 'bn', 'gn' are optional.
bn by default
"""
__shared__
=
[
'num_classes'
,
'down_ratio'
]
__shared__
=
[
'num_classes'
,
'down_ratio'
,
'norm_type'
]
__inject__
=
[
'hm_loss'
,
'wh_loss'
]
def
__init__
(
self
,
...
...
@@ -152,12 +223,16 @@ class TTFHead(nn.Layer):
hm_loss
=
'CTFocalLoss'
,
wh_loss
=
'GIoULoss'
,
wh_offset_base
=
16.
,
down_ratio
=
4
):
down_ratio
=
4
,
dcn_head
=
False
,
lite_head
=
False
,
norm_type
=
'bn'
):
super
(
TTFHead
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
hm_head
=
HMHead
(
in_channels
,
hm_head_planes
,
num_classes
,
hm_head_conv_num
)
self
.
wh_head
=
WHHead
(
in_channels
,
wh_head_planes
,
wh_head_conv_num
)
hm_head_conv_num
,
dcn_head
,
lite_head
,
norm_type
)
self
.
wh_head
=
WHHead
(
in_channels
,
wh_head_planes
,
wh_head_conv_num
,
dcn_head
,
lite_head
,
norm_type
)
self
.
hm_loss
=
hm_loss
self
.
wh_loss
=
wh_loss
...
...
ppdet/modeling/layers.py
浏览文件 @
866332ed
...
...
@@ -23,7 +23,7 @@ from paddle import ParamAttr
from
paddle
import
to_tensor
from
paddle.nn
import
Conv2D
,
BatchNorm2D
,
GroupNorm
import
paddle.nn.functional
as
F
from
paddle.nn.initializer
import
Normal
,
Constant
from
paddle.nn.initializer
import
Normal
,
Constant
,
XavierUniform
from
paddle.regularizer
import
L2Decay
from
ppdet.core.workspace
import
register
,
serializable
...
...
@@ -112,6 +112,7 @@ class ConvNormLayer(nn.Layer):
ch_out
,
filter_size
,
stride
,
groups
=
1
,
norm_type
=
'bn'
,
norm_decay
=
0.
,
norm_groups
=
32
,
...
...
@@ -142,7 +143,7 @@ class ConvNormLayer(nn.Layer):
kernel_size
=
filter_size
,
stride
=
stride
,
padding
=
(
filter_size
-
1
)
//
2
,
groups
=
1
,
groups
=
groups
,
weight_attr
=
ParamAttr
(
name
=
name
+
"_weight"
,
initializer
=
initializer
,
...
...
@@ -158,7 +159,7 @@ class ConvNormLayer(nn.Layer):
kernel_size
=
filter_size
,
stride
=
stride
,
padding
=
(
filter_size
-
1
)
//
2
,
groups
=
1
,
groups
=
groups
,
weight_attr
=
ParamAttr
(
name
=
name
+
"_weight"
,
initializer
=
initializer
,
...
...
@@ -197,6 +198,71 @@ class ConvNormLayer(nn.Layer):
return
out
class
LiteConv
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
stride
=
1
,
with_act
=
True
,
norm_type
=
'sync_bn'
,
name
=
None
):
super
(
LiteConv
,
self
).
__init__
()
self
.
lite_conv
=
nn
.
Sequential
()
conv1
=
ConvNormLayer
(
in_channels
,
in_channels
,
filter_size
=
5
,
stride
=
stride
,
groups
=
in_channels
,
norm_type
=
norm_type
,
initializer
=
XavierUniform
(),
norm_name
=
name
+
'.conv1.norm'
,
name
=
name
+
'.conv1'
)
conv2
=
ConvNormLayer
(
in_channels
,
out_channels
,
filter_size
=
1
,
stride
=
stride
,
norm_type
=
norm_type
,
initializer
=
XavierUniform
(),
norm_name
=
name
+
'.conv2.norm'
,
name
=
name
+
'.conv2'
)
conv3
=
ConvNormLayer
(
out_channels
,
out_channels
,
filter_size
=
1
,
stride
=
stride
,
norm_type
=
norm_type
,
initializer
=
XavierUniform
(),
norm_name
=
name
+
'.conv3.norm'
,
name
=
name
+
'.conv3'
)
conv4
=
ConvNormLayer
(
out_channels
,
out_channels
,
filter_size
=
5
,
stride
=
stride
,
groups
=
out_channels
,
norm_type
=
norm_type
,
initializer
=
XavierUniform
(),
norm_name
=
name
+
'.conv4.norm'
,
name
=
name
+
'.conv4'
)
conv_list
=
[
conv1
,
conv2
,
conv3
,
conv4
]
self
.
lite_conv
.
add_sublayer
(
'conv1'
,
conv1
)
self
.
lite_conv
.
add_sublayer
(
'relu6_1'
,
nn
.
ReLU6
())
self
.
lite_conv
.
add_sublayer
(
'conv2'
,
conv2
)
if
with_act
:
self
.
lite_conv
.
add_sublayer
(
'relu6_2'
,
nn
.
ReLU6
())
self
.
lite_conv
.
add_sublayer
(
'conv3'
,
conv3
)
self
.
lite_conv
.
add_sublayer
(
'relu6_3'
,
nn
.
ReLU6
())
self
.
lite_conv
.
add_sublayer
(
'conv4'
,
conv4
)
if
with_act
:
self
.
lite_conv
.
add_sublayer
(
'relu6_4'
,
nn
.
ReLU6
())
def
forward
(
self
,
inputs
):
out
=
self
.
lite_conv
(
inputs
)
return
out
@
register
@
serializable
class
AnchorGeneratorRPN
(
object
):
...
...
ppdet/modeling/necks/ttf_fpn.py
浏览文件 @
866332ed
...
...
@@ -16,11 +16,11 @@ import paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddle
import
ParamAttr
from
paddle.nn.initializer
import
Constant
,
Uniform
,
Normal
from
paddle.nn.initializer
import
Constant
,
Uniform
,
Normal
,
XavierUniform
from
paddle
import
ParamAttr
from
ppdet.core.workspace
import
register
,
serializable
from
paddle.regularizer
import
L2Decay
from
ppdet.modeling.layers
import
DeformableConvV2
from
ppdet.modeling.layers
import
DeformableConvV2
,
ConvNormLayer
,
LiteConv
import
math
from
ppdet.modeling.ops
import
batch_norm
from
..shape_spec
import
ShapeSpec
...
...
@@ -29,7 +29,7 @@ __all__ = ['TTFFPN']
class
Upsample
(
nn
.
Layer
):
def
__init__
(
self
,
ch_in
,
ch_out
,
name
=
None
):
def
__init__
(
self
,
ch_in
,
ch_out
,
n
orm_type
=
'bn'
,
n
ame
=
None
):
super
(
Upsample
,
self
).
__init__
()
fan_in
=
ch_in
*
3
*
3
stdv
=
1.
/
math
.
sqrt
(
fan_in
)
...
...
@@ -46,7 +46,7 @@ class Upsample(nn.Layer):
regularizer
=
L2Decay
(
0.
))
self
.
bn
=
batch_norm
(
ch_out
,
norm_type
=
'bn'
,
initializer
=
Constant
(
1.
),
name
=
name
)
ch_out
,
norm_type
=
norm_type
,
initializer
=
Constant
(
1.
),
name
=
name
)
def
forward
(
self
,
feat
):
dcn
=
self
.
dcn
(
feat
)
...
...
@@ -56,28 +56,105 @@ class Upsample(nn.Layer):
return
out
class
DeConv
(
nn
.
Layer
):
def
__init__
(
self
,
ch_in
,
ch_out
,
norm_type
=
'bn'
,
name
=
None
):
super
(
DeConv
,
self
).
__init__
()
self
.
deconv
=
nn
.
Sequential
()
conv1
=
ConvNormLayer
(
ch_in
=
ch_in
,
ch_out
=
ch_out
,
stride
=
1
,
filter_size
=
1
,
norm_type
=
norm_type
,
initializer
=
XavierUniform
(),
norm_name
=
name
+
'.conv1.norm'
,
name
=
name
+
'.conv1'
)
conv2
=
nn
.
Conv2DTranspose
(
in_channels
=
ch_out
,
out_channels
=
ch_out
,
kernel_size
=
4
,
padding
=
1
,
stride
=
2
,
groups
=
ch_out
,
weight_attr
=
ParamAttr
(
initializer
=
XavierUniform
()),
bias_attr
=
False
)
bn
=
batch_norm
(
ch_out
,
norm_type
=
norm_type
,
norm_decay
=
0.
,
name
=
name
+
'.bn'
)
conv3
=
ConvNormLayer
(
ch_in
=
ch_out
,
ch_out
=
ch_out
,
stride
=
1
,
filter_size
=
1
,
norm_type
=
norm_type
,
initializer
=
XavierUniform
(),
norm_name
=
name
+
'.conv3.norm'
,
name
=
name
+
'.conv3'
)
self
.
deconv
.
add_sublayer
(
'conv1'
,
conv1
)
self
.
deconv
.
add_sublayer
(
'relu6_1'
,
nn
.
ReLU6
())
self
.
deconv
.
add_sublayer
(
'conv2'
,
conv2
)
self
.
deconv
.
add_sublayer
(
'bn'
,
bn
)
self
.
deconv
.
add_sublayer
(
'relu6_2'
,
nn
.
ReLU6
())
self
.
deconv
.
add_sublayer
(
'conv3'
,
conv3
)
self
.
deconv
.
add_sublayer
(
'relu6_3'
,
nn
.
ReLU6
())
def
forward
(
self
,
inputs
):
return
self
.
deconv
(
inputs
)
class
LiteUpsample
(
nn
.
Layer
):
def
__init__
(
self
,
ch_in
,
ch_out
,
norm_type
=
'bn'
,
name
=
None
):
super
(
LiteUpsample
,
self
).
__init__
()
self
.
deconv
=
DeConv
(
ch_in
,
ch_out
,
norm_type
=
norm_type
,
name
=
name
+
'.deconv'
)
self
.
conv
=
LiteConv
(
ch_in
,
ch_out
,
norm_type
=
norm_type
,
name
=
name
+
'.liteconv'
)
def
forward
(
self
,
inputs
):
deconv_up
=
self
.
deconv
(
inputs
)
conv
=
self
.
conv
(
inputs
)
interp_up
=
F
.
interpolate
(
conv
,
scale_factor
=
2.
,
mode
=
'bilinear'
)
return
deconv_up
+
interp_up
class
ShortCut
(
nn
.
Layer
):
def
__init__
(
self
,
layer_num
,
ch_out
,
name
=
None
):
def
__init__
(
self
,
layer_num
,
ch_in
,
ch_out
,
norm_type
=
'bn'
,
lite_neck
=
False
,
name
=
None
):
super
(
ShortCut
,
self
).
__init__
()
shortcut_conv
=
nn
.
Sequential
()
ch_in
=
ch_out
*
2
for
i
in
range
(
layer_num
):
fan_out
=
3
*
3
*
ch_out
std
=
math
.
sqrt
(
2.
/
fan_out
)
in_channels
=
ch_in
if
i
==
0
else
ch_out
shortcut_name
=
name
+
'.conv.{}'
.
format
(
i
)
shortcut_conv
.
add_sublayer
(
shortcut_name
,
nn
.
Conv2D
(
in_channels
=
in_channels
,
out_channels
=
ch_out
,
kernel_size
=
3
,
padding
=
1
,
weight_attr
=
ParamAttr
(
initializer
=
Normal
(
0
,
std
)),
bias_attr
=
ParamAttr
(
learning_rate
=
2.
,
regularizer
=
L2Decay
(
0.
))))
if
i
<
layer_num
-
1
:
shortcut_conv
.
add_sublayer
(
shortcut_name
+
'.act'
,
nn
.
ReLU
())
if
lite_neck
:
shortcut_conv
.
add_sublayer
(
shortcut_name
,
LiteConv
(
in_channels
=
in_channels
,
out_channels
=
ch_out
,
with_act
=
i
<
layer_num
-
1
,
norm_type
=
norm_type
,
name
=
shortcut_name
))
else
:
shortcut_conv
.
add_sublayer
(
shortcut_name
,
nn
.
Conv2D
(
in_channels
=
in_channels
,
out_channels
=
ch_out
,
kernel_size
=
3
,
padding
=
1
,
weight_attr
=
ParamAttr
(
initializer
=
Normal
(
0
,
std
)),
bias_attr
=
ParamAttr
(
learning_rate
=
2.
,
regularizer
=
L2Decay
(
0.
))))
if
i
<
layer_num
-
1
:
shortcut_conv
.
add_sublayer
(
shortcut_name
+
'.act'
,
nn
.
ReLU
())
self
.
shortcut
=
self
.
add_sublayer
(
'short'
,
shortcut_conv
)
def
forward
(
self
,
feat
):
...
...
@@ -93,35 +170,68 @@ class TTFFPN(nn.Layer):
in_channels (list): number of input feature channels from backbone.
[128,256,512,1024] by default, means the channels of DarkNet53
backbone return_idx [1,2,3,4].
planes (list): the number of output feature channels of FPN.
[256, 128, 64] by default
shortcut_num (list): the number of convolution layers in each shortcut.
[3,2,1] by default, means DarkNet53 backbone return_idx_1 has 3 convs
in its shortcut, return_idx_2 has 2 convs and return_idx_3 has 1 conv.
norm_type (string): norm type, 'sync_bn', 'bn', 'gn' are optional.
bn by default
lite_neck (bool): whether to use lite conv in TTFNet FPN,
False by default
fusion_method (string): the method to fusion upsample and lateral layer.
'add' and 'concat' are optional, add by default
"""
__shared__
=
[
'norm_type'
]
def
__init__
(
self
,
in_channels
=
[
128
,
256
,
512
,
1024
],
shortcut_num
=
[
3
,
2
,
1
]):
in_channels
,
planes
=
[
256
,
128
,
64
],
shortcut_num
=
[
3
,
2
,
1
],
norm_type
=
'bn'
,
lite_neck
=
False
,
fusion_method
=
'add'
):
super
(
TTFFPN
,
self
).
__init__
()
self
.
planes
=
[
c
//
2
for
c
in
in_channels
[:
-
1
]][::
-
1
]
self
.
planes
=
planes
self
.
shortcut_num
=
shortcut_num
[::
-
1
]
self
.
shortcut_len
=
len
(
shortcut_num
)
self
.
ch_in
=
in_channels
[::
-
1
]
self
.
fusion_method
=
fusion_method
self
.
upsample_list
=
[]
self
.
shortcut_list
=
[]
self
.
upper_list
=
[]
for
i
,
out_c
in
enumerate
(
self
.
planes
):
in_c
=
self
.
ch_in
[
i
]
if
i
==
0
else
self
.
ch_in
[
i
]
//
2
in_c
=
self
.
ch_in
[
i
]
if
i
==
0
else
self
.
upper_list
[
-
1
]
upsample_module
=
LiteUpsample
if
lite_neck
else
Upsample
upsample
=
self
.
add_sublayer
(
'upsample.'
+
str
(
i
),
Upsample
(
in_c
,
out_c
,
name
=
'upsample.'
+
str
(
i
)))
upsample_module
(
in_c
,
out_c
,
norm_type
=
norm_type
,
name
=
'deconv_layers.'
+
str
(
i
)))
self
.
upsample_list
.
append
(
upsample
)
if
i
<
self
.
shortcut_len
:
shortcut
=
self
.
add_sublayer
(
'shortcut.'
+
str
(
i
),
ShortCut
(
self
.
shortcut_num
[
i
],
out_c
,
name
=
'shortcut.'
+
str
(
i
)))
self
.
shortcut_num
[
i
],
self
.
ch_in
[
i
+
1
],
out_c
,
norm_type
=
norm_type
,
lite_neck
=
lite_neck
,
name
=
'shortcut.'
+
str
(
i
)))
self
.
shortcut_list
.
append
(
shortcut
)
if
self
.
fusion_method
==
'add'
:
upper_c
=
out_c
elif
self
.
fusion_method
==
'concat'
:
upper_c
=
out_c
*
2
else
:
raise
ValueError
(
'Illegal fusion method. Expected add or
\
concat, but received {}'
.
format
(
self
.
fusion_method
))
self
.
upper_list
.
append
(
upper_c
)
def
forward
(
self
,
inputs
):
feat
=
inputs
[
-
1
]
...
...
@@ -129,7 +239,10 @@ class TTFFPN(nn.Layer):
feat
=
self
.
upsample_list
[
i
](
feat
)
if
i
<
self
.
shortcut_len
:
shortcut
=
self
.
shortcut_list
[
i
](
inputs
[
-
i
-
2
])
feat
=
feat
+
shortcut
if
self
.
fusion_method
==
'add'
:
feat
=
feat
+
shortcut
else
:
feat
=
paddle
.
concat
([
feat
,
shortcut
],
axis
=
1
)
return
feat
@
classmethod
...
...
@@ -138,4 +251,4 @@ class TTFFPN(nn.Layer):
@
property
def
out_shape
(
self
):
return
[
ShapeSpec
(
channels
=
self
.
planes
[
-
1
],
)]
return
[
ShapeSpec
(
channels
=
self
.
upper_list
[
-
1
],
)]
static/configs/anchor_free/pafnet_10x_coco.yml
0 → 100644
浏览文件 @
866332ed
architecture
:
TTFNet
use_gpu
:
true
max_iters
:
150000
log_smooth_window
:
20
save_dir
:
output
snapshot_iter
:
10000
metric
:
COCO
pretrain_weights
:
https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar
weights
:
output/pafnet_10x_coco/model_final
num_classes
:
80
use_ema
:
true
ema_decay
:
0.9998
TTFNet
:
backbone
:
ResNet
ttf_head
:
TTFHead
ResNet
:
norm_type
:
sync_bn
freeze_at
:
0
freeze_norm
:
false
norm_decay
:
0.
depth
:
50
feature_maps
:
[
2
,
3
,
4
,
5
]
variant
:
d
dcn_v2_stages
:
[
3
,
4
,
5
]
TTFHead
:
head_conv
:
128
wh_conv
:
64
hm_head_conv_num
:
2
wh_head_conv_num
:
2
wh_offset_base
:
16
wh_loss
:
GiouLoss
dcn_head
:
True
GiouLoss
:
loss_weight
:
5.
do_average
:
false
use_class_weight
:
false
LearningRate
:
base_lr
:
0.015
schedulers
:
-
!PiecewiseDecay
gamma
:
0.1
milestones
:
-
112500
-
137500
-
!LinearWarmup
start_factor
:
0.2
steps
:
500
OptimizerBuilder
:
optimizer
:
momentum
:
0.9
type
:
Momentum
regularizer
:
factor
:
0.0004
type
:
L2
TrainReader
:
inputs_def
:
fields
:
[
'
image'
,
'
ttf_heatmap'
,
'
ttf_box_target'
,
'
ttf_reg_weight'
]
dataset
:
!COCODataSet
image_dir
:
train2017
anno_path
:
annotations/instances_train2017.json
dataset_dir
:
dataset/coco
with_background
:
false
sample_transforms
:
-
!DecodeImage
to_rgb
:
true
with_cutmix
:
True
-
!CutmixImage
alpha
:
1.5
beta
:
1.5
-
!ColorDistort
hue
:
[
-18.
,
18.
,
0.5
]
saturation
:
[
0.5
,
1.5
,
0.5
]
contrast
:
[
0.5
,
1.5
,
0.5
]
brightness
:
[
-32.
,
32.
,
0.5
]
random_apply
:
False
hsv_format
:
True
random_channel
:
True
-
!RandomExpand
ratio
:
4
prob
:
0.5
fill_value
:
[
123.675
,
116.28
,
103.53
]
-
!RandomCrop
aspect_ratio
:
NULL
cover_all_box
:
True
-
!RandomFlipImage
prob
:
0.5
batch_transforms
:
-
!RandomShape
sizes
:
[
416
,
448
,
480
,
512
,
544
,
576
,
608
,
640
,
672
]
random_inter
:
True
resize_box
:
True
-
!NormalizeImage
is_channel_first
:
false
is_scale
:
false
mean
:
[
123.675
,
116.28
,
103.53
]
std
:
[
58.395
,
57.12
,
57.375
]
-
!Permute
to_bgr
:
false
channel_first
:
true
-
!Gt2TTFTarget
num_classes
:
80
down_ratio
:
4
-
!PadBatch
pad_to_stride
:
32
batch_size
:
12
shuffle
:
true
worker_num
:
8
bufsize
:
2
use_process
:
false
cutmix_epoch
:
100
EvalReader
:
inputs_def
:
image_shape
:
[
3
,
512
,
512
]
fields
:
[
'
image'
,
'
im_id'
,
'
scale_factor'
]
dataset
:
!COCODataSet
image_dir
:
val2017
anno_path
:
annotations/instances_val2017.json
dataset_dir
:
dataset/coco
with_background
:
false
sample_transforms
:
-
!DecodeImage
to_rgb
:
True
-
!Resize
target_dim
:
512
-
!NormalizeImage
mean
:
[
123.675
,
116.28
,
103.53
]
std
:
[
58.395
,
57.12
,
57.375
]
is_scale
:
false
is_channel_first
:
false
-
!Permute
to_bgr
:
false
channel_first
:
True
batch_size
:
1
drop_empty
:
false
worker_num
:
8
bufsize
:
16
TestReader
:
inputs_def
:
image_shape
:
[
3
,
512
,
512
]
fields
:
[
'
image'
,
'
im_id'
,
'
scale_factor'
]
dataset
:
!ImageFolder
anno_path
:
annotations/instances_val2017.json
with_background
:
false
sample_transforms
:
-
!DecodeImage
to_rgb
:
True
-
!Resize
interp
:
1
target_dim
:
512
-
!NormalizeImage
mean
:
[
123.675
,
116.28
,
103.53
]
std
:
[
58.395
,
57.12
,
57.375
]
is_scale
:
false
is_channel_first
:
false
-
!Permute
to_bgr
:
false
channel_first
:
True
batch_size
:
1
static/configs/anchor_free/pafnet_lite_mobilenet_v3_20x_coco.yml
0 → 100644
浏览文件 @
866332ed
architecture
:
TTFNet
use_gpu
:
true
max_iters
:
300000
log_smooth_window
:
20
save_dir
:
output
snapshot_iter
:
50000
metric
:
COCO
pretrain_weights
:
https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x1_0_ssld_pretrained.tar
weights
:
output/pafnet_lite_mobilenet_v3_20x_coco/model_final
num_classes
:
80
TTFNet
:
backbone
:
MobileNetV3RCNN
ttf_head
:
TTFLiteHead
MobileNetV3RCNN
:
norm_type
:
sync_bn
norm_decay
:
0.0
model_name
:
large
scale
:
1.0
conv_decay
:
0.00001
lr_mult_list
:
[
0.25
,
0.25
,
0.5
,
0.5
,
0.75
]
freeze_norm
:
false
TTFLiteHead
:
head_conv
:
48
GiouLoss
:
loss_weight
:
5.
do_average
:
false
use_class_weight
:
false
LearningRate
:
base_lr
:
0.015
schedulers
:
-
!PiecewiseDecay
gamma
:
0.1
milestones
:
-
225000
-
275000
-
!LinearWarmup
start_factor
:
0.2
steps
:
1000
OptimizerBuilder
:
clip_grad_by_norm
:
35
optimizer
:
momentum
:
0.9
type
:
Momentum
regularizer
:
factor
:
0.0004
type
:
L2
TrainReader
:
inputs_def
:
fields
:
[
'
image'
,
'
ttf_heatmap'
,
'
ttf_box_target'
,
'
ttf_reg_weight'
]
dataset
:
!COCODataSet
image_dir
:
train2017
anno_path
:
annotations/instances_train2017.json
dataset_dir
:
dataset/coco
with_background
:
false
sample_transforms
:
-
!DecodeImage
to_rgb
:
true
with_cutmix
:
True
-
!ColorDistort
hue
:
[
-18.
,
18.
,
0.5
]
saturation
:
[
0.5
,
1.5
,
0.5
]
contrast
:
[
0.5
,
1.5
,
0.5
]
brightness
:
[
-32.
,
32.
,
0.5
]
random_apply
:
False
hsv_format
:
False
random_channel
:
True
-
!RandomExpand
ratio
:
4
prob
:
0.5
fill_value
:
[
123.675
,
116.28
,
103.53
]
-
!RandomCrop
aspect_ratio
:
NULL
cover_all_box
:
True
-
!CutmixImage
alpha
:
1.5
beta
:
1.5
-
!RandomFlipImage
prob
:
0.5
-
!GridMaskOp
use_h
:
true
use_w
:
true
rotate
:
1
offset
:
false
ratio
:
0.5
mode
:
1
prob
:
0.7
upper_iter
:
300000
batch_transforms
:
-
!RandomShape
sizes
:
[
320
,
352
,
384
,
416
,
448
,
480
,
512
]
random_inter
:
True
resize_box
:
True
-
!NormalizeImage
is_channel_first
:
false
is_scale
:
false
mean
:
[
123.675
,
116.28
,
103.53
]
std
:
[
58.395
,
57.12
,
57.375
]
-
!Permute
to_bgr
:
false
channel_first
:
true
-
!Gt2TTFTarget
num_classes
:
80
down_ratio
:
4
-
!PadBatch
pad_to_stride
:
32
batch_size
:
12
shuffle
:
true
worker_num
:
8
bufsize
:
2
use_process
:
false
cutmix_epoch
:
200
EvalReader
:
inputs_def
:
image_shape
:
[
3
,
320
,
320
]
fields
:
[
'
image'
,
'
im_id'
,
'
scale_factor'
]
dataset
:
!COCODataSet
image_dir
:
val2017
anno_path
:
annotations/instances_val2017.json
dataset_dir
:
dataset/coco
with_background
:
false
sample_transforms
:
-
!DecodeImage
to_rgb
:
True
-
!Resize
target_dim
:
320
-
!NormalizeImage
mean
:
[
123.675
,
116.28
,
103.53
]
std
:
[
58.395
,
57.12
,
57.375
]
is_scale
:
false
is_channel_first
:
false
-
!Permute
to_bgr
:
false
channel_first
:
True
batch_size
:
1
drop_empty
:
false
worker_num
:
2
bufsize
:
2
TestReader
:
inputs_def
:
image_shape
:
[
3
,
320
,
320
]
fields
:
[
'
image'
,
'
im_id'
,
'
scale_factor'
]
dataset
:
!ImageFolder
anno_path
:
annotations/instances_val2017.json
with_background
:
false
sample_transforms
:
-
!DecodeImage
to_rgb
:
True
-
!Resize
interp
:
1
target_dim
:
320
-
!NormalizeImage
mean
:
[
123.675
,
116.28
,
103.53
]
std
:
[
58.395
,
57.12
,
57.375
]
is_scale
:
false
is_channel_first
:
false
-
!Permute
to_bgr
:
false
channel_first
:
True
batch_size
:
1
static/ppdet/modeling/anchor_heads/ttf_head.py
浏览文件 @
866332ed
...
...
@@ -24,10 +24,10 @@ from paddle.fluid.param_attr import ParamAttr
from
paddle.fluid.initializer
import
Normal
,
Constant
,
Uniform
,
Xavier
from
paddle.fluid.regularizer
import
L2Decay
from
ppdet.core.workspace
import
register
from
ppdet.modeling.ops
import
DeformConv
,
DropBlock
from
ppdet.modeling.ops
import
DeformConv
,
DropBlock
,
ConvNorm
from
ppdet.modeling.losses
import
GiouLoss
__all__
=
[
'TTFHead'
]
__all__
=
[
'TTFHead'
,
'TTFLiteHead'
]
@
register
...
...
@@ -65,6 +65,8 @@ class TTFHead(object):
drop_block(bool): whether use dropblock. False by default.
block_size(int): block_size parameter for drop_block. 3 by default.
keep_prob(float): keep_prob parameter for drop_block. 0.9 by default.
fusion_method (string): Method to fusion upsample and lateral branch.
'add' and 'concat' are optional, add by default
"""
__inject__
=
[
'wh_loss'
]
...
...
@@ -90,7 +92,8 @@ class TTFHead(object):
dcn_head
=
False
,
drop_block
=
False
,
block_size
=
3
,
keep_prob
=
0.9
):
keep_prob
=
0.9
,
fusion_method
=
'add'
):
super
(
TTFHead
,
self
).
__init__
()
self
.
head_conv
=
head_conv
self
.
num_classes
=
num_classes
...
...
@@ -115,6 +118,7 @@ class TTFHead(object):
self
.
drop_block
=
drop_block
self
.
block_size
=
block_size
self
.
keep_prob
=
keep_prob
self
.
fusion_method
=
fusion_method
def
shortcut
(
self
,
x
,
out_c
,
layer_num
,
kernel_size
=
3
,
padding
=
1
,
name
=
None
):
...
...
@@ -255,7 +259,14 @@ class TTFHead(object):
out_c
,
self
.
shortcut_num
[
i
],
name
=
name
+
'.shortcut_layers.'
+
str
(
i
))
feat
=
fluid
.
layers
.
elementwise_add
(
feat
,
shortcut
)
if
self
.
fusion_method
==
'add'
:
feat
=
fluid
.
layers
.
elementwise_add
(
feat
,
shortcut
)
elif
self
.
fusion_method
==
'concat'
:
feat
=
fluid
.
layers
.
concat
([
feat
,
shortcut
],
axis
=
1
)
else
:
raise
ValueError
(
"Illegal fusion method, expected 'add' or 'concat', but received {}"
.
format
(
self
.
fusion_method
))
hm
=
self
.
hm_head
(
feat
,
name
=
name
+
'.hm'
,
is_test
=
is_test
)
wh
=
self
.
wh_head
(
feat
,
name
=
name
+
'.wh'
)
*
self
.
wh_offset_base
...
...
@@ -273,12 +284,13 @@ class TTFHead(object):
# batch size is 1
scores_r
=
fluid
.
layers
.
reshape
(
scores
,
[
cat
,
-
1
])
topk_scores
,
topk_inds
=
fluid
.
layers
.
topk
(
scores_r
,
k
)
topk_ys
=
topk_inds
/
width
topk_ys
=
topk_inds
/
/
width
topk_xs
=
topk_inds
%
width
topk_score_r
=
fluid
.
layers
.
reshape
(
topk_scores
,
[
-
1
])
topk_score
,
topk_ind
=
fluid
.
layers
.
topk
(
topk_score_r
,
k
)
topk_clses
=
fluid
.
layers
.
cast
(
topk_ind
/
k
,
'float32'
)
k_t
=
fluid
.
layers
.
assign
(
np
.
array
([
k
],
dtype
=
'int64'
))
topk_clses
=
fluid
.
layers
.
cast
(
topk_ind
/
k_t
,
'float32'
)
topk_inds
=
fluid
.
layers
.
reshape
(
topk_inds
,
[
-
1
])
topk_ys
=
fluid
.
layers
.
reshape
(
topk_ys
,
[
-
1
,
1
])
...
...
@@ -384,3 +396,172 @@ class TTFHead(object):
ttf_loss
=
{
'hm_loss'
:
hm_loss
,
'wh_loss'
:
wh_loss
}
return
ttf_loss
@
register
class
TTFLiteHead
(
TTFHead
):
"""
TTFLiteHead
Lite version for TTFNet
Args:
head_conv(int): the default channel number of convolution in head.
32 by default.
num_classes(int): the number of classes, 80 by default.
planes(tuple): the channel number of convolution in each upsample.
(96, 48, 24) by default.
wh_conv(int): the channel number of convolution in wh head.
24 by default.
wh_loss(object): `GiouLoss` instance.
shortcut_num(tuple): the number of convolution layers in each shortcut.
(1, 2, 2) by default.
fusion_method (string): Method to fusion upsample and lateral branch.
'add' and 'concat' are optional, add by default
"""
__inject__
=
[
'wh_loss'
]
__shared__
=
[
'num_classes'
]
def
__init__
(
self
,
head_conv
=
32
,
num_classes
=
80
,
planes
=
(
96
,
48
,
24
),
wh_conv
=
24
,
wh_loss
=
'GiouLoss'
,
shortcut_num
=
(
1
,
2
,
2
),
fusion_method
=
'concat'
):
super
(
TTFLiteHead
,
self
).
__init__
(
head_conv
=
head_conv
,
num_classes
=
num_classes
,
planes
=
planes
,
wh_conv
=
wh_conv
,
wh_loss
=
wh_loss
,
shortcut_num
=
shortcut_num
,
fusion_method
=
fusion_method
)
def
_lite_conv
(
self
,
x
,
out_c
,
act
=
None
,
name
=
None
):
conv1
=
ConvNorm
(
input
=
x
,
num_filters
=
x
.
shape
[
1
],
filter_size
=
5
,
groups
=
x
.
shape
[
1
],
norm_type
=
'bn'
,
act
=
'relu6'
,
initializer
=
Xavier
(),
name
=
name
+
'.depthwise'
,
norm_name
=
name
+
'.depthwise.bn'
)
conv2
=
ConvNorm
(
input
=
conv1
,
num_filters
=
out_c
,
filter_size
=
1
,
norm_type
=
'bn'
,
act
=
act
,
initializer
=
Xavier
(),
name
=
name
+
'.pointwise_linear'
,
norm_name
=
name
+
'.pointwise_linear.bn'
)
conv3
=
ConvNorm
(
input
=
conv2
,
num_filters
=
out_c
,
filter_size
=
1
,
norm_type
=
'bn'
,
act
=
'relu6'
,
initializer
=
Xavier
(),
name
=
name
+
'.pointwise'
,
norm_name
=
name
+
'.pointwise.bn'
)
conv4
=
ConvNorm
(
input
=
conv3
,
num_filters
=
out_c
,
filter_size
=
5
,
groups
=
out_c
,
norm_type
=
'bn'
,
act
=
act
,
initializer
=
Xavier
(),
name
=
name
+
'.depthwise_linear'
,
norm_name
=
name
+
'.depthwise_linear.bn'
)
return
conv4
def
shortcut
(
self
,
x
,
out_c
,
layer_num
,
name
=
None
):
assert
layer_num
>
0
for
i
in
range
(
layer_num
):
param_name
=
name
+
'.layers.'
+
str
(
i
*
2
)
act
=
'relu6'
if
i
<
layer_num
-
1
else
None
x
=
self
.
_lite_conv
(
x
,
out_c
,
act
,
param_name
)
return
x
def
_deconv_upsample
(
self
,
x
,
out_c
,
name
=
None
):
conv1
=
ConvNorm
(
input
=
x
,
num_filters
=
out_c
,
filter_size
=
1
,
norm_type
=
'bn'
,
act
=
'relu6'
,
name
=
name
+
'.pointwise'
,
initializer
=
Xavier
(),
norm_name
=
name
+
'.pointwise.bn'
)
conv2
=
fluid
.
layers
.
conv2d_transpose
(
input
=
conv1
,
num_filters
=
out_c
,
filter_size
=
4
,
padding
=
1
,
stride
=
2
,
groups
=
out_c
,
param_attr
=
ParamAttr
(
name
=
name
+
'.deconv.weights'
,
initializer
=
Xavier
()),
bias_attr
=
False
)
bn
=
fluid
.
layers
.
batch_norm
(
input
=
conv2
,
act
=
'relu6'
,
param_attr
=
ParamAttr
(
name
=
name
+
'.deconv.bn.scale'
,
regularizer
=
L2Decay
(
0.
)),
bias_attr
=
ParamAttr
(
name
=
name
+
'.deconv.bn.offset'
,
regularizer
=
L2Decay
(
0.
)),
moving_mean_name
=
name
+
'.deconv.bn.mean'
,
moving_variance_name
=
name
+
'.deconv.bn.variance'
)
conv3
=
ConvNorm
(
input
=
bn
,
num_filters
=
out_c
,
filter_size
=
1
,
norm_type
=
'bn'
,
act
=
'relu6'
,
name
=
name
+
'.normal'
,
initializer
=
Xavier
(),
norm_name
=
name
+
'.normal.bn'
)
return
conv3
def
_interp_upsample
(
self
,
x
,
out_c
,
name
=
None
):
conv
=
self
.
_lite_conv
(
x
,
out_c
,
'relu6'
,
name
)
up
=
fluid
.
layers
.
resize_bilinear
(
conv
,
scale
=
2
)
return
up
def
upsample
(
self
,
x
,
out_c
,
name
=
None
):
deconv_up
=
self
.
_deconv_upsample
(
x
,
out_c
,
name
=
name
+
'.dilation_up'
)
interp_up
=
self
.
_interp_upsample
(
x
,
out_c
,
name
=
name
+
'.interp_up'
)
return
deconv_up
+
interp_up
def
_head
(
self
,
x
,
out_c
,
conv_num
=
1
,
head_out_c
=
None
,
name
=
None
,
is_test
=
False
):
head_out_c
=
self
.
head_conv
if
not
head_out_c
else
head_out_c
for
i
in
range
(
conv_num
):
conv_name
=
'{}.{}.conv'
.
format
(
name
,
i
)
x
=
self
.
_lite_conv
(
x
,
head_out_c
,
'relu6'
,
conv_name
)
bias_init
=
float
(
-
np
.
log
((
1
-
0.01
)
/
0.01
))
if
'.hm'
in
name
else
0.
conv_b_init
=
Constant
(
bias_init
)
x
=
fluid
.
layers
.
conv2d
(
x
,
out_c
,
1
,
param_attr
=
ParamAttr
(
name
=
'{}.{}.weight'
.
format
(
name
,
conv_num
)),
bias_attr
=
ParamAttr
(
learning_rate
=
2.
,
regularizer
=
L2Decay
(
0.
),
name
=
'{}.{}.bias'
.
format
(
name
,
conv_num
),
initializer
=
conv_b_init
))
return
x
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录