Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
d43e6d9a
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
1 年多 前同步成功
通知
696
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d43e6d9a
编写于
7月 13, 2020
作者:
W
wangguanzhong
提交者:
GitHub
7月 13, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add ttfnet (#1054)
上级
315fd738
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
932 addition
and
74 deletion
+932
-74
configs/anchor_free/README.md
configs/anchor_free/README.md
+16
-4
configs/anchor_free/ttfnet_darknet.yml
configs/anchor_free/ttfnet_darknet.yml
+141
-0
deploy/python/infer.py
deploy/python/infer.py
+11
-8
ppdet/data/transform/batch_operators.py
ppdet/data/transform/batch_operators.py
+111
-6
ppdet/data/transform/op_helper.py
ppdet/data/transform/op_helper.py
+5
-3
ppdet/data/transform/operators.py
ppdet/data/transform/operators.py
+45
-20
ppdet/modeling/anchor_heads/__init__.py
ppdet/modeling/anchor_heads/__init__.py
+2
-0
ppdet/modeling/anchor_heads/ttf_head.py
ppdet/modeling/anchor_heads/ttf_head.py
+383
-0
ppdet/modeling/architectures/__init__.py
ppdet/modeling/architectures/__init__.py
+2
-0
ppdet/modeling/architectures/ttfnet.py
ppdet/modeling/architectures/ttfnet.py
+132
-0
ppdet/modeling/backbones/darknet.py
ppdet/modeling/backbones/darknet.py
+5
-1
ppdet/modeling/losses/giou_loss.py
ppdet/modeling/losses/giou_loss.py
+32
-10
ppdet/modeling/ops.py
ppdet/modeling/ops.py
+41
-22
ppdet/utils/coco_eval.py
ppdet/utils/coco_eval.py
+1
-0
ppdet/utils/eval_utils.py
ppdet/utils/eval_utils.py
+2
-0
tools/export_model.py
tools/export_model.py
+3
-0
未找到文件。
configs/anchor_free/README.md
浏览文件 @
d43e6d9a
...
...
@@ -12,10 +12,12 @@
## 模型库与基线
下表中展示了PaddleDetection当前支持的网络结构,具体细节请参考
[
算法细节
](
#算法细节
)
。
| | ResNet50 | ResNet50-vd | Hourglass104 |
|:------------------------:|:--------:|:--------------------------:|:------------------------:|
|
[
CornerNet-Squeeze
](
#CornerNet-Squeeze
)
| x | ✓ | ✓ |
|
[
FCOS
](
#FCOS
)
| ✓ | x | x |
| | ResNet50 | ResNet50-vd | Hourglass104 | DarkNet53
|:------------------------:|:--------:|:-------------:|:-------------:|:-------------:|
|
[
CornerNet-Squeeze
](
#CornerNet-Squeeze
)
| x | ✓ | ✓ |x |
|
[
FCOS
](
#FCOS
)
| ✓ | x | x | x |
|
[
TTFNet
](
#TTFNet
)
| x | x | x | ✓ |
### 模型库
...
...
@@ -31,6 +33,7 @@
| FCOS | ResNet50 | 2 |
[
ResNet50\_cos\_pretrained
](
https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar
)
| 39.8 | 18.85 |
[
下载链接
](
https://paddlemodels.bj.bcebos.com/object_detection/fcos_r50_fpn_1x.pdparams
)
|
[
配置文件
](
https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/anchor_free/fcos_r50_fpn_1x.yml
)
|
| FCOS+multiscale_train | ResNet50 | 2 |
[
ResNet50\_cos\_pretrained
](
https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar
)
| 42.0 | 19.05 |
[
下载链接
](
https://paddlemodels.bj.bcebos.com/object_detection/fcos_r50_fpn_multiscale_2x.pdparams
)
|
[
配置文件
](
https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/anchor_free/fcos_r50_fpn_multiscale_2x.yml
)
|
| FCOS+DCN | ResNet50 | 2 |
[
ResNet50\_cos\_pretrained
](
https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar
)
| 44.4 | 13.66 |
[
下载链接
](
https://paddlemodels.bj.bcebos.com/object_detection/fcos_dcn_r50_fpn_1x.pdparams
)
|
[
配置文件
](
https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/anchor_free/fcos_dcn_r50_fpn_1x.yml
)
|
| TTFNet | DarkNet53 | 12 |
[
DarkNet53_pretrained
](
https://paddle-imagenet-models-name.bj.bcebos.com/DarkNet53_pretrained.tar
)
| 32.9 | 85.92 |
[
下载链接
](
https://paddlemodels.bj.bcebos.com/object_detection/ttfnet_darknet.pdparams
)
|
[
配置文件
](
https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/anchor_free/ttfnet_darknet.yml
)
|
**注意:**
...
...
@@ -64,5 +67,14 @@
-
通过center-ness单层分支预测当前点是否是目标中心,消除低质量误检
## TTFNet
**简介:**
[
TTFNet
](
https://arxiv.org/abs/1909.00700
)
是一种用于实时目标检测且对训练时间友好的网络,对CenterNet收敛速度慢的问题进行改进,提出了利用高斯核生成训练样本的新方法,有效的消除了anchor-free head中存在的模糊性。同时简单轻量化的网络结构也易于进行任务扩展。
**特点:**
-
结构简单,仅需要两个head检测目标位置和大小,并且去除了耗时的后处理操作
-
训练时间短,基于DarkNet53的骨干网路,V100 8卡仅需要训练2个小时即可达到较好的模型效果
## 如何贡献代码
我们非常欢迎您可以为PaddleDetection中的Anchor Free检测模型提供代码,您可以提交PR供我们review;也十分感谢您的反馈,可以提交相应issue,我们会及时解答。
configs/anchor_free/ttfnet_darknet.yml
0 → 100644
浏览文件 @
d43e6d9a
architecture
:
TTFNet
use_gpu
:
true
max_iters
:
15000
log_smooth_window
:
20
save_dir
:
output
snapshot_iter
:
1000
metric
:
COCO
pretrain_weights
:
https://paddle-imagenet-models-name.bj.bcebos.com/DarkNet53_pretrained.tar
weights
:
output/ttfnet_darknet/model_final
num_classes
:
80
TTFNet
:
backbone
:
DarkNet
ttf_head
:
TTFHead
DarkNet
:
norm_type
:
bn
norm_decay
:
0.0004
depth
:
53
freeze_at
:
1
TTFHead
:
head_conv
:
128
wh_conv
:
64
hm_head_conv_num
:
2
wh_head_conv_num
:
2
wh_offset_base
:
16
wh_loss
:
GiouLoss
GiouLoss
:
loss_weight
:
5.
do_average
:
false
use_class_weight
:
false
LearningRate
:
base_lr
:
0.015
schedulers
:
-
!PiecewiseDecay
gamma
:
0.1
milestones
:
-
11250
-
13750
-
!LinearWarmup
start_factor
:
0.2
steps
:
500
OptimizerBuilder
:
optimizer
:
momentum
:
0.9
type
:
Momentum
regularizer
:
factor
:
0.0004
type
:
L2
TrainReader
:
inputs_def
:
fields
:
[
'
image'
,
'
ttf_heatmap'
,
'
ttf_box_target'
,
'
ttf_reg_weight'
]
dataset
:
!COCODataSet
image_dir
:
train2017
anno_path
:
annotations/instances_train2017.json
dataset_dir
:
dataset/coco
with_background
:
false
sample_transforms
:
-
!DecodeImage
to_rgb
:
true
-
!Resize
target_dim
:
512
-
!RandomFlipImage
prob
:
0.5
-
!NormalizeImage
is_channel_first
:
false
is_scale
:
false
mean
:
[
123.675
,
116.28
,
103.53
]
std
:
[
58.395
,
57.12
,
57.375
]
-
!Permute
to_bgr
:
false
channel_first
:
true
batch_transforms
:
-
!Gt2TTFTarget
num_classes
:
80
down_ratio
:
4
-
!PadBatch
pad_to_stride
:
32
batch_size
:
12
shuffle
:
true
worker_num
:
8
bufsize
:
2
use_process
:
true
EvalReader
:
inputs_def
:
image_shape
:
[
3
,
512
,
512
]
fields
:
[
'
image'
,
'
im_id'
,
'
scale_factor'
]
dataset
:
!COCODataSet
image_dir
:
val2017
anno_path
:
annotations/instances_val2017.json
dataset_dir
:
dataset/coco
with_background
:
false
sample_transforms
:
-
!DecodeImage
to_rgb
:
True
-
!Resize
target_dim
:
512
-
!NormalizeImage
mean
:
[
123.675
,
116.28
,
103.53
]
std
:
[
58.395
,
57.12
,
57.375
]
is_scale
:
false
is_channel_first
:
false
-
!Permute
to_bgr
:
false
channel_first
:
True
batch_size
:
1
drop_empty
:
false
worker_num
:
8
bufsize
:
16
TestReader
:
inputs_def
:
image_shape
:
[
3
,
512
,
512
]
fields
:
[
'
image'
,
'
im_id'
,
'
scale_factor'
]
dataset
:
!ImageFolder
anno_path
:
annotations/instances_val2017.json
with_background
:
false
sample_transforms
:
-
!DecodeImage
to_rgb
:
True
-
!Resize
interp
:
1
target_dim
:
512
-
!NormalizeImage
mean
:
[
123.675
,
116.28
,
103.53
]
std
:
[
58.395
,
57.12
,
57.375
]
is_scale
:
false
is_channel_first
:
false
-
!Permute
to_bgr
:
false
channel_first
:
True
batch_size
:
1
deploy/python/infer.py
浏览文件 @
d43e6d9a
...
...
@@ -115,8 +115,7 @@ class Resize(object):
padding_im
[:
im_h
,
:
im_w
,
:]
=
im
im
=
padding_im
if
self
.
arch
in
self
.
scale_set
:
im_info
[
'scale'
]
=
im_scale_x
im_info
[
'scale'
]
=
[
im_scale_x
,
im_scale_y
]
im_info
[
'resize_shape'
]
=
im
.
shape
[:
2
]
return
im
,
im_info
...
...
@@ -252,18 +251,23 @@ def create_inputs(im, im_info, model_arch='YOLO'):
inputs
[
'image'
]
=
im
origin_shape
=
list
(
im_info
[
'origin_shape'
])
resize_shape
=
list
(
im_info
[
'resize_shape'
])
scale
=
im_info
[
'scale'
]
scale
_x
,
scale_y
=
im_info
[
'scale'
]
if
'YOLO'
in
model_arch
:
im_size
=
np
.
array
([
origin_shape
]).
astype
(
'int32'
)
inputs
[
'im_size'
]
=
im_size
elif
'RetinaNet'
in
model_arch
:
scale
=
scale_x
im_info
=
np
.
array
([
resize_shape
+
[
scale
]]).
astype
(
'float32'
)
inputs
[
'im_info'
]
=
im_info
elif
'RCNN'
in
model_arch
:
scale
=
scale_x
im_info
=
np
.
array
([
resize_shape
+
[
scale
]]).
astype
(
'float32'
)
im_shape
=
np
.
array
([
origin_shape
+
[
1.
]]).
astype
(
'float32'
)
inputs
[
'im_info'
]
=
im_info
inputs
[
'im_shape'
]
=
im_shape
elif
'TTF'
in
model_arch
:
scale_factor
=
np
.
array
([
scale_x
,
scale_y
]
*
2
).
astype
(
'float32'
)
inputs
[
'scale_factor'
]
=
scale_factor
return
inputs
...
...
@@ -272,7 +276,7 @@ class Config():
Args:
model_dir (str): root path of model.yml
"""
support_models
=
[
'YOLO'
,
'SSD'
,
'RetinaNet'
,
'RCNN'
,
'Face'
]
support_models
=
[
'YOLO'
,
'SSD'
,
'RetinaNet'
,
'RCNN'
,
'Face'
,
'TTF'
]
def
__init__
(
self
,
model_dir
):
# parsing Yaml config for Preprocess
...
...
@@ -298,9 +302,8 @@ class Config():
for
support_model
in
self
.
support_models
:
if
support_model
in
yml_conf
[
'arch'
]:
return
True
raise
ValueError
(
"Unsupported arch: {}, expect SSD, YOLO, RetinaNet, RCNN and Face"
.
format
(
yml_conf
[
'arch'
]))
raise
ValueError
(
"Unsupported arch: {}, expect {}"
.
format
(
yml_conf
[
'arch'
],
self
.
support_models
))
def
print_config
(
self
):
print
(
'----------- Model Configuration -----------'
)
...
...
@@ -450,7 +453,7 @@ class Detector():
np_boxes
[:,
3
]
*=
w
np_boxes
[:,
4
]
*=
h
np_boxes
[:,
5
]
*=
w
expect_boxes
=
np_boxes
[:,
1
]
>
threshold
expect_boxes
=
(
np_boxes
[:,
1
]
>
threshold
)
&
(
np_boxes
[:,
0
]
>
-
1
)
np_boxes
=
np_boxes
[
expect_boxes
,
:]
for
box
in
np_boxes
:
print
(
'class_id:{:d}, confidence:{:.2f},'
...
...
ppdet/data/transform/batch_operators.py
浏览文件 @
d43e6d9a
...
...
@@ -26,13 +26,17 @@ import cv2
import
numpy
as
np
from
.operators
import
register_op
,
BaseOperator
from
.op_helper
import
jaccard_overlap
from
.op_helper
import
jaccard_overlap
,
gaussian2D
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
'PadBatch'
,
'RandomShape'
,
'PadMultiScaleTest'
,
'Gt2YoloTarget'
,
'Gt2FCOSTarget'
'PadBatch'
,
'RandomShape'
,
'PadMultiScaleTest'
,
'Gt2YoloTarget'
,
'Gt2FCOSTarget'
,
'Gt2TTFTarget'
,
]
...
...
@@ -41,7 +45,6 @@ class PadBatch(BaseOperator):
"""
Pad a batch of samples so they can be divisible by a stride.
The layout of each image should be 'CHW'.
Args:
pad_to_stride (int): If `pad_to_stride > 0`, pad zeros to ensure
height and width is divisible by `pad_to_stride`.
...
...
@@ -89,13 +92,12 @@ class RandomShape(BaseOperator):
select one an interpolation algorithm [cv2.INTER_NEAREST, cv2.INTER_LINEAR,
cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4]. If random_inter is
False, use cv2.INTER_NEAREST.
Args:
sizes (list): list of int, random choose a size from these
random_inter (bool): whether to randomly interpolation, defalut true.
"""
def
__init__
(
self
,
sizes
=
[],
random_inter
=
False
):
def
__init__
(
self
,
sizes
=
[],
random_inter
=
False
,
resize_box
=
False
):
super
(
RandomShape
,
self
).
__init__
()
self
.
sizes
=
sizes
self
.
random_inter
=
random_inter
...
...
@@ -106,6 +108,7 @@ class RandomShape(BaseOperator):
cv2
.
INTER_CUBIC
,
cv2
.
INTER_LANCZOS4
,
]
if
random_inter
else
[]
self
.
resize_box
=
resize_box
def
__call__
(
self
,
samples
,
context
=
None
):
shape
=
np
.
random
.
choice
(
self
.
sizes
)
...
...
@@ -119,6 +122,12 @@ class RandomShape(BaseOperator):
im
=
cv2
.
resize
(
im
,
None
,
None
,
fx
=
scale_x
,
fy
=
scale_y
,
interpolation
=
method
)
samples
[
i
][
'image'
]
=
im
if
self
.
resize_box
and
'gt_bbox'
in
samples
[
i
]
and
len
(
samples
[
0
][
'gt_bbox'
])
>
0
:
scale_array
=
np
.
array
([
scale_x
,
scale_y
]
*
2
,
dtype
=
np
.
float32
)
samples
[
i
][
'gt_bbox'
]
=
np
.
clip
(
samples
[
i
][
'gt_bbox'
]
*
scale_array
,
0
,
float
(
shape
)
-
1
)
return
samples
...
...
@@ -478,3 +487,99 @@ class Gt2FCOSTarget(BaseOperator):
sample
[
'centerness{}'
.
format
(
lvl
)]
=
np
.
reshape
(
ctn_targets_by_level
[
lvl
],
newshape
=
[
grid_h
,
grid_w
,
1
])
return
samples
@
register_op
class
Gt2TTFTarget
(
BaseOperator
):
"""
Gt2TTFTarget
Generate TTFNet targets by ground truth data
Args:
num_classes(int): the number of classes.
down_ratio(int): the down ratio from images to heatmap, 4 by default.
alpha(float): the alpha parameter to generate gaussian target.
0.54 by default.
"""
def
__init__
(
self
,
num_classes
,
down_ratio
=
4
,
alpha
=
0.54
):
super
(
Gt2TTFTarget
,
self
).
__init__
()
self
.
down_ratio
=
down_ratio
self
.
num_classes
=
num_classes
self
.
alpha
=
alpha
def
__call__
(
self
,
samples
,
context
=
None
):
output_size
=
samples
[
0
][
'image'
].
shape
[
1
]
feat_size
=
output_size
//
self
.
down_ratio
for
sample
in
samples
:
heatmap
=
np
.
zeros
(
(
self
.
num_classes
,
feat_size
,
feat_size
),
dtype
=
'float32'
)
box_target
=
np
.
ones
(
(
4
,
feat_size
,
feat_size
),
dtype
=
'float32'
)
*
-
1
reg_weight
=
np
.
zeros
((
1
,
feat_size
,
feat_size
),
dtype
=
'float32'
)
gt_bbox
=
sample
[
'gt_bbox'
]
gt_class
=
sample
[
'gt_class'
]
bbox_w
=
gt_bbox
[:,
2
]
-
gt_bbox
[:,
0
]
+
1
bbox_h
=
gt_bbox
[:,
3
]
-
gt_bbox
[:,
1
]
+
1
area
=
bbox_w
*
bbox_h
boxes_areas_log
=
np
.
log
(
area
)
boxes_ind
=
np
.
argsort
(
boxes_areas_log
,
axis
=
0
)[::
-
1
]
boxes_area_topk_log
=
boxes_areas_log
[
boxes_ind
]
gt_bbox
=
gt_bbox
[
boxes_ind
]
gt_class
=
gt_class
[
boxes_ind
]
feat_gt_bbox
=
gt_bbox
/
self
.
down_ratio
feat_gt_bbox
=
np
.
clip
(
feat_gt_bbox
,
0
,
feat_size
-
1
)
feat_hs
,
feat_ws
=
(
feat_gt_bbox
[:,
3
]
-
feat_gt_bbox
[:,
1
],
feat_gt_bbox
[:,
2
]
-
feat_gt_bbox
[:,
0
])
ct_inds
=
np
.
stack
(
[(
gt_bbox
[:,
0
]
+
gt_bbox
[:,
2
])
/
2
,
(
gt_bbox
[:,
1
]
+
gt_bbox
[:,
3
])
/
2
],
axis
=
1
)
/
self
.
down_ratio
h_radiuses_alpha
=
(
feat_hs
/
2.
*
self
.
alpha
).
astype
(
'int32'
)
w_radiuses_alpha
=
(
feat_ws
/
2.
*
self
.
alpha
).
astype
(
'int32'
)
for
k
in
range
(
len
(
gt_bbox
)):
cls_id
=
gt_class
[
k
]
fake_heatmap
=
np
.
zeros
((
feat_size
,
feat_size
),
dtype
=
'float32'
)
self
.
draw_truncate_gaussian
(
fake_heatmap
,
ct_inds
[
k
],
h_radiuses_alpha
[
k
],
w_radiuses_alpha
[
k
])
heatmap
[
cls_id
]
=
np
.
maximum
(
heatmap
[
cls_id
],
fake_heatmap
)
box_target_inds
=
fake_heatmap
>
0
box_target
[:,
box_target_inds
]
=
gt_bbox
[
k
][:,
None
]
local_heatmap
=
fake_heatmap
[
box_target_inds
]
ct_div
=
np
.
sum
(
local_heatmap
)
local_heatmap
*=
boxes_area_topk_log
[
k
]
reg_weight
[
0
,
box_target_inds
]
=
local_heatmap
/
ct_div
sample
[
'ttf_heatmap'
]
=
heatmap
sample
[
'ttf_box_target'
]
=
box_target
sample
[
'ttf_reg_weight'
]
=
reg_weight
return
samples
def
draw_truncate_gaussian
(
self
,
heatmap
,
center
,
h_radius
,
w_radius
):
h
,
w
=
2
*
h_radius
+
1
,
2
*
w_radius
+
1
sigma_x
=
w
/
6
sigma_y
=
h
/
6
gaussian
=
gaussian2D
((
h
,
w
),
sigma_x
,
sigma_y
)
x
,
y
=
int
(
center
[
0
]),
int
(
center
[
1
])
height
,
width
=
heatmap
.
shape
[
0
:
2
]
left
,
right
=
min
(
x
,
w_radius
),
min
(
width
-
x
,
w_radius
+
1
)
top
,
bottom
=
min
(
y
,
h_radius
),
min
(
height
-
y
,
h_radius
+
1
)
masked_heatmap
=
heatmap
[
y
-
top
:
y
+
bottom
,
x
-
left
:
x
+
right
]
masked_gaussian
=
gaussian
[
h_radius
-
top
:
h_radius
+
bottom
,
w_radius
-
left
:
w_radius
+
right
]
if
min
(
masked_gaussian
.
shape
)
>
0
and
min
(
masked_heatmap
.
shape
)
>
0
:
heatmap
[
y
-
top
:
y
+
bottom
,
x
-
left
:
x
+
right
]
=
np
.
maximum
(
masked_heatmap
,
masked_gaussian
)
return
heatmap
ppdet/data/transform/op_helper.py
浏览文件 @
d43e6d9a
...
...
@@ -438,7 +438,8 @@ def gaussian_radius(bbox_size, min_overlap):
def
draw_gaussian
(
heatmap
,
center
,
radius
,
k
=
1
,
delte
=
6
):
diameter
=
2
*
radius
+
1
gaussian
=
gaussian2D
((
diameter
,
diameter
),
sigma
=
diameter
/
delte
)
sigma
=
diameter
/
delte
gaussian
=
gaussian2D
((
diameter
,
diameter
),
sigma_x
=
sigma
,
sigma_y
=
sigma
)
x
,
y
=
center
...
...
@@ -453,10 +454,11 @@ def draw_gaussian(heatmap, center, radius, k=1, delte=6):
np
.
maximum
(
masked_heatmap
,
masked_gaussian
*
k
,
out
=
masked_heatmap
)
def
gaussian2D
(
shape
,
sigma
=
1
):
def
gaussian2D
(
shape
,
sigma
_x
=
1
,
sigma_y
=
1
):
m
,
n
=
[(
ss
-
1.
)
/
2.
for
ss
in
shape
]
y
,
x
=
np
.
ogrid
[
-
m
:
m
+
1
,
-
n
:
n
+
1
]
h
=
np
.
exp
(
-
(
x
*
x
+
y
*
y
)
/
(
2
*
sigma
*
sigma
))
h
=
np
.
exp
(
-
(
x
*
x
/
(
2
*
sigma_x
*
sigma_x
)
+
y
*
y
/
(
2
*
sigma_y
*
sigma_y
)))
h
[
h
<
np
.
finfo
(
h
.
dtype
).
eps
*
h
.
max
()]
=
0
return
h
ppdet/data/transform/operators.py
浏览文件 @
d43e6d9a
...
...
@@ -92,7 +92,6 @@ class BaseOperator(object):
class
DecodeImage
(
BaseOperator
):
def
__init__
(
self
,
to_rgb
=
True
,
with_mixup
=
False
,
with_cutmix
=
False
):
""" Transform the image data to numpy format.
Args:
to_rgb (bool): whether to convert BGR to RGB
with_mixup (bool): whether or not to mixup image and gt_bbbox/gt_score
...
...
@@ -165,7 +164,6 @@ class MultiscaleTestResize(BaseOperator):
use_flip
=
True
):
"""
Rescale image to the each size in target size, and capped at max_size.
Args:
origin_target_size(int): original target size of image's short side.
origin_max_size(int): original max size of image.
...
...
@@ -274,7 +272,6 @@ class ResizeImage(BaseOperator):
if max_size != 0.
If target_size is list, selected a scale randomly as the specified
target size.
Args:
target_size (int|list): the target size of image's short side,
multi-scale training is adopted when type is list.
...
...
@@ -1177,7 +1174,6 @@ class Permute(BaseOperator):
Args:
to_bgr (bool): confirm whether to convert RGB to BGR
channel_first (bool): confirm whether to change channel
"""
super
(
Permute
,
self
).
__init__
()
self
.
to_bgr
=
to_bgr
...
...
@@ -1386,7 +1382,6 @@ class RandomInterpImage(BaseOperator):
@
register_op
class
Resize
(
BaseOperator
):
"""Resize image and bbox.
Args:
target_dim (int or list): target size, can be a single number or a list
(for random shape).
...
...
@@ -1419,6 +1414,7 @@ class Resize(BaseOperator):
scale_array
=
np
.
array
([
scale_x
,
scale_y
]
*
2
,
dtype
=
np
.
float32
)
sample
[
'gt_bbox'
]
=
np
.
clip
(
sample
[
'gt_bbox'
]
*
scale_array
,
0
,
dim
-
1
)
sample
[
'scale_factor'
]
=
[
scale_x
,
scale_y
]
*
2
sample
[
'h'
]
=
resize_h
sample
[
'w'
]
=
resize_w
...
...
@@ -1430,7 +1426,6 @@ class Resize(BaseOperator):
@
register_op
class
ColorDistort
(
BaseOperator
):
"""Random color distortion.
Args:
hue (list): hue settings.
in [lower, upper, probability] format.
...
...
@@ -1442,6 +1437,8 @@ class ColorDistort(BaseOperator):
in [lower, upper, probability] format.
random_apply (bool): whether to apply in random (yolo) or fixed (SSD)
order.
hsv_format (bool): whether to convert color from BGR to HSV
random_channel (bool): whether to swap channels randomly
"""
def
__init__
(
self
,
...
...
@@ -1449,13 +1446,17 @@ class ColorDistort(BaseOperator):
saturation
=
[
0.5
,
1.5
,
0.5
],
contrast
=
[
0.5
,
1.5
,
0.5
],
brightness
=
[
0.5
,
1.5
,
0.5
],
random_apply
=
True
):
random_apply
=
True
,
hsv_format
=
False
,
random_channel
=
False
):
super
(
ColorDistort
,
self
).
__init__
()
self
.
hue
=
hue
self
.
saturation
=
saturation
self
.
contrast
=
contrast
self
.
brightness
=
brightness
self
.
random_apply
=
random_apply
self
.
hsv_format
=
hsv_format
self
.
random_channel
=
random_channel
def
apply_hue
(
self
,
img
):
low
,
high
,
prob
=
self
.
hue
...
...
@@ -1463,6 +1464,11 @@ class ColorDistort(BaseOperator):
return
img
img
=
img
.
astype
(
np
.
float32
)
if
self
.
hsv_format
:
img
[...,
0
]
+=
random
.
uniform
(
low
,
high
)
img
[...,
0
][
img
[...,
0
]
>
360
]
-=
360
img
[...,
0
][
img
[...,
0
]
<
0
]
+=
360
return
img
# XXX works, but result differ from HSV version
delta
=
np
.
random
.
uniform
(
low
,
high
)
...
...
@@ -1482,8 +1488,10 @@ class ColorDistort(BaseOperator):
if
np
.
random
.
uniform
(
0.
,
1.
)
<
prob
:
return
img
delta
=
np
.
random
.
uniform
(
low
,
high
)
img
=
img
.
astype
(
np
.
float32
)
if
self
.
hsv_format
:
img
[...,
1
]
*=
delta
return
img
gray
=
img
*
np
.
array
([[[
0.299
,
0.587
,
0.114
]]],
dtype
=
np
.
float32
)
gray
=
gray
.
sum
(
axis
=
2
,
keepdims
=
True
)
gray
*=
(
1.0
-
delta
)
...
...
@@ -1530,12 +1538,24 @@ class ColorDistort(BaseOperator):
if
np
.
random
.
randint
(
0
,
2
):
img
=
self
.
apply_contrast
(
img
)
if
self
.
hsv_format
:
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_RGB2HSV
)
img
=
self
.
apply_saturation
(
img
)
img
=
self
.
apply_hue
(
img
)
if
self
.
hsv_format
:
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_HSV2RGB
)
else
:
if
self
.
hsv_format
:
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_RGB2HSV
)
img
=
self
.
apply_saturation
(
img
)
img
=
self
.
apply_hue
(
img
)
if
self
.
hsv_format
:
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_HSV2RGB
)
img
=
self
.
apply_contrast
(
img
)
if
self
.
random_channel
:
if
np
.
random
.
randint
(
0
,
2
):
img
=
img
[...,
np
.
random
.
permutation
(
3
)]
sample
[
'image'
]
=
img
return
sample
...
...
@@ -1603,7 +1623,6 @@ class CornerRandColor(ColorDistort):
@
register_op
class
NormalizePermute
(
BaseOperator
):
"""Normalize and permute channel order.
Args:
mean (list): mean values in RGB order.
std (list): std values in RGB order.
...
...
@@ -1633,7 +1652,6 @@ class NormalizePermute(BaseOperator):
@
register_op
class
RandomExpand
(
BaseOperator
):
"""Random expand the canvas.
Args:
ratio (float): maximum expansion ratio.
prob (float): probability to expand.
...
...
@@ -1725,7 +1743,6 @@ class RandomExpand(BaseOperator):
@
register_op
class
RandomCrop
(
BaseOperator
):
"""Random crop image and bboxes.
Args:
aspect_ratio (list): aspect ratio of cropped region.
in [min, max] format.
...
...
@@ -1852,11 +1869,23 @@ class RandomCrop(BaseOperator):
found
=
False
for
i
in
range
(
self
.
num_attempts
):
scale
=
np
.
random
.
uniform
(
*
self
.
scaling
)
min_ar
,
max_ar
=
self
.
aspect_ratio
aspect_ratio
=
np
.
random
.
uniform
(
max
(
min_ar
,
scale
**
2
),
min
(
max_ar
,
scale
**-
2
))
crop_h
=
int
(
h
*
scale
/
np
.
sqrt
(
aspect_ratio
))
crop_w
=
int
(
w
*
scale
*
np
.
sqrt
(
aspect_ratio
))
if
self
.
aspect_ratio
is
not
None
:
min_ar
,
max_ar
=
self
.
aspect_ratio
aspect_ratio
=
np
.
random
.
uniform
(
max
(
min_ar
,
scale
**
2
),
min
(
max_ar
,
scale
**-
2
))
h_scale
=
scale
/
np
.
sqrt
(
aspect_ratio
)
w_scale
=
scale
*
np
.
sqrt
(
aspect_ratio
)
else
:
h_scale
=
np
.
random
.
uniform
(
*
self
.
scaling
)
w_scale
=
np
.
random
.
uniform
(
*
self
.
scaling
)
crop_h
=
h
*
h_scale
crop_w
=
w
*
w_scale
if
self
.
aspect_ratio
is
None
:
if
crop_h
/
crop_w
<
0.5
or
crop_h
/
crop_w
>
2.0
:
continue
crop_h
=
int
(
crop_h
)
crop_w
=
int
(
crop_w
)
crop_y
=
np
.
random
.
randint
(
0
,
h
-
crop_h
)
crop_x
=
np
.
random
.
randint
(
0
,
w
-
crop_w
)
crop_box
=
[
crop_x
,
crop_y
,
crop_x
+
crop_w
,
crop_y
+
crop_h
]
...
...
@@ -2008,7 +2037,6 @@ class BboxXYXY2XYWH(BaseOperator):
return
sample
@
register_op
class
Lighting
(
BaseOperator
):
"""
Lighting the imagen by eigenvalues and eigenvectors
...
...
@@ -2248,7 +2276,6 @@ class CornerRatio(BaseOperator):
class
RandomScaledCrop
(
BaseOperator
):
"""Resize image and bbox based on long side (with optional random scaling),
then crop or pad image to target size.
Args:
target_dim (int): target size.
scale_range (list): random scale range.
...
...
@@ -2303,7 +2330,6 @@ class RandomScaledCrop(BaseOperator):
@
register_op
class
ResizeAndPad
(
BaseOperator
):
"""Resize image and bbox, then pad image to target size.
Args:
target_dim (int): target size
interp (int): interpolation method, default to `cv2.INTER_LINEAR`.
...
...
@@ -2342,7 +2368,6 @@ class ResizeAndPad(BaseOperator):
@
register_op
class
TargetAssign
(
BaseOperator
):
"""Assign regression target and labels.
Args:
image_size (int or list): input image size, a single integer or list of
[h, w]. Default: 512
...
...
ppdet/modeling/anchor_heads/__init__.py
浏览文件 @
d43e6d9a
...
...
@@ -20,6 +20,7 @@ from . import retina_head
from
.
import
fcos_head
from
.
import
corner_head
from
.
import
efficient_head
from
.
import
ttf_head
from
.rpn_head
import
*
from
.yolo_head
import
*
...
...
@@ -27,3 +28,4 @@ from .retina_head import *
from
.fcos_head
import
*
from
.corner_head
import
*
from
.efficient_head
import
*
from
.ttf_head
import
*
ppdet/modeling/anchor_heads/ttf_head.py
0 → 100644
浏览文件 @
d43e6d9a
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
math
import
numpy
as
np
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.initializer
import
Normal
,
Constant
,
Uniform
,
Xavier
from
paddle.fluid.regularizer
import
L2Decay
from
ppdet.core.workspace
import
register
from
ppdet.modeling.ops
import
DeformConv
,
DropBlock
from
ppdet.modeling.losses
import
GiouLoss
__all__
=
[
'TTFHead'
]
@
register
class
TTFHead
(
object
):
"""
TTFHead
Args:
head_conv(int): the default channel number of convolution in head.
128 by default.
num_classes(int): the number of classes, 80 by default.
hm_weight(float): the weight of heatmap branch. 1. by default.
wh_weight(float): the weight of wh branch. 5. by default.
wh_offset_base(flaot): the base offset of width and height.
16. by default.
planes(tuple): the channel number of convolution in each upsample.
(256, 128, 64) by default.
shortcut_num(tuple): the number of convolution layers in each shortcut.
(1, 2, 3) by default.
wh_head_conv_num(int): the number of convolution layers in wh head.
2 by default.
hm_head_conv_num(int): the number of convolution layers in wh head.
2 by default.
wh_conv(int): the channel number of convolution in wh head.
64 by default.
wh_planes(int): the output channel in wh head. 4 by default.
score_thresh(float): the score threshold to get prediction.
0.01 by default.
max_per_img(int): the maximum detection per image. 100 by default.
base_down_ratio(int): the base down_ratio, the actual down_ratio is
calculated by base_down_ratio and the number of upsample layers.
16 by default.
wh_loss(object): `GiouLoss` instance.
dcn_upsample(bool): whether upsample by dcn. True by default.
dcn_head(bool): whether use dcn in head. False by default.
drop_block(bool): whether use dropblock. False by default.
block_size(int): block_size parameter for drop_block. 3 by default.
keep_prob(float): keep_prob parameter for drop_block. 0.9 by default.
"""
__inject__
=
[
'wh_loss'
]
__shared__
=
[
'num_classes'
]
def
__init__
(
self
,
head_conv
=
128
,
num_classes
=
80
,
hm_weight
=
1.
,
wh_weight
=
5.
,
wh_offset_base
=
16.
,
planes
=
(
256
,
128
,
64
),
shortcut_num
=
(
1
,
2
,
3
),
wh_head_conv_num
=
2
,
hm_head_conv_num
=
2
,
wh_conv
=
64
,
wh_planes
=
4
,
score_thresh
=
0.01
,
max_per_img
=
100
,
base_down_ratio
=
32
,
wh_loss
=
'GiouLoss'
,
dcn_upsample
=
True
,
dcn_head
=
False
,
drop_block
=
False
,
block_size
=
3
,
keep_prob
=
0.9
):
super
(
TTFHead
,
self
).
__init__
()
self
.
head_conv
=
head_conv
self
.
num_classes
=
num_classes
self
.
hm_weight
=
hm_weight
self
.
wh_weight
=
wh_weight
self
.
wh_offset_base
=
wh_offset_base
self
.
planes
=
planes
self
.
shortcut_num
=
shortcut_num
self
.
shortcut_len
=
len
(
shortcut_num
)
self
.
wh_head_conv_num
=
wh_head_conv_num
self
.
hm_head_conv_num
=
hm_head_conv_num
self
.
wh_conv
=
wh_conv
self
.
wh_planes
=
wh_planes
self
.
score_thresh
=
score_thresh
self
.
max_per_img
=
max_per_img
self
.
down_ratio
=
base_down_ratio
//
2
**
len
(
planes
)
self
.
hm_weight
=
hm_weight
self
.
wh_weight
=
wh_weight
self
.
wh_loss
=
wh_loss
self
.
dcn_upsample
=
dcn_upsample
self
.
dcn_head
=
dcn_head
self
.
drop_block
=
drop_block
self
.
block_size
=
block_size
self
.
keep_prob
=
keep_prob
def
shortcut
(
self
,
x
,
out_c
,
layer_num
,
kernel_size
=
3
,
padding
=
1
,
name
=
None
):
assert
layer_num
>
0
for
i
in
range
(
layer_num
):
act
=
'relu'
if
i
<
layer_num
-
1
else
None
fan_out
=
kernel_size
*
kernel_size
*
out_c
std
=
math
.
sqrt
(
2.
/
fan_out
)
param_name
=
name
+
'.layers.'
+
str
(
i
*
2
)
x
=
fluid
.
layers
.
conv2d
(
x
,
out_c
,
kernel_size
,
padding
=
padding
,
act
=
act
,
param_attr
=
ParamAttr
(
initializer
=
Normal
(
0
,
std
),
name
=
param_name
+
'.weight'
),
bias_attr
=
ParamAttr
(
learning_rate
=
2.
,
regularizer
=
L2Decay
(
0.
),
name
=
param_name
+
'.bias'
))
return
x
def
upsample
(
self
,
x
,
out_c
,
name
=
None
):
fan_in
=
x
.
shape
[
1
]
*
3
*
3
stdv
=
1.
/
math
.
sqrt
(
fan_in
)
if
self
.
dcn_upsample
:
conv
=
DeformConv
(
x
,
out_c
,
3
,
initializer
=
Uniform
(
-
stdv
,
stdv
),
bias_attr
=
True
,
name
=
name
+
'.0'
)
else
:
conv
=
fluid
.
layers
.
conv2d
(
x
,
out_c
,
3
,
padding
=
1
,
param_attr
=
ParamAttr
(
initializer
=
Uniform
(
-
stdv
,
stdv
)),
bias_attr
=
ParamAttr
(
learning_rate
=
2.
,
regularizer
=
L2Decay
(
0.
)))
norm_name
=
name
+
'.1'
pattr
=
ParamAttr
(
name
=
norm_name
+
'.weight'
,
initializer
=
Constant
(
1.
))
battr
=
ParamAttr
(
name
=
norm_name
+
'.bias'
,
initializer
=
Constant
(
0.
))
bn
=
fluid
.
layers
.
batch_norm
(
input
=
conv
,
act
=
'relu'
,
param_attr
=
pattr
,
bias_attr
=
battr
,
name
=
norm_name
+
'.output.1'
,
moving_mean_name
=
norm_name
+
'.running_mean'
,
moving_variance_name
=
norm_name
+
'.running_var'
)
up
=
fluid
.
layers
.
resize_bilinear
(
bn
,
scale
=
2
,
name
=
name
+
'.2.upsample'
)
return
up
def
_head
(
self
,
x
,
out_c
,
conv_num
=
1
,
head_out_c
=
None
,
name
=
None
,
is_test
=
False
):
head_out_c
=
self
.
head_conv
if
not
head_out_c
else
head_out_c
conv_w_std
=
0.01
if
'.hm'
in
name
else
0.001
conv_w_init
=
Normal
(
0
,
conv_w_std
)
for
i
in
range
(
conv_num
):
conv_name
=
'{}.{}.conv'
.
format
(
name
,
i
)
if
self
.
dcn_head
:
x
=
DeformConv
(
x
,
head_out_c
,
3
,
initializer
=
conv_w_init
,
name
=
conv_name
+
'.dcn'
)
x
=
fluid
.
layers
.
relu
(
x
)
else
:
x
=
fluid
.
layers
.
conv2d
(
x
,
head_out_c
,
3
,
padding
=
1
,
param_attr
=
ParamAttr
(
initializer
=
conv_w_init
,
name
=
conv_name
+
'.weight'
),
bias_attr
=
ParamAttr
(
learning_rate
=
2.
,
regularizer
=
L2Decay
(
0.
),
name
=
conv_name
+
'.bias'
),
act
=
'relu'
)
if
self
.
drop_block
and
'.hm'
in
name
:
x
=
DropBlock
(
x
,
block_size
=
self
.
block_size
,
keep_prob
=
self
.
keep_prob
,
is_test
=
is_test
)
bias_init
=
float
(
-
np
.
log
((
1
-
0.01
)
/
0.01
))
if
'.hm'
in
name
else
0.
conv_b_init
=
Constant
(
bias_init
)
x
=
fluid
.
layers
.
conv2d
(
x
,
out_c
,
1
,
param_attr
=
ParamAttr
(
initializer
=
conv_w_init
,
name
=
'{}.{}.weight'
.
format
(
name
,
conv_num
)),
bias_attr
=
ParamAttr
(
learning_rate
=
2.
,
regularizer
=
L2Decay
(
0.
),
name
=
'{}.{}.bias'
.
format
(
name
,
conv_num
),
initializer
=
conv_b_init
))
return
x
def
hm_head
(
self
,
x
,
name
=
None
,
is_test
=
False
):
hm
=
self
.
_head
(
x
,
self
.
num_classes
,
self
.
hm_head_conv_num
,
name
=
name
,
is_test
=
is_test
)
return
hm
def
wh_head
(
self
,
x
,
name
=
None
):
planes
=
self
.
wh_planes
wh
=
self
.
_head
(
x
,
planes
,
self
.
wh_head_conv_num
,
self
.
wh_conv
,
name
=
name
)
return
fluid
.
layers
.
relu
(
wh
)
def
get_output
(
self
,
input
,
name
=
None
,
is_test
=
False
):
feat
=
input
[
-
1
]
for
i
,
out_c
in
enumerate
(
self
.
planes
):
feat
=
self
.
upsample
(
feat
,
out_c
,
name
=
name
+
'.deconv_layers.'
+
str
(
i
))
if
i
<
self
.
shortcut_len
:
shortcut
=
self
.
shortcut
(
input
[
-
i
-
2
],
out_c
,
self
.
shortcut_num
[
i
],
name
=
name
+
'.shortcut_layers.'
+
str
(
i
))
feat
=
fluid
.
layers
.
elementwise_add
(
feat
,
shortcut
)
hm
=
self
.
hm_head
(
feat
,
name
=
name
+
'.hm'
,
is_test
=
is_test
)
wh
=
self
.
wh_head
(
feat
,
name
=
name
+
'.wh'
)
*
self
.
wh_offset_base
return
hm
,
wh
def
_simple_nms
(
self
,
heat
,
kernel
=
3
):
pad
=
(
kernel
-
1
)
//
2
hmax
=
fluid
.
layers
.
pool2d
(
heat
,
kernel
,
'max'
,
pool_padding
=
pad
)
keep
=
fluid
.
layers
.
cast
(
hmax
==
heat
,
'float32'
)
return
heat
*
keep
def
_topk
(
self
,
scores
,
k
):
cat
,
height
,
width
=
scores
.
shape
[
1
:]
# batch size is 1
scores_r
=
fluid
.
layers
.
reshape
(
scores
,
[
cat
,
-
1
])
topk_scores
,
topk_inds
=
fluid
.
layers
.
topk
(
scores_r
,
k
)
topk_ys
=
topk_inds
/
width
topk_xs
=
topk_inds
%
width
topk_score_r
=
fluid
.
layers
.
reshape
(
topk_scores
,
[
-
1
])
topk_score
,
topk_ind
=
fluid
.
layers
.
topk
(
topk_score_r
,
k
)
topk_clses
=
fluid
.
layers
.
cast
(
topk_ind
/
k
,
'float32'
)
topk_inds
=
fluid
.
layers
.
reshape
(
topk_inds
,
[
-
1
])
topk_ys
=
fluid
.
layers
.
reshape
(
topk_ys
,
[
-
1
,
1
])
topk_xs
=
fluid
.
layers
.
reshape
(
topk_xs
,
[
-
1
,
1
])
topk_inds
=
fluid
.
layers
.
gather
(
topk_inds
,
topk_ind
)
topk_ys
=
fluid
.
layers
.
gather
(
topk_ys
,
topk_ind
)
topk_xs
=
fluid
.
layers
.
gather
(
topk_xs
,
topk_ind
)
return
topk_score
,
topk_inds
,
topk_clses
,
topk_ys
,
topk_xs
def
get_bboxes
(
self
,
heatmap
,
wh
,
scale_factor
):
heatmap
=
fluid
.
layers
.
sigmoid
(
heatmap
)
heat
=
self
.
_simple_nms
(
heatmap
)
scores
,
inds
,
clses
,
ys
,
xs
=
self
.
_topk
(
heat
,
self
.
max_per_img
)
ys
=
fluid
.
layers
.
cast
(
ys
,
'float32'
)
*
self
.
down_ratio
xs
=
fluid
.
layers
.
cast
(
xs
,
'float32'
)
*
self
.
down_ratio
scores
=
fluid
.
layers
.
unsqueeze
(
scores
,
[
1
])
clses
=
fluid
.
layers
.
unsqueeze
(
clses
,
[
1
])
wh_t
=
fluid
.
layers
.
transpose
(
wh
,
[
0
,
2
,
3
,
1
])
wh
=
fluid
.
layers
.
reshape
(
wh_t
,
[
-
1
,
wh_t
.
shape
[
-
1
]])
wh
=
fluid
.
layers
.
gather
(
wh
,
inds
)
x1
=
xs
-
wh
[:,
0
:
1
]
y1
=
ys
-
wh
[:,
1
:
2
]
x2
=
xs
+
wh
[:,
2
:
3
]
y2
=
ys
+
wh
[:,
3
:
4
]
bboxes
=
fluid
.
layers
.
concat
([
x1
,
y1
,
x2
,
y2
],
axis
=
1
)
bboxes
=
fluid
.
layers
.
elementwise_div
(
bboxes
,
scale_factor
,
axis
=-
1
)
results
=
fluid
.
layers
.
concat
([
clses
,
scores
,
bboxes
],
axis
=
1
)
# hack: append result with cls=-1 and score=1. to avoid all scores
# are less than score_thresh which may cause error in gather.
fill_r
=
fluid
.
layers
.
assign
(
np
.
array
(
[[
-
1
,
1.
,
0
,
0
,
0
,
0
]],
dtype
=
'float32'
))
results
=
fluid
.
layers
.
concat
([
results
,
fill_r
])
scores
=
results
[:,
1
]
valid_ind
=
fluid
.
layers
.
where
(
scores
>
self
.
score_thresh
)
results
=
fluid
.
layers
.
gather
(
results
,
valid_ind
)
return
{
'bbox'
:
results
}
def
ct_focal_loss
(
self
,
pred_hm
,
target_hm
,
gamma
=
2.0
):
fg_map
=
fluid
.
layers
.
cast
(
target_hm
==
1
,
'float32'
)
fg_map
.
stop_gradient
=
True
bg_map
=
fluid
.
layers
.
cast
(
target_hm
<
1
,
'float32'
)
bg_map
.
stop_gradient
=
True
neg_weights
=
fluid
.
layers
.
pow
(
1
-
target_hm
,
4
)
*
bg_map
pos_loss
=
0
-
fluid
.
layers
.
log
(
pred_hm
)
*
fluid
.
layers
.
pow
(
1
-
pred_hm
,
gamma
)
*
fg_map
neg_loss
=
0
-
fluid
.
layers
.
log
(
1
-
pred_hm
)
*
fluid
.
layers
.
pow
(
pred_hm
,
gamma
)
*
neg_weights
pos_loss
=
fluid
.
layers
.
reduce_sum
(
pos_loss
)
neg_loss
=
fluid
.
layers
.
reduce_sum
(
neg_loss
)
fg_num
=
fluid
.
layers
.
reduce_sum
(
fg_map
)
focal_loss
=
(
pos_loss
+
neg_loss
)
/
(
fg_num
+
fluid
.
layers
.
cast
(
fg_num
==
0
,
'float32'
))
return
focal_loss
def
filter_box_by_weight
(
self
,
pred
,
target
,
weight
):
index
=
fluid
.
layers
.
where
(
weight
>
0
)
index
.
stop_gradient
=
True
weight
=
fluid
.
layers
.
gather_nd
(
weight
,
index
)
pred
=
fluid
.
layers
.
gather_nd
(
pred
,
index
)
target
=
fluid
.
layers
.
gather_nd
(
target
,
index
)
return
pred
,
target
,
weight
def
get_loss
(
self
,
pred_hm
,
pred_wh
,
target_hm
,
box_target
,
target_weight
):
pred_hm
=
paddle
.
tensor
.
clamp
(
fluid
.
layers
.
sigmoid
(
pred_hm
),
1e-4
,
1
-
1e-4
)
hm_loss
=
self
.
ct_focal_loss
(
pred_hm
,
target_hm
)
*
self
.
hm_weight
shape
=
fluid
.
layers
.
shape
(
target_hm
)
shape
.
stop_gradient
=
True
H
,
W
=
shape
[
2
],
shape
[
3
]
mask
=
fluid
.
layers
.
reshape
(
target_weight
,
[
-
1
,
H
,
W
])
avg_factor
=
fluid
.
layers
.
reduce_sum
(
mask
)
+
1e-4
base_step
=
self
.
down_ratio
zero
=
fluid
.
layers
.
fill_constant
(
shape
=
[
1
],
value
=
0
,
dtype
=
'int32'
)
shifts_x
=
paddle
.
arange
(
zero
,
W
*
base_step
,
base_step
,
dtype
=
'int32'
)
shifts_y
=
paddle
.
arange
(
zero
,
H
*
base_step
,
base_step
,
dtype
=
'int32'
)
shift_y
,
shift_x
=
paddle
.
tensor
.
meshgrid
([
shifts_y
,
shifts_x
])
base_loc
=
fluid
.
layers
.
stack
([
shift_x
,
shift_y
],
axis
=
0
)
base_loc
.
stop_gradient
=
True
pred_boxes
=
fluid
.
layers
.
concat
(
[
0
-
pred_wh
[:,
0
:
2
,
:,
:]
+
base_loc
,
pred_wh
[:,
2
:
4
]
+
base_loc
],
axis
=
1
)
pred_boxes
=
fluid
.
layers
.
transpose
(
pred_boxes
,
[
0
,
2
,
3
,
1
])
boxes
=
fluid
.
layers
.
transpose
(
box_target
,
[
0
,
2
,
3
,
1
])
boxes
.
stop_gradient
=
True
pred_boxes
,
boxes
,
mask
=
self
.
filter_box_by_weight
(
pred_boxes
,
boxes
,
mask
)
mask
.
stop_gradient
=
True
wh_loss
=
self
.
wh_loss
(
pred_boxes
,
boxes
,
outside_weight
=
mask
,
use_transform
=
False
)
wh_loss
=
wh_loss
/
avg_factor
ttf_loss
=
{
'hm_loss'
:
hm_loss
,
'wh_loss'
:
wh_loss
}
return
ttf_loss
ppdet/modeling/architectures/__init__.py
浏览文件 @
d43e6d9a
...
...
@@ -27,6 +27,7 @@ from . import blazeface
from
.
import
faceboxes
from
.
import
fcos
from
.
import
cornernet_squeeze
from
.
import
ttfnet
from
.faster_rcnn
import
*
from
.mask_rcnn
import
*
...
...
@@ -41,3 +42,4 @@ from .blazeface import *
from
.faceboxes
import
*
from
.fcos
import
*
from
.cornernet_squeeze
import
*
from
.ttfnet
import
*
ppdet/modeling/architectures/ttfnet.py
0 → 100644
浏览文件 @
d43e6d9a
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
collections
import
OrderedDict
from
paddle
import
fluid
from
ppdet.experimental
import
mixed_precision_global_state
from
ppdet.core.workspace
import
register
__all__
=
[
'TTFNet'
]
@
register
class
TTFNet
(
object
):
"""
TTFNet network, see https://arxiv.org/abs/1909.00700
Args:
backbone (object): backbone instance
ttf_head (object): `TTFHead` instance
num_classes (int): the number of classes, 80 by default.
"""
__category__
=
'architecture'
__inject__
=
[
'backbone'
,
'ttf_head'
]
__shared__
=
[
'num_classes'
]
def
__init__
(
self
,
backbone
,
ttf_head
=
'TTFHead'
,
num_classes
=
80
):
super
(
TTFNet
,
self
).
__init__
()
self
.
backbone
=
backbone
self
.
ttf_head
=
ttf_head
self
.
num_classes
=
num_classes
def
build
(
self
,
feed_vars
,
mode
=
'train'
):
im
=
feed_vars
[
'image'
]
mixed_precision_enabled
=
mixed_precision_global_state
()
is
not
None
# cast inputs to FP16
if
mixed_precision_enabled
:
im
=
fluid
.
layers
.
cast
(
im
,
'float16'
)
body_feats
=
self
.
backbone
(
im
)
if
isinstance
(
body_feats
,
OrderedDict
):
body_feat_names
=
list
(
body_feats
.
keys
())
body_feats
=
[
body_feats
[
name
]
for
name
in
body_feat_names
]
# cast features back to FP32
if
mixed_precision_enabled
:
body_feats
=
[
fluid
.
layers
.
cast
(
v
,
'float32'
)
for
v
in
body_feats
]
predict_hm
,
predict_wh
=
self
.
ttf_head
.
get_output
(
body_feats
,
'ttf_head'
,
is_test
=
mode
==
'test'
)
if
mode
==
'train'
:
heatmap
=
feed_vars
[
'ttf_heatmap'
]
box_target
=
feed_vars
[
'ttf_box_target'
]
reg_weight
=
feed_vars
[
'ttf_reg_weight'
]
loss
=
self
.
ttf_head
.
get_loss
(
predict_hm
,
predict_wh
,
heatmap
,
box_target
,
reg_weight
)
total_loss
=
fluid
.
layers
.
sum
(
list
(
loss
.
values
()))
loss
.
update
({
'loss'
:
total_loss
})
return
loss
else
:
results
=
self
.
ttf_head
.
get_bboxes
(
predict_hm
,
predict_wh
,
feed_vars
[
'scale_factor'
])
return
results
def
_inputs_def
(
self
,
image_shape
,
downsample
):
im_shape
=
[
None
]
+
image_shape
H
,
W
=
im_shape
[
2
:]
target_h
=
None
if
H
is
None
else
H
//
downsample
target_w
=
None
if
W
is
None
else
W
//
downsample
# yapf: disable
inputs_def
=
{
'image'
:
{
'shape'
:
im_shape
,
'dtype'
:
'float32'
,
'lod_level'
:
0
},
'scale_factor'
:
{
'shape'
:
[
None
,
4
],
'dtype'
:
'float32'
,
'lod_level'
:
0
},
'im_id'
:
{
'shape'
:
[
None
,
1
],
'dtype'
:
'int64'
,
'lod_level'
:
0
},
'ttf_heatmap'
:
{
'shape'
:
[
None
,
self
.
num_classes
,
target_h
,
target_w
],
'dtype'
:
'float32'
,
'lod_level'
:
0
},
'ttf_box_target'
:
{
'shape'
:
[
None
,
4
,
target_h
,
target_w
],
'dtype'
:
'float32'
,
'lod_level'
:
0
},
'ttf_reg_weight'
:
{
'shape'
:
[
None
,
1
,
target_h
,
target_w
],
'dtype'
:
'float32'
,
'lod_level'
:
0
},
}
# yapf: enable
return
inputs_def
def
build_inputs
(
self
,
image_shape
=
[
3
,
None
,
None
],
fields
=
[
'image'
,
'ttf_heatmap'
,
'ttf_box_target'
,
'ttf_reg_weight'
],
# for train
use_dataloader
=
True
,
iterable
=
False
,
downsample
=
4
):
inputs_def
=
self
.
_inputs_def
(
image_shape
,
downsample
)
feed_vars
=
OrderedDict
([(
key
,
fluid
.
data
(
name
=
key
,
shape
=
inputs_def
[
key
][
'shape'
],
dtype
=
inputs_def
[
key
][
'dtype'
],
lod_level
=
inputs_def
[
key
][
'lod_level'
]))
for
key
in
fields
])
loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
feed_list
=
list
(
feed_vars
.
values
()),
capacity
=
16
,
use_double_buffer
=
True
,
iterable
=
iterable
)
if
use_dataloader
else
None
return
feed_vars
,
loader
def
train
(
self
,
feed_vars
):
return
self
.
build
(
feed_vars
,
mode
=
'train'
)
def
eval
(
self
,
feed_vars
):
return
self
.
build
(
feed_vars
,
mode
=
'test'
)
def
test
(
self
,
feed_vars
):
return
self
.
build
(
feed_vars
,
mode
=
'test'
)
ppdet/modeling/backbones/darknet.py
浏览文件 @
d43e6d9a
...
...
@@ -42,13 +42,15 @@ class DarkNet(object):
depth
=
53
,
norm_type
=
'bn'
,
norm_decay
=
0.
,
weight_prefix_name
=
''
):
weight_prefix_name
=
''
,
freeze_at
=-
1
):
assert
depth
in
[
53
],
"unsupported depth value"
self
.
depth
=
depth
self
.
norm_type
=
norm_type
self
.
norm_decay
=
norm_decay
self
.
depth_cfg
=
{
53
:
([
1
,
2
,
8
,
8
,
4
],
self
.
basicblock
)}
self
.
prefix_name
=
weight_prefix_name
self
.
freeze_at
=
freeze_at
def
_conv_norm
(
self
,
input
,
...
...
@@ -161,6 +163,8 @@ class DarkNet(object):
ch_out
=
32
*
2
**
i
,
count
=
stage
,
name
=
self
.
prefix_name
+
"stage.{}"
.
format
(
i
))
if
i
<
self
.
freeze_at
:
block
.
stop_gradient
=
True
blocks
.
append
(
block
)
if
i
<
len
(
stages
)
-
1
:
# do not downsaple in the last stage
downsample_
=
self
.
_downsample
(
...
...
ppdet/modeling/losses/giou_loss.py
浏览文件 @
d43e6d9a
...
...
@@ -33,14 +33,24 @@ class GiouLoss(object):
loss_weight (float): diou loss weight, default as 10 in faster-rcnn
is_cls_agnostic (bool): flag of class-agnostic
num_classes (int): class num
do_average (bool): whether to average the loss
use_class_weight(bool): whether to use class weight
'''
__shared__
=
[
'num_classes'
]
def
__init__
(
self
,
loss_weight
=
10.
,
is_cls_agnostic
=
False
,
num_classes
=
81
):
def
__init__
(
self
,
loss_weight
=
10.
,
is_cls_agnostic
=
False
,
num_classes
=
81
,
do_average
=
True
,
use_class_weight
=
True
):
super
(
GiouLoss
,
self
).
__init__
()
self
.
loss_weight
=
loss_weight
self
.
is_cls_agnostic
=
is_cls_agnostic
self
.
num_classes
=
num_classes
self
.
do_average
=
do_average
self
.
class_weight
=
2
if
is_cls_agnostic
else
num_classes
self
.
use_class_weight
=
use_class_weight
# deltas: NxMx4
def
bbox_transform
(
self
,
deltas
,
weights
):
...
...
@@ -78,10 +88,15 @@ class GiouLoss(object):
y
,
inside_weight
=
None
,
outside_weight
=
None
,
bbox_reg_weight
=
[
0.1
,
0.1
,
0.2
,
0.2
]):
bbox_reg_weight
=
[
0.1
,
0.1
,
0.2
,
0.2
],
use_transform
=
True
):
eps
=
1.e-10
x1
,
y1
,
x2
,
y2
=
self
.
bbox_transform
(
x
,
bbox_reg_weight
)
x1g
,
y1g
,
x2g
,
y2g
=
self
.
bbox_transform
(
y
,
bbox_reg_weight
)
if
use_transform
:
x1
,
y1
,
x2
,
y2
=
self
.
bbox_transform
(
x
,
bbox_reg_weight
)
x1g
,
y1g
,
x2g
,
y2g
=
self
.
bbox_transform
(
y
,
bbox_reg_weight
)
else
:
x1
,
y1
,
x2
,
y2
=
fluid
.
layers
.
split
(
x
,
num_or_sections
=
4
,
dim
=
1
)
x1g
,
y1g
,
x2g
,
y2g
=
fluid
.
layers
.
split
(
y
,
num_or_sections
=
4
,
dim
=
1
)
x2
=
fluid
.
layers
.
elementwise_max
(
x1
,
x2
)
y2
=
fluid
.
layers
.
elementwise_max
(
y1
,
y2
)
...
...
@@ -99,9 +114,9 @@ class GiouLoss(object):
intsctk
=
(
xkis2
-
xkis1
)
*
(
ykis2
-
ykis1
)
intsctk
=
intsctk
*
fluid
.
layers
.
greater_than
(
xkis2
,
xkis1
)
*
fluid
.
layers
.
greater_than
(
ykis2
,
ykis1
)
unionk
=
(
x2
-
x1
)
*
(
y2
-
y1
)
+
(
x2g
-
x1g
)
*
(
y2g
-
y1g
)
-
intsctk
+
eps
iouk
=
intsctk
/
unionk
area_c
=
(
xc2
-
xc1
)
*
(
yc2
-
yc1
)
+
eps
...
...
@@ -116,10 +131,17 @@ class GiouLoss(object):
outside_weight
=
fluid
.
layers
.
reduce_mean
(
outside_weight
,
dim
=
1
)
iou_weights
=
inside_weight
*
outside_weight
class_weight
=
2
if
self
.
is_cls_agnostic
else
self
.
num_classes
iouk
=
fluid
.
layers
.
reduce_mean
((
1
-
iouk
)
*
iou_weights
)
*
class_weight
miouk
=
fluid
.
layers
.
reduce_mean
(
(
1
-
miouk
)
*
iou_weights
)
*
class_weight
elif
outside_weight
is
not
None
:
iou_weights
=
outside_weight
if
self
.
do_average
:
miouk
=
fluid
.
layers
.
reduce_mean
((
1
-
miouk
)
*
iou_weights
)
else
:
iou_distance
=
fluid
.
layers
.
elementwise_mul
(
1
-
miouk
,
iou_weights
,
axis
=
0
)
miouk
=
fluid
.
layers
.
reduce_sum
(
iou_distance
)
if
self
.
use_class_weight
:
miouk
=
miouk
*
self
.
class_weight
return
miouk
*
self
.
loss_weight
ppdet/modeling/ops.py
浏览文件 @
d43e6d9a
...
...
@@ -30,7 +30,8 @@ __all__ = [
'GenerateProposals'
,
'MultiClassNMS'
,
'BBoxAssigner'
,
'MaskAssigner'
,
'RoIAlign'
,
'RoIPool'
,
'MultiBoxHead'
,
'SSDLiteMultiBoxHead'
,
'SSDOutputDecoder'
,
'RetinaTargetAssign'
,
'RetinaOutputDecoder'
,
'ConvNorm'
,
'DeformConvNorm'
,
'MultiClassSoftNMS'
,
'MatrixNMS'
,
'LibraBBoxAssigner'
'DeformConvNorm'
,
'MultiClassSoftNMS'
,
'MatrixNMS'
,
'LibraBBoxAssigner'
,
'DeformConv'
]
...
...
@@ -43,36 +44,32 @@ def _conv_offset(input, filter_size, stride, padding, act=None, name=None):
stride
=
stride
,
padding
=
padding
,
param_attr
=
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0
),
name
=
name
+
".w_0"
),
initializer
=
fluid
.
initializer
.
Constant
(
0
),
name
=
name
+
".w_0"
),
bias_attr
=
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0
),
initializer
=
fluid
.
initializer
.
Constant
(
0
),
learning_rate
=
2.
,
regularizer
=
L2Decay
(
0.
),
name
=
name
+
".b_0"
),
act
=
act
,
name
=
name
)
return
out
def
DeformConvNorm
(
input
,
num_filters
,
filter_size
,
stride
=
1
,
groups
=
1
,
norm_decay
=
0.
,
norm_type
=
'affine_channel'
,
norm_groups
=
32
,
dilation
=
1
,
lr_scale
=
1
,
freeze_norm
=
False
,
act
=
None
,
norm_name
=
None
,
initializer
=
None
,
bias_attr
=
False
,
name
=
None
):
def
DeformConv
(
input
,
num_filters
,
filter_size
,
stride
=
1
,
groups
=
1
,
dilation
=
1
,
lr_scale
=
1
,
initializer
=
None
,
bias_attr
=
False
,
name
=
None
):
if
bias_attr
:
bias_para
=
ParamAttr
(
name
=
name
+
"_bias"
,
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0
),
initializer
=
fluid
.
initializer
.
Constant
(
0
),
regularizer
=
L2Decay
(
0.
),
learning_rate
=
lr_scale
*
2
)
else
:
bias_para
=
False
...
...
@@ -109,6 +106,29 @@ def DeformConvNorm(input,
bias_attr
=
bias_para
,
name
=
name
+
".conv2d.output.1"
)
return
conv
def
DeformConvNorm
(
input
,
num_filters
,
filter_size
,
stride
=
1
,
groups
=
1
,
norm_decay
=
0.
,
norm_type
=
'affine_channel'
,
norm_groups
=
32
,
dilation
=
1
,
lr_scale
=
1
,
freeze_norm
=
False
,
act
=
None
,
norm_name
=
None
,
initializer
=
None
,
bias_attr
=
False
,
name
=
None
):
assert
norm_type
in
[
'bn'
,
'sync_bn'
,
'affine_channel'
]
conv
=
DeformConv
(
input
,
num_filters
,
filter_size
,
stride
,
groups
,
dilation
,
lr_scale
,
initializer
,
bias_attr
,
name
)
norm_lr
=
0.
if
freeze_norm
else
1.
pattr
=
ParamAttr
(
name
=
norm_name
+
'_scale'
,
...
...
@@ -330,7 +350,6 @@ class AnchorGenerator(object):
@
serializable
class
AnchorGrid
(
object
):
"""Generate anchor grid
Args:
image_size (int or list): input image size, may be a single integer or
list of [h, w]. Default: 512
...
...
ppdet/utils/coco_eval.py
浏览文件 @
d43e6d9a
...
...
@@ -261,6 +261,7 @@ def bbox2out(results, clsid2catid, is_bbox_normalized=False):
for
j
in
range
(
num
):
dt
=
bboxes
[
k
]
clsid
,
score
,
xmin
,
ymin
,
xmax
,
ymax
=
dt
.
tolist
()
if
clsid
<
0
:
continue
catid
=
(
clsid2catid
[
int
(
clsid
)])
if
is_bbox_normalized
:
...
...
ppdet/utils/eval_utils.py
浏览文件 @
d43e6d9a
...
...
@@ -161,6 +161,8 @@ def eval_run(exe,
if
'Corner'
in
cfg
.
architecture
and
post_config
is
not
None
:
from
ppdet.utils.post_process
import
corner_post_process
corner_post_process
(
res
,
post_config
,
cfg
.
num_classes
)
if
'TTFNet'
in
cfg
.
architecture
:
res
[
'bbox'
][
1
].
append
([
len
(
res
[
'bbox'
][
0
])])
results
.
append
(
res
)
if
iter_id
%
100
==
0
:
logger
.
info
(
'Test iter {}'
.
format
(
iter_id
))
...
...
tools/export_model.py
浏览文件 @
d43e6d9a
...
...
@@ -76,6 +76,8 @@ def parse_reader(reader_cfg, metric, arch):
params
[
'max_size'
]
=
max
(
image_shape
[
1
:])
if
arch
in
scale_set
else
0
params
[
'image_shape'
]
=
image_shape
[
1
:]
if
'target_dim'
in
params
:
params
.
pop
(
'target_dim'
)
p
.
update
(
params
)
preprocess_list
.
append
(
p
)
batch_transforms
=
reader_cfg
.
get
(
'batch_transforms'
,
None
)
...
...
@@ -109,6 +111,7 @@ def dump_infer_config(FLAGS, config):
'RCNN'
:
40
,
'RetinaNet'
:
40
,
'Face'
:
3
,
'TTFNet'
:
3
,
}
infer_arch
=
config
[
'architecture'
]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录