Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleX
提交
91a601e2
P
PaddleX
项目概览
PaddlePaddle
/
PaddleX
通知
138
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
43
列表
看板
标记
里程碑
合并请求
5
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleX
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
43
Issue
43
列表
看板
标记
里程碑
合并请求
5
合并请求
5
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
91a601e2
编写于
5月 28, 2020
作者:
S
sunyanfang01
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add yolo with iou aware loss
上级
948032a7
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
1109 addition
and
179 deletion
+1109
-179
paddlex/cv/datasets/dataset.py
paddlex/cv/datasets/dataset.py
+13
-10
paddlex/cv/models/utils/pretrain_weights.py
paddlex/cv/models/utils/pretrain_weights.py
+18
-15
paddlex/cv/models/yolo_v3.py
paddlex/cv/models/yolo_v3.py
+65
-17
paddlex/cv/nets/detection/iou_aware.py
paddlex/cv/nets/detection/iou_aware.py
+80
-0
paddlex/cv/nets/detection/loss/iou_aware_loss.py
paddlex/cv/nets/detection/loss/iou_aware_loss.py
+73
-0
paddlex/cv/nets/detection/loss/iou_loss.py
paddlex/cv/nets/detection/loss/iou_loss.py
+230
-0
paddlex/cv/nets/detection/loss/yolo_loss.py
paddlex/cv/nets/detection/loss/yolo_loss.py
+339
-0
paddlex/cv/nets/detection/yolo_v3.py
paddlex/cv/nets/detection/yolo_v3.py
+107
-37
paddlex/cv/transforms/det_transforms.py
paddlex/cv/transforms/det_transforms.py
+184
-100
未找到文件。
paddlex/cv/datasets/dataset.py
浏览文件 @
91a601e2
...
...
@@ -114,7 +114,7 @@ def multithread_reader(mapper,
while
not
isinstance
(
sample
,
EndSignal
):
batch_data
.
append
(
sample
)
if
len
(
batch_data
)
==
batch_size
:
batch_data
=
GenerateMiniBatch
(
batch_data
)
batch_data
=
GenerateMiniBatch
(
batch_data
,
mapper
)
yield
batch_data
batch_data
=
[]
sample
=
out_queue
.
get
()
...
...
@@ -126,11 +126,11 @@ def multithread_reader(mapper,
else
:
batch_data
.
append
(
sample
)
if
len
(
batch_data
)
==
batch_size
:
batch_data
=
GenerateMiniBatch
(
batch_data
)
batch_data
=
GenerateMiniBatch
(
batch_data
,
mapper
)
yield
batch_data
batch_data
=
[]
if
not
drop_last
and
len
(
batch_data
)
!=
0
:
batch_data
=
GenerateMiniBatch
(
batch_data
)
batch_data
=
GenerateMiniBatch
(
batch_data
,
mapper
)
yield
batch_data
batch_data
=
[]
...
...
@@ -187,18 +187,21 @@ def multiprocess_reader(mapper,
else
:
batch_data
.
append
(
sample
)
if
len
(
batch_data
)
==
batch_size
:
batch_data
=
GenerateMiniBatch
(
batch_data
)
batch_data
=
GenerateMiniBatch
(
batch_data
,
mapper
)
yield
batch_data
batch_data
=
[]
if
len
(
batch_data
)
!=
0
and
not
drop_last
:
batch_data
=
GenerateMiniBatch
(
batch_data
)
batch_data
=
GenerateMiniBatch
(
batch_data
,
mapper
)
yield
batch_data
batch_data
=
[]
return
queue_reader
def
GenerateMiniBatch
(
batch_data
):
def
GenerateMiniBatch
(
batch_data
,
mapper
):
if
mapper
.
batch_transforms
is
not
None
:
for
op
in
mapper
.
batch_transforms
:
batch_data
=
op
(
batch_data
)
if
len
(
batch_data
)
==
1
:
return
batch_data
width
=
[
data
[
0
].
shape
[
2
]
for
data
in
batch_data
]
...
...
@@ -209,8 +212,8 @@ def GenerateMiniBatch(batch_data):
padding_batch
=
[]
for
data
in
batch_data
:
im_c
,
im_h
,
im_w
=
data
[
0
].
shape
[:]
padding_im
=
np
.
zeros
(
(
im_c
,
max_shape
[
1
],
max_shape
[
2
]),
dtype
=
np
.
float32
)
padding_im
=
np
.
zeros
(
(
im_c
,
max_shape
[
1
],
max_shape
[
2
]),
dtype
=
np
.
float32
)
padding_im
[:,
:
im_h
,
:
im_w
]
=
data
[
0
]
padding_batch
.
append
((
padding_im
,
)
+
data
[
1
:])
return
padding_batch
...
...
@@ -226,8 +229,8 @@ class Dataset:
if
num_workers
==
'auto'
:
import
multiprocessing
as
mp
num_workers
=
mp
.
cpu_count
()
//
2
if
mp
.
cpu_count
()
//
2
<
8
else
8
if
platform
.
platform
().
startswith
(
"Darwin"
)
or
platform
.
platform
(
).
startswith
(
"Windows"
):
if
platform
.
platform
().
startswith
(
"Darwin"
)
or
platform
.
platform
(
).
startswith
(
"Windows"
):
parallel_method
=
'thread'
if
transforms
is
None
:
raise
Exception
(
"transform should be defined."
)
...
...
paddlex/cv/models/utils/pretrain_weights.py
浏览文件 @
91a601e2
...
...
@@ -56,20 +56,11 @@ image_pretrain = {
'https://paddle-imagenet-models-name.bj.bcebos.com/Xception65_deeplab_pretrained.tar'
,
'ShuffleNetV2'
:
'https://paddle-imagenet-models-name.bj.bcebos.com/ShuffleNetV2_pretrained.tar'
,
'HRNet_W18'
:
'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W18_C_pretrained.tar'
,
'HRNet_W30'
:
'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W30_C_pretrained.tar'
,
'HRNet_W32'
:
'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W32_C_pretrained.tar'
,
'HRNet_W40'
:
'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W40_C_pretrained.tar'
,
'HRNet_W48'
:
'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W48_C_pretrained.tar'
,
'HRNet_W60'
:
'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W60_C_pretrained.tar'
,
'HRNet_W64'
:
'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W64_C_pretrained.tar'
,
}
obj365_pretrain
=
{
'ResNet50_vd_dcn_db_obj365'
:
'https://paddlemodels.bj.bcebos.com/object_detection/ResNet50_vd_dcn_db_obj365_pretrained.tar'
,
}
coco_pretrain
=
{
...
...
@@ -117,6 +108,18 @@ def get_pretrain_weights(flag, model_type, backbone, save_dir):
raise
Exception
(
"Unexpected error, please make sure paddlehub >= 1.6.2"
)
return
osp
.
join
(
new_save_dir
,
backbone
)
elif
flag
==
'Object365'
:
new_save_dir
=
save_dir
if
hasattr
(
paddlex
,
'pretrain_dir'
):
new_save_dir
=
paddlex
.
pretrain_dir
if
backbone
==
'ResNet50_vd'
:
backbone
=
'ResNet50_vd_dcn_db_obj365'
assert
backbone
in
obj365_pretrain
,
"There is not Object365 pretrain weights for {}, you may try ImageNet."
.
format
(
backbone
)
url
=
obj365_pretrain
[
backbone
]
fname
=
osp
.
split
(
url
)[
-
1
].
split
(
'.'
)[
0
]
paddlex
.
utils
.
download_and_decompress
(
url
,
path
=
new_save_dir
)
return
osp
.
join
(
new_save_dir
,
fname
)
elif
flag
==
'COCO'
:
new_save_dir
=
save_dir
if
hasattr
(
paddlex
,
'pretrain_dir'
):
...
...
@@ -144,5 +147,5 @@ def get_pretrain_weights(flag, model_type, backbone, save_dir):
return
osp
.
join
(
new_save_dir
,
backbone
)
else
:
raise
Exception
(
"pretrain_weights need to be defined as directory path or `IMAGENET` or 'COCO' (download pretrain weights automatically)."
"pretrain_weights need to be defined as directory path or `IMAGENET` or
`Object365` or
'COCO' (download pretrain weights automatically)."
)
paddlex/cv/models/yolo_v3.py
浏览文件 @
91a601e2
...
...
@@ -58,13 +58,18 @@ class YOLOv3(BaseAPI):
nms_keep_topk
=
100
,
nms_iou_threshold
=
0.45
,
label_smooth
=
False
,
use_iou_loss
=
False
,
use_iou_aware_loss
=
False
,
iou_aware_factor
=
0.4
,
use_drop_block
=
False
,
use_dcn_v2
=
False
,
train_random_shapes
=
[
320
,
352
,
384
,
416
,
448
,
480
,
512
,
544
,
576
,
608
]):
self
.
init_params
=
locals
()
super
(
YOLOv3
,
self
).
__init__
(
'detector'
)
backbones
=
[
'DarkNet53'
,
'ResNet34'
,
'MobileNetV1'
,
'MobileNetV3_large'
'DarkNet53'
,
'ResNet34'
,
'MobileNetV1'
,
'MobileNetV3_large'
,
'ResNet50_vd'
]
assert
backbone
in
backbones
,
"backbone should be one of {}"
.
format
(
backbones
)
...
...
@@ -79,8 +84,18 @@ class YOLOv3(BaseAPI):
self
.
nms_iou_threshold
=
nms_iou_threshold
self
.
label_smooth
=
label_smooth
self
.
sync_bn
=
True
self
.
use_iou_loss
=
use_iou_loss
self
.
use_iou_aware_loss
=
use_iou_aware_loss
self
.
iou_aware_factor
=
iou_aware_factor
self
.
use_drop_block
=
use_drop_block
self
.
use_dcn_v2
=
use_dcn_v2
self
.
train_random_shapes
=
train_random_shapes
self
.
fixed_input_shape
=
None
if
self
.
anchors
is
None
:
self
.
anchors
=
[[
10
,
13
],
[
16
,
30
],
[
33
,
23
],
[
30
,
61
],
[
62
,
45
],
[
59
,
119
],
[
116
,
90
],
[
156
,
198
],
[
373
,
326
]]
if
self
.
anchor_masks
is
None
:
self
.
anchor_masks
=
[[
6
,
7
,
8
],
[
3
,
4
,
5
],
[
0
,
1
,
2
]]
def
_get_backbone
(
self
,
backbone_name
):
if
backbone_name
==
'DarkNet53'
:
...
...
@@ -93,6 +108,16 @@ class YOLOv3(BaseAPI):
norm_decay
=
0.
,
feature_maps
=
[
3
,
4
,
5
],
freeze_at
=
0
)
elif
backbone_name
==
'ResNet50_vd'
:
backbone
=
paddlex
.
cv
.
nets
.
ResNet
(
norm_type
=
'sync_bn'
,
layers
=
50
,
variant
=
'd'
,
freeze_norm
=
False
,
norm_decay
=
0.
,
feature_maps
=
[
3
,
4
,
5
],
freeze_at
=
0
,
dcn_v2_stages
=
[
5
]
if
self
.
use_dcn_v2
else
[])
elif
backbone_name
==
'MobileNetV1'
:
backbone
=
paddlex
.
cv
.
nets
.
MobileNetV1
(
norm_type
=
'sync_bn'
)
elif
backbone_name
.
startswith
(
'MobileNetV3'
):
...
...
@@ -115,7 +140,12 @@ class YOLOv3(BaseAPI):
nms_keep_topk
=
self
.
nms_keep_topk
,
nms_iou_threshold
=
self
.
nms_iou_threshold
,
train_random_shapes
=
self
.
train_random_shapes
,
fixed_input_shape
=
self
.
fixed_input_shape
)
fixed_input_shape
=
self
.
fixed_input_shape
,
use_iou_loss
=
self
.
use_iou_loss
,
use_iou_aware_loss
=
self
.
use_iou_aware_loss
,
iou_aware_factor
=
self
.
iou_aware_factor
,
use_drop_block
=
self
.
use_drop_block
,
batch_size
=
self
.
train_batch_size
if
hasattr
(
self
,
'train_batch_size'
)
else
8
)
inputs
=
model
.
generate_inputs
()
model_out
=
model
.
build_net
(
inputs
)
outputs
=
OrderedDict
([(
'bbox'
,
model_out
)])
...
...
@@ -217,7 +247,22 @@ class YOLOv3(BaseAPI):
assert
metric
in
[
'COCO'
,
'VOC'
],
"Metric only support 'VOC' or 'COCO'"
self
.
metric
=
metric
self
.
train_batch_size
=
train_batch_size
self
.
labels
=
train_dataset
.
labels
if
self
.
use_iou_loss
or
self
.
use_iou_aware_loss
:
if
self
.
train_random_shapes
is
None
or
len
(
self
.
train_random_shapes
)
==
0
:
for
transform
in
train_dataset
.
transforms
.
transforms
:
if
isinstance
(
transform
,
paddlex
.
det
.
transforms
.
Resize
):
self
.
train_random_shapes
=
[
transform
.
target_size
]
break
train_dataset
.
transforms
.
batch_transforms
=
[]
reshape_bt
=
paddlex
.
det
.
transforms
.
RandomShape
train_dataset
.
transforms
.
batch_transforms
.
append
(
reshape_bt
(
random_shapes
=
self
.
train_random_shapes
))
iou_bt
=
paddlex
.
det
.
transforms
.
GenerateYoloTarget
train_dataset
.
transforms
.
batch_transforms
.
append
(
iou_bt
(
anchors
=
self
.
anchors
,
anchor_masks
=
self
.
anchor_masks
,
num_classes
=
self
.
num_classes
))
# 构建训练网络
if
optimizer
is
None
:
# 构建默认的优化策略
...
...
@@ -306,10 +351,11 @@ class YOLOv3(BaseAPI):
images
=
np
.
array
([
d
[
0
]
for
d
in
data
])
im_sizes
=
np
.
array
([
d
[
1
]
for
d
in
data
])
feed_data
=
{
'image'
:
images
,
'im_size'
:
im_sizes
}
outputs
=
self
.
exe
.
run
(
self
.
test_prog
,
feed
=
[
feed_data
],
fetch_list
=
list
(
self
.
test_outputs
.
values
()),
return_numpy
=
False
)
outputs
=
self
.
exe
.
run
(
self
.
test_prog
,
feed
=
[
feed_data
],
fetch_list
=
list
(
self
.
test_outputs
.
values
()),
return_numpy
=
False
)
res
=
{
'bbox'
:
(
np
.
array
(
outputs
[
0
]),
outputs
[
0
].
recursive_sequence_lengths
())
...
...
@@ -325,13 +371,13 @@ class YOLOv3(BaseAPI):
res
[
'gt_label'
]
=
(
res_gt_label
,
[])
res
[
'is_difficult'
]
=
(
res_is_difficult
,
[])
results
.
append
(
res
)
logging
.
debug
(
"[EVAL] Epoch={}, Step={}/{}"
.
format
(
epoch_id
,
step
+
1
,
total_steps
))
logging
.
debug
(
"[EVAL] Epoch={}, Step={}/{}"
.
format
(
epoch_id
,
step
+
1
,
total_steps
))
box_ap_stats
,
eval_details
=
eval_results
(
results
,
metric
,
eval_dataset
.
coco_gt
,
with_background
=
False
)
evaluate_metrics
=
OrderedDict
(
zip
([
'bbox_mmap'
if
metric
==
'COCO'
else
'bbox_map'
],
box_ap_stats
))
zip
([
'bbox_mmap'
if
metric
==
'COCO'
else
'bbox_map'
],
box_ap_stats
))
if
return_details
:
return
evaluate_metrics
,
eval_details
return
evaluate_metrics
...
...
@@ -345,8 +391,7 @@ class YOLOv3(BaseAPI):
Returns:
list: 预测结果列表,每个预测结果由预测框类别标签、
预测框类别名称、预测框坐标(坐标格式为[xmin, ymin, w, h])、
预测框得分组成。
预测框类别名称、预测框坐标、预测框得分组成。
"""
if
transforms
is
None
and
not
hasattr
(
self
,
'test_transforms'
):
raise
Exception
(
"transforms need to be defined, now is None."
)
...
...
@@ -359,11 +404,14 @@ class YOLOv3(BaseAPI):
im
,
im_size
=
self
.
test_transforms
(
img_file
)
im
=
np
.
expand_dims
(
im
,
axis
=
0
)
im_size
=
np
.
expand_dims
(
im_size
,
axis
=
0
)
outputs
=
self
.
exe
.
run
(
self
.
test_prog
,
feed
=
{
'image'
:
im
,
'im_size'
:
im_size
},
fetch_list
=
list
(
self
.
test_outputs
.
values
()),
return_numpy
=
False
)
outputs
=
self
.
exe
.
run
(
self
.
test_prog
,
feed
=
{
'image'
:
im
,
'im_size'
:
im_size
},
fetch_list
=
list
(
self
.
test_outputs
.
values
()),
return_numpy
=
False
)
res
=
{
k
:
(
np
.
array
(
v
),
v
.
recursive_sequence_lengths
())
for
k
,
v
in
zip
(
list
(
self
.
test_outputs
.
keys
()),
outputs
)
...
...
paddlex/cv/nets/detection/iou_aware.py
0 → 100644
浏览文件 @
91a601e2
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import
paddle.fluid
as
fluid
def
_split_ioup
(
output
,
an_num
,
num_classes
):
"""
Split new output feature map to output, predicted iou
along channel dimension
"""
ioup
=
fluid
.
layers
.
slice
(
output
,
axes
=
[
1
],
starts
=
[
0
],
ends
=
[
an_num
])
ioup
=
fluid
.
layers
.
sigmoid
(
ioup
)
oriout
=
fluid
.
layers
.
slice
(
output
,
axes
=
[
1
],
starts
=
[
an_num
],
ends
=
[
an_num
*
(
num_classes
+
6
)])
return
(
ioup
,
oriout
)
def
_de_sigmoid
(
x
,
eps
=
1e-7
):
x
=
fluid
.
layers
.
clip
(
x
,
eps
,
1
/
eps
)
one
=
fluid
.
layers
.
fill_constant
(
shape
=
[
1
,
1
,
1
,
1
],
dtype
=
x
.
dtype
,
value
=
1.
)
x
=
fluid
.
layers
.
clip
((
one
/
x
-
1.0
),
eps
,
1
/
eps
)
x
=
-
fluid
.
layers
.
log
(
x
)
return
x
def
_postprocess_output
(
ioup
,
output
,
an_num
,
num_classes
,
iou_aware_factor
):
"""
post process output objectness score
"""
tensors
=
[]
stride
=
output
.
shape
[
1
]
//
an_num
for
m
in
range
(
an_num
):
tensors
.
append
(
fluid
.
layers
.
slice
(
output
,
axes
=
[
1
],
starts
=
[
stride
*
m
+
0
],
ends
=
[
stride
*
m
+
4
]))
obj
=
fluid
.
layers
.
slice
(
output
,
axes
=
[
1
],
starts
=
[
stride
*
m
+
4
],
ends
=
[
stride
*
m
+
5
])
obj
=
fluid
.
layers
.
sigmoid
(
obj
)
ip
=
fluid
.
layers
.
slice
(
ioup
,
axes
=
[
1
],
starts
=
[
m
],
ends
=
[
m
+
1
])
new_obj
=
fluid
.
layers
.
pow
(
obj
,
(
1
-
iou_aware_factor
))
*
fluid
.
layers
.
pow
(
ip
,
iou_aware_factor
)
new_obj
=
_de_sigmoid
(
new_obj
)
tensors
.
append
(
new_obj
)
tensors
.
append
(
fluid
.
layers
.
slice
(
output
,
axes
=
[
1
],
starts
=
[
stride
*
m
+
5
],
ends
=
[
stride
*
m
+
5
+
num_classes
]))
output
=
fluid
.
layers
.
concat
(
tensors
,
axis
=
1
)
return
output
def
get_iou_aware_score
(
output
,
an_num
,
num_classes
,
iou_aware_factor
):
ioup
,
output
=
_split_ioup
(
output
,
an_num
,
num_classes
)
output
=
_postprocess_output
(
ioup
,
output
,
an_num
,
num_classes
,
iou_aware_factor
)
return
output
paddlex/cv/nets/detection/loss/iou_aware_loss.py
0 → 100644
浏览文件 @
91a601e2
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.initializer
import
NumpyArrayInitializer
from
paddle
import
fluid
from
.iou_loss
import
IouLoss
class
IouAwareLoss
(
IouLoss
):
"""
iou aware loss, see https://arxiv.org/abs/1912.05992
Args:
loss_weight (float): iou aware loss weight, default is 1.0
max_height (int): max height of input to support random shape input
max_width (int): max width of input to support random shape input
"""
def
__init__
(
self
,
loss_weight
=
1.0
,
max_height
=
608
,
max_width
=
608
):
super
(
IouAwareLoss
,
self
).
__init__
(
loss_weight
=
loss_weight
,
max_height
=
max_height
,
max_width
=
max_width
)
def
__call__
(
self
,
ioup
,
x
,
y
,
w
,
h
,
tx
,
ty
,
tw
,
th
,
anchors
,
downsample_ratio
,
batch_size
,
eps
=
1.e-10
):
'''
Args:
ioup ([Variables]): the predicted iou
x | y | w | h ([Variables]): the output of yolov3 for encoded x|y|w|h
tx |ty |tw |th ([Variables]): the target of yolov3 for encoded x|y|w|h
anchors ([float]): list of anchors for current output layer
downsample_ratio (float): the downsample ratio for current output layer
batch_size (int): training batch size
eps (float): the decimal to prevent the denominator eqaul zero
'''
pred
=
self
.
_bbox_transform
(
x
,
y
,
w
,
h
,
anchors
,
downsample_ratio
,
batch_size
,
False
)
gt
=
self
.
_bbox_transform
(
tx
,
ty
,
tw
,
th
,
anchors
,
downsample_ratio
,
batch_size
,
True
)
iouk
=
self
.
_iou
(
pred
,
gt
,
ioup
,
eps
)
iouk
.
stop_gradient
=
True
loss_iou_aware
=
fluid
.
layers
.
cross_entropy
(
ioup
,
iouk
,
soft_label
=
True
)
loss_iou_aware
=
loss_iou_aware
*
self
.
_loss_weight
return
loss_iou_aware
paddlex/cv/nets/detection/loss/iou_loss.py
0 → 100644
浏览文件 @
91a601e2
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.initializer
import
NumpyArrayInitializer
from
paddle
import
fluid
class
IouLoss
(
object
):
"""
iou loss, see https://arxiv.org/abs/1908.03851
loss = 1.0 - iou * iou
Args:
loss_weight (float): iou loss weight, default is 2.5
max_height (int): max height of input to support random shape input
max_width (int): max width of input to support random shape input
ciou_term (bool): whether to add ciou_term
loss_square (bool): whether to square the iou term
"""
def
__init__
(
self
,
loss_weight
=
2.5
,
max_height
=
608
,
max_width
=
608
,
ciou_term
=
False
,
loss_square
=
True
):
self
.
_loss_weight
=
loss_weight
self
.
_MAX_HI
=
max_height
self
.
_MAX_WI
=
max_width
self
.
ciou_term
=
ciou_term
self
.
loss_square
=
loss_square
def
__call__
(
self
,
x
,
y
,
w
,
h
,
tx
,
ty
,
tw
,
th
,
anchors
,
downsample_ratio
,
batch_size
,
ioup
=
None
,
eps
=
1.e-10
):
'''
Args:
x | y | w | h ([Variables]): the output of yolov3 for encoded x|y|w|h
tx |ty |tw |th ([Variables]): the target of yolov3 for encoded x|y|w|h
anchors ([float]): list of anchors for current output layer
downsample_ratio (float): the downsample ratio for current output layer
batch_size (int): training batch size
eps (float): the decimal to prevent the denominator eqaul zero
'''
pred
=
self
.
_bbox_transform
(
x
,
y
,
w
,
h
,
anchors
,
downsample_ratio
,
batch_size
,
False
)
gt
=
self
.
_bbox_transform
(
tx
,
ty
,
tw
,
th
,
anchors
,
downsample_ratio
,
batch_size
,
True
)
iouk
=
self
.
_iou
(
pred
,
gt
,
ioup
,
eps
)
if
self
.
loss_square
:
loss_iou
=
1.
-
iouk
*
iouk
else
:
loss_iou
=
1.
-
iouk
loss_iou
=
loss_iou
*
self
.
_loss_weight
return
loss_iou
def
_iou
(
self
,
pred
,
gt
,
ioup
=
None
,
eps
=
1.e-10
):
x1
,
y1
,
x2
,
y2
=
pred
x1g
,
y1g
,
x2g
,
y2g
=
gt
x2
=
fluid
.
layers
.
elementwise_max
(
x1
,
x2
)
y2
=
fluid
.
layers
.
elementwise_max
(
y1
,
y2
)
xkis1
=
fluid
.
layers
.
elementwise_max
(
x1
,
x1g
)
ykis1
=
fluid
.
layers
.
elementwise_max
(
y1
,
y1g
)
xkis2
=
fluid
.
layers
.
elementwise_min
(
x2
,
x2g
)
ykis2
=
fluid
.
layers
.
elementwise_min
(
y2
,
y2g
)
intsctk
=
(
xkis2
-
xkis1
)
*
(
ykis2
-
ykis1
)
intsctk
=
intsctk
*
fluid
.
layers
.
greater_than
(
xkis2
,
xkis1
)
*
fluid
.
layers
.
greater_than
(
ykis2
,
ykis1
)
unionk
=
(
x2
-
x1
)
*
(
y2
-
y1
)
+
(
x2g
-
x1g
)
*
(
y2g
-
y1g
)
-
intsctk
+
eps
iouk
=
intsctk
/
unionk
if
self
.
ciou_term
:
ciou
=
self
.
get_ciou_term
(
pred
,
gt
,
iouk
,
eps
)
iouk
=
iouk
-
ciou
return
iouk
def
get_ciou_term
(
self
,
pred
,
gt
,
iouk
,
eps
):
x1
,
y1
,
x2
,
y2
=
pred
x1g
,
y1g
,
x2g
,
y2g
=
gt
cx
=
(
x1
+
x2
)
/
2
cy
=
(
y1
+
y2
)
/
2
w
=
(
x2
-
x1
)
+
fluid
.
layers
.
cast
((
x2
-
x1
)
==
0
,
'float32'
)
h
=
(
y2
-
y1
)
+
fluid
.
layers
.
cast
((
y2
-
y1
)
==
0
,
'float32'
)
cxg
=
(
x1g
+
x2g
)
/
2
cyg
=
(
y1g
+
y2g
)
/
2
wg
=
x2g
-
x1g
hg
=
y2g
-
y1g
# A or B
xc1
=
fluid
.
layers
.
elementwise_min
(
x1
,
x1g
)
yc1
=
fluid
.
layers
.
elementwise_min
(
y1
,
y1g
)
xc2
=
fluid
.
layers
.
elementwise_max
(
x2
,
x2g
)
yc2
=
fluid
.
layers
.
elementwise_max
(
y2
,
y2g
)
# DIOU term
dist_intersection
=
(
cx
-
cxg
)
*
(
cx
-
cxg
)
+
(
cy
-
cyg
)
*
(
cy
-
cyg
)
dist_union
=
(
xc2
-
xc1
)
*
(
xc2
-
xc1
)
+
(
yc2
-
yc1
)
*
(
yc2
-
yc1
)
diou_term
=
(
dist_intersection
+
eps
)
/
(
dist_union
+
eps
)
# CIOU term
ciou_term
=
0
ar_gt
=
wg
/
hg
ar_pred
=
w
/
h
arctan
=
fluid
.
layers
.
atan
(
ar_gt
)
-
fluid
.
layers
.
atan
(
ar_pred
)
ar_loss
=
4.
/
np
.
pi
/
np
.
pi
*
arctan
*
arctan
alpha
=
ar_loss
/
(
1
-
iouk
+
ar_loss
+
eps
)
alpha
.
stop_gradient
=
True
ciou_term
=
alpha
*
ar_loss
return
diou_term
+
ciou_term
def
_bbox_transform
(
self
,
dcx
,
dcy
,
dw
,
dh
,
anchors
,
downsample_ratio
,
batch_size
,
is_gt
):
grid_x
=
int
(
self
.
_MAX_WI
/
downsample_ratio
)
grid_y
=
int
(
self
.
_MAX_HI
/
downsample_ratio
)
an_num
=
len
(
anchors
)
//
2
shape_fmp
=
fluid
.
layers
.
shape
(
dcx
)
shape_fmp
.
stop_gradient
=
True
# generate the grid_w x grid_h center of feature map
idx_i
=
np
.
array
([[
i
for
i
in
range
(
grid_x
)]])
idx_j
=
np
.
array
([[
j
for
j
in
range
(
grid_y
)]]).
transpose
()
gi_np
=
np
.
repeat
(
idx_i
,
grid_y
,
axis
=
0
)
gi_np
=
np
.
reshape
(
gi_np
,
newshape
=
[
1
,
1
,
grid_y
,
grid_x
])
gi_np
=
np
.
tile
(
gi_np
,
reps
=
[
batch_size
,
an_num
,
1
,
1
])
gj_np
=
np
.
repeat
(
idx_j
,
grid_x
,
axis
=
1
)
gj_np
=
np
.
reshape
(
gj_np
,
newshape
=
[
1
,
1
,
grid_y
,
grid_x
])
gj_np
=
np
.
tile
(
gj_np
,
reps
=
[
batch_size
,
an_num
,
1
,
1
])
gi_max
=
self
.
_create_tensor_from_numpy
(
gi_np
.
astype
(
np
.
float32
))
gi
=
fluid
.
layers
.
crop
(
x
=
gi_max
,
shape
=
dcx
)
gi
.
stop_gradient
=
True
gj_max
=
self
.
_create_tensor_from_numpy
(
gj_np
.
astype
(
np
.
float32
))
gj
=
fluid
.
layers
.
crop
(
x
=
gj_max
,
shape
=
dcx
)
gj
.
stop_gradient
=
True
grid_x_act
=
fluid
.
layers
.
cast
(
shape_fmp
[
3
],
dtype
=
"float32"
)
grid_x_act
.
stop_gradient
=
True
grid_y_act
=
fluid
.
layers
.
cast
(
shape_fmp
[
2
],
dtype
=
"float32"
)
grid_y_act
.
stop_gradient
=
True
if
is_gt
:
cx
=
fluid
.
layers
.
elementwise_add
(
dcx
,
gi
)
/
grid_x_act
cx
.
gradient
=
True
cy
=
fluid
.
layers
.
elementwise_add
(
dcy
,
gj
)
/
grid_y_act
cy
.
gradient
=
True
else
:
dcx_sig
=
fluid
.
layers
.
sigmoid
(
dcx
)
cx
=
fluid
.
layers
.
elementwise_add
(
dcx_sig
,
gi
)
/
grid_x_act
dcy_sig
=
fluid
.
layers
.
sigmoid
(
dcy
)
cy
=
fluid
.
layers
.
elementwise_add
(
dcy_sig
,
gj
)
/
grid_y_act
anchor_w_
=
[
anchors
[
i
]
for
i
in
range
(
0
,
len
(
anchors
))
if
i
%
2
==
0
]
anchor_w_np
=
np
.
array
(
anchor_w_
)
anchor_w_np
=
np
.
reshape
(
anchor_w_np
,
newshape
=
[
1
,
an_num
,
1
,
1
])
anchor_w_np
=
np
.
tile
(
anchor_w_np
,
reps
=
[
batch_size
,
1
,
grid_y
,
grid_x
])
anchor_w_max
=
self
.
_create_tensor_from_numpy
(
anchor_w_np
.
astype
(
np
.
float32
))
anchor_w
=
fluid
.
layers
.
crop
(
x
=
anchor_w_max
,
shape
=
dcx
)
anchor_w
.
stop_gradient
=
True
anchor_h_
=
[
anchors
[
i
]
for
i
in
range
(
0
,
len
(
anchors
))
if
i
%
2
==
1
]
anchor_h_np
=
np
.
array
(
anchor_h_
)
anchor_h_np
=
np
.
reshape
(
anchor_h_np
,
newshape
=
[
1
,
an_num
,
1
,
1
])
anchor_h_np
=
np
.
tile
(
anchor_h_np
,
reps
=
[
batch_size
,
1
,
grid_y
,
grid_x
])
anchor_h_max
=
self
.
_create_tensor_from_numpy
(
anchor_h_np
.
astype
(
np
.
float32
))
anchor_h
=
fluid
.
layers
.
crop
(
x
=
anchor_h_max
,
shape
=
dcx
)
anchor_h
.
stop_gradient
=
True
# e^tw e^th
exp_dw
=
fluid
.
layers
.
exp
(
dw
)
exp_dh
=
fluid
.
layers
.
exp
(
dh
)
pw
=
fluid
.
layers
.
elementwise_mul
(
exp_dw
,
anchor_w
)
/
\
(
grid_x_act
*
downsample_ratio
)
ph
=
fluid
.
layers
.
elementwise_mul
(
exp_dh
,
anchor_h
)
/
\
(
grid_y_act
*
downsample_ratio
)
if
is_gt
:
exp_dw
.
stop_gradient
=
True
exp_dh
.
stop_gradient
=
True
pw
.
stop_gradient
=
True
ph
.
stop_gradient
=
True
x1
=
cx
-
0.5
*
pw
y1
=
cy
-
0.5
*
ph
x2
=
cx
+
0.5
*
pw
y2
=
cy
+
0.5
*
ph
if
is_gt
:
x1
.
stop_gradient
=
True
y1
.
stop_gradient
=
True
x2
.
stop_gradient
=
True
y2
.
stop_gradient
=
True
return
x1
,
y1
,
x2
,
y2
def
_create_tensor_from_numpy
(
self
,
numpy_array
):
paddle_array
=
fluid
.
layers
.
create_parameter
(
attr
=
ParamAttr
(),
shape
=
numpy_array
.
shape
,
dtype
=
numpy_array
.
dtype
,
default_initializer
=
NumpyArrayInitializer
(
numpy_array
))
paddle_array
.
stop_gradient
=
True
return
paddle_array
paddlex/cv/nets/detection/loss/yolo_loss.py
0 → 100644
浏览文件 @
91a601e2
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
paddle
import
fluid
from
.iou_loss
import
IouLoss
from
.iou_aware_loss
import
IouAwareLoss
class
YOLOv3Loss
(
object
):
"""
Combined loss for YOLOv3 network
Args:
batch_size (int): training batch size
ignore_thresh (float): threshold to ignore confidence loss
label_smooth (bool): whether to use label smoothing
"""
def
__init__
(
self
,
batch_size
=
8
,
ignore_thresh
=
0.7
,
label_smooth
=
True
,
iou_loss_weight
=
None
,
iou_aware_loss_weight
=
None
,
scale_x_y
=
1.
,
match_score
=
False
):
self
.
_batch_size
=
batch_size
self
.
_ignore_thresh
=
ignore_thresh
self
.
_label_smooth
=
label_smooth
self
.
_iou_loss_weight
=
iou_loss_weight
self
.
_iou_aware_loss_weight
=
iou_aware_loss_weight
self
.
match_score
=
match_score
def
__call__
(
self
,
outputs
,
gt_box
,
gt_label
,
gt_score
,
targets
,
anchors
,
anchor_masks
,
mask_anchors
,
num_classes
,
prefix_name
,
max_size
):
if
len
(
targets
)
!=
0
:
losses_all
=
self
.
_get_fine_grained_loss
(
outputs
,
targets
,
gt_box
,
self
.
_batch_size
,
num_classes
,
mask_anchors
,
self
.
_ignore_thresh
,
max_size
)
else
:
losses
=
[]
downsample
=
32
for
i
,
output
in
enumerate
(
outputs
):
anchor_mask
=
anchor_masks
[
i
]
loss
=
fluid
.
layers
.
yolov3_loss
(
x
=
output
,
gt_box
=
gt_box
,
gt_label
=
gt_label
,
gt_score
=
gt_score
,
anchors
=
anchors
,
anchor_mask
=
anchor_mask
,
class_num
=
num_classes
,
ignore_thresh
=
self
.
_ignore_thresh
,
downsample_ratio
=
downsample
,
use_label_smooth
=
self
.
_label_smooth
,
name
=
prefix_name
+
"yolo_loss"
+
str
(
i
))
losses
.
append
(
fluid
.
layers
.
reduce_mean
(
loss
))
downsample
//=
2
losses_all
=
{
'loss'
:
sum
(
losses
)}
total_loss
=
fluid
.
layers
.
sum
(
list
(
losses_all
.
values
()))
return
total_loss
def
_get_fine_grained_loss
(
self
,
outputs
,
targets
,
gt_box
,
batch_size
,
num_classes
,
mask_anchors
,
ignore_thresh
,
max_size
):
"""
Calculate fine grained YOLOv3 loss
Args:
outputs ([Variables]): List of Variables, output of backbone stages
targets ([Variables]): List of Variables, The targets for yolo
loss calculatation.
gt_box (Variable): The ground-truth boudding boxes.
batch_size (int): The training batch size
num_classes (int): class num of dataset
mask_anchors ([[float]]): list of anchors in each output layer
ignore_thresh (float): prediction bbox overlap any gt_box greater
than ignore_thresh, objectness loss will
be ignored.
Returns:
Type: dict
xy_loss (Variable): YOLOv3 (x, y) coordinates loss
wh_loss (Variable): YOLOv3 (w, h) coordinates loss
obj_loss (Variable): YOLOv3 objectness score loss
cls_loss (Variable): YOLOv3 classification loss
"""
assert
len
(
outputs
)
==
len
(
targets
),
\
"YOLOv3 output layer number not equal target number"
loss_xys
,
loss_whs
,
loss_objs
,
loss_clss
=
[],
[],
[],
[]
if
self
.
_iou_loss_weight
is
not
None
:
loss_ious
=
[]
if
self
.
_iou_aware_loss_weight
is
not
None
:
loss_iou_awares
=
[]
downsample
=
32
for
i
,
(
output
,
target
,
anchors
)
in
enumerate
(
zip
(
outputs
,
targets
,
mask_anchors
)):
an_num
=
len
(
anchors
)
//
2
if
self
.
_iou_aware_loss_weight
is
not
None
:
ioup
,
output
=
self
.
_split_ioup
(
output
,
an_num
,
num_classes
)
x
,
y
,
w
,
h
,
obj
,
cls
=
self
.
_split_output
(
output
,
an_num
,
num_classes
)
tx
,
ty
,
tw
,
th
,
tscale
,
tobj
,
tcls
=
self
.
_split_target
(
target
)
tscale_tobj
=
tscale
*
tobj
loss_x
=
fluid
.
layers
.
sigmoid_cross_entropy_with_logits
(
x
,
tx
)
*
tscale_tobj
loss_x
=
fluid
.
layers
.
reduce_sum
(
loss_x
,
dim
=
[
1
,
2
,
3
])
loss_y
=
fluid
.
layers
.
sigmoid_cross_entropy_with_logits
(
y
,
ty
)
*
tscale_tobj
loss_y
=
fluid
.
layers
.
reduce_sum
(
loss_y
,
dim
=
[
1
,
2
,
3
])
# NOTE: we refined loss function of (w, h) as L1Loss
loss_w
=
fluid
.
layers
.
abs
(
w
-
tw
)
*
tscale_tobj
loss_w
=
fluid
.
layers
.
reduce_sum
(
loss_w
,
dim
=
[
1
,
2
,
3
])
loss_h
=
fluid
.
layers
.
abs
(
h
-
th
)
*
tscale_tobj
loss_h
=
fluid
.
layers
.
reduce_sum
(
loss_h
,
dim
=
[
1
,
2
,
3
])
if
self
.
_iou_loss_weight
is
not
None
:
iou_loss_obj
=
IouLoss
(
self
.
_iou_loss_weight
,
max_size
,
max_size
)
loss_iou
=
iou_loss_obj
(
x
,
y
,
w
,
h
,
tx
,
ty
,
tw
,
th
,
anchors
,
downsample
,
self
.
_batch_size
)
loss_iou
=
loss_iou
*
tscale_tobj
loss_iou
=
fluid
.
layers
.
reduce_sum
(
loss_iou
,
dim
=
[
1
,
2
,
3
])
loss_ious
.
append
(
fluid
.
layers
.
reduce_mean
(
loss_iou
))
if
self
.
_iou_aware_loss_weight
is
not
None
:
iou_aware_loss_obj
=
IouAwareLoss
(
self
.
_iou_aware_loss_weight
,
max_size
,
max_size
)
loss_iou_aware
=
iou_aware_loss_obj
(
ioup
,
x
,
y
,
w
,
h
,
tx
,
ty
,
tw
,
th
,
anchors
,
downsample
,
self
.
_batch_size
)
loss_iou_aware
=
loss_iou_aware
*
tobj
loss_iou_aware
=
fluid
.
layers
.
reduce_sum
(
loss_iou_aware
,
dim
=
[
1
,
2
,
3
])
loss_iou_awares
.
append
(
fluid
.
layers
.
reduce_mean
(
loss_iou_aware
))
loss_obj_pos
,
loss_obj_neg
=
self
.
_calc_obj_loss
(
output
,
obj
,
tobj
,
gt_box
,
self
.
_batch_size
,
anchors
,
num_classes
,
downsample
,
self
.
_ignore_thresh
)
loss_cls
=
fluid
.
layers
.
sigmoid_cross_entropy_with_logits
(
cls
,
tcls
)
loss_cls
=
fluid
.
layers
.
elementwise_mul
(
loss_cls
,
tobj
,
axis
=
0
)
loss_cls
=
fluid
.
layers
.
reduce_sum
(
loss_cls
)
loss_xys
.
append
(
fluid
.
layers
.
reduce_mean
(
loss_x
+
loss_y
))
loss_whs
.
append
(
fluid
.
layers
.
reduce_mean
(
loss_w
+
loss_h
))
loss_objs
.
append
(
fluid
.
layers
.
reduce_mean
(
loss_obj_pos
+
loss_obj_neg
))
loss_clss
.
append
(
fluid
.
layers
.
reduce_mean
(
loss_cls
))
downsample
//=
2
losses_all
=
{
"loss_xy"
:
fluid
.
layers
.
sum
(
loss_xys
),
"loss_wh"
:
fluid
.
layers
.
sum
(
loss_whs
),
"loss_obj"
:
fluid
.
layers
.
sum
(
loss_objs
),
"loss_cls"
:
fluid
.
layers
.
sum
(
loss_clss
),
}
if
self
.
_iou_loss_weight
is
not
None
:
losses_all
[
"loss_iou"
]
=
fluid
.
layers
.
sum
(
loss_ious
)
if
self
.
_iou_aware_loss_weight
is
not
None
:
losses_all
[
"loss_iou_aware"
]
=
fluid
.
layers
.
sum
(
loss_iou_awares
)
return
losses_all
def
_split_ioup
(
self
,
output
,
an_num
,
num_classes
):
"""
Split output feature map to output, predicted iou
along channel dimension
"""
ioup
=
fluid
.
layers
.
slice
(
output
,
axes
=
[
1
],
starts
=
[
0
],
ends
=
[
an_num
])
ioup
=
fluid
.
layers
.
sigmoid
(
ioup
)
oriout
=
fluid
.
layers
.
slice
(
output
,
axes
=
[
1
],
starts
=
[
an_num
],
ends
=
[
an_num
*
(
num_classes
+
6
)])
return
(
ioup
,
oriout
)
def
_split_output
(
self
,
output
,
an_num
,
num_classes
):
"""
Split output feature map to x, y, w, h, objectness, classification
along channel dimension
"""
x
=
fluid
.
layers
.
strided_slice
(
output
,
axes
=
[
1
],
starts
=
[
0
],
ends
=
[
output
.
shape
[
1
]],
strides
=
[
5
+
num_classes
])
y
=
fluid
.
layers
.
strided_slice
(
output
,
axes
=
[
1
],
starts
=
[
1
],
ends
=
[
output
.
shape
[
1
]],
strides
=
[
5
+
num_classes
])
w
=
fluid
.
layers
.
strided_slice
(
output
,
axes
=
[
1
],
starts
=
[
2
],
ends
=
[
output
.
shape
[
1
]],
strides
=
[
5
+
num_classes
])
h
=
fluid
.
layers
.
strided_slice
(
output
,
axes
=
[
1
],
starts
=
[
3
],
ends
=
[
output
.
shape
[
1
]],
strides
=
[
5
+
num_classes
])
obj
=
fluid
.
layers
.
strided_slice
(
output
,
axes
=
[
1
],
starts
=
[
4
],
ends
=
[
output
.
shape
[
1
]],
strides
=
[
5
+
num_classes
])
clss
=
[]
stride
=
output
.
shape
[
1
]
//
an_num
for
m
in
range
(
an_num
):
clss
.
append
(
fluid
.
layers
.
slice
(
output
,
axes
=
[
1
],
starts
=
[
stride
*
m
+
5
],
ends
=
[
stride
*
m
+
5
+
num_classes
]))
cls
=
fluid
.
layers
.
transpose
(
fluid
.
layers
.
stack
(
clss
,
axis
=
1
),
perm
=
[
0
,
1
,
3
,
4
,
2
])
return
(
x
,
y
,
w
,
h
,
obj
,
cls
)
def
_split_target
(
self
,
target
):
"""
split target to x, y, w, h, objectness, classification
along dimension 2
target is in shape [N, an_num, 6 + class_num, H, W]
"""
tx
=
target
[:,
:,
0
,
:,
:]
ty
=
target
[:,
:,
1
,
:,
:]
tw
=
target
[:,
:,
2
,
:,
:]
th
=
target
[:,
:,
3
,
:,
:]
tscale
=
target
[:,
:,
4
,
:,
:]
tobj
=
target
[:,
:,
5
,
:,
:]
tcls
=
fluid
.
layers
.
transpose
(
target
[:,
:,
6
:,
:,
:],
perm
=
[
0
,
1
,
3
,
4
,
2
])
tcls
.
stop_gradient
=
True
return
(
tx
,
ty
,
tw
,
th
,
tscale
,
tobj
,
tcls
)
def
_calc_obj_loss
(
self
,
output
,
obj
,
tobj
,
gt_box
,
batch_size
,
anchors
,
num_classes
,
downsample
,
ignore_thresh
):
# A prediction bbox overlap any gt_bbox over ignore_thresh,
# objectness loss will be ignored, process as follows:
# 1. get pred bbox, which is same with YOLOv3 infer mode, use yolo_box here
# NOTE: img_size is set as 1.0 to get noramlized pred bbox
bbox
,
prob
=
fluid
.
layers
.
yolo_box
(
x
=
output
,
img_size
=
fluid
.
layers
.
ones
(
shape
=
[
batch_size
,
2
],
dtype
=
"int32"
),
anchors
=
anchors
,
class_num
=
num_classes
,
conf_thresh
=
0.
,
downsample_ratio
=
downsample
,
clip_bbox
=
False
)
# 2. split pred bbox and gt bbox by sample, calculate IoU between pred bbox
# and gt bbox in each sample
if
batch_size
>
1
:
preds
=
fluid
.
layers
.
split
(
bbox
,
batch_size
,
dim
=
0
)
gts
=
fluid
.
layers
.
split
(
gt_box
,
batch_size
,
dim
=
0
)
else
:
preds
=
[
bbox
]
gts
=
[
gt_box
]
probs
=
[
prob
]
ious
=
[]
for
pred
,
gt
in
zip
(
preds
,
gts
):
def
box_xywh2xyxy
(
box
):
x
=
box
[:,
0
]
y
=
box
[:,
1
]
w
=
box
[:,
2
]
h
=
box
[:,
3
]
return
fluid
.
layers
.
stack
(
[
x
-
w
/
2.
,
y
-
h
/
2.
,
x
+
w
/
2.
,
y
+
h
/
2.
,
],
axis
=
1
)
pred
=
fluid
.
layers
.
squeeze
(
pred
,
axes
=
[
0
])
gt
=
box_xywh2xyxy
(
fluid
.
layers
.
squeeze
(
gt
,
axes
=
[
0
]))
ious
.
append
(
fluid
.
layers
.
iou_similarity
(
pred
,
gt
))
iou
=
fluid
.
layers
.
stack
(
ious
,
axis
=
0
)
# 3. Get iou_mask by IoU between gt bbox and prediction bbox,
# Get obj_mask by tobj(holds gt_score), calculate objectness loss
max_iou
=
fluid
.
layers
.
reduce_max
(
iou
,
dim
=-
1
)
iou_mask
=
fluid
.
layers
.
cast
(
max_iou
<=
ignore_thresh
,
dtype
=
"float32"
)
if
self
.
match_score
:
max_prob
=
fluid
.
layers
.
reduce_max
(
prob
,
dim
=-
1
)
iou_mask
=
iou_mask
*
fluid
.
layers
.
cast
(
max_prob
<=
0.25
,
dtype
=
"float32"
)
output_shape
=
fluid
.
layers
.
shape
(
output
)
an_num
=
len
(
anchors
)
//
2
iou_mask
=
fluid
.
layers
.
reshape
(
iou_mask
,
(
-
1
,
an_num
,
output_shape
[
2
],
output_shape
[
3
]))
iou_mask
.
stop_gradient
=
True
# NOTE: tobj holds gt_score, obj_mask holds object existence mask
obj_mask
=
fluid
.
layers
.
cast
(
tobj
>
0.
,
dtype
=
"float32"
)
obj_mask
.
stop_gradient
=
True
# For positive objectness grids, objectness loss should be calculated
# For negative objectness grids, objectness loss is calculated only iou_mask == 1.0
loss_obj
=
fluid
.
layers
.
sigmoid_cross_entropy_with_logits
(
obj
,
obj_mask
)
loss_obj_pos
=
fluid
.
layers
.
reduce_sum
(
loss_obj
*
tobj
,
dim
=
[
1
,
2
,
3
])
loss_obj_neg
=
fluid
.
layers
.
reduce_sum
(
loss_obj
*
(
1.0
-
obj_mask
)
*
iou_mask
,
dim
=
[
1
,
2
,
3
])
return
loss_obj_pos
,
loss_obj_neg
paddlex/cv/nets/detection/yolo_v3.py
浏览文件 @
91a601e2
...
...
@@ -13,9 +13,12 @@
# limitations under the License.
from
paddle
import
fluid
from
paddle.fluid.initializer
import
NumpyArrayInitializer
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.regularizer
import
L2Decay
from
collections
import
OrderedDict
from
.loss
import
yolo_loss
from
.iou_aware
import
get_iou_aware_score
class
YOLOv3
:
...
...
@@ -34,7 +37,12 @@ class YOLOv3:
train_random_shapes
=
[
320
,
352
,
384
,
416
,
448
,
480
,
512
,
544
,
576
,
608
],
fixed_input_shape
=
None
):
fixed_input_shape
=
None
,
use_iou_loss
=
False
,
use_iou_aware_loss
=
False
,
iou_aware_factor
=
0.4
,
use_drop_block
=
False
,
batch_size
=
8
):
if
anchors
is
None
:
anchors
=
[[
10
,
13
],
[
16
,
30
],
[
33
,
23
],
[
30
,
61
],
[
62
,
45
],
[
59
,
119
],
[
116
,
90
],
[
156
,
198
],
[
373
,
326
]]
...
...
@@ -56,6 +64,13 @@ class YOLOv3:
self
.
prefix_name
=
''
self
.
train_random_shapes
=
train_random_shapes
self
.
fixed_input_shape
=
fixed_input_shape
self
.
use_iou_loss
=
use_iou_loss
self
.
use_iou_aware_loss
=
use_iou_aware_loss
self
.
iou_aware_factor
=
iou_aware_factor
self
.
use_drop_block
=
use_drop_block
self
.
block_size
=
3
self
.
keep_prob
=
0.9
self
.
batch_size
=
batch_size
def
_head
(
self
,
feats
):
outputs
=
[]
...
...
@@ -71,7 +86,10 @@ class YOLOv3:
channel
=
512
//
(
2
**
i
),
name
=
self
.
prefix_name
+
'yolo_block.{}'
.
format
(
i
))
num_filters
=
len
(
self
.
anchor_masks
[
i
])
*
(
self
.
num_classes
+
5
)
if
self
.
use_iou_aware_loss
:
num_filters
=
len
(
self
.
anchor_masks
[
i
])
*
(
self
.
num_classes
+
6
)
else
:
num_filters
=
len
(
self
.
anchor_masks
[
i
])
*
(
self
.
num_classes
+
5
)
block_out
=
fluid
.
layers
.
conv2d
(
input
=
tip
,
num_filters
=
num_filters
,
...
...
@@ -155,6 +173,55 @@ class YOLOv3:
out
=
fluid
.
layers
.
resize_nearest
(
input
=
input
,
scale
=
float
(
scale
),
name
=
name
)
return
out
def
_dropblock
(
self
,
input
,
block_size
=
3
,
keep_prob
=
0.9
):
is_test
=
False
if
self
.
mode
==
'train'
else
True
if
is_test
:
return
input
def
calculate_gamma
(
input
,
block_size
,
keep_prob
):
input_shape
=
fluid
.
layers
.
shape
(
input
)
feat_shape_tmp
=
fluid
.
layers
.
slice
(
input_shape
,
[
0
],
[
3
],
[
4
])
feat_shape_tmp
=
fluid
.
layers
.
cast
(
feat_shape_tmp
,
dtype
=
"float32"
)
feat_shape_t
=
fluid
.
layers
.
reshape
(
feat_shape_tmp
,
[
1
,
1
,
1
,
1
])
feat_area
=
fluid
.
layers
.
pow
(
feat_shape_t
,
factor
=
2
)
block_shape_t
=
fluid
.
layers
.
fill_constant
(
shape
=
[
1
,
1
,
1
,
1
],
value
=
block_size
,
dtype
=
'float32'
)
block_area
=
fluid
.
layers
.
pow
(
block_shape_t
,
factor
=
2
)
useful_shape_t
=
feat_shape_t
-
block_shape_t
+
1
useful_area
=
fluid
.
layers
.
pow
(
useful_shape_t
,
factor
=
2
)
upper_t
=
feat_area
*
(
1
-
keep_prob
)
bottom_t
=
block_area
*
useful_area
output
=
upper_t
/
bottom_t
return
output
gamma
=
calculate_gamma
(
input
,
block_size
=
block_size
,
keep_prob
=
keep_prob
)
input_shape
=
fluid
.
layers
.
shape
(
input
)
p
=
fluid
.
layers
.
expand_as
(
gamma
,
input
)
input_shape_tmp
=
fluid
.
layers
.
cast
(
input_shape
,
dtype
=
"int64"
)
random_matrix
=
fluid
.
layers
.
uniform_random
(
input_shape_tmp
,
dtype
=
'float32'
,
min
=
0.0
,
max
=
1.0
,
seed
=
1
)
one_zero_m
=
fluid
.
layers
.
less_than
(
random_matrix
,
p
)
one_zero_m
.
stop_gradient
=
True
one_zero_m
=
fluid
.
layers
.
cast
(
one_zero_m
,
dtype
=
"float32"
)
mask_flag
=
fluid
.
layers
.
pool2d
(
one_zero_m
,
pool_size
=
block_size
,
pool_type
=
'max'
,
pool_stride
=
1
,
pool_padding
=
block_size
//
2
)
mask
=
1.0
-
mask_flag
elem_numel
=
fluid
.
layers
.
reduce_prod
(
input_shape
)
elem_numel_m
=
fluid
.
layers
.
cast
(
elem_numel
,
dtype
=
"float32"
)
elem_numel_m
.
stop_gradient
=
True
elem_sum
=
fluid
.
layers
.
reduce_sum
(
mask
)
elem_sum_m
=
fluid
.
layers
.
cast
(
elem_sum
,
dtype
=
"float32"
)
elem_sum_m
.
stop_gradient
=
True
output
=
input
*
mask
*
elem_numel_m
/
elem_sum_m
return
output
def
_detection_block
(
self
,
input
,
channel
,
name
=
None
):
assert
channel
%
2
==
0
,
"channel({}) cannot be divided by 2 in detection block({})"
.
format
(
...
...
@@ -179,6 +246,16 @@ class YOLOv3:
padding
=
1
,
is_test
=
is_test
,
name
=
'{}.{}.1'
.
format
(
name
,
i
))
if
self
.
use_drop_block
and
i
==
0
and
channel
!=
512
:
conv
=
self
.
_dropblock
(
conv
,
block_size
=
self
.
block_size
,
keep_prob
=
self
.
keep_prob
)
if
self
.
use_drop_block
and
channel
==
512
:
conv
=
self
.
_dropblock
(
conv
,
block_size
=
self
.
block_size
,
keep_prob
=
self
.
keep_prob
)
route
=
self
.
_conv_bn
(
conv
,
channel
,
...
...
@@ -197,31 +274,28 @@ class YOLOv3:
name
=
'{}.tip'
.
format
(
name
))
return
route
,
tip
def
_get_loss
(
self
,
inputs
,
gt_box
,
gt_label
,
gt_score
):
def
_get_loss
(
self
,
inputs
,
gt_box
,
gt_label
,
gt_score
,
targets
):
losses
=
[]
downsample
=
32
for
i
,
input
in
enumerate
(
inputs
):
loss
=
fluid
.
layers
.
yolov3_loss
(
x
=
input
,
gt_box
=
gt_box
,
gt_label
=
gt_label
,
gt_score
=
gt_score
,
anchors
=
self
.
anchors
,
anchor_mask
=
self
.
anchor_masks
[
i
],
class_num
=
self
.
num_classes
,
ignore_thresh
=
self
.
ignore_thresh
,
downsample_ratio
=
downsample
,
use_label_smooth
=
self
.
label_smooth
,
name
=
self
.
prefix_name
+
'yolo_loss'
+
str
(
i
))
losses
.
append
(
fluid
.
layers
.
reduce_mean
(
loss
))
downsample
//=
2
return
sum
(
losses
)
yolo_loss_obj
=
yolo_loss
.
YOLOv3Loss
(
batch_size
=
self
.
batch_size
,
ignore_thresh
=
self
.
ignore_thresh
,
label_smooth
=
self
.
label_smooth
,
iou_loss_weight
=
2.5
if
self
.
use_iou_loss
else
None
,
iou_aware_loss_weight
=
1.0
if
self
.
use_iou_aware_loss
else
None
)
return
yolo_loss_obj
(
inputs
,
gt_box
,
gt_label
,
gt_score
,
targets
,
self
.
anchors
,
self
.
anchor_masks
,
self
.
mask_anchors
,
self
.
num_classes
,
self
.
prefix_name
,
max
(
self
.
train_random_shapes
))
def
_get_prediction
(
self
,
inputs
,
im_size
):
boxes
=
[]
scores
=
[]
downsample
=
32
for
i
,
input
in
enumerate
(
inputs
):
if
self
.
use_iou_aware_loss
:
input
=
get_iou_aware_score
(
input
,
len
(
self
.
anchor_masks
[
i
]),
self
.
num_classes
,
self
.
iou_aware_factor
)
box
,
score
=
fluid
.
layers
.
yolo_box
(
x
=
input
,
img_size
=
im_size
,
...
...
@@ -267,6 +341,12 @@ class YOLOv3:
dtype
=
'float32'
,
shape
=
[
None
,
None
],
name
=
'gt_score'
)
inputs
[
'im_size'
]
=
fluid
.
data
(
dtype
=
'int32'
,
shape
=
[
None
,
2
],
name
=
'im_size'
)
if
self
.
use_iou_loss
or
self
.
use_iou_aware_loss
:
for
i
,
mask
in
enumerate
(
self
.
anchor_masks
):
inputs
[
'target{}'
.
format
(
i
)]
=
fluid
.
data
(
dtype
=
'float32'
,
shape
=
[
None
,
len
(
mask
),
6
+
self
.
num_classes
,
None
,
None
],
name
=
'target{}'
.
format
(
i
))
elif
self
.
mode
==
'eval'
:
inputs
[
'im_size'
]
=
fluid
.
data
(
dtype
=
'int32'
,
shape
=
[
None
,
2
],
name
=
'im_size'
)
...
...
@@ -285,22 +365,6 @@ class YOLOv3:
def
build_net
(
self
,
inputs
):
image
=
inputs
[
'image'
]
if
self
.
mode
==
'train'
:
if
isinstance
(
self
.
train_random_shapes
,
(
list
,
tuple
))
and
len
(
self
.
train_random_shapes
)
>
0
:
import
numpy
as
np
shapes
=
np
.
array
(
self
.
train_random_shapes
)
shapes
=
np
.
stack
([
shapes
,
shapes
],
axis
=
1
).
astype
(
'float32'
)
shapes_tensor
=
fluid
.
layers
.
assign
(
shapes
)
index
=
fluid
.
layers
.
uniform_random
(
shape
=
[
1
],
dtype
=
'float32'
,
min
=
0.0
,
max
=
1
)
index
=
fluid
.
layers
.
cast
(
index
*
len
(
self
.
train_random_shapes
),
dtype
=
'int32'
)
shape
=
fluid
.
layers
.
gather
(
shapes_tensor
,
index
)
shape
=
fluid
.
layers
.
reshape
(
shape
,
[
-
1
])
shape
=
fluid
.
layers
.
cast
(
shape
,
dtype
=
'int32'
)
image
=
fluid
.
layers
.
resize_nearest
(
image
,
out_shape
=
shape
,
align_corners
=
False
)
feats
=
self
.
backbone
(
image
)
if
isinstance
(
feats
,
OrderedDict
):
feat_names
=
list
(
feats
.
keys
())
...
...
@@ -320,8 +384,14 @@ class YOLOv3:
whwh
=
fluid
.
layers
.
cast
(
whwh
,
dtype
=
'float32'
)
whwh
.
stop_gradient
=
True
normalized_box
=
fluid
.
layers
.
elementwise_div
(
gt_box
,
whwh
)
targets
=
[]
if
self
.
use_iou_loss
or
self
.
use_iou_aware_loss
:
for
i
,
mask
in
enumerate
(
self
.
anchor_masks
):
k
=
'target{}'
.
format
(
i
)
if
k
in
inputs
:
targets
.
append
(
inputs
[
k
])
return
self
.
_get_loss
(
head_outputs
,
normalized_box
,
gt_label
,
gt_score
)
gt_score
,
targets
)
else
:
im_size
=
inputs
[
'im_size'
]
return
self
.
_get_prediction
(
head_outputs
,
im_size
)
paddlex/cv/transforms/det_transforms.py
浏览文件 @
91a601e2
...
...
@@ -49,13 +49,14 @@ class Compose(DetTransform):
ValueError: 数据长度不匹配。
"""
def
__init__
(
self
,
transforms
):
def
__init__
(
self
,
transforms
,
batch_transforms
=
None
):
if
not
isinstance
(
transforms
,
list
):
raise
TypeError
(
'The transforms must be a list!'
)
if
len
(
transforms
)
<
1
:
raise
ValueError
(
'The length of transforms '
+
\
'must be equal or larger than 1!'
)
self
.
transforms
=
transforms
self
.
batch_transforms
=
batch_transforms
self
.
use_mixup
=
False
for
t
in
self
.
transforms
:
if
type
(
t
).
__name__
==
'MixupImage'
:
...
...
@@ -498,9 +499,10 @@ class Normalize(DetTransform):
TypeError: 形参数据类型不满足需求。
"""
def
__init__
(
self
,
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
]):
def
__init__
(
self
,
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
]
,
is_scale
=
True
):
self
.
mean
=
mean
self
.
std
=
std
self
.
is_scale
=
is_scale
if
not
(
isinstance
(
self
.
mean
,
list
)
and
isinstance
(
self
.
std
,
list
)):
raise
TypeError
(
"NormalizeImage: input type is invalid."
)
from
functools
import
reduce
...
...
@@ -521,7 +523,7 @@ class Normalize(DetTransform):
"""
mean
=
np
.
array
(
self
.
mean
)[
np
.
newaxis
,
np
.
newaxis
,
:]
std
=
np
.
array
(
self
.
std
)[
np
.
newaxis
,
np
.
newaxis
,
:]
im
=
normalize
(
im
,
mean
,
std
)
im
=
normalize
(
im
,
mean
,
std
,
self
.
is_scale
)
if
label_info
is
None
:
return
(
im
,
im_info
)
else
:
...
...
@@ -1233,108 +1235,190 @@ class ArrangeYOLOv3(DetTransform):
im_shape
=
im_info
[
'image_shape'
]
outputs
=
(
im
,
im_shape
)
return
outputs
class
RandomShape
(
DetTransform
):
"""调整图像大小(resize)。
对batch数据中的每张图像全部resize到random_shapes中任意一个大小。
注意:当插值方式为“RANDOM”时,则随机选取一种插值方式进行resize。
Args:
random_shapes (list): resize大小选择列表。
默认为[320, 352, 384, 416, 448, 480, 512, 544, 576, 608]。
interp (str): resize的插值方式,与opencv的插值方式对应,取值范围为
['NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM']。默认为"RANDOM"。
Raises:
ValueError: 插值方式不在['NEAREST', 'LINEAR', 'CUBIC',
'AREA', 'LANCZOS4', 'RANDOM']中。
"""
# The interpolation mode
interp_dict
=
{
'NEAREST'
:
cv2
.
INTER_NEAREST
,
'LINEAR'
:
cv2
.
INTER_LINEAR
,
'CUBIC'
:
cv2
.
INTER_CUBIC
,
'AREA'
:
cv2
.
INTER_AREA
,
'LANCZOS4'
:
cv2
.
INTER_LANCZOS4
}
def
__init__
(
self
,
random_shapes
=
[
320
,
352
,
384
,
416
,
448
,
480
,
512
,
544
,
576
,
608
],
interp
=
'RANDOM'
):
if
not
(
interp
==
"RANDOM"
or
interp
in
self
.
interp_dict
):
raise
ValueError
(
"interp should be one of {}"
.
format
(
self
.
interp_dict
.
keys
()))
self
.
random_shapes
=
random_shapes
self
.
interp
=
interp
class
ComposedRCNNTransforms
(
Compose
):
""" RCNN模型(faster-rcnn/mask-rcnn)图像处理流程,具体如下,
训练阶段:
1. 随机以0.5的概率将图像水平翻转
2. 图像归一化
3. 图像按比例Resize,scale计算方式如下
scale = min_max_size[0] / short_size_of_image
if max_size_of_image * scale > min_max_size[1]:
scale = min_max_size[1] / max_size_of_image
4. 将3步骤的长宽进行padding,使得长宽为32的倍数
验证阶段:
1. 图像归一化
2. 图像按比例Resize,scale计算方式同上训练阶段
3. 将2步骤的长宽进行padding,使得长宽为32的倍数
def
__call__
(
self
,
batch_data
):
"""
Args:
mode(str): 图像处理流程所处阶段,训练/验证/预测,分别对应'train', 'eval', 'test'
min_max_size(list): 图像在缩放时,最小边和最大边的约束条件
mean(list): 图像均值
std(list): 图像方差
"""
batch_data (list): 由与图像相关的各种信息组成的batch数据。
def
__init__
(
self
,
mode
,
min_max_size
=
[
800
,
1333
],
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
]):
if
mode
==
'train'
:
# 训练时的transforms,包含数据增强
transforms
=
[
RandomHorizontalFlip
(
prob
=
0.5
),
Normalize
(
mean
=
mean
,
std
=
std
),
ResizeByShort
(
short_size
=
min_max_size
[
0
],
max_size
=
min_max_size
[
1
]),
Padding
(
coarsest_stride
=
32
)
]
Returns:
list: 由与图像相关的各种信息组成的batch数据。
"""
shape
=
np
.
random
.
choice
(
self
.
random_shapes
)
if
self
.
interp
==
"RANDOM"
:
interp
=
random
.
choice
(
list
(
self
.
interp_dict
.
keys
()))
else
:
# 验证/预测时的transforms
transforms
=
[
Normalize
(
mean
=
mean
,
std
=
std
),
ResizeByShort
(
short_size
=
min_max_size
[
0
],
max_size
=
min_max_size
[
1
]),
Padding
(
coarsest_stride
=
32
)
]
super
(
ComposedRCNNTransforms
,
self
).
__init__
(
transforms
)
class
ComposedYOLOTransforms
(
Compose
):
"""YOLOv3模型的图像预处理流程,具体如下,
训练阶段:
1. 在前mixup_epoch轮迭代中,使用MixupImage策略,见https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/det_transforms.html#mixupimage
2. 对图像进行随机扰动,包括亮度,对比度,饱和度和色调
3. 随机扩充图像,见https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/det_transforms.html#randomexpand
4. 随机裁剪图像
5. 将4步骤的输出图像Resize成shape参数的大小
6. 随机0.5的概率水平翻转图像
7. 图像归一化
验证/预测阶段:
1. 将图像Resize成shape参数大小
2. 图像归一化
Args:
mode(str): 图像处理流程所处阶段,训练/验证/预测,分别对应'train', 'eval', 'test'
shape(list): 输入模型中图像的大小,输入模型的图像会被Resize成此大小
mixup_epoch(int): 模型训练过程中,前mixup_epoch会使用mixup策略
mean(list): 图像均值
std(list): 图像方差
interp
=
self
.
interp
for
data_id
,
data
in
enumerate
(
batch_data
):
data_list
=
list
(
data
)
im
=
data_list
[
0
]
im
=
np
.
swapaxes
(
im
,
1
,
0
)
im
=
np
.
swapaxes
(
im
,
1
,
2
)
im
=
resize
(
im
,
shape
,
self
.
interp_dict
[
interp
])
im
=
np
.
swapaxes
(
im
,
1
,
2
)
im
=
np
.
swapaxes
(
im
,
1
,
0
)
data_list
[
0
]
=
im
batch_data
[
data_id
]
=
tuple
(
data_list
)
np
.
save
(
'im.npy'
,
im
)
return
batch_data
class
GenerateYoloTarget
(
DetTransform
):
"""生成YOLOv3的ground truth(真实标注框)在不同特征层的位置转换信息。
该transform只在YOLOv3计算细粒度loss时使用。
Args:
anchors (list|tuple): anchor框的宽度和高度。
anchor_masks (list|tuple): 在计算损失时,使用anchor的mask索引。
num_classes (int): 类别数。默认为80。
iou_thresh (float): iou阈值,当anchor和真实标注框的iou大于该阈值时,计入target。默认为1.0。
"""
def
__init__
(
self
,
anchors
,
anchor_masks
,
num_classes
=
80
,
iou_thresh
=
1.
):
super
(
GenerateYoloTarget
,
self
).
__init__
()
self
.
anchors
=
anchors
self
.
anchor_masks
=
anchor_masks
self
.
num_classes
=
num_classes
self
.
iou_thresh
=
iou_thresh
def
__call__
(
self
,
batch_data
):
"""
Args:
batch_data (list): 由与图像相关的各种信息组成的batch数据。
def
__init__
(
self
,
mode
,
shape
=
[
608
,
608
],
mixup_epoch
=
250
,
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
]):
width
=
shape
if
isinstance
(
shape
,
list
):
if
shape
[
0
]
!=
shape
[
1
]:
raise
Exception
(
"In YOLOv3 model, width and height should be equal"
)
width
=
shape
[
0
]
if
width
%
32
!=
0
:
raise
Exception
(
"In YOLOv3 model, width and height should be multiple of 32, e.g 224、256、320...."
)
Returns:
list: 由与图像相关的各种信息组成的batch数据。
其中,每个数据新添加的字段为:
- target0 (np.ndarray): YOLOv3的ground truth在特征层0的位置转换信息,
形状为(特征层0的anchor数量, 6+类别数, 特征层0的h, 特征层0的w)。
- target1 (np.ndarray): YOLOv3的ground truth在特征层1的位置转换信息,
形状为(特征层1的anchor数量, 6+类别数, 特征层1的h, 特征层1的w)。
- ...
-targetn (np.ndarray): YOLOv3的ground truth在特征层n的位置转换信息,
形状为(特征层n的anchor数量, 6+类别数, 特征层n的h, 特征层n的w)。
n的是大小由anchor_masks的长度决定。
"""
im
=
batch_data
[
0
][
0
]
h
=
im
.
shape
[
1
]
w
=
im
.
shape
[
2
]
an_hw
=
np
.
array
(
self
.
anchors
)
/
np
.
array
([[
w
,
h
]])
for
data_id
,
data
in
enumerate
(
batch_data
):
gt_bbox
=
data
[
1
]
gt_class
=
data
[
2
]
gt_score
=
data
[
3
]
im_shape
=
data
[
4
]
origin_h
=
float
(
im_shape
[
0
])
origin_w
=
float
(
im_shape
[
1
])
data_list
=
list
(
data
)
for
i
,
mask
in
enumerate
(
self
.
anchor_masks
):
downsample_ratio
=
32
//
pow
(
2
,
i
)
grid_h
=
int
(
h
/
downsample_ratio
)
grid_w
=
int
(
w
/
downsample_ratio
)
target
=
np
.
zeros
(
(
len
(
mask
),
6
+
self
.
num_classes
,
grid_h
,
grid_w
),
dtype
=
np
.
float32
)
for
b
in
range
(
gt_bbox
.
shape
[
0
]):
gx
=
gt_bbox
[
b
,
0
]
/
float
(
origin_w
)
gy
=
gt_bbox
[
b
,
1
]
/
float
(
origin_h
)
gw
=
gt_bbox
[
b
,
2
]
/
float
(
origin_w
)
gh
=
gt_bbox
[
b
,
3
]
/
float
(
origin_h
)
cls
=
gt_class
[
b
]
score
=
gt_score
[
b
]
if
gw
<=
0.
or
gh
<=
0.
or
score
<=
0.
:
continue
# find best match anchor index
best_iou
=
0.
best_idx
=
-
1
for
an_idx
in
range
(
an_hw
.
shape
[
0
]):
iou
=
jaccard_overlap
(
[
0.
,
0.
,
gw
,
gh
],
[
0.
,
0.
,
an_hw
[
an_idx
,
0
],
an_hw
[
an_idx
,
1
]])
if
iou
>
best_iou
:
best_iou
=
iou
best_idx
=
an_idx
gi
=
int
(
gx
*
grid_w
)
gj
=
int
(
gy
*
grid_h
)
# gtbox should be regresed in this layes if best match
# anchor index in anchor mask of this layer
if
best_idx
in
mask
:
best_n
=
mask
.
index
(
best_idx
)
# x, y, w, h, scale
target
[
best_n
,
0
,
gj
,
gi
]
=
gx
*
grid_w
-
gi
target
[
best_n
,
1
,
gj
,
gi
]
=
gy
*
grid_h
-
gj
target
[
best_n
,
2
,
gj
,
gi
]
=
np
.
log
(
gw
*
w
/
self
.
anchors
[
best_idx
][
0
])
target
[
best_n
,
3
,
gj
,
gi
]
=
np
.
log
(
gh
*
h
/
self
.
anchors
[
best_idx
][
1
])
target
[
best_n
,
4
,
gj
,
gi
]
=
2.0
-
gw
*
gh
# objectness record gt_score
target
[
best_n
,
5
,
gj
,
gi
]
=
score
# classification
target
[
best_n
,
6
+
cls
,
gj
,
gi
]
=
1.
# For non-matched anchors, calculate the target if the iou
# between anchor and gt is larger than iou_thresh
if
self
.
iou_thresh
<
1
:
for
idx
,
mask_i
in
enumerate
(
mask
):
if
mask_i
==
best_idx
:
continue
iou
=
jaccard_overlap
(
[
0.
,
0.
,
gw
,
gh
],
[
0.
,
0.
,
an_hw
[
mask_i
,
0
],
an_hw
[
mask_i
,
1
]])
if
iou
>
self
.
iou_thresh
:
# x, y, w, h, scale
target
[
idx
,
0
,
gj
,
gi
]
=
gx
*
grid_w
-
gi
target
[
idx
,
1
,
gj
,
gi
]
=
gy
*
grid_h
-
gj
target
[
idx
,
2
,
gj
,
gi
]
=
np
.
log
(
gw
*
w
/
self
.
anchors
[
mask_i
][
0
])
target
[
idx
,
3
,
gj
,
gi
]
=
np
.
log
(
gh
*
h
/
self
.
anchors
[
mask_i
][
1
])
target
[
idx
,
4
,
gj
,
gi
]
=
2.0
-
gw
*
gh
# objectness record gt_score
target
[
idx
,
5
,
gj
,
gi
]
=
score
# classification
target
[
idx
,
6
+
cls
,
gj
,
gi
]
=
1.
data_list
.
append
(
target
)
batch_data
[
data_id
]
=
tuple
(
data_list
)
return
batch_data
if
mode
==
'train'
:
# 训练时的transforms,包含数据增强
transforms
=
[
MixupImage
(
mixup_epoch
=
mixup_epoch
),
RandomDistort
(),
RandomExpand
(),
RandomCrop
(),
Resize
(
target_size
=
width
,
interp
=
'RANDOM'
),
RandomHorizontalFlip
(),
Normalize
(
mean
=
mean
,
std
=
std
)
]
else
:
# 验证/预测时的transforms
transforms
=
[
Resize
(
target_size
=
width
,
interp
=
'CUBIC'
),
Normalize
(
mean
=
mean
,
std
=
std
)
]
super
(
ComposedYOLOTransforms
,
self
).
__init__
(
transforms
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录