Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleX
提交
ed8de1ae
P
PaddleX
项目概览
PaddlePaddle
/
PaddleX
通知
138
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
43
列表
看板
标记
里程碑
合并请求
5
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleX
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
43
Issue
43
列表
看板
标记
里程碑
合并请求
5
合并请求
5
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ed8de1ae
编写于
8月 06, 2020
作者:
F
FlyingQianMM
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add ppyolo
上级
f1465e6f
变更
11
展开全部
隐藏空白更改
内联
并排
Showing
11 changed file
with
1751 addition
and
147 deletion
+1751
-147
paddlex/cv/datasets/dataset.py
paddlex/cv/datasets/dataset.py
+10
-8
paddlex/cv/models/base.py
paddlex/cv/models/base.py
+19
-8
paddlex/cv/models/yolo_v3.py
paddlex/cv/models/yolo_v3.py
+125
-18
paddlex/cv/nets/detection/iou_aware.py
paddlex/cv/nets/detection/iou_aware.py
+85
-0
paddlex/cv/nets/detection/loss/iou_aware_loss.py
paddlex/cv/nets/detection/loss/iou_aware_loss.py
+77
-0
paddlex/cv/nets/detection/loss/iou_loss.py
paddlex/cv/nets/detection/loss/iou_loss.py
+235
-0
paddlex/cv/nets/detection/loss/yolo_loss.py
paddlex/cv/nets/detection/loss/yolo_loss.py
+371
-0
paddlex/cv/nets/detection/ops.py
paddlex/cv/nets/detection/ops.py
+270
-0
paddlex/cv/nets/detection/yolo_v3.py
paddlex/cv/nets/detection/yolo_v3.py
+305
-113
paddlex/cv/transforms/det_transforms.py
paddlex/cv/transforms/det_transforms.py
+185
-0
tutorials/train/object_detection/ppyolo.py
tutorials/train/object_detection/ppyolo.py
+69
-0
未找到文件。
paddlex/cv/datasets/dataset.py
浏览文件 @
ed8de1ae
...
...
@@ -115,7 +115,7 @@ def multithread_reader(mapper,
while
not
isinstance
(
sample
,
EndSignal
):
batch_data
.
append
(
sample
)
if
len
(
batch_data
)
==
batch_size
:
batch_data
=
generate_minibatch
(
batch_data
)
batch_data
=
generate_minibatch
(
batch_data
,
mapper
=
mapper
)
yield
batch_data
batch_data
=
[]
sample
=
out_queue
.
get
()
...
...
@@ -127,11 +127,11 @@ def multithread_reader(mapper,
else
:
batch_data
.
append
(
sample
)
if
len
(
batch_data
)
==
batch_size
:
batch_data
=
generate_minibatch
(
batch_data
)
batch_data
=
generate_minibatch
(
batch_data
,
mapper
=
mapper
)
yield
batch_data
batch_data
=
[]
if
not
drop_last
and
len
(
batch_data
)
!=
0
:
batch_data
=
generate_minibatch
(
batch_data
)
batch_data
=
generate_minibatch
(
batch_data
,
mapper
=
mapper
)
yield
batch_data
batch_data
=
[]
...
...
@@ -188,18 +188,21 @@ def multiprocess_reader(mapper,
else
:
batch_data
.
append
(
sample
)
if
len
(
batch_data
)
==
batch_size
:
batch_data
=
generate_minibatch
(
batch_data
)
batch_data
=
generate_minibatch
(
batch_data
,
mapper
=
mapper
)
yield
batch_data
batch_data
=
[]
if
len
(
batch_data
)
!=
0
and
not
drop_last
:
batch_data
=
generate_minibatch
(
batch_data
)
batch_data
=
generate_minibatch
(
batch_data
,
mapper
=
mapper
)
yield
batch_data
batch_data
=
[]
return
queue_reader
def
generate_minibatch
(
batch_data
,
label_padding_value
=
255
):
def
generate_minibatch
(
batch_data
,
label_padding_value
=
255
,
mapper
=
None
):
if
mapper
is
not
None
and
mapper
.
batch_transforms
is
not
None
:
for
op
in
mapper
.
batch_transforms
:
batch_data
=
op
(
batch_data
)
# if batch_size is 1, do not pad the image
if
len
(
batch_data
)
==
1
:
return
batch_data
...
...
@@ -218,14 +221,13 @@ def generate_minibatch(batch_data, label_padding_value=255):
(
im_c
,
max_shape
[
1
],
max_shape
[
2
]),
dtype
=
np
.
float32
)
padding_im
[:,
:
im_h
,
:
im_w
]
=
data
[
0
]
if
len
(
data
)
>
2
:
# padding the image, label and insert 'padding' into `im_info` of segmentation during evaluating phase.
# padding the image, label and insert 'padding' into `im_info` of segmentation during evaluating phase.
if
len
(
data
[
1
])
==
0
or
'padding'
not
in
[
data
[
1
][
i
][
0
]
for
i
in
range
(
len
(
data
[
1
]))
]:
data
[
1
].
append
((
'padding'
,
[
im_h
,
im_w
]))
padding_batch
.
append
((
padding_im
,
data
[
1
],
data
[
2
]))
elif
len
(
data
)
>
1
:
if
isinstance
(
data
[
1
],
np
.
ndarray
)
and
len
(
data
[
1
].
shape
)
>
1
:
# padding the image and label of segmentation during the training
...
...
paddlex/cv/models/base.py
浏览文件 @
ed8de1ae
...
...
@@ -94,6 +94,8 @@ class BaseAPI:
self
.
train_inputs
,
self
.
train_outputs
=
self
.
build_net
(
mode
=
'train'
)
self
.
train_prog
=
fluid
.
default_main_program
()
startup_prog
=
fluid
.
default_startup_program
()
self
.
train_prog
.
random_seed
=
1000
startup_prog
.
random_seed
=
1000
# 构建预测网络
self
.
test_prog
=
fluid
.
Program
()
...
...
@@ -246,8 +248,8 @@ class BaseAPI:
logging
.
info
(
"Load pretrain weights from {}."
.
format
(
pretrain_weights
),
use_color
=
True
)
paddlex
.
utils
.
utils
.
load_pretrain_weights
(
self
.
exe
,
self
.
train_prog
,
pretrain_weights
,
fuse_bn
)
paddlex
.
utils
.
utils
.
load_pretrain_weights
(
self
.
exe
,
self
.
train_prog
,
pretrain_weights
,
fuse_bn
)
# 进行裁剪
if
sensitivities_file
is
not
None
:
import
paddleslim
...
...
@@ -351,7 +353,9 @@ class BaseAPI:
logging
.
info
(
"Model saved in {}."
.
format
(
save_dir
))
def
export_inference_model
(
self
,
save_dir
):
test_input_names
=
[
var
.
name
for
var
in
list
(
self
.
test_inputs
.
values
())]
test_input_names
=
[
var
.
name
for
var
in
list
(
self
.
test_inputs
.
values
())
]
test_outputs
=
list
(
self
.
test_outputs
.
values
())
with
fluid
.
scope_guard
(
self
.
scope
):
if
self
.
__class__
.
__name__
==
'MaskRCNN'
:
...
...
@@ -389,7 +393,8 @@ class BaseAPI:
# 模型保存成功的标志
open
(
osp
.
join
(
save_dir
,
'.success'
),
'w'
).
close
()
logging
.
info
(
"Model for inference deploy saved in {}."
.
format
(
save_dir
))
logging
.
info
(
"Model for inference deploy saved in {}."
.
format
(
save_dir
))
def
train_loop
(
self
,
num_epochs
,
...
...
@@ -516,11 +521,13 @@ class BaseAPI:
eta
=
((
num_epochs
-
i
)
*
total_num_steps
-
step
-
1
)
*
avg_step_time
if
time_eval_one_epoch
is
not
None
:
eval_eta
=
(
total_eval_times
-
i
//
save_interval_epochs
)
*
time_eval_one_epoch
eval_eta
=
(
total_eval_times
-
i
//
save_interval_epochs
)
*
time_eval_one_epoch
else
:
eval_eta
=
(
total_eval_times
-
i
//
save_interval_epochs
)
*
total_num_steps_eval
*
avg_step_time
eval_eta
=
(
total_eval_times
-
i
//
save_interval_epochs
)
*
total_num_steps_eval
*
avg_step_time
eta_str
=
seconds_to_hms
(
eta
+
eval_eta
)
logging
.
info
(
...
...
@@ -543,6 +550,8 @@ class BaseAPI:
current_save_dir
=
osp
.
join
(
save_dir
,
"epoch_{}"
.
format
(
i
+
1
))
if
not
osp
.
isdir
(
current_save_dir
):
os
.
makedirs
(
current_save_dir
)
if
hasattr
(
self
,
'use_ema'
):
self
.
exe
.
run
(
self
.
ema
.
apply_program
)
if
eval_dataset
is
not
None
and
eval_dataset
.
num_samples
>
0
:
self
.
eval_metrics
,
self
.
eval_details
=
self
.
evaluate
(
eval_dataset
=
eval_dataset
,
...
...
@@ -569,6 +578,8 @@ class BaseAPI:
log_writer
.
add_scalar
(
"Metrics/Eval(Epoch): {}"
.
format
(
k
),
v
,
i
+
1
)
self
.
save_model
(
save_dir
=
current_save_dir
)
if
hasattr
(
self
,
'use_ema'
):
self
.
exe
.
run
(
self
.
ema
.
restore_program
)
time_eval_one_epoch
=
time
.
time
()
-
eval_epoch_start_time
eval_epoch_start_time
=
time
.
time
()
if
best_model_epoch
>
0
:
...
...
paddlex/cv/models/yolo_v3.py
浏览文件 @
ed8de1ae
...
...
@@ -19,6 +19,8 @@ import os.path as osp
import
numpy
as
np
from
multiprocessing.pool
import
ThreadPool
import
paddle.fluid
as
fluid
from
paddle.fluid.layers.learning_rate_scheduler
import
_decay_step_counter
from
paddle.fluid.optimizer
import
ExponentialMovingAverage
import
paddlex.utils.logging
as
logging
import
paddlex
import
copy
...
...
@@ -28,6 +30,10 @@ from .base import BaseAPI
from
collections
import
OrderedDict
from
.utils.detection_eval
import
eval_results
,
bbox2out
import
random
random
.
seed
(
0
)
np
.
random
.
seed
(
0
)
class
YOLOv3
(
BaseAPI
):
"""构建YOLOv3,并实现其训练、评估、预测和模型导出。
...
...
@@ -50,24 +56,37 @@ class YOLOv3(BaseAPI):
train_random_shapes (list|tuple): 训练时从列表中随机选择图像大小。默认值为[320, 352, 384, 416, 448, 480, 512, 544, 576, 608]。
"""
def
__init__
(
self
,
num_classes
=
80
,
backbone
=
'MobileNetV1'
,
anchors
=
None
,
anchor_masks
=
None
,
ignore_threshold
=
0.7
,
nms_score_threshold
=
0.01
,
nms_topk
=
1000
,
nms_keep_topk
=
100
,
nms_iou_threshold
=
0.45
,
label_smooth
=
False
,
train_random_shapes
=
[
320
,
352
,
384
,
416
,
448
,
480
,
512
,
544
,
576
,
608
]):
def
__init__
(
self
,
num_classes
=
80
,
backbone
=
'MobileNetV1'
,
with_dcn_v2
=
False
,
# YOLO Head
anchors
=
None
,
anchor_masks
=
None
,
use_coord_conv
=
False
,
use_iou_aware
=
False
,
use_spp
=
False
,
use_drop_block
=
False
,
scale_x_y
=
1.0
,
# YOLOv3 Loss
ignore_threshold
=
0.7
,
label_smooth
=
False
,
use_iou_loss
=
False
,
# NMS
use_matrix_nms
=
False
,
nms_score_threshold
=
0.01
,
nms_topk
=
1000
,
nms_keep_topk
=
100
,
nms_iou_threshold
=
0.45
,
train_random_shapes
=
[
320
,
352
,
384
,
416
,
448
,
480
,
512
,
544
,
576
,
608
]):
self
.
init_params
=
locals
()
super
(
YOLOv3
,
self
).
__init__
(
'detector'
)
backbones
=
[
'DarkNet53'
,
'ResNet34'
,
'MobileNetV1'
,
'MobileNetV3_large'
'DarkNet53'
,
'ResNet34'
,
'MobileNetV1'
,
'MobileNetV3_large'
,
'ResNet50_vd'
]
assert
backbone
in
backbones
,
"backbone should be one of {}"
.
format
(
backbones
)
...
...
@@ -75,6 +94,11 @@ class YOLOv3(BaseAPI):
self
.
num_classes
=
num_classes
self
.
anchors
=
anchors
self
.
anchor_masks
=
anchor_masks
if
anchors
is
None
:
self
.
anchors
=
[[
10
,
13
],
[
16
,
30
],
[
33
,
23
],
[
30
,
61
],
[
62
,
45
],
[
59
,
119
],
[
116
,
90
],
[
156
,
198
],
[
373
,
326
]]
if
anchor_masks
is
None
:
self
.
anchor_masks
=
[[
6
,
7
,
8
],
[
3
,
4
,
5
],
[
0
,
1
,
2
]]
self
.
ignore_threshold
=
ignore_threshold
self
.
nms_score_threshold
=
nms_score_threshold
self
.
nms_topk
=
nms_topk
...
...
@@ -84,6 +108,20 @@ class YOLOv3(BaseAPI):
self
.
sync_bn
=
True
self
.
train_random_shapes
=
train_random_shapes
self
.
fixed_input_shape
=
None
self
.
use_fine_grained_loss
=
False
if
use_coord_conv
or
use_iou_aware
or
use_spp
or
use_drop_block
or
use_iou_loss
:
self
.
use_fine_grained_loss
=
True
self
.
use_coord_conv
=
use_coord_conv
self
.
use_iou_aware
=
use_iou_aware
self
.
use_spp
=
use_spp
self
.
use_drop_block
=
use_drop_block
self
.
use_iou_loss
=
use_iou_loss
self
.
scale_x_y
=
scale_x_y
self
.
max_height
=
608
self
.
max_width
=
608
self
.
use_matrix_nms
=
use_matrix_nms
self
.
use_ema
=
False
self
.
with_dcn_v2
=
with_dcn_v2
def
_get_backbone
(
self
,
backbone_name
):
if
backbone_name
==
'DarkNet53'
:
...
...
@@ -102,6 +140,16 @@ class YOLOv3(BaseAPI):
model_name
=
backbone_name
.
split
(
'_'
)[
1
]
backbone
=
paddlex
.
cv
.
nets
.
MobileNetV3
(
norm_type
=
'sync_bn'
,
model_name
=
model_name
)
elif
backbone_name
==
'ResNet50_vd'
:
backbone
=
paddlex
.
cv
.
nets
.
ResNet
(
norm_type
=
'sync_bn'
,
layers
=
50
,
freeze_norm
=
False
,
norm_decay
=
0.
,
feature_maps
=
[
3
,
4
,
5
],
freeze_at
=
0
,
variant
=
'd'
,
dcn_v2_stages
=
[
5
]
if
self
.
with_dcn_v2
else
[])
return
backbone
def
build_net
(
self
,
mode
=
'train'
):
...
...
@@ -117,14 +165,31 @@ class YOLOv3(BaseAPI):
nms_topk
=
self
.
nms_topk
,
nms_keep_topk
=
self
.
nms_keep_topk
,
nms_iou_threshold
=
self
.
nms_iou_threshold
,
train_random_shapes
=
self
.
train_random_shapes
,
fixed_input_shape
=
self
.
fixed_input_shape
)
fixed_input_shape
=
self
.
fixed_input_shape
,
coord_conv
=
self
.
use_coord_conv
,
iou_aware
=
self
.
use_iou_aware
,
scale_x_y
=
self
.
scale_x_y
,
spp
=
self
.
use_spp
,
drop_block
=
self
.
use_drop_block
,
use_matrix_nms
=
self
.
use_matrix_nms
,
use_fine_grained_loss
=
self
.
use_fine_grained_loss
,
use_iou_loss
=
self
.
use_iou_loss
,
batch_size
=
self
.
batch_size_per_gpu
if
hasattr
(
self
,
'batch_size_per_gpu'
)
else
8
)
if
mode
==
'train'
and
self
.
use_iou_loss
or
self
.
use_iou_aware
:
model
.
max_height
=
self
.
max_height
model
.
max_width
=
self
.
max_width
inputs
=
model
.
generate_inputs
()
model_out
=
model
.
build_net
(
inputs
)
outputs
=
OrderedDict
([(
'bbox'
,
model_out
)])
outputs
=
OrderedDict
([(
'bbox'
,
model_out
[
0
]
)])
if
mode
==
'train'
:
self
.
optimizer
.
minimize
(
model_out
)
outputs
=
OrderedDict
([(
'loss'
,
model_out
)])
if
self
.
use_ema
:
global_steps
=
_decay_step_counter
()
self
.
ema
=
ExponentialMovingAverage
(
self
.
ema_decay
,
thres_steps
=
global_steps
)
self
.
ema
.
update
()
return
inputs
,
outputs
def
default_optimizer
(
self
,
learning_rate
,
warmup_steps
,
warmup_start_lr
,
...
...
@@ -172,6 +237,8 @@ class YOLOv3(BaseAPI):
warmup_start_lr
=
0.0
,
lr_decay_epochs
=
[
213
,
240
],
lr_decay_gamma
=
0.1
,
use_ema
=
False
,
ema_decay
=
0.9998
,
metric
=
None
,
use_vdl
=
False
,
sensitivities_file
=
None
,
...
...
@@ -242,6 +309,46 @@ class YOLOv3(BaseAPI):
lr_decay_gamma
=
lr_decay_gamma
,
num_steps_each_epoch
=
num_steps_each_epoch
)
self
.
optimizer
=
optimizer
self
.
use_ema
=
use_ema
self
.
ema_decay
=
ema_decay
self
.
batch_size_per_gpu
=
int
(
train_batch_size
/
paddlex
.
env_info
[
'num'
])
if
self
.
use_fine_grained_loss
:
for
transform
in
train_dataset
.
transforms
.
transforms
:
if
isinstance
(
transform
,
paddlex
.
det
.
transforms
.
Resize
):
self
.
max_height
=
transform
.
target_size
self
.
max_width
=
transform
.
target_size
break
if
train_dataset
.
transforms
.
batch_transforms
is
None
:
train_dataset
.
transforms
.
batch_transforms
=
list
()
define_random_shape
=
False
for
bt
in
train_dataset
.
transforms
.
batch_transforms
:
if
isinstance
(
bt
,
paddlex
.
det
.
transforms
.
BatchRandomShape
):
define_random_shape
=
True
if
not
define_random_shape
:
if
isinstance
(
self
.
train_random_shapes
,
(
list
,
tuple
))
and
len
(
self
.
train_random_shapes
)
>
0
:
train_dataset
.
transforms
.
batch_transforms
.
append
(
paddlex
.
det
.
transforms
.
BatchRandomShape
(
random_shapes
=
self
.
train_random_shapes
))
if
self
.
use_fine_grained_loss
:
self
.
max_height
=
max
(
self
.
max_height
,
max
(
self
.
train_random_shapes
))
self
.
max_width
=
max
(
self
.
max_width
,
max
(
self
.
train_random_shapes
))
if
self
.
use_fine_grained_loss
:
define_generate_target
=
False
for
bt
in
train_dataset
.
transforms
.
batch_transforms
:
if
isinstance
(
bt
,
paddlex
.
det
.
transforms
.
GenerateYoloTarget
):
define_generate_target
=
True
if
not
define_generate_target
:
train_dataset
.
transforms
.
batch_transforms
.
append
(
paddlex
.
det
.
transforms
.
GenerateYoloTarget
(
anchors
=
self
.
anchors
,
anchor_masks
=
self
.
anchor_masks
,
num_classes
=
self
.
num_classes
,
downsample_ratios
=
[
32
,
16
,
8
]))
# 构建训练、验证、预测网络
self
.
build_program
()
# 初始化网络权重
...
...
paddlex/cv/nets/detection/iou_aware.py
0 → 100644
浏览文件 @
ed8de1ae
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
paddle
import
fluid
def
_split_ioup
(
output
,
an_num
,
num_classes
):
"""
Split new output feature map to output, predicted iou
along channel dimension
"""
ioup
=
fluid
.
layers
.
slice
(
output
,
axes
=
[
1
],
starts
=
[
0
],
ends
=
[
an_num
])
ioup
=
fluid
.
layers
.
sigmoid
(
ioup
)
oriout
=
fluid
.
layers
.
slice
(
output
,
axes
=
[
1
],
starts
=
[
an_num
],
ends
=
[
an_num
*
(
num_classes
+
6
)])
return
(
ioup
,
oriout
)
def
_de_sigmoid
(
x
,
eps
=
1e-7
):
x
=
fluid
.
layers
.
clip
(
x
,
eps
,
1
/
eps
)
one
=
fluid
.
layers
.
fill_constant
(
shape
=
[
1
,
1
,
1
,
1
],
dtype
=
x
.
dtype
,
value
=
1.
)
x
=
fluid
.
layers
.
clip
((
one
/
x
-
1.0
),
eps
,
1
/
eps
)
x
=
-
fluid
.
layers
.
log
(
x
)
return
x
def
_postprocess_output
(
ioup
,
output
,
an_num
,
num_classes
,
iou_aware_factor
):
"""
post process output objectness score
"""
tensors
=
[]
stride
=
output
.
shape
[
1
]
//
an_num
for
m
in
range
(
an_num
):
tensors
.
append
(
fluid
.
layers
.
slice
(
output
,
axes
=
[
1
],
starts
=
[
stride
*
m
+
0
],
ends
=
[
stride
*
m
+
4
]))
obj
=
fluid
.
layers
.
slice
(
output
,
axes
=
[
1
],
starts
=
[
stride
*
m
+
4
],
ends
=
[
stride
*
m
+
5
])
obj
=
fluid
.
layers
.
sigmoid
(
obj
)
ip
=
fluid
.
layers
.
slice
(
ioup
,
axes
=
[
1
],
starts
=
[
m
],
ends
=
[
m
+
1
])
new_obj
=
fluid
.
layers
.
pow
(
obj
,
(
1
-
iou_aware_factor
))
*
fluid
.
layers
.
pow
(
ip
,
iou_aware_factor
)
new_obj
=
_de_sigmoid
(
new_obj
)
tensors
.
append
(
new_obj
)
tensors
.
append
(
fluid
.
layers
.
slice
(
output
,
axes
=
[
1
],
starts
=
[
stride
*
m
+
5
],
ends
=
[
stride
*
m
+
5
+
num_classes
]))
output
=
fluid
.
layers
.
concat
(
tensors
,
axis
=
1
)
return
output
def
get_iou_aware_score
(
output
,
an_num
,
num_classes
,
iou_aware_factor
):
ioup
,
output
=
_split_ioup
(
output
,
an_num
,
num_classes
)
output
=
_postprocess_output
(
ioup
,
output
,
an_num
,
num_classes
,
iou_aware_factor
)
return
output
paddlex/cv/nets/detection/loss/iou_aware_loss.py
0 → 100644
浏览文件 @
ed8de1ae
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.initializer
import
NumpyArrayInitializer
from
paddle
import
fluid
from
.iou_loss
import
IouLoss
class
IouAwareLoss
(
IouLoss
):
"""
iou aware loss, see https://arxiv.org/abs/1912.05992
Args:
loss_weight (float): iou aware loss weight, default is 1.0
max_height (int): max height of input to support random shape input
max_width (int): max width of input to support random shape input
"""
def
__init__
(
self
,
loss_weight
=
1.0
,
max_height
=
608
,
max_width
=
608
):
super
(
IouAwareLoss
,
self
).
__init__
(
loss_weight
=
loss_weight
,
max_height
=
max_height
,
max_width
=
max_width
)
def
__call__
(
self
,
ioup
,
x
,
y
,
w
,
h
,
tx
,
ty
,
tw
,
th
,
anchors
,
downsample_ratio
,
batch_size
,
scale_x_y
,
eps
=
1.e-10
):
'''
Args:
ioup ([Variables]): the predicted iou
x | y | w | h ([Variables]): the output of yolov3 for encoded x|y|w|h
tx |ty |tw |th ([Variables]): the target of yolov3 for encoded x|y|w|h
anchors ([float]): list of anchors for current output layer
downsample_ratio (float): the downsample ratio for current output layer
batch_size (int): training batch size
eps (float): the decimal to prevent the denominator eqaul zero
'''
pred
=
self
.
_bbox_transform
(
x
,
y
,
w
,
h
,
anchors
,
downsample_ratio
,
batch_size
,
False
,
scale_x_y
,
eps
)
gt
=
self
.
_bbox_transform
(
tx
,
ty
,
tw
,
th
,
anchors
,
downsample_ratio
,
batch_size
,
True
,
scale_x_y
,
eps
)
iouk
=
self
.
_iou
(
pred
,
gt
,
ioup
,
eps
)
iouk
.
stop_gradient
=
True
loss_iou_aware
=
fluid
.
layers
.
cross_entropy
(
ioup
,
iouk
,
soft_label
=
True
)
loss_iou_aware
=
loss_iou_aware
*
self
.
_loss_weight
return
loss_iou_aware
paddlex/cv/nets/detection/loss/iou_loss.py
0 → 100644
浏览文件 @
ed8de1ae
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.initializer
import
NumpyArrayInitializer
from
paddle
import
fluid
class
IouLoss
(
object
):
"""
iou loss, see https://arxiv.org/abs/1908.03851
loss = 1.0 - iou * iou
Args:
loss_weight (float): iou loss weight, default is 2.5
max_height (int): max height of input to support random shape input
max_width (int): max width of input to support random shape input
ciou_term (bool): whether to add ciou_term
loss_square (bool): whether to square the iou term
"""
def
__init__
(
self
,
loss_weight
=
2.5
,
max_height
=
608
,
max_width
=
608
,
ciou_term
=
False
,
loss_square
=
True
):
self
.
_loss_weight
=
loss_weight
self
.
_MAX_HI
=
max_height
self
.
_MAX_WI
=
max_width
self
.
ciou_term
=
ciou_term
self
.
loss_square
=
loss_square
def
__call__
(
self
,
x
,
y
,
w
,
h
,
tx
,
ty
,
tw
,
th
,
anchors
,
downsample_ratio
,
batch_size
,
scale_x_y
=
1.
,
ioup
=
None
,
eps
=
1.e-10
):
'''
Args:
x | y | w | h ([Variables]): the output of yolov3 for encoded x|y|w|h
tx |ty |tw |th ([Variables]): the target of yolov3 for encoded x|y|w|h
anchors ([float]): list of anchors for current output layer
downsample_ratio (float): the downsample ratio for current output layer
batch_size (int): training batch size
eps (float): the decimal to prevent the denominator eqaul zero
'''
pred
=
self
.
_bbox_transform
(
x
,
y
,
w
,
h
,
anchors
,
downsample_ratio
,
batch_size
,
False
,
scale_x_y
,
eps
)
gt
=
self
.
_bbox_transform
(
tx
,
ty
,
tw
,
th
,
anchors
,
downsample_ratio
,
batch_size
,
True
,
scale_x_y
,
eps
)
iouk
=
self
.
_iou
(
pred
,
gt
,
ioup
,
eps
)
if
self
.
loss_square
:
loss_iou
=
1.
-
iouk
*
iouk
else
:
loss_iou
=
1.
-
iouk
loss_iou
=
loss_iou
*
self
.
_loss_weight
return
loss_iou
def
_iou
(
self
,
pred
,
gt
,
ioup
=
None
,
eps
=
1.e-10
):
x1
,
y1
,
x2
,
y2
=
pred
x1g
,
y1g
,
x2g
,
y2g
=
gt
x2
=
fluid
.
layers
.
elementwise_max
(
x1
,
x2
)
y2
=
fluid
.
layers
.
elementwise_max
(
y1
,
y2
)
xkis1
=
fluid
.
layers
.
elementwise_max
(
x1
,
x1g
)
ykis1
=
fluid
.
layers
.
elementwise_max
(
y1
,
y1g
)
xkis2
=
fluid
.
layers
.
elementwise_min
(
x2
,
x2g
)
ykis2
=
fluid
.
layers
.
elementwise_min
(
y2
,
y2g
)
intsctk
=
(
xkis2
-
xkis1
)
*
(
ykis2
-
ykis1
)
intsctk
=
intsctk
*
fluid
.
layers
.
greater_than
(
xkis2
,
xkis1
)
*
fluid
.
layers
.
greater_than
(
ykis2
,
ykis1
)
unionk
=
(
x2
-
x1
)
*
(
y2
-
y1
)
+
(
x2g
-
x1g
)
*
(
y2g
-
y1g
)
-
intsctk
+
eps
iouk
=
intsctk
/
unionk
if
self
.
ciou_term
:
ciou
=
self
.
get_ciou_term
(
pred
,
gt
,
iouk
,
eps
)
iouk
=
iouk
-
ciou
return
iouk
def
get_ciou_term
(
self
,
pred
,
gt
,
iouk
,
eps
):
x1
,
y1
,
x2
,
y2
=
pred
x1g
,
y1g
,
x2g
,
y2g
=
gt
cx
=
(
x1
+
x2
)
/
2
cy
=
(
y1
+
y2
)
/
2
w
=
(
x2
-
x1
)
+
fluid
.
layers
.
cast
((
x2
-
x1
)
==
0
,
'float32'
)
h
=
(
y2
-
y1
)
+
fluid
.
layers
.
cast
((
y2
-
y1
)
==
0
,
'float32'
)
cxg
=
(
x1g
+
x2g
)
/
2
cyg
=
(
y1g
+
y2g
)
/
2
wg
=
x2g
-
x1g
hg
=
y2g
-
y1g
# A or B
xc1
=
fluid
.
layers
.
elementwise_min
(
x1
,
x1g
)
yc1
=
fluid
.
layers
.
elementwise_min
(
y1
,
y1g
)
xc2
=
fluid
.
layers
.
elementwise_max
(
x2
,
x2g
)
yc2
=
fluid
.
layers
.
elementwise_max
(
y2
,
y2g
)
# DIOU term
dist_intersection
=
(
cx
-
cxg
)
*
(
cx
-
cxg
)
+
(
cy
-
cyg
)
*
(
cy
-
cyg
)
dist_union
=
(
xc2
-
xc1
)
*
(
xc2
-
xc1
)
+
(
yc2
-
yc1
)
*
(
yc2
-
yc1
)
diou_term
=
(
dist_intersection
+
eps
)
/
(
dist_union
+
eps
)
# CIOU term
ciou_term
=
0
ar_gt
=
wg
/
hg
ar_pred
=
w
/
h
arctan
=
fluid
.
layers
.
atan
(
ar_gt
)
-
fluid
.
layers
.
atan
(
ar_pred
)
ar_loss
=
4.
/
np
.
pi
/
np
.
pi
*
arctan
*
arctan
alpha
=
ar_loss
/
(
1
-
iouk
+
ar_loss
+
eps
)
alpha
.
stop_gradient
=
True
ciou_term
=
alpha
*
ar_loss
return
diou_term
+
ciou_term
def
_bbox_transform
(
self
,
dcx
,
dcy
,
dw
,
dh
,
anchors
,
downsample_ratio
,
batch_size
,
is_gt
,
scale_x_y
,
eps
):
grid_x
=
int
(
self
.
_MAX_WI
/
downsample_ratio
)
grid_y
=
int
(
self
.
_MAX_HI
/
downsample_ratio
)
an_num
=
len
(
anchors
)
//
2
shape_fmp
=
fluid
.
layers
.
shape
(
dcx
)
shape_fmp
.
stop_gradient
=
True
# generate the grid_w x grid_h center of feature map
idx_i
=
np
.
array
([[
i
for
i
in
range
(
grid_x
)]])
idx_j
=
np
.
array
([[
j
for
j
in
range
(
grid_y
)]]).
transpose
()
gi_np
=
np
.
repeat
(
idx_i
,
grid_y
,
axis
=
0
)
gi_np
=
np
.
reshape
(
gi_np
,
newshape
=
[
1
,
1
,
grid_y
,
grid_x
])
gi_np
=
np
.
tile
(
gi_np
,
reps
=
[
batch_size
,
an_num
,
1
,
1
])
gj_np
=
np
.
repeat
(
idx_j
,
grid_x
,
axis
=
1
)
gj_np
=
np
.
reshape
(
gj_np
,
newshape
=
[
1
,
1
,
grid_y
,
grid_x
])
gj_np
=
np
.
tile
(
gj_np
,
reps
=
[
batch_size
,
an_num
,
1
,
1
])
gi_max
=
self
.
_create_tensor_from_numpy
(
gi_np
.
astype
(
np
.
float32
))
gi
=
fluid
.
layers
.
crop
(
x
=
gi_max
,
shape
=
dcx
)
gi
.
stop_gradient
=
True
gj_max
=
self
.
_create_tensor_from_numpy
(
gj_np
.
astype
(
np
.
float32
))
gj
=
fluid
.
layers
.
crop
(
x
=
gj_max
,
shape
=
dcx
)
gj
.
stop_gradient
=
True
grid_x_act
=
fluid
.
layers
.
cast
(
shape_fmp
[
3
],
dtype
=
"float32"
)
grid_x_act
.
stop_gradient
=
True
grid_y_act
=
fluid
.
layers
.
cast
(
shape_fmp
[
2
],
dtype
=
"float32"
)
grid_y_act
.
stop_gradient
=
True
if
is_gt
:
cx
=
fluid
.
layers
.
elementwise_add
(
dcx
,
gi
)
/
grid_x_act
cx
.
gradient
=
True
cy
=
fluid
.
layers
.
elementwise_add
(
dcy
,
gj
)
/
grid_y_act
cy
.
gradient
=
True
else
:
dcx_sig
=
fluid
.
layers
.
sigmoid
(
dcx
)
dcy_sig
=
fluid
.
layers
.
sigmoid
(
dcy
)
if
(
abs
(
scale_x_y
-
1.0
)
>
eps
):
dcx_sig
=
scale_x_y
*
dcx_sig
-
0.5
*
(
scale_x_y
-
1
)
dcy_sig
=
scale_x_y
*
dcy_sig
-
0.5
*
(
scale_x_y
-
1
)
cx
=
fluid
.
layers
.
elementwise_add
(
dcx_sig
,
gi
)
/
grid_x_act
cy
=
fluid
.
layers
.
elementwise_add
(
dcy_sig
,
gj
)
/
grid_y_act
anchor_w_
=
[
anchors
[
i
]
for
i
in
range
(
0
,
len
(
anchors
))
if
i
%
2
==
0
]
anchor_w_np
=
np
.
array
(
anchor_w_
)
anchor_w_np
=
np
.
reshape
(
anchor_w_np
,
newshape
=
[
1
,
an_num
,
1
,
1
])
anchor_w_np
=
np
.
tile
(
anchor_w_np
,
reps
=
[
batch_size
,
1
,
grid_y
,
grid_x
])
anchor_w_max
=
self
.
_create_tensor_from_numpy
(
anchor_w_np
.
astype
(
np
.
float32
))
anchor_w
=
fluid
.
layers
.
crop
(
x
=
anchor_w_max
,
shape
=
dcx
)
anchor_w
.
stop_gradient
=
True
anchor_h_
=
[
anchors
[
i
]
for
i
in
range
(
0
,
len
(
anchors
))
if
i
%
2
==
1
]
anchor_h_np
=
np
.
array
(
anchor_h_
)
anchor_h_np
=
np
.
reshape
(
anchor_h_np
,
newshape
=
[
1
,
an_num
,
1
,
1
])
anchor_h_np
=
np
.
tile
(
anchor_h_np
,
reps
=
[
batch_size
,
1
,
grid_y
,
grid_x
])
anchor_h_max
=
self
.
_create_tensor_from_numpy
(
anchor_h_np
.
astype
(
np
.
float32
))
anchor_h
=
fluid
.
layers
.
crop
(
x
=
anchor_h_max
,
shape
=
dcx
)
anchor_h
.
stop_gradient
=
True
# e^tw e^th
exp_dw
=
fluid
.
layers
.
exp
(
dw
)
exp_dh
=
fluid
.
layers
.
exp
(
dh
)
pw
=
fluid
.
layers
.
elementwise_mul
(
exp_dw
,
anchor_w
)
/
\
(
grid_x_act
*
downsample_ratio
)
ph
=
fluid
.
layers
.
elementwise_mul
(
exp_dh
,
anchor_h
)
/
\
(
grid_y_act
*
downsample_ratio
)
if
is_gt
:
exp_dw
.
stop_gradient
=
True
exp_dh
.
stop_gradient
=
True
pw
.
stop_gradient
=
True
ph
.
stop_gradient
=
True
x1
=
cx
-
0.5
*
pw
y1
=
cy
-
0.5
*
ph
x2
=
cx
+
0.5
*
pw
y2
=
cy
+
0.5
*
ph
if
is_gt
:
x1
.
stop_gradient
=
True
y1
.
stop_gradient
=
True
x2
.
stop_gradient
=
True
y2
.
stop_gradient
=
True
return
x1
,
y1
,
x2
,
y2
def
_create_tensor_from_numpy
(
self
,
numpy_array
):
paddle_array
=
fluid
.
layers
.
create_parameter
(
attr
=
ParamAttr
(),
shape
=
numpy_array
.
shape
,
dtype
=
numpy_array
.
dtype
,
default_initializer
=
NumpyArrayInitializer
(
numpy_array
))
paddle_array
.
stop_gradient
=
True
return
paddle_array
paddlex/cv/nets/detection/loss/yolo_loss.py
0 → 100644
浏览文件 @
ed8de1ae
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
paddle
import
fluid
try
:
from
collections.abc
import
Sequence
except
Exception
:
from
collections
import
Sequence
class
YOLOv3Loss
(
object
):
"""
Combined loss for YOLOv3 network
Args:
batch_size (int): training batch size
ignore_thresh (float): threshold to ignore confidence loss
label_smooth (bool): whether to use label smoothing
use_fine_grained_loss (bool): whether use fine grained YOLOv3 loss
instead of fluid.layers.yolov3_loss
"""
def
__init__
(
self
,
batch_size
=
8
,
ignore_thresh
=
0.7
,
label_smooth
=
True
,
use_fine_grained_loss
=
False
,
iou_loss
=
None
,
iou_aware_loss
=
None
,
downsample
=
[
32
,
16
,
8
],
scale_x_y
=
1.
,
match_score
=
False
):
self
.
_batch_size
=
batch_size
self
.
_ignore_thresh
=
ignore_thresh
self
.
_label_smooth
=
label_smooth
self
.
_use_fine_grained_loss
=
use_fine_grained_loss
self
.
_iou_loss
=
iou_loss
self
.
_iou_aware_loss
=
iou_aware_loss
self
.
downsample
=
downsample
self
.
scale_x_y
=
scale_x_y
self
.
match_score
=
match_score
def
__call__
(
self
,
outputs
,
gt_box
,
gt_label
,
gt_score
,
targets
,
anchors
,
anchor_masks
,
mask_anchors
,
num_classes
,
prefix_name
):
if
self
.
_use_fine_grained_loss
:
return
self
.
_get_fine_grained_loss
(
outputs
,
targets
,
gt_box
,
self
.
_batch_size
,
num_classes
,
mask_anchors
,
self
.
_ignore_thresh
)
else
:
losses
=
[]
for
i
,
output
in
enumerate
(
outputs
):
scale_x_y
=
self
.
scale_x_y
if
not
isinstance
(
self
.
scale_x_y
,
Sequence
)
else
self
.
scale_x_y
[
i
]
anchor_mask
=
anchor_masks
[
i
]
loss
=
fluid
.
layers
.
yolov3_loss
(
x
=
output
,
gt_box
=
gt_box
,
gt_label
=
gt_label
,
gt_score
=
gt_score
,
anchors
=
anchors
,
anchor_mask
=
anchor_mask
,
class_num
=
num_classes
,
ignore_thresh
=
self
.
_ignore_thresh
,
downsample_ratio
=
self
.
downsample
[
i
],
use_label_smooth
=
self
.
_label_smooth
,
scale_x_y
=
scale_x_y
,
name
=
prefix_name
+
"yolo_loss"
+
str
(
i
))
losses
.
append
(
fluid
.
layers
.
reduce_mean
(
loss
))
return
{
'loss'
:
sum
(
losses
)}
def
_get_fine_grained_loss
(
self
,
outputs
,
targets
,
gt_box
,
batch_size
,
num_classes
,
mask_anchors
,
ignore_thresh
,
eps
=
1.e-10
):
"""
Calculate fine grained YOLOv3 loss
Args:
outputs ([Variables]): List of Variables, output of backbone stages
targets ([Variables]): List of Variables, The targets for yolo
loss calculatation.
gt_box (Variable): The ground-truth boudding boxes.
batch_size (int): The training batch size
num_classes (int): class num of dataset
mask_anchors ([[float]]): list of anchors in each output layer
ignore_thresh (float): prediction bbox overlap any gt_box greater
than ignore_thresh, objectness loss will
be ignored.
Returns:
Type: dict
xy_loss (Variable): YOLOv3 (x, y) coordinates loss
wh_loss (Variable): YOLOv3 (w, h) coordinates loss
obj_loss (Variable): YOLOv3 objectness score loss
cls_loss (Variable): YOLOv3 classification loss
"""
assert
len
(
outputs
)
==
len
(
targets
),
\
"YOLOv3 output layer number not equal target number"
loss_xys
,
loss_whs
,
loss_objs
,
loss_clss
=
[],
[],
[],
[]
if
self
.
_iou_loss
is
not
None
:
loss_ious
=
[]
if
self
.
_iou_aware_loss
is
not
None
:
loss_iou_awares
=
[]
for
i
,
(
output
,
target
,
anchors
)
in
enumerate
(
zip
(
outputs
,
targets
,
mask_anchors
)):
downsample
=
self
.
downsample
[
i
]
an_num
=
len
(
anchors
)
//
2
if
self
.
_iou_aware_loss
is
not
None
:
ioup
,
output
=
self
.
_split_ioup
(
output
,
an_num
,
num_classes
)
x
,
y
,
w
,
h
,
obj
,
cls
=
self
.
_split_output
(
output
,
an_num
,
num_classes
)
tx
,
ty
,
tw
,
th
,
tscale
,
tobj
,
tcls
=
self
.
_split_target
(
target
)
tscale_tobj
=
tscale
*
tobj
scale_x_y
=
self
.
scale_x_y
if
not
isinstance
(
self
.
scale_x_y
,
Sequence
)
else
self
.
scale_x_y
[
i
]
if
(
abs
(
scale_x_y
-
1.0
)
<
eps
):
loss_x
=
fluid
.
layers
.
sigmoid_cross_entropy_with_logits
(
x
,
tx
)
*
tscale_tobj
loss_x
=
fluid
.
layers
.
reduce_sum
(
loss_x
,
dim
=
[
1
,
2
,
3
])
loss_y
=
fluid
.
layers
.
sigmoid_cross_entropy_with_logits
(
y
,
ty
)
*
tscale_tobj
loss_y
=
fluid
.
layers
.
reduce_sum
(
loss_y
,
dim
=
[
1
,
2
,
3
])
else
:
dx
=
scale_x_y
*
fluid
.
layers
.
sigmoid
(
x
)
-
0.5
*
(
scale_x_y
-
1.0
)
dy
=
scale_x_y
*
fluid
.
layers
.
sigmoid
(
y
)
-
0.5
*
(
scale_x_y
-
1.0
)
loss_x
=
fluid
.
layers
.
abs
(
dx
-
tx
)
*
tscale_tobj
loss_x
=
fluid
.
layers
.
reduce_sum
(
loss_x
,
dim
=
[
1
,
2
,
3
])
loss_y
=
fluid
.
layers
.
abs
(
dy
-
ty
)
*
tscale_tobj
loss_y
=
fluid
.
layers
.
reduce_sum
(
loss_y
,
dim
=
[
1
,
2
,
3
])
# NOTE: we refined loss function of (w, h) as L1Loss
loss_w
=
fluid
.
layers
.
abs
(
w
-
tw
)
*
tscale_tobj
loss_w
=
fluid
.
layers
.
reduce_sum
(
loss_w
,
dim
=
[
1
,
2
,
3
])
loss_h
=
fluid
.
layers
.
abs
(
h
-
th
)
*
tscale_tobj
loss_h
=
fluid
.
layers
.
reduce_sum
(
loss_h
,
dim
=
[
1
,
2
,
3
])
if
self
.
_iou_loss
is
not
None
:
loss_iou
=
self
.
_iou_loss
(
x
,
y
,
w
,
h
,
tx
,
ty
,
tw
,
th
,
anchors
,
downsample
,
self
.
_batch_size
,
scale_x_y
)
loss_iou
=
loss_iou
*
tscale_tobj
loss_iou
=
fluid
.
layers
.
reduce_sum
(
loss_iou
,
dim
=
[
1
,
2
,
3
])
loss_ious
.
append
(
fluid
.
layers
.
reduce_mean
(
loss_iou
))
if
self
.
_iou_aware_loss
is
not
None
:
loss_iou_aware
=
self
.
_iou_aware_loss
(
ioup
,
x
,
y
,
w
,
h
,
tx
,
ty
,
tw
,
th
,
anchors
,
downsample
,
self
.
_batch_size
,
scale_x_y
)
loss_iou_aware
=
loss_iou_aware
*
tobj
loss_iou_aware
=
fluid
.
layers
.
reduce_sum
(
loss_iou_aware
,
dim
=
[
1
,
2
,
3
])
loss_iou_awares
.
append
(
fluid
.
layers
.
reduce_mean
(
loss_iou_aware
))
loss_obj_pos
,
loss_obj_neg
=
self
.
_calc_obj_loss
(
output
,
obj
,
tobj
,
gt_box
,
self
.
_batch_size
,
anchors
,
num_classes
,
downsample
,
self
.
_ignore_thresh
,
scale_x_y
)
loss_cls
=
fluid
.
layers
.
sigmoid_cross_entropy_with_logits
(
cls
,
tcls
)
loss_cls
=
fluid
.
layers
.
elementwise_mul
(
loss_cls
,
tobj
,
axis
=
0
)
loss_cls
=
fluid
.
layers
.
reduce_sum
(
loss_cls
,
dim
=
[
1
,
2
,
3
,
4
])
loss_xys
.
append
(
fluid
.
layers
.
reduce_mean
(
loss_x
+
loss_y
))
loss_whs
.
append
(
fluid
.
layers
.
reduce_mean
(
loss_w
+
loss_h
))
loss_objs
.
append
(
fluid
.
layers
.
reduce_mean
(
loss_obj_pos
+
loss_obj_neg
))
loss_clss
.
append
(
fluid
.
layers
.
reduce_mean
(
loss_cls
))
losses_all
=
{
"loss_xy"
:
fluid
.
layers
.
sum
(
loss_xys
),
"loss_wh"
:
fluid
.
layers
.
sum
(
loss_whs
),
"loss_obj"
:
fluid
.
layers
.
sum
(
loss_objs
),
"loss_cls"
:
fluid
.
layers
.
sum
(
loss_clss
),
}
if
self
.
_iou_loss
is
not
None
:
losses_all
[
"loss_iou"
]
=
fluid
.
layers
.
sum
(
loss_ious
)
if
self
.
_iou_aware_loss
is
not
None
:
losses_all
[
"loss_iou_aware"
]
=
fluid
.
layers
.
sum
(
loss_iou_awares
)
return
losses_all
def
_split_ioup
(
self
,
output
,
an_num
,
num_classes
):
"""
Split output feature map to output, predicted iou
along channel dimension
"""
ioup
=
fluid
.
layers
.
slice
(
output
,
axes
=
[
1
],
starts
=
[
0
],
ends
=
[
an_num
])
ioup
=
fluid
.
layers
.
sigmoid
(
ioup
)
oriout
=
fluid
.
layers
.
slice
(
output
,
axes
=
[
1
],
starts
=
[
an_num
],
ends
=
[
an_num
*
(
num_classes
+
6
)])
return
(
ioup
,
oriout
)
def
_split_output
(
self
,
output
,
an_num
,
num_classes
):
"""
Split output feature map to x, y, w, h, objectness, classification
along channel dimension
"""
x
=
fluid
.
layers
.
strided_slice
(
output
,
axes
=
[
1
],
starts
=
[
0
],
ends
=
[
output
.
shape
[
1
]],
strides
=
[
5
+
num_classes
])
y
=
fluid
.
layers
.
strided_slice
(
output
,
axes
=
[
1
],
starts
=
[
1
],
ends
=
[
output
.
shape
[
1
]],
strides
=
[
5
+
num_classes
])
w
=
fluid
.
layers
.
strided_slice
(
output
,
axes
=
[
1
],
starts
=
[
2
],
ends
=
[
output
.
shape
[
1
]],
strides
=
[
5
+
num_classes
])
h
=
fluid
.
layers
.
strided_slice
(
output
,
axes
=
[
1
],
starts
=
[
3
],
ends
=
[
output
.
shape
[
1
]],
strides
=
[
5
+
num_classes
])
obj
=
fluid
.
layers
.
strided_slice
(
output
,
axes
=
[
1
],
starts
=
[
4
],
ends
=
[
output
.
shape
[
1
]],
strides
=
[
5
+
num_classes
])
clss
=
[]
stride
=
output
.
shape
[
1
]
//
an_num
for
m
in
range
(
an_num
):
clss
.
append
(
fluid
.
layers
.
slice
(
output
,
axes
=
[
1
],
starts
=
[
stride
*
m
+
5
],
ends
=
[
stride
*
m
+
5
+
num_classes
]))
cls
=
fluid
.
layers
.
transpose
(
fluid
.
layers
.
stack
(
clss
,
axis
=
1
),
perm
=
[
0
,
1
,
3
,
4
,
2
])
return
(
x
,
y
,
w
,
h
,
obj
,
cls
)
def
_split_target
(
self
,
target
):
"""
split target to x, y, w, h, objectness, classification
along dimension 2
target is in shape [N, an_num, 6 + class_num, H, W]
"""
tx
=
target
[:,
:,
0
,
:,
:]
ty
=
target
[:,
:,
1
,
:,
:]
tw
=
target
[:,
:,
2
,
:,
:]
th
=
target
[:,
:,
3
,
:,
:]
tscale
=
target
[:,
:,
4
,
:,
:]
tobj
=
target
[:,
:,
5
,
:,
:]
tcls
=
fluid
.
layers
.
transpose
(
target
[:,
:,
6
:,
:,
:],
perm
=
[
0
,
1
,
3
,
4
,
2
])
tcls
.
stop_gradient
=
True
return
(
tx
,
ty
,
tw
,
th
,
tscale
,
tobj
,
tcls
)
def
_calc_obj_loss
(
self
,
output
,
obj
,
tobj
,
gt_box
,
batch_size
,
anchors
,
num_classes
,
downsample
,
ignore_thresh
,
scale_x_y
):
# A prediction bbox overlap any gt_bbox over ignore_thresh,
# objectness loss will be ignored, process as follows:
# 1. get pred bbox, which is same with YOLOv3 infer mode, use yolo_box here
# NOTE: img_size is set as 1.0 to get noramlized pred bbox
bbox
,
prob
=
fluid
.
layers
.
yolo_box
(
x
=
output
,
img_size
=
fluid
.
layers
.
ones
(
shape
=
[
batch_size
,
2
],
dtype
=
"int32"
),
anchors
=
anchors
,
class_num
=
num_classes
,
conf_thresh
=
0.
,
downsample_ratio
=
downsample
,
clip_bbox
=
False
,
scale_x_y
=
scale_x_y
)
# 2. split pred bbox and gt bbox by sample, calculate IoU between pred bbox
# and gt bbox in each sample
if
batch_size
>
1
:
preds
=
fluid
.
layers
.
split
(
bbox
,
batch_size
,
dim
=
0
)
gts
=
fluid
.
layers
.
split
(
gt_box
,
batch_size
,
dim
=
0
)
else
:
preds
=
[
bbox
]
gts
=
[
gt_box
]
probs
=
[
prob
]
ious
=
[]
for
pred
,
gt
in
zip
(
preds
,
gts
):
def
box_xywh2xyxy
(
box
):
x
=
box
[:,
0
]
y
=
box
[:,
1
]
w
=
box
[:,
2
]
h
=
box
[:,
3
]
return
fluid
.
layers
.
stack
(
[
x
-
w
/
2.
,
y
-
h
/
2.
,
x
+
w
/
2.
,
y
+
h
/
2.
,
],
axis
=
1
)
pred
=
fluid
.
layers
.
squeeze
(
pred
,
axes
=
[
0
])
gt
=
box_xywh2xyxy
(
fluid
.
layers
.
squeeze
(
gt
,
axes
=
[
0
]))
ious
.
append
(
fluid
.
layers
.
iou_similarity
(
pred
,
gt
))
iou
=
fluid
.
layers
.
stack
(
ious
,
axis
=
0
)
# 3. Get iou_mask by IoU between gt bbox and prediction bbox,
# Get obj_mask by tobj(holds gt_score), calculate objectness loss
max_iou
=
fluid
.
layers
.
reduce_max
(
iou
,
dim
=-
1
)
iou_mask
=
fluid
.
layers
.
cast
(
max_iou
<=
ignore_thresh
,
dtype
=
"float32"
)
if
self
.
match_score
:
max_prob
=
fluid
.
layers
.
reduce_max
(
prob
,
dim
=-
1
)
iou_mask
=
iou_mask
*
fluid
.
layers
.
cast
(
max_prob
<=
0.25
,
dtype
=
"float32"
)
output_shape
=
fluid
.
layers
.
shape
(
output
)
an_num
=
len
(
anchors
)
//
2
iou_mask
=
fluid
.
layers
.
reshape
(
iou_mask
,
(
-
1
,
an_num
,
output_shape
[
2
],
output_shape
[
3
]))
iou_mask
.
stop_gradient
=
True
# NOTE: tobj holds gt_score, obj_mask holds object existence mask
obj_mask
=
fluid
.
layers
.
cast
(
tobj
>
0.
,
dtype
=
"float32"
)
obj_mask
.
stop_gradient
=
True
# For positive objectness grids, objectness loss should be calculated
# For negative objectness grids, objectness loss is calculated only iou_mask == 1.0
loss_obj
=
fluid
.
layers
.
sigmoid_cross_entropy_with_logits
(
obj
,
obj_mask
)
loss_obj_pos
=
fluid
.
layers
.
reduce_sum
(
loss_obj
*
tobj
,
dim
=
[
1
,
2
,
3
])
loss_obj_neg
=
fluid
.
layers
.
reduce_sum
(
loss_obj
*
(
1.0
-
obj_mask
)
*
iou_mask
,
dim
=
[
1
,
2
,
3
])
return
loss_obj_pos
,
loss_obj_neg
paddlex/cv/nets/detection/ops.py
0 → 100644
浏览文件 @
ed8de1ae
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
from
numbers
import
Integral
import
math
import
six
import
paddle
from
paddle
import
fluid
def
DropBlock
(
input
,
block_size
,
keep_prob
,
is_test
):
if
is_test
:
return
input
def
CalculateGamma
(
input
,
block_size
,
keep_prob
):
input_shape
=
fluid
.
layers
.
shape
(
input
)
feat_shape_tmp
=
fluid
.
layers
.
slice
(
input_shape
,
[
0
],
[
3
],
[
4
])
feat_shape_tmp
=
fluid
.
layers
.
cast
(
feat_shape_tmp
,
dtype
=
"float32"
)
feat_shape_t
=
fluid
.
layers
.
reshape
(
feat_shape_tmp
,
[
1
,
1
,
1
,
1
])
feat_area
=
fluid
.
layers
.
pow
(
feat_shape_t
,
factor
=
2
)
block_shape_t
=
fluid
.
layers
.
fill_constant
(
shape
=
[
1
,
1
,
1
,
1
],
value
=
block_size
,
dtype
=
'float32'
)
block_area
=
fluid
.
layers
.
pow
(
block_shape_t
,
factor
=
2
)
useful_shape_t
=
feat_shape_t
-
block_shape_t
+
1
useful_area
=
fluid
.
layers
.
pow
(
useful_shape_t
,
factor
=
2
)
upper_t
=
feat_area
*
(
1
-
keep_prob
)
bottom_t
=
block_area
*
useful_area
output
=
upper_t
/
bottom_t
return
output
gamma
=
CalculateGamma
(
input
,
block_size
=
block_size
,
keep_prob
=
keep_prob
)
input_shape
=
fluid
.
layers
.
shape
(
input
)
p
=
fluid
.
layers
.
expand_as
(
gamma
,
input
)
input_shape_tmp
=
fluid
.
layers
.
cast
(
input_shape
,
dtype
=
"int64"
)
random_matrix
=
fluid
.
layers
.
uniform_random
(
input_shape_tmp
,
dtype
=
'float32'
,
min
=
0.0
,
max
=
1.0
,
seed
=
1000
)
one_zero_m
=
fluid
.
layers
.
less_than
(
random_matrix
,
p
)
one_zero_m
.
stop_gradient
=
True
one_zero_m
=
fluid
.
layers
.
cast
(
one_zero_m
,
dtype
=
"float32"
)
mask_flag
=
fluid
.
layers
.
pool2d
(
one_zero_m
,
pool_size
=
block_size
,
pool_type
=
'max'
,
pool_stride
=
1
,
pool_padding
=
block_size
//
2
)
mask
=
1.0
-
mask_flag
elem_numel
=
fluid
.
layers
.
reduce_prod
(
input_shape
)
elem_numel_m
=
fluid
.
layers
.
cast
(
elem_numel
,
dtype
=
"float32"
)
elem_numel_m
.
stop_gradient
=
True
elem_sum
=
fluid
.
layers
.
reduce_sum
(
mask
)
elem_sum_m
=
fluid
.
layers
.
cast
(
elem_sum
,
dtype
=
"float32"
)
elem_sum_m
.
stop_gradient
=
True
output
=
input
*
mask
*
elem_numel_m
/
elem_sum_m
return
output
class
MultiClassNMS
(
object
):
def
__init__
(
self
,
score_threshold
=
.
05
,
nms_top_k
=-
1
,
keep_top_k
=
100
,
nms_threshold
=
.
5
,
normalized
=
False
,
nms_eta
=
1.0
,
background_label
=
0
):
super
(
MultiClassNMS
,
self
).
__init__
()
self
.
score_threshold
=
score_threshold
self
.
nms_top_k
=
nms_top_k
self
.
keep_top_k
=
keep_top_k
self
.
nms_threshold
=
nms_threshold
self
.
normalized
=
normalized
self
.
nms_eta
=
nms_eta
self
.
background_label
=
background_label
def
__call__
(
self
,
bboxes
,
scores
):
return
fluid
.
layers
.
multiclass_nms
(
bboxes
=
bboxes
,
scores
=
scores
,
score_threshold
=
self
.
score_threshold
,
nms_top_k
=
self
.
nms_top_k
,
keep_top_k
=
self
.
keep_top_k
,
normalized
=
self
.
normalized
,
nms_threshold
=
self
.
nms_threshold
,
nms_eta
=
self
.
nms_eta
,
background_label
=
self
.
background_label
)
class
MatrixNMS
(
object
):
def
__init__
(
self
,
score_threshold
=
.
05
,
post_threshold
=
.
05
,
nms_top_k
=-
1
,
keep_top_k
=
100
,
use_gaussian
=
False
,
gaussian_sigma
=
2.
,
normalized
=
False
,
background_label
=
0
):
super
(
MatrixNMS
,
self
).
__init__
()
self
.
score_threshold
=
score_threshold
self
.
post_threshold
=
post_threshold
self
.
nms_top_k
=
nms_top_k
self
.
keep_top_k
=
keep_top_k
self
.
normalized
=
normalized
self
.
use_gaussian
=
use_gaussian
self
.
gaussian_sigma
=
gaussian_sigma
self
.
background_label
=
background_label
def
__call__
(
self
,
bboxes
,
scores
):
return
paddle
.
fluid
.
layers
.
matrix_nms
(
bboxes
=
bboxes
,
scores
=
scores
,
score_threshold
=
self
.
score_threshold
,
post_threshold
=
self
.
post_threshold
,
nms_top_k
=
self
.
nms_top_k
,
keep_top_k
=
self
.
keep_top_k
,
normalized
=
self
.
normalized
,
use_gaussian
=
self
.
use_gaussian
,
gaussian_sigma
=
self
.
gaussian_sigma
,
background_label
=
self
.
background_label
)
class
MultiClassSoftNMS
(
object
):
def
__init__
(
self
,
score_threshold
=
0.01
,
keep_top_k
=
300
,
softnms_sigma
=
0.5
,
normalized
=
False
,
background_label
=
0
,
):
super
(
MultiClassSoftNMS
,
self
).
__init__
()
self
.
score_threshold
=
score_threshold
self
.
keep_top_k
=
keep_top_k
self
.
softnms_sigma
=
softnms_sigma
self
.
normalized
=
normalized
self
.
background_label
=
background_label
def
__call__
(
self
,
bboxes
,
scores
):
def
create_tmp_var
(
program
,
name
,
dtype
,
shape
,
lod_level
):
return
program
.
current_block
().
create_var
(
name
=
name
,
dtype
=
dtype
,
shape
=
shape
,
lod_level
=
lod_level
)
def
_soft_nms_for_cls
(
dets
,
sigma
,
thres
):
"""soft_nms_for_cls"""
dets_final
=
[]
while
len
(
dets
)
>
0
:
maxpos
=
np
.
argmax
(
dets
[:,
0
])
dets_final
.
append
(
dets
[
maxpos
].
copy
())
ts
,
tx1
,
ty1
,
tx2
,
ty2
=
dets
[
maxpos
]
scores
=
dets
[:,
0
]
# force remove bbox at maxpos
scores
[
maxpos
]
=
-
1
x1
=
dets
[:,
1
]
y1
=
dets
[:,
2
]
x2
=
dets
[:,
3
]
y2
=
dets
[:,
4
]
eta
=
0
if
self
.
normalized
else
1
areas
=
(
x2
-
x1
+
eta
)
*
(
y2
-
y1
+
eta
)
xx1
=
np
.
maximum
(
tx1
,
x1
)
yy1
=
np
.
maximum
(
ty1
,
y1
)
xx2
=
np
.
minimum
(
tx2
,
x2
)
yy2
=
np
.
minimum
(
ty2
,
y2
)
w
=
np
.
maximum
(
0.0
,
xx2
-
xx1
+
eta
)
h
=
np
.
maximum
(
0.0
,
yy2
-
yy1
+
eta
)
inter
=
w
*
h
ovr
=
inter
/
(
areas
+
areas
[
maxpos
]
-
inter
)
weight
=
np
.
exp
(
-
(
ovr
*
ovr
)
/
sigma
)
scores
=
scores
*
weight
idx_keep
=
np
.
where
(
scores
>=
thres
)
dets
[:,
0
]
=
scores
dets
=
dets
[
idx_keep
]
dets_final
=
np
.
array
(
dets_final
).
reshape
(
-
1
,
5
)
return
dets_final
def
_soft_nms
(
bboxes
,
scores
):
class_nums
=
scores
.
shape
[
-
1
]
softnms_thres
=
self
.
score_threshold
softnms_sigma
=
self
.
softnms_sigma
keep_top_k
=
self
.
keep_top_k
cls_boxes
=
[[]
for
_
in
range
(
class_nums
)]
cls_ids
=
[[]
for
_
in
range
(
class_nums
)]
start_idx
=
1
if
self
.
background_label
==
0
else
0
for
j
in
range
(
start_idx
,
class_nums
):
inds
=
np
.
where
(
scores
[:,
j
]
>=
softnms_thres
)[
0
]
scores_j
=
scores
[
inds
,
j
]
rois_j
=
bboxes
[
inds
,
j
,
:]
if
len
(
bboxes
.
shape
)
>
2
else
bboxes
[
inds
,
:]
dets_j
=
np
.
hstack
((
scores_j
[:,
np
.
newaxis
],
rois_j
)).
astype
(
np
.
float32
,
copy
=
False
)
cls_rank
=
np
.
argsort
(
-
dets_j
[:,
0
])
dets_j
=
dets_j
[
cls_rank
]
cls_boxes
[
j
]
=
_soft_nms_for_cls
(
dets_j
,
sigma
=
softnms_sigma
,
thres
=
softnms_thres
)
cls_ids
[
j
]
=
np
.
array
([
j
]
*
cls_boxes
[
j
].
shape
[
0
]).
reshape
(
-
1
,
1
)
cls_boxes
=
np
.
vstack
(
cls_boxes
[
start_idx
:])
cls_ids
=
np
.
vstack
(
cls_ids
[
start_idx
:])
pred_result
=
np
.
hstack
([
cls_ids
,
cls_boxes
])
# Limit to max_per_image detections **over all classes**
image_scores
=
cls_boxes
[:,
0
]
if
len
(
image_scores
)
>
keep_top_k
:
image_thresh
=
np
.
sort
(
image_scores
)[
-
keep_top_k
]
keep
=
np
.
where
(
cls_boxes
[:,
0
]
>=
image_thresh
)[
0
]
pred_result
=
pred_result
[
keep
,
:]
return
pred_result
def
_batch_softnms
(
bboxes
,
scores
):
batch_offsets
=
bboxes
.
lod
()
bboxes
=
np
.
array
(
bboxes
)
scores
=
np
.
array
(
scores
)
out_offsets
=
[
0
]
pred_res
=
[]
if
len
(
batch_offsets
)
>
0
:
batch_offset
=
batch_offsets
[
0
]
for
i
in
range
(
len
(
batch_offset
)
-
1
):
s
,
e
=
batch_offset
[
i
],
batch_offset
[
i
+
1
]
pred
=
_soft_nms
(
bboxes
[
s
:
e
],
scores
[
s
:
e
])
out_offsets
.
append
(
pred
.
shape
[
0
]
+
out_offsets
[
-
1
])
pred_res
.
append
(
pred
)
else
:
assert
len
(
bboxes
.
shape
)
==
3
assert
len
(
scores
.
shape
)
==
3
for
i
in
range
(
bboxes
.
shape
[
0
]):
pred
=
_soft_nms
(
bboxes
[
i
],
scores
[
i
])
out_offsets
.
append
(
pred
.
shape
[
0
]
+
out_offsets
[
-
1
])
pred_res
.
append
(
pred
)
res
=
fluid
.
LoDTensor
()
res
.
set_lod
([
out_offsets
])
if
len
(
pred_res
)
==
0
:
pred_res
=
np
.
array
([[
1
]],
dtype
=
np
.
float32
)
res
.
set
(
np
.
vstack
(
pred_res
).
astype
(
np
.
float32
),
fluid
.
CPUPlace
())
return
res
pred_result
=
create_tmp_var
(
fluid
.
default_main_program
(),
name
=
'softnms_pred_result'
,
dtype
=
'float32'
,
shape
=
[
-
1
,
6
],
lod_level
=
1
)
fluid
.
layers
.
py_func
(
func
=
_batch_softnms
,
x
=
[
bboxes
,
scores
],
out
=
pred_result
)
return
pred_result
paddlex/cv/nets/detection/yolo_v3.py
浏览文件 @
ed8de1ae
此差异已折叠。
点击以展开。
paddlex/cv/transforms/det_transforms.py
浏览文件 @
ed8de1ae
...
...
@@ -55,6 +55,7 @@ class Compose(DetTransform):
raise
ValueError
(
'The length of transforms '
+
\
'must be equal or larger than 1!'
)
self
.
transforms
=
transforms
self
.
batch_transforms
=
None
self
.
use_mixup
=
False
for
t
in
self
.
transforms
:
if
type
(
t
).
__name__
==
'MixupImage'
:
...
...
@@ -1385,3 +1386,187 @@ class ComposedYOLOv3Transforms(Compose):
mean
=
mean
,
std
=
std
)
]
super
(
ComposedYOLOv3Transforms
,
self
).
__init__
(
transforms
)
class
BatchRandomShape
(
DetTransform
):
"""调整图像大小(resize)。
对batch数据中的每张图像全部resize到random_shapes中任意一个大小。
注意:当插值方式为“RANDOM”时,则随机选取一种插值方式进行resize。
Args:
random_shapes (list): resize大小选择列表。
默认为[320, 352, 384, 416, 448, 480, 512, 544, 576, 608]。
interp (str): resize的插值方式,与opencv的插值方式对应,取值范围为
['NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM']。默认为"RANDOM"。
Raises:
ValueError: 插值方式不在['NEAREST', 'LINEAR', 'CUBIC',
'AREA', 'LANCZOS4', 'RANDOM']中。
"""
# The interpolation mode
interp_dict
=
{
'NEAREST'
:
cv2
.
INTER_NEAREST
,
'LINEAR'
:
cv2
.
INTER_LINEAR
,
'CUBIC'
:
cv2
.
INTER_CUBIC
,
'AREA'
:
cv2
.
INTER_AREA
,
'LANCZOS4'
:
cv2
.
INTER_LANCZOS4
}
def
__init__
(
self
,
random_shapes
=
[
320
,
352
,
384
,
416
,
448
,
480
,
512
,
544
,
576
,
608
],
interp
=
'RANDOM'
):
if
not
(
interp
==
"RANDOM"
or
interp
in
self
.
interp_dict
):
raise
ValueError
(
"interp should be one of {}"
.
format
(
self
.
interp_dict
.
keys
()))
self
.
random_shapes
=
random_shapes
self
.
interp
=
interp
def
__call__
(
self
,
batch_data
):
"""
Args:
batch_data (list): 由与图像相关的各种信息组成的batch数据。
Returns:
list: 由与图像相关的各种信息组成的batch数据。
"""
shape
=
np
.
random
.
choice
(
self
.
random_shapes
)
if
self
.
interp
==
"RANDOM"
:
interp
=
random
.
choice
(
list
(
self
.
interp_dict
.
keys
()))
else
:
interp
=
self
.
interp
for
data_id
,
data
in
enumerate
(
batch_data
):
data_list
=
list
(
data
)
im
=
data_list
[
0
]
im
=
np
.
swapaxes
(
im
,
1
,
0
)
im
=
np
.
swapaxes
(
im
,
1
,
2
)
im
=
resize
(
im
,
shape
,
self
.
interp_dict
[
interp
])
im
=
np
.
swapaxes
(
im
,
1
,
2
)
im
=
np
.
swapaxes
(
im
,
1
,
0
)
data_list
[
0
]
=
im
batch_data
[
data_id
]
=
tuple
(
data_list
)
return
batch_data
class
GenerateYoloTarget
(
object
):
"""生成YOLOv3的ground truth(真实标注框)在不同特征层的位置转换信息。
该transform只在YOLOv3计算细粒度loss时使用。
Args:
anchors (list|tuple): anchor框的宽度和高度。
anchor_masks (list|tuple): 在计算损失时,使用anchor的mask索引。
num_classes (int): 类别数。默认为80。
iou_thresh (float): iou阈值,当anchor和真实标注框的iou大于该阈值时,计入target。默认为1.0。
"""
def
__init__
(
self
,
anchors
,
anchor_masks
,
downsample_ratios
,
num_classes
=
80
,
iou_thresh
=
1.
):
super
(
GenerateYoloTarget
,
self
).
__init__
()
self
.
anchors
=
anchors
self
.
anchor_masks
=
anchor_masks
self
.
downsample_ratios
=
downsample_ratios
self
.
num_classes
=
num_classes
self
.
iou_thresh
=
iou_thresh
def
__call__
(
self
,
batch_data
):
"""
Args:
batch_data (list): 由与图像相关的各种信息组成的batch数据。
Returns:
list: 由与图像相关的各种信息组成的batch数据。
其中,每个数据新添加的字段为:
- target0 (np.ndarray): YOLOv3的ground truth在特征层0的位置转换信息,
形状为(特征层0的anchor数量, 6+类别数, 特征层0的h, 特征层0的w)。
- target1 (np.ndarray): YOLOv3的ground truth在特征层1的位置转换信息,
形状为(特征层1的anchor数量, 6+类别数, 特征层1的h, 特征层1的w)。
- ...
-targetn (np.ndarray): YOLOv3的ground truth在特征层n的位置转换信息,
形状为(特征层n的anchor数量, 6+类别数, 特征层n的h, 特征层n的w)。
n的是大小由anchor_masks的长度决定。
"""
im
=
batch_data
[
0
][
0
]
h
=
im
.
shape
[
1
]
w
=
im
.
shape
[
2
]
an_hw
=
np
.
array
(
self
.
anchors
)
/
np
.
array
([[
w
,
h
]])
for
data_id
,
data
in
enumerate
(
batch_data
):
gt_bbox
=
data
[
1
]
gt_class
=
data
[
2
]
gt_score
=
data
[
3
]
im_shape
=
data
[
4
]
origin_h
=
float
(
im_shape
[
0
])
origin_w
=
float
(
im_shape
[
1
])
data_list
=
list
(
data
)
for
i
,
(
mask
,
downsample_ratio
)
in
enumerate
(
zip
(
self
.
anchor_masks
,
self
.
downsample_ratios
)):
grid_h
=
int
(
h
/
downsample_ratio
)
grid_w
=
int
(
w
/
downsample_ratio
)
target
=
np
.
zeros
(
(
len
(
mask
),
6
+
self
.
num_classes
,
grid_h
,
grid_w
),
dtype
=
np
.
float32
)
for
b
in
range
(
gt_bbox
.
shape
[
0
]):
gx
=
gt_bbox
[
b
,
0
]
/
float
(
origin_w
)
gy
=
gt_bbox
[
b
,
1
]
/
float
(
origin_h
)
gw
=
gt_bbox
[
b
,
2
]
/
float
(
origin_w
)
gh
=
gt_bbox
[
b
,
3
]
/
float
(
origin_h
)
cls
=
gt_class
[
b
]
score
=
gt_score
[
b
]
if
gw
<=
0.
or
gh
<=
0.
or
score
<=
0.
:
continue
# find best match anchor index
best_iou
=
0.
best_idx
=
-
1
for
an_idx
in
range
(
an_hw
.
shape
[
0
]):
iou
=
jaccard_overlap
(
[
0.
,
0.
,
gw
,
gh
],
[
0.
,
0.
,
an_hw
[
an_idx
,
0
],
an_hw
[
an_idx
,
1
]])
if
iou
>
best_iou
:
best_iou
=
iou
best_idx
=
an_idx
gi
=
int
(
gx
*
grid_w
)
gj
=
int
(
gy
*
grid_h
)
# gtbox should be regresed in this layes if best match
# anchor index in anchor mask of this layer
if
best_idx
in
mask
:
best_n
=
mask
.
index
(
best_idx
)
# x, y, w, h, scale
target
[
best_n
,
0
,
gj
,
gi
]
=
gx
*
grid_w
-
gi
target
[
best_n
,
1
,
gj
,
gi
]
=
gy
*
grid_h
-
gj
target
[
best_n
,
2
,
gj
,
gi
]
=
np
.
log
(
gw
*
w
/
self
.
anchors
[
best_idx
][
0
])
target
[
best_n
,
3
,
gj
,
gi
]
=
np
.
log
(
gh
*
h
/
self
.
anchors
[
best_idx
][
1
])
target
[
best_n
,
4
,
gj
,
gi
]
=
2.0
-
gw
*
gh
# objectness record gt_score
target
[
best_n
,
5
,
gj
,
gi
]
=
score
# classification
target
[
best_n
,
6
+
cls
,
gj
,
gi
]
=
1.
# For non-matched anchors, calculate the target if the iou
# between anchor and gt is larger than iou_thresh
if
self
.
iou_thresh
<
1
:
for
idx
,
mask_i
in
enumerate
(
mask
):
if
mask_i
==
best_idx
:
continue
iou
=
jaccard_overlap
(
[
0.
,
0.
,
gw
,
gh
],
[
0.
,
0.
,
an_hw
[
mask_i
,
0
],
an_hw
[
mask_i
,
1
]])
if
iou
>
self
.
iou_thresh
:
# x, y, w, h, scale
target
[
idx
,
0
,
gj
,
gi
]
=
gx
*
grid_w
-
gi
target
[
idx
,
1
,
gj
,
gi
]
=
gy
*
grid_h
-
gj
target
[
idx
,
2
,
gj
,
gi
]
=
np
.
log
(
gw
*
w
/
self
.
anchors
[
mask_i
][
0
])
target
[
idx
,
3
,
gj
,
gi
]
=
np
.
log
(
gh
*
h
/
self
.
anchors
[
mask_i
][
1
])
target
[
idx
,
4
,
gj
,
gi
]
=
2.0
-
gw
*
gh
# objectness record gt_score
target
[
idx
,
5
,
gj
,
gi
]
=
score
# classification
target
[
idx
,
6
+
cls
,
gj
,
gi
]
=
1.
data_list
.
append
(
target
)
batch_data
[
data_id
]
=
tuple
(
data_list
)
return
batch_data
tutorials/train/object_detection/ppyolo.py
0 → 100644
浏览文件 @
ed8de1ae
# 环境变量配置,用于控制是否使用GPU
# 说明文档:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html#gpu
import
os
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
'0'
from
paddlex.det
import
transforms
import
paddlex
as
pdx
# 下载和解压昆虫检测数据集
insect_dataset
=
'https://bj.bcebos.com/paddlex/datasets/insect_det.tar.gz'
pdx
.
utils
.
download_and_decompress
(
insect_dataset
,
path
=
'./'
)
# 定义训练和验证时的transforms
# API说明 https://paddlex.readthedocs.io/zh_CN/develop/apis/transforms/det_transforms.html
train_transforms
=
transforms
.
Compose
([
transforms
.
MixupImage
(
mixup_epoch
=
250
),
transforms
.
RandomDistort
(),
transforms
.
RandomExpand
(),
transforms
.
RandomCrop
(),
transforms
.
Resize
(
target_size
=
608
,
interp
=
'RANDOM'
),
transforms
.
RandomHorizontalFlip
(),
transforms
.
Normalize
()
])
eval_transforms
=
transforms
.
Compose
([
transforms
.
Resize
(
target_size
=
608
,
interp
=
'CUBIC'
),
transforms
.
Normalize
()
])
# 定义训练和验证所用的数据集
# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/datasets.html#paddlex-datasets-vocdetection
train_dataset
=
pdx
.
datasets
.
VOCDetection
(
data_dir
=
'insect_det'
,
file_list
=
'insect_det/train_list.txt'
,
label_list
=
'insect_det/labels.txt'
,
transforms
=
train_transforms
,
shuffle
=
True
)
eval_dataset
=
pdx
.
datasets
.
VOCDetection
(
data_dir
=
'insect_det'
,
file_list
=
'insect_det/val_list.txt'
,
label_list
=
'insect_det/labels.txt'
,
transforms
=
eval_transforms
)
# 初始化模型,并进行训练
# 可使用VisualDL查看训练指标,参考https://paddlex.readthedocs.io/zh_CN/develop/train/visualdl.html
num_classes
=
len
(
train_dataset
.
labels
)
# API说明: https://paddlex.readthedocs.io/zh_CN/develop/apis/models/detection.html#paddlex-det-yolov3
model
=
pdx
.
det
.
YOLOv3
(
num_classes
=
num_classes
,
backbone
=
'ResNet50_vd'
,
with_dcn_v2
=
True
,
use_coord_conv
=
True
,
use_iou_aware
=
True
,
use_spp
=
True
,
use_drop_block
=
True
,
scale_x_y
=
1.05
,
use_iou_loss
=
True
,
use_matrix_nms
=
True
)
# API说明: https://paddlex.readthedocs.io/zh_CN/develop/apis/models/detection.html#train
# 各参数介绍与调整说明:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html
model
.
train
(
num_epochs
=
270
,
train_dataset
=
train_dataset
,
train_batch_size
=
8
,
eval_dataset
=
eval_dataset
,
learning_rate
=
0.000125
,
lr_decay_epochs
=
[
210
,
240
],
use_ema
=
True
,
save_dir
=
'output/ppyolo'
,
use_vdl
=
True
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录