Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleX
提交
5b18edf5
P
PaddleX
项目概览
PaddlePaddle
/
PaddleX
通知
138
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
43
列表
看板
标记
里程碑
合并请求
5
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleX
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
43
Issue
43
列表
看板
标记
里程碑
合并请求
5
合并请求
5
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5b18edf5
编写于
5月 28, 2020
作者:
S
sunyanfang01
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix for review
上级
612acfa8
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
183 addition
and
156 deletion
+183
-156
paddlex/cv/models/utils/pretrain_weights.py
paddlex/cv/models/utils/pretrain_weights.py
+21
-7
paddlex/cv/models/yolo_v3.py
paddlex/cv/models/yolo_v3.py
+18
-16
paddlex/cv/nets/detection/yolo_v3.py
paddlex/cv/nets/detection/yolo_v3.py
+8
-6
paddlex/cv/transforms/box_utils.py
paddlex/cv/transforms/box_utils.py
+121
-0
paddlex/cv/transforms/det_transforms.py
paddlex/cv/transforms/det_transforms.py
+12
-125
paddlex/cv/transforms/ops.py
paddlex/cv/transforms/ops.py
+3
-2
未找到文件。
paddlex/cv/models/utils/pretrain_weights.py
浏览文件 @
5b18edf5
...
@@ -73,7 +73,7 @@ image_pretrain = {
...
@@ -73,7 +73,7 @@ image_pretrain = {
}
}
obj365_pretrain
=
{
obj365_pretrain
=
{
'ResNet50_vd_
dcn_db_
obj365'
:
'ResNet50_vd_obj365'
:
'https://paddlemodels.bj.bcebos.com/object_detection/ResNet50_vd_dcn_db_obj365_pretrained.tar'
,
'https://paddlemodels.bj.bcebos.com/object_detection/ResNet50_vd_dcn_db_obj365_pretrained.tar'
,
}
}
...
@@ -127,13 +127,27 @@ def get_pretrain_weights(flag, model_type, backbone, save_dir):
...
@@ -127,13 +127,27 @@ def get_pretrain_weights(flag, model_type, backbone, save_dir):
if
hasattr
(
paddlex
,
'pretrain_dir'
):
if
hasattr
(
paddlex
,
'pretrain_dir'
):
new_save_dir
=
paddlex
.
pretrain_dir
new_save_dir
=
paddlex
.
pretrain_dir
if
backbone
==
'ResNet50_vd'
:
if
backbone
==
'ResNet50_vd'
:
backbone
=
'ResNet50_vd_
dcn_db_
obj365'
backbone
=
'ResNet50_vd_obj365'
assert
backbone
in
obj365_pretrain
,
"There is not Object365 pretrain weights for {},
you may try ImageNet.
"
.
format
(
assert
backbone
in
obj365_pretrain
,
"There is not Object365 pretrain weights for {},
try use pretrain_weights='IMAGENET'
"
.
format
(
backbone
)
backbone
)
url
=
obj365_pretrain
[
backbone
]
# url = obj365_pretrain[backbone]
fname
=
osp
.
split
(
url
)[
-
1
].
split
(
'.'
)[
0
]
# fname = osp.split(url)[-1].split('.')[0]
paddlex
.
utils
.
download_and_decompress
(
url
,
path
=
new_save_dir
)
# paddlex.utils.download_and_decompress(url, path=new_save_dir)
return
osp
.
join
(
new_save_dir
,
fname
)
# return osp.join(new_save_dir, fname)
try
:
hub
.
download
(
backbone
,
save_path
=
new_save_dir
)
except
Exception
as
e
:
if
isinstance
(
hub
.
ResourceNotFoundError
):
raise
Exception
(
"Resource for backbone {} not found"
.
format
(
backbone
))
elif
isinstance
(
hub
.
ServerConnectionError
):
raise
Exception
(
"Cannot get reource for backbone {}, please check your internet connecgtion"
.
format
(
backbone
))
else
:
raise
Exception
(
"Unexpected error, please make sure paddlehub >= 1.6.2"
)
return
osp
.
join
(
new_save_dir
,
backbone
)
elif
flag
==
'COCO'
:
elif
flag
==
'COCO'
:
new_save_dir
=
save_dir
new_save_dir
=
save_dir
if
hasattr
(
paddlex
,
'pretrain_dir'
):
if
hasattr
(
paddlex
,
'pretrain_dir'
):
...
...
paddlex/cv/models/yolo_v3.py
浏览文件 @
5b18edf5
...
@@ -44,7 +44,10 @@ class YOLOv3(BaseAPI):
...
@@ -44,7 +44,10 @@ class YOLOv3(BaseAPI):
nms_keep_topk (int): 进行NMS后,每个图像要保留的总检测框数。默认为100。
nms_keep_topk (int): 进行NMS后,每个图像要保留的总检测框数。默认为100。
nms_iou_threshold (float): 进行NMS时,用于剔除检测框IOU的阈值。默认为0.45。
nms_iou_threshold (float): 进行NMS时,用于剔除检测框IOU的阈值。默认为0.45。
label_smooth (bool): 是否使用label smooth。默认值为False。
label_smooth (bool): 是否使用label smooth。默认值为False。
train_random_shapes (list|tuple): 训练时从列表中随机选择图像大小。默认值为[320, 352, 384, 416, 448, 480, 512, 544, 576, 608]。
use_iou_loss (bool): 是否使用IoU Loss。默认为False。
use_iou_aware_loss (bool): 是否使用IoU Aware Loss。默认为False。
use_drop_block (bool): 是否使用DropBlock模块。默认为False。
use_dcn_v2 (bool): 是否使用Deformable Convolution v2(可变形卷积)。默认为False。
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -62,10 +65,7 @@ class YOLOv3(BaseAPI):
...
@@ -62,10 +65,7 @@ class YOLOv3(BaseAPI):
use_iou_aware_loss
=
False
,
use_iou_aware_loss
=
False
,
iou_aware_factor
=
0.4
,
iou_aware_factor
=
0.4
,
use_drop_block
=
False
,
use_drop_block
=
False
,
use_dcn_v2
=
False
,
use_dcn_v2
=
False
):
train_random_shapes
=
[
320
,
352
,
384
,
416
,
448
,
480
,
512
,
544
,
576
,
608
]):
self
.
init_params
=
locals
()
self
.
init_params
=
locals
()
super
(
YOLOv3
,
self
).
__init__
(
'detector'
)
super
(
YOLOv3
,
self
).
__init__
(
'detector'
)
backbones
=
[
backbones
=
[
...
@@ -89,7 +89,6 @@ class YOLOv3(BaseAPI):
...
@@ -89,7 +89,6 @@ class YOLOv3(BaseAPI):
self
.
iou_aware_factor
=
iou_aware_factor
self
.
iou_aware_factor
=
iou_aware_factor
self
.
use_drop_block
=
use_drop_block
self
.
use_drop_block
=
use_drop_block
self
.
use_dcn_v2
=
use_dcn_v2
self
.
use_dcn_v2
=
use_dcn_v2
self
.
train_random_shapes
=
train_random_shapes
self
.
fixed_input_shape
=
None
self
.
fixed_input_shape
=
None
if
self
.
anchors
is
None
:
if
self
.
anchors
is
None
:
self
.
anchors
=
[[
10
,
13
],
[
16
,
30
],
[
33
,
23
],
[
30
,
61
],
[
62
,
45
],
self
.
anchors
=
[[
10
,
13
],
[
16
,
30
],
[
33
,
23
],
[
30
,
61
],
[
62
,
45
],
...
@@ -139,13 +138,13 @@ class YOLOv3(BaseAPI):
...
@@ -139,13 +138,13 @@ class YOLOv3(BaseAPI):
nms_topk
=
self
.
nms_topk
,
nms_topk
=
self
.
nms_topk
,
nms_keep_topk
=
self
.
nms_keep_topk
,
nms_keep_topk
=
self
.
nms_keep_topk
,
nms_iou_threshold
=
self
.
nms_iou_threshold
,
nms_iou_threshold
=
self
.
nms_iou_threshold
,
train_random_shapes
=
self
.
train_random_shapes
,
fixed_input_shape
=
self
.
fixed_input_shape
,
fixed_input_shape
=
self
.
fixed_input_shape
,
use_iou_loss
=
self
.
use_iou_loss
,
use_iou_loss
=
self
.
use_iou_loss
,
use_iou_aware_loss
=
self
.
use_iou_aware_loss
,
use_iou_aware_loss
=
self
.
use_iou_aware_loss
,
iou_aware_factor
=
self
.
iou_aware_factor
,
iou_aware_factor
=
self
.
iou_aware_factor
,
use_drop_block
=
self
.
use_drop_block
,
use_drop_block
=
self
.
use_drop_block
,
batch_size
=
self
.
train_batch_size
if
hasattr
(
self
,
'train_batch_size'
)
else
8
)
batch_size
=
self
.
train_batch_size
,
max_shape
=
self
.
max_shape
)
inputs
=
model
.
generate_inputs
()
inputs
=
model
.
generate_inputs
()
model_out
=
model
.
build_net
(
inputs
)
model_out
=
model
.
build_net
(
inputs
)
outputs
=
OrderedDict
([(
'bbox'
,
model_out
)])
outputs
=
OrderedDict
([(
'bbox'
,
model_out
)])
...
@@ -254,15 +253,18 @@ class YOLOv3(BaseAPI):
...
@@ -254,15 +253,18 @@ class YOLOv3(BaseAPI):
if
isinstance
(
transform
,
paddlex
.
det
.
transforms
.
Normalize
):
if
isinstance
(
transform
,
paddlex
.
det
.
transforms
.
Normalize
):
transform
.
is_scale
=
False
transform
.
is_scale
=
False
if
self
.
use_iou_loss
or
self
.
use_iou_aware_loss
:
if
self
.
use_iou_loss
or
self
.
use_iou_aware_loss
:
if
self
.
train_random_shapes
is
None
or
len
(
self
.
train_random_shapes
)
==
0
:
self
.
max_shape
=
0
for
transform
in
train_dataset
.
transforms
.
transforms
:
for
transform
in
train_dataset
.
transforms
.
transforms
:
if
isinstance
(
transform
,
paddlex
.
det
.
transforms
.
Resize
):
if
isinstance
(
transform
,
paddlex
.
det
.
transforms
.
Resize
):
self
.
train_random_shapes
=
[
transform
.
target_size
]
self
.
max_shape
=
[
transform
.
target_size
]
break
break
if
train_dataset
.
transforms
.
batch_transforms
is
None
:
train_dataset
.
transforms
.
batch_transforms
=
[]
train_dataset
.
transforms
.
batch_transforms
=
[]
reshape_bt
=
paddlex
.
det
.
transforms
.
RandomShape
else
:
train_dataset
.
transforms
.
batch_transforms
.
append
(
reshape_bt
(
for
bt
in
train_dataset
.
transforms
.
batch_transforms
:
random_shapes
=
self
.
train_random_shapes
))
if
isinstance
(
bt
,
paddlex
.
det
.
transforms
.
BatchRandomShape
):
self
.
max_shape
=
max
(
bt
.
random_shapes
)
break
iou_bt
=
paddlex
.
det
.
transforms
.
GenerateYoloTarget
iou_bt
=
paddlex
.
det
.
transforms
.
GenerateYoloTarget
train_dataset
.
transforms
.
batch_transforms
.
append
(
iou_bt
(
anchors
=
self
.
anchors
,
train_dataset
.
transforms
.
batch_transforms
.
append
(
iou_bt
(
anchors
=
self
.
anchors
,
anchor_masks
=
self
.
anchor_masks
,
anchor_masks
=
self
.
anchor_masks
,
...
...
paddlex/cv/nets/detection/yolo_v3.py
浏览文件 @
5b18edf5
...
@@ -34,15 +34,17 @@ class YOLOv3:
...
@@ -34,15 +34,17 @@ class YOLOv3:
nms_topk
=
1000
,
nms_topk
=
1000
,
nms_keep_topk
=
100
,
nms_keep_topk
=
100
,
nms_iou_threshold
=
0.45
,
nms_iou_threshold
=
0.45
,
train_random_shapes
=
[
train_random_shapes
=
"Deprecated"
,
320
,
352
,
384
,
416
,
448
,
480
,
512
,
544
,
576
,
608
],
fixed_input_shape
=
None
,
fixed_input_shape
=
None
,
use_iou_loss
=
False
,
use_iou_loss
=
False
,
use_iou_aware_loss
=
False
,
use_iou_aware_loss
=
False
,
iou_aware_factor
=
0.4
,
iou_aware_factor
=
0.4
,
use_drop_block
=
False
,
use_drop_block
=
False
,
batch_size
=
8
):
batch_size
=
8
,
max_shape
=
608
):
if
train_random_shapes
!=
"Deprecated"
:
raise
Exception
(
"The 'train_random_shapes' is deprecated. If you want to set train_random_shapes, "
\
+
"you can use BatchRandomShape. The details you can see "
)
if
anchors
is
None
:
if
anchors
is
None
:
anchors
=
[[
10
,
13
],
[
16
,
30
],
[
33
,
23
],
[
30
,
61
],
[
62
,
45
],
anchors
=
[[
10
,
13
],
[
16
,
30
],
[
33
,
23
],
[
30
,
61
],
[
62
,
45
],
[
59
,
119
],
[
116
,
90
],
[
156
,
198
],
[
373
,
326
]]
[
59
,
119
],
[
116
,
90
],
[
156
,
198
],
[
373
,
326
]]
...
@@ -62,7 +64,6 @@ class YOLOv3:
...
@@ -62,7 +64,6 @@ class YOLOv3:
self
.
nms_iou_threshold
=
nms_iou_threshold
self
.
nms_iou_threshold
=
nms_iou_threshold
self
.
norm_decay
=
0.0
self
.
norm_decay
=
0.0
self
.
prefix_name
=
''
self
.
prefix_name
=
''
self
.
train_random_shapes
=
train_random_shapes
self
.
fixed_input_shape
=
fixed_input_shape
self
.
fixed_input_shape
=
fixed_input_shape
self
.
use_iou_loss
=
use_iou_loss
self
.
use_iou_loss
=
use_iou_loss
self
.
use_iou_aware_loss
=
use_iou_aware_loss
self
.
use_iou_aware_loss
=
use_iou_aware_loss
...
@@ -71,6 +72,7 @@ class YOLOv3:
...
@@ -71,6 +72,7 @@ class YOLOv3:
self
.
block_size
=
3
self
.
block_size
=
3
self
.
keep_prob
=
0.9
self
.
keep_prob
=
0.9
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
self
.
max_shape
=
max_shape
def
_head
(
self
,
feats
):
def
_head
(
self
,
feats
):
outputs
=
[]
outputs
=
[]
...
@@ -284,7 +286,7 @@ class YOLOv3:
...
@@ -284,7 +286,7 @@ class YOLOv3:
return
yolo_loss_obj
(
inputs
,
gt_box
,
gt_label
,
gt_score
,
targets
,
return
yolo_loss_obj
(
inputs
,
gt_box
,
gt_label
,
gt_score
,
targets
,
self
.
anchors
,
self
.
anchor_masks
,
self
.
anchors
,
self
.
anchor_masks
,
self
.
mask_anchors
,
self
.
num_classes
,
self
.
mask_anchors
,
self
.
num_classes
,
self
.
prefix_name
,
max
(
self
.
train_random_shapes
)
)
self
.
prefix_name
,
self
.
max_shape
)
def
_get_prediction
(
self
,
inputs
,
im_size
):
def
_get_prediction
(
self
,
inputs
,
im_size
):
boxes
=
[]
boxes
=
[]
...
...
paddlex/cv/transforms/box_utils.py
浏览文件 @
5b18edf5
...
@@ -221,3 +221,124 @@ def segms_horizontal_flip(segms, height, width):
...
@@ -221,3 +221,124 @@ def segms_horizontal_flip(segms, height, width):
import
pycocotools.mask
as
mask_util
import
pycocotools.mask
as
mask_util
flipped_segms
.
append
(
_flip_rle
(
segm
,
height
,
width
))
flipped_segms
.
append
(
_flip_rle
(
segm
,
height
,
width
))
return
flipped_segms
return
flipped_segms
class
GenerateYoloTarget
(
object
):
"""生成YOLOv3的ground truth(真实标注框)在不同特征层的位置转换信息。
该transform只在YOLOv3计算细粒度loss时使用。
Args:
anchors (list|tuple): anchor框的宽度和高度。
anchor_masks (list|tuple): 在计算损失时,使用anchor的mask索引。
num_classes (int): 类别数。默认为80。
iou_thresh (float): iou阈值,当anchor和真实标注框的iou大于该阈值时,计入target。默认为1.0。
"""
def
__init__
(
self
,
anchors
,
anchor_masks
,
num_classes
=
80
,
iou_thresh
=
1.
):
super
(
GenerateYoloTarget
,
self
).
__init__
()
self
.
anchors
=
anchors
self
.
anchor_masks
=
anchor_masks
self
.
num_classes
=
num_classes
self
.
iou_thresh
=
iou_thresh
def
__call__
(
self
,
batch_data
):
"""
Args:
batch_data (list): 由与图像相关的各种信息组成的batch数据。
Returns:
list: 由与图像相关的各种信息组成的batch数据。
其中,每个数据新添加的字段为:
- target0 (np.ndarray): YOLOv3的ground truth在特征层0的位置转换信息,
形状为(特征层0的anchor数量, 6+类别数, 特征层0的h, 特征层0的w)。
- target1 (np.ndarray): YOLOv3的ground truth在特征层1的位置转换信息,
形状为(特征层1的anchor数量, 6+类别数, 特征层1的h, 特征层1的w)。
- ...
-targetn (np.ndarray): YOLOv3的ground truth在特征层n的位置转换信息,
形状为(特征层n的anchor数量, 6+类别数, 特征层n的h, 特征层n的w)。
n的是大小由anchor_masks的长度决定。
"""
im
=
batch_data
[
0
][
0
]
h
=
im
.
shape
[
1
]
w
=
im
.
shape
[
2
]
an_hw
=
np
.
array
(
self
.
anchors
)
/
np
.
array
([[
w
,
h
]])
for
data_id
,
data
in
enumerate
(
batch_data
):
gt_bbox
=
data
[
1
]
gt_class
=
data
[
2
]
gt_score
=
data
[
3
]
im_shape
=
data
[
4
]
origin_h
=
float
(
im_shape
[
0
])
origin_w
=
float
(
im_shape
[
1
])
data_list
=
list
(
data
)
for
i
,
mask
in
enumerate
(
self
.
anchor_masks
):
downsample_ratio
=
32
//
pow
(
2
,
i
)
grid_h
=
int
(
h
/
downsample_ratio
)
grid_w
=
int
(
w
/
downsample_ratio
)
target
=
np
.
zeros
(
(
len
(
mask
),
6
+
self
.
num_classes
,
grid_h
,
grid_w
),
dtype
=
np
.
float32
)
for
b
in
range
(
gt_bbox
.
shape
[
0
]):
gx
=
gt_bbox
[
b
,
0
]
/
float
(
origin_w
)
gy
=
gt_bbox
[
b
,
1
]
/
float
(
origin_h
)
gw
=
gt_bbox
[
b
,
2
]
/
float
(
origin_w
)
gh
=
gt_bbox
[
b
,
3
]
/
float
(
origin_h
)
cls
=
gt_class
[
b
]
score
=
gt_score
[
b
]
if
gw
<=
0.
or
gh
<=
0.
or
score
<=
0.
:
continue
# find best match anchor index
best_iou
=
0.
best_idx
=
-
1
for
an_idx
in
range
(
an_hw
.
shape
[
0
]):
iou
=
jaccard_overlap
(
[
0.
,
0.
,
gw
,
gh
],
[
0.
,
0.
,
an_hw
[
an_idx
,
0
],
an_hw
[
an_idx
,
1
]])
if
iou
>
best_iou
:
best_iou
=
iou
best_idx
=
an_idx
gi
=
int
(
gx
*
grid_w
)
gj
=
int
(
gy
*
grid_h
)
# gtbox should be regresed in this layes if best match
# anchor index in anchor mask of this layer
if
best_idx
in
mask
:
best_n
=
mask
.
index
(
best_idx
)
# x, y, w, h, scale
target
[
best_n
,
0
,
gj
,
gi
]
=
gx
*
grid_w
-
gi
target
[
best_n
,
1
,
gj
,
gi
]
=
gy
*
grid_h
-
gj
target
[
best_n
,
2
,
gj
,
gi
]
=
np
.
log
(
gw
*
w
/
self
.
anchors
[
best_idx
][
0
])
target
[
best_n
,
3
,
gj
,
gi
]
=
np
.
log
(
gh
*
h
/
self
.
anchors
[
best_idx
][
1
])
target
[
best_n
,
4
,
gj
,
gi
]
=
2.0
-
gw
*
gh
# objectness record gt_score
target
[
best_n
,
5
,
gj
,
gi
]
=
score
# classification
target
[
best_n
,
6
+
cls
,
gj
,
gi
]
=
1.
# For non-matched anchors, calculate the target if the iou
# between anchor and gt is larger than iou_thresh
if
self
.
iou_thresh
<
1
:
for
idx
,
mask_i
in
enumerate
(
mask
):
if
mask_i
==
best_idx
:
continue
iou
=
jaccard_overlap
(
[
0.
,
0.
,
gw
,
gh
],
[
0.
,
0.
,
an_hw
[
mask_i
,
0
],
an_hw
[
mask_i
,
1
]])
if
iou
>
self
.
iou_thresh
:
# x, y, w, h, scale
target
[
idx
,
0
,
gj
,
gi
]
=
gx
*
grid_w
-
gi
target
[
idx
,
1
,
gj
,
gi
]
=
gy
*
grid_h
-
gj
target
[
idx
,
2
,
gj
,
gi
]
=
np
.
log
(
gw
*
w
/
self
.
anchors
[
mask_i
][
0
])
target
[
idx
,
3
,
gj
,
gi
]
=
np
.
log
(
gh
*
h
/
self
.
anchors
[
mask_i
][
1
])
target
[
idx
,
4
,
gj
,
gi
]
=
2.0
-
gw
*
gh
# objectness record gt_score
target
[
idx
,
5
,
gj
,
gi
]
=
score
# classification
target
[
idx
,
6
+
cls
,
gj
,
gi
]
=
1.
data_list
.
append
(
target
)
batch_data
[
data_id
]
=
tuple
(
data_list
)
return
batch_data
paddlex/cv/transforms/det_transforms.py
浏览文件 @
5b18edf5
...
@@ -1238,7 +1238,7 @@ class ArrangeYOLOv3(DetTransform):
...
@@ -1238,7 +1238,7 @@ class ArrangeYOLOv3(DetTransform):
return
outputs
return
outputs
class
RandomShape
(
DetTransform
):
class
Batch
RandomShape
(
DetTransform
):
"""调整图像大小(resize)。
"""调整图像大小(resize)。
对batch数据中的每张图像全部resize到random_shapes中任意一个大小。
对batch数据中的每张图像全部resize到random_shapes中任意一个大小。
...
@@ -1303,127 +1303,6 @@ class RandomShape(DetTransform):
...
@@ -1303,127 +1303,6 @@ class RandomShape(DetTransform):
return
batch_data
return
batch_data
class
GenerateYoloTarget
(
DetTransform
):
"""生成YOLOv3的ground truth(真实标注框)在不同特征层的位置转换信息。
该transform只在YOLOv3计算细粒度loss时使用。
Args:
anchors (list|tuple): anchor框的宽度和高度。
anchor_masks (list|tuple): 在计算损失时,使用anchor的mask索引。
num_classes (int): 类别数。默认为80。
iou_thresh (float): iou阈值,当anchor和真实标注框的iou大于该阈值时,计入target。默认为1.0。
"""
def
__init__
(
self
,
anchors
,
anchor_masks
,
num_classes
=
80
,
iou_thresh
=
1.
):
super
(
GenerateYoloTarget
,
self
).
__init__
()
self
.
anchors
=
anchors
self
.
anchor_masks
=
anchor_masks
self
.
num_classes
=
num_classes
self
.
iou_thresh
=
iou_thresh
def
__call__
(
self
,
batch_data
):
"""
Args:
batch_data (list): 由与图像相关的各种信息组成的batch数据。
Returns:
list: 由与图像相关的各种信息组成的batch数据。
其中,每个数据新添加的字段为:
- target0 (np.ndarray): YOLOv3的ground truth在特征层0的位置转换信息,
形状为(特征层0的anchor数量, 6+类别数, 特征层0的h, 特征层0的w)。
- target1 (np.ndarray): YOLOv3的ground truth在特征层1的位置转换信息,
形状为(特征层1的anchor数量, 6+类别数, 特征层1的h, 特征层1的w)。
- ...
-targetn (np.ndarray): YOLOv3的ground truth在特征层n的位置转换信息,
形状为(特征层n的anchor数量, 6+类别数, 特征层n的h, 特征层n的w)。
n的是大小由anchor_masks的长度决定。
"""
im
=
batch_data
[
0
][
0
]
h
=
im
.
shape
[
1
]
w
=
im
.
shape
[
2
]
an_hw
=
np
.
array
(
self
.
anchors
)
/
np
.
array
([[
w
,
h
]])
for
data_id
,
data
in
enumerate
(
batch_data
):
gt_bbox
=
data
[
1
]
gt_class
=
data
[
2
]
gt_score
=
data
[
3
]
im_shape
=
data
[
4
]
origin_h
=
float
(
im_shape
[
0
])
origin_w
=
float
(
im_shape
[
1
])
data_list
=
list
(
data
)
for
i
,
mask
in
enumerate
(
self
.
anchor_masks
):
downsample_ratio
=
32
//
pow
(
2
,
i
)
grid_h
=
int
(
h
/
downsample_ratio
)
grid_w
=
int
(
w
/
downsample_ratio
)
target
=
np
.
zeros
(
(
len
(
mask
),
6
+
self
.
num_classes
,
grid_h
,
grid_w
),
dtype
=
np
.
float32
)
for
b
in
range
(
gt_bbox
.
shape
[
0
]):
gx
=
gt_bbox
[
b
,
0
]
/
float
(
origin_w
)
gy
=
gt_bbox
[
b
,
1
]
/
float
(
origin_h
)
gw
=
gt_bbox
[
b
,
2
]
/
float
(
origin_w
)
gh
=
gt_bbox
[
b
,
3
]
/
float
(
origin_h
)
cls
=
gt_class
[
b
]
score
=
gt_score
[
b
]
if
gw
<=
0.
or
gh
<=
0.
or
score
<=
0.
:
continue
# find best match anchor index
best_iou
=
0.
best_idx
=
-
1
for
an_idx
in
range
(
an_hw
.
shape
[
0
]):
iou
=
jaccard_overlap
(
[
0.
,
0.
,
gw
,
gh
],
[
0.
,
0.
,
an_hw
[
an_idx
,
0
],
an_hw
[
an_idx
,
1
]])
if
iou
>
best_iou
:
best_iou
=
iou
best_idx
=
an_idx
gi
=
int
(
gx
*
grid_w
)
gj
=
int
(
gy
*
grid_h
)
# gtbox should be regresed in this layes if best match
# anchor index in anchor mask of this layer
if
best_idx
in
mask
:
best_n
=
mask
.
index
(
best_idx
)
# x, y, w, h, scale
target
[
best_n
,
0
,
gj
,
gi
]
=
gx
*
grid_w
-
gi
target
[
best_n
,
1
,
gj
,
gi
]
=
gy
*
grid_h
-
gj
target
[
best_n
,
2
,
gj
,
gi
]
=
np
.
log
(
gw
*
w
/
self
.
anchors
[
best_idx
][
0
])
target
[
best_n
,
3
,
gj
,
gi
]
=
np
.
log
(
gh
*
h
/
self
.
anchors
[
best_idx
][
1
])
target
[
best_n
,
4
,
gj
,
gi
]
=
2.0
-
gw
*
gh
# objectness record gt_score
target
[
best_n
,
5
,
gj
,
gi
]
=
score
# classification
target
[
best_n
,
6
+
cls
,
gj
,
gi
]
=
1.
# For non-matched anchors, calculate the target if the iou
# between anchor and gt is larger than iou_thresh
if
self
.
iou_thresh
<
1
:
for
idx
,
mask_i
in
enumerate
(
mask
):
if
mask_i
==
best_idx
:
continue
iou
=
jaccard_overlap
(
[
0.
,
0.
,
gw
,
gh
],
[
0.
,
0.
,
an_hw
[
mask_i
,
0
],
an_hw
[
mask_i
,
1
]])
if
iou
>
self
.
iou_thresh
:
# x, y, w, h, scale
target
[
idx
,
0
,
gj
,
gi
]
=
gx
*
grid_w
-
gi
target
[
idx
,
1
,
gj
,
gi
]
=
gy
*
grid_h
-
gj
target
[
idx
,
2
,
gj
,
gi
]
=
np
.
log
(
gw
*
w
/
self
.
anchors
[
mask_i
][
0
])
target
[
idx
,
3
,
gj
,
gi
]
=
np
.
log
(
gh
*
h
/
self
.
anchors
[
mask_i
][
1
])
target
[
idx
,
4
,
gj
,
gi
]
=
2.0
-
gw
*
gh
# objectness record gt_score
target
[
idx
,
5
,
gj
,
gi
]
=
score
# classification
target
[
idx
,
6
+
cls
,
gj
,
gi
]
=
1.
data_list
.
append
(
target
)
batch_data
[
data_id
]
=
tuple
(
data_list
)
return
batch_data
class
ComposedRCNNTransforms
(
Compose
):
class
ComposedRCNNTransforms
(
Compose
):
""" RCNN模型(faster-rcnn/mask-rcnn)图像处理流程,具体如下,
""" RCNN模型(faster-rcnn/mask-rcnn)图像处理流程,具体如下,
训练阶段:
训练阶段:
...
@@ -1489,6 +1368,8 @@ class ComposedYOLOTransforms(Compose):
...
@@ -1489,6 +1368,8 @@ class ComposedYOLOTransforms(Compose):
mixup_epoch(int): 模型训练过程中,前mixup_epoch会使用mixup策略
mixup_epoch(int): 模型训练过程中,前mixup_epoch会使用mixup策略
mean(list): 图像均值
mean(list): 图像均值
std(list): 图像方差
std(list): 图像方差
random_shapes (list): resize大小选择列表。
默认为[320, 352, 384, 416, 448, 480, 512, 544, 576, 608]。
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -1496,7 +1377,10 @@ class ComposedYOLOTransforms(Compose):
...
@@ -1496,7 +1377,10 @@ class ComposedYOLOTransforms(Compose):
shape
=
[
608
,
608
],
shape
=
[
608
,
608
],
mixup_epoch
=
250
,
mixup_epoch
=
250
,
mean
=
[
0.485
,
0.456
,
0.406
],
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
]):
std
=
[
0.229
,
0.224
,
0.225
],
random_shapes
=
[
320
,
352
,
384
,
416
,
448
,
480
,
512
,
544
,
576
,
608
]):
width
=
shape
width
=
shape
if
isinstance
(
shape
,
list
):
if
isinstance
(
shape
,
list
):
if
shape
[
0
]
!=
shape
[
1
]:
if
shape
[
0
]
!=
shape
[
1
]:
...
@@ -1517,6 +1401,9 @@ class ComposedYOLOTransforms(Compose):
...
@@ -1517,6 +1401,9 @@ class ComposedYOLOTransforms(Compose):
interp
=
'RANDOM'
),
RandomHorizontalFlip
(),
Normalize
(
interp
=
'RANDOM'
),
RandomHorizontalFlip
(),
Normalize
(
mean
=
mean
,
std
=
std
)
mean
=
mean
,
std
=
std
)
]
]
batch_transforms
=
[
BatchRandomShape
(
random_shapes
=
random_shapes
)
]
else
:
else
:
# 验证/预测时的transforms
# 验证/预测时的transforms
transforms
=
[
transforms
=
[
...
@@ -1524,4 +1411,4 @@ class ComposedYOLOTransforms(Compose):
...
@@ -1524,4 +1411,4 @@ class ComposedYOLOTransforms(Compose):
target_size
=
width
,
interp
=
'CUBIC'
),
Normalize
(
target_size
=
width
,
interp
=
'CUBIC'
),
Normalize
(
mean
=
mean
,
std
=
std
)
mean
=
mean
,
std
=
std
)
]
]
super
(
ComposedYOLOTransforms
,
self
).
__init__
(
transforms
)
super
(
ComposedYOLOTransforms
,
self
).
__init__
(
transforms
,
batch_transforms
)
\ No newline at end of file
\ No newline at end of file
paddlex/cv/transforms/ops.py
浏览文件 @
5b18edf5
...
@@ -18,7 +18,8 @@ import numpy as np
...
@@ -18,7 +18,8 @@ import numpy as np
from
PIL
import
Image
,
ImageEnhance
from
PIL
import
Image
,
ImageEnhance
def
normalize
(
im
,
mean
,
std
):
def
normalize
(
im
,
mean
,
std
,
is_scale
=
True
):
if
is_scale
:
im
=
im
/
255.0
im
=
im
/
255.0
im
-=
mean
im
-=
mean
im
/=
std
im
/=
std
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录