Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
76f6c939
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
1 年多 前同步成功
通知
696
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
76f6c939
编写于
5月 08, 2020
作者:
W
wangguanzhong
提交者:
GitHub
5月 08, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add Yolo v4 (#594)
* add yolov4
上级
80728741
变更
19
隐藏空白更改
内联
并排
Showing
19 changed file
with
1015 addition
and
108 deletion
+1015
-108
configs/yolo/README.md
configs/yolo/README.md
+46
-0
configs/yolo/yolov4_cspdarknet.yml
configs/yolo/yolov4_cspdarknet.yml
+122
-0
configs/yolo/yolov4_cspdarknet_voc.yml
configs/yolo/yolov4_cspdarknet_voc.yml
+182
-0
ppdet/data/source/coco.py
ppdet/data/source/coco.py
+51
-40
ppdet/data/transform/batch_operators.py
ppdet/data/transform/batch_operators.py
+34
-4
ppdet/data/transform/operators.py
ppdet/data/transform/operators.py
+1
-1
ppdet/modeling/anchor_heads/yolo_head.py
ppdet/modeling/anchor_heads/yolo_head.py
+200
-8
ppdet/modeling/architectures/__init__.py
ppdet/modeling/architectures/__init__.py
+2
-2
ppdet/modeling/architectures/yolo.py
ppdet/modeling/architectures/yolo.py
+26
-2
ppdet/modeling/backbones/__init__.py
ppdet/modeling/backbones/__init__.py
+2
-0
ppdet/modeling/backbones/cspdarknet.py
ppdet/modeling/backbones/cspdarknet.py
+212
-0
ppdet/modeling/losses/iou_aware_loss.py
ppdet/modeling/losses/iou_aware_loss.py
+5
-2
ppdet/modeling/losses/iou_loss.py
ppdet/modeling/losses/iou_loss.py
+60
-24
ppdet/modeling/losses/yolo_loss.py
ppdet/modeling/losses/yolo_loss.py
+28
-10
ppdet/utils/coco_eval.py
ppdet/utils/coco_eval.py
+30
-9
ppdet/utils/eval_utils.py
ppdet/utils/eval_utils.py
+6
-3
ppdet/utils/voc_eval.py
ppdet/utils/voc_eval.py
+0
-1
tools/eval.py
tools/eval.py
+4
-2
tools/train.py
tools/train.py
+4
-0
未找到文件。
configs/yolo/README.md
0 → 100644
浏览文件 @
76f6c939
# YOLO v4
## 内容
-
[
简介
](
#简介
)
-
[
模型库与基线
](
#模型库与基线
)
-
[
未来工作
](
#未来工作
)
-
[
如何贡献代码
](
#如何贡献代码
)
## 简介
[
YOLO v4
](
https://arxiv.org/abs/2004.10934
)
的Paddle实现版本
目前PaddleDetection中转换了
[
darknet
](
https://github.com/AlexeyAB/darknet
)
中YOLO v4的权重,可以直接对图片进行预测,在
[
test-dev2019
](
http://cocodataset.org/#detection-2019
)
中精度为43.5%。另外,PaddleDetection支持VOC数据集上finetune,精度达到86.0%
PaddleDetection支持YOLO v4的多个模块:
-
mish激活函数
-
PAN模块
-
SPP模块
-
ciou loss
-
label_smooth
## 模型库
下表中展示了PaddleDetection当前支持的网络结构。
| | GPU个数 | 测试集 | 骨干网络 | 精度 | 模型下载 | 配置文件 |
|:------------------------:|:-------:|:------:|:--------------------------:|:------------------------:| :---------:| :-----: |
| YOLO v4 | - |test-dev2019 | CSPDarkNet53 | 43.5 |
[
下载链接
](
https://paddlemodels.bj.bcebos.com/object_detection/yolov4_cspdarknet.pdparams
)
|
[
配置文件
](
https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/yolo/yolov4_cspdarknet.yml
)
|
| YOLO v4 VOC | 2 | VOC2007 | CSPDarkNet53 | - |
[
下载链接
](
https://paddlemodels.bj.bcebos.com/object_detection/yolov4_cspdarknet_voc.pdparams
)
|
[
配置文件
](
https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/yolo/yolov4_cspdarknet_voc.yml
)
|
**注意:**
-
由于原版YOLO v4使用coco trainval2014进行训练,训练样本中包含部分评估样本,若使用val集会导致精度虚高,因此使用coco test集对模型进行评估。
-
YOLO v4模型仅支持coco test集评估和图片预测,由于test集不包含目标框的真实标注,评估时会将预测结果保存在json文件中,请将结果提交至
[
cocodataset
](
http://cocodataset.org/#detection-2019
)
上查看最终精度指标。
-
coco测试集使用test2017,下载请参考
[
coco2017
](
http://cocodataset.org/#download
)
## 未来工作
1.
mish激活函数优化
2.
mosaic数据预处理实现
3.
scale
\_
x
\_
y为yolo_box中decode时对box的位置进行微调,该功能将在Paddle2.0版本中实现
## 如何贡献代码
我们非常欢迎您可以为PaddleDetection提供代码,您可以提交PR供我们review;也十分感谢您的反馈,可以提交相应issue,我们会及时解答。
configs/yolo/yolov4_cspdarknet.yml
0 → 100644
浏览文件 @
76f6c939
architecture
:
YOLOv4
use_gpu
:
true
max_iters
:
500200
log_smooth_window
:
20
save_dir
:
output
snapshot_iter
:
10000
metric
:
COCO
pretrain_weights
:
https://paddlemodels.bj.bcebos.com/object_detection/yolov4_cspdarknet.pdparams
weights
:
output/yolov4_cspdarknet/model_final
num_classes
:
80
use_fine_grained_loss
:
true
save_prediction_only
:
True
YOLOv4
:
backbone
:
CSPDarkNet
yolo_head
:
YOLOv4Head
CSPDarkNet
:
norm_type
:
sync_bn
norm_decay
:
0.
depth
:
53
YOLOv4Head
:
anchors
:
[[
12
,
16
],
[
19
,
36
],
[
40
,
28
],
[
36
,
75
],
[
76
,
55
],
[
72
,
146
],
[
142
,
110
],
[
192
,
243
],
[
459
,
401
]]
anchor_masks
:
[[
0
,
1
,
2
],
[
3
,
4
,
5
],
[
6
,
7
,
8
]]
nms
:
background_label
:
-1
keep_top_k
:
-1
nms_threshold
:
0.45
nms_top_k
:
-1
normalized
:
true
score_threshold
:
0.001
downsample
:
[
8
,
16
,
32
]
YOLOv3Loss
:
# batch_size here is only used for fine grained loss, not used
# for training batch_size setting, training batch_size setting
# is in configs/yolov3_reader.yml TrainReader.batch_size, batch
# size here should be set as same value as TrainReader.batch_size
batch_size
:
4
ignore_thresh
:
0.7
label_smooth
:
true
downsample
:
[
8
,
16
,
32
]
#scale_x_y: [1.2, 1.1, 1.05]
iou_loss
:
IouLoss
match_score
:
true
IouLoss
:
loss_weight
:
0.07
max_height
:
608
max_width
:
608
ciou_term
:
true
loss_square
:
false
LearningRate
:
base_lr
:
0.0001
schedulers
:
-
!PiecewiseDecay
gamma
:
0.1
milestones
:
-
400000
-
450000
-
!LinearWarmup
start_factor
:
0.
steps
:
1000
OptimizerBuilder
:
clip_grad_by_norm
:
10.
optimizer
:
momentum
:
0.949
type
:
Momentum
regularizer
:
factor
:
0.0005
type
:
L2
_READER_
:
'
../yolov3_reader.yml'
EvalReader
:
inputs_def
:
fields
:
[
'
image'
,
'
im_size'
,
'
im_id'
]
num_max_boxes
:
90
dataset
:
!COCODataSet
image_dir
:
test2017
anno_path
:
annotations/image_info_test-dev2017.json
dataset_dir
:
data/coco
with_background
:
false
sample_transforms
:
-
!DecodeImage
to_rgb
:
True
-
!ResizeImage
target_size
:
608
interp
:
1
-
!NormalizeImage
mean
:
[
0.
,
0.
,
0.
]
std
:
[
1.
,
1.
,
1.
]
is_scale
:
True
is_channel_first
:
false
-
!Permute
to_bgr
:
false
channel_first
:
True
batch_size
:
1
TestReader
:
dataset
:
!ImageFolder
use_default_label
:
true
with_background
:
false
sample_transforms
:
-
!DecodeImage
to_rgb
:
True
-
!ResizeImage
target_size
:
608
interp
:
1
-
!NormalizeImage
mean
:
[
0.
,
0.
,
0.
]
std
:
[
1.
,
1.
,
1.
]
is_scale
:
True
is_channel_first
:
false
-
!Permute
to_bgr
:
false
channel_first
:
True
configs/yolo/yolov4_cspdarknet_voc.yml
0 → 100644
浏览文件 @
76f6c939
architecture
:
YOLOv4
use_gpu
:
true
max_iters
:
70000
log_smooth_window
:
20
save_dir
:
output
snapshot_iter
:
1000
metric
:
VOC
pretrain_weights
:
https://paddlemodels.bj.bcebos.com/object_detection/yolov4_cspdarknet.pdparams
weights
:
output/yolov4_cspdarknet_voc/model_final
num_classes
:
20
use_fine_grained_loss
:
true
YOLOv4
:
backbone
:
CSPDarkNet
yolo_head
:
YOLOv4Head
CSPDarkNet
:
norm_type
:
sync_bn
norm_decay
:
0.
depth
:
53
YOLOv4Head
:
anchors
:
[[
12
,
16
],
[
19
,
36
],
[
40
,
28
],
[
36
,
75
],
[
76
,
55
],
[
72
,
146
],
[
142
,
110
],
[
192
,
243
],
[
459
,
401
]]
anchor_masks
:
[[
0
,
1
,
2
],
[
3
,
4
,
5
],
[
6
,
7
,
8
]]
nms
:
background_label
:
-1
keep_top_k
:
-1
nms_threshold
:
0.45
nms_top_k
:
-1
normalized
:
true
score_threshold
:
0.001
downsample
:
[
8
,
16
,
32
]
YOLOv3Loss
:
# batch_size here is only used for fine grained loss, not used
# for training batch_size setting, training batch_size setting
# is in configs/yolov3_reader.yml TrainReader.batch_size, batch
# size here should be set as same value as TrainReader.batch_size
batch_size
:
4
ignore_thresh
:
0.7
label_smooth
:
true
downsample
:
[
8
,
16
,
32
]
#scale_x_y: [1.2, 1.1, 1.05]
iou_loss
:
IouLoss
match_score
:
true
IouLoss
:
loss_weight
:
0.07
max_height
:
608
max_width
:
608
ciou_term
:
true
loss_square
:
true
LearningRate
:
base_lr
:
0.0001
schedulers
:
-
!PiecewiseDecay
gamma
:
0.1
milestones
:
-
56000
-
63000
-
!LinearWarmup
start_factor
:
0.
steps
:
1000
OptimizerBuilder
:
clip_grad_by_norm
:
10.
optimizer
:
momentum
:
0.949
type
:
Momentum
regularizer
:
factor
:
0.0005
type
:
L2
_READER_
:
'
../yolov3_reader.yml'
TrainReader
:
inputs_def
:
fields
:
[
'
image'
,
'
gt_bbox'
,
'
gt_class'
,
'
gt_score'
,
'
im_id'
]
num_max_boxes
:
50
dataset
:
!VOCDataSet
anno_path
:
trainval.txt
dataset_dir
:
dataset/voc
with_background
:
false
sample_transforms
:
-
!DecodeImage
to_rgb
:
True
with_mixup
:
True
-
!MixupImage
alpha
:
1.5
beta
:
1.5
-
!ColorDistort
{}
-
!RandomExpand
fill_value
:
[
123.675
,
116.28
,
103.53
]
-
!RandomCrop
{}
-
!RandomFlipImage
is_normalized
:
false
-
!NormalizeBox
{}
-
!PadBox
num_max_boxes
:
50
-
!BboxXYXY2XYWH
{}
batch_transforms
:
-
!RandomShape
sizes
:
[
320
,
352
,
384
,
416
,
448
,
480
,
512
,
544
,
576
,
608
]
random_inter
:
True
-
!NormalizeImage
mean
:
[
0.
,
0.
,
0.
]
std
:
[
1.
,
1.
,
1.
]
is_scale
:
True
is_channel_first
:
false
-
!Permute
to_bgr
:
false
channel_first
:
True
# Gt2YoloTarget is only used when use_fine_grained_loss set as true,
# this operator will be deleted automatically if use_fine_grained_loss
# is set as false
-
!Gt2YoloTarget
anchor_masks
:
[[
0
,
1
,
2
],
[
3
,
4
,
5
],
[
6
,
7
,
8
]]
anchors
:
[[
12
,
16
],
[
19
,
36
],
[
40
,
28
],
[
36
,
75
],
[
76
,
55
],
[
72
,
146
],
[
142
,
110
],
[
192
,
243
],
[
459
,
401
]]
downsample_ratios
:
[
8
,
16
,
32
]
batch_size
:
4
shuffle
:
true
mixup_epoch
:
28
drop_last
:
true
worker_num
:
8
bufsize
:
16
use_process
:
true
drop_empty
:
false
EvalReader
:
inputs_def
:
fields
:
[
'
image'
,
'
im_size'
,
'
im_id'
,
'
gt_bbox'
,
'
gt_class'
,
'
is_difficult'
]
num_max_boxes
:
90
dataset
:
!VOCDataSet
anno_path
:
test.txt
dataset_dir
:
dataset/voc
use_default_label
:
true
with_background
:
false
sample_transforms
:
-
!DecodeImage
to_rgb
:
True
-
!ResizeImage
target_size
:
608
interp
:
1
-
!NormalizeImage
mean
:
[
0.
,
0.
,
0.
]
std
:
[
1.
,
1.
,
1.
]
is_scale
:
True
is_channel_first
:
false
-
!PadBox
num_max_boxes
:
90
-
!Permute
to_bgr
:
false
channel_first
:
True
batch_size
:
4
drop_empty
:
false
worker_num
:
8
bufsize
:
16
TestReader
:
dataset
:
!ImageFolder
use_default_label
:
true
with_background
:
false
sample_transforms
:
-
!DecodeImage
to_rgb
:
True
-
!ResizeImage
target_size
:
608
interp
:
1
-
!NormalizeImage
mean
:
[
0.
,
0.
,
0.
]
std
:
[
1.
,
1.
,
1.
]
is_scale
:
True
is_channel_first
:
false
-
!Permute
to_bgr
:
false
channel_first
:
True
ppdet/data/source/coco.py
浏览文件 @
76f6c939
...
...
@@ -67,6 +67,7 @@ class COCODataSet(DataSet):
self
.
roidbs
=
None
# a dict used to map category name to class id
self
.
cname2cid
=
None
self
.
load_image_only
=
False
def
load_roidb_and_cname2cid
(
self
):
anno_path
=
os
.
path
.
join
(
self
.
dataset_dir
,
self
.
anno_path
)
...
...
@@ -92,61 +93,71 @@ class COCODataSet(DataSet):
for
catid
,
clsid
in
catid2clsid
.
items
()
})
if
'annotations'
not
in
coco
.
dataset
:
self
.
load_image_only
=
True
logger
.
warn
(
'Annotation file: {} does not contains ground truth '
'and load image information only.'
.
format
(
anno_path
))
for
img_id
in
img_ids
:
img_anno
=
coco
.
loadImgs
(
img_id
)[
0
]
im_fname
=
img_anno
[
'file_name'
]
im_w
=
float
(
img_anno
[
'width'
])
im_h
=
float
(
img_anno
[
'height'
])
ins_anno_ids
=
coco
.
getAnnIds
(
imgIds
=
img_id
,
iscrowd
=
False
)
instances
=
coco
.
loadAnns
(
ins_anno_ids
)
bboxes
=
[]
for
inst
in
instances
:
x
,
y
,
box_w
,
box_h
=
inst
[
'bbox'
]
x1
=
max
(
0
,
x
)
y1
=
max
(
0
,
y
)
x2
=
min
(
im_w
-
1
,
x1
+
max
(
0
,
box_w
-
1
))
y2
=
min
(
im_h
-
1
,
y1
+
max
(
0
,
box_h
-
1
))
if
inst
[
'area'
]
>
0
and
x2
>=
x1
and
y2
>=
y1
:
inst
[
'clean_bbox'
]
=
[
x1
,
y1
,
x2
,
y2
]
bboxes
.
append
(
inst
)
else
:
logger
.
warn
(
'Found an invalid bbox in annotations: im_id: {}, '
'area: {} x1: {}, y1: {}, x2: {}, y2: {}.'
.
format
(
img_id
,
float
(
inst
[
'area'
]),
x1
,
y1
,
x2
,
y2
))
num_bbox
=
len
(
bboxes
)
gt_bbox
=
np
.
zeros
((
num_bbox
,
4
),
dtype
=
np
.
float32
)
gt_class
=
np
.
zeros
((
num_bbox
,
1
),
dtype
=
np
.
int32
)
gt_score
=
np
.
ones
((
num_bbox
,
1
),
dtype
=
np
.
float32
)
is_crowd
=
np
.
zeros
((
num_bbox
,
1
),
dtype
=
np
.
int32
)
difficult
=
np
.
zeros
((
num_bbox
,
1
),
dtype
=
np
.
int32
)
gt_poly
=
[
None
]
*
num_bbox
for
i
,
box
in
enumerate
(
bboxes
):
catid
=
box
[
'category_id'
]
gt_class
[
i
][
0
]
=
catid2clsid
[
catid
]
gt_bbox
[
i
,
:]
=
box
[
'clean_bbox'
]
is_crowd
[
i
][
0
]
=
box
[
'iscrowd'
]
if
'segmentation'
in
box
:
gt_poly
[
i
]
=
box
[
'segmentation'
]
im_fname
=
os
.
path
.
join
(
image_dir
,
im_fname
)
if
image_dir
else
im_fname
coco_rec
=
{
'im_file'
:
im_fname
,
'im_id'
:
np
.
array
([
img_id
]),
'h'
:
im_h
,
'w'
:
im_w
,
'is_crowd'
:
is_crowd
,
'gt_class'
:
gt_class
,
'gt_bbox'
:
gt_bbox
,
'gt_score'
:
gt_score
,
'gt_poly'
:
gt_poly
,
}
if
not
self
.
load_image_only
:
ins_anno_ids
=
coco
.
getAnnIds
(
imgIds
=
img_id
,
iscrowd
=
False
)
instances
=
coco
.
loadAnns
(
ins_anno_ids
)
bboxes
=
[]
for
inst
in
instances
:
x
,
y
,
box_w
,
box_h
=
inst
[
'bbox'
]
x1
=
max
(
0
,
x
)
y1
=
max
(
0
,
y
)
x2
=
min
(
im_w
-
1
,
x1
+
max
(
0
,
box_w
-
1
))
y2
=
min
(
im_h
-
1
,
y1
+
max
(
0
,
box_h
-
1
))
if
inst
[
'area'
]
>
0
and
x2
>=
x1
and
y2
>=
y1
:
inst
[
'clean_bbox'
]
=
[
x1
,
y1
,
x2
,
y2
]
bboxes
.
append
(
inst
)
else
:
logger
.
warn
(
'Found an invalid bbox in annotations: im_id: {}, '
'area: {} x1: {}, y1: {}, x2: {}, y2: {}.'
.
format
(
img_id
,
float
(
inst
[
'area'
]),
x1
,
y1
,
x2
,
y2
))
num_bbox
=
len
(
bboxes
)
gt_bbox
=
np
.
zeros
((
num_bbox
,
4
),
dtype
=
np
.
float32
)
gt_class
=
np
.
zeros
((
num_bbox
,
1
),
dtype
=
np
.
int32
)
gt_score
=
np
.
ones
((
num_bbox
,
1
),
dtype
=
np
.
float32
)
is_crowd
=
np
.
zeros
((
num_bbox
,
1
),
dtype
=
np
.
int32
)
difficult
=
np
.
zeros
((
num_bbox
,
1
),
dtype
=
np
.
int32
)
gt_poly
=
[
None
]
*
num_bbox
for
i
,
box
in
enumerate
(
bboxes
):
catid
=
box
[
'category_id'
]
gt_class
[
i
][
0
]
=
catid2clsid
[
catid
]
gt_bbox
[
i
,
:]
=
box
[
'clean_bbox'
]
is_crowd
[
i
][
0
]
=
box
[
'iscrowd'
]
if
'segmentation'
in
box
:
gt_poly
[
i
]
=
box
[
'segmentation'
]
coco_rec
.
update
({
'is_crowd'
:
is_crowd
,
'gt_class'
:
gt_class
,
'gt_bbox'
:
gt_bbox
,
'gt_score'
:
gt_score
,
'gt_poly'
:
gt_poly
,
})
logger
.
debug
(
'Load file: {}, im_id: {}, h: {}, w: {}.'
.
format
(
im_fname
,
img_id
,
im_h
,
im_w
))
records
.
append
(
coco_rec
)
...
...
ppdet/data/transform/batch_operators.py
浏览文件 @
76f6c939
...
...
@@ -179,13 +179,18 @@ class Gt2YoloTarget(BaseOperator):
fine grained YOLOv3 loss mode
"""
def
__init__
(
self
,
anchors
,
anchor_masks
,
downsample_ratios
,
num_classes
=
80
):
def
__init__
(
self
,
anchors
,
anchor_masks
,
downsample_ratios
,
num_classes
=
80
,
iou_thresh
=
1.
):
super
(
Gt2YoloTarget
,
self
).
__init__
()
self
.
anchors
=
anchors
self
.
anchor_masks
=
anchor_masks
self
.
downsample_ratios
=
downsample_ratios
self
.
num_classes
=
num_classes
self
.
iou_thresh
=
iou_thresh
def
__call__
(
self
,
samples
,
context
=
None
):
assert
len
(
self
.
anchor_masks
)
==
len
(
self
.
downsample_ratios
),
\
...
...
@@ -225,12 +230,13 @@ class Gt2YoloTarget(BaseOperator):
best_iou
=
iou
best_idx
=
an_idx
gi
=
int
(
gx
*
grid_w
)
gj
=
int
(
gy
*
grid_h
)
# gtbox should be regresed in this layes if best match
# anchor index in anchor mask of this layer
if
best_idx
in
mask
:
best_n
=
mask
.
index
(
best_idx
)
gi
=
int
(
gx
*
grid_w
)
gj
=
int
(
gy
*
grid_h
)
# x, y, w, h, scale
target
[
best_n
,
0
,
gj
,
gi
]
=
gx
*
grid_w
-
gi
...
...
@@ -246,6 +252,30 @@ class Gt2YoloTarget(BaseOperator):
# classification
target
[
best_n
,
6
+
cls
,
gj
,
gi
]
=
1.
# For non-matched anchors, calculate the target if the iou
# between anchor and gt is larger than iou_thresh
if
self
.
iou_thresh
<
1
:
for
idx
,
mask_i
in
enumerate
(
mask
):
if
mask_i
==
best_idx
:
continue
iou
=
jaccard_overlap
(
[
0.
,
0.
,
gw
,
gh
],
[
0.
,
0.
,
an_hw
[
mask_i
,
0
],
an_hw
[
mask_i
,
1
]])
if
iou
>
self
.
iou_thresh
:
# x, y, w, h, scale
target
[
idx
,
0
,
gj
,
gi
]
=
gx
*
grid_w
-
gi
target
[
idx
,
1
,
gj
,
gi
]
=
gy
*
grid_h
-
gj
target
[
idx
,
2
,
gj
,
gi
]
=
np
.
log
(
gw
*
w
/
self
.
anchors
[
mask_i
][
0
])
target
[
idx
,
3
,
gj
,
gi
]
=
np
.
log
(
gh
*
h
/
self
.
anchors
[
mask_i
][
1
])
target
[
idx
,
4
,
gj
,
gi
]
=
2.0
-
gw
*
gh
# objectness record gt_score
target
[
idx
,
5
,
gj
,
gi
]
=
score
# classification
target
[
idx
,
6
+
cls
,
gj
,
gi
]
=
1.
sample
[
'target{}'
.
format
(
i
)]
=
target
return
samples
...
...
ppdet/data/transform/operators.py
浏览文件 @
76f6c939
...
...
@@ -114,6 +114,7 @@ class DecodeImage(BaseOperator):
im
=
sample
[
'image'
]
data
=
np
.
frombuffer
(
im
,
dtype
=
'uint8'
)
im
=
cv2
.
imdecode
(
data
,
1
)
# BGR mode, but need RGB mode
if
self
.
to_rgb
:
im
=
cv2
.
cvtColor
(
im
,
cv2
.
COLOR_BGR2RGB
)
sample
[
'image'
]
=
im
...
...
@@ -331,7 +332,6 @@ class ResizeImage(BaseOperator):
im
=
Image
.
fromarray
(
im
)
im
=
im
.
resize
((
int
(
resize_w
),
int
(
resize_h
)),
self
.
interp
)
im
=
np
.
array
(
im
)
sample
[
'image'
]
=
im
return
sample
...
...
ppdet/modeling/anchor_heads/yolo_head.py
浏览文件 @
76f6c939
...
...
@@ -25,8 +25,12 @@ from ppdet.modeling.losses.yolo_loss import YOLOv3Loss
from
ppdet.core.workspace
import
register
from
ppdet.modeling.ops
import
DropBlock
from
.iou_aware
import
get_iou_aware_score
try
:
from
collections.abc
import
Sequence
except
Exception
:
from
collections
import
Sequence
__all__
=
[
'YOLOv3Head'
]
__all__
=
[
'YOLOv3Head'
,
'YOLOv4Head'
]
@
register
...
...
@@ -62,7 +66,9 @@ class YOLOv3Head(object):
keep_top_k
=
100
,
nms_threshold
=
0.45
,
background_label
=-
1
).
__dict__
,
weight_prefix_name
=
''
):
weight_prefix_name
=
''
,
downsample
=
[
32
,
16
,
8
],
scale_x_y
=
1.0
):
self
.
norm_decay
=
norm_decay
self
.
num_classes
=
num_classes
self
.
anchor_masks
=
anchor_masks
...
...
@@ -77,6 +83,9 @@ class YOLOv3Head(object):
self
.
keep_prob
=
keep_prob
if
isinstance
(
nms
,
dict
):
self
.
nms
=
MultiClassNMS
(
**
nms
)
self
.
downsample
=
downsample
# TODO(guanzhong) activate scale_x_y in Paddle 2.0
#self.scale_x_y = scale_x_y
def
_conv_bn
(
self
,
input
,
...
...
@@ -105,7 +114,6 @@ class YOLOv3Head(object):
out
=
fluid
.
layers
.
batch_norm
(
input
=
conv
,
act
=
None
,
is_test
=
is_test
,
param_attr
=
bn_param_attr
,
bias_attr
=
bn_bias_attr
,
moving_mean_name
=
bn_name
+
'.mean'
,
...
...
@@ -301,27 +309,211 @@ class YOLOv3Head(object):
boxes
=
[]
scores
=
[]
downsample
=
32
for
i
,
output
in
enumerate
(
outputs
):
if
self
.
iou_aware
:
output
=
get_iou_aware_score
(
output
,
len
(
self
.
anchor_masks
[
i
]),
self
.
num_classes
,
self
.
iou_aware_factor
)
#scale_x_y = self.scale_x_y if not isinstance(
# self.scale_x_y, Sequence) else self.scale_x_y[i]
box
,
score
=
fluid
.
layers
.
yolo_box
(
x
=
output
,
img_size
=
im_size
,
anchors
=
self
.
mask_anchors
[
i
],
class_num
=
self
.
num_classes
,
conf_thresh
=
self
.
nms
.
score_threshold
,
downsample_ratio
=
downsample
,
name
=
self
.
prefix_name
+
"yolo_box"
+
str
(
i
))
downsample_ratio
=
self
.
downsample
[
i
],
name
=
self
.
prefix_name
+
"yolo_box"
+
str
(
i
),
clip_bbox
=
False
)
boxes
.
append
(
box
)
scores
.
append
(
fluid
.
layers
.
transpose
(
score
,
perm
=
[
0
,
2
,
1
]))
downsample
//=
2
yolo_boxes
=
fluid
.
layers
.
concat
(
boxes
,
axis
=
1
)
yolo_scores
=
fluid
.
layers
.
concat
(
scores
,
axis
=
2
)
pred
=
self
.
nms
(
bboxes
=
yolo_boxes
,
scores
=
yolo_scores
)
return
{
'bbox'
:
pred
}
@
register
class
YOLOv4Head
(
YOLOv3Head
):
"""
Head block for YOLOv4 network
Args:
anchors (list): anchors
anchor_masks (list): anchor masks
nms (object): an instance of `MultiClassNMS`
spp_stage (int): apply spp on which stage.
num_classes (int): number of output classes
downsample (list): downsample ratio for each yolo_head
scale_x_y (list): scale the left top point of bbox at each stage
"""
__inject__
=
[
'nms'
,
'yolo_loss'
]
__shared__
=
[
'num_classes'
,
'weight_prefix_name'
]
def
__init__
(
self
,
anchors
=
[[
12
,
16
],
[
19
,
36
],
[
40
,
28
],
[
36
,
75
],
[
76
,
55
],
[
72
,
146
],
[
142
,
110
],
[
192
,
243
],
[
459
,
401
]],
anchor_masks
=
[[
0
,
1
,
2
],
[
3
,
4
,
5
],
[
6
,
7
,
8
]],
nms
=
MultiClassNMS
(
score_threshold
=
0.01
,
nms_top_k
=-
1
,
keep_top_k
=-
1
,
nms_threshold
=
0.45
,
background_label
=-
1
).
__dict__
,
spp_stage
=
5
,
num_classes
=
80
,
weight_prefix_name
=
''
,
downsample
=
[
8
,
16
,
32
],
scale_x_y
=
[
1.2
,
1.1
,
1.05
],
yolo_loss
=
"YOLOv3Loss"
,
iou_aware
=
False
,
iou_aware_factor
=
0.4
,
):
super
(
YOLOv4Head
,
self
).
__init__
(
anchors
=
anchors
,
anchor_masks
=
anchor_masks
,
nms
=
nms
,
num_classes
=
num_classes
,
weight_prefix_name
=
weight_prefix_name
,
downsample
=
downsample
,
scale_x_y
=
scale_x_y
,
yolo_loss
=
yolo_loss
,
iou_aware
=
iou_aware
,
iou_aware_factor
=
iou_aware_factor
)
self
.
spp_stage
=
spp_stage
def
_upsample
(
self
,
input
,
scale
=
2
,
name
=
None
):
out
=
fluid
.
layers
.
resize_nearest
(
input
=
input
,
scale
=
float
(
scale
),
name
=
name
)
return
out
def
max_pool
(
self
,
input
,
size
):
pad
=
[(
size
-
1
)
//
2
]
*
2
return
fluid
.
layers
.
pool2d
(
input
,
size
,
'max'
,
pool_padding
=
pad
)
def
spp
(
self
,
input
):
branch_a
=
self
.
max_pool
(
input
,
13
)
branch_b
=
self
.
max_pool
(
input
,
9
)
branch_c
=
self
.
max_pool
(
input
,
5
)
out
=
fluid
.
layers
.
concat
([
branch_a
,
branch_b
,
branch_c
,
input
],
axis
=
1
)
return
out
def
stack_conv
(
self
,
input
,
ch_list
=
[
512
,
1024
,
512
],
filter_list
=
[
1
,
3
,
1
],
stride
=
1
,
name
=
None
):
conv
=
input
for
i
,
(
ch_out
,
f_size
)
in
enumerate
(
zip
(
ch_list
,
filter_list
)):
padding
=
1
if
f_size
==
3
else
0
conv
=
self
.
_conv_bn
(
conv
,
ch_out
=
ch_out
,
filter_size
=
f_size
,
stride
=
stride
,
padding
=
padding
,
name
=
'{}.{}'
.
format
(
name
,
i
))
return
conv
def
spp_module
(
self
,
input
,
name
=
None
):
conv
=
self
.
stack_conv
(
input
,
name
=
name
+
'.stack_conv.0'
)
spp_out
=
self
.
spp
(
conv
)
conv
=
self
.
stack_conv
(
spp_out
,
name
=
name
+
'.stack_conv.1'
)
return
conv
def
pan_module
(
self
,
input
,
filter_list
,
name
=
None
):
for
i
in
range
(
1
,
len
(
input
)):
ch_out
=
input
[
i
].
shape
[
1
]
//
2
conv_left
=
self
.
_conv_bn
(
input
[
i
],
ch_out
=
ch_out
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
name
=
name
+
'.{}.left'
.
format
(
i
))
ch_out
=
input
[
i
-
1
].
shape
[
1
]
//
2
conv_right
=
self
.
_conv_bn
(
input
[
i
-
1
],
ch_out
=
ch_out
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
name
=
name
+
'.{}.right'
.
format
(
i
))
conv_right
=
self
.
_upsample
(
conv_right
)
pan_out
=
fluid
.
layers
.
concat
([
conv_left
,
conv_right
],
axis
=
1
)
ch_list
=
[
pan_out
.
shape
[
1
]
//
2
*
k
for
k
in
[
1
,
2
,
1
,
2
,
1
]]
input
[
i
]
=
self
.
stack_conv
(
pan_out
,
ch_list
=
ch_list
,
filter_list
=
filter_list
,
name
=
name
+
'.stack_conv.{}'
.
format
(
i
))
return
input
def
_get_outputs
(
self
,
input
,
is_train
=
True
):
outputs
=
[]
filter_list
=
[
1
,
3
,
1
,
3
,
1
]
spp_stage
=
len
(
input
)
-
self
.
spp_stage
# get last out_layer_num blocks in reverse order
out_layer_num
=
len
(
self
.
anchor_masks
)
blocks
=
input
[
-
1
:
-
out_layer_num
-
1
:
-
1
]
blocks
[
spp_stage
]
=
self
.
spp_module
(
blocks
[
spp_stage
],
name
=
self
.
prefix_name
+
"spp_module"
)
blocks
=
self
.
pan_module
(
blocks
,
filter_list
=
filter_list
,
name
=
self
.
prefix_name
+
'pan_module'
)
# reverse order back to input
blocks
=
blocks
[::
-
1
]
route
=
None
for
i
,
block
in
enumerate
(
blocks
):
if
i
>
0
:
# perform concat in first 2 detection_block
route
=
self
.
_conv_bn
(
route
,
ch_out
=
route
.
shape
[
1
]
*
2
,
filter_size
=
3
,
stride
=
2
,
padding
=
1
,
name
=
self
.
prefix_name
+
'yolo_block.route.{}'
.
format
(
i
))
block
=
fluid
.
layers
.
concat
(
input
=
[
route
,
block
],
axis
=
1
)
ch_list
=
[
block
.
shape
[
1
]
//
2
*
k
for
k
in
[
1
,
2
,
1
,
2
,
1
]]
block
=
self
.
stack_conv
(
block
,
ch_list
=
ch_list
,
filter_list
=
filter_list
,
name
=
self
.
prefix_name
+
'yolo_block.stack_conv.{}'
.
format
(
i
))
route
=
block
block_out
=
self
.
_conv_bn
(
block
,
ch_out
=
block
.
shape
[
1
]
*
2
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
,
name
=
self
.
prefix_name
+
'yolo_output.{}.conv.0'
.
format
(
i
))
if
self
.
iou_aware
:
num_filters
=
len
(
self
.
anchor_masks
[
i
])
*
(
self
.
num_classes
+
6
)
else
:
num_filters
=
len
(
self
.
anchor_masks
[
i
])
*
(
self
.
num_classes
+
5
)
block_out
=
fluid
.
layers
.
conv2d
(
input
=
block_out
,
num_filters
=
num_filters
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
act
=
None
,
param_attr
=
ParamAttr
(
name
=
self
.
prefix_name
+
"yolo_output.{}.conv.1.weights"
.
format
(
i
)),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.
),
name
=
self
.
prefix_name
+
"yolo_output.{}.conv.1.bias"
.
format
(
i
)))
outputs
.
append
(
block_out
)
return
outputs
ppdet/modeling/architectures/__init__.py
浏览文件 @
76f6c939
...
...
@@ -19,7 +19,7 @@ from . import mask_rcnn
from
.
import
cascade_rcnn
from
.
import
cascade_mask_rcnn
from
.
import
cascade_rcnn_cls_aware
from
.
import
yolo
v3
from
.
import
yolo
from
.
import
ssd
from
.
import
retinanet
from
.
import
efficientdet
...
...
@@ -33,7 +33,7 @@ from .mask_rcnn import *
from
.cascade_rcnn
import
*
from
.cascade_mask_rcnn
import
*
from
.cascade_rcnn_cls_aware
import
*
from
.yolo
v3
import
*
from
.yolo
import
*
from
.ssd
import
*
from
.retinanet
import
*
from
.efficientdet
import
*
...
...
ppdet/modeling/architectures/yolo
v3
.py
→
ppdet/modeling/architectures/yolo.py
浏览文件 @
76f6c939
...
...
@@ -23,7 +23,7 @@ from paddle import fluid
from
ppdet.experimental
import
mixed_precision_global_state
from
ppdet.core.workspace
import
register
__all__
=
[
'YOLOv3'
]
__all__
=
[
'YOLOv3'
,
'YOLOv4'
]
@
register
...
...
@@ -42,7 +42,7 @@ class YOLOv3(object):
def
__init__
(
self
,
backbone
,
yolo_head
=
'YOLOv
3
Head'
,
yolo_head
=
'YOLOv
4
Head'
,
use_fine_grained_loss
=
False
):
super
(
YOLOv3
,
self
).
__init__
()
self
.
backbone
=
backbone
...
...
@@ -160,3 +160,27 @@ class YOLOv3(object):
def
test
(
self
,
feed_vars
):
return
self
.
build
(
feed_vars
,
mode
=
'test'
)
@
register
class
YOLOv4
(
YOLOv3
):
"""
YOLOv4 network, see https://arxiv.org/abs/2004.10934
Args:
backbone (object): an backbone instance
yolo_head (object): an `YOLOv4Head` instance
"""
__category__
=
'architecture'
__inject__
=
[
'backbone'
,
'yolo_head'
]
__shared__
=
[
'use_fine_grained_loss'
]
def
__init__
(
self
,
backbone
,
yolo_head
=
'YOLOv4Head'
,
use_fine_grained_loss
=
False
):
super
(
YOLOv4
,
self
).
__init__
(
backbone
=
backbone
,
yolo_head
=
yolo_head
,
use_fine_grained_loss
=
use_fine_grained_loss
)
ppdet/modeling/backbones/__init__.py
浏览文件 @
76f6c939
...
...
@@ -32,6 +32,7 @@ from . import bfp
from
.
import
hourglass
from
.
import
efficientnet
from
.
import
bifpn
from
.
import
cspdarknet
from
.resnet
import
*
from
.resnext
import
*
...
...
@@ -51,3 +52,4 @@ from .bfp import *
from
.hourglass
import
*
from
.efficientnet
import
*
from
.bifpn
import
*
from
.cspdarknet
import
*
ppdet/modeling/backbones/cspdarknet.py
0 → 100644
浏览文件 @
76f6c939
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
six
from
paddle
import
fluid
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.regularizer
import
L2Decay
from
ppdet.core.workspace
import
register
__all__
=
[
'CSPDarkNet'
]
@
register
class
CSPDarkNet
(
object
):
"""
CSPDarkNet, see https://arxiv.org/abs/1911.11929
Args:
depth (int): network depth, currently only cspdarknet 53 is supported
norm_type (str): normalization type, 'bn' and 'sync_bn' are supported
norm_decay (float): weight decay for normalization layer weights
"""
__shared__
=
[
'norm_type'
,
'weight_prefix_name'
]
def
__init__
(
self
,
depth
=
53
,
norm_type
=
'bn'
,
norm_decay
=
0.
,
weight_prefix_name
=
''
):
assert
depth
in
[
53
],
"unsupported depth value"
self
.
depth
=
depth
self
.
norm_type
=
norm_type
self
.
norm_decay
=
norm_decay
self
.
depth_cfg
=
{
53
:
([
1
,
2
,
8
,
8
,
4
],
self
.
basicblock
)}
self
.
prefix_name
=
weight_prefix_name
def
_softplus
(
self
,
input
):
expf
=
fluid
.
layers
.
exp
(
fluid
.
layers
.
clip
(
input
,
-
200
,
50
))
return
fluid
.
layers
.
log
(
1
+
expf
)
def
_mish
(
self
,
input
):
return
input
*
fluid
.
layers
.
tanh
(
self
.
_softplus
(
input
))
def
_conv_norm
(
self
,
input
,
ch_out
,
filter_size
,
stride
,
padding
,
act
=
'mish'
,
name
=
None
):
conv
=
fluid
.
layers
.
conv2d
(
input
=
input
,
num_filters
=
ch_out
,
filter_size
=
filter_size
,
stride
=
stride
,
padding
=
padding
,
act
=
None
,
param_attr
=
ParamAttr
(
name
=
name
+
".conv.weights"
),
bias_attr
=
False
)
bn_name
=
name
+
".bn"
bn_param_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
float
(
self
.
norm_decay
)),
name
=
bn_name
+
'.scale'
)
bn_bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
float
(
self
.
norm_decay
)),
name
=
bn_name
+
'.offset'
)
out
=
fluid
.
layers
.
batch_norm
(
input
=
conv
,
act
=
None
,
param_attr
=
bn_param_attr
,
bias_attr
=
bn_bias_attr
,
moving_mean_name
=
bn_name
+
'.mean'
,
moving_variance_name
=
bn_name
+
'.var'
)
if
act
==
'mish'
:
out
=
self
.
_mish
(
out
)
return
out
def
_downsample
(
self
,
input
,
ch_out
,
filter_size
=
3
,
stride
=
2
,
padding
=
1
,
name
=
None
):
return
self
.
_conv_norm
(
input
,
ch_out
=
ch_out
,
filter_size
=
filter_size
,
stride
=
stride
,
padding
=
padding
,
name
=
name
)
def
conv_layer
(
self
,
input
,
ch_out
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
name
=
None
):
return
self
.
_conv_norm
(
input
,
ch_out
=
ch_out
,
filter_size
=
filter_size
,
stride
=
stride
,
padding
=
padding
,
name
=
name
)
def
basicblock
(
self
,
input
,
ch_out
,
scale_first
=
False
,
name
=
None
):
conv1
=
self
.
_conv_norm
(
input
,
ch_out
=
ch_out
//
2
if
scale_first
else
ch_out
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
name
=
name
+
".0"
)
conv2
=
self
.
_conv_norm
(
conv1
,
ch_out
=
ch_out
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
,
name
=
name
+
".1"
)
out
=
fluid
.
layers
.
elementwise_add
(
x
=
input
,
y
=
conv2
,
act
=
None
)
return
out
def
layer_warp
(
self
,
block_func
,
input
,
ch_out
,
count
,
keep_ch
=
False
,
scale_first
=
False
,
name
=
None
):
if
scale_first
:
ch_out
=
ch_out
*
2
right
=
self
.
conv_layer
(
input
,
ch_out
,
name
=
'{}.route_in.right'
.
format
(
name
))
neck
=
self
.
conv_layer
(
input
,
ch_out
,
name
=
'{}.neck'
.
format
(
name
))
out
=
block_func
(
neck
,
ch_out
=
ch_out
,
scale_first
=
scale_first
,
name
=
'{}.0'
.
format
(
name
))
for
j
in
six
.
moves
.
xrange
(
1
,
count
):
out
=
block_func
(
out
,
ch_out
=
ch_out
,
name
=
'{}.{}'
.
format
(
name
,
j
))
left
=
self
.
conv_layer
(
out
,
ch_out
,
name
=
'{}.route_in.left'
.
format
(
name
))
route
=
fluid
.
layers
.
concat
([
left
,
right
],
axis
=
1
)
out
=
self
.
conv_layer
(
route
,
ch_out
=
ch_out
if
keep_ch
else
ch_out
*
2
,
name
=
'{}.conv_layer'
.
format
(
name
))
return
out
def
__call__
(
self
,
input
):
"""
Get the backbone of CSPDarkNet, that is output for the 5 stages.
Args:
input (Variable): input variable.
Returns:
The last variables of each stage.
"""
stages
,
block_func
=
self
.
depth_cfg
[
self
.
depth
]
stages
=
stages
[
0
:
5
]
conv
=
self
.
_conv_norm
(
input
=
input
,
ch_out
=
32
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
,
act
=
'mish'
,
name
=
self
.
prefix_name
+
"conv"
)
blocks
=
[]
for
i
,
stage
in
enumerate
(
stages
):
input
=
conv
if
i
==
0
else
block
downsample_
=
self
.
_downsample
(
input
=
input
,
ch_out
=
input
.
shape
[
1
]
*
2
,
name
=
self
.
prefix_name
+
"stage.{}.downsample"
.
format
(
i
))
block
=
self
.
layer_warp
(
block_func
=
block_func
,
input
=
downsample_
,
ch_out
=
32
*
2
**
i
,
count
=
stage
,
keep_ch
=
(
i
==
0
),
scale_first
=
i
==
0
,
name
=
self
.
prefix_name
+
"stage.{}"
.
format
(
i
))
blocks
.
append
(
block
)
return
blocks
ppdet/modeling/losses/iou_aware_loss.py
浏览文件 @
76f6c939
...
...
@@ -66,8 +66,11 @@ class IouAwareLoss(IouLoss):
eps (float): the decimal to prevent the denominator eqaul zero
'''
iouk
=
self
.
_iou
(
x
,
y
,
w
,
h
,
tx
,
ty
,
tw
,
th
,
anchors
,
downsample_ratio
,
batch_size
,
ioup
,
eps
)
pred
=
self
.
_bbox_transform
(
x
,
y
,
w
,
h
,
anchors
,
downsample_ratio
,
batch_size
,
False
)
gt
=
self
.
_bbox_transform
(
tx
,
ty
,
tw
,
th
,
anchors
,
downsample_ratio
,
batch_size
,
True
)
iouk
=
self
.
_iou
(
pred
,
gt
,
ioup
,
eps
)
iouk
.
stop_gradient
=
True
loss_iou_aware
=
fluid
.
layers
.
cross_entropy
(
ioup
,
iouk
,
soft_label
=
True
)
...
...
ppdet/modeling/losses/iou_loss.py
浏览文件 @
76f6c939
...
...
@@ -35,12 +35,21 @@ class IouLoss(object):
loss_weight (float): iou loss weight, default is 2.5
max_height (int): max height of input to support random shape input
max_width (int): max width of input to support random shape input
ciou_term (bool): whether to add ciou_term
loss_square (bool): whether to square the iou term
"""
def
__init__
(
self
,
loss_weight
=
2.5
,
max_height
=
608
,
max_width
=
608
):
def
__init__
(
self
,
loss_weight
=
2.5
,
max_height
=
608
,
max_width
=
608
,
ciou_term
=
False
,
loss_square
=
True
):
self
.
_loss_weight
=
loss_weight
self
.
_MAX_HI
=
max_height
self
.
_MAX_WI
=
max_width
self
.
ciou_term
=
ciou_term
self
.
loss_square
=
loss_square
def
__call__
(
self
,
x
,
...
...
@@ -65,33 +74,22 @@ class IouLoss(object):
batch_size (int): training batch size
eps (float): the decimal to prevent the denominator eqaul zero
'''
iouk
=
self
.
_iou
(
x
,
y
,
w
,
h
,
tx
,
ty
,
tw
,
th
,
anchors
,
downsample_ratio
,
batch_size
,
ioup
,
eps
)
loss_iou
=
1.
-
iouk
*
iouk
pred
=
self
.
_bbox_transform
(
x
,
y
,
w
,
h
,
anchors
,
downsample_ratio
,
batch_size
,
False
)
gt
=
self
.
_bbox_transform
(
tx
,
ty
,
tw
,
th
,
anchors
,
downsample_ratio
,
batch_size
,
True
)
iouk
=
self
.
_iou
(
pred
,
gt
,
ioup
,
eps
)
if
self
.
loss_square
:
loss_iou
=
1.
-
iouk
*
iouk
else
:
loss_iou
=
1.
-
iouk
loss_iou
=
loss_iou
*
self
.
_loss_weight
return
loss_iou
def
_iou
(
self
,
x
,
y
,
w
,
h
,
tx
,
ty
,
tw
,
th
,
anchors
,
downsample_ratio
,
batch_size
,
ioup
=
None
,
eps
=
1.e-10
):
x1
,
y1
,
x2
,
y2
=
self
.
_bbox_transform
(
x
,
y
,
w
,
h
,
anchors
,
downsample_ratio
,
batch_size
,
False
)
x1g
,
y1g
,
x2g
,
y2g
=
self
.
_bbox_transform
(
tx
,
ty
,
tw
,
th
,
anchors
,
downsample_ratio
,
batch_size
,
True
)
def
_iou
(
self
,
pred
,
gt
,
ioup
=
None
,
eps
=
1.e-10
):
x1
,
y1
,
x2
,
y2
=
pred
x1g
,
y1g
,
x2g
,
y2g
=
gt
x2
=
fluid
.
layers
.
elementwise_max
(
x1
,
x2
)
y2
=
fluid
.
layers
.
elementwise_max
(
y1
,
y2
)
...
...
@@ -106,8 +104,46 @@ class IouLoss(object):
unionk
=
(
x2
-
x1
)
*
(
y2
-
y1
)
+
(
x2g
-
x1g
)
*
(
y2g
-
y1g
)
-
intsctk
+
eps
iouk
=
intsctk
/
unionk
if
self
.
ciou_term
:
ciou
=
self
.
get_ciou_term
(
pred
,
gt
,
iouk
,
eps
)
iouk
=
iouk
-
ciou
return
iouk
def
get_ciou_term
(
self
,
pred
,
gt
,
iouk
,
eps
):
x1
,
y1
,
x2
,
y2
=
pred
x1g
,
y1g
,
x2g
,
y2g
=
gt
cx
=
(
x1
+
x2
)
/
2
cy
=
(
y1
+
y2
)
/
2
w
=
(
x2
-
x1
)
+
fluid
.
layers
.
cast
((
x2
-
x1
)
==
0
,
'float32'
)
h
=
(
y2
-
y1
)
+
fluid
.
layers
.
cast
((
y2
-
y1
)
==
0
,
'float32'
)
cxg
=
(
x1g
+
x2g
)
/
2
cyg
=
(
y1g
+
y2g
)
/
2
wg
=
x2g
-
x1g
hg
=
y2g
-
y1g
# A or B
xc1
=
fluid
.
layers
.
elementwise_min
(
x1
,
x1g
)
yc1
=
fluid
.
layers
.
elementwise_min
(
y1
,
y1g
)
xc2
=
fluid
.
layers
.
elementwise_max
(
x2
,
x2g
)
yc2
=
fluid
.
layers
.
elementwise_max
(
y2
,
y2g
)
# DIOU term
dist_intersection
=
(
cx
-
cxg
)
*
(
cx
-
cxg
)
+
(
cy
-
cyg
)
*
(
cy
-
cyg
)
dist_union
=
(
xc2
-
xc1
)
*
(
xc2
-
xc1
)
+
(
yc2
-
yc1
)
*
(
yc2
-
yc1
)
diou_term
=
(
dist_intersection
+
eps
)
/
(
dist_union
+
eps
)
# CIOU term
ciou_term
=
0
ar_gt
=
wg
/
hg
ar_pred
=
w
/
h
arctan
=
fluid
.
layers
.
atan
(
ar_gt
)
-
fluid
.
layers
.
atan
(
ar_pred
)
ar_loss
=
4.
/
np
.
pi
/
np
.
pi
*
arctan
*
arctan
alpha
=
ar_loss
/
(
1
-
iouk
+
ar_loss
+
eps
)
alpha
.
stop_gradient
=
True
ciou_term
=
alpha
*
ar_loss
return
diou_term
+
ciou_term
def
_bbox_transform
(
self
,
dcx
,
dcy
,
dw
,
dh
,
anchors
,
downsample_ratio
,
batch_size
,
is_gt
):
grid_x
=
int
(
self
.
_MAX_WI
/
downsample_ratio
)
...
...
ppdet/modeling/losses/yolo_loss.py
浏览文件 @
76f6c939
...
...
@@ -18,6 +18,10 @@ from __future__ import print_function
from
paddle
import
fluid
from
ppdet.core.workspace
import
register
try
:
from
collections.abc
import
Sequence
except
Exception
:
from
collections
import
Sequence
__all__
=
[
'YOLOv3Loss'
]
...
...
@@ -43,13 +47,20 @@ class YOLOv3Loss(object):
label_smooth
=
True
,
use_fine_grained_loss
=
False
,
iou_loss
=
None
,
iou_aware_loss
=
None
):
iou_aware_loss
=
None
,
downsample
=
[
32
,
16
,
8
],
scale_x_y
=
1.
,
match_score
=
False
):
self
.
_batch_size
=
batch_size
self
.
_ignore_thresh
=
ignore_thresh
self
.
_label_smooth
=
label_smooth
self
.
_use_fine_grained_loss
=
use_fine_grained_loss
self
.
_iou_loss
=
iou_loss
self
.
_iou_aware_loss
=
iou_aware_loss
self
.
downsample
=
downsample
# TODO(guanzhong) activate scale_x_y in Paddle 2.0
#self.scale_x_y = scale_x_y
self
.
match_score
=
match_score
def
__call__
(
self
,
outputs
,
gt_box
,
gt_label
,
gt_score
,
targets
,
anchors
,
anchor_masks
,
mask_anchors
,
num_classes
,
prefix_name
):
...
...
@@ -59,8 +70,9 @@ class YOLOv3Loss(object):
mask_anchors
,
self
.
_ignore_thresh
)
else
:
losses
=
[]
downsample
=
32
for
i
,
output
in
enumerate
(
outputs
):
#scale_x_y = self.scale_x_y if not isinstance(
# self.scale_x_y, Sequence) else self.scale_x_y[i]
anchor_mask
=
anchor_masks
[
i
]
loss
=
fluid
.
layers
.
yolov3_loss
(
x
=
output
,
...
...
@@ -71,11 +83,10 @@ class YOLOv3Loss(object):
anchor_mask
=
anchor_mask
,
class_num
=
num_classes
,
ignore_thresh
=
self
.
_ignore_thresh
,
downsample_ratio
=
downsample
,
downsample_ratio
=
self
.
downsample
[
i
]
,
use_label_smooth
=
self
.
_label_smooth
,
name
=
prefix_name
+
"yolo_loss"
+
str
(
i
))
losses
.
append
(
fluid
.
layers
.
reduce_mean
(
loss
))
downsample
//=
2
return
{
'loss'
:
sum
(
losses
)}
...
...
@@ -108,7 +119,6 @@ class YOLOv3Loss(object):
assert
len
(
outputs
)
==
len
(
targets
),
\
"YOLOv3 output layer number not equal target number"
downsample
=
32
loss_xys
,
loss_whs
,
loss_objs
,
loss_clss
=
[],
[],
[],
[]
if
self
.
_iou_loss
is
not
None
:
loss_ious
=
[]
...
...
@@ -116,6 +126,7 @@ class YOLOv3Loss(object):
loss_iou_awares
=
[]
for
i
,
(
output
,
target
,
anchors
)
in
enumerate
(
zip
(
outputs
,
targets
,
mask_anchors
)):
downsample
=
self
.
downsample
[
i
]
an_num
=
len
(
anchors
)
//
2
if
self
.
_iou_aware_loss
is
not
None
:
ioup
,
output
=
self
.
_split_ioup
(
output
,
an_num
,
num_classes
)
...
...
@@ -151,9 +162,11 @@ class YOLOv3Loss(object):
loss_iou_aware
,
dim
=
[
1
,
2
,
3
])
loss_iou_awares
.
append
(
fluid
.
layers
.
reduce_mean
(
loss_iou_aware
))
#scale_x_y = self.scale_x_y if not isinstance(
# self.scale_x_y, Sequence) else self.scale_x_y[i]
loss_obj_pos
,
loss_obj_neg
=
self
.
_calc_obj_loss
(
output
,
obj
,
tobj
,
gt_box
,
self
.
_batch_size
,
anchors
,
num_classes
,
downsample
,
self
.
_ignore_thresh
)
num_classes
,
downsample
,
self
.
_ignore_thresh
,
scale_x_y
)
loss_cls
=
fluid
.
layers
.
sigmoid_cross_entropy_with_logits
(
cls
,
tcls
)
loss_cls
=
fluid
.
layers
.
elementwise_mul
(
loss_cls
,
tobj
,
axis
=
0
)
...
...
@@ -165,7 +178,6 @@ class YOLOv3Loss(object):
fluid
.
layers
.
reduce_mean
(
loss_obj_pos
+
loss_obj_neg
))
loss_clss
.
append
(
fluid
.
layers
.
reduce_mean
(
loss_cls
))
downsample
//=
2
losses_all
=
{
"loss_xy"
:
fluid
.
layers
.
sum
(
loss_xys
),
"loss_wh"
:
fluid
.
layers
.
sum
(
loss_whs
),
...
...
@@ -264,13 +276,13 @@ class YOLOv3Loss(object):
return
(
tx
,
ty
,
tw
,
th
,
tscale
,
tobj
,
tcls
)
def
_calc_obj_loss
(
self
,
output
,
obj
,
tobj
,
gt_box
,
batch_size
,
anchors
,
num_classes
,
downsample
,
ignore_thresh
):
num_classes
,
downsample
,
ignore_thresh
,
scale_x_y
):
# A prediction bbox overlap any gt_bbox over ignore_thresh,
# objectness loss will be ignored, process as follows:
# 1. get pred bbox, which is same with YOLOv3 infer mode, use yolo_box here
# NOTE: img_size is set as 1.0 to get noramlized pred bbox
bbox
,
_
=
fluid
.
layers
.
yolo_box
(
bbox
,
prob
=
fluid
.
layers
.
yolo_box
(
x
=
output
,
img_size
=
fluid
.
layers
.
ones
(
shape
=
[
batch_size
,
2
],
dtype
=
"int32"
),
...
...
@@ -288,6 +300,7 @@ class YOLOv3Loss(object):
else
:
preds
=
[
bbox
]
gts
=
[
gt_box
]
probs
=
[
prob
]
ious
=
[]
for
pred
,
gt
in
zip
(
preds
,
gts
):
...
...
@@ -307,12 +320,17 @@ class YOLOv3Loss(object):
pred
=
fluid
.
layers
.
squeeze
(
pred
,
axes
=
[
0
])
gt
=
box_xywh2xyxy
(
fluid
.
layers
.
squeeze
(
gt
,
axes
=
[
0
]))
ious
.
append
(
fluid
.
layers
.
iou_similarity
(
pred
,
gt
))
iou
=
fluid
.
layers
.
stack
(
ious
,
axis
=
0
)
iou
=
fluid
.
layers
.
stack
(
ious
,
axis
=
0
)
# 3. Get iou_mask by IoU between gt bbox and prediction bbox,
# Get obj_mask by tobj(holds gt_score), calculate objectness loss
max_iou
=
fluid
.
layers
.
reduce_max
(
iou
,
dim
=-
1
)
iou_mask
=
fluid
.
layers
.
cast
(
max_iou
<=
ignore_thresh
,
dtype
=
"float32"
)
if
self
.
match_score
:
max_prob
=
fluid
.
layers
.
reduce_max
(
prob
,
dim
=-
1
)
iou_mask
=
iou_mask
*
fluid
.
layers
.
cast
(
max_prob
<=
0.25
,
dtype
=
"float32"
)
output_shape
=
fluid
.
layers
.
shape
(
output
)
an_num
=
len
(
anchors
)
//
2
iou_mask
=
fluid
.
layers
.
reshape
(
iou_mask
,
(
-
1
,
an_num
,
output_shape
[
2
],
...
...
ppdet/utils/coco_eval.py
浏览文件 @
76f6c939
...
...
@@ -37,11 +37,13 @@ __all__ = [
]
def
clip_bbox
(
bbox
):
xmin
=
max
(
min
(
bbox
[
0
],
1.
),
0.
)
ymin
=
max
(
min
(
bbox
[
1
],
1.
),
0.
)
xmax
=
max
(
min
(
bbox
[
2
],
1.
),
0.
)
ymax
=
max
(
min
(
bbox
[
3
],
1.
),
0.
)
def
clip_bbox
(
bbox
,
im_size
=
None
):
h
=
1.
if
im_size
is
None
else
im_size
[
0
]
w
=
1.
if
im_size
is
None
else
im_size
[
1
]
xmin
=
max
(
min
(
bbox
[
0
],
w
),
0.
)
ymin
=
max
(
min
(
bbox
[
1
],
h
),
0.
)
xmax
=
max
(
min
(
bbox
[
2
],
w
),
0.
)
ymax
=
max
(
min
(
bbox
[
3
],
h
),
0.
)
return
xmin
,
ymin
,
xmax
,
ymax
...
...
@@ -66,7 +68,8 @@ def bbox_eval(results,
anno_file
,
outfile
,
with_background
=
True
,
is_bbox_normalized
=
False
):
is_bbox_normalized
=
False
,
save_only
=
False
):
assert
'bbox'
in
results
[
0
]
assert
outfile
.
endswith
(
'.json'
)
from
pycocotools.coco
import
COCO
...
...
@@ -91,13 +94,23 @@ def bbox_eval(results,
with
open
(
outfile
,
'w'
)
as
f
:
json
.
dump
(
xywh_results
,
f
)
if
save_only
:
logger
.
info
(
'The bbox result is saved to {} and do not '
'evaluate the mAP.'
.
format
(
outfile
))
return
map_stats
=
cocoapi_eval
(
outfile
,
'bbox'
,
coco_gt
=
coco_gt
)
# flush coco evaluation result
sys
.
stdout
.
flush
()
return
map_stats
def
mask_eval
(
results
,
anno_file
,
outfile
,
resolution
,
thresh_binarize
=
0.5
):
def
mask_eval
(
results
,
anno_file
,
outfile
,
resolution
,
thresh_binarize
=
0.5
,
save_only
=
False
):
assert
'mask'
in
results
[
0
]
assert
outfile
.
endswith
(
'.json'
)
from
pycocotools.coco
import
COCO
...
...
@@ -143,6 +156,11 @@ def mask_eval(results, anno_file, outfile, resolution, thresh_binarize=0.5):
with
open
(
outfile
,
'w'
)
as
f
:
json
.
dump
(
segm_results
,
f
)
if
save_only
:
logger
.
info
(
'The mask result is saved to {} and do not '
'evaluate the mAP.'
.
format
(
outfile
))
return
cocoapi_eval
(
outfile
,
'segm'
,
coco_gt
=
coco_gt
)
...
...
@@ -257,8 +275,11 @@ def bbox2out(results, clsid2catid, is_bbox_normalized=False):
w
*=
im_width
h
*=
im_height
else
:
w
=
xmax
-
xmin
+
1
h
=
ymax
-
ymin
+
1
im_size
=
t
[
'im_size'
][
0
][
i
].
tolist
()
xmin
,
ymin
,
xmax
,
ymax
=
\
clip_bbox
([
xmin
,
ymin
,
xmax
,
ymax
],
im_size
)
w
=
xmax
-
xmin
h
=
ymax
-
ymin
bbox
=
[
xmin
,
ymin
,
w
,
h
]
coco_res
=
{
...
...
ppdet/utils/eval_utils.py
浏览文件 @
76f6c939
...
...
@@ -191,7 +191,8 @@ def eval_results(results,
is_bbox_normalized
=
False
,
output_directory
=
None
,
map_type
=
'11point'
,
dataset
=
None
):
dataset
=
None
,
save_only
=
False
):
"""Evaluation for evaluation program results"""
box_ap_stats
=
[]
if
metric
==
'COCO'
:
...
...
@@ -213,13 +214,15 @@ def eval_results(results,
anno_file
,
output
,
with_background
,
is_bbox_normalized
=
is_bbox_normalized
)
is_bbox_normalized
=
is_bbox_normalized
,
save_only
=
save_only
)
if
'mask'
in
results
[
0
]:
output
=
'mask.json'
if
output_directory
:
output
=
os
.
path
.
join
(
output_directory
,
'mask.json'
)
mask_eval
(
results
,
anno_file
,
output
,
resolution
)
mask_eval
(
results
,
anno_file
,
output
,
resolution
,
save_only
=
save_only
)
else
:
if
'accum_map'
in
results
[
-
1
]:
res
=
np
.
mean
(
results
[
-
1
][
'accum_map'
][
0
])
...
...
ppdet/utils/voc_eval.py
浏览文件 @
76f6c939
...
...
@@ -68,7 +68,6 @@ def bbox_eval(results,
if
bboxes
.
shape
==
(
1
,
1
)
or
bboxes
is
None
:
continue
gt_boxes
=
t
[
'gt_bbox'
][
0
]
gt_labels
=
t
[
'gt_class'
][
0
]
difficults
=
t
[
'is_difficult'
][
0
]
if
not
evaluate_difficult
\
...
...
tools/eval.py
浏览文件 @
76f6c939
...
...
@@ -111,7 +111,7 @@ def main():
extra_keys
=
[]
if
cfg
.
metric
==
'COCO'
:
extra_keys
=
[
'im_info'
,
'im_id'
,
'im_shape'
]
extra_keys
=
[
'im_info'
,
'im_id'
,
'im_shape'
,
'im_size'
]
if
cfg
.
metric
==
'VOC'
:
extra_keys
=
[
'gt_bbox'
,
'gt_class'
,
'is_difficult'
]
...
...
@@ -160,6 +160,7 @@ def main():
# evaluation
# if map_type not set, use default 11point, only use in VOC eval
map_type
=
cfg
.
map_type
if
'map_type'
in
cfg
else
'11point'
save_only
=
getattr
(
cfg
,
'save_prediction_only'
,
False
)
eval_results
(
results
,
cfg
.
metric
,
...
...
@@ -168,7 +169,8 @@ def main():
is_bbox_normalized
,
FLAGS
.
output_eval
,
map_type
,
dataset
=
dataset
)
dataset
=
dataset
,
save_only
=
save_only
)
if
__name__
==
'__main__'
:
...
...
tools/train.py
浏览文件 @
76f6c939
...
...
@@ -79,6 +79,10 @@ def main():
# check if paddlepaddle version is satisfied
check_version
()
save_only
=
getattr
(
cfg
,
'save_prediction_only'
,
False
)
if
save_only
:
raise
NotImplementedError
(
'The config file only support prediction,'
' training stage is not implemented now'
)
main_arch
=
cfg
.
architecture
if
cfg
.
use_gpu
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录