Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
76f6c939
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
76f6c939
编写于
5月 08, 2020
作者:
W
wangguanzhong
提交者:
GitHub
5月 08, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add Yolo v4 (#594)
* add yolov4
上级
80728741
变更
19
隐藏空白更改
内联
并排
Showing
19 changed file
with
1015 addition
and
108 deletion
+1015
-108
configs/yolo/README.md
configs/yolo/README.md
+46
-0
configs/yolo/yolov4_cspdarknet.yml
configs/yolo/yolov4_cspdarknet.yml
+122
-0
configs/yolo/yolov4_cspdarknet_voc.yml
configs/yolo/yolov4_cspdarknet_voc.yml
+182
-0
ppdet/data/source/coco.py
ppdet/data/source/coco.py
+51
-40
ppdet/data/transform/batch_operators.py
ppdet/data/transform/batch_operators.py
+34
-4
ppdet/data/transform/operators.py
ppdet/data/transform/operators.py
+1
-1
ppdet/modeling/anchor_heads/yolo_head.py
ppdet/modeling/anchor_heads/yolo_head.py
+200
-8
ppdet/modeling/architectures/__init__.py
ppdet/modeling/architectures/__init__.py
+2
-2
ppdet/modeling/architectures/yolo.py
ppdet/modeling/architectures/yolo.py
+26
-2
ppdet/modeling/backbones/__init__.py
ppdet/modeling/backbones/__init__.py
+2
-0
ppdet/modeling/backbones/cspdarknet.py
ppdet/modeling/backbones/cspdarknet.py
+212
-0
ppdet/modeling/losses/iou_aware_loss.py
ppdet/modeling/losses/iou_aware_loss.py
+5
-2
ppdet/modeling/losses/iou_loss.py
ppdet/modeling/losses/iou_loss.py
+60
-24
ppdet/modeling/losses/yolo_loss.py
ppdet/modeling/losses/yolo_loss.py
+28
-10
ppdet/utils/coco_eval.py
ppdet/utils/coco_eval.py
+30
-9
ppdet/utils/eval_utils.py
ppdet/utils/eval_utils.py
+6
-3
ppdet/utils/voc_eval.py
ppdet/utils/voc_eval.py
+0
-1
tools/eval.py
tools/eval.py
+4
-2
tools/train.py
tools/train.py
+4
-0
未找到文件。
configs/yolo/README.md
0 → 100644
浏览文件 @
76f6c939
# YOLO v4
## 内容
-
[
简介
](
#简介
)
-
[
模型库与基线
](
#模型库与基线
)
-
[
未来工作
](
#未来工作
)
-
[
如何贡献代码
](
#如何贡献代码
)
## 简介
[
YOLO v4
](
https://arxiv.org/abs/2004.10934
)
的Paddle实现版本
目前PaddleDetection中转换了
[
darknet
](
https://github.com/AlexeyAB/darknet
)
中YOLO v4的权重,可以直接对图片进行预测,在
[
test-dev2019
](
http://cocodataset.org/#detection-2019
)
中精度为43.5%。另外,PaddleDetection支持VOC数据集上finetune,精度达到86.0%
PaddleDetection支持YOLO v4的多个模块:
-
mish激活函数
-
PAN模块
-
SPP模块
-
ciou loss
-
label_smooth
## 模型库
下表中展示了PaddleDetection当前支持的网络结构。
| | GPU个数 | 测试集 | 骨干网络 | 精度 | 模型下载 | 配置文件 |
|:------------------------:|:-------:|:------:|:--------------------------:|:------------------------:| :---------:| :-----: |
| YOLO v4 | - |test-dev2019 | CSPDarkNet53 | 43.5 |
[
下载链接
](
https://paddlemodels.bj.bcebos.com/object_detection/yolov4_cspdarknet.pdparams
)
|
[
配置文件
](
https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/yolo/yolov4_cspdarknet.yml
)
|
| YOLO v4 VOC | 2 | VOC2007 | CSPDarkNet53 | - |
[
下载链接
](
https://paddlemodels.bj.bcebos.com/object_detection/yolov4_cspdarknet_voc.pdparams
)
|
[
配置文件
](
https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/yolo/yolov4_cspdarknet_voc.yml
)
|
**注意:**
-
由于原版YOLO v4使用coco trainval2014进行训练,训练样本中包含部分评估样本,若使用val集会导致精度虚高,因此使用coco test集对模型进行评估。
-
YOLO v4模型仅支持coco test集评估和图片预测,由于test集不包含目标框的真实标注,评估时会将预测结果保存在json文件中,请将结果提交至
[
cocodataset
](
http://cocodataset.org/#detection-2019
)
上查看最终精度指标。
-
coco测试集使用test2017,下载请参考
[
coco2017
](
http://cocodataset.org/#download
)
## 未来工作
1.
mish激活函数优化
2.
mosaic数据预处理实现
3.
scale
\_
x
\_
y为yolo_box中decode时对box的位置进行微调,该功能将在Paddle2.0版本中实现
## 如何贡献代码
我们非常欢迎您可以为PaddleDetection提供代码,您可以提交PR供我们review;也十分感谢您的反馈,可以提交相应issue,我们会及时解答。
configs/yolo/yolov4_cspdarknet.yml
0 → 100644
浏览文件 @
76f6c939
architecture
:
YOLOv4
use_gpu
:
true
max_iters
:
500200
log_smooth_window
:
20
save_dir
:
output
snapshot_iter
:
10000
metric
:
COCO
pretrain_weights
:
https://paddlemodels.bj.bcebos.com/object_detection/yolov4_cspdarknet.pdparams
weights
:
output/yolov4_cspdarknet/model_final
num_classes
:
80
use_fine_grained_loss
:
true
save_prediction_only
:
True
YOLOv4
:
backbone
:
CSPDarkNet
yolo_head
:
YOLOv4Head
CSPDarkNet
:
norm_type
:
sync_bn
norm_decay
:
0.
depth
:
53
YOLOv4Head
:
anchors
:
[[
12
,
16
],
[
19
,
36
],
[
40
,
28
],
[
36
,
75
],
[
76
,
55
],
[
72
,
146
],
[
142
,
110
],
[
192
,
243
],
[
459
,
401
]]
anchor_masks
:
[[
0
,
1
,
2
],
[
3
,
4
,
5
],
[
6
,
7
,
8
]]
nms
:
background_label
:
-1
keep_top_k
:
-1
nms_threshold
:
0.45
nms_top_k
:
-1
normalized
:
true
score_threshold
:
0.001
downsample
:
[
8
,
16
,
32
]
YOLOv3Loss
:
# batch_size here is only used for fine grained loss, not used
# for training batch_size setting, training batch_size setting
# is in configs/yolov3_reader.yml TrainReader.batch_size, batch
# size here should be set as same value as TrainReader.batch_size
batch_size
:
4
ignore_thresh
:
0.7
label_smooth
:
true
downsample
:
[
8
,
16
,
32
]
#scale_x_y: [1.2, 1.1, 1.05]
iou_loss
:
IouLoss
match_score
:
true
IouLoss
:
loss_weight
:
0.07
max_height
:
608
max_width
:
608
ciou_term
:
true
loss_square
:
false
LearningRate
:
base_lr
:
0.0001
schedulers
:
-
!PiecewiseDecay
gamma
:
0.1
milestones
:
-
400000
-
450000
-
!LinearWarmup
start_factor
:
0.
steps
:
1000
OptimizerBuilder
:
clip_grad_by_norm
:
10.
optimizer
:
momentum
:
0.949
type
:
Momentum
regularizer
:
factor
:
0.0005
type
:
L2
_READER_
:
'
../yolov3_reader.yml'
EvalReader
:
inputs_def
:
fields
:
[
'
image'
,
'
im_size'
,
'
im_id'
]
num_max_boxes
:
90
dataset
:
!COCODataSet
image_dir
:
test2017
anno_path
:
annotations/image_info_test-dev2017.json
dataset_dir
:
data/coco
with_background
:
false
sample_transforms
:
-
!DecodeImage
to_rgb
:
True
-
!ResizeImage
target_size
:
608
interp
:
1
-
!NormalizeImage
mean
:
[
0.
,
0.
,
0.
]
std
:
[
1.
,
1.
,
1.
]
is_scale
:
True
is_channel_first
:
false
-
!Permute
to_bgr
:
false
channel_first
:
True
batch_size
:
1
TestReader
:
dataset
:
!ImageFolder
use_default_label
:
true
with_background
:
false
sample_transforms
:
-
!DecodeImage
to_rgb
:
True
-
!ResizeImage
target_size
:
608
interp
:
1
-
!NormalizeImage
mean
:
[
0.
,
0.
,
0.
]
std
:
[
1.
,
1.
,
1.
]
is_scale
:
True
is_channel_first
:
false
-
!Permute
to_bgr
:
false
channel_first
:
True
configs/yolo/yolov4_cspdarknet_voc.yml
0 → 100644
浏览文件 @
76f6c939
architecture
:
YOLOv4
use_gpu
:
true
max_iters
:
70000
log_smooth_window
:
20
save_dir
:
output
snapshot_iter
:
1000
metric
:
VOC
pretrain_weights
:
https://paddlemodels.bj.bcebos.com/object_detection/yolov4_cspdarknet.pdparams
weights
:
output/yolov4_cspdarknet_voc/model_final
num_classes
:
20
use_fine_grained_loss
:
true
YOLOv4
:
backbone
:
CSPDarkNet
yolo_head
:
YOLOv4Head
CSPDarkNet
:
norm_type
:
sync_bn
norm_decay
:
0.
depth
:
53
YOLOv4Head
:
anchors
:
[[
12
,
16
],
[
19
,
36
],
[
40
,
28
],
[
36
,
75
],
[
76
,
55
],
[
72
,
146
],
[
142
,
110
],
[
192
,
243
],
[
459
,
401
]]
anchor_masks
:
[[
0
,
1
,
2
],
[
3
,
4
,
5
],
[
6
,
7
,
8
]]
nms
:
background_label
:
-1
keep_top_k
:
-1
nms_threshold
:
0.45
nms_top_k
:
-1
normalized
:
true
score_threshold
:
0.001
downsample
:
[
8
,
16
,
32
]
YOLOv3Loss
:
# batch_size here is only used for fine grained loss, not used
# for training batch_size setting, training batch_size setting
# is in configs/yolov3_reader.yml TrainReader.batch_size, batch
# size here should be set as same value as TrainReader.batch_size
batch_size
:
4
ignore_thresh
:
0.7
label_smooth
:
true
downsample
:
[
8
,
16
,
32
]
#scale_x_y: [1.2, 1.1, 1.05]
iou_loss
:
IouLoss
match_score
:
true
IouLoss
:
loss_weight
:
0.07
max_height
:
608
max_width
:
608
ciou_term
:
true
loss_square
:
true
LearningRate
:
base_lr
:
0.0001
schedulers
:
-
!PiecewiseDecay
gamma
:
0.1
milestones
:
-
56000
-
63000
-
!LinearWarmup
start_factor
:
0.
steps
:
1000
OptimizerBuilder
:
clip_grad_by_norm
:
10.
optimizer
:
momentum
:
0.949
type
:
Momentum
regularizer
:
factor
:
0.0005
type
:
L2
_READER_
:
'
../yolov3_reader.yml'
TrainReader
:
inputs_def
:
fields
:
[
'
image'
,
'
gt_bbox'
,
'
gt_class'
,
'
gt_score'
,
'
im_id'
]
num_max_boxes
:
50
dataset
:
!VOCDataSet
anno_path
:
trainval.txt
dataset_dir
:
dataset/voc
with_background
:
false
sample_transforms
:
-
!DecodeImage
to_rgb
:
True
with_mixup
:
True
-
!MixupImage
alpha
:
1.5
beta
:
1.5
-
!ColorDistort
{}
-
!RandomExpand
fill_value
:
[
123.675
,
116.28
,
103.53
]
-
!RandomCrop
{}
-
!RandomFlipImage
is_normalized
:
false
-
!NormalizeBox
{}
-
!PadBox
num_max_boxes
:
50
-
!BboxXYXY2XYWH
{}
batch_transforms
:
-
!RandomShape
sizes
:
[
320
,
352
,
384
,
416
,
448
,
480
,
512
,
544
,
576
,
608
]
random_inter
:
True
-
!NormalizeImage
mean
:
[
0.
,
0.
,
0.
]
std
:
[
1.
,
1.
,
1.
]
is_scale
:
True
is_channel_first
:
false
-
!Permute
to_bgr
:
false
channel_first
:
True
# Gt2YoloTarget is only used when use_fine_grained_loss set as true,
# this operator will be deleted automatically if use_fine_grained_loss
# is set as false
-
!Gt2YoloTarget
anchor_masks
:
[[
0
,
1
,
2
],
[
3
,
4
,
5
],
[
6
,
7
,
8
]]
anchors
:
[[
12
,
16
],
[
19
,
36
],
[
40
,
28
],
[
36
,
75
],
[
76
,
55
],
[
72
,
146
],
[
142
,
110
],
[
192
,
243
],
[
459
,
401
]]
downsample_ratios
:
[
8
,
16
,
32
]
batch_size
:
4
shuffle
:
true
mixup_epoch
:
28
drop_last
:
true
worker_num
:
8
bufsize
:
16
use_process
:
true
drop_empty
:
false
EvalReader
:
inputs_def
:
fields
:
[
'
image'
,
'
im_size'
,
'
im_id'
,
'
gt_bbox'
,
'
gt_class'
,
'
is_difficult'
]
num_max_boxes
:
90
dataset
:
!VOCDataSet
anno_path
:
test.txt
dataset_dir
:
dataset/voc
use_default_label
:
true
with_background
:
false
sample_transforms
:
-
!DecodeImage
to_rgb
:
True
-
!ResizeImage
target_size
:
608
interp
:
1
-
!NormalizeImage
mean
:
[
0.
,
0.
,
0.
]
std
:
[
1.
,
1.
,
1.
]
is_scale
:
True
is_channel_first
:
false
-
!PadBox
num_max_boxes
:
90
-
!Permute
to_bgr
:
false
channel_first
:
True
batch_size
:
4
drop_empty
:
false
worker_num
:
8
bufsize
:
16
TestReader
:
dataset
:
!ImageFolder
use_default_label
:
true
with_background
:
false
sample_transforms
:
-
!DecodeImage
to_rgb
:
True
-
!ResizeImage
target_size
:
608
interp
:
1
-
!NormalizeImage
mean
:
[
0.
,
0.
,
0.
]
std
:
[
1.
,
1.
,
1.
]
is_scale
:
True
is_channel_first
:
false
-
!Permute
to_bgr
:
false
channel_first
:
True
ppdet/data/source/coco.py
浏览文件 @
76f6c939
...
...
@@ -67,6 +67,7 @@ class COCODataSet(DataSet):
self
.
roidbs
=
None
# a dict used to map category name to class id
self
.
cname2cid
=
None
self
.
load_image_only
=
False
def
load_roidb_and_cname2cid
(
self
):
anno_path
=
os
.
path
.
join
(
self
.
dataset_dir
,
self
.
anno_path
)
...
...
@@ -92,61 +93,71 @@ class COCODataSet(DataSet):
for
catid
,
clsid
in
catid2clsid
.
items
()
})
if
'annotations'
not
in
coco
.
dataset
:
self
.
load_image_only
=
True
logger
.
warn
(
'Annotation file: {} does not contains ground truth '
'and load image information only.'
.
format
(
anno_path
))
for
img_id
in
img_ids
:
img_anno
=
coco
.
loadImgs
(
img_id
)[
0
]
im_fname
=
img_anno
[
'file_name'
]
im_w
=
float
(
img_anno
[
'width'
])
im_h
=
float
(
img_anno
[
'height'
])
ins_anno_ids
=
coco
.
getAnnIds
(
imgIds
=
img_id
,
iscrowd
=
False
)
instances
=
coco
.
loadAnns
(
ins_anno_ids
)
bboxes
=
[]
for
inst
in
instances
:
x
,
y
,
box_w
,
box_h
=
inst
[
'bbox'
]
x1
=
max
(
0
,
x
)
y1
=
max
(
0
,
y
)
x2
=
min
(
im_w
-
1
,
x1
+
max
(
0
,
box_w
-
1
))
y2
=
min
(
im_h
-
1
,
y1
+
max
(
0
,
box_h
-
1
))
if
inst
[
'area'
]
>
0
and
x2
>=
x1
and
y2
>=
y1
:
inst
[
'clean_bbox'
]
=
[
x1
,
y1
,
x2
,
y2
]
bboxes
.
append
(
inst
)
else
:
logger
.
warn
(
'Found an invalid bbox in annotations: im_id: {}, '
'area: {} x1: {}, y1: {}, x2: {}, y2: {}.'
.
format
(
img_id
,
float
(
inst
[
'area'
]),
x1
,
y1
,
x2
,
y2
))
num_bbox
=
len
(
bboxes
)
gt_bbox
=
np
.
zeros
((
num_bbox
,
4
),
dtype
=
np
.
float32
)
gt_class
=
np
.
zeros
((
num_bbox
,
1
),
dtype
=
np
.
int32
)
gt_score
=
np
.
ones
((
num_bbox
,
1
),
dtype
=
np
.
float32
)
is_crowd
=
np
.
zeros
((
num_bbox
,
1
),
dtype
=
np
.
int32
)
difficult
=
np
.
zeros
((
num_bbox
,
1
),
dtype
=
np
.
int32
)
gt_poly
=
[
None
]
*
num_bbox
for
i
,
box
in
enumerate
(
bboxes
):
catid
=
box
[
'category_id'
]
gt_class
[
i
][
0
]
=
catid2clsid
[
catid
]
gt_bbox
[
i
,
:]
=
box
[
'clean_bbox'
]
is_crowd
[
i
][
0
]
=
box
[
'iscrowd'
]
if
'segmentation'
in
box
:
gt_poly
[
i
]
=
box
[
'segmentation'
]
im_fname
=
os
.
path
.
join
(
image_dir
,
im_fname
)
if
image_dir
else
im_fname
coco_rec
=
{
'im_file'
:
im_fname
,
'im_id'
:
np
.
array
([
img_id
]),
'h'
:
im_h
,
'w'
:
im_w
,
'is_crowd'
:
is_crowd
,
'gt_class'
:
gt_class
,
'gt_bbox'
:
gt_bbox
,
'gt_score'
:
gt_score
,
'gt_poly'
:
gt_poly
,
}
if
not
self
.
load_image_only
:
ins_anno_ids
=
coco
.
getAnnIds
(
imgIds
=
img_id
,
iscrowd
=
False
)
instances
=
coco
.
loadAnns
(
ins_anno_ids
)
bboxes
=
[]
for
inst
in
instances
:
x
,
y
,
box_w
,
box_h
=
inst
[
'bbox'
]
x1
=
max
(
0
,
x
)
y1
=
max
(
0
,
y
)
x2
=
min
(
im_w
-
1
,
x1
+
max
(
0
,
box_w
-
1
))
y2
=
min
(
im_h
-
1
,
y1
+
max
(
0
,
box_h
-
1
))
if
inst
[
'area'
]
>
0
and
x2
>=
x1
and
y2
>=
y1
:
inst
[
'clean_bbox'
]
=
[
x1
,
y1
,
x2
,
y2
]
bboxes
.
append
(
inst
)
else
:
logger
.
warn
(
'Found an invalid bbox in annotations: im_id: {}, '
'area: {} x1: {}, y1: {}, x2: {}, y2: {}.'
.
format
(
img_id
,
float
(
inst
[
'area'
]),
x1
,
y1
,
x2
,
y2
))
num_bbox
=
len
(
bboxes
)
gt_bbox
=
np
.
zeros
((
num_bbox
,
4
),
dtype
=
np
.
float32
)
gt_class
=
np
.
zeros
((
num_bbox
,
1
),
dtype
=
np
.
int32
)
gt_score
=
np
.
ones
((
num_bbox
,
1
),
dtype
=
np
.
float32
)
is_crowd
=
np
.
zeros
((
num_bbox
,
1
),
dtype
=
np
.
int32
)
difficult
=
np
.
zeros
((
num_bbox
,
1
),
dtype
=
np
.
int32
)
gt_poly
=
[
None
]
*
num_bbox
for
i
,
box
in
enumerate
(
bboxes
):
catid
=
box
[
'category_id'
]
gt_class
[
i
][
0
]
=
catid2clsid
[
catid
]
gt_bbox
[
i
,
:]
=
box
[
'clean_bbox'
]
is_crowd
[
i
][
0
]
=
box
[
'iscrowd'
]
if
'segmentation'
in
box
:
gt_poly
[
i
]
=
box
[
'segmentation'
]
coco_rec
.
update
({
'is_crowd'
:
is_crowd
,
'gt_class'
:
gt_class
,
'gt_bbox'
:
gt_bbox
,
'gt_score'
:
gt_score
,
'gt_poly'
:
gt_poly
,
})
logger
.
debug
(
'Load file: {}, im_id: {}, h: {}, w: {}.'
.
format
(
im_fname
,
img_id
,
im_h
,
im_w
))
records
.
append
(
coco_rec
)
...
...
ppdet/data/transform/batch_operators.py
浏览文件 @
76f6c939
...
...
@@ -179,13 +179,18 @@ class Gt2YoloTarget(BaseOperator):
fine grained YOLOv3 loss mode
"""
def
__init__
(
self
,
anchors
,
anchor_masks
,
downsample_ratios
,
num_classes
=
80
):
def
__init__
(
self
,
anchors
,
anchor_masks
,
downsample_ratios
,
num_classes
=
80
,
iou_thresh
=
1.
):
super
(
Gt2YoloTarget
,
self
).
__init__
()
self
.
anchors
=
anchors
self
.
anchor_masks
=
anchor_masks
self
.
downsample_ratios
=
downsample_ratios
self
.
num_classes
=
num_classes
self
.
iou_thresh
=
iou_thresh
def
__call__
(
self
,
samples
,
context
=
None
):
assert
len
(
self
.
anchor_masks
)
==
len
(
self
.
downsample_ratios
),
\
...
...
@@ -225,12 +230,13 @@ class Gt2YoloTarget(BaseOperator):
best_iou
=
iou
best_idx
=
an_idx
gi
=
int
(
gx
*
grid_w
)
gj
=
int
(
gy
*
grid_h
)
# gtbox should be regresed in this layes if best match
# anchor index in anchor mask of this layer
if
best_idx
in
mask
:
best_n
=
mask
.
index
(
best_idx
)
gi
=
int
(
gx
*
grid_w
)
gj
=
int
(
gy
*
grid_h
)
# x, y, w, h, scale
target
[
best_n
,
0
,
gj
,
gi
]
=
gx
*
grid_w
-
gi
...
...
@@ -246,6 +252,30 @@ class Gt2YoloTarget(BaseOperator):
# classification
target
[
best_n
,
6
+
cls
,
gj
,
gi
]
=
1.
# For non-matched anchors, calculate the target if the iou
# between anchor and gt is larger than iou_thresh
if
self
.
iou_thresh
<
1
:
for
idx
,
mask_i
in
enumerate
(
mask
):
if
mask_i
==
best_idx
:
continue
iou
=
jaccard_overlap
(
[
0.
,
0.
,
gw
,
gh
],
[
0.
,
0.
,
an_hw
[
mask_i
,
0
],
an_hw
[
mask_i
,
1
]])
if
iou
>
self
.
iou_thresh
:
# x, y, w, h, scale
target
[
idx
,
0
,
gj
,
gi
]
=
gx
*
grid_w
-
gi
target
[
idx
,
1
,
gj
,
gi
]
=
gy
*
grid_h
-
gj
target
[
idx
,
2
,
gj
,
gi
]
=
np
.
log
(
gw
*
w
/
self
.
anchors
[
mask_i
][
0
])
target
[
idx
,
3
,
gj
,
gi
]
=
np
.
log
(
gh
*
h
/
self
.
anchors
[
mask_i
][
1
])
target
[
idx
,
4
,
gj
,
gi
]
=
2.0
-
gw
*
gh
# objectness record gt_score
target
[
idx
,
5
,
gj
,
gi
]
=
score
# classification
target
[
idx
,
6
+
cls
,
gj
,
gi
]
=
1.
sample
[
'target{}'
.
format
(
i
)]
=
target
return
samples
...
...
ppdet/data/transform/operators.py
浏览文件 @
76f6c939
...
...
@@ -114,6 +114,7 @@ class DecodeImage(BaseOperator):
im
=
sample
[
'image'
]
data
=
np
.
frombuffer
(
im
,
dtype
=
'uint8'
)
im
=
cv2
.
imdecode
(
data
,
1
)
# BGR mode, but need RGB mode
if
self
.
to_rgb
:
im
=
cv2
.
cvtColor
(
im
,
cv2
.
COLOR_BGR2RGB
)
sample
[
'image'
]
=
im
...
...
@@ -331,7 +332,6 @@ class ResizeImage(BaseOperator):
im
=
Image
.
fromarray
(
im
)
im
=
im
.
resize
((
int
(
resize_w
),
int
(
resize_h
)),
self
.
interp
)
im
=
np
.
array
(
im
)
sample
[
'image'
]
=
im
return
sample
...
...
ppdet/modeling/anchor_heads/yolo_head.py
浏览文件 @
76f6c939
...
...
@@ -25,8 +25,12 @@ from ppdet.modeling.losses.yolo_loss import YOLOv3Loss
from
ppdet.core.workspace
import
register
from
ppdet.modeling.ops
import
DropBlock
from
.iou_aware
import
get_iou_aware_score
try
:
from
collections.abc
import
Sequence
except
Exception
:
from
collections
import
Sequence
__all__
=
[
'YOLOv3Head'
]
__all__
=
[
'YOLOv3Head'
,
'YOLOv4Head'
]
@
register
...
...
@@ -62,7 +66,9 @@ class YOLOv3Head(object):
keep_top_k
=
100
,
nms_threshold
=
0.45
,
background_label
=-
1
).
__dict__
,
weight_prefix_name
=
''
):
weight_prefix_name
=
''
,
downsample
=
[
32
,
16
,
8
],
scale_x_y
=
1.0
):
self
.
norm_decay
=
norm_decay
self
.
num_classes
=
num_classes
self
.
anchor_masks
=
anchor_masks
...
...
@@ -77,6 +83,9 @@ class YOLOv3Head(object):
self
.
keep_prob
=
keep_prob
if
isinstance
(
nms
,
dict
):
self
.
nms
=
MultiClassNMS
(
**
nms
)
self
.
downsample
=
downsample
# TODO(guanzhong) activate scale_x_y in Paddle 2.0
#self.scale_x_y = scale_x_y
def
_conv_bn
(
self
,
input
,
...
...
@@ -105,7 +114,6 @@ class YOLOv3Head(object):
out
=
fluid
.
layers
.
batch_norm
(
input
=
conv
,
act
=
None
,
is_test
=
is_test
,
param_attr
=
bn_param_attr
,
bias_attr
=
bn_bias_attr
,
moving_mean_name
=
bn_name
+
'.mean'
,
...
...
@@ -301,27 +309,211 @@ class YOLOv3Head(object):
boxes
=
[]
scores
=
[]
downsample
=
32
for
i
,
output
in
enumerate
(
outputs
):
if
self
.
iou_aware
:
output
=
get_iou_aware_score
(
output
,
len
(
self
.
anchor_masks
[
i
]),
self
.
num_classes
,
self
.
iou_aware_factor
)
#scale_x_y = self.scale_x_y if not isinstance(
# self.scale_x_y, Sequence) else self.scale_x_y[i]
box
,
score
=
fluid
.
layers
.
yolo_box
(
x
=
output
,
img_size
=
im_size
,
anchors
=
self
.
mask_anchors
[
i
],
class_num
=
self
.
num_classes
,
conf_thresh
=
self
.
nms
.
score_threshold
,
downsample_ratio
=
downsample
,
name
=
self
.
prefix_name
+
"yolo_box"
+
str
(
i
))
downsample_ratio
=
self
.
downsample
[
i
],
name
=
self
.
prefix_name
+
"yolo_box"
+
str
(
i
),
clip_bbox
=
False
)
boxes
.
append
(
box
)
scores
.
append
(
fluid
.
layers
.
transpose
(
score
,
perm
=
[
0
,
2
,
1
]))
downsample
//=
2
yolo_boxes
=
fluid
.
layers
.
concat
(
boxes
,
axis
=
1
)
yolo_scores
=
fluid
.
layers
.
concat
(
scores
,
axis
=
2
)
pred
=
self
.
nms
(
bboxes
=
yolo_boxes
,
scores
=
yolo_scores
)
return
{
'bbox'
:
pred
}
@
register
class
YOLOv4Head
(
YOLOv3Head
):
"""
Head block for YOLOv4 network
Args:
anchors (list): anchors
anchor_masks (list): anchor masks
nms (object): an instance of `MultiClassNMS`
spp_stage (int): apply spp on which stage.
num_classes (int): number of output classes
downsample (list): downsample ratio for each yolo_head
scale_x_y (list): scale the left top point of bbox at each stage
"""
__inject__
=
[
'nms'
,
'yolo_loss'
]
__shared__
=
[
'num_classes'
,
'weight_prefix_name'
]
def
__init__
(
self
,
anchors
=
[[
12
,
16
],
[
19
,
36
],
[
40
,
28
],
[
36
,
75
],
[
76
,
55
],
[
72
,
146
],
[
142
,
110
],
[
192
,
243
],
[
459
,
401
]],
anchor_masks
=
[[
0
,
1
,
2
],
[
3
,
4
,
5
],
[
6
,
7
,
8
]],
nms
=
MultiClassNMS
(
score_threshold
=
0.01
,
nms_top_k
=-
1
,
keep_top_k
=-
1
,
nms_threshold
=
0.45
,
background_label
=-
1
).
__dict__
,
spp_stage
=
5
,
num_classes
=
80
,
weight_prefix_name
=
''
,
downsample
=
[
8
,
16
,
32
],
scale_x_y
=
[
1.2
,
1.1
,
1.05
],
yolo_loss
=
"YOLOv3Loss"
,
iou_aware
=
False
,
iou_aware_factor
=
0.4
,
):
super
(
YOLOv4Head
,
self
).
__init__
(
anchors
=
anchors
,
anchor_masks
=
anchor_masks
,
nms
=
nms
,
num_classes
=
num_classes
,
weight_prefix_name
=
weight_prefix_name
,
downsample
=
downsample
,
scale_x_y
=
scale_x_y
,
yolo_loss
=
yolo_loss
,
iou_aware
=
iou_aware
,
iou_aware_factor
=
iou_aware_factor
)
self
.
spp_stage
=
spp_stage
def
_upsample
(
self
,
input
,
scale
=
2
,
name
=
None
):
out
=
fluid
.
layers
.
resize_nearest
(
input
=
input
,
scale
=
float
(
scale
),
name
=
name
)
return
out
def
max_pool
(
self
,
input
,
size
):
pad
=
[(
size
-
1
)
//
2
]
*
2
return
fluid
.
layers
.
pool2d
(
input
,
size
,
'max'
,
pool_padding
=
pad
)
def
spp
(
self
,
input
):
branch_a
=
self
.
max_pool
(
input
,
13
)
branch_b
=
self
.
max_pool
(
input
,
9
)
branch_c
=
self
.
max_pool
(
input
,
5
)
out
=
fluid
.
layers
.
concat
([
branch_a
,
branch_b
,
branch_c
,
input
],
axis
=
1
)
return
out
def
stack_conv
(
self
,
input
,
ch_list
=
[
512
,
1024
,
512
],
filter_list
=
[
1
,
3
,
1
],
stride
=
1
,
name
=
None
):
conv
=
input
for
i
,
(
ch_out
,
f_size
)
in
enumerate
(
zip
(
ch_list
,
filter_list
)):
padding
=
1
if
f_size
==
3
else
0
conv
=
self
.
_conv_bn
(
conv
,
ch_out
=
ch_out
,
filter_size
=
f_size
,
stride
=
stride
,
padding
=
padding
,
name
=
'{}.{}'
.
format
(
name
,
i
))
return
conv
def
spp_module
(
self
,
input
,
name
=
None
):
conv
=
self
.
stack_conv
(
input
,
name
=
name
+
'.stack_conv.0'
)
spp_out
=
self
.
spp
(
conv
)
conv
=
self
.
stack_conv
(
spp_out
,
name
=
name
+
'.stack_conv.1'
)
return
conv
def
pan_module
(
self
,
input
,
filter_list
,
name
=
None
):
for
i
in
range
(
1
,
len
(
input
)):
ch_out
=
input
[
i
].
shape
[
1
]
//
2
conv_left
=
self
.
_conv_bn
(
input
[
i
],
ch_out
=
ch_out
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
name
=
name
+
'.{}.left'
.
format
(
i
))
ch_out
=
input
[
i
-
1
].
shape
[
1
]
//
2
conv_right
=
self
.
_conv_bn
(
input
[
i
-
1
],
ch_out
=
ch_out
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
name
=
name
+
'.{}.right'
.
format
(
i
))
conv_right
=
self
.
_upsample
(
conv_right
)
pan_out
=
fluid
.
layers
.
concat
([
conv_left
,
conv_right
],
axis
=
1
)
ch_list
=
[
pan_out
.
shape
[
1
]
//
2
*
k
for
k
in
[
1
,
2
,
1
,
2
,
1
]]
input
[
i
]
=
self
.
stack_conv
(
pan_out
,
ch_list
=
ch_list
,
filter_list
=
filter_list
,
name
=
name
+
'.stack_conv.{}'
.
format
(
i
))
return
input
def
_get_outputs
(
self
,
input
,
is_train
=
True
):
outputs
=
[]
filter_list
=
[
1
,
3
,
1
,
3
,
1
]
spp_stage
=
len
(
input
)
-
self
.
spp_stage
# get last out_layer_num blocks in reverse order
out_layer_num
=
len
(
self
.
anchor_masks
)
blocks
=
input
[
-
1
:
-
out_layer_num
-
1
:
-
1
]
blocks
[
spp_stage
]
=
self
.
spp_module
(
blocks
[
spp_stage
],
name
=
self
.
prefix_name
+
"spp_module"
)
blocks
=
self
.
pan_module
(
blocks
,
filter_list
=
filter_list
,
name
=
self
.
prefix_name
+
'pan_module'
)
# reverse order back to input
blocks
=
blocks
[::
-
1
]
route
=
None
for
i
,
block
in
enumerate
(
blocks
):
if
i
>
0
:
# perform concat in first 2 detection_block
route
=
self
.
_conv_bn
(
route
,
ch_out
=
route
.
shape
[
1
]
*
2
,
filter_size
=
3
,
stride
=
2
,
padding
=
1
,
name
=
self
.
prefix_name
+
'yolo_block.route.{}'
.
format
(
i
))
block
=
fluid
.
layers
.
concat
(
input
=
[
route
,
block
],
axis
=
1
)
ch_list
=
[
block
.
shape
[
1
]
//
2
*
k
for
k
in
[
1
,
2
,
1
,
2
,
1
]]
block
=
self
.
stack_conv
(
block
,
ch_list
=
ch_list
,
filter_list
=
filter_list
,
name
=
self
.
prefix_name
+
'yolo_block.stack_conv.{}'
.
format
(
i
))
route
=
block
block_out
=
self
.
_conv_bn
(
block
,
ch_out
=
block
.
shape
[
1
]
*
2
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
,
name
=
self
.
prefix_name
+
'yolo_output.{}.conv.0'
.
format
(
i
))
if
self
.
iou_aware
:
num_filters
=
len
(
self
.
anchor_masks
[
i
])
*
(
self
.
num_classes
+
6
)
else
:
num_filters
=
len
(
self
.
anchor_masks
[
i
])
*
(
self
.
num_classes
+
5
)
block_out
=
fluid
.
layers
.
conv2d
(
input
=
block_out
,
num_filters
=
num_filters
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
act
=
None
,
param_attr
=
ParamAttr
(
name
=
self
.
prefix_name
+
"yolo_output.{}.conv.1.weights"
.
format
(
i
)),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.
),
name
=
self
.
prefix_name
+
"yolo_output.{}.conv.1.bias"
.
format
(
i
)))
outputs
.
append
(
block_out
)
return
outputs
ppdet/modeling/architectures/__init__.py
浏览文件 @
76f6c939
...
...
@@ -19,7 +19,7 @@ from . import mask_rcnn
from
.
import
cascade_rcnn
from
.
import
cascade_mask_rcnn
from
.
import
cascade_rcnn_cls_aware
from
.
import
yolo
v3
from
.
import
yolo
from
.
import
ssd
from
.
import
retinanet
from
.
import
efficientdet
...
...
@@ -33,7 +33,7 @@ from .mask_rcnn import *
from
.cascade_rcnn
import
*
from
.cascade_mask_rcnn
import
*
from
.cascade_rcnn_cls_aware
import
*
from
.yolo
v3
import
*
from
.yolo
import
*
from
.ssd
import
*
from
.retinanet
import
*
from
.efficientdet
import
*
...
...
ppdet/modeling/architectures/yolo
v3
.py
→
ppdet/modeling/architectures/yolo.py
浏览文件 @
76f6c939
...
...
@@ -23,7 +23,7 @@ from paddle import fluid
from
ppdet.experimental
import
mixed_precision_global_state
from
ppdet.core.workspace
import
register
__all__
=
[
'YOLOv3'
]
__all__
=
[
'YOLOv3'
,
'YOLOv4'
]
@
register
...
...
@@ -42,7 +42,7 @@ class YOLOv3(object):
def
__init__
(
self
,
backbone
,
yolo_head
=
'YOLOv
3
Head'
,
yolo_head
=
'YOLOv
4
Head'
,
use_fine_grained_loss
=
False
):
super
(
YOLOv3
,
self
).
__init__
()
self
.
backbone
=
backbone
...
...
@@ -160,3 +160,27 @@ class YOLOv3(object):
def
test
(
self
,
feed_vars
):
return
self
.
build
(
feed_vars
,
mode
=
'test'
)
@
register
class
YOLOv4
(
YOLOv3
):
"""
YOLOv4 network, see https://arxiv.org/abs/2004.10934
Args:
backbone (object): an backbone instance
yolo_head (object): an `YOLOv4Head` instance
"""
__category__
=
'architecture'
__inject__
=
[
'backbone'
,
'yolo_head'
]
__shared__
=
[
'use_fine_grained_loss'
]
def
__init__
(
self
,
backbone
,
yolo_head
=
'YOLOv4Head'
,
use_fine_grained_loss
=
False
):
super
(
YOLOv4
,
self
).
__init__
(
backbone
=
backbone
,
yolo_head
=
yolo_head
,
use_fine_grained_loss
=
use_fine_grained_loss
)
ppdet/modeling/backbones/__init__.py
浏览文件 @
76f6c939
...
...
@@ -32,6 +32,7 @@ from . import bfp
from
.
import
hourglass
from
.
import
efficientnet
from
.
import
bifpn
from
.
import
cspdarknet
from
.resnet
import
*
from
.resnext
import
*
...
...
@@ -51,3 +52,4 @@ from .bfp import *
from
.hourglass
import
*
from
.efficientnet
import
*
from
.bifpn
import
*
from
.cspdarknet
import
*
ppdet/modeling/backbones/cspdarknet.py
0 → 100644
浏览文件 @
76f6c939
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
six
from
paddle
import
fluid
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.regularizer
import
L2Decay
from
ppdet.core.workspace
import
register
__all__
=
[
'CSPDarkNet'
]
@
register
class
CSPDarkNet
(
object
):
"""
CSPDarkNet, see https://arxiv.org/abs/1911.11929
Args:
depth (int): network depth, currently only cspdarknet 53 is supported
norm_type (str): normalization type, 'bn' and 'sync_bn' are supported
norm_decay (float): weight decay for normalization layer weights
"""
__shared__
=
[
'norm_type'
,
'weight_prefix_name'
]
def
__init__
(
self
,
depth
=
53
,
norm_type
=
'bn'
,
norm_decay
=
0.
,
weight_prefix_name
=
''
):
assert
depth
in
[
53
],
"unsupported depth value"
self
.
depth
=
depth
self
.
norm_type
=
norm_type
self
.
norm_decay
=
norm_decay
self
.
depth_cfg
=
{
53
:
([
1
,
2
,
8
,
8
,
4
],
self
.
basicblock
)}
self
.
prefix_name
=
weight_prefix_name
def
_softplus
(
self
,
input
):
expf
=
fluid
.
layers
.
exp
(
fluid
.
layers
.
clip
(
input
,
-
200
,
50
))
return
fluid
.
layers
.
log
(
1
+
expf
)
def
_mish
(
self
,
input
):
return
input
*
fluid
.
layers
.
tanh
(
self
.
_softplus
(
input
))
def
_conv_norm
(
self
,
input
,
ch_out
,
filter_size
,
stride
,
padding
,
act
=
'mish'
,
name
=
None
):
conv
=
fluid
.
layers
.
conv2d
(
input
=
input
,
num_filters
=
ch_out
,
filter_size
=
filter_size
,
stride
=
stride
,
padding
=
padding
,
act
=
None
,
param_attr
=
ParamAttr
(
name
=
name
+
".conv.weights"
),
bias_attr
=
False
)
bn_name
=
name
+
".bn"
bn_param_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
float
(
self
.
norm_decay
)),
name
=
bn_name
+
'.scale'
)
bn_bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
float
(
self
.
norm_decay
)),
name
=
bn_name
+
'.offset'
)
out
=
fluid
.
layers
.
batch_norm
(
input
=
conv
,
act
=
None
,
param_attr
=
bn_param_attr
,
bias_attr
=
bn_bias_attr
,
moving_mean_name
=
bn_name
+
'.mean'
,
moving_variance_name
=
bn_name
+
'.var'
)
if
act
==
'mish'
:
out
=
self
.
_mish
(
out
)
return
out
def
_downsample
(
self
,
input
,
ch_out
,
filter_size
=
3
,
stride
=
2
,
padding
=
1
,
name
=
None
):
return
self
.
_conv_norm
(
input
,
ch_out
=
ch_out
,
filter_size
=
filter_size
,
stride
=
stride
,
padding
=
padding
,
name
=
name
)
def
conv_layer
(
self
,
input
,
ch_out
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
name
=
None
):
return
self
.
_conv_norm
(
input
,
ch_out
=
ch_out
,
filter_size
=
filter_size
,
stride
=
stride
,
padding
=
padding
,
name
=
name
)
def
basicblock
(
self
,
input
,
ch_out
,
scale_first
=
False
,
name
=
None
):
conv1
=
self
.
_conv_norm
(
input
,
ch_out
=
ch_out
//
2
if
scale_first
else
ch_out
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
name
=
name
+
".0"
)
conv2
=
self
.
_conv_norm
(
conv1
,
ch_out
=
ch_out
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
,
name
=
name
+
".1"
)
out
=
fluid
.
layers
.
elementwise_add
(
x
=
input
,
y
=
conv2
,
act
=
None
)
return
out
def
layer_warp
(
self
,
block_func
,
input
,
ch_out
,
count
,
keep_ch
=
False
,
scale_first
=
False
,
name
=
None
):
if
scale_first
:
ch_out
=
ch_out
*
2
right
=
self
.
conv_layer
(
input
,
ch_out
,
name
=
'{}.route_in.right'
.
format
(
name
))
neck
=
self
.
conv_layer
(
input
,
ch_out
,
name
=
'{}.neck'
.
format
(
name
))
out
=
block_func
(
neck
,
ch_out
=
ch_out
,
scale_first
=
scale_first
,
name
=
'{}.0'
.
format
(
name
))
for
j
in
six
.
moves
.
xrange
(
1
,
count
):
out
=
block_func
(
out
,
ch_out
=
ch_out
,
name
=
'{}.{}'
.
format
(
name
,
j
))
left
=
self
.
conv_layer
(
out
,
ch_out
,
name
=
'{}.route_in.left'
.
format
(
name
))
route
=
fluid
.
layers
.
concat
([
left
,
right
],
axis
=
1
)
out
=
self
.
conv_layer
(
route
,
ch_out
=
ch_out
if
keep_ch
else
ch_out
*
2
,
name
=
'{}.conv_layer'
.
format
(
name
))
return
out
def
__call__
(
self
,
input
):
"""
Get the backbone of CSPDarkNet, that is output for the 5 stages.
Args:
input (Variable): input variable.
Returns:
The last variables of each stage.
"""
stages
,
block_func
=
self
.
depth_cfg
[
self
.
depth
]
stages
=
stages
[
0
:
5
]
conv
=
self
.
_conv_norm
(
input
=
input
,
ch_out
=
32
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
,
act
=
'mish'
,
name
=
self
.
prefix_name
+
"conv"
)
blocks
=
[]
for
i
,
stage
in
enumerate
(
stages
):
input
=
conv
if
i
==
0
else
block
downsample_
=
self
.
_downsample
(
input
=
input
,
ch_out
=
input
.
shape
[
1
]
*
2
,
name
=
self
.
prefix_name
+
"stage.{}.downsample"
.
format
(
i
))
block
=
self
.
layer_warp
(
block_func
=
block_func
,
input
=
downsample_
,
ch_out
=
32
*
2
**
i
,
count
=
stage
,
keep_ch
=
(
i
==
0
),
scale_first
=
i
==
0
,
name
=
self
.
prefix_name
+
"stage.{}"
.
format
(
i
))
blocks
.
append
(
block
)
return
blocks
ppdet/modeling/losses/iou_aware_loss.py
浏览文件 @
76f6c939
...
...
@@ -66,8 +66,11 @@ class IouAwareLoss(IouLoss):
eps (float): the decimal to prevent the denominator eqaul zero
'''
iouk
=
self
.
_iou
(
x
,
y
,
w
,
h
,
tx
,
ty
,
tw
,
th
,
anchors
,
downsample_ratio
,
batch_size
,
ioup
,
eps
)
pred
=
self
.
_bbox_transform
(
x
,
y
,
w
,
h
,
anchors
,
downsample_ratio
,
batch_size
,
False
)
gt
=
self
.
_bbox_transform
(
tx
,
ty
,
tw
,
th
,
anchors
,
downsample_ratio
,
batch_size
,
True
)
iouk
=
self
.
_iou
(
pred
,
gt
,
ioup
,
eps
)
iouk
.
stop_gradient
=
True
loss_iou_aware
=
fluid
.
layers
.
cross_entropy
(
ioup
,
iouk
,
soft_label
=
True
)
...
...
ppdet/modeling/losses/iou_loss.py
浏览文件 @
76f6c939
...
...
@@ -35,12 +35,21 @@ class IouLoss(object):
loss_weight (float): iou loss weight, default is 2.5
max_height (int): max height of input to support random shape input
max_width (int): max width of input to support random shape input
ciou_term (bool): whether to add ciou_term
loss_square (bool): whether to square the iou term
"""
def
__init__
(
self
,
loss_weight
=
2.5
,
max_height
=
608
,
max_width
=
608
):
def
__init__
(
self
,
loss_weight
=
2.5
,
max_height
=
608
,
max_width
=
608
,
ciou_term
=
False
,
loss_square
=
True
):
self
.
_loss_weight
=
loss_weight
self
.
_MAX_HI
=
max_height
self
.
_MAX_WI
=
max_width
self
.
ciou_term
=
ciou_term
self
.
loss_square
=
loss_square
def
__call__
(
self
,
x
,
...
...
@@ -65,33 +74,22 @@ class IouLoss(object):
batch_size (int): training batch size
eps (float): the decimal to prevent the denominator eqaul zero
'''
iouk
=
self
.
_iou
(
x
,
y
,
w
,
h
,
tx
,
ty
,
tw
,
th
,
anchors
,
downsample_ratio
,
batch_size
,
ioup
,
eps
)
loss_iou
=
1.
-
iouk
*
iouk
pred
=
self
.
_bbox_transform
(
x
,
y
,
w
,
h
,
anchors
,
downsample_ratio
,
batch_size
,
False
)
gt
=
self
.
_bbox_transform
(
tx
,
ty
,
tw
,
th
,
anchors
,
downsample_ratio
,
batch_size
,
True
)
iouk
=
self
.
_iou
(
pred
,
gt
,
ioup
,
eps
)
if
self
.
loss_square
:
loss_iou
=
1.
-
iouk
*
iouk
else
:
loss_iou
=
1.
-
iouk
loss_iou
=
loss_iou
*
self
.
_loss_weight
return
loss_iou
def
_iou
(
self
,
x
,
y
,
w
,
h
,
tx
,
ty
,
tw
,
th
,
anchors
,
downsample_ratio
,
batch_size
,
ioup
=
None
,
eps
=
1.e-10
):
x1
,
y1
,
x2
,
y2
=
self
.
_bbox_transform
(
x
,
y
,
w
,
h
,
anchors
,
downsample_ratio
,
batch_size
,
False
)
x1g
,
y1g
,
x2g
,
y2g
=
self
.
_bbox_transform
(
tx
,
ty
,
tw
,
th
,
anchors
,
downsample_ratio
,
batch_size
,
True
)
def
_iou
(
self
,
pred
,
gt
,
ioup
=
None
,
eps
=
1.e-10
):
x1
,
y1
,
x2
,
y2
=
pred
x1g
,
y1g
,
x2g
,
y2g
=
gt
x2
=
fluid
.
layers
.
elementwise_max
(
x1
,
x2
)
y2
=
fluid
.
layers
.
elementwise_max
(
y1
,
y2
)
...
...
@@ -106,8 +104,46 @@ class IouLoss(object):
unionk
=
(
x2
-
x1
)
*
(
y2
-
y1
)
+
(
x2g
-
x1g
)
*
(
y2g
-
y1g
)
-
intsctk
+
eps
iouk
=
intsctk
/
unionk
if
self
.
ciou_term
:
ciou
=
self
.
get_ciou_term
(
pred
,
gt
,
iouk
,
eps
)
iouk
=
iouk
-
ciou
return
iouk
def
get_ciou_term
(
self
,
pred
,
gt
,
iouk
,
eps
):
x1
,
y1
,
x2
,
y2
=
pred
x1g
,
y1g
,
x2g
,
y2g
=
gt
cx
=
(
x1
+
x2
)
/
2
cy
=
(
y1
+
y2
)
/
2
w
=
(
x2
-
x1
)
+
fluid
.
layers
.
cast
((
x2
-
x1
)
==
0
,
'float32'
)
h
=
(
y2
-
y1
)
+
fluid
.
layers
.
cast
((
y2
-
y1
)
==
0
,
'float32'
)
cxg
=
(
x1g
+
x2g
)
/
2
cyg
=
(
y1g
+
y2g
)
/
2
wg
=
x2g
-
x1g
hg
=
y2g
-
y1g
# A or B
xc1
=
fluid
.
layers
.
elementwise_min
(
x1
,
x1g
)
yc1
=
fluid
.
layers
.
elementwise_min
(
y1
,
y1g
)
xc2
=
fluid
.
layers
.
elementwise_max
(
x2
,
x2g
)
yc2
=
fluid
.
layers
.
elementwise_max
(
y2
,
y2g
)
# DIOU term
dist_intersection
=
(
cx
-
cxg
)
*
(
cx
-
cxg
)
+
(
cy
-
cyg
)
*
(
cy
-
cyg
)
dist_union
=
(
xc2
-
xc1
)
*
(
xc2
-
xc1
)
+
(
yc2
-
yc1
)
*
(
yc2
-
yc1
)
diou_term
=
(
dist_intersection
+
eps
)
/
(
dist_union
+
eps
)
# CIOU term
ciou_term
=
0
ar_gt
=
wg
/
hg
ar_pred
=
w
/
h
arctan
=
fluid
.
layers
.
atan
(
ar_gt
)
-
fluid
.
layers
.
atan
(
ar_pred
)
ar_loss
=
4.
/
np
.
pi
/
np
.
pi
*
arctan
*
arctan
alpha
=
ar_loss
/
(
1
-
iouk
+
ar_loss
+
eps
)
alpha
.
stop_gradient
=
True
ciou_term
=
alpha
*
ar_loss
return
diou_term
+
ciou_term
def
_bbox_transform
(
self
,
dcx
,
dcy
,
dw
,
dh
,
anchors
,
downsample_ratio
,
batch_size
,
is_gt
):
grid_x
=
int
(
self
.
_MAX_WI
/
downsample_ratio
)
...
...
ppdet/modeling/losses/yolo_loss.py
浏览文件 @
76f6c939
...
...
@@ -18,6 +18,10 @@ from __future__ import print_function
from
paddle
import
fluid
from
ppdet.core.workspace
import
register
try
:
from
collections.abc
import
Sequence
except
Exception
:
from
collections
import
Sequence
__all__
=
[
'YOLOv3Loss'
]
...
...
@@ -43,13 +47,20 @@ class YOLOv3Loss(object):
label_smooth
=
True
,
use_fine_grained_loss
=
False
,
iou_loss
=
None
,
iou_aware_loss
=
None
):
iou_aware_loss
=
None
,
downsample
=
[
32
,
16
,
8
],
scale_x_y
=
1.
,
match_score
=
False
):
self
.
_batch_size
=
batch_size
self
.
_ignore_thresh
=
ignore_thresh
self
.
_label_smooth
=
label_smooth
self
.
_use_fine_grained_loss
=
use_fine_grained_loss
self
.
_iou_loss
=
iou_loss
self
.
_iou_aware_loss
=
iou_aware_loss
self
.
downsample
=
downsample
# TODO(guanzhong) activate scale_x_y in Paddle 2.0
#self.scale_x_y = scale_x_y
self
.
match_score
=
match_score
def
__call__
(
self
,
outputs
,
gt_box
,
gt_label
,
gt_score
,
targets
,
anchors
,
anchor_masks
,
mask_anchors
,
num_classes
,
prefix_name
):
...
...
@@ -59,8 +70,9 @@ class YOLOv3Loss(object):
mask_anchors
,
self
.
_ignore_thresh
)
else
:
losses
=
[]
downsample
=
32
for
i
,
output
in
enumerate
(
outputs
):
#scale_x_y = self.scale_x_y if not isinstance(
# self.scale_x_y, Sequence) else self.scale_x_y[i]
anchor_mask
=
anchor_masks
[
i
]
loss
=
fluid
.
layers
.
yolov3_loss
(
x
=
output
,
...
...
@@ -71,11 +83,10 @@ class YOLOv3Loss(object):
anchor_mask
=
anchor_mask
,
class_num
=
num_classes
,
ignore_thresh
=
self
.
_ignore_thresh
,
downsample_ratio
=
downsample
,
downsample_ratio
=
self
.
downsample
[
i
]
,
use_label_smooth
=
self
.
_label_smooth
,
name
=
prefix_name
+
"yolo_loss"
+
str
(
i
))
losses
.
append
(
fluid
.
layers
.
reduce_mean
(
loss
))
downsample
//=
2
return
{
'loss'
:
sum
(
losses
)}
...
...
@@ -108,7 +119,6 @@ class YOLOv3Loss(object):
assert
len
(
outputs
)
==
len
(
targets
),
\
"YOLOv3 output layer number not equal target number"
downsample
=
32
loss_xys
,
loss_whs
,
loss_objs
,
loss_clss
=
[],
[],
[],
[]
if
self
.
_iou_loss
is
not
None
:
loss_ious
=
[]
...
...
@@ -116,6 +126,7 @@ class YOLOv3Loss(object):
loss_iou_awares
=
[]
for
i
,
(
output
,
target
,
anchors
)
in
enumerate
(
zip
(
outputs
,
targets
,
mask_anchors
)):
downsample
=
self
.
downsample
[
i
]
an_num
=
len
(
anchors
)
//
2
if
self
.
_iou_aware_loss
is
not
None
:
ioup
,
output
=
self
.
_split_ioup
(
output
,
an_num
,
num_classes
)
...
...
@@ -151,9 +162,11 @@ class YOLOv3Loss(object):
loss_iou_aware
,
dim
=
[
1
,
2
,
3
])
loss_iou_awares
.
append
(
fluid
.
layers
.
reduce_mean
(
loss_iou_aware
))
#scale_x_y = self.scale_x_y if not isinstance(
# self.scale_x_y, Sequence) else self.scale_x_y[i]
loss_obj_pos
,
loss_obj_neg
=
self
.
_calc_obj_loss
(
output
,
obj
,
tobj
,
gt_box
,
self
.
_batch_size
,
anchors
,
num_classes
,
downsample
,
self
.
_ignore_thresh
)
num_classes
,
downsample
,
self
.
_ignore_thresh
,
scale_x_y
)
loss_cls
=
fluid
.
layers
.
sigmoid_cross_entropy_with_logits
(
cls
,
tcls
)
loss_cls
=
fluid
.
layers
.
elementwise_mul
(
loss_cls
,
tobj
,
axis
=
0
)
...
...
@@ -165,7 +178,6 @@ class YOLOv3Loss(object):
fluid
.
layers
.
reduce_mean
(
loss_obj_pos
+
loss_obj_neg
))
loss_clss
.
append
(
fluid
.
layers
.
reduce_mean
(
loss_cls
))
downsample
//=
2
losses_all
=
{
"loss_xy"
:
fluid
.
layers
.
sum
(
loss_xys
),
"loss_wh"
:
fluid
.
layers
.
sum
(
loss_whs
),
...
...
@@ -264,13 +276,13 @@ class YOLOv3Loss(object):
return
(
tx
,
ty
,
tw
,
th
,
tscale
,
tobj
,
tcls
)
def
_calc_obj_loss
(
self
,
output
,
obj
,
tobj
,
gt_box
,
batch_size
,
anchors
,
num_classes
,
downsample
,
ignore_thresh
):
num_classes
,
downsample
,
ignore_thresh
,
scale_x_y
):
# A prediction bbox overlap any gt_bbox over ignore_thresh,
# objectness loss will be ignored, process as follows:
# 1. get pred bbox, which is same with YOLOv3 infer mode, use yolo_box here
# NOTE: img_size is set as 1.0 to get noramlized pred bbox
bbox
,
_
=
fluid
.
layers
.
yolo_box
(
bbox
,
prob
=
fluid
.
layers
.
yolo_box
(
x
=
output
,
img_size
=
fluid
.
layers
.
ones
(
shape
=
[
batch_size
,
2
],
dtype
=
"int32"
),
...
...
@@ -288,6 +300,7 @@ class YOLOv3Loss(object):
else
:
preds
=
[
bbox
]
gts
=
[
gt_box
]
probs
=
[
prob
]
ious
=
[]
for
pred
,
gt
in
zip
(
preds
,
gts
):
...
...
@@ -307,12 +320,17 @@ class YOLOv3Loss(object):
pred
=
fluid
.
layers
.
squeeze
(
pred
,
axes
=
[
0
])
gt
=
box_xywh2xyxy
(
fluid
.
layers
.
squeeze
(
gt
,
axes
=
[
0
]))
ious
.
append
(
fluid
.
layers
.
iou_similarity
(
pred
,
gt
))
iou
=
fluid
.
layers
.
stack
(
ious
,
axis
=
0
)
iou
=
fluid
.
layers
.
stack
(
ious
,
axis
=
0
)
# 3. Get iou_mask by IoU between gt bbox and prediction bbox,
# Get obj_mask by tobj(holds gt_score), calculate objectness loss
max_iou
=
fluid
.
layers
.
reduce_max
(
iou
,
dim
=-
1
)
iou_mask
=
fluid
.
layers
.
cast
(
max_iou
<=
ignore_thresh
,
dtype
=
"float32"
)
if
self
.
match_score
:
max_prob
=
fluid
.
layers
.
reduce_max
(
prob
,
dim
=-
1
)
iou_mask
=
iou_mask
*
fluid
.
layers
.
cast
(
max_prob
<=
0.25
,
dtype
=
"float32"
)
output_shape
=
fluid
.
layers
.
shape
(
output
)
an_num
=
len
(
anchors
)
//
2
iou_mask
=
fluid
.
layers
.
reshape
(
iou_mask
,
(
-
1
,
an_num
,
output_shape
[
2
],
...
...
ppdet/utils/coco_eval.py
浏览文件 @
76f6c939
...
...
@@ -37,11 +37,13 @@ __all__ = [
]
def
clip_bbox
(
bbox
):
xmin
=
max
(
min
(
bbox
[
0
],
1.
),
0.
)
ymin
=
max
(
min
(
bbox
[
1
],
1.
),
0.
)
xmax
=
max
(
min
(
bbox
[
2
],
1.
),
0.
)
ymax
=
max
(
min
(
bbox
[
3
],
1.
),
0.
)
def
clip_bbox
(
bbox
,
im_size
=
None
):
h
=
1.
if
im_size
is
None
else
im_size
[
0
]
w
=
1.
if
im_size
is
None
else
im_size
[
1
]
xmin
=
max
(
min
(
bbox
[
0
],
w
),
0.
)
ymin
=
max
(
min
(
bbox
[
1
],
h
),
0.
)
xmax
=
max
(
min
(
bbox
[
2
],
w
),
0.
)
ymax
=
max
(
min
(
bbox
[
3
],
h
),
0.
)
return
xmin
,
ymin
,
xmax
,
ymax
...
...
@@ -66,7 +68,8 @@ def bbox_eval(results,
anno_file
,
outfile
,
with_background
=
True
,
is_bbox_normalized
=
False
):
is_bbox_normalized
=
False
,
save_only
=
False
):
assert
'bbox'
in
results
[
0
]
assert
outfile
.
endswith
(
'.json'
)
from
pycocotools.coco
import
COCO
...
...
@@ -91,13 +94,23 @@ def bbox_eval(results,
with
open
(
outfile
,
'w'
)
as
f
:
json
.
dump
(
xywh_results
,
f
)
if
save_only
:
logger
.
info
(
'The bbox result is saved to {} and do not '
'evaluate the mAP.'
.
format
(
outfile
))
return
map_stats
=
cocoapi_eval
(
outfile
,
'bbox'
,
coco_gt
=
coco_gt
)
# flush coco evaluation result
sys
.
stdout
.
flush
()
return
map_stats
def
mask_eval
(
results
,
anno_file
,
outfile
,
resolution
,
thresh_binarize
=
0.5
):
def
mask_eval
(
results
,
anno_file
,
outfile
,
resolution
,
thresh_binarize
=
0.5
,
save_only
=
False
):
assert
'mask'
in
results
[
0
]
assert
outfile
.
endswith
(
'.json'
)
from
pycocotools.coco
import
COCO
...
...
@@ -143,6 +156,11 @@ def mask_eval(results, anno_file, outfile, resolution, thresh_binarize=0.5):
with
open
(
outfile
,
'w'
)
as
f
:
json
.
dump
(
segm_results
,
f
)
if
save_only
:
logger
.
info
(
'The mask result is saved to {} and do not '
'evaluate the mAP.'
.
format
(
outfile
))
return
cocoapi_eval
(
outfile
,
'segm'
,
coco_gt
=
coco_gt
)
...
...
@@ -257,8 +275,11 @@ def bbox2out(results, clsid2catid, is_bbox_normalized=False):
w
*=
im_width
h
*=
im_height
else
:
w
=
xmax
-
xmin
+
1
h
=
ymax
-
ymin
+
1
im_size
=
t
[
'im_size'
][
0
][
i
].
tolist
()
xmin
,
ymin
,
xmax
,
ymax
=
\
clip_bbox
([
xmin
,
ymin
,
xmax
,
ymax
],
im_size
)
w
=
xmax
-
xmin
h
=
ymax
-
ymin
bbox
=
[
xmin
,
ymin
,
w
,
h
]
coco_res
=
{
...
...
ppdet/utils/eval_utils.py
浏览文件 @
76f6c939
...
...
@@ -191,7 +191,8 @@ def eval_results(results,
is_bbox_normalized
=
False
,
output_directory
=
None
,
map_type
=
'11point'
,
dataset
=
None
):
dataset
=
None
,
save_only
=
False
):
"""Evaluation for evaluation program results"""
box_ap_stats
=
[]
if
metric
==
'COCO'
:
...
...
@@ -213,13 +214,15 @@ def eval_results(results,
anno_file
,
output
,
with_background
,
is_bbox_normalized
=
is_bbox_normalized
)
is_bbox_normalized
=
is_bbox_normalized
,
save_only
=
save_only
)
if
'mask'
in
results
[
0
]:
output
=
'mask.json'
if
output_directory
:
output
=
os
.
path
.
join
(
output_directory
,
'mask.json'
)
mask_eval
(
results
,
anno_file
,
output
,
resolution
)
mask_eval
(
results
,
anno_file
,
output
,
resolution
,
save_only
=
save_only
)
else
:
if
'accum_map'
in
results
[
-
1
]:
res
=
np
.
mean
(
results
[
-
1
][
'accum_map'
][
0
])
...
...
ppdet/utils/voc_eval.py
浏览文件 @
76f6c939
...
...
@@ -68,7 +68,6 @@ def bbox_eval(results,
if
bboxes
.
shape
==
(
1
,
1
)
or
bboxes
is
None
:
continue
gt_boxes
=
t
[
'gt_bbox'
][
0
]
gt_labels
=
t
[
'gt_class'
][
0
]
difficults
=
t
[
'is_difficult'
][
0
]
if
not
evaluate_difficult
\
...
...
tools/eval.py
浏览文件 @
76f6c939
...
...
@@ -111,7 +111,7 @@ def main():
extra_keys
=
[]
if
cfg
.
metric
==
'COCO'
:
extra_keys
=
[
'im_info'
,
'im_id'
,
'im_shape'
]
extra_keys
=
[
'im_info'
,
'im_id'
,
'im_shape'
,
'im_size'
]
if
cfg
.
metric
==
'VOC'
:
extra_keys
=
[
'gt_bbox'
,
'gt_class'
,
'is_difficult'
]
...
...
@@ -160,6 +160,7 @@ def main():
# evaluation
# if map_type not set, use default 11point, only use in VOC eval
map_type
=
cfg
.
map_type
if
'map_type'
in
cfg
else
'11point'
save_only
=
getattr
(
cfg
,
'save_prediction_only'
,
False
)
eval_results
(
results
,
cfg
.
metric
,
...
...
@@ -168,7 +169,8 @@ def main():
is_bbox_normalized
,
FLAGS
.
output_eval
,
map_type
,
dataset
=
dataset
)
dataset
=
dataset
,
save_only
=
save_only
)
if
__name__
==
'__main__'
:
...
...
tools/train.py
浏览文件 @
76f6c939
...
...
@@ -79,6 +79,10 @@ def main():
# check if paddlepaddle version is satisfied
check_version
()
save_only
=
getattr
(
cfg
,
'save_prediction_only'
,
False
)
if
save_only
:
raise
NotImplementedError
(
'The config file only support prediction,'
' training stage is not implemented now'
)
main_arch
=
cfg
.
architecture
if
cfg
.
use_gpu
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录