Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleX
提交
f8085469
P
PaddleX
项目概览
PaddlePaddle
/
PaddleX
通知
138
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
43
列表
看板
标记
里程碑
合并请求
5
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleX
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
43
Issue
43
列表
看板
标记
里程碑
合并请求
5
合并请求
5
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f8085469
编写于
8月 06, 2020
作者:
J
Jason
提交者:
GitHub
8月 06, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #266 from FlyingQianMM/develop_qh
add ppyolo
上级
f82f69df
dadd35ef
变更
18
展开全部
隐藏空白更改
内联
并排
Showing
18 changed file
with
2187 addition
and
444 deletion
+2187
-444
paddlex/cv/__init__.py
paddlex/cv/__init__.py
+1
-0
paddlex/cv/datasets/dataset.py
paddlex/cv/datasets/dataset.py
+10
-8
paddlex/cv/models/__init__.py
paddlex/cv/models/__init__.py
+1
-0
paddlex/cv/models/base.py
paddlex/cv/models/base.py
+17
-8
paddlex/cv/models/ppyolo.py
paddlex/cv/models/ppyolo.py
+555
-0
paddlex/cv/models/yolo_v3.py
paddlex/cv/models/yolo_v3.py
+20
-321
paddlex/cv/nets/detection/iou_aware.py
paddlex/cv/nets/detection/iou_aware.py
+85
-0
paddlex/cv/nets/detection/loss/iou_aware_loss.py
paddlex/cv/nets/detection/loss/iou_aware_loss.py
+77
-0
paddlex/cv/nets/detection/loss/iou_loss.py
paddlex/cv/nets/detection/loss/iou_loss.py
+235
-0
paddlex/cv/nets/detection/loss/yolo_loss.py
paddlex/cv/nets/detection/loss/yolo_loss.py
+371
-0
paddlex/cv/nets/detection/ops.py
paddlex/cv/nets/detection/ops.py
+270
-0
paddlex/cv/nets/detection/yolo_v3.py
paddlex/cv/nets/detection/yolo_v3.py
+295
-105
paddlex/cv/transforms/__init__.py
paddlex/cv/transforms/__init__.py
+4
-1
paddlex/cv/transforms/cls_transforms.py
paddlex/cv/transforms/cls_transforms.py
+1
-1
paddlex/cv/transforms/det_transforms.py
paddlex/cv/transforms/det_transforms.py
+185
-0
paddlex/cv/transforms/seg_transforms.py
paddlex/cv/transforms/seg_transforms.py
+1
-0
paddlex/det.py
paddlex/det.py
+1
-0
tutorials/train/object_detection/ppyolo.py
tutorials/train/object_detection/ppyolo.py
+58
-0
未找到文件。
paddlex/cv/__init__.py
浏览文件 @
f8085469
...
@@ -26,6 +26,7 @@ ResNet50 = models.ResNet50
...
@@ -26,6 +26,7 @@ ResNet50 = models.ResNet50
DarkNet53
=
models
.
DarkNet53
DarkNet53
=
models
.
DarkNet53
# detection
# detection
YOLOv3
=
models
.
YOLOv3
YOLOv3
=
models
.
YOLOv3
PPYOLO
=
models
.
PPYOLO
#EAST = models.EAST
#EAST = models.EAST
FasterRCNN
=
models
.
FasterRCNN
FasterRCNN
=
models
.
FasterRCNN
MaskRCNN
=
models
.
MaskRCNN
MaskRCNN
=
models
.
MaskRCNN
...
...
paddlex/cv/datasets/dataset.py
浏览文件 @
f8085469
...
@@ -115,7 +115,7 @@ def multithread_reader(mapper,
...
@@ -115,7 +115,7 @@ def multithread_reader(mapper,
while
not
isinstance
(
sample
,
EndSignal
):
while
not
isinstance
(
sample
,
EndSignal
):
batch_data
.
append
(
sample
)
batch_data
.
append
(
sample
)
if
len
(
batch_data
)
==
batch_size
:
if
len
(
batch_data
)
==
batch_size
:
batch_data
=
generate_minibatch
(
batch_data
)
batch_data
=
generate_minibatch
(
batch_data
,
mapper
=
mapper
)
yield
batch_data
yield
batch_data
batch_data
=
[]
batch_data
=
[]
sample
=
out_queue
.
get
()
sample
=
out_queue
.
get
()
...
@@ -127,11 +127,11 @@ def multithread_reader(mapper,
...
@@ -127,11 +127,11 @@ def multithread_reader(mapper,
else
:
else
:
batch_data
.
append
(
sample
)
batch_data
.
append
(
sample
)
if
len
(
batch_data
)
==
batch_size
:
if
len
(
batch_data
)
==
batch_size
:
batch_data
=
generate_minibatch
(
batch_data
)
batch_data
=
generate_minibatch
(
batch_data
,
mapper
=
mapper
)
yield
batch_data
yield
batch_data
batch_data
=
[]
batch_data
=
[]
if
not
drop_last
and
len
(
batch_data
)
!=
0
:
if
not
drop_last
and
len
(
batch_data
)
!=
0
:
batch_data
=
generate_minibatch
(
batch_data
)
batch_data
=
generate_minibatch
(
batch_data
,
mapper
=
mapper
)
yield
batch_data
yield
batch_data
batch_data
=
[]
batch_data
=
[]
...
@@ -188,18 +188,21 @@ def multiprocess_reader(mapper,
...
@@ -188,18 +188,21 @@ def multiprocess_reader(mapper,
else
:
else
:
batch_data
.
append
(
sample
)
batch_data
.
append
(
sample
)
if
len
(
batch_data
)
==
batch_size
:
if
len
(
batch_data
)
==
batch_size
:
batch_data
=
generate_minibatch
(
batch_data
)
batch_data
=
generate_minibatch
(
batch_data
,
mapper
=
mapper
)
yield
batch_data
yield
batch_data
batch_data
=
[]
batch_data
=
[]
if
len
(
batch_data
)
!=
0
and
not
drop_last
:
if
len
(
batch_data
)
!=
0
and
not
drop_last
:
batch_data
=
generate_minibatch
(
batch_data
)
batch_data
=
generate_minibatch
(
batch_data
,
mapper
=
mapper
)
yield
batch_data
yield
batch_data
batch_data
=
[]
batch_data
=
[]
return
queue_reader
return
queue_reader
def
generate_minibatch
(
batch_data
,
label_padding_value
=
255
):
def
generate_minibatch
(
batch_data
,
label_padding_value
=
255
,
mapper
=
None
):
if
mapper
is
not
None
and
mapper
.
batch_transforms
is
not
None
:
for
op
in
mapper
.
batch_transforms
:
batch_data
=
op
(
batch_data
)
# if batch_size is 1, do not pad the image
# if batch_size is 1, do not pad the image
if
len
(
batch_data
)
==
1
:
if
len
(
batch_data
)
==
1
:
return
batch_data
return
batch_data
...
@@ -218,14 +221,13 @@ def generate_minibatch(batch_data, label_padding_value=255):
...
@@ -218,14 +221,13 @@ def generate_minibatch(batch_data, label_padding_value=255):
(
im_c
,
max_shape
[
1
],
max_shape
[
2
]),
dtype
=
np
.
float32
)
(
im_c
,
max_shape
[
1
],
max_shape
[
2
]),
dtype
=
np
.
float32
)
padding_im
[:,
:
im_h
,
:
im_w
]
=
data
[
0
]
padding_im
[:,
:
im_h
,
:
im_w
]
=
data
[
0
]
if
len
(
data
)
>
2
:
if
len
(
data
)
>
2
:
# padding the image, label and insert 'padding' into `im_info` of segmentation during evaluating phase.
# padding the image, label and insert 'padding' into `im_info` of segmentation during evaluating phase.
if
len
(
data
[
1
])
==
0
or
'padding'
not
in
[
if
len
(
data
[
1
])
==
0
or
'padding'
not
in
[
data
[
1
][
i
][
0
]
for
i
in
range
(
len
(
data
[
1
]))
data
[
1
][
i
][
0
]
for
i
in
range
(
len
(
data
[
1
]))
]:
]:
data
[
1
].
append
((
'padding'
,
[
im_h
,
im_w
]))
data
[
1
].
append
((
'padding'
,
[
im_h
,
im_w
]))
padding_batch
.
append
((
padding_im
,
data
[
1
],
data
[
2
]))
padding_batch
.
append
((
padding_im
,
data
[
1
],
data
[
2
]))
elif
len
(
data
)
>
1
:
elif
len
(
data
)
>
1
:
if
isinstance
(
data
[
1
],
np
.
ndarray
)
and
len
(
data
[
1
].
shape
)
>
1
:
if
isinstance
(
data
[
1
],
np
.
ndarray
)
and
len
(
data
[
1
].
shape
)
>
1
:
# padding the image and label of segmentation during the training
# padding the image and label of segmentation during the training
...
...
paddlex/cv/models/__init__.py
浏览文件 @
f8085469
...
@@ -38,6 +38,7 @@ from .classifier import HRNet_W18
...
@@ -38,6 +38,7 @@ from .classifier import HRNet_W18
from
.classifier
import
AlexNet
from
.classifier
import
AlexNet
from
.base
import
BaseAPI
from
.base
import
BaseAPI
from
.yolo_v3
import
YOLOv3
from
.yolo_v3
import
YOLOv3
from
.ppyolo
import
PPYOLO
from
.faster_rcnn
import
FasterRCNN
from
.faster_rcnn
import
FasterRCNN
from
.mask_rcnn
import
MaskRCNN
from
.mask_rcnn
import
MaskRCNN
from
.unet
import
UNet
from
.unet
import
UNet
...
...
paddlex/cv/models/base.py
浏览文件 @
f8085469
...
@@ -246,8 +246,8 @@ class BaseAPI:
...
@@ -246,8 +246,8 @@ class BaseAPI:
logging
.
info
(
logging
.
info
(
"Load pretrain weights from {}."
.
format
(
pretrain_weights
),
"Load pretrain weights from {}."
.
format
(
pretrain_weights
),
use_color
=
True
)
use_color
=
True
)
paddlex
.
utils
.
utils
.
load_pretrain_weights
(
self
.
exe
,
self
.
train_prog
,
paddlex
.
utils
.
utils
.
load_pretrain_weights
(
pretrain_weights
,
fuse_bn
)
self
.
exe
,
self
.
train_prog
,
pretrain_weights
,
fuse_bn
)
# 进行裁剪
# 进行裁剪
if
sensitivities_file
is
not
None
:
if
sensitivities_file
is
not
None
:
import
paddleslim
import
paddleslim
...
@@ -351,7 +351,9 @@ class BaseAPI:
...
@@ -351,7 +351,9 @@ class BaseAPI:
logging
.
info
(
"Model saved in {}."
.
format
(
save_dir
))
logging
.
info
(
"Model saved in {}."
.
format
(
save_dir
))
def
export_inference_model
(
self
,
save_dir
):
def
export_inference_model
(
self
,
save_dir
):
test_input_names
=
[
var
.
name
for
var
in
list
(
self
.
test_inputs
.
values
())]
test_input_names
=
[
var
.
name
for
var
in
list
(
self
.
test_inputs
.
values
())
]
test_outputs
=
list
(
self
.
test_outputs
.
values
())
test_outputs
=
list
(
self
.
test_outputs
.
values
())
with
fluid
.
scope_guard
(
self
.
scope
):
with
fluid
.
scope_guard
(
self
.
scope
):
if
self
.
__class__
.
__name__
==
'MaskRCNN'
:
if
self
.
__class__
.
__name__
==
'MaskRCNN'
:
...
@@ -389,7 +391,8 @@ class BaseAPI:
...
@@ -389,7 +391,8 @@ class BaseAPI:
# 模型保存成功的标志
# 模型保存成功的标志
open
(
osp
.
join
(
save_dir
,
'.success'
),
'w'
).
close
()
open
(
osp
.
join
(
save_dir
,
'.success'
),
'w'
).
close
()
logging
.
info
(
"Model for inference deploy saved in {}."
.
format
(
save_dir
))
logging
.
info
(
"Model for inference deploy saved in {}."
.
format
(
save_dir
))
def
train_loop
(
self
,
def
train_loop
(
self
,
num_epochs
,
num_epochs
,
...
@@ -516,11 +519,13 @@ class BaseAPI:
...
@@ -516,11 +519,13 @@ class BaseAPI:
eta
=
((
num_epochs
-
i
)
*
total_num_steps
-
step
-
1
eta
=
((
num_epochs
-
i
)
*
total_num_steps
-
step
-
1
)
*
avg_step_time
)
*
avg_step_time
if
time_eval_one_epoch
is
not
None
:
if
time_eval_one_epoch
is
not
None
:
eval_eta
=
(
total_eval_times
-
i
//
save_interval_epochs
eval_eta
=
(
)
*
time_eval_one_epoch
total_eval_times
-
i
//
save_interval_epochs
)
*
time_eval_one_epoch
else
:
else
:
eval_eta
=
(
total_eval_times
-
i
//
save_interval_epochs
eval_eta
=
(
)
*
total_num_steps_eval
*
avg_step_time
total_eval_times
-
i
//
save_interval_epochs
)
*
total_num_steps_eval
*
avg_step_time
eta_str
=
seconds_to_hms
(
eta
+
eval_eta
)
eta_str
=
seconds_to_hms
(
eta
+
eval_eta
)
logging
.
info
(
logging
.
info
(
...
@@ -543,6 +548,8 @@ class BaseAPI:
...
@@ -543,6 +548,8 @@ class BaseAPI:
current_save_dir
=
osp
.
join
(
save_dir
,
"epoch_{}"
.
format
(
i
+
1
))
current_save_dir
=
osp
.
join
(
save_dir
,
"epoch_{}"
.
format
(
i
+
1
))
if
not
osp
.
isdir
(
current_save_dir
):
if
not
osp
.
isdir
(
current_save_dir
):
os
.
makedirs
(
current_save_dir
)
os
.
makedirs
(
current_save_dir
)
if
hasattr
(
self
,
'use_ema'
):
self
.
exe
.
run
(
self
.
ema
.
apply_program
)
if
eval_dataset
is
not
None
and
eval_dataset
.
num_samples
>
0
:
if
eval_dataset
is
not
None
and
eval_dataset
.
num_samples
>
0
:
self
.
eval_metrics
,
self
.
eval_details
=
self
.
evaluate
(
self
.
eval_metrics
,
self
.
eval_details
=
self
.
evaluate
(
eval_dataset
=
eval_dataset
,
eval_dataset
=
eval_dataset
,
...
@@ -569,6 +576,8 @@ class BaseAPI:
...
@@ -569,6 +576,8 @@ class BaseAPI:
log_writer
.
add_scalar
(
log_writer
.
add_scalar
(
"Metrics/Eval(Epoch): {}"
.
format
(
k
),
v
,
i
+
1
)
"Metrics/Eval(Epoch): {}"
.
format
(
k
),
v
,
i
+
1
)
self
.
save_model
(
save_dir
=
current_save_dir
)
self
.
save_model
(
save_dir
=
current_save_dir
)
if
hasattr
(
self
,
'use_ema'
):
self
.
exe
.
run
(
self
.
ema
.
restore_program
)
time_eval_one_epoch
=
time
.
time
()
-
eval_epoch_start_time
time_eval_one_epoch
=
time
.
time
()
-
eval_epoch_start_time
eval_epoch_start_time
=
time
.
time
()
eval_epoch_start_time
=
time
.
time
()
if
best_model_epoch
>
0
:
if
best_model_epoch
>
0
:
...
...
paddlex/cv/models/ppyolo.py
0 → 100644
浏览文件 @
f8085469
此差异已折叠。
点击以展开。
paddlex/cv/models/yolo_v3.py
浏览文件 @
f8085469
...
@@ -15,21 +15,11 @@
...
@@ -15,21 +15,11 @@
from
__future__
import
absolute_import
from
__future__
import
absolute_import
import
math
import
math
import
tqdm
import
tqdm
import
os.path
as
osp
import
numpy
as
np
from
multiprocessing.pool
import
ThreadPool
import
paddle.fluid
as
fluid
import
paddlex.utils.logging
as
logging
import
paddlex
import
paddlex
import
copy
from
.ppyolo
import
PPYOLO
from
paddlex.cv.transforms
import
arrange_transforms
from
paddlex.cv.datasets
import
generate_minibatch
from
.base
import
BaseAPI
from
collections
import
OrderedDict
from
.utils.detection_eval
import
eval_results
,
bbox2out
class
YOLOv3
(
BaseAPI
):
class
YOLOv3
(
PPYOLO
):
"""构建YOLOv3,并实现其训练、评估、预测和模型导出。
"""构建YOLOv3,并实现其训练、评估、预测和模型导出。
Args:
Args:
...
@@ -65,12 +55,12 @@ class YOLOv3(BaseAPI):
...
@@ -65,12 +55,12 @@ class YOLOv3(BaseAPI):
320
,
352
,
384
,
416
,
448
,
480
,
512
,
544
,
576
,
608
320
,
352
,
384
,
416
,
448
,
480
,
512
,
544
,
576
,
608
]):
]):
self
.
init_params
=
locals
()
self
.
init_params
=
locals
()
super
(
YOLOv3
,
self
).
__init__
(
'detector'
)
backbones
=
[
backbones
=
[
'DarkNet53'
,
'ResNet34'
,
'MobileNetV1'
,
'MobileNetV3_large'
'DarkNet53'
,
'ResNet34'
,
'MobileNetV1'
,
'MobileNetV3_large'
]
]
assert
backbone
in
backbones
,
"backbone should be one of {}"
.
format
(
assert
backbone
in
backbones
,
"backbone should be one of {}"
.
format
(
backbones
)
backbones
)
super
(
YOLOv3
,
self
).
__init__
(
'detector'
)
self
.
backbone
=
backbone
self
.
backbone
=
backbone
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
self
.
anchors
=
anchors
self
.
anchors
=
anchors
...
@@ -84,6 +74,16 @@ class YOLOv3(BaseAPI):
...
@@ -84,6 +74,16 @@ class YOLOv3(BaseAPI):
self
.
sync_bn
=
True
self
.
sync_bn
=
True
self
.
train_random_shapes
=
train_random_shapes
self
.
train_random_shapes
=
train_random_shapes
self
.
fixed_input_shape
=
None
self
.
fixed_input_shape
=
None
self
.
use_fine_grained_loss
=
False
self
.
use_coord_conv
=
False
self
.
use_iou_aware
=
False
self
.
use_spp
=
False
self
.
use_drop_block
=
False
self
.
use_iou_loss
=
False
self
.
scale_x_y
=
1.
self
.
use_matrix_nms
=
False
self
.
use_ema
=
False
self
.
with_dcn_v2
=
False
def
_get_backbone
(
self
,
backbone_name
):
def
_get_backbone
(
self
,
backbone_name
):
if
backbone_name
==
'DarkNet53'
:
if
backbone_name
==
'DarkNet53'
:
...
@@ -104,59 +104,6 @@ class YOLOv3(BaseAPI):
...
@@ -104,59 +104,6 @@ class YOLOv3(BaseAPI):
norm_type
=
'sync_bn'
,
model_name
=
model_name
)
norm_type
=
'sync_bn'
,
model_name
=
model_name
)
return
backbone
return
backbone
def
build_net
(
self
,
mode
=
'train'
):
model
=
paddlex
.
cv
.
nets
.
detection
.
YOLOv3
(
backbone
=
self
.
_get_backbone
(
self
.
backbone
),
num_classes
=
self
.
num_classes
,
mode
=
mode
,
anchors
=
self
.
anchors
,
anchor_masks
=
self
.
anchor_masks
,
ignore_threshold
=
self
.
ignore_threshold
,
label_smooth
=
self
.
label_smooth
,
nms_score_threshold
=
self
.
nms_score_threshold
,
nms_topk
=
self
.
nms_topk
,
nms_keep_topk
=
self
.
nms_keep_topk
,
nms_iou_threshold
=
self
.
nms_iou_threshold
,
train_random_shapes
=
self
.
train_random_shapes
,
fixed_input_shape
=
self
.
fixed_input_shape
)
inputs
=
model
.
generate_inputs
()
model_out
=
model
.
build_net
(
inputs
)
outputs
=
OrderedDict
([(
'bbox'
,
model_out
)])
if
mode
==
'train'
:
self
.
optimizer
.
minimize
(
model_out
)
outputs
=
OrderedDict
([(
'loss'
,
model_out
)])
return
inputs
,
outputs
def
default_optimizer
(
self
,
learning_rate
,
warmup_steps
,
warmup_start_lr
,
lr_decay_epochs
,
lr_decay_gamma
,
num_steps_each_epoch
):
if
warmup_steps
>
lr_decay_epochs
[
0
]
*
num_steps_each_epoch
:
logging
.
error
(
"In function train(), parameters should satisfy: warmup_steps <= lr_decay_epochs[0]*num_samples_in_train_dataset"
,
exit
=
False
)
logging
.
error
(
"See this doc for more information: https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/appendix/parameters.md#notice"
,
exit
=
False
)
logging
.
error
(
"warmup_steps should less than {} or lr_decay_epochs[0] greater than {}, please modify 'lr_decay_epochs' or 'warmup_steps' in train function"
.
format
(
lr_decay_epochs
[
0
]
*
num_steps_each_epoch
,
warmup_steps
//
num_steps_each_epoch
))
boundaries
=
[
b
*
num_steps_each_epoch
for
b
in
lr_decay_epochs
]
values
=
[(
lr_decay_gamma
**
i
)
*
learning_rate
for
i
in
range
(
len
(
lr_decay_epochs
)
+
1
)]
lr_decay
=
fluid
.
layers
.
piecewise_decay
(
boundaries
=
boundaries
,
values
=
values
)
lr_warmup
=
fluid
.
layers
.
linear_lr_warmup
(
learning_rate
=
lr_decay
,
warmup_steps
=
warmup_steps
,
start_lr
=
warmup_start_lr
,
end_lr
=
learning_rate
)
optimizer
=
fluid
.
optimizer
.
Momentum
(
learning_rate
=
lr_warmup
,
momentum
=
0.9
,
regularization
=
fluid
.
regularizer
.
L2DecayRegularizer
(
5e-04
))
return
optimizer
def
train
(
self
,
def
train
(
self
,
num_epochs
,
num_epochs
,
train_dataset
,
train_dataset
,
...
@@ -214,259 +161,11 @@ class YOLOv3(BaseAPI):
...
@@ -214,259 +161,11 @@ class YOLOv3(BaseAPI):
ValueError: 评估类型不在指定列表中。
ValueError: 评估类型不在指定列表中。
ValueError: 模型从inference model进行加载。
ValueError: 模型从inference model进行加载。
"""
"""
if
not
self
.
trainable
:
raise
ValueError
(
"Model is not trainable from load_model method."
)
if
metric
is
None
:
if
isinstance
(
train_dataset
,
paddlex
.
datasets
.
CocoDetection
):
metric
=
'COCO'
elif
isinstance
(
train_dataset
,
paddlex
.
datasets
.
VOCDetection
)
or
\
isinstance
(
train_dataset
,
paddlex
.
datasets
.
EasyDataDet
):
metric
=
'VOC'
else
:
raise
ValueError
(
"train_dataset should be datasets.VOCDetection or datasets.COCODetection or datasets.EasyDataDet."
)
assert
metric
in
[
'COCO'
,
'VOC'
],
"Metric only support 'VOC' or 'COCO'"
self
.
metric
=
metric
self
.
labels
=
train_dataset
.
labels
# 构建训练网络
if
optimizer
is
None
:
# 构建默认的优化策略
num_steps_each_epoch
=
train_dataset
.
num_samples
//
train_batch_size
optimizer
=
self
.
default_optimizer
(
learning_rate
=
learning_rate
,
warmup_steps
=
warmup_steps
,
warmup_start_lr
=
warmup_start_lr
,
lr_decay_epochs
=
lr_decay_epochs
,
lr_decay_gamma
=
lr_decay_gamma
,
num_steps_each_epoch
=
num_steps_each_epoch
)
self
.
optimizer
=
optimizer
# 构建训练、验证、预测网络
self
.
build_program
()
# 初始化网络权重
self
.
net_initialize
(
startup_prog
=
fluid
.
default_startup_program
(),
pretrain_weights
=
pretrain_weights
,
save_dir
=
save_dir
,
sensitivities_file
=
sensitivities_file
,
eval_metric_loss
=
eval_metric_loss
,
resume_checkpoint
=
resume_checkpoint
)
# 训练
self
.
train_loop
(
num_epochs
=
num_epochs
,
train_dataset
=
train_dataset
,
train_batch_size
=
train_batch_size
,
eval_dataset
=
eval_dataset
,
save_interval_epochs
=
save_interval_epochs
,
log_interval_steps
=
log_interval_steps
,
save_dir
=
save_dir
,
use_vdl
=
use_vdl
,
early_stop
=
early_stop
,
early_stop_patience
=
early_stop_patience
)
def
evaluate
(
self
,
eval_dataset
,
batch_size
=
1
,
epoch_id
=
None
,
metric
=
None
,
return_details
=
False
):
"""评估。
Args:
eval_dataset (paddlex.datasets): 验证数据读取器。
batch_size (int): 验证数据批大小。默认为1。
epoch_id (int): 当前评估模型所在的训练轮数。
metric (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认为None,
根据用户传入的Dataset自动选择,如为VOCDetection,则metric为'VOC';
如为COCODetection,则metric为'COCO'。
return_details (bool): 是否返回详细信息。
Returns:
tuple (metrics, eval_details) | dict (metrics): 当return_details为True时,返回(metrics, eval_details),
当return_details为False时,返回metrics。metrics为dict,包含关键字:'bbox_mmap'或者’bbox_map‘,
分别表示平均准确率平均值在各个IoU阈值下的结果取平均值的结果(mmAP)、平均准确率平均值(mAP)。
eval_details为dict,包含关键字:'bbox',对应元素预测结果列表,每个预测结果由图像id、
预测框类别id、预测框坐标、预测框得分;’gt‘:真实标注框相关信息。
"""
arrange_transforms
(
model_type
=
self
.
model_type
,
class_name
=
self
.
__class__
.
__name__
,
transforms
=
eval_dataset
.
transforms
,
mode
=
'eval'
)
if
metric
is
None
:
if
hasattr
(
self
,
'metric'
)
and
self
.
metric
is
not
None
:
metric
=
self
.
metric
else
:
if
isinstance
(
eval_dataset
,
paddlex
.
datasets
.
CocoDetection
):
metric
=
'COCO'
elif
isinstance
(
eval_dataset
,
paddlex
.
datasets
.
VOCDetection
):
metric
=
'VOC'
else
:
raise
Exception
(
"eval_dataset should be datasets.VOCDetection or datasets.COCODetection."
)
assert
metric
in
[
'COCO'
,
'VOC'
],
"Metric only support 'VOC' or 'COCO'"
total_steps
=
math
.
ceil
(
eval_dataset
.
num_samples
*
1.0
/
batch_size
)
results
=
list
()
data_generator
=
eval_dataset
.
generator
(
batch_size
=
batch_size
,
drop_last
=
False
)
logging
.
info
(
"Start to evaluating(total_samples={}, total_steps={})..."
.
format
(
eval_dataset
.
num_samples
,
total_steps
))
for
step
,
data
in
tqdm
.
tqdm
(
enumerate
(
data_generator
()),
total
=
total_steps
):
images
=
np
.
array
([
d
[
0
]
for
d
in
data
])
im_sizes
=
np
.
array
([
d
[
1
]
for
d
in
data
])
feed_data
=
{
'image'
:
images
,
'im_size'
:
im_sizes
}
with
fluid
.
scope_guard
(
self
.
scope
):
outputs
=
self
.
exe
.
run
(
self
.
test_prog
,
feed
=
[
feed_data
],
fetch_list
=
list
(
self
.
test_outputs
.
values
()),
return_numpy
=
False
)
res
=
{
'bbox'
:
(
np
.
array
(
outputs
[
0
]),
outputs
[
0
].
recursive_sequence_lengths
())
}
res_id
=
[
np
.
array
([
d
[
2
]])
for
d
in
data
]
res
[
'im_id'
]
=
(
res_id
,
[])
if
metric
==
'VOC'
:
res_gt_box
=
[
d
[
3
].
reshape
(
-
1
,
4
)
for
d
in
data
]
res_gt_label
=
[
d
[
4
].
reshape
(
-
1
,
1
)
for
d
in
data
]
res_is_difficult
=
[
d
[
5
].
reshape
(
-
1
,
1
)
for
d
in
data
]
res_id
=
[
np
.
array
([
d
[
2
]])
for
d
in
data
]
res
[
'gt_box'
]
=
(
res_gt_box
,
[])
res
[
'gt_label'
]
=
(
res_gt_label
,
[])
res
[
'is_difficult'
]
=
(
res_is_difficult
,
[])
results
.
append
(
res
)
logging
.
debug
(
"[EVAL] Epoch={}, Step={}/{}"
.
format
(
epoch_id
,
step
+
1
,
total_steps
))
box_ap_stats
,
eval_details
=
eval_results
(
results
,
metric
,
eval_dataset
.
coco_gt
,
with_background
=
False
)
evaluate_metrics
=
OrderedDict
(
zip
([
'bbox_mmap'
if
metric
==
'COCO'
else
'bbox_map'
],
box_ap_stats
))
if
return_details
:
return
evaluate_metrics
,
eval_details
return
evaluate_metrics
@
staticmethod
def
_preprocess
(
images
,
transforms
,
model_type
,
class_name
,
thread_num
=
1
):
arrange_transforms
(
model_type
=
model_type
,
class_name
=
class_name
,
transforms
=
transforms
,
mode
=
'test'
)
pool
=
ThreadPool
(
thread_num
)
batch_data
=
pool
.
map
(
transforms
,
images
)
pool
.
close
()
pool
.
join
()
padding_batch
=
generate_minibatch
(
batch_data
)
im
=
np
.
array
(
[
data
[
0
]
for
data
in
padding_batch
],
dtype
=
padding_batch
[
0
][
0
].
dtype
)
im_size
=
np
.
array
([
data
[
1
]
for
data
in
padding_batch
],
dtype
=
np
.
int32
)
return
im
,
im_size
@
staticmethod
def
_postprocess
(
res
,
batch_size
,
num_classes
,
labels
):
clsid2catid
=
dict
({
i
:
i
for
i
in
range
(
num_classes
)})
xywh_results
=
bbox2out
([
res
],
clsid2catid
)
preds
=
[[]
for
i
in
range
(
batch_size
)]
for
xywh_res
in
xywh_results
:
image_id
=
xywh_res
[
'image_id'
]
del
xywh_res
[
'image_id'
]
xywh_res
[
'category'
]
=
labels
[
xywh_res
[
'category_id'
]]
preds
[
image_id
].
append
(
xywh_res
)
return
preds
def
predict
(
self
,
img_file
,
transforms
=
None
):
"""预测。
Args:
img_file (str|np.ndarray): 预测图像路径,或者是解码后的排列格式为(H, W, C)且类型为float32且为BGR格式的数组。
transforms (paddlex.det.transforms): 数据预处理操作。
Returns:
list: 预测结果列表,每个预测结果由预测框类别标签、
预测框类别名称、预测框坐标(坐标格式为[xmin, ymin, w, h])、
预测框得分组成。
"""
if
transforms
is
None
and
not
hasattr
(
self
,
'test_transforms'
):
raise
Exception
(
"transforms need to be defined, now is None."
)
if
isinstance
(
img_file
,
(
str
,
np
.
ndarray
)):
images
=
[
img_file
]
else
:
raise
Exception
(
"img_file must be str/np.ndarray"
)
if
transforms
is
None
:
transforms
=
self
.
test_transforms
im
,
im_size
=
YOLOv3
.
_preprocess
(
images
,
transforms
,
self
.
model_type
,
self
.
__class__
.
__name__
)
with
fluid
.
scope_guard
(
self
.
scope
):
result
=
self
.
exe
.
run
(
self
.
test_prog
,
feed
=
{
'image'
:
im
,
'im_size'
:
im_size
},
fetch_list
=
list
(
self
.
test_outputs
.
values
()),
return_numpy
=
False
,
use_program_cache
=
True
)
res
=
{
k
:
(
np
.
array
(
v
),
v
.
recursive_sequence_lengths
())
for
k
,
v
in
zip
(
list
(
self
.
test_outputs
.
keys
()),
result
)
}
res
[
'im_id'
]
=
(
np
.
array
(
[[
i
]
for
i
in
range
(
len
(
images
))]).
astype
(
'int32'
),
[[]])
preds
=
YOLOv3
.
_postprocess
(
res
,
len
(
images
),
self
.
num_classes
,
self
.
labels
)
return
preds
[
0
]
def
batch_predict
(
self
,
img_file_list
,
transforms
=
None
,
thread_num
=
2
):
"""预测。
Args:
img_file_list (list|tuple): 对列表(或元组)中的图像同时进行预测,列表中的元素可以是图像路径,也可以是解码后的排列格式为(H,W,C)
且类型为float32且为BGR格式的数组。
transforms (paddlex.det.transforms): 数据预处理操作。
thread_num (int): 并发执行各图像预处理时的线程数。
Returns:
list: 每个元素都为列表,表示各图像的预测结果。在各图像的预测结果列表中,每个预测结果由预测框类别标签、
预测框类别名称、预测框坐标(坐标格式为[xmin, ymin, w, h])、
预测框得分组成。
"""
if
transforms
is
None
and
not
hasattr
(
self
,
'test_transforms'
):
raise
Exception
(
"transforms need to be defined, now is None."
)
if
not
isinstance
(
img_file_list
,
(
list
,
tuple
)):
raise
Exception
(
"im_file must be list/tuple"
)
if
transforms
is
None
:
transforms
=
self
.
test_transforms
im
,
im_size
=
YOLOv3
.
_preprocess
(
img_file_list
,
transforms
,
self
.
model_type
,
self
.
__class__
.
__name__
,
thread_num
)
with
fluid
.
scope_guard
(
self
.
scope
):
result
=
self
.
exe
.
run
(
self
.
test_prog
,
feed
=
{
'image'
:
im
,
'im_size'
:
im_size
},
fetch_list
=
list
(
self
.
test_outputs
.
values
()),
return_numpy
=
False
,
use_program_cache
=
True
)
res
=
{
return
super
(
YOLOv3
,
self
).
train
(
k
:
(
np
.
array
(
v
),
v
.
recursive_sequence_lengths
())
num_epochs
,
train_dataset
,
train_batch_size
,
eval_dataset
,
for
k
,
v
in
zip
(
list
(
self
.
test_outputs
.
keys
()),
result
)
save_interval_epochs
,
log_interval_steps
,
save_dir
,
}
pretrain_weights
,
optimizer
,
learning_rate
,
warmup_steps
,
res
[
'im_id'
]
=
(
np
.
array
(
warmup_start_lr
,
lr_decay_epochs
,
lr_decay_gamma
,
metric
,
use_vdl
,
[[
i
]
for
i
in
range
(
len
(
img_file_list
))]).
astype
(
'int32'
),
[[]])
sensitivities_file
,
eval_metric_loss
,
early_stop
,
preds
=
YOLOv3
.
_postprocess
(
res
,
early_stop_patience
,
resume_checkpoint
,
False
)
len
(
img_file_list
),
self
.
num_classes
,
self
.
labels
)
return
preds
paddlex/cv/nets/detection/iou_aware.py
0 → 100644
浏览文件 @
f8085469
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
paddle
import
fluid
def
_split_ioup
(
output
,
an_num
,
num_classes
):
"""
Split new output feature map to output, predicted iou
along channel dimension
"""
ioup
=
fluid
.
layers
.
slice
(
output
,
axes
=
[
1
],
starts
=
[
0
],
ends
=
[
an_num
])
ioup
=
fluid
.
layers
.
sigmoid
(
ioup
)
oriout
=
fluid
.
layers
.
slice
(
output
,
axes
=
[
1
],
starts
=
[
an_num
],
ends
=
[
an_num
*
(
num_classes
+
6
)])
return
(
ioup
,
oriout
)
def
_de_sigmoid
(
x
,
eps
=
1e-7
):
x
=
fluid
.
layers
.
clip
(
x
,
eps
,
1
/
eps
)
one
=
fluid
.
layers
.
fill_constant
(
shape
=
[
1
,
1
,
1
,
1
],
dtype
=
x
.
dtype
,
value
=
1.
)
x
=
fluid
.
layers
.
clip
((
one
/
x
-
1.0
),
eps
,
1
/
eps
)
x
=
-
fluid
.
layers
.
log
(
x
)
return
x
def
_postprocess_output
(
ioup
,
output
,
an_num
,
num_classes
,
iou_aware_factor
):
"""
post process output objectness score
"""
tensors
=
[]
stride
=
output
.
shape
[
1
]
//
an_num
for
m
in
range
(
an_num
):
tensors
.
append
(
fluid
.
layers
.
slice
(
output
,
axes
=
[
1
],
starts
=
[
stride
*
m
+
0
],
ends
=
[
stride
*
m
+
4
]))
obj
=
fluid
.
layers
.
slice
(
output
,
axes
=
[
1
],
starts
=
[
stride
*
m
+
4
],
ends
=
[
stride
*
m
+
5
])
obj
=
fluid
.
layers
.
sigmoid
(
obj
)
ip
=
fluid
.
layers
.
slice
(
ioup
,
axes
=
[
1
],
starts
=
[
m
],
ends
=
[
m
+
1
])
new_obj
=
fluid
.
layers
.
pow
(
obj
,
(
1
-
iou_aware_factor
))
*
fluid
.
layers
.
pow
(
ip
,
iou_aware_factor
)
new_obj
=
_de_sigmoid
(
new_obj
)
tensors
.
append
(
new_obj
)
tensors
.
append
(
fluid
.
layers
.
slice
(
output
,
axes
=
[
1
],
starts
=
[
stride
*
m
+
5
],
ends
=
[
stride
*
m
+
5
+
num_classes
]))
output
=
fluid
.
layers
.
concat
(
tensors
,
axis
=
1
)
return
output
def
get_iou_aware_score
(
output
,
an_num
,
num_classes
,
iou_aware_factor
):
ioup
,
output
=
_split_ioup
(
output
,
an_num
,
num_classes
)
output
=
_postprocess_output
(
ioup
,
output
,
an_num
,
num_classes
,
iou_aware_factor
)
return
output
paddlex/cv/nets/detection/loss/iou_aware_loss.py
0 → 100644
浏览文件 @
f8085469
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.initializer
import
NumpyArrayInitializer
from
paddle
import
fluid
from
.iou_loss
import
IouLoss
class
IouAwareLoss
(
IouLoss
):
"""
iou aware loss, see https://arxiv.org/abs/1912.05992
Args:
loss_weight (float): iou aware loss weight, default is 1.0
max_height (int): max height of input to support random shape input
max_width (int): max width of input to support random shape input
"""
def
__init__
(
self
,
loss_weight
=
1.0
,
max_height
=
608
,
max_width
=
608
):
super
(
IouAwareLoss
,
self
).
__init__
(
loss_weight
=
loss_weight
,
max_height
=
max_height
,
max_width
=
max_width
)
def
__call__
(
self
,
ioup
,
x
,
y
,
w
,
h
,
tx
,
ty
,
tw
,
th
,
anchors
,
downsample_ratio
,
batch_size
,
scale_x_y
,
eps
=
1.e-10
):
'''
Args:
ioup ([Variables]): the predicted iou
x | y | w | h ([Variables]): the output of yolov3 for encoded x|y|w|h
tx |ty |tw |th ([Variables]): the target of yolov3 for encoded x|y|w|h
anchors ([float]): list of anchors for current output layer
downsample_ratio (float): the downsample ratio for current output layer
batch_size (int): training batch size
eps (float): the decimal to prevent the denominator eqaul zero
'''
pred
=
self
.
_bbox_transform
(
x
,
y
,
w
,
h
,
anchors
,
downsample_ratio
,
batch_size
,
False
,
scale_x_y
,
eps
)
gt
=
self
.
_bbox_transform
(
tx
,
ty
,
tw
,
th
,
anchors
,
downsample_ratio
,
batch_size
,
True
,
scale_x_y
,
eps
)
iouk
=
self
.
_iou
(
pred
,
gt
,
ioup
,
eps
)
iouk
.
stop_gradient
=
True
loss_iou_aware
=
fluid
.
layers
.
cross_entropy
(
ioup
,
iouk
,
soft_label
=
True
)
loss_iou_aware
=
loss_iou_aware
*
self
.
_loss_weight
return
loss_iou_aware
paddlex/cv/nets/detection/loss/iou_loss.py
0 → 100644
浏览文件 @
f8085469
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.initializer
import
NumpyArrayInitializer
from
paddle
import
fluid
class
IouLoss
(
object
):
"""
iou loss, see https://arxiv.org/abs/1908.03851
loss = 1.0 - iou * iou
Args:
loss_weight (float): iou loss weight, default is 2.5
max_height (int): max height of input to support random shape input
max_width (int): max width of input to support random shape input
ciou_term (bool): whether to add ciou_term
loss_square (bool): whether to square the iou term
"""
def
__init__
(
self
,
loss_weight
=
2.5
,
max_height
=
608
,
max_width
=
608
,
ciou_term
=
False
,
loss_square
=
True
):
self
.
_loss_weight
=
loss_weight
self
.
_MAX_HI
=
max_height
self
.
_MAX_WI
=
max_width
self
.
ciou_term
=
ciou_term
self
.
loss_square
=
loss_square
def
__call__
(
self
,
x
,
y
,
w
,
h
,
tx
,
ty
,
tw
,
th
,
anchors
,
downsample_ratio
,
batch_size
,
scale_x_y
=
1.
,
ioup
=
None
,
eps
=
1.e-10
):
'''
Args:
x | y | w | h ([Variables]): the output of yolov3 for encoded x|y|w|h
tx |ty |tw |th ([Variables]): the target of yolov3 for encoded x|y|w|h
anchors ([float]): list of anchors for current output layer
downsample_ratio (float): the downsample ratio for current output layer
batch_size (int): training batch size
eps (float): the decimal to prevent the denominator eqaul zero
'''
pred
=
self
.
_bbox_transform
(
x
,
y
,
w
,
h
,
anchors
,
downsample_ratio
,
batch_size
,
False
,
scale_x_y
,
eps
)
gt
=
self
.
_bbox_transform
(
tx
,
ty
,
tw
,
th
,
anchors
,
downsample_ratio
,
batch_size
,
True
,
scale_x_y
,
eps
)
iouk
=
self
.
_iou
(
pred
,
gt
,
ioup
,
eps
)
if
self
.
loss_square
:
loss_iou
=
1.
-
iouk
*
iouk
else
:
loss_iou
=
1.
-
iouk
loss_iou
=
loss_iou
*
self
.
_loss_weight
return
loss_iou
def
_iou
(
self
,
pred
,
gt
,
ioup
=
None
,
eps
=
1.e-10
):
x1
,
y1
,
x2
,
y2
=
pred
x1g
,
y1g
,
x2g
,
y2g
=
gt
x2
=
fluid
.
layers
.
elementwise_max
(
x1
,
x2
)
y2
=
fluid
.
layers
.
elementwise_max
(
y1
,
y2
)
xkis1
=
fluid
.
layers
.
elementwise_max
(
x1
,
x1g
)
ykis1
=
fluid
.
layers
.
elementwise_max
(
y1
,
y1g
)
xkis2
=
fluid
.
layers
.
elementwise_min
(
x2
,
x2g
)
ykis2
=
fluid
.
layers
.
elementwise_min
(
y2
,
y2g
)
intsctk
=
(
xkis2
-
xkis1
)
*
(
ykis2
-
ykis1
)
intsctk
=
intsctk
*
fluid
.
layers
.
greater_than
(
xkis2
,
xkis1
)
*
fluid
.
layers
.
greater_than
(
ykis2
,
ykis1
)
unionk
=
(
x2
-
x1
)
*
(
y2
-
y1
)
+
(
x2g
-
x1g
)
*
(
y2g
-
y1g
)
-
intsctk
+
eps
iouk
=
intsctk
/
unionk
if
self
.
ciou_term
:
ciou
=
self
.
get_ciou_term
(
pred
,
gt
,
iouk
,
eps
)
iouk
=
iouk
-
ciou
return
iouk
def
get_ciou_term
(
self
,
pred
,
gt
,
iouk
,
eps
):
x1
,
y1
,
x2
,
y2
=
pred
x1g
,
y1g
,
x2g
,
y2g
=
gt
cx
=
(
x1
+
x2
)
/
2
cy
=
(
y1
+
y2
)
/
2
w
=
(
x2
-
x1
)
+
fluid
.
layers
.
cast
((
x2
-
x1
)
==
0
,
'float32'
)
h
=
(
y2
-
y1
)
+
fluid
.
layers
.
cast
((
y2
-
y1
)
==
0
,
'float32'
)
cxg
=
(
x1g
+
x2g
)
/
2
cyg
=
(
y1g
+
y2g
)
/
2
wg
=
x2g
-
x1g
hg
=
y2g
-
y1g
# A or B
xc1
=
fluid
.
layers
.
elementwise_min
(
x1
,
x1g
)
yc1
=
fluid
.
layers
.
elementwise_min
(
y1
,
y1g
)
xc2
=
fluid
.
layers
.
elementwise_max
(
x2
,
x2g
)
yc2
=
fluid
.
layers
.
elementwise_max
(
y2
,
y2g
)
# DIOU term
dist_intersection
=
(
cx
-
cxg
)
*
(
cx
-
cxg
)
+
(
cy
-
cyg
)
*
(
cy
-
cyg
)
dist_union
=
(
xc2
-
xc1
)
*
(
xc2
-
xc1
)
+
(
yc2
-
yc1
)
*
(
yc2
-
yc1
)
diou_term
=
(
dist_intersection
+
eps
)
/
(
dist_union
+
eps
)
# CIOU term
ciou_term
=
0
ar_gt
=
wg
/
hg
ar_pred
=
w
/
h
arctan
=
fluid
.
layers
.
atan
(
ar_gt
)
-
fluid
.
layers
.
atan
(
ar_pred
)
ar_loss
=
4.
/
np
.
pi
/
np
.
pi
*
arctan
*
arctan
alpha
=
ar_loss
/
(
1
-
iouk
+
ar_loss
+
eps
)
alpha
.
stop_gradient
=
True
ciou_term
=
alpha
*
ar_loss
return
diou_term
+
ciou_term
def
_bbox_transform
(
self
,
dcx
,
dcy
,
dw
,
dh
,
anchors
,
downsample_ratio
,
batch_size
,
is_gt
,
scale_x_y
,
eps
):
grid_x
=
int
(
self
.
_MAX_WI
/
downsample_ratio
)
grid_y
=
int
(
self
.
_MAX_HI
/
downsample_ratio
)
an_num
=
len
(
anchors
)
//
2
shape_fmp
=
fluid
.
layers
.
shape
(
dcx
)
shape_fmp
.
stop_gradient
=
True
# generate the grid_w x grid_h center of feature map
idx_i
=
np
.
array
([[
i
for
i
in
range
(
grid_x
)]])
idx_j
=
np
.
array
([[
j
for
j
in
range
(
grid_y
)]]).
transpose
()
gi_np
=
np
.
repeat
(
idx_i
,
grid_y
,
axis
=
0
)
gi_np
=
np
.
reshape
(
gi_np
,
newshape
=
[
1
,
1
,
grid_y
,
grid_x
])
gi_np
=
np
.
tile
(
gi_np
,
reps
=
[
batch_size
,
an_num
,
1
,
1
])
gj_np
=
np
.
repeat
(
idx_j
,
grid_x
,
axis
=
1
)
gj_np
=
np
.
reshape
(
gj_np
,
newshape
=
[
1
,
1
,
grid_y
,
grid_x
])
gj_np
=
np
.
tile
(
gj_np
,
reps
=
[
batch_size
,
an_num
,
1
,
1
])
gi_max
=
self
.
_create_tensor_from_numpy
(
gi_np
.
astype
(
np
.
float32
))
gi
=
fluid
.
layers
.
crop
(
x
=
gi_max
,
shape
=
dcx
)
gi
.
stop_gradient
=
True
gj_max
=
self
.
_create_tensor_from_numpy
(
gj_np
.
astype
(
np
.
float32
))
gj
=
fluid
.
layers
.
crop
(
x
=
gj_max
,
shape
=
dcx
)
gj
.
stop_gradient
=
True
grid_x_act
=
fluid
.
layers
.
cast
(
shape_fmp
[
3
],
dtype
=
"float32"
)
grid_x_act
.
stop_gradient
=
True
grid_y_act
=
fluid
.
layers
.
cast
(
shape_fmp
[
2
],
dtype
=
"float32"
)
grid_y_act
.
stop_gradient
=
True
if
is_gt
:
cx
=
fluid
.
layers
.
elementwise_add
(
dcx
,
gi
)
/
grid_x_act
cx
.
gradient
=
True
cy
=
fluid
.
layers
.
elementwise_add
(
dcy
,
gj
)
/
grid_y_act
cy
.
gradient
=
True
else
:
dcx_sig
=
fluid
.
layers
.
sigmoid
(
dcx
)
dcy_sig
=
fluid
.
layers
.
sigmoid
(
dcy
)
if
(
abs
(
scale_x_y
-
1.0
)
>
eps
):
dcx_sig
=
scale_x_y
*
dcx_sig
-
0.5
*
(
scale_x_y
-
1
)
dcy_sig
=
scale_x_y
*
dcy_sig
-
0.5
*
(
scale_x_y
-
1
)
cx
=
fluid
.
layers
.
elementwise_add
(
dcx_sig
,
gi
)
/
grid_x_act
cy
=
fluid
.
layers
.
elementwise_add
(
dcy_sig
,
gj
)
/
grid_y_act
anchor_w_
=
[
anchors
[
i
]
for
i
in
range
(
0
,
len
(
anchors
))
if
i
%
2
==
0
]
anchor_w_np
=
np
.
array
(
anchor_w_
)
anchor_w_np
=
np
.
reshape
(
anchor_w_np
,
newshape
=
[
1
,
an_num
,
1
,
1
])
anchor_w_np
=
np
.
tile
(
anchor_w_np
,
reps
=
[
batch_size
,
1
,
grid_y
,
grid_x
])
anchor_w_max
=
self
.
_create_tensor_from_numpy
(
anchor_w_np
.
astype
(
np
.
float32
))
anchor_w
=
fluid
.
layers
.
crop
(
x
=
anchor_w_max
,
shape
=
dcx
)
anchor_w
.
stop_gradient
=
True
anchor_h_
=
[
anchors
[
i
]
for
i
in
range
(
0
,
len
(
anchors
))
if
i
%
2
==
1
]
anchor_h_np
=
np
.
array
(
anchor_h_
)
anchor_h_np
=
np
.
reshape
(
anchor_h_np
,
newshape
=
[
1
,
an_num
,
1
,
1
])
anchor_h_np
=
np
.
tile
(
anchor_h_np
,
reps
=
[
batch_size
,
1
,
grid_y
,
grid_x
])
anchor_h_max
=
self
.
_create_tensor_from_numpy
(
anchor_h_np
.
astype
(
np
.
float32
))
anchor_h
=
fluid
.
layers
.
crop
(
x
=
anchor_h_max
,
shape
=
dcx
)
anchor_h
.
stop_gradient
=
True
# e^tw e^th
exp_dw
=
fluid
.
layers
.
exp
(
dw
)
exp_dh
=
fluid
.
layers
.
exp
(
dh
)
pw
=
fluid
.
layers
.
elementwise_mul
(
exp_dw
,
anchor_w
)
/
\
(
grid_x_act
*
downsample_ratio
)
ph
=
fluid
.
layers
.
elementwise_mul
(
exp_dh
,
anchor_h
)
/
\
(
grid_y_act
*
downsample_ratio
)
if
is_gt
:
exp_dw
.
stop_gradient
=
True
exp_dh
.
stop_gradient
=
True
pw
.
stop_gradient
=
True
ph
.
stop_gradient
=
True
x1
=
cx
-
0.5
*
pw
y1
=
cy
-
0.5
*
ph
x2
=
cx
+
0.5
*
pw
y2
=
cy
+
0.5
*
ph
if
is_gt
:
x1
.
stop_gradient
=
True
y1
.
stop_gradient
=
True
x2
.
stop_gradient
=
True
y2
.
stop_gradient
=
True
return
x1
,
y1
,
x2
,
y2
def
_create_tensor_from_numpy
(
self
,
numpy_array
):
paddle_array
=
fluid
.
layers
.
create_parameter
(
attr
=
ParamAttr
(),
shape
=
numpy_array
.
shape
,
dtype
=
numpy_array
.
dtype
,
default_initializer
=
NumpyArrayInitializer
(
numpy_array
))
paddle_array
.
stop_gradient
=
True
return
paddle_array
paddlex/cv/nets/detection/loss/yolo_loss.py
0 → 100644
浏览文件 @
f8085469
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
paddle
import
fluid
try
:
from
collections.abc
import
Sequence
except
Exception
:
from
collections
import
Sequence
class
YOLOv3Loss
(
object
):
"""
Combined loss for YOLOv3 network
Args:
batch_size (int): training batch size
ignore_thresh (float): threshold to ignore confidence loss
label_smooth (bool): whether to use label smoothing
use_fine_grained_loss (bool): whether use fine grained YOLOv3 loss
instead of fluid.layers.yolov3_loss
"""
def
__init__
(
self
,
batch_size
=
8
,
ignore_thresh
=
0.7
,
label_smooth
=
True
,
use_fine_grained_loss
=
False
,
iou_loss
=
None
,
iou_aware_loss
=
None
,
downsample
=
[
32
,
16
,
8
],
scale_x_y
=
1.
,
match_score
=
False
):
self
.
_batch_size
=
batch_size
self
.
_ignore_thresh
=
ignore_thresh
self
.
_label_smooth
=
label_smooth
self
.
_use_fine_grained_loss
=
use_fine_grained_loss
self
.
_iou_loss
=
iou_loss
self
.
_iou_aware_loss
=
iou_aware_loss
self
.
downsample
=
downsample
self
.
scale_x_y
=
scale_x_y
self
.
match_score
=
match_score
def
__call__
(
self
,
outputs
,
gt_box
,
gt_label
,
gt_score
,
targets
,
anchors
,
anchor_masks
,
mask_anchors
,
num_classes
,
prefix_name
):
if
self
.
_use_fine_grained_loss
:
return
self
.
_get_fine_grained_loss
(
outputs
,
targets
,
gt_box
,
self
.
_batch_size
,
num_classes
,
mask_anchors
,
self
.
_ignore_thresh
)
else
:
losses
=
[]
for
i
,
output
in
enumerate
(
outputs
):
scale_x_y
=
self
.
scale_x_y
if
not
isinstance
(
self
.
scale_x_y
,
Sequence
)
else
self
.
scale_x_y
[
i
]
anchor_mask
=
anchor_masks
[
i
]
loss
=
fluid
.
layers
.
yolov3_loss
(
x
=
output
,
gt_box
=
gt_box
,
gt_label
=
gt_label
,
gt_score
=
gt_score
,
anchors
=
anchors
,
anchor_mask
=
anchor_mask
,
class_num
=
num_classes
,
ignore_thresh
=
self
.
_ignore_thresh
,
downsample_ratio
=
self
.
downsample
[
i
],
use_label_smooth
=
self
.
_label_smooth
,
scale_x_y
=
scale_x_y
,
name
=
prefix_name
+
"yolo_loss"
+
str
(
i
))
losses
.
append
(
fluid
.
layers
.
reduce_mean
(
loss
))
return
{
'loss'
:
sum
(
losses
)}
def
_get_fine_grained_loss
(
self
,
outputs
,
targets
,
gt_box
,
batch_size
,
num_classes
,
mask_anchors
,
ignore_thresh
,
eps
=
1.e-10
):
"""
Calculate fine grained YOLOv3 loss
Args:
outputs ([Variables]): List of Variables, output of backbone stages
targets ([Variables]): List of Variables, The targets for yolo
loss calculatation.
gt_box (Variable): The ground-truth boudding boxes.
batch_size (int): The training batch size
num_classes (int): class num of dataset
mask_anchors ([[float]]): list of anchors in each output layer
ignore_thresh (float): prediction bbox overlap any gt_box greater
than ignore_thresh, objectness loss will
be ignored.
Returns:
Type: dict
xy_loss (Variable): YOLOv3 (x, y) coordinates loss
wh_loss (Variable): YOLOv3 (w, h) coordinates loss
obj_loss (Variable): YOLOv3 objectness score loss
cls_loss (Variable): YOLOv3 classification loss
"""
assert
len
(
outputs
)
==
len
(
targets
),
\
"YOLOv3 output layer number not equal target number"
loss_xys
,
loss_whs
,
loss_objs
,
loss_clss
=
[],
[],
[],
[]
if
self
.
_iou_loss
is
not
None
:
loss_ious
=
[]
if
self
.
_iou_aware_loss
is
not
None
:
loss_iou_awares
=
[]
for
i
,
(
output
,
target
,
anchors
)
in
enumerate
(
zip
(
outputs
,
targets
,
mask_anchors
)):
downsample
=
self
.
downsample
[
i
]
an_num
=
len
(
anchors
)
//
2
if
self
.
_iou_aware_loss
is
not
None
:
ioup
,
output
=
self
.
_split_ioup
(
output
,
an_num
,
num_classes
)
x
,
y
,
w
,
h
,
obj
,
cls
=
self
.
_split_output
(
output
,
an_num
,
num_classes
)
tx
,
ty
,
tw
,
th
,
tscale
,
tobj
,
tcls
=
self
.
_split_target
(
target
)
tscale_tobj
=
tscale
*
tobj
scale_x_y
=
self
.
scale_x_y
if
not
isinstance
(
self
.
scale_x_y
,
Sequence
)
else
self
.
scale_x_y
[
i
]
if
(
abs
(
scale_x_y
-
1.0
)
<
eps
):
loss_x
=
fluid
.
layers
.
sigmoid_cross_entropy_with_logits
(
x
,
tx
)
*
tscale_tobj
loss_x
=
fluid
.
layers
.
reduce_sum
(
loss_x
,
dim
=
[
1
,
2
,
3
])
loss_y
=
fluid
.
layers
.
sigmoid_cross_entropy_with_logits
(
y
,
ty
)
*
tscale_tobj
loss_y
=
fluid
.
layers
.
reduce_sum
(
loss_y
,
dim
=
[
1
,
2
,
3
])
else
:
dx
=
scale_x_y
*
fluid
.
layers
.
sigmoid
(
x
)
-
0.5
*
(
scale_x_y
-
1.0
)
dy
=
scale_x_y
*
fluid
.
layers
.
sigmoid
(
y
)
-
0.5
*
(
scale_x_y
-
1.0
)
loss_x
=
fluid
.
layers
.
abs
(
dx
-
tx
)
*
tscale_tobj
loss_x
=
fluid
.
layers
.
reduce_sum
(
loss_x
,
dim
=
[
1
,
2
,
3
])
loss_y
=
fluid
.
layers
.
abs
(
dy
-
ty
)
*
tscale_tobj
loss_y
=
fluid
.
layers
.
reduce_sum
(
loss_y
,
dim
=
[
1
,
2
,
3
])
# NOTE: we refined loss function of (w, h) as L1Loss
loss_w
=
fluid
.
layers
.
abs
(
w
-
tw
)
*
tscale_tobj
loss_w
=
fluid
.
layers
.
reduce_sum
(
loss_w
,
dim
=
[
1
,
2
,
3
])
loss_h
=
fluid
.
layers
.
abs
(
h
-
th
)
*
tscale_tobj
loss_h
=
fluid
.
layers
.
reduce_sum
(
loss_h
,
dim
=
[
1
,
2
,
3
])
if
self
.
_iou_loss
is
not
None
:
loss_iou
=
self
.
_iou_loss
(
x
,
y
,
w
,
h
,
tx
,
ty
,
tw
,
th
,
anchors
,
downsample
,
self
.
_batch_size
,
scale_x_y
)
loss_iou
=
loss_iou
*
tscale_tobj
loss_iou
=
fluid
.
layers
.
reduce_sum
(
loss_iou
,
dim
=
[
1
,
2
,
3
])
loss_ious
.
append
(
fluid
.
layers
.
reduce_mean
(
loss_iou
))
if
self
.
_iou_aware_loss
is
not
None
:
loss_iou_aware
=
self
.
_iou_aware_loss
(
ioup
,
x
,
y
,
w
,
h
,
tx
,
ty
,
tw
,
th
,
anchors
,
downsample
,
self
.
_batch_size
,
scale_x_y
)
loss_iou_aware
=
loss_iou_aware
*
tobj
loss_iou_aware
=
fluid
.
layers
.
reduce_sum
(
loss_iou_aware
,
dim
=
[
1
,
2
,
3
])
loss_iou_awares
.
append
(
fluid
.
layers
.
reduce_mean
(
loss_iou_aware
))
loss_obj_pos
,
loss_obj_neg
=
self
.
_calc_obj_loss
(
output
,
obj
,
tobj
,
gt_box
,
self
.
_batch_size
,
anchors
,
num_classes
,
downsample
,
self
.
_ignore_thresh
,
scale_x_y
)
loss_cls
=
fluid
.
layers
.
sigmoid_cross_entropy_with_logits
(
cls
,
tcls
)
loss_cls
=
fluid
.
layers
.
elementwise_mul
(
loss_cls
,
tobj
,
axis
=
0
)
loss_cls
=
fluid
.
layers
.
reduce_sum
(
loss_cls
,
dim
=
[
1
,
2
,
3
,
4
])
loss_xys
.
append
(
fluid
.
layers
.
reduce_mean
(
loss_x
+
loss_y
))
loss_whs
.
append
(
fluid
.
layers
.
reduce_mean
(
loss_w
+
loss_h
))
loss_objs
.
append
(
fluid
.
layers
.
reduce_mean
(
loss_obj_pos
+
loss_obj_neg
))
loss_clss
.
append
(
fluid
.
layers
.
reduce_mean
(
loss_cls
))
losses_all
=
{
"loss_xy"
:
fluid
.
layers
.
sum
(
loss_xys
),
"loss_wh"
:
fluid
.
layers
.
sum
(
loss_whs
),
"loss_obj"
:
fluid
.
layers
.
sum
(
loss_objs
),
"loss_cls"
:
fluid
.
layers
.
sum
(
loss_clss
),
}
if
self
.
_iou_loss
is
not
None
:
losses_all
[
"loss_iou"
]
=
fluid
.
layers
.
sum
(
loss_ious
)
if
self
.
_iou_aware_loss
is
not
None
:
losses_all
[
"loss_iou_aware"
]
=
fluid
.
layers
.
sum
(
loss_iou_awares
)
return
losses_all
def
_split_ioup
(
self
,
output
,
an_num
,
num_classes
):
"""
Split output feature map to output, predicted iou
along channel dimension
"""
ioup
=
fluid
.
layers
.
slice
(
output
,
axes
=
[
1
],
starts
=
[
0
],
ends
=
[
an_num
])
ioup
=
fluid
.
layers
.
sigmoid
(
ioup
)
oriout
=
fluid
.
layers
.
slice
(
output
,
axes
=
[
1
],
starts
=
[
an_num
],
ends
=
[
an_num
*
(
num_classes
+
6
)])
return
(
ioup
,
oriout
)
def
_split_output
(
self
,
output
,
an_num
,
num_classes
):
"""
Split output feature map to x, y, w, h, objectness, classification
along channel dimension
"""
x
=
fluid
.
layers
.
strided_slice
(
output
,
axes
=
[
1
],
starts
=
[
0
],
ends
=
[
output
.
shape
[
1
]],
strides
=
[
5
+
num_classes
])
y
=
fluid
.
layers
.
strided_slice
(
output
,
axes
=
[
1
],
starts
=
[
1
],
ends
=
[
output
.
shape
[
1
]],
strides
=
[
5
+
num_classes
])
w
=
fluid
.
layers
.
strided_slice
(
output
,
axes
=
[
1
],
starts
=
[
2
],
ends
=
[
output
.
shape
[
1
]],
strides
=
[
5
+
num_classes
])
h
=
fluid
.
layers
.
strided_slice
(
output
,
axes
=
[
1
],
starts
=
[
3
],
ends
=
[
output
.
shape
[
1
]],
strides
=
[
5
+
num_classes
])
obj
=
fluid
.
layers
.
strided_slice
(
output
,
axes
=
[
1
],
starts
=
[
4
],
ends
=
[
output
.
shape
[
1
]],
strides
=
[
5
+
num_classes
])
clss
=
[]
stride
=
output
.
shape
[
1
]
//
an_num
for
m
in
range
(
an_num
):
clss
.
append
(
fluid
.
layers
.
slice
(
output
,
axes
=
[
1
],
starts
=
[
stride
*
m
+
5
],
ends
=
[
stride
*
m
+
5
+
num_classes
]))
cls
=
fluid
.
layers
.
transpose
(
fluid
.
layers
.
stack
(
clss
,
axis
=
1
),
perm
=
[
0
,
1
,
3
,
4
,
2
])
return
(
x
,
y
,
w
,
h
,
obj
,
cls
)
def
_split_target
(
self
,
target
):
"""
split target to x, y, w, h, objectness, classification
along dimension 2
target is in shape [N, an_num, 6 + class_num, H, W]
"""
tx
=
target
[:,
:,
0
,
:,
:]
ty
=
target
[:,
:,
1
,
:,
:]
tw
=
target
[:,
:,
2
,
:,
:]
th
=
target
[:,
:,
3
,
:,
:]
tscale
=
target
[:,
:,
4
,
:,
:]
tobj
=
target
[:,
:,
5
,
:,
:]
tcls
=
fluid
.
layers
.
transpose
(
target
[:,
:,
6
:,
:,
:],
perm
=
[
0
,
1
,
3
,
4
,
2
])
tcls
.
stop_gradient
=
True
return
(
tx
,
ty
,
tw
,
th
,
tscale
,
tobj
,
tcls
)
def
_calc_obj_loss
(
self
,
output
,
obj
,
tobj
,
gt_box
,
batch_size
,
anchors
,
num_classes
,
downsample
,
ignore_thresh
,
scale_x_y
):
# A prediction bbox overlap any gt_bbox over ignore_thresh,
# objectness loss will be ignored, process as follows:
# 1. get pred bbox, which is same with YOLOv3 infer mode, use yolo_box here
# NOTE: img_size is set as 1.0 to get noramlized pred bbox
bbox
,
prob
=
fluid
.
layers
.
yolo_box
(
x
=
output
,
img_size
=
fluid
.
layers
.
ones
(
shape
=
[
batch_size
,
2
],
dtype
=
"int32"
),
anchors
=
anchors
,
class_num
=
num_classes
,
conf_thresh
=
0.
,
downsample_ratio
=
downsample
,
clip_bbox
=
False
,
scale_x_y
=
scale_x_y
)
# 2. split pred bbox and gt bbox by sample, calculate IoU between pred bbox
# and gt bbox in each sample
if
batch_size
>
1
:
preds
=
fluid
.
layers
.
split
(
bbox
,
batch_size
,
dim
=
0
)
gts
=
fluid
.
layers
.
split
(
gt_box
,
batch_size
,
dim
=
0
)
else
:
preds
=
[
bbox
]
gts
=
[
gt_box
]
probs
=
[
prob
]
ious
=
[]
for
pred
,
gt
in
zip
(
preds
,
gts
):
def
box_xywh2xyxy
(
box
):
x
=
box
[:,
0
]
y
=
box
[:,
1
]
w
=
box
[:,
2
]
h
=
box
[:,
3
]
return
fluid
.
layers
.
stack
(
[
x
-
w
/
2.
,
y
-
h
/
2.
,
x
+
w
/
2.
,
y
+
h
/
2.
,
],
axis
=
1
)
pred
=
fluid
.
layers
.
squeeze
(
pred
,
axes
=
[
0
])
gt
=
box_xywh2xyxy
(
fluid
.
layers
.
squeeze
(
gt
,
axes
=
[
0
]))
ious
.
append
(
fluid
.
layers
.
iou_similarity
(
pred
,
gt
))
iou
=
fluid
.
layers
.
stack
(
ious
,
axis
=
0
)
# 3. Get iou_mask by IoU between gt bbox and prediction bbox,
# Get obj_mask by tobj(holds gt_score), calculate objectness loss
max_iou
=
fluid
.
layers
.
reduce_max
(
iou
,
dim
=-
1
)
iou_mask
=
fluid
.
layers
.
cast
(
max_iou
<=
ignore_thresh
,
dtype
=
"float32"
)
if
self
.
match_score
:
max_prob
=
fluid
.
layers
.
reduce_max
(
prob
,
dim
=-
1
)
iou_mask
=
iou_mask
*
fluid
.
layers
.
cast
(
max_prob
<=
0.25
,
dtype
=
"float32"
)
output_shape
=
fluid
.
layers
.
shape
(
output
)
an_num
=
len
(
anchors
)
//
2
iou_mask
=
fluid
.
layers
.
reshape
(
iou_mask
,
(
-
1
,
an_num
,
output_shape
[
2
],
output_shape
[
3
]))
iou_mask
.
stop_gradient
=
True
# NOTE: tobj holds gt_score, obj_mask holds object existence mask
obj_mask
=
fluid
.
layers
.
cast
(
tobj
>
0.
,
dtype
=
"float32"
)
obj_mask
.
stop_gradient
=
True
# For positive objectness grids, objectness loss should be calculated
# For negative objectness grids, objectness loss is calculated only iou_mask == 1.0
loss_obj
=
fluid
.
layers
.
sigmoid_cross_entropy_with_logits
(
obj
,
obj_mask
)
loss_obj_pos
=
fluid
.
layers
.
reduce_sum
(
loss_obj
*
tobj
,
dim
=
[
1
,
2
,
3
])
loss_obj_neg
=
fluid
.
layers
.
reduce_sum
(
loss_obj
*
(
1.0
-
obj_mask
)
*
iou_mask
,
dim
=
[
1
,
2
,
3
])
return
loss_obj_pos
,
loss_obj_neg
paddlex/cv/nets/detection/ops.py
0 → 100644
浏览文件 @
f8085469
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
from
numbers
import
Integral
import
math
import
six
import
paddle
from
paddle
import
fluid
def
DropBlock
(
input
,
block_size
,
keep_prob
,
is_test
):
if
is_test
:
return
input
def
CalculateGamma
(
input
,
block_size
,
keep_prob
):
input_shape
=
fluid
.
layers
.
shape
(
input
)
feat_shape_tmp
=
fluid
.
layers
.
slice
(
input_shape
,
[
0
],
[
3
],
[
4
])
feat_shape_tmp
=
fluid
.
layers
.
cast
(
feat_shape_tmp
,
dtype
=
"float32"
)
feat_shape_t
=
fluid
.
layers
.
reshape
(
feat_shape_tmp
,
[
1
,
1
,
1
,
1
])
feat_area
=
fluid
.
layers
.
pow
(
feat_shape_t
,
factor
=
2
)
block_shape_t
=
fluid
.
layers
.
fill_constant
(
shape
=
[
1
,
1
,
1
,
1
],
value
=
block_size
,
dtype
=
'float32'
)
block_area
=
fluid
.
layers
.
pow
(
block_shape_t
,
factor
=
2
)
useful_shape_t
=
feat_shape_t
-
block_shape_t
+
1
useful_area
=
fluid
.
layers
.
pow
(
useful_shape_t
,
factor
=
2
)
upper_t
=
feat_area
*
(
1
-
keep_prob
)
bottom_t
=
block_area
*
useful_area
output
=
upper_t
/
bottom_t
return
output
gamma
=
CalculateGamma
(
input
,
block_size
=
block_size
,
keep_prob
=
keep_prob
)
input_shape
=
fluid
.
layers
.
shape
(
input
)
p
=
fluid
.
layers
.
expand_as
(
gamma
,
input
)
input_shape_tmp
=
fluid
.
layers
.
cast
(
input_shape
,
dtype
=
"int64"
)
random_matrix
=
fluid
.
layers
.
uniform_random
(
input_shape_tmp
,
dtype
=
'float32'
,
min
=
0.0
,
max
=
1.0
)
one_zero_m
=
fluid
.
layers
.
less_than
(
random_matrix
,
p
)
one_zero_m
.
stop_gradient
=
True
one_zero_m
=
fluid
.
layers
.
cast
(
one_zero_m
,
dtype
=
"float32"
)
mask_flag
=
fluid
.
layers
.
pool2d
(
one_zero_m
,
pool_size
=
block_size
,
pool_type
=
'max'
,
pool_stride
=
1
,
pool_padding
=
block_size
//
2
)
mask
=
1.0
-
mask_flag
elem_numel
=
fluid
.
layers
.
reduce_prod
(
input_shape
)
elem_numel_m
=
fluid
.
layers
.
cast
(
elem_numel
,
dtype
=
"float32"
)
elem_numel_m
.
stop_gradient
=
True
elem_sum
=
fluid
.
layers
.
reduce_sum
(
mask
)
elem_sum_m
=
fluid
.
layers
.
cast
(
elem_sum
,
dtype
=
"float32"
)
elem_sum_m
.
stop_gradient
=
True
output
=
input
*
mask
*
elem_numel_m
/
elem_sum_m
return
output
class
MultiClassNMS
(
object
):
def
__init__
(
self
,
score_threshold
=
.
05
,
nms_top_k
=-
1
,
keep_top_k
=
100
,
nms_threshold
=
.
5
,
normalized
=
False
,
nms_eta
=
1.0
,
background_label
=
0
):
super
(
MultiClassNMS
,
self
).
__init__
()
self
.
score_threshold
=
score_threshold
self
.
nms_top_k
=
nms_top_k
self
.
keep_top_k
=
keep_top_k
self
.
nms_threshold
=
nms_threshold
self
.
normalized
=
normalized
self
.
nms_eta
=
nms_eta
self
.
background_label
=
background_label
def
__call__
(
self
,
bboxes
,
scores
):
return
fluid
.
layers
.
multiclass_nms
(
bboxes
=
bboxes
,
scores
=
scores
,
score_threshold
=
self
.
score_threshold
,
nms_top_k
=
self
.
nms_top_k
,
keep_top_k
=
self
.
keep_top_k
,
normalized
=
self
.
normalized
,
nms_threshold
=
self
.
nms_threshold
,
nms_eta
=
self
.
nms_eta
,
background_label
=
self
.
background_label
)
class
MatrixNMS
(
object
):
def
__init__
(
self
,
score_threshold
=
.
05
,
post_threshold
=
.
05
,
nms_top_k
=-
1
,
keep_top_k
=
100
,
use_gaussian
=
False
,
gaussian_sigma
=
2.
,
normalized
=
False
,
background_label
=
0
):
super
(
MatrixNMS
,
self
).
__init__
()
self
.
score_threshold
=
score_threshold
self
.
post_threshold
=
post_threshold
self
.
nms_top_k
=
nms_top_k
self
.
keep_top_k
=
keep_top_k
self
.
normalized
=
normalized
self
.
use_gaussian
=
use_gaussian
self
.
gaussian_sigma
=
gaussian_sigma
self
.
background_label
=
background_label
def
__call__
(
self
,
bboxes
,
scores
):
return
paddle
.
fluid
.
layers
.
matrix_nms
(
bboxes
=
bboxes
,
scores
=
scores
,
score_threshold
=
self
.
score_threshold
,
post_threshold
=
self
.
post_threshold
,
nms_top_k
=
self
.
nms_top_k
,
keep_top_k
=
self
.
keep_top_k
,
normalized
=
self
.
normalized
,
use_gaussian
=
self
.
use_gaussian
,
gaussian_sigma
=
self
.
gaussian_sigma
,
background_label
=
self
.
background_label
)
class
MultiClassSoftNMS
(
object
):
def
__init__
(
self
,
score_threshold
=
0.01
,
keep_top_k
=
300
,
softnms_sigma
=
0.5
,
normalized
=
False
,
background_label
=
0
,
):
super
(
MultiClassSoftNMS
,
self
).
__init__
()
self
.
score_threshold
=
score_threshold
self
.
keep_top_k
=
keep_top_k
self
.
softnms_sigma
=
softnms_sigma
self
.
normalized
=
normalized
self
.
background_label
=
background_label
def
__call__
(
self
,
bboxes
,
scores
):
def
create_tmp_var
(
program
,
name
,
dtype
,
shape
,
lod_level
):
return
program
.
current_block
().
create_var
(
name
=
name
,
dtype
=
dtype
,
shape
=
shape
,
lod_level
=
lod_level
)
def
_soft_nms_for_cls
(
dets
,
sigma
,
thres
):
"""soft_nms_for_cls"""
dets_final
=
[]
while
len
(
dets
)
>
0
:
maxpos
=
np
.
argmax
(
dets
[:,
0
])
dets_final
.
append
(
dets
[
maxpos
].
copy
())
ts
,
tx1
,
ty1
,
tx2
,
ty2
=
dets
[
maxpos
]
scores
=
dets
[:,
0
]
# force remove bbox at maxpos
scores
[
maxpos
]
=
-
1
x1
=
dets
[:,
1
]
y1
=
dets
[:,
2
]
x2
=
dets
[:,
3
]
y2
=
dets
[:,
4
]
eta
=
0
if
self
.
normalized
else
1
areas
=
(
x2
-
x1
+
eta
)
*
(
y2
-
y1
+
eta
)
xx1
=
np
.
maximum
(
tx1
,
x1
)
yy1
=
np
.
maximum
(
ty1
,
y1
)
xx2
=
np
.
minimum
(
tx2
,
x2
)
yy2
=
np
.
minimum
(
ty2
,
y2
)
w
=
np
.
maximum
(
0.0
,
xx2
-
xx1
+
eta
)
h
=
np
.
maximum
(
0.0
,
yy2
-
yy1
+
eta
)
inter
=
w
*
h
ovr
=
inter
/
(
areas
+
areas
[
maxpos
]
-
inter
)
weight
=
np
.
exp
(
-
(
ovr
*
ovr
)
/
sigma
)
scores
=
scores
*
weight
idx_keep
=
np
.
where
(
scores
>=
thres
)
dets
[:,
0
]
=
scores
dets
=
dets
[
idx_keep
]
dets_final
=
np
.
array
(
dets_final
).
reshape
(
-
1
,
5
)
return
dets_final
def
_soft_nms
(
bboxes
,
scores
):
class_nums
=
scores
.
shape
[
-
1
]
softnms_thres
=
self
.
score_threshold
softnms_sigma
=
self
.
softnms_sigma
keep_top_k
=
self
.
keep_top_k
cls_boxes
=
[[]
for
_
in
range
(
class_nums
)]
cls_ids
=
[[]
for
_
in
range
(
class_nums
)]
start_idx
=
1
if
self
.
background_label
==
0
else
0
for
j
in
range
(
start_idx
,
class_nums
):
inds
=
np
.
where
(
scores
[:,
j
]
>=
softnms_thres
)[
0
]
scores_j
=
scores
[
inds
,
j
]
rois_j
=
bboxes
[
inds
,
j
,
:]
if
len
(
bboxes
.
shape
)
>
2
else
bboxes
[
inds
,
:]
dets_j
=
np
.
hstack
((
scores_j
[:,
np
.
newaxis
],
rois_j
)).
astype
(
np
.
float32
,
copy
=
False
)
cls_rank
=
np
.
argsort
(
-
dets_j
[:,
0
])
dets_j
=
dets_j
[
cls_rank
]
cls_boxes
[
j
]
=
_soft_nms_for_cls
(
dets_j
,
sigma
=
softnms_sigma
,
thres
=
softnms_thres
)
cls_ids
[
j
]
=
np
.
array
([
j
]
*
cls_boxes
[
j
].
shape
[
0
]).
reshape
(
-
1
,
1
)
cls_boxes
=
np
.
vstack
(
cls_boxes
[
start_idx
:])
cls_ids
=
np
.
vstack
(
cls_ids
[
start_idx
:])
pred_result
=
np
.
hstack
([
cls_ids
,
cls_boxes
])
# Limit to max_per_image detections **over all classes**
image_scores
=
cls_boxes
[:,
0
]
if
len
(
image_scores
)
>
keep_top_k
:
image_thresh
=
np
.
sort
(
image_scores
)[
-
keep_top_k
]
keep
=
np
.
where
(
cls_boxes
[:,
0
]
>=
image_thresh
)[
0
]
pred_result
=
pred_result
[
keep
,
:]
return
pred_result
def
_batch_softnms
(
bboxes
,
scores
):
batch_offsets
=
bboxes
.
lod
()
bboxes
=
np
.
array
(
bboxes
)
scores
=
np
.
array
(
scores
)
out_offsets
=
[
0
]
pred_res
=
[]
if
len
(
batch_offsets
)
>
0
:
batch_offset
=
batch_offsets
[
0
]
for
i
in
range
(
len
(
batch_offset
)
-
1
):
s
,
e
=
batch_offset
[
i
],
batch_offset
[
i
+
1
]
pred
=
_soft_nms
(
bboxes
[
s
:
e
],
scores
[
s
:
e
])
out_offsets
.
append
(
pred
.
shape
[
0
]
+
out_offsets
[
-
1
])
pred_res
.
append
(
pred
)
else
:
assert
len
(
bboxes
.
shape
)
==
3
assert
len
(
scores
.
shape
)
==
3
for
i
in
range
(
bboxes
.
shape
[
0
]):
pred
=
_soft_nms
(
bboxes
[
i
],
scores
[
i
])
out_offsets
.
append
(
pred
.
shape
[
0
]
+
out_offsets
[
-
1
])
pred_res
.
append
(
pred
)
res
=
fluid
.
LoDTensor
()
res
.
set_lod
([
out_offsets
])
if
len
(
pred_res
)
==
0
:
pred_res
=
np
.
array
([[
1
]],
dtype
=
np
.
float32
)
res
.
set
(
np
.
vstack
(
pred_res
).
astype
(
np
.
float32
),
fluid
.
CPUPlace
())
return
res
pred_result
=
create_tmp_var
(
fluid
.
default_main_program
(),
name
=
'softnms_pred_result'
,
dtype
=
'float32'
,
shape
=
[
-
1
,
6
],
lod_level
=
1
)
fluid
.
layers
.
py_func
(
func
=
_batch_softnms
,
x
=
[
bboxes
,
scores
],
out
=
pred_result
)
return
pred_result
paddlex/cv/nets/detection/yolo_v3.py
浏览文件 @
f8085469
此差异已折叠。
点击以展开。
paddlex/cv/transforms/__init__.py
浏览文件 @
f8085469
...
@@ -91,7 +91,10 @@ def arrange_transforms(model_type, class_name, transforms, mode='train'):
...
@@ -91,7 +91,10 @@ def arrange_transforms(model_type, class_name, transforms, mode='train'):
elif
model_type
==
'segmenter'
:
elif
model_type
==
'segmenter'
:
arrange_transform
=
seg_transforms
.
ArrangeSegmenter
arrange_transform
=
seg_transforms
.
ArrangeSegmenter
elif
model_type
==
'detector'
:
elif
model_type
==
'detector'
:
arrange_name
=
'Arrange{}'
.
format
(
class_name
)
if
class_name
==
"PPYOLO"
:
arrange_name
=
'ArrangeYOLOv3'
else
:
arrange_name
=
'Arrange{}'
.
format
(
class_name
)
arrange_transform
=
getattr
(
det_transforms
,
arrange_name
)
arrange_transform
=
getattr
(
det_transforms
,
arrange_name
)
else
:
else
:
raise
Exception
(
"Unrecognized model type: {}"
.
format
(
self
.
model_type
))
raise
Exception
(
"Unrecognized model type: {}"
.
format
(
self
.
model_type
))
...
...
paddlex/cv/transforms/cls_transforms.py
浏览文件 @
f8085469
...
@@ -46,7 +46,7 @@ class Compose(ClsTransform):
...
@@ -46,7 +46,7 @@ class Compose(ClsTransform):
raise
ValueError
(
'The length of transforms '
+
\
raise
ValueError
(
'The length of transforms '
+
\
'must be equal or larger than 1!'
)
'must be equal or larger than 1!'
)
self
.
transforms
=
transforms
self
.
transforms
=
transforms
self
.
batch_transforms
=
None
# 检查transforms里面的操作,目前支持PaddleX定义的或者是imgaug操作
# 检查transforms里面的操作,目前支持PaddleX定义的或者是imgaug操作
for
op
in
self
.
transforms
:
for
op
in
self
.
transforms
:
if
not
isinstance
(
op
,
ClsTransform
):
if
not
isinstance
(
op
,
ClsTransform
):
...
...
paddlex/cv/transforms/det_transforms.py
浏览文件 @
f8085469
...
@@ -55,6 +55,7 @@ class Compose(DetTransform):
...
@@ -55,6 +55,7 @@ class Compose(DetTransform):
raise
ValueError
(
'The length of transforms '
+
\
raise
ValueError
(
'The length of transforms '
+
\
'must be equal or larger than 1!'
)
'must be equal or larger than 1!'
)
self
.
transforms
=
transforms
self
.
transforms
=
transforms
self
.
batch_transforms
=
None
self
.
use_mixup
=
False
self
.
use_mixup
=
False
for
t
in
self
.
transforms
:
for
t
in
self
.
transforms
:
if
type
(
t
).
__name__
==
'MixupImage'
:
if
type
(
t
).
__name__
==
'MixupImage'
:
...
@@ -1385,3 +1386,187 @@ class ComposedYOLOv3Transforms(Compose):
...
@@ -1385,3 +1386,187 @@ class ComposedYOLOv3Transforms(Compose):
mean
=
mean
,
std
=
std
)
mean
=
mean
,
std
=
std
)
]
]
super
(
ComposedYOLOv3Transforms
,
self
).
__init__
(
transforms
)
super
(
ComposedYOLOv3Transforms
,
self
).
__init__
(
transforms
)
class
BatchRandomShape
(
DetTransform
):
"""调整图像大小(resize)。
对batch数据中的每张图像全部resize到random_shapes中任意一个大小。
注意:当插值方式为“RANDOM”时,则随机选取一种插值方式进行resize。
Args:
random_shapes (list): resize大小选择列表。
默认为[320, 352, 384, 416, 448, 480, 512, 544, 576, 608]。
interp (str): resize的插值方式,与opencv的插值方式对应,取值范围为
['NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM']。默认为"RANDOM"。
Raises:
ValueError: 插值方式不在['NEAREST', 'LINEAR', 'CUBIC',
'AREA', 'LANCZOS4', 'RANDOM']中。
"""
# The interpolation mode
interp_dict
=
{
'NEAREST'
:
cv2
.
INTER_NEAREST
,
'LINEAR'
:
cv2
.
INTER_LINEAR
,
'CUBIC'
:
cv2
.
INTER_CUBIC
,
'AREA'
:
cv2
.
INTER_AREA
,
'LANCZOS4'
:
cv2
.
INTER_LANCZOS4
}
def
__init__
(
self
,
random_shapes
=
[
320
,
352
,
384
,
416
,
448
,
480
,
512
,
544
,
576
,
608
],
interp
=
'RANDOM'
):
if
not
(
interp
==
"RANDOM"
or
interp
in
self
.
interp_dict
):
raise
ValueError
(
"interp should be one of {}"
.
format
(
self
.
interp_dict
.
keys
()))
self
.
random_shapes
=
random_shapes
self
.
interp
=
interp
def
__call__
(
self
,
batch_data
):
"""
Args:
batch_data (list): 由与图像相关的各种信息组成的batch数据。
Returns:
list: 由与图像相关的各种信息组成的batch数据。
"""
shape
=
np
.
random
.
choice
(
self
.
random_shapes
)
if
self
.
interp
==
"RANDOM"
:
interp
=
random
.
choice
(
list
(
self
.
interp_dict
.
keys
()))
else
:
interp
=
self
.
interp
for
data_id
,
data
in
enumerate
(
batch_data
):
data_list
=
list
(
data
)
im
=
data_list
[
0
]
im
=
np
.
swapaxes
(
im
,
1
,
0
)
im
=
np
.
swapaxes
(
im
,
1
,
2
)
im
=
resize
(
im
,
shape
,
self
.
interp_dict
[
interp
])
im
=
np
.
swapaxes
(
im
,
1
,
2
)
im
=
np
.
swapaxes
(
im
,
1
,
0
)
data_list
[
0
]
=
im
batch_data
[
data_id
]
=
tuple
(
data_list
)
return
batch_data
class
GenerateYoloTarget
(
object
):
"""生成YOLOv3的ground truth(真实标注框)在不同特征层的位置转换信息。
该transform只在YOLOv3计算细粒度loss时使用。
Args:
anchors (list|tuple): anchor框的宽度和高度。
anchor_masks (list|tuple): 在计算损失时,使用anchor的mask索引。
num_classes (int): 类别数。默认为80。
iou_thresh (float): iou阈值,当anchor和真实标注框的iou大于该阈值时,计入target。默认为1.0。
"""
def
__init__
(
self
,
anchors
,
anchor_masks
,
downsample_ratios
,
num_classes
=
80
,
iou_thresh
=
1.
):
super
(
GenerateYoloTarget
,
self
).
__init__
()
self
.
anchors
=
anchors
self
.
anchor_masks
=
anchor_masks
self
.
downsample_ratios
=
downsample_ratios
self
.
num_classes
=
num_classes
self
.
iou_thresh
=
iou_thresh
def
__call__
(
self
,
batch_data
):
"""
Args:
batch_data (list): 由与图像相关的各种信息组成的batch数据。
Returns:
list: 由与图像相关的各种信息组成的batch数据。
其中,每个数据新添加的字段为:
- target0 (np.ndarray): YOLOv3的ground truth在特征层0的位置转换信息,
形状为(特征层0的anchor数量, 6+类别数, 特征层0的h, 特征层0的w)。
- target1 (np.ndarray): YOLOv3的ground truth在特征层1的位置转换信息,
形状为(特征层1的anchor数量, 6+类别数, 特征层1的h, 特征层1的w)。
- ...
-targetn (np.ndarray): YOLOv3的ground truth在特征层n的位置转换信息,
形状为(特征层n的anchor数量, 6+类别数, 特征层n的h, 特征层n的w)。
n的是大小由anchor_masks的长度决定。
"""
im
=
batch_data
[
0
][
0
]
h
=
im
.
shape
[
1
]
w
=
im
.
shape
[
2
]
an_hw
=
np
.
array
(
self
.
anchors
)
/
np
.
array
([[
w
,
h
]])
for
data_id
,
data
in
enumerate
(
batch_data
):
gt_bbox
=
data
[
1
]
gt_class
=
data
[
2
]
gt_score
=
data
[
3
]
im_shape
=
data
[
4
]
origin_h
=
float
(
im_shape
[
0
])
origin_w
=
float
(
im_shape
[
1
])
data_list
=
list
(
data
)
for
i
,
(
mask
,
downsample_ratio
)
in
enumerate
(
zip
(
self
.
anchor_masks
,
self
.
downsample_ratios
)):
grid_h
=
int
(
h
/
downsample_ratio
)
grid_w
=
int
(
w
/
downsample_ratio
)
target
=
np
.
zeros
(
(
len
(
mask
),
6
+
self
.
num_classes
,
grid_h
,
grid_w
),
dtype
=
np
.
float32
)
for
b
in
range
(
gt_bbox
.
shape
[
0
]):
gx
=
gt_bbox
[
b
,
0
]
/
float
(
origin_w
)
gy
=
gt_bbox
[
b
,
1
]
/
float
(
origin_h
)
gw
=
gt_bbox
[
b
,
2
]
/
float
(
origin_w
)
gh
=
gt_bbox
[
b
,
3
]
/
float
(
origin_h
)
cls
=
gt_class
[
b
]
score
=
gt_score
[
b
]
if
gw
<=
0.
or
gh
<=
0.
or
score
<=
0.
:
continue
# find best match anchor index
best_iou
=
0.
best_idx
=
-
1
for
an_idx
in
range
(
an_hw
.
shape
[
0
]):
iou
=
jaccard_overlap
(
[
0.
,
0.
,
gw
,
gh
],
[
0.
,
0.
,
an_hw
[
an_idx
,
0
],
an_hw
[
an_idx
,
1
]])
if
iou
>
best_iou
:
best_iou
=
iou
best_idx
=
an_idx
gi
=
int
(
gx
*
grid_w
)
gj
=
int
(
gy
*
grid_h
)
# gtbox should be regresed in this layes if best match
# anchor index in anchor mask of this layer
if
best_idx
in
mask
:
best_n
=
mask
.
index
(
best_idx
)
# x, y, w, h, scale
target
[
best_n
,
0
,
gj
,
gi
]
=
gx
*
grid_w
-
gi
target
[
best_n
,
1
,
gj
,
gi
]
=
gy
*
grid_h
-
gj
target
[
best_n
,
2
,
gj
,
gi
]
=
np
.
log
(
gw
*
w
/
self
.
anchors
[
best_idx
][
0
])
target
[
best_n
,
3
,
gj
,
gi
]
=
np
.
log
(
gh
*
h
/
self
.
anchors
[
best_idx
][
1
])
target
[
best_n
,
4
,
gj
,
gi
]
=
2.0
-
gw
*
gh
# objectness record gt_score
target
[
best_n
,
5
,
gj
,
gi
]
=
score
# classification
target
[
best_n
,
6
+
cls
,
gj
,
gi
]
=
1.
# For non-matched anchors, calculate the target if the iou
# between anchor and gt is larger than iou_thresh
if
self
.
iou_thresh
<
1
:
for
idx
,
mask_i
in
enumerate
(
mask
):
if
mask_i
==
best_idx
:
continue
iou
=
jaccard_overlap
(
[
0.
,
0.
,
gw
,
gh
],
[
0.
,
0.
,
an_hw
[
mask_i
,
0
],
an_hw
[
mask_i
,
1
]])
if
iou
>
self
.
iou_thresh
:
# x, y, w, h, scale
target
[
idx
,
0
,
gj
,
gi
]
=
gx
*
grid_w
-
gi
target
[
idx
,
1
,
gj
,
gi
]
=
gy
*
grid_h
-
gj
target
[
idx
,
2
,
gj
,
gi
]
=
np
.
log
(
gw
*
w
/
self
.
anchors
[
mask_i
][
0
])
target
[
idx
,
3
,
gj
,
gi
]
=
np
.
log
(
gh
*
h
/
self
.
anchors
[
mask_i
][
1
])
target
[
idx
,
4
,
gj
,
gi
]
=
2.0
-
gw
*
gh
# objectness record gt_score
target
[
idx
,
5
,
gj
,
gi
]
=
score
# classification
target
[
idx
,
6
+
cls
,
gj
,
gi
]
=
1.
data_list
.
append
(
target
)
batch_data
[
data_id
]
=
tuple
(
data_list
)
return
batch_data
paddlex/cv/transforms/seg_transforms.py
浏览文件 @
f8085469
...
@@ -49,6 +49,7 @@ class Compose(SegTransform):
...
@@ -49,6 +49,7 @@ class Compose(SegTransform):
raise
ValueError
(
'The length of transforms '
+
\
raise
ValueError
(
'The length of transforms '
+
\
'must be equal or larger than 1!'
)
'must be equal or larger than 1!'
)
self
.
transforms
=
transforms
self
.
transforms
=
transforms
self
.
batch_transforms
=
None
self
.
to_rgb
=
False
self
.
to_rgb
=
False
# 检查transforms里面的操作,目前支持PaddleX定义的或者是imgaug操作
# 检查transforms里面的操作,目前支持PaddleX定义的或者是imgaug操作
for
op
in
self
.
transforms
:
for
op
in
self
.
transforms
:
...
...
paddlex/det.py
浏览文件 @
f8085469
...
@@ -17,6 +17,7 @@ from . import cv
...
@@ -17,6 +17,7 @@ from . import cv
FasterRCNN
=
cv
.
models
.
FasterRCNN
FasterRCNN
=
cv
.
models
.
FasterRCNN
YOLOv3
=
cv
.
models
.
YOLOv3
YOLOv3
=
cv
.
models
.
YOLOv3
PPYOLO
=
cv
.
models
.
PPYOLO
MaskRCNN
=
cv
.
models
.
MaskRCNN
MaskRCNN
=
cv
.
models
.
MaskRCNN
transforms
=
cv
.
transforms
.
det_transforms
transforms
=
cv
.
transforms
.
det_transforms
visualize
=
cv
.
models
.
utils
.
visualize
.
visualize_detection
visualize
=
cv
.
models
.
utils
.
visualize
.
visualize_detection
...
...
tutorials/train/object_detection/ppyolo.py
0 → 100644
浏览文件 @
f8085469
# 环境变量配置,用于控制是否使用GPU
# 说明文档:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html#gpu
import
os
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
'0'
from
paddlex.det
import
transforms
import
paddlex
as
pdx
# 下载和解压昆虫检测数据集
insect_dataset
=
'https://bj.bcebos.com/paddlex/datasets/insect_det.tar.gz'
pdx
.
utils
.
download_and_decompress
(
insect_dataset
,
path
=
'./'
)
# 定义训练和验证时的transforms
# API说明 https://paddlex.readthedocs.io/zh_CN/develop/apis/transforms/det_transforms.html
train_transforms
=
transforms
.
Compose
([
transforms
.
MixupImage
(
mixup_epoch
=
250
),
transforms
.
RandomDistort
(),
transforms
.
RandomExpand
(),
transforms
.
RandomCrop
(),
transforms
.
Resize
(
target_size
=
608
,
interp
=
'RANDOM'
),
transforms
.
RandomHorizontalFlip
(),
transforms
.
Normalize
()
])
eval_transforms
=
transforms
.
Compose
([
transforms
.
Resize
(
target_size
=
608
,
interp
=
'CUBIC'
),
transforms
.
Normalize
()
])
# 定义训练和验证所用的数据集
# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/datasets.html#paddlex-datasets-vocdetection
train_dataset
=
pdx
.
datasets
.
VOCDetection
(
data_dir
=
'insect_det'
,
file_list
=
'insect_det/train_list.txt'
,
label_list
=
'insect_det/labels.txt'
,
transforms
=
train_transforms
,
shuffle
=
True
)
eval_dataset
=
pdx
.
datasets
.
VOCDetection
(
data_dir
=
'insect_det'
,
file_list
=
'insect_det/val_list.txt'
,
label_list
=
'insect_det/labels.txt'
,
transforms
=
eval_transforms
)
# 初始化模型,并进行训练
# 可使用VisualDL查看训练指标,参考https://paddlex.readthedocs.io/zh_CN/develop/train/visualdl.html
num_classes
=
len
(
train_dataset
.
labels
)
# API说明: https://paddlex.readthedocs.io/zh_CN/develop/apis/models/detection.html#paddlex-det-yolov3
model
=
pdx
.
det
.
PPYOLO
(
num_classes
=
num_classes
)
# API说明: https://paddlex.readthedocs.io/zh_CN/develop/apis/models/detection.html#train
# 各参数介绍与调整说明:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html
model
.
train
(
num_epochs
=
270
,
train_dataset
=
train_dataset
,
train_batch_size
=
8
,
eval_dataset
=
eval_dataset
,
learning_rate
=
0.000125
,
lr_decay_epochs
=
[
210
,
240
],
save_dir
=
'output/ppyolo'
,
use_vdl
=
True
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录