Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
86c06294
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
1 年多 前同步成功
通知
283
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
86c06294
编写于
10月 14, 2020
作者:
W
wuzewu
提交者:
GitHub
10月 14, 2020
浏览文件
操作
浏览文件
下载
差异文件
add detection module
上级
cbae1549
722c2d58
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
1145 addition
and
15 deletion
+1145
-15
demo/detection/yolov3_darknet53_pascalvoc/4026.jpeg
demo/detection/yolov3_darknet53_pascalvoc/4026.jpeg
+0
-0
demo/detection/yolov3_darknet53_pascalvoc/predict.py
demo/detection/yolov3_darknet53_pascalvoc/predict.py
+9
-0
demo/detection/yolov3_darknet53_pascalvoc/train.py
demo/detection/yolov3_darknet53_pascalvoc/train.py
+24
-0
hub_module/modules/image/object_detection/yolov3_darknet53_pascalvoc/module.py
...age/object_detection/yolov3_darknet53_pascalvoc/module.py
+318
-0
paddlehub/datasets/pascalvoc.py
paddlehub/datasets/pascalvoc.py
+190
-0
paddlehub/module/cv_module.py
paddlehub/module/cv_module.py
+130
-12
paddlehub/process/detect_transforms.py
paddlehub/process/detect_transforms.py
+342
-0
paddlehub/process/functional.py
paddlehub/process/functional.py
+128
-1
paddlehub/process/transforms.py
paddlehub/process/transforms.py
+4
-2
未找到文件。
demo/detection/yolov3_darknet53_pascalvoc/4026.jpeg
0 → 100644
浏览文件 @
86c06294
83.0 KB
demo/detection/yolov3_darknet53_pascalvoc/predict.py
0 → 100644
浏览文件 @
86c06294
import
paddle
import
paddlehub
as
hub
if
__name__
==
'__main__'
:
place
=
paddle
.
CUDAPlace
(
0
)
paddle
.
disable_static
()
model
=
hub
.
Module
(
name
=
'yolov3_darknet53_pascalvoc'
,
is_train
=
False
)
model
.
eval
()
model
.
predict
(
imgpath
=
"4026.jpeg"
,
filelist
=
"/PATH/TO/JSON/FILE"
)
demo/detection/yolov3_darknet53_pascalvoc/train.py
0 → 100644
浏览文件 @
86c06294
import
paddle
import
paddlehub
as
hub
import
paddle.nn
as
nn
from
paddlehub.finetune.trainer
import
Trainer
from
paddlehub.datasets.pascalvoc
import
DetectionData
import
paddlehub.process.detect_transforms
as
T
if
__name__
==
"__main__"
:
paddle
.
disable_static
()
transform
=
T
.
Compose
([
T
.
RandomDistort
(),
T
.
RandomExpand
(
fill
=
[
0.485
,
0.456
,
0.406
]),
T
.
RandomCrop
(),
T
.
Resize
(
target_size
=
416
),
T
.
RandomFlip
(),
T
.
ShuffleBox
(),
T
.
Normalize
(
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
])
])
train_reader
=
DetectionData
(
transform
)
model
=
hub
.
Module
(
name
=
'yolov3_darknet53_pascalvoc'
)
optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
0.0001
,
parameters
=
model
.
parameters
())
trainer
=
Trainer
(
model
,
optimizer
,
checkpoint_dir
=
'test_ckpt_img_det'
)
trainer
.
train
(
train_reader
,
epochs
=
5
,
batch_size
=
4
,
eval_dataset
=
train_reader
,
log_interval
=
1
,
save_interval
=
1
)
hub_module/modules/image/object_detection/yolov3_darknet53_pascalvoc/module.py
0 → 100644
浏览文件 @
86c06294
import
os
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddle.nn.initializer
import
Normal
,
Constant
from
paddle.regularizer
import
L2Decay
from
paddlehub.module.cv_module
import
Yolov3Module
import
paddlehub.process.detect_transforms
as
T
from
paddlehub.module.module
import
moduleinfo
class
ConvBNLayer
(
nn
.
Layer
):
"""Basic block for Darknet"""
def
__init__
(
self
,
ch_in
:
int
,
ch_out
:
int
,
filter_size
:
int
=
3
,
stride
:
int
=
1
,
groups
:
int
=
1
,
padding
:
int
=
0
,
act
:
str
=
'leakly'
,
is_test
:
bool
=
False
):
super
(
ConvBNLayer
,
self
).
__init__
()
self
.
conv
=
nn
.
Conv2d
(
ch_in
,
ch_out
,
filter_size
,
padding
=
padding
,
stride
=
stride
,
groups
=
groups
,
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
Normal
(
0.
,
0.02
)),
bias_attr
=
False
)
self
.
batch_norm
=
nn
.
BatchNorm
(
num_channels
=
ch_out
,
is_test
=
is_test
,
param_attr
=
paddle
.
ParamAttr
(
initializer
=
Normal
(
0.
,
0.02
),
regularizer
=
L2Decay
(
0.
)))
self
.
act
=
act
def
forward
(
self
,
inputs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
out
=
self
.
conv
(
inputs
)
out
=
self
.
batch_norm
(
out
)
if
self
.
act
==
"leakly"
:
out
=
F
.
leaky_relu
(
x
=
out
,
negative_slope
=
0.1
)
return
out
class
DownSample
(
nn
.
Layer
):
"""Downsample block for Darknet"""
def
__init__
(
self
,
ch_in
:
int
,
ch_out
:
int
,
filter_size
:
int
=
3
,
stride
:
int
=
2
,
padding
:
int
=
1
,
is_test
:
bool
=
False
):
super
(
DownSample
,
self
).
__init__
()
self
.
conv_bn_layer
=
ConvBNLayer
(
ch_in
=
ch_in
,
ch_out
=
ch_out
,
filter_size
=
filter_size
,
stride
=
stride
,
padding
=
padding
,
is_test
=
is_test
)
self
.
ch_out
=
ch_out
def
forward
(
self
,
inputs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
out
=
self
.
conv_bn_layer
(
inputs
)
return
out
class
BasicBlock
(
nn
.
Layer
):
"""Basic residual block for Darknet"""
def
__init__
(
self
,
ch_in
:
int
,
ch_out
:
int
,
is_test
:
bool
=
False
):
super
(
BasicBlock
,
self
).
__init__
()
self
.
conv1
=
ConvBNLayer
(
ch_in
=
ch_in
,
ch_out
=
ch_out
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
is_test
=
is_test
)
self
.
conv2
=
ConvBNLayer
(
ch_in
=
ch_out
,
ch_out
=
ch_out
*
2
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
,
is_test
=
is_test
)
def
forward
(
self
,
inputs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
conv1
=
self
.
conv1
(
inputs
)
conv2
=
self
.
conv2
(
conv1
)
out
=
paddle
.
elementwise_add
(
x
=
inputs
,
y
=
conv2
,
act
=
None
)
return
out
class
LayerWarp
(
nn
.
Layer
):
"""Warp layer composed by basic residual blocks"""
def
__init__
(
self
,
ch_in
:
int
,
ch_out
:
int
,
count
:
int
,
is_test
:
bool
=
False
):
super
(
LayerWarp
,
self
).
__init__
()
self
.
basicblock0
=
BasicBlock
(
ch_in
,
ch_out
,
is_test
=
is_test
)
self
.
res_out_list
=
[]
for
i
in
range
(
1
,
count
):
res_out
=
self
.
add_sublayer
(
"basic_block_%d"
%
(
i
),
BasicBlock
(
ch_out
*
2
,
ch_out
,
is_test
=
is_test
))
self
.
res_out_list
.
append
(
res_out
)
self
.
ch_out
=
ch_out
def
forward
(
self
,
inputs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
y
=
self
.
basicblock0
(
inputs
)
for
basic_block_i
in
self
.
res_out_list
:
y
=
basic_block_i
(
y
)
return
y
class
DarkNet53_conv_body
(
nn
.
Layer
):
"""Darknet53
Args:
ch_in(int): Input channels, default is 3.
is_test (bool): Set the test mode, default is True.
"""
def
__init__
(
self
,
ch_in
:
int
=
3
,
is_test
:
bool
=
False
):
super
(
DarkNet53_conv_body
,
self
).
__init__
()
self
.
stages
=
[
1
,
2
,
8
,
8
,
4
]
self
.
stages
=
self
.
stages
[
0
:
5
]
self
.
conv0
=
ConvBNLayer
(
ch_in
=
ch_in
,
ch_out
=
32
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
,
is_test
=
is_test
)
self
.
downsample0
=
DownSample
(
ch_in
=
32
,
ch_out
=
32
*
2
,
is_test
=
is_test
)
self
.
darknet53_conv_block_list
=
[]
self
.
downsample_list
=
[]
ch_in
=
[
64
,
128
,
256
,
512
,
1024
]
for
i
,
stage
in
enumerate
(
self
.
stages
):
conv_block
=
self
.
add_sublayer
(
"stage_%d"
%
(
i
),
LayerWarp
(
int
(
ch_in
[
i
]),
32
*
(
2
**
i
),
stage
,
is_test
=
is_test
))
self
.
darknet53_conv_block_list
.
append
(
conv_block
)
for
i
in
range
(
len
(
self
.
stages
)
-
1
):
downsample
=
self
.
add_sublayer
(
"stage_%d_downsample"
%
i
,
DownSample
(
ch_in
=
32
*
(
2
**
(
i
+
1
)),
ch_out
=
32
*
(
2
**
(
i
+
2
)),
is_test
=
is_test
))
self
.
downsample_list
.
append
(
downsample
)
def
forward
(
self
,
inputs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
out
=
self
.
conv0
(
inputs
)
out
=
self
.
downsample0
(
out
)
blocks
=
[]
for
i
,
conv_block_i
in
enumerate
(
self
.
darknet53_conv_block_list
):
out
=
conv_block_i
(
out
)
blocks
.
append
(
out
)
if
i
<
len
(
self
.
stages
)
-
1
:
out
=
self
.
downsample_list
[
i
](
out
)
return
blocks
[
-
1
:
-
4
:
-
1
]
class
YoloDetectionBlock
(
nn
.
Layer
):
"""Basic block for Yolov3"""
def
__init__
(
self
,
ch_in
:
int
,
channel
:
int
,
is_test
:
bool
=
True
):
super
(
YoloDetectionBlock
,
self
).
__init__
()
assert
channel
%
2
==
0
,
\
"channel {} cannot be divided by 2"
.
format
(
channel
)
self
.
conv0
=
ConvBNLayer
(
ch_in
=
ch_in
,
ch_out
=
channel
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
is_test
=
is_test
)
self
.
conv1
=
ConvBNLayer
(
ch_in
=
channel
,
ch_out
=
channel
*
2
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
,
is_test
=
is_test
)
self
.
conv2
=
ConvBNLayer
(
ch_in
=
channel
*
2
,
ch_out
=
channel
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
is_test
=
is_test
)
self
.
conv3
=
ConvBNLayer
(
ch_in
=
channel
,
ch_out
=
channel
*
2
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
,
is_test
=
is_test
)
self
.
route
=
ConvBNLayer
(
ch_in
=
channel
*
2
,
ch_out
=
channel
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
is_test
=
is_test
)
self
.
tip
=
ConvBNLayer
(
ch_in
=
channel
,
ch_out
=
channel
*
2
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
,
is_test
=
is_test
)
def
forward
(
self
,
inputs
):
out
=
self
.
conv0
(
inputs
)
out
=
self
.
conv1
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
conv3
(
out
)
route
=
self
.
route
(
out
)
tip
=
self
.
tip
(
route
)
return
route
,
tip
class
Upsample
(
nn
.
Layer
):
"""Upsample block for Yolov3"""
def
__init__
(
self
,
scale
:
int
=
2
):
super
(
Upsample
,
self
).
__init__
()
self
.
scale
=
scale
def
forward
(
self
,
inputs
:
paddle
.
Tensor
):
shape_nchw
=
paddle
.
to_tensor
(
inputs
.
shape
)
shape_hw
=
paddle
.
slice
(
shape_nchw
,
axes
=
[
0
],
starts
=
[
2
],
ends
=
[
4
])
shape_hw
.
stop_gradient
=
True
in_shape
=
paddle
.
cast
(
shape_hw
,
dtype
=
'int32'
)
out_shape
=
in_shape
*
self
.
scale
out_shape
.
stop_gradient
=
True
out
=
F
.
resize_nearest
(
input
=
inputs
,
scale
=
self
.
scale
,
actual_shape
=
out_shape
)
return
out
@
moduleinfo
(
name
=
"yolov3_darknet53_pascalvoc"
,
type
=
"CV/image_editing"
,
author
=
"paddlepaddle"
,
author_email
=
""
,
summary
=
"Yolov3 is a detection model, this module is trained with VOC dataset."
,
version
=
"1.0.0"
,
meta
=
Yolov3Module
)
class
YOLOv3
(
nn
.
Layer
):
"""YOLOV3 for detection
Args:
ch_in(int): Input channels, default is 3.
class_num(int): Categories for detection,if dataset is voc, class_num is 20.
ignore_thresh(float): The ignore threshold to ignore confidence loss.
valid_thresh(float): Threshold to filter out bounding boxes with low confidence score.
nms_topk(int): Maximum number of detections to be kept according to the confidences after the filtering
detections based on score_threshold.
nms_posk(int): Number of total bboxes to be kept per image after NMS step. -1 means keeping all bboxes after NMS
step.
nms_thresh (float): The threshold to be used in NMS. Default: 0.3.
is_train (bool): Set the train mode, default is True.
load_checkpoint(str): Whether to load checkpoint.
"""
def
__init__
(
self
,
ch_in
:
int
=
3
,
class_num
:
int
=
20
,
ignore_thresh
:
float
=
0.7
,
valid_thresh
:
float
=
0.005
,
nms_topk
:
int
=
400
,
nms_posk
:
int
=
100
,
nms_thresh
:
float
=
0.45
,
is_train
:
bool
=
True
,
load_checkpoint
:
str
=
None
):
super
(
YOLOv3
,
self
).
__init__
()
self
.
is_train
=
is_train
self
.
block
=
DarkNet53_conv_body
(
ch_in
=
ch_in
,
is_test
=
not
self
.
is_train
)
self
.
block_outputs
=
[]
self
.
yolo_blocks
=
[]
self
.
route_blocks_2
=
[]
self
.
anchor_masks
=
[[
6
,
7
,
8
],
[
3
,
4
,
5
],
[
0
,
1
,
2
]]
self
.
anchors
=
[
10
,
13
,
16
,
30
,
33
,
23
,
30
,
61
,
62
,
45
,
59
,
119
,
116
,
90
,
156
,
198
,
373
,
326
]
self
.
class_num
=
class_num
self
.
ignore_thresh
=
ignore_thresh
self
.
valid_thresh
=
valid_thresh
self
.
nms_topk
=
nms_topk
self
.
nms_posk
=
nms_posk
self
.
nms_thresh
=
nms_thresh
ch_in_list
=
[
1024
,
768
,
384
]
for
i
in
range
(
3
):
yolo_block
=
self
.
add_sublayer
(
"yolo_detecton_block_%d"
%
(
i
),
YoloDetectionBlock
(
ch_in_list
[
i
],
channel
=
512
//
(
2
**
i
),
is_test
=
not
self
.
is_train
))
self
.
yolo_blocks
.
append
(
yolo_block
)
num_filters
=
len
(
self
.
anchor_masks
[
i
])
*
(
self
.
class_num
+
5
)
block_out
=
self
.
add_sublayer
(
"block_out_%d"
%
(
i
),
nn
.
Conv2d
(
1024
//
(
2
**
i
),
num_filters
,
1
,
stride
=
1
,
padding
=
0
,
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
Normal
(
0.
,
0.02
)),
bias_attr
=
paddle
.
ParamAttr
(
initializer
=
Constant
(
0.0
),
regularizer
=
L2Decay
(
0.
))))
self
.
block_outputs
.
append
(
block_out
)
if
i
<
2
:
route
=
self
.
add_sublayer
(
"route2_%d"
%
i
,
ConvBNLayer
(
ch_in
=
512
//
(
2
**
i
),
ch_out
=
256
//
(
2
**
i
),
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
is_test
=
(
not
self
.
is_train
)))
self
.
route_blocks_2
.
append
(
route
)
self
.
upsample
=
Upsample
()
if
load_checkpoint
is
not
None
:
model_dict
=
paddle
.
load
(
load_checkpoint
)[
0
]
self
.
set_dict
(
model_dict
)
print
(
"load custom checkpoint success"
)
else
:
checkpoint
=
os
.
path
.
join
(
self
.
directory
,
'yolov3_darknet53_voc.pdparams'
)
if
not
os
.
path
.
exists
(
checkpoint
):
os
.
system
(
'wget https://paddlehub.bj.bcebos.com/dygraph/detection/yolov3_darknet53_voc.pdparams -O '
\
+
checkpoint
)
model_dict
=
paddle
.
load
(
checkpoint
)[
0
]
self
.
set_dict
(
model_dict
)
print
(
"load pretrained checkpoint success"
)
def
transform
(
self
,
img
):
if
self
.
is_train
:
transform
=
T
.
Compose
([
T
.
RandomDistort
(),
T
.
RandomExpand
(
fill
=
[
0.485
,
0.456
,
0.406
]),
T
.
RandomCrop
(),
T
.
Resize
(
target_size
=
416
),
T
.
RandomFlip
(),
T
.
ShuffleBox
(),
T
.
Normalize
(
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
])
])
else
:
transform
=
T
.
Compose
([
T
.
Resize
(
target_size
=
416
,
interp
=
'CUBIC'
),
T
.
Normalize
(
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
])
])
return
transform
(
img
)
def
forward
(
self
,
inputs
:
paddle
.
Tensor
):
outputs
=
[]
blocks
=
self
.
block
(
inputs
)
route
=
None
for
i
,
block
in
enumerate
(
blocks
):
if
i
>
0
:
block
=
paddle
.
concat
([
route
,
block
],
axis
=
1
)
route
,
tip
=
self
.
yolo_blocks
[
i
](
block
)
block_out
=
self
.
block_outputs
[
i
](
tip
)
outputs
.
append
(
block_out
)
if
i
<
2
:
route
=
self
.
route_blocks_2
[
i
](
route
)
route
=
self
.
upsample
(
route
)
return
outputs
paddlehub/datasets/pascalvoc.py
0 → 100644
浏览文件 @
86c06294
# coding:utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
copy
from
typing
import
Callable
import
paddle
import
numpy
as
np
from
paddlehub.env
import
DATA_HOME
from
pycocotools.coco
import
COCO
class
DetectCatagory
:
"""Load label name, id and map from detection dataset.
Args:
attrbox(Callable): Method to get detection attributes of images.
data_dir(str): Image dataset path.
Returns:
label_names(List(str)): The dataset label names.
label_ids(List(int)): The dataset label ids.
category_to_id_map(dict): Mapping relations of category and id for images.
"""
def
__init__
(
self
,
attrbox
:
Callable
,
data_dir
:
str
):
self
.
attrbox
=
attrbox
self
.
img_dir
=
data_dir
def
__call__
(
self
):
self
.
categories
=
self
.
attrbox
.
loadCats
(
self
.
attrbox
.
getCatIds
())
self
.
num_category
=
len
(
self
.
categories
)
label_names
=
[]
label_ids
=
[]
for
category
in
self
.
categories
:
label_names
.
append
(
category
[
'name'
])
label_ids
.
append
(
int
(
category
[
'id'
]))
category_to_id_map
=
{
v
:
i
for
i
,
v
in
enumerate
(
label_ids
)}
return
label_names
,
label_ids
,
category_to_id_map
class
ParseImages
:
"""Prepare images for detection.
Args:
attrbox(Callable): Method to get detection attributes of images.
data_dir(str): Image dataset path.
category_to_id_map(dict): Mapping relations of category and id for images.
Returns:
imgs(dict): The input for detection model, it is a dict.
"""
def
__init__
(
self
,
attrbox
:
Callable
,
data_dir
:
str
,
category_to_id_map
:
dict
):
self
.
attrbox
=
attrbox
self
.
img_dir
=
data_dir
self
.
category_to_id_map
=
category_to_id_map
self
.
parse_gt_annotations
=
GTAnotations
(
self
.
attrbox
,
self
.
category_to_id_map
)
def
__call__
(
self
):
image_ids
=
self
.
attrbox
.
getImgIds
()
image_ids
.
sort
()
imgs
=
copy
.
deepcopy
(
self
.
attrbox
.
loadImgs
(
image_ids
))
for
img
in
imgs
:
img
[
'image'
]
=
os
.
path
.
join
(
self
.
img_dir
,
img
[
'file_name'
])
assert
os
.
path
.
exists
(
img
[
'image'
]),
"image {} not found."
.
format
(
img
[
'image'
])
box_num
=
50
img
[
'gt_boxes'
]
=
np
.
zeros
((
box_num
,
4
),
dtype
=
np
.
float32
)
img
[
'gt_labels'
]
=
np
.
zeros
((
box_num
),
dtype
=
np
.
int32
)
img
=
self
.
parse_gt_annotations
(
img
)
return
imgs
class
GTAnotations
:
"""Set gt boxes and gt labels for train.
Args:
attrbox(Callable): Method for get detection attributes for images.
category_to_id_map(dict): Mapping relations of category and id for images.
img(dict): Input for detection model.
Returns:
img(dict): Set specific value on the attributes of 'gt boxes' and 'gt labels' for input.
"""
def
__init__
(
self
,
attrbox
:
Callable
,
category_to_id_map
:
dict
):
self
.
attrbox
=
attrbox
self
.
category_to_id_map
=
category_to_id_map
def
box_to_center_relative
(
self
,
box
:
list
,
img_height
:
int
,
img_width
:
int
)
->
np
.
ndarray
:
"""
Convert COCO annotations box with format [x1, y1, w, h] to
center mode [center_x, center_y, w, h] and divide image width
and height to get relative value in range[0, 1]
"""
assert
len
(
box
)
==
4
,
"box should be a len(4) list or tuple"
x
,
y
,
w
,
h
=
box
x1
=
max
(
x
,
0
)
x2
=
min
(
x
+
w
-
1
,
img_width
-
1
)
y1
=
max
(
y
,
0
)
y2
=
min
(
y
+
h
-
1
,
img_height
-
1
)
x
=
(
x1
+
x2
)
/
2
/
img_width
y
=
(
y1
+
y2
)
/
2
/
img_height
w
=
(
x2
-
x1
)
/
img_width
h
=
(
y2
-
y1
)
/
img_height
return
np
.
array
([
x
,
y
,
w
,
h
])
def
__call__
(
self
,
img
:
dict
):
img_height
=
img
[
'height'
]
img_width
=
img
[
'width'
]
anno
=
self
.
attrbox
.
loadAnns
(
self
.
attrbox
.
getAnnIds
(
imgIds
=
img
[
'id'
],
iscrowd
=
None
))
gt_index
=
0
for
target
in
anno
:
if
target
[
'area'
]
<
-
1
:
continue
if
'ignore'
in
target
and
target
[
'ignore'
]:
continue
box
=
self
.
box_to_center_relative
(
target
[
'bbox'
],
img_height
,
img_width
)
if
box
[
2
]
<=
0
and
box
[
3
]
<=
0
:
continue
img
[
'gt_boxes'
][
gt_index
]
=
box
img
[
'gt_labels'
][
gt_index
]
=
\
self
.
category_to_id_map
[
target
[
'category_id'
]]
gt_index
+=
1
if
gt_index
>=
50
:
break
return
img
class
DetectionData
(
paddle
.
io
.
Dataset
):
"""
Dataset for image detection.
Args:
transform(callmethod) : The method of preprocess images.
mode(str): The mode for preparing dataset.
Returns:
DataSet: An iterable object for data iterating
"""
def
__init__
(
self
,
transform
:
Callable
,
size
:
int
=
416
,
mode
:
str
=
'train'
):
self
.
mode
=
mode
self
.
transform
=
transform
self
.
size
=
size
if
self
.
mode
==
'train'
:
train_file_list
=
'annotations/instances_train2017.json'
train_data_dir
=
'train2017'
self
.
train_file_list
=
os
.
path
.
join
(
DATA_HOME
,
'voc'
,
train_file_list
)
self
.
train_data_dir
=
os
.
path
.
join
(
DATA_HOME
,
'voc'
,
train_data_dir
)
self
.
COCO
=
COCO
(
self
.
train_file_list
)
self
.
img_dir
=
self
.
train_data_dir
elif
self
.
mode
==
'test'
:
val_file_list
=
'annotations/instances_val2017.json'
val_data_dir
=
'val2017'
self
.
val_file_list
=
os
.
path
.
join
(
DATA_HOME
,
'voc'
,
val_file_list
)
self
.
val_data_dir
=
os
.
path
.
join
(
DATA_HOME
,
'voc'
,
val_data_dir
)
self
.
COCO
=
COCO
(
self
.
val_file_list
)
self
.
img_dir
=
self
.
val_data_dir
parse_dataset_catagory
=
DetectCatagory
(
self
.
COCO
,
self
.
img_dir
)
self
.
label_names
,
self
.
label_ids
,
self
.
category_to_id_map
=
parse_dataset_catagory
()
parse_images
=
ParseImages
(
self
.
COCO
,
self
.
img_dir
,
self
.
category_to_id_map
)
self
.
data
=
parse_images
()
def
__getitem__
(
self
,
idx
:
int
):
img
=
self
.
data
[
idx
]
im
,
data
=
self
.
transform
(
img
)
out_img
,
gt_boxes
,
gt_labels
,
gt_scores
=
im
,
data
[
'gt_boxes'
],
data
[
'gt_labels'
],
data
[
'gt_scores'
]
return
out_img
,
gt_boxes
,
gt_labels
,
gt_scores
def
__len__
(
self
):
return
len
(
self
.
data
)
paddlehub/module/cv_module.py
浏览文件 @
86c06294
...
...
@@ -27,8 +27,8 @@ from PIL import Image
from
paddlehub.module.module
import
serving
,
RunModule
from
paddlehub.utils.utils
import
base64_to_cv2
from
paddlehub.process.transforms
import
ConvertColorSpace
,
ColorPostprocess
,
Resize
from
paddlehub.process.functional
import
subtract_imagenet_mean_batch
,
gram_matrix
import
paddlehub.process.transforms
as
T
import
paddlehub.process.functional
as
Func
class
ImageServing
(
object
):
...
...
@@ -136,8 +136,8 @@ class ImageColorizeModule(RunModule, ImageServing):
visual_ret
=
OrderedDict
()
psnrs
=
[]
lab2rgb
=
ConvertColorSpace
(
mode
=
'LAB2RGB'
)
process
=
ColorPostprocess
()
lab2rgb
=
T
.
ConvertColorSpace
(
mode
=
'LAB2RGB'
)
process
=
T
.
ColorPostprocess
()
for
i
in
range
(
batch
[
0
].
numpy
().
shape
[
0
]):
real
=
lab2rgb
(
np
.
concatenate
((
batch
[
0
].
numpy
(),
batch
[
3
].
numpy
()),
axis
=
1
))[
i
]
...
...
@@ -163,9 +163,9 @@ class ImageColorizeModule(RunModule, ImageServing):
Returns:
results(list[dict]) : The prediction result of each input image
'''
lab2rgb
=
ConvertColorSpace
(
mode
=
'LAB2RGB'
)
process
=
ColorPostprocess
()
resize
=
Resize
((
256
,
256
))
lab2rgb
=
T
.
ConvertColorSpace
(
mode
=
'LAB2RGB'
)
process
=
T
.
ColorPostprocess
()
resize
=
T
.
Resize
((
256
,
256
))
visual_ret
=
OrderedDict
()
im
=
self
.
transforms
(
images
,
is_train
=
False
)
out_class
,
out_reg
=
self
(
paddle
.
to_tensor
(
im
[
'A'
]),
paddle
.
to_variable
(
im
[
'hint_B'
]),
...
...
@@ -196,6 +196,124 @@ class ImageColorizeModule(RunModule, ImageServing):
return
result
class
Yolov3Module
(
RunModule
,
ImageServing
):
def
training_step
(
self
,
batch
:
int
,
batch_idx
:
int
)
->
dict
:
'''
One step for training, which should be called as forward computation.
Args:
batch(list[paddle.Tensor]): The one batch data, which contains images, ground truth boxes, labels and scores.
batch_idx(int): The index of batch.
Returns:
results(dict): The model outputs, such as loss.
'''
return
self
.
validation_step
(
batch
,
batch_idx
)
def
validation_step
(
self
,
batch
:
int
,
batch_idx
:
int
)
->
dict
:
'''
One step for validation, which should be called as forward computation.
Args:
batch(list[paddle.Tensor]): The one batch data, which contains images, ground truth boxes, labels and scores.
batch_idx(int): The index of batch.
Returns:
results(dict) : The model outputs, such as metrics.
'''
img
=
batch
[
0
].
astype
(
'float32'
)
gtbox
=
batch
[
1
].
astype
(
'float32'
)
gtlabel
=
batch
[
2
].
astype
(
'int32'
)
gtscore
=
batch
[
3
].
astype
(
"float32"
)
losses
=
[]
outputs
=
self
(
img
)
self
.
downsample
=
32
for
i
,
out
in
enumerate
(
outputs
):
anchor_mask
=
self
.
anchor_masks
[
i
]
loss
=
F
.
yolov3_loss
(
x
=
out
,
gt_box
=
gtbox
,
gt_label
=
gtlabel
,
gt_score
=
gtscore
,
anchors
=
self
.
anchors
,
anchor_mask
=
anchor_mask
,
class_num
=
self
.
class_num
,
ignore_thresh
=
self
.
ignore_thresh
,
downsample_ratio
=
32
,
use_label_smooth
=
False
)
losses
.
append
(
paddle
.
reduce_mean
(
loss
))
self
.
downsample
//=
2
return
{
'loss'
:
sum
(
losses
)}
def
predict
(
self
,
imgpath
:
str
,
filelist
:
str
,
visualization
:
bool
=
True
,
save_path
:
str
=
'result'
):
'''
Detect images
Args:
imgpath(str): Image path .
filelist(str): Path to get label name.
visualization(bool): Whether to save result image.
save_path(str) : Path to save detected images.
Returns:
boxes(np.ndarray): Predict box information.
scores(np.ndarray): Predict score.
labels(np.ndarray): Predict labels.
'''
boxes
=
[]
scores
=
[]
self
.
downsample
=
32
im
=
self
.
transform
(
imgpath
)
h
,
w
,
c
=
Func
.
img_shape
(
imgpath
)
im_shape
=
paddle
.
to_tensor
(
np
.
array
([[
h
,
w
]]).
astype
(
'int32'
))
label_names
=
Func
.
get_label_infos
(
filelist
)
img_data
=
paddle
.
to_tensor
(
np
.
array
([
im
]).
astype
(
'float32'
))
outputs
=
self
(
img_data
)
for
i
,
out
in
enumerate
(
outputs
):
anchor_mask
=
self
.
anchor_masks
[
i
]
mask_anchors
=
[]
for
m
in
anchor_mask
:
mask_anchors
.
append
((
self
.
anchors
[
2
*
m
]))
mask_anchors
.
append
(
self
.
anchors
[
2
*
m
+
1
])
box
,
score
=
F
.
yolo_box
(
x
=
out
,
img_size
=
im_shape
,
anchors
=
mask_anchors
,
class_num
=
self
.
class_num
,
conf_thresh
=
self
.
valid_thresh
,
downsample_ratio
=
self
.
downsample
,
name
=
"yolo_box"
+
str
(
i
))
boxes
.
append
(
box
)
scores
.
append
(
paddle
.
transpose
(
score
,
perm
=
[
0
,
2
,
1
]))
self
.
downsample
//=
2
yolo_boxes
=
paddle
.
concat
(
boxes
,
axis
=
1
)
yolo_scores
=
paddle
.
concat
(
scores
,
axis
=
2
)
pred
=
F
.
multiclass_nms
(
bboxes
=
yolo_boxes
,
scores
=
yolo_scores
,
score_threshold
=
self
.
valid_thresh
,
nms_top_k
=
self
.
nms_topk
,
keep_top_k
=
self
.
nms_posk
,
nms_threshold
=
self
.
nms_thresh
,
background_label
=-
1
)
bboxes
=
pred
.
numpy
()
labels
=
bboxes
[:,
0
].
astype
(
'int32'
)
scores
=
bboxes
[:,
1
].
astype
(
'float32'
)
boxes
=
bboxes
[:,
2
:].
astype
(
'float32'
)
if
visualization
:
Func
.
draw_boxes_on_image
(
imgpath
,
boxes
,
scores
,
labels
,
label_names
,
0.5
)
return
boxes
,
scores
,
labels
class
StyleTransferModule
(
RunModule
,
ImageServing
):
def
training_step
(
self
,
batch
:
int
,
batch_idx
:
int
)
->
dict
:
'''
...
...
@@ -228,19 +346,19 @@ class StyleTransferModule(RunModule, ImageServing):
y
=
self
(
batch
[
0
])
xc
=
paddle
.
to_tensor
(
batch
[
0
].
numpy
().
copy
())
y
=
subtract_imagenet_mean_batch
(
y
)
xc
=
subtract_imagenet_mean_batch
(
xc
)
y
=
Func
.
subtract_imagenet_mean_batch
(
y
)
xc
=
Func
.
subtract_imagenet_mean_batch
(
xc
)
features_y
=
self
.
getFeature
(
y
)
features_xc
=
self
.
getFeature
(
xc
)
f_xc_c
=
paddle
.
to_tensor
(
features_xc
[
1
].
numpy
(),
stop_gradient
=
True
)
content_loss
=
mse_loss
(
features_y
[
1
],
f_xc_c
)
batch
[
1
]
=
subtract_imagenet_mean_batch
(
batch
[
1
])
batch
[
1
]
=
Func
.
subtract_imagenet_mean_batch
(
batch
[
1
])
features_style
=
self
.
getFeature
(
batch
[
1
])
gram_style
=
[
gram_matrix
(
y
)
for
y
in
features_style
]
gram_style
=
[
Func
.
gram_matrix
(
y
)
for
y
in
features_style
]
style_loss
=
0.
for
m
in
range
(
len
(
features_y
)):
gram_y
=
gram_matrix
(
features_y
[
m
])
gram_y
=
Func
.
gram_matrix
(
features_y
[
m
])
gram_s
=
paddle
.
to_tensor
(
np
.
tile
(
gram_style
[
m
].
numpy
(),
(
N
,
1
,
1
,
1
)))
style_loss
+=
mse_loss
(
gram_y
,
gram_s
[:
N
,
:,
:])
...
...
paddlehub/process/detect_transforms.py
0 → 100644
浏览文件 @
86c06294
import
os
import
random
from
typing
import
Callable
import
cv2
import
numpy
as
np
import
matplotlib
import
PIL
from
PIL
import
Image
,
ImageEnhance
from
matplotlib
import
pyplot
as
plt
from
paddlehub.process.functional
import
*
matplotlib
.
use
(
'Agg'
)
class
RandomDistort
:
"""
Distort the input image randomly.
Args:
lower(float): The lower bound value for enhancement, default is 0.5.
upper(float): The upper bound value for enhancement, default is 1.5.
Returns:
img(np.ndarray): Distorted image.
data(dict): Image info and label info.
"""
def
__init__
(
self
,
lower
:
float
=
0.5
,
upper
:
float
=
1.5
):
self
.
lower
=
lower
self
.
upper
=
upper
def
random_brightness
(
self
,
img
:
PIL
.
Image
):
e
=
np
.
random
.
uniform
(
self
.
lower
,
self
.
upper
)
return
ImageEnhance
.
Brightness
(
img
).
enhance
(
e
)
def
random_contrast
(
self
,
img
:
PIL
.
Image
):
e
=
np
.
random
.
uniform
(
self
.
lower
,
self
.
upper
)
return
ImageEnhance
.
Contrast
(
img
).
enhance
(
e
)
def
random_color
(
self
,
img
:
PIL
.
Image
):
e
=
np
.
random
.
uniform
(
self
.
lower
,
self
.
upper
)
return
ImageEnhance
.
Color
(
img
).
enhance
(
e
)
def
__call__
(
self
,
img
:
np
.
ndarray
,
data
:
dict
):
ops
=
[
self
.
random_brightness
,
self
.
random_contrast
,
self
.
random_color
]
np
.
random
.
shuffle
(
ops
)
img
=
Image
.
fromarray
(
img
)
img
=
ops
[
0
](
img
)
img
=
ops
[
1
](
img
)
img
=
ops
[
2
](
img
)
img
=
np
.
asarray
(
img
)
return
img
,
data
class
RandomExpand
:
"""
Randomly expand images and gt boxes by random ratio. It is a data enhancement operation for model training.
Args:
max_ratio(float): Max value for expansion ratio, default is 4.
fill(list): Initialize the pixel value of the image with the input fill value, default is None.
keep_ratio(bool): Whether image keeps ratio.
thresh(float): If random ratio does not exceed the thresh, return original images and gt boxes, default is 0.5.
Return:
img(np.ndarray): Distorted image.
data(dict): Image info and label info.
"""
def
__init__
(
self
,
max_ratio
:
float
=
4.
,
fill
:
list
=
None
,
keep_ratio
:
bool
=
True
,
thresh
:
float
=
0.5
):
self
.
max_ratio
=
max_ratio
self
.
fill
=
fill
self
.
keep_ratio
=
keep_ratio
self
.
thresh
=
thresh
def
__call__
(
self
,
img
:
np
.
ndarray
,
data
:
dict
):
gtboxes
=
data
[
'gt_boxes'
]
if
random
.
random
()
>
self
.
thresh
:
return
img
,
data
if
self
.
max_ratio
<
1.0
:
return
img
,
data
h
,
w
,
c
=
img
.
shape
ratio_x
=
random
.
uniform
(
1
,
self
.
max_ratio
)
if
self
.
keep_ratio
:
ratio_y
=
ratio_x
else
:
ratio_y
=
random
.
uniform
(
1
,
self
.
max_ratio
)
oh
=
int
(
h
*
ratio_y
)
ow
=
int
(
w
*
ratio_x
)
off_x
=
random
.
randint
(
0
,
ow
-
w
)
off_y
=
random
.
randint
(
0
,
oh
-
h
)
out_img
=
np
.
zeros
((
oh
,
ow
,
c
))
if
self
.
fill
and
len
(
self
.
fill
)
==
c
:
for
i
in
range
(
c
):
out_img
[:,
:,
i
]
=
self
.
fill
[
i
]
*
255.0
out_img
[
off_y
:
off_y
+
h
,
off_x
:
off_x
+
w
,
:]
=
img
gtboxes
[:,
0
]
=
((
gtboxes
[:,
0
]
*
w
)
+
off_x
)
/
float
(
ow
)
gtboxes
[:,
1
]
=
((
gtboxes
[:,
1
]
*
h
)
+
off_y
)
/
float
(
oh
)
gtboxes
[:,
2
]
=
gtboxes
[:,
2
]
/
ratio_x
gtboxes
[:,
3
]
=
gtboxes
[:,
3
]
/
ratio_y
data
[
'gt_boxes'
]
=
gtboxes
img
=
out_img
.
astype
(
'uint8'
)
return
img
,
data
class
RandomCrop
:
"""
Random crop the input image according to constraints.
Args:
scales(list): The value of the cutting area relative to the original area, expressed in the form of
\
[min, max]. The default value is [.3, 1.].
max_ratio(float): Max ratio of the original area relative to the cutting area, default is 2.0.
constraints(list): The value of min and max iou values, default is None.
max_trial(int): The max trial for finding a valid crop area. The default value is 50.
Returns:
img(np.ndarray): Distorted image.
data(dict): Image info and label info.
"""
def
__init__
(
self
,
scales
:
list
=
[
0.3
,
1.0
],
max_ratio
:
float
=
2.0
,
constraints
:
list
=
None
,
max_trial
:
int
=
50
):
self
.
scales
=
scales
self
.
max_ratio
=
max_ratio
self
.
constraints
=
constraints
self
.
max_trial
=
max_trial
def
__call__
(
self
,
img
:
np
.
ndarray
,
data
:
dict
):
boxes
=
data
[
'gt_boxes'
]
labels
=
data
[
'gt_labels'
]
scores
=
data
[
'gt_scores'
]
if
len
(
boxes
)
==
0
:
return
img
,
data
if
not
self
.
constraints
:
self
.
constraints
=
[(
0.1
,
1.0
),
(
0.3
,
1.0
),
(
0.5
,
1.0
),
(
0.7
,
1.0
),
(
0.9
,
1.0
),
(
0.0
,
1.0
)]
img
=
Image
.
fromarray
(
img
)
w
,
h
=
img
.
size
crops
=
[(
0
,
0
,
w
,
h
)]
for
min_iou
,
max_iou
in
self
.
constraints
:
for
_
in
range
(
self
.
max_trial
):
scale
=
random
.
uniform
(
self
.
scales
[
0
],
self
.
scales
[
1
])
aspect_ratio
=
random
.
uniform
(
max
(
1
/
self
.
max_ratio
,
scale
*
scale
),
\
min
(
self
.
max_ratio
,
1
/
scale
/
scale
))
crop_h
=
int
(
h
*
scale
/
np
.
sqrt
(
aspect_ratio
))
crop_w
=
int
(
w
*
scale
*
np
.
sqrt
(
aspect_ratio
))
crop_x
=
random
.
randrange
(
w
-
crop_w
)
crop_y
=
random
.
randrange
(
h
-
crop_h
)
crop_box
=
np
.
array
([[(
crop_x
+
crop_w
/
2.0
)
/
w
,
(
crop_y
+
crop_h
/
2.0
)
/
h
,
crop_w
/
float
(
w
),
crop_h
/
float
(
h
)]])
iou
=
box_iou_xywh
(
crop_box
,
boxes
)
if
min_iou
<=
iou
.
min
()
and
max_iou
>=
iou
.
max
():
crops
.
append
((
crop_x
,
crop_y
,
crop_w
,
crop_h
))
break
while
crops
:
crop
=
crops
.
pop
(
np
.
random
.
randint
(
0
,
len
(
crops
)))
crop_boxes
,
crop_labels
,
crop_scores
,
box_num
=
box_crop
(
boxes
,
labels
,
scores
,
crop
,
(
w
,
h
))
if
box_num
<
1
:
continue
img
=
img
.
crop
((
crop
[
0
],
crop
[
1
],
crop
[
0
]
+
crop
[
2
],
crop
[
1
]
+
crop
[
3
])).
resize
(
img
.
size
,
Image
.
LANCZOS
)
img
=
np
.
asarray
(
img
)
data
[
'gt_boxes'
]
=
crop_boxes
data
[
'gt_labels'
]
=
crop_labels
data
[
'gt_scores'
]
=
crop_scores
return
img
,
data
img
=
np
.
asarray
(
img
)
data
[
'gt_boxes'
]
=
boxes
data
[
'gt_labels'
]
=
labels
data
[
'gt_scores'
]
=
scores
return
img
,
data
class
RandomFlip
:
"""Flip the images and gt boxes randomly.
Args:
thresh: Probability for random flip.
Returns:
img(np.ndarray): Distorted image.
data(dict): Image info and label info.
"""
def
__init__
(
self
,
thresh
:
float
=
0.5
):
self
.
thresh
=
thresh
def
__call__
(
self
,
img
,
data
):
gtboxes
=
data
[
'gt_boxes'
]
if
random
.
random
()
>
self
.
thresh
:
img
=
img
[:,
::
-
1
,
:]
gtboxes
[:,
0
]
=
1.0
-
gtboxes
[:,
0
]
data
[
'gt_boxes'
]
=
gtboxes
return
img
,
data
class
Compose
:
"""
Preprocess the input data according to the operators.
Args:
transforms(list): Preprocessing operators.
Returns:
img(np.ndarray): Preprocessed image.
data(dict): Image info and label info, default is None.
"""
def
__init__
(
self
,
transforms
:
list
):
if
not
isinstance
(
transforms
,
list
):
raise
TypeError
(
'The transforms must be a list!'
)
if
len
(
transforms
)
<
1
:
raise
ValueError
(
'The length of transforms '
+
\
'must be equal or larger than 1!'
)
self
.
transforms
=
transforms
def
__call__
(
self
,
data
:
dict
):
if
isinstance
(
data
,
dict
):
if
isinstance
(
data
[
'image'
],
str
):
img
=
cv2
.
imread
(
data
[
'image'
])
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
)
gt_labels
=
data
[
'gt_labels'
].
copy
()
data
[
'gt_scores'
]
=
np
.
ones_like
(
gt_labels
)
for
op
in
self
.
transforms
:
img
,
data
=
op
(
img
,
data
)
img
=
img
.
transpose
((
2
,
0
,
1
))
return
img
,
data
if
isinstance
(
data
,
str
):
img
=
cv2
.
imread
(
data
)
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
)
for
op
in
self
.
transforms
:
img
,
data
=
op
(
img
,
data
)
img
=
img
.
transpose
((
2
,
0
,
1
))
return
img
class
Resize
:
"""
Resize the input images.
Args:
target_size(int): Targeted input size.
interp(str): Interpolation method.
Returns:
img(np.ndarray): Preprocessed image.
data(dict): Image info and label info, default is None.
"""
def
__init__
(
self
,
target_size
:
int
=
512
,
interp
:
str
=
'RANDOM'
):
self
.
interp_dict
=
{
'NEAREST'
:
cv2
.
INTER_NEAREST
,
'LINEAR'
:
cv2
.
INTER_LINEAR
,
'CUBIC'
:
cv2
.
INTER_CUBIC
,
'AREA'
:
cv2
.
INTER_AREA
,
'LANCZOS4'
:
cv2
.
INTER_LANCZOS4
}
self
.
interp
=
interp
if
not
(
interp
==
"RANDOM"
or
interp
in
self
.
interp_dict
):
raise
ValueError
(
"interp should be one of {}"
.
format
(
self
.
interp_dict
.
keys
()))
if
isinstance
(
target_size
,
list
)
or
isinstance
(
target_size
,
tuple
):
if
len
(
target_size
)
!=
2
:
raise
TypeError
(
'when target is list or tuple, it should include 2 elements, but it is {}'
.
format
(
target_size
))
elif
not
isinstance
(
target_size
,
int
):
raise
TypeError
(
"Type of target_size is invalid. Must be Integer or List or tuple, now is {}"
.
format
(
type
(
target_size
)))
self
.
target_size
=
target_size
def
__call__
(
self
,
img
,
data
=
None
):
if
self
.
interp
==
"RANDOM"
:
interp
=
random
.
choice
(
list
(
self
.
interp_dict
.
keys
()))
else
:
interp
=
self
.
interp
img
=
resize
(
img
,
self
.
target_size
,
self
.
interp_dict
[
interp
])
if
data
is
not
None
:
return
img
,
data
else
:
return
img
class
Normalize
:
"""
Normalize the input images.
Args:
mean(list): Mean values for normalization, default is [0.5, 0.5, 0.5].
std(list): Standard deviation for normalization, default is [0.5, 0.5, 0.5].
Returns:
img(np.ndarray): Preprocessed image.
data(dict): Image info and label info, default is None.
"""
def
__init__
(
self
,
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
0.5
,
0.5
,
0.5
]):
self
.
mean
=
mean
self
.
std
=
std
if
not
(
isinstance
(
self
.
mean
,
list
)
and
isinstance
(
self
.
std
,
list
)):
raise
ValueError
(
"{}: input type is invalid."
.
format
(
self
))
from
functools
import
reduce
if
reduce
(
lambda
x
,
y
:
x
*
y
,
self
.
std
)
==
0
:
raise
ValueError
(
'{}: std is invalid!'
.
format
(
self
))
def
__call__
(
self
,
im
,
data
=
None
):
mean
=
np
.
array
(
self
.
mean
)[
np
.
newaxis
,
np
.
newaxis
,
:]
std
=
np
.
array
(
self
.
std
)[
np
.
newaxis
,
np
.
newaxis
,
:]
im
=
normalize
(
im
,
mean
,
std
)
if
data
is
not
None
:
return
im
,
data
else
:
return
im
class
ShuffleBox
:
"""Shuffle detection information for corresponding input image."""
def
__call__
(
self
,
img
,
data
):
gt
=
np
.
concatenate
([
data
[
'gt_boxes'
],
data
[
'gt_labels'
][:,
np
.
newaxis
],
data
[
'gt_scores'
][:,
np
.
newaxis
]],
axis
=
1
)
idx
=
np
.
arange
(
gt
.
shape
[
0
])
np
.
random
.
shuffle
(
idx
)
gt
=
gt
[
idx
,
:]
data
[
'gt_boxes'
],
data
[
'gt_labels'
],
data
[
'gt_scores'
]
=
gt
[:,
:
4
],
gt
[:,
4
],
gt
[:,
5
]
return
img
,
data
paddlehub/process/functional.py
浏览文件 @
86c06294
...
...
@@ -11,13 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
cv2
import
paddle
import
matplotlib
import
numpy
as
np
from
pycocotools.coco
import
COCO
from
PIL
import
Image
,
ImageEnhance
from
matplotlib
import
pyplot
as
plt
matplotlib
.
use
(
'Agg'
)
def
normalize
(
im
,
mean
,
std
):
...
...
@@ -120,6 +124,129 @@ def get_img_file(dir_name: str) -> list:
return
images
def
box_crop
(
boxes
:
np
.
ndarray
,
labels
:
np
.
ndarray
,
scores
:
np
.
ndarray
,
crop
:
list
,
img_shape
:
list
):
"""Crop the boxes ,labels, scores according to the given shape"""
x
,
y
,
w
,
h
=
map
(
float
,
crop
)
im_w
,
im_h
=
map
(
float
,
img_shape
)
boxes
=
boxes
.
copy
()
boxes
[:,
0
],
boxes
[:,
2
]
=
(
boxes
[:,
0
]
-
boxes
[:,
2
]
/
2
)
*
im_w
,
(
boxes
[:,
0
]
+
boxes
[:,
2
]
/
2
)
*
im_w
boxes
[:,
1
],
boxes
[:,
3
]
=
(
boxes
[:,
1
]
-
boxes
[:,
3
]
/
2
)
*
im_h
,
(
boxes
[:,
1
]
+
boxes
[:,
3
]
/
2
)
*
im_h
crop_box
=
np
.
array
([
x
,
y
,
x
+
w
,
y
+
h
])
centers
=
(
boxes
[:,
:
2
]
+
boxes
[:,
2
:])
/
2.0
mask
=
np
.
logical_and
(
crop_box
[:
2
]
<=
centers
,
centers
<=
crop_box
[
2
:]).
all
(
axis
=
1
)
boxes
[:,
:
2
]
=
np
.
maximum
(
boxes
[:,
:
2
],
crop_box
[:
2
])
boxes
[:,
2
:]
=
np
.
minimum
(
boxes
[:,
2
:],
crop_box
[
2
:])
boxes
[:,
:
2
]
-=
crop_box
[:
2
]
boxes
[:,
2
:]
-=
crop_box
[:
2
]
mask
=
np
.
logical_and
(
mask
,
(
boxes
[:,
:
2
]
<
boxes
[:,
2
:]).
all
(
axis
=
1
))
boxes
=
boxes
*
np
.
expand_dims
(
mask
.
astype
(
'float32'
),
axis
=
1
)
labels
=
labels
*
mask
.
astype
(
'float32'
)
scores
=
scores
*
mask
.
astype
(
'float32'
)
boxes
[:,
0
],
boxes
[:,
2
]
=
(
boxes
[:,
0
]
+
boxes
[:,
2
])
/
2
/
w
,
(
boxes
[:,
2
]
-
boxes
[:,
0
])
/
w
boxes
[:,
1
],
boxes
[:,
3
]
=
(
boxes
[:,
1
]
+
boxes
[:,
3
])
/
2
/
h
,
(
boxes
[:,
3
]
-
boxes
[:,
1
])
/
h
return
boxes
,
labels
,
scores
,
mask
.
sum
()
def
box_iou_xywh
(
box1
:
np
.
ndarray
,
box2
:
np
.
ndarray
)
->
float
:
"""Calculate iou by xywh"""
assert
box1
.
shape
[
-
1
]
==
4
,
"Box1 shape[-1] should be 4."
assert
box2
.
shape
[
-
1
]
==
4
,
"Box2 shape[-1] should be 4."
b1_x1
,
b1_x2
=
box1
[:,
0
]
-
box1
[:,
2
]
/
2
,
box1
[:,
0
]
+
box1
[:,
2
]
/
2
b1_y1
,
b1_y2
=
box1
[:,
1
]
-
box1
[:,
3
]
/
2
,
box1
[:,
1
]
+
box1
[:,
3
]
/
2
b2_x1
,
b2_x2
=
box2
[:,
0
]
-
box2
[:,
2
]
/
2
,
box2
[:,
0
]
+
box2
[:,
2
]
/
2
b2_y1
,
b2_y2
=
box2
[:,
1
]
-
box2
[:,
3
]
/
2
,
box2
[:,
1
]
+
box2
[:,
3
]
/
2
inter_x1
=
np
.
maximum
(
b1_x1
,
b2_x1
)
inter_x2
=
np
.
minimum
(
b1_x2
,
b2_x2
)
inter_y1
=
np
.
maximum
(
b1_y1
,
b2_y1
)
inter_y2
=
np
.
minimum
(
b1_y2
,
b2_y2
)
inter_w
=
inter_x2
-
inter_x1
inter_h
=
inter_y2
-
inter_y1
inter_w
[
inter_w
<
0
]
=
0
inter_h
[
inter_h
<
0
]
=
0
inter_area
=
inter_w
*
inter_h
b1_area
=
(
b1_x2
-
b1_x1
)
*
(
b1_y2
-
b1_y1
)
b2_area
=
(
b2_x2
-
b2_x1
)
*
(
b2_y2
-
b2_y1
)
return
inter_area
/
(
b1_area
+
b2_area
-
inter_area
)
def
draw_boxes_on_image
(
image_path
:
str
,
boxes
:
np
.
ndarray
,
scores
:
np
.
ndarray
,
labels
:
np
.
ndarray
,
label_names
:
list
,
score_thresh
:
float
=
0.5
):
"""Draw boxes on images."""
image
=
np
.
array
(
Image
.
open
(
image_path
))
plt
.
figure
()
_
,
ax
=
plt
.
subplots
(
1
)
ax
.
imshow
(
image
)
image_name
=
image_path
.
split
(
'/'
)[
-
1
]
print
(
"Image {} detect: "
.
format
(
image_name
))
colors
=
{}
for
box
,
score
,
label
in
zip
(
boxes
,
scores
,
labels
):
if
score
<
score_thresh
:
continue
if
box
[
2
]
<=
box
[
0
]
or
box
[
3
]
<=
box
[
1
]:
continue
label
=
int
(
label
)
if
label
not
in
colors
:
colors
[
label
]
=
plt
.
get_cmap
(
'hsv'
)(
label
/
len
(
label_names
))
x1
,
y1
,
x2
,
y2
=
box
[
0
],
box
[
1
],
box
[
2
],
box
[
3
]
rect
=
plt
.
Rectangle
((
x1
,
y1
),
x2
-
x1
,
y2
-
y1
,
fill
=
False
,
linewidth
=
2.0
,
edgecolor
=
colors
[
label
])
ax
.
add_patch
(
rect
)
ax
.
text
(
x1
,
y1
,
'{} {:.4f}'
.
format
(
label_names
[
label
],
score
),
verticalalignment
=
'bottom'
,
horizontalalignment
=
'left'
,
bbox
=
{
'facecolor'
:
colors
[
label
],
'alpha'
:
0.5
,
'pad'
:
0
},
fontsize
=
8
,
color
=
'white'
)
print
(
"
\t
{:15s} at {:25} score: {:.5f}"
.
format
(
label_names
[
int
(
label
)],
str
(
list
(
map
(
int
,
list
(
box
)))),
score
))
image_name
=
image_name
.
replace
(
'jpg'
,
'png'
)
plt
.
axis
(
'off'
)
plt
.
gca
().
xaxis
.
set_major_locator
(
plt
.
NullLocator
())
plt
.
gca
().
yaxis
.
set_major_locator
(
plt
.
NullLocator
())
plt
.
savefig
(
"./output/{}"
.
format
(
image_name
),
bbox_inches
=
'tight'
,
pad_inches
=
0.0
)
print
(
"Detect result save at ./output/{}
\n
"
.
format
(
image_name
))
plt
.
cla
()
plt
.
close
(
'all'
)
def
img_shape
(
img_path
:
str
):
"""Get image shape."""
im
=
cv2
.
imread
(
img_path
)
im
=
cv2
.
cvtColor
(
im
,
cv2
.
COLOR_BGR2RGB
)
h
,
w
,
c
=
im
.
shape
return
h
,
w
,
c
def
get_label_infos
(
file_list
:
str
):
"""Get label names by corresponding category ids."""
map_label
=
COCO
(
file_list
)
label_names
=
[]
categories
=
map_label
.
loadCats
(
map_label
.
getCatIds
())
for
category
in
categories
:
label_names
.
append
(
category
[
'name'
])
return
label_names
def
subtract_imagenet_mean_batch
(
batch
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
"""Subtract ImageNet mean pixel-wise from a BGR image."""
mean
=
np
.
zeros
(
shape
=
batch
.
shape
,
dtype
=
'float32'
)
...
...
paddlehub/process/transforms.py
浏览文件 @
86c06294
...
...
@@ -12,13 +12,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
random
import
copy
from
typing
import
Callable
from
collections
import
OrderedDict
import
cv2
import
numpy
as
np
from
PIL
import
Image
from
PIL
import
Image
,
ImageEnhance
from
paddlehub.process.functional
import
*
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录