Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
1f2180b6
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
282
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1f2180b6
编写于
9月 25, 2020
作者:
H
haoyuying
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
revise yolov3 tansformer and basemodel
上级
9a210a9e
变更
8
展开全部
显示空白变更内容
内联
并排
Showing
8 changed file
with
677 addition
and
651 deletion
+677
-651
demo/detection/yolov3_darknet53_pascalvoc/predict.py
demo/detection/yolov3_darknet53_pascalvoc/predict.py
+1
-1
demo/detection/yolov3_darknet53_pascalvoc/train.py
demo/detection/yolov3_darknet53_pascalvoc/train.py
+11
-8
hub_module/modules/image/object_detection/yolov3_darknet53_pascalvoc/module.py
...age/object_detection/yolov3_darknet53_pascalvoc/module.py
+21
-73
paddlehub/datasets/pascalvoc.py
paddlehub/datasets/pascalvoc.py
+4
-8
paddlehub/module/cv_module.py
paddlehub/module/cv_module.py
+72
-54
paddlehub/process/detect_transforms.py
paddlehub/process/detect_transforms.py
+427
-0
paddlehub/process/functional.py
paddlehub/process/functional.py
+141
-2
paddlehub/process/transforms.py
paddlehub/process/transforms.py
+0
-505
未找到文件。
demo/detection/yolov3_darknet53_pascalvoc/predict.py
浏览文件 @
1f2180b6
...
...
@@ -6,4 +6,4 @@ if __name__ == '__main__':
paddle
.
disable_static
()
model
=
model
=
hub
.
Module
(
name
=
'yolov3_darknet53_pascalvoc'
,
is_train
=
False
)
model
.
eval
()
model
.
predict
(
imgpath
=
"
/PATH/TO/IMAGE
"
,
filelist
=
"/PATH/TO/JSON/FILE"
)
model
.
predict
(
imgpath
=
"
4026.jpeg
"
,
filelist
=
"/PATH/TO/JSON/FILE"
)
demo/detection/yolov3_darknet53_pascalvoc/train.py
浏览文件 @
1f2180b6
...
...
@@ -3,18 +3,21 @@ import paddlehub as hub
import
paddle.nn
as
nn
from
paddlehub.finetune.trainer
import
Trainer
from
paddlehub.datasets.pascalvoc
import
DetectionData
from
paddlehub.process.
transforms
import
DetectTrainReader
,
DetectTestReader
from
paddlehub.process.
detect_transforms
import
Compose
,
RandomDistort
,
RandomExpand
,
RandomCrop
,
RandomFlip
,
Normalize
,
Resize
,
ShuffleBox
if
__name__
==
"__main__"
:
place
=
paddle
.
CUDAPlace
(
0
)
paddle
.
disable_static
()
is_train
=
True
if
is_train
:
transform
=
DetectTrainReader
()
transform
=
Compose
([
RandomDistort
(),
RandomExpand
(
fill
=
[
0.485
,
0.456
,
0.406
]),
RandomCrop
(),
Resize
(
target_size
=
416
),
RandomFlip
(),
ShuffleBox
(),
Normalize
(
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
])
])
train_reader
=
DetectionData
(
transform
)
else
:
transform
=
DetectTestReader
()
test_reader
=
DetectionData
(
transform
)
model
=
hub
.
Module
(
name
=
'yolov3_darknet53_pascalvoc'
)
model
.
train
()
optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
0.0001
,
parameters
=
model
.
parameters
())
...
...
hub_module/modules/image/object_detection/yolov3_darknet53_pascalvoc/module.py
浏览文件 @
1f2180b6
...
...
@@ -7,7 +7,7 @@ from paddle.nn.initializer import Normal, Constant
from
paddle.regularizer
import
L2Decay
from
pycocotools.coco
import
COCO
from
paddlehub.module.cv_module
import
Yolov3Module
from
paddlehub.process.
transforms
import
DetectTrainReader
,
DetectTestReader
from
paddlehub.process.
detect_transforms
import
Compose
,
RandomDistort
,
RandomExpand
,
RandomCrop
,
Resize
,
RandomFlip
,
ShuffleBox
,
Normalize
from
paddlehub.module.module
import
moduleinfo
...
...
@@ -286,12 +286,24 @@ class YOLOv3(nn.Layer):
self
.
set_dict
(
model_dict
)
print
(
"load pretrained checkpoint success"
)
def
transform
(
self
,
img
:
paddle
.
Tensor
,
size
:
int
):
def
transform
(
self
,
img
):
if
self
.
is_train
:
transforms
=
DetectTrainReader
()
transform
=
Compose
([
RandomDistort
(),
RandomExpand
(
fill
=
[
0.485
,
0.456
,
0.406
]),
RandomCrop
(),
Resize
(
target_size
=
416
),
RandomFlip
(),
ShuffleBox
(),
Normalize
(
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
])
])
else
:
transforms
=
DetectTestReader
()
return
transforms
(
img
,
size
)
transform
=
Compose
([
Resize
(
target_size
=
416
,
interp
=
'CUBIC'
),
Normalize
(
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
])
])
return
transform
(
img
)
def
get_label_infos
(
self
,
file_list
:
str
):
self
.
COCO
=
COCO
(
file_list
)
...
...
@@ -301,23 +313,8 @@ class YOLOv3(nn.Layer):
label_names
.
append
(
category
[
'name'
])
return
label_names
def
forward
(
self
,
inputs
:
paddle
.
Tensor
,
gtbox
:
paddle
.
Tensor
=
None
,
gtlabel
:
paddle
.
Tensor
=
None
,
gtscore
:
paddle
.
Tensor
=
None
,
im_shape
:
paddle
.
Tensor
=
None
):
self
.
gtbox
=
gtbox
self
.
gtlabel
=
gtlabel
self
.
gtscore
=
gtscore
self
.
im_shape
=
im_shape
self
.
outputs
=
[]
self
.
boxes
=
[]
self
.
scores
=
[]
self
.
losses
=
[]
self
.
pred
=
[]
self
.
downsample
=
32
def
forward
(
self
,
inputs
:
paddle
.
Tensor
):
outputs
=
[]
blocks
=
self
.
block
(
inputs
)
route
=
None
for
i
,
block
in
enumerate
(
blocks
):
...
...
@@ -325,58 +322,9 @@ class YOLOv3(nn.Layer):
block
=
paddle
.
concat
([
route
,
block
],
axis
=
1
)
route
,
tip
=
self
.
yolo_blocks
[
i
](
block
)
block_out
=
self
.
block_outputs
[
i
](
tip
)
self
.
outputs
.
append
(
block_out
)
outputs
.
append
(
block_out
)
if
i
<
2
:
route
=
self
.
route_blocks_2
[
i
](
route
)
route
=
self
.
upsample
(
route
)
for
i
,
out
in
enumerate
(
self
.
outputs
):
anchor_mask
=
self
.
anchor_masks
[
i
]
if
self
.
is_train
:
loss
=
F
.
yolov3_loss
(
x
=
out
,
gt_box
=
self
.
gtbox
,
gt_label
=
self
.
gtlabel
,
gt_score
=
self
.
gtscore
,
anchors
=
self
.
anchors
,
anchor_mask
=
anchor_mask
,
class_num
=
self
.
class_num
,
ignore_thresh
=
self
.
ignore_thresh
,
downsample_ratio
=
self
.
downsample
,
use_label_smooth
=
False
)
else
:
loss
=
paddle
.
to_tensor
(
0.0
)
self
.
losses
.
append
(
paddle
.
reduce_mean
(
loss
))
mask_anchors
=
[]
for
m
in
anchor_mask
:
mask_anchors
.
append
((
self
.
anchors
[
2
*
m
]))
mask_anchors
.
append
(
self
.
anchors
[
2
*
m
+
1
])
boxes
,
scores
=
F
.
yolo_box
(
x
=
out
,
img_size
=
self
.
im_shape
,
anchors
=
mask_anchors
,
class_num
=
self
.
class_num
,
conf_thresh
=
self
.
valid_thresh
,
downsample_ratio
=
self
.
downsample
,
name
=
"yolo_box"
+
str
(
i
))
self
.
boxes
.
append
(
boxes
)
self
.
scores
.
append
(
paddle
.
transpose
(
scores
,
perm
=
[
0
,
2
,
1
]))
self
.
downsample
//=
2
for
i
in
range
(
self
.
boxes
[
0
].
shape
[
0
]):
yolo_boxes
=
paddle
.
unsqueeze
(
paddle
.
concat
([
self
.
boxes
[
0
][
i
],
self
.
boxes
[
1
][
i
],
self
.
boxes
[
2
][
i
]],
axis
=
0
),
0
)
yolo_scores
=
paddle
.
unsqueeze
(
paddle
.
concat
([
self
.
scores
[
0
][
i
],
self
.
scores
[
1
][
i
],
self
.
scores
[
2
][
i
]],
axis
=
1
),
0
)
pred
=
F
.
multiclass_nms
(
bboxes
=
yolo_boxes
,
scores
=
yolo_scores
,
score_threshold
=
self
.
valid_thresh
,
nms_top_k
=
self
.
nms_topk
,
keep_top_k
=
self
.
nms_posk
,
nms_threshold
=
self
.
nms_thresh
,
background_label
=-
1
)
self
.
pred
.
append
(
pred
)
return
sum
(
self
.
losses
),
self
.
pred
return
outputs
paddlehub/datasets/pascalvoc.py
浏览文件 @
1f2180b6
...
...
@@ -61,14 +61,10 @@ class DetectionData(paddle.io.Dataset):
self
.
data
=
parse_images
()
def
__getitem__
(
self
,
idx
:
int
):
if
self
.
mode
==
"train"
:
img
=
self
.
data
[
idx
]
out_img
,
gt_boxes
,
gt_labels
,
gt_scores
=
self
.
transform
(
img
,
416
)
im
,
data
=
self
.
transform
(
img
)
out_img
,
gt_boxes
,
gt_labels
,
gt_scores
=
im
,
data
[
'gt_boxes'
],
data
[
'gt_labels'
],
data
[
'gt_scores'
]
return
out_img
,
gt_boxes
,
gt_labels
,
gt_scores
elif
self
.
mode
==
"test"
:
img
=
self
.
data
[
idx
]
out_img
,
id
,
(
h
,
w
)
=
self
.
transform
(
img
)
return
out_img
,
id
,
(
h
,
w
)
def
__len__
(
self
):
return
len
(
self
.
data
)
paddlehub/module/cv_module.py
浏览文件 @
1f2180b6
...
...
@@ -26,7 +26,8 @@ from PIL import Image
from
paddlehub.module.module
import
serving
,
RunModule
from
paddlehub.utils.utils
import
base64_to_cv2
from
paddlehub.process.transforms
import
ConvertColorSpace
,
ColorPostprocess
,
Resize
,
BoxTool
from
paddlehub.process.transforms
import
ConvertColorSpace
,
ColorPostprocess
,
Resize
from
paddlehub.process.functional
import
subtract_imagenet_mean_batch
,
gram_matrix
,
draw_boxes_on_image
,
img_shape
class
ImageServing
(
object
):
...
...
@@ -218,38 +219,30 @@ class Yolov3Module(RunModule, ImageServing):
Returns:
results(dict) : The model outputs, such as metrics.
'''
ious
=
[]
boxtool
=
BoxTool
()
img
=
batch
[
0
].
astype
(
'float32'
)
B
,
C
,
W
,
H
=
img
.
shape
im_shape
=
np
.
array
([(
W
,
H
)]
*
B
).
astype
(
'int32'
)
im_shape
=
paddle
.
to_tensor
(
im_shape
)
gt_box
=
batch
[
1
].
astype
(
'float32'
)
gt_label
=
batch
[
2
].
astype
(
'int32'
)
gt_score
=
batch
[
3
].
astype
(
"float32"
)
loss
,
pred
=
self
(
img
,
gt_box
,
gt_label
,
gt_score
,
im_shape
)
for
i
in
range
(
len
(
pred
)):
bboxes
=
pred
[
i
].
numpy
()
labels
=
bboxes
[:,
0
].
astype
(
'int32'
)
scores
=
bboxes
[:,
1
].
astype
(
'float32'
)
boxes
=
bboxes
[:,
2
:].
astype
(
'float32'
)
iou
=
[]
for
j
,
(
box
,
score
,
label
)
in
enumerate
(
zip
(
boxes
,
scores
,
labels
)):
x1
,
y1
,
x2
,
y2
=
box
w
=
x2
-
x1
+
1
h
=
y2
-
y1
+
1
bbox
=
[
x1
,
y1
,
w
,
h
]
bbox
=
np
.
expand_dims
(
boxtool
.
coco_anno_box_to_center_relative
(
bbox
,
H
,
W
),
0
)
gt
=
gt_box
[
i
].
numpy
()
iou
.
append
(
max
(
boxtool
.
box_iou_xywh
(
bbox
,
gt
)))
ious
.
append
(
max
(
iou
))
ious
=
paddle
.
to_tensor
(
np
.
array
(
ious
))
return
{
'loss'
:
loss
,
'metrics'
:
{
'iou'
:
ious
}}
gtbox
=
batch
[
1
].
astype
(
'float32'
)
gtlabel
=
batch
[
2
].
astype
(
'int32'
)
gtscore
=
batch
[
3
].
astype
(
"float32"
)
losses
=
[]
outputs
=
self
(
img
)
self
.
downsample
=
32
for
i
,
out
in
enumerate
(
outputs
):
anchor_mask
=
self
.
anchor_masks
[
i
]
loss
=
F
.
yolov3_loss
(
x
=
out
,
gt_box
=
gtbox
,
gt_label
=
gtlabel
,
gt_score
=
gtscore
,
anchors
=
self
.
anchors
,
anchor_mask
=
anchor_mask
,
class_num
=
self
.
class_num
,
ignore_thresh
=
self
.
ignore_thresh
,
downsample_ratio
=
32
,
use_label_smooth
=
False
)
losses
.
append
(
paddle
.
reduce_mean
(
loss
))
self
.
downsample
//=
2
return
{
'loss'
:
sum
(
losses
)}
def
predict
(
self
,
imgpath
:
str
,
filelist
:
str
,
visualization
:
bool
=
True
,
save_path
:
str
=
'result'
):
'''
...
...
@@ -266,28 +259,53 @@ class Yolov3Module(RunModule, ImageServing):
scores(np.ndarray): Predict score.
labels(np.ndarray): Predict labels.
'''
boxtool
=
BoxTool
()
img
=
{}
img
[
'image'
]
=
imgpath
img
[
'id'
]
=
0
im
,
im_id
,
im_shape
=
self
.
transform
(
img
,
416
)
boxes
=
[]
scores
=
[]
self
.
downsample
=
32
im
=
self
.
transform
(
imgpath
)
h
,
w
,
c
=
img_shape
(
imgpath
)
im_shape
=
paddle
.
to_tensor
(
np
.
array
([[
h
,
w
]]).
astype
(
'int32'
))
label_names
=
self
.
get_label_infos
(
filelist
)
img_data
=
np
.
array
([
im
]).
astype
(
'float32'
)
img_data
=
paddle
.
to_tensor
(
img_data
)
im_shape
=
np
.
array
([
im_shape
]).
astype
(
'int32'
)
im_shape
=
paddle
.
to_tensor
(
im_shape
)
output
,
pred
=
self
(
img_data
,
None
,
None
,
None
,
im_shape
)
for
i
in
range
(
len
(
pred
)):
bboxes
=
pred
[
i
].
numpy
()
img_data
=
paddle
.
to_tensor
(
np
.
array
([
im
]).
astype
(
'float32'
))
outputs
=
self
(
img_data
)
for
i
,
out
in
enumerate
(
outputs
):
anchor_mask
=
self
.
anchor_masks
[
i
]
mask_anchors
=
[]
for
m
in
anchor_mask
:
mask_anchors
.
append
((
self
.
anchors
[
2
*
m
]))
mask_anchors
.
append
(
self
.
anchors
[
2
*
m
+
1
])
box
,
score
=
F
.
yolo_box
(
x
=
out
,
img_size
=
im_shape
,
anchors
=
mask_anchors
,
class_num
=
self
.
class_num
,
conf_thresh
=
self
.
valid_thresh
,
downsample_ratio
=
self
.
downsample
,
name
=
"yolo_box"
+
str
(
i
))
boxes
.
append
(
box
)
scores
.
append
(
paddle
.
transpose
(
score
,
perm
=
[
0
,
2
,
1
]))
self
.
downsample
//=
2
yolo_boxes
=
paddle
.
concat
(
boxes
,
axis
=
1
)
yolo_scores
=
paddle
.
concat
(
scores
,
axis
=
2
)
pred
=
F
.
multiclass_nms
(
bboxes
=
yolo_boxes
,
scores
=
yolo_scores
,
score_threshold
=
self
.
valid_thresh
,
nms_top_k
=
self
.
nms_topk
,
keep_top_k
=
self
.
nms_posk
,
nms_threshold
=
self
.
nms_thresh
,
background_label
=-
1
)
bboxes
=
pred
.
numpy
()
labels
=
bboxes
[:,
0
].
astype
(
'int32'
)
scores
=
bboxes
[:,
1
].
astype
(
'float32'
)
boxes
=
bboxes
[:,
2
:].
astype
(
'float32'
)
if
visualization
:
if
not
os
.
path
.
exists
(
save_path
):
os
.
mkdir
(
save_path
)
boxtool
.
draw_boxes_on_image
(
imgpath
,
boxes
,
scores
,
labels
,
label_names
,
0.5
)
draw_boxes_on_image
(
imgpath
,
boxes
,
scores
,
labels
,
label_names
,
0.5
)
return
boxes
,
scores
,
labels
paddlehub/process/detect_transforms.py
0 → 100644
浏览文件 @
1f2180b6
import
copy
import
os
import
random
from
typing
import
Callable
import
cv2
import
numpy
as
np
import
matplotlib
import
PIL
from
PIL
import
Image
,
ImageEnhance
from
matplotlib
import
pyplot
as
plt
from
paddlehub.process.functional
import
*
matplotlib
.
use
(
'Agg'
)
class
DetectCatagory
:
"""Load label name, id and map from detection dataset.
Args:
attrbox(Callable): Method to get detection attributes of images.
data_dir(str): Image dataset path.
Returns:
label_names(List(str)): The dataset label names.
label_ids(List(int)): The dataset label ids.
category_to_id_map(dict): Mapping relations of category and id for images.
"""
def
__init__
(
self
,
attrbox
:
Callable
,
data_dir
:
str
):
self
.
attrbox
=
attrbox
self
.
img_dir
=
data_dir
def
__call__
(
self
):
self
.
categories
=
self
.
attrbox
.
loadCats
(
self
.
attrbox
.
getCatIds
())
self
.
num_category
=
len
(
self
.
categories
)
label_names
=
[]
label_ids
=
[]
for
category
in
self
.
categories
:
label_names
.
append
(
category
[
'name'
])
label_ids
.
append
(
int
(
category
[
'id'
]))
category_to_id_map
=
{
v
:
i
for
i
,
v
in
enumerate
(
label_ids
)}
return
label_names
,
label_ids
,
category_to_id_map
class
ParseImages
:
"""Prepare images for detection.
Args:
attrbox(Callable): Method to get detection attributes of images.
is_train(bool): Select the mode for train or test.
data_dir(str): Image dataset path.
category_to_id_map(dict): Mapping relations of category and id for images.
Returns:
imgs(dict): The input for detection model, it is a dict.
"""
def
__init__
(
self
,
attrbox
:
Callable
,
data_dir
:
str
,
category_to_id_map
:
dict
):
self
.
attrbox
=
attrbox
self
.
img_dir
=
data_dir
self
.
category_to_id_map
=
category_to_id_map
self
.
parse_gt_annotations
=
GTAnotations
(
self
.
attrbox
,
self
.
category_to_id_map
)
def
__call__
(
self
):
image_ids
=
self
.
attrbox
.
getImgIds
()
image_ids
.
sort
()
imgs
=
copy
.
deepcopy
(
self
.
attrbox
.
loadImgs
(
image_ids
))
for
img
in
imgs
:
img
[
'image'
]
=
os
.
path
.
join
(
self
.
img_dir
,
img
[
'file_name'
])
assert
os
.
path
.
exists
(
img
[
'image'
]),
"image {} not found."
.
format
(
img
[
'image'
])
box_num
=
50
img
[
'gt_boxes'
]
=
np
.
zeros
((
box_num
,
4
),
dtype
=
np
.
float32
)
img
[
'gt_labels'
]
=
np
.
zeros
((
box_num
),
dtype
=
np
.
int32
)
img
=
self
.
parse_gt_annotations
(
img
)
return
imgs
class
GTAnotations
:
"""Set gt boxes and gt labels for train.
Args:
attrbox(Callable): Method for get detection attributes for images.
category_to_id_map(dict): Mapping relations of category and id for images.
img(dict): Input for detection model.
Returns:
img(dict): Set specific value on the attributes of 'gt boxes' and 'gt labels' for input.
"""
def
__init__
(
self
,
attrbox
:
Callable
,
category_to_id_map
:
dict
):
self
.
attrbox
=
attrbox
self
.
category_to_id_map
=
category_to_id_map
def
__call__
(
self
,
img
:
dict
):
img_height
=
img
[
'height'
]
img_width
=
img
[
'width'
]
anno
=
self
.
attrbox
.
loadAnns
(
self
.
attrbox
.
getAnnIds
(
imgIds
=
img
[
'id'
],
iscrowd
=
None
))
gt_index
=
0
for
target
in
anno
:
if
target
[
'area'
]
<
-
1
:
continue
if
'ignore'
in
target
and
target
[
'ignore'
]:
continue
box
=
coco_anno_box_to_center_relative
(
target
[
'bbox'
],
img_height
,
img_width
)
if
box
[
2
]
<=
0
and
box
[
3
]
<=
0
:
continue
img
[
'gt_boxes'
][
gt_index
]
=
box
img
[
'gt_labels'
][
gt_index
]
=
\
self
.
category_to_id_map
[
target
[
'category_id'
]]
gt_index
+=
1
if
gt_index
>=
50
:
break
return
img
class
RandomDistort
:
""" Distort the input image randomly.
Args:
lower(float): The lower bound value for enhancement, default is 0.5.
upper(float): The upper bound value for enhancement, default is 1.5.
Returns:
img(np.ndarray): Distorted image.
data(dict): Image info and label info.
"""
def
__init__
(
self
,
lower
:
float
=
0.5
,
upper
:
float
=
1.5
):
self
.
lower
=
lower
self
.
upper
=
upper
def
random_brightness
(
self
,
img
:
PIL
.
Image
):
e
=
np
.
random
.
uniform
(
self
.
lower
,
self
.
upper
)
return
ImageEnhance
.
Brightness
(
img
).
enhance
(
e
)
def
random_contrast
(
self
,
img
:
PIL
.
Image
):
e
=
np
.
random
.
uniform
(
self
.
lower
,
self
.
upper
)
return
ImageEnhance
.
Contrast
(
img
).
enhance
(
e
)
def
random_color
(
self
,
img
:
PIL
.
Image
):
e
=
np
.
random
.
uniform
(
self
.
lower
,
self
.
upper
)
return
ImageEnhance
.
Color
(
img
).
enhance
(
e
)
def
__call__
(
self
,
img
:
np
.
ndarray
,
data
:
dict
):
ops
=
[
self
.
random_brightness
,
self
.
random_contrast
,
self
.
random_color
]
np
.
random
.
shuffle
(
ops
)
img
=
Image
.
fromarray
(
img
)
img
=
ops
[
0
](
img
)
img
=
ops
[
1
](
img
)
img
=
ops
[
2
](
img
)
img
=
np
.
asarray
(
img
)
return
img
,
data
class
RandomExpand
:
"""Randomly expand images and gt boxes by random ratio. It is a data enhancement operation for model training.
Args:
max_ratio(float): Max value for expansion ratio, default is 4.
fill(list): Initialize the pixel value of the image with the input fill value, default is None.
keep_ratio(bool): Whether image keeps ratio.
thresh(float): If random ratio does not exceed the thresh, return original images and gt boxes, default is 0.5.
Return:
img(np.ndarray): Distorted image.
data(dict): Image info and label info.
"""
def
__init__
(
self
,
max_ratio
:
float
=
4.
,
fill
:
list
=
None
,
keep_ratio
:
bool
=
True
,
thresh
:
float
=
0.5
):
self
.
max_ratio
=
max_ratio
self
.
fill
=
fill
self
.
keep_ratio
=
keep_ratio
self
.
thresh
=
thresh
def
__call__
(
self
,
img
:
np
.
ndarray
,
data
:
dict
):
gtboxes
=
data
[
'gt_boxes'
]
if
random
.
random
()
>
self
.
thresh
:
return
img
,
data
if
self
.
max_ratio
<
1.0
:
return
img
,
data
h
,
w
,
c
=
img
.
shape
ratio_x
=
random
.
uniform
(
1
,
self
.
max_ratio
)
if
self
.
keep_ratio
:
ratio_y
=
ratio_x
else
:
ratio_y
=
random
.
uniform
(
1
,
self
.
max_ratio
)
oh
=
int
(
h
*
ratio_y
)
ow
=
int
(
w
*
ratio_x
)
off_x
=
random
.
randint
(
0
,
ow
-
w
)
off_y
=
random
.
randint
(
0
,
oh
-
h
)
out_img
=
np
.
zeros
((
oh
,
ow
,
c
))
if
self
.
fill
and
len
(
self
.
fill
)
==
c
:
for
i
in
range
(
c
):
out_img
[:,
:,
i
]
=
self
.
fill
[
i
]
*
255.0
out_img
[
off_y
:
off_y
+
h
,
off_x
:
off_x
+
w
,
:]
=
img
gtboxes
[:,
0
]
=
((
gtboxes
[:,
0
]
*
w
)
+
off_x
)
/
float
(
ow
)
gtboxes
[:,
1
]
=
((
gtboxes
[:,
1
]
*
h
)
+
off_y
)
/
float
(
oh
)
gtboxes
[:,
2
]
=
gtboxes
[:,
2
]
/
ratio_x
gtboxes
[:,
3
]
=
gtboxes
[:,
3
]
/
ratio_y
data
[
'gt_boxes'
]
=
gtboxes
img
=
out_img
.
astype
(
'uint8'
)
return
img
,
data
class
RandomCrop
:
"""
Random crop the input image according to constraints.
Args:
scales(list): The value of the cutting area relative to the original area, expressed in the form of
\
[min, max]. The default value is [.3, 1.].
max_ratio(float): Max ratio of the original area relative to the cutting area, default is 2.0.
constraints(list): The value of min and max iou values, default is None.
max_trial(int): The max trial for finding a valid crop area. The default value is 50.
Returns:
img(np.ndarray): Distorted image.
data(dict): Image info and label info.
"""
def
__init__
(
self
,
scales
:
list
=
[
0.3
,
1.0
],
max_ratio
:
float
=
2.0
,
constraints
:
list
=
None
,
max_trial
:
int
=
50
):
self
.
scales
=
scales
self
.
max_ratio
=
max_ratio
self
.
constraints
=
constraints
self
.
max_trial
=
max_trial
def
__call__
(
self
,
img
:
np
.
ndarray
,
data
:
dict
):
boxes
=
data
[
'gt_boxes'
]
labels
=
data
[
'gt_labels'
]
scores
=
data
[
'gt_scores'
]
if
len
(
boxes
)
==
0
:
return
img
,
data
if
not
self
.
constraints
:
self
.
constraints
=
[(
0.1
,
1.0
),
(
0.3
,
1.0
),
(
0.5
,
1.0
),
(
0.7
,
1.0
),
(
0.9
,
1.0
),
(
0.0
,
1.0
)]
img
=
Image
.
fromarray
(
img
)
w
,
h
=
img
.
size
crops
=
[(
0
,
0
,
w
,
h
)]
for
min_iou
,
max_iou
in
self
.
constraints
:
for
_
in
range
(
self
.
max_trial
):
scale
=
random
.
uniform
(
self
.
scales
[
0
],
self
.
scales
[
1
])
aspect_ratio
=
random
.
uniform
(
max
(
1
/
self
.
max_ratio
,
scale
*
scale
),
\
min
(
self
.
max_ratio
,
1
/
scale
/
scale
))
crop_h
=
int
(
h
*
scale
/
np
.
sqrt
(
aspect_ratio
))
crop_w
=
int
(
w
*
scale
*
np
.
sqrt
(
aspect_ratio
))
crop_x
=
random
.
randrange
(
w
-
crop_w
)
crop_y
=
random
.
randrange
(
h
-
crop_h
)
crop_box
=
np
.
array
([[(
crop_x
+
crop_w
/
2.0
)
/
w
,
(
crop_y
+
crop_h
/
2.0
)
/
h
,
crop_w
/
float
(
w
),
crop_h
/
float
(
h
)]])
iou
=
box_iou_xywh
(
crop_box
,
boxes
)
if
min_iou
<=
iou
.
min
()
and
max_iou
>=
iou
.
max
():
crops
.
append
((
crop_x
,
crop_y
,
crop_w
,
crop_h
))
break
while
crops
:
crop
=
crops
.
pop
(
np
.
random
.
randint
(
0
,
len
(
crops
)))
crop_boxes
,
crop_labels
,
crop_scores
,
box_num
=
box_crop
(
boxes
,
labels
,
scores
,
crop
,
(
w
,
h
))
if
box_num
<
1
:
continue
img
=
img
.
crop
((
crop
[
0
],
crop
[
1
],
crop
[
0
]
+
crop
[
2
],
crop
[
1
]
+
crop
[
3
])).
resize
(
img
.
size
,
Image
.
LANCZOS
)
img
=
np
.
asarray
(
img
)
data
[
'gt_boxes'
]
=
crop_boxes
data
[
'gt_labels'
]
=
crop_labels
data
[
'gt_scores'
]
=
crop_scores
return
img
,
data
img
=
np
.
asarray
(
img
)
data
[
'gt_boxes'
]
=
boxes
data
[
'gt_labels'
]
=
labels
data
[
'gt_scores'
]
=
scores
return
img
,
data
class
RandomFlip
:
"""Flip the images and gt boxes randomly.
Args:
thresh: Probability for random flip.
Returns:
img(np.ndarray): Distorted image.
data(dict): Image info and label info.
"""
def
__init__
(
self
,
thresh
:
float
=
0.5
):
self
.
thresh
=
thresh
def
__call__
(
self
,
img
,
data
):
gtboxes
=
data
[
'gt_boxes'
]
if
random
.
random
()
>
self
.
thresh
:
img
=
img
[:,
::
-
1
,
:]
gtboxes
[:,
0
]
=
1.0
-
gtboxes
[:,
0
]
data
[
'gt_boxes'
]
=
gtboxes
return
img
,
data
class
Compose
:
"""Preprocess the input data according to the operators.
Args:
transforms(list): Preprocessing operators.
Returns:
img(np.ndarray): Preprocessed image.
data(dict): Image info and label info, default is None.
"""
def
__init__
(
self
,
transforms
:
list
):
if
not
isinstance
(
transforms
,
list
):
raise
TypeError
(
'The transforms must be a list!'
)
if
len
(
transforms
)
<
1
:
raise
ValueError
(
'The length of transforms '
+
\
'must be equal or larger than 1!'
)
self
.
transforms
=
transforms
def
__call__
(
self
,
data
:
dict
):
if
isinstance
(
data
,
dict
):
if
isinstance
(
data
[
'image'
],
str
):
img
=
cv2
.
imread
(
data
[
'image'
])
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
)
gt_labels
=
data
[
'gt_labels'
].
copy
()
data
[
'gt_scores'
]
=
np
.
ones_like
(
gt_labels
)
for
op
in
self
.
transforms
:
img
,
data
=
op
(
img
,
data
)
img
=
img
.
transpose
((
2
,
0
,
1
))
return
img
,
data
if
isinstance
(
data
,
str
):
img
=
cv2
.
imread
(
data
)
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
)
for
op
in
self
.
transforms
:
img
,
data
=
op
(
img
,
data
)
img
=
img
.
transpose
((
2
,
0
,
1
))
return
img
class
Resize
:
"""Resize the input images.
Args:
target_size(int): Targeted input size.
interp(str): Interpolation method.
Returns:
img(np.ndarray): Preprocessed image.
data(dict): Image info and label info, default is None.
"""
def
__init__
(
self
,
target_size
:
int
=
512
,
interp
:
str
=
'RANDOM'
):
self
.
interp_dict
=
{
'NEAREST'
:
cv2
.
INTER_NEAREST
,
'LINEAR'
:
cv2
.
INTER_LINEAR
,
'CUBIC'
:
cv2
.
INTER_CUBIC
,
'AREA'
:
cv2
.
INTER_AREA
,
'LANCZOS4'
:
cv2
.
INTER_LANCZOS4
}
self
.
interp
=
interp
if
not
(
interp
==
"RANDOM"
or
interp
in
self
.
interp_dict
):
raise
ValueError
(
"interp should be one of {}"
.
format
(
self
.
interp_dict
.
keys
()))
if
isinstance
(
target_size
,
list
)
or
isinstance
(
target_size
,
tuple
):
if
len
(
target_size
)
!=
2
:
raise
TypeError
(
'when target is list or tuple, it should include 2 elements, but it is {}'
.
format
(
target_size
))
elif
not
isinstance
(
target_size
,
int
):
raise
TypeError
(
"Type of target_size is invalid. Must be Integer or List or tuple, now is {}"
.
format
(
type
(
target_size
)))
self
.
target_size
=
target_size
def
__call__
(
self
,
img
,
data
=
None
):
if
self
.
interp
==
"RANDOM"
:
interp
=
random
.
choice
(
list
(
self
.
interp_dict
.
keys
()))
else
:
interp
=
self
.
interp
img
=
resize
(
img
,
self
.
target_size
,
self
.
interp_dict
[
interp
])
if
data
is
not
None
:
return
img
,
data
else
:
return
img
class
Normalize
:
"""Normalize the input images.
Args:
mean(list): Mean values for normalization, default is [0.5, 0.5, 0.5].
std(list): Standard deviation for normalization, default is [0.5, 0.5, 0.5].
Returns:
img(np.ndarray): Preprocessed image.
data(dict): Image info and label info, default is None.
"""
def
__init__
(
self
,
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
0.5
,
0.5
,
0.5
]):
self
.
mean
=
mean
self
.
std
=
std
if
not
(
isinstance
(
self
.
mean
,
list
)
and
isinstance
(
self
.
std
,
list
)):
raise
ValueError
(
"{}: input type is invalid."
.
format
(
self
))
from
functools
import
reduce
if
reduce
(
lambda
x
,
y
:
x
*
y
,
self
.
std
)
==
0
:
raise
ValueError
(
'{}: std is invalid!'
.
format
(
self
))
def
__call__
(
self
,
im
,
data
=
None
):
if
data
is
not
None
:
mean
=
np
.
array
(
self
.
mean
)[
np
.
newaxis
,
np
.
newaxis
,
:]
std
=
np
.
array
(
self
.
std
)[
np
.
newaxis
,
np
.
newaxis
,
:]
im
=
normalize
(
im
,
mean
,
std
)
return
im
,
data
else
:
mean
=
np
.
array
(
self
.
mean
)[
np
.
newaxis
,
np
.
newaxis
,
:]
std
=
np
.
array
(
self
.
std
)[
np
.
newaxis
,
np
.
newaxis
,
:]
im
=
normalize
(
im
,
mean
,
std
)
return
im
class
ShuffleBox
:
"""Shuffle data information."""
def
__call__
(
self
,
img
,
data
):
gt
=
np
.
concatenate
([
data
[
'gt_boxes'
],
data
[
'gt_labels'
][:,
np
.
newaxis
],
data
[
'gt_scores'
][:,
np
.
newaxis
]],
axis
=
1
)
idx
=
np
.
arange
(
gt
.
shape
[
0
])
np
.
random
.
shuffle
(
idx
)
gt
=
gt
[
idx
,
:]
data
[
'gt_boxes'
],
data
[
'gt_labels'
],
data
[
'gt_scores'
]
=
gt
[:,
:
4
],
gt
[:,
4
],
gt
[:,
5
]
return
img
,
data
paddlehub/process/functional.py
浏览文件 @
1f2180b6
...
...
@@ -11,12 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
cv2
import
paddle
import
matplotlib
import
numpy
as
np
from
PIL
import
Image
,
ImageEnhance
from
matplotlib
import
pyplot
as
plt
matplotlib
.
use
(
'Agg'
)
def
normalize
(
im
,
mean
,
std
):
...
...
@@ -118,3 +122,138 @@ def get_img_file(dir_name: str) -> list:
images
.
append
(
img_path
)
images
.
sort
()
return
images
def
coco_anno_box_to_center_relative
(
box
:
list
,
img_height
:
int
,
img_width
:
int
)
->
np
.
ndarray
:
"""
Convert COCO annotations box with format [x1, y1, w, h] to
center mode [center_x, center_y, w, h] and divide image width
and height to get relative value in range[0, 1]
"""
assert
len
(
box
)
==
4
,
"box should be a len(4) list or tuple"
x
,
y
,
w
,
h
=
box
x1
=
max
(
x
,
0
)
x2
=
min
(
x
+
w
-
1
,
img_width
-
1
)
y1
=
max
(
y
,
0
)
y2
=
min
(
y
+
h
-
1
,
img_height
-
1
)
x
=
(
x1
+
x2
)
/
2
/
img_width
y
=
(
y1
+
y2
)
/
2
/
img_height
w
=
(
x2
-
x1
)
/
img_width
h
=
(
y2
-
y1
)
/
img_height
return
np
.
array
([
x
,
y
,
w
,
h
])
def
box_crop
(
boxes
:
np
.
ndarray
,
labels
:
np
.
ndarray
,
scores
:
np
.
ndarray
,
crop
:
list
,
img_shape
:
list
):
"""Crop the boxes ,labels, scores according to the given shape"""
x
,
y
,
w
,
h
=
map
(
float
,
crop
)
im_w
,
im_h
=
map
(
float
,
img_shape
)
boxes
=
boxes
.
copy
()
boxes
[:,
0
],
boxes
[:,
2
]
=
(
boxes
[:,
0
]
-
boxes
[:,
2
]
/
2
)
*
im_w
,
(
boxes
[:,
0
]
+
boxes
[:,
2
]
/
2
)
*
im_w
boxes
[:,
1
],
boxes
[:,
3
]
=
(
boxes
[:,
1
]
-
boxes
[:,
3
]
/
2
)
*
im_h
,
(
boxes
[:,
1
]
+
boxes
[:,
3
]
/
2
)
*
im_h
crop_box
=
np
.
array
([
x
,
y
,
x
+
w
,
y
+
h
])
centers
=
(
boxes
[:,
:
2
]
+
boxes
[:,
2
:])
/
2.0
mask
=
np
.
logical_and
(
crop_box
[:
2
]
<=
centers
,
centers
<=
crop_box
[
2
:]).
all
(
axis
=
1
)
boxes
[:,
:
2
]
=
np
.
maximum
(
boxes
[:,
:
2
],
crop_box
[:
2
])
boxes
[:,
2
:]
=
np
.
minimum
(
boxes
[:,
2
:],
crop_box
[
2
:])
boxes
[:,
:
2
]
-=
crop_box
[:
2
]
boxes
[:,
2
:]
-=
crop_box
[:
2
]
mask
=
np
.
logical_and
(
mask
,
(
boxes
[:,
:
2
]
<
boxes
[:,
2
:]).
all
(
axis
=
1
))
boxes
=
boxes
*
np
.
expand_dims
(
mask
.
astype
(
'float32'
),
axis
=
1
)
labels
=
labels
*
mask
.
astype
(
'float32'
)
scores
=
scores
*
mask
.
astype
(
'float32'
)
boxes
[:,
0
],
boxes
[:,
2
]
=
(
boxes
[:,
0
]
+
boxes
[:,
2
])
/
2
/
w
,
(
boxes
[:,
2
]
-
boxes
[:,
0
])
/
w
boxes
[:,
1
],
boxes
[:,
3
]
=
(
boxes
[:,
1
]
+
boxes
[:,
3
])
/
2
/
h
,
(
boxes
[:,
3
]
-
boxes
[:,
1
])
/
h
return
boxes
,
labels
,
scores
,
mask
.
sum
()
def
box_iou_xywh
(
box1
:
np
.
ndarray
,
box2
:
np
.
ndarray
)
->
float
:
"""Calculate iou by xywh"""
assert
box1
.
shape
[
-
1
]
==
4
,
"Box1 shape[-1] should be 4."
assert
box2
.
shape
[
-
1
]
==
4
,
"Box2 shape[-1] should be 4."
b1_x1
,
b1_x2
=
box1
[:,
0
]
-
box1
[:,
2
]
/
2
,
box1
[:,
0
]
+
box1
[:,
2
]
/
2
b1_y1
,
b1_y2
=
box1
[:,
1
]
-
box1
[:,
3
]
/
2
,
box1
[:,
1
]
+
box1
[:,
3
]
/
2
b2_x1
,
b2_x2
=
box2
[:,
0
]
-
box2
[:,
2
]
/
2
,
box2
[:,
0
]
+
box2
[:,
2
]
/
2
b2_y1
,
b2_y2
=
box2
[:,
1
]
-
box2
[:,
3
]
/
2
,
box2
[:,
1
]
+
box2
[:,
3
]
/
2
inter_x1
=
np
.
maximum
(
b1_x1
,
b2_x1
)
inter_x2
=
np
.
minimum
(
b1_x2
,
b2_x2
)
inter_y1
=
np
.
maximum
(
b1_y1
,
b2_y1
)
inter_y2
=
np
.
minimum
(
b1_y2
,
b2_y2
)
inter_w
=
inter_x2
-
inter_x1
inter_h
=
inter_y2
-
inter_y1
inter_w
[
inter_w
<
0
]
=
0
inter_h
[
inter_h
<
0
]
=
0
inter_area
=
inter_w
*
inter_h
b1_area
=
(
b1_x2
-
b1_x1
)
*
(
b1_y2
-
b1_y1
)
b2_area
=
(
b2_x2
-
b2_x1
)
*
(
b2_y2
-
b2_y1
)
return
inter_area
/
(
b1_area
+
b2_area
-
inter_area
)
def
draw_boxes_on_image
(
image_path
:
str
,
boxes
:
np
.
ndarray
,
scores
:
np
.
ndarray
,
labels
:
np
.
ndarray
,
label_names
:
list
,
score_thresh
:
float
=
0.5
):
"""Draw boxes on images."""
image
=
np
.
array
(
Image
.
open
(
image_path
))
plt
.
figure
()
_
,
ax
=
plt
.
subplots
(
1
)
ax
.
imshow
(
image
)
image_name
=
image_path
.
split
(
'/'
)[
-
1
]
print
(
"Image {} detect: "
.
format
(
image_name
))
colors
=
{}
for
box
,
score
,
label
in
zip
(
boxes
,
scores
,
labels
):
if
score
<
score_thresh
:
continue
if
box
[
2
]
<=
box
[
0
]
or
box
[
3
]
<=
box
[
1
]:
continue
label
=
int
(
label
)
if
label
not
in
colors
:
colors
[
label
]
=
plt
.
get_cmap
(
'hsv'
)(
label
/
len
(
label_names
))
x1
,
y1
,
x2
,
y2
=
box
[
0
],
box
[
1
],
box
[
2
],
box
[
3
]
rect
=
plt
.
Rectangle
((
x1
,
y1
),
x2
-
x1
,
y2
-
y1
,
fill
=
False
,
linewidth
=
2.0
,
edgecolor
=
colors
[
label
])
ax
.
add_patch
(
rect
)
ax
.
text
(
x1
,
y1
,
'{} {:.4f}'
.
format
(
label_names
[
label
],
score
),
verticalalignment
=
'bottom'
,
horizontalalignment
=
'left'
,
bbox
=
{
'facecolor'
:
colors
[
label
],
'alpha'
:
0.5
,
'pad'
:
0
},
fontsize
=
8
,
color
=
'white'
)
print
(
"
\t
{:15s} at {:25} score: {:.5f}"
.
format
(
label_names
[
int
(
label
)],
str
(
list
(
map
(
int
,
list
(
box
)))),
score
))
image_name
=
image_name
.
replace
(
'jpg'
,
'png'
)
plt
.
axis
(
'off'
)
plt
.
gca
().
xaxis
.
set_major_locator
(
plt
.
NullLocator
())
plt
.
gca
().
yaxis
.
set_major_locator
(
plt
.
NullLocator
())
plt
.
savefig
(
"./output/{}"
.
format
(
image_name
),
bbox_inches
=
'tight'
,
pad_inches
=
0.0
)
print
(
"Detect result save at ./output/{}
\n
"
.
format
(
image_name
))
plt
.
cla
()
plt
.
close
(
'all'
)
def
img_shape
(
img_path
:
str
):
"""Get image shape."""
im
=
cv2
.
imread
(
img_path
)
im
=
cv2
.
cvtColor
(
im
,
cv2
.
COLOR_BGR2RGB
)
h
,
w
,
c
=
im
.
shape
return
h
,
w
,
c
paddlehub/process/transforms.py
浏览文件 @
1f2180b6
此差异已折叠。
点击以展开。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录