Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
5148424a
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
1 年多 前同步成功
通知
284
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5148424a
编写于
9月 21, 2020
作者:
H
haoyuying
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add yolov3_darknet_pascalvoc
上级
9a1eac7b
变更
8
展开全部
隐藏空白更改
内联
并排
Showing
8 changed file
with
1171 addition
and
50 deletion
+1171
-50
demo/detection/yolov3_darknet53_pascalvoc/4026.jpeg
demo/detection/yolov3_darknet53_pascalvoc/4026.jpeg
+0
-0
demo/detection/yolov3_darknet53_pascalvoc/predict.py
demo/detection/yolov3_darknet53_pascalvoc/predict.py
+9
-0
demo/detection/yolov3_darknet53_pascalvoc/train.py
demo/detection/yolov3_darknet53_pascalvoc/train.py
+22
-0
hub_module/modules/image/object_detection/yolov3_darknet53_pascalvoc/darknet.py
...ge/object_detection/yolov3_darknet53_pascalvoc/darknet.py
+144
-0
hub_module/modules/image/object_detection/yolov3_darknet53_pascalvoc/module.py
...age/object_detection/yolov3_darknet53_pascalvoc/module.py
+247
-0
paddlehub/datasets/pascalvoc.py
paddlehub/datasets/pascalvoc.py
+74
-0
paddlehub/module/cv_module.py
paddlehub/module/cv_module.py
+112
-13
paddlehub/process/transforms.py
paddlehub/process/transforms.py
+563
-37
未找到文件。
demo/detection/yolov3_darknet53_pascalvoc/4026.jpeg
0 → 100644
浏览文件 @
5148424a
83.0 KB
demo/detection/yolov3_darknet53_pascalvoc/predict.py
0 → 100644
浏览文件 @
5148424a
import
paddle
import
paddlehub
as
hub
if
__name__
==
'__main__'
:
place
=
paddle
.
CUDAPlace
(
0
)
paddle
.
disable_static
()
model
=
model
=
hub
.
Module
(
name
=
'yolov3_darknet53_pascalvoc'
,
is_train
=
False
)
model
.
eval
()
model
.
predict
(
imgpath
=
"/PATH/TO/IMAGE"
,
filelist
=
"/PATH/TO/JSON/FILE"
)
demo/detection/yolov3_darknet53_pascalvoc/train.py
0 → 100644
浏览文件 @
5148424a
import
paddle
import
paddlehub
as
hub
import
paddle.nn
as
nn
from
paddlehub.finetune.trainer
import
Trainer
from
paddlehub.datasets.pascalvoc
import
DetectionData
from
paddlehub.process.transforms
import
DetectTrainReader
,
DetectTestReader
if
__name__
==
"__main__"
:
place
=
paddle
.
CUDAPlace
(
0
)
paddle
.
disable_static
()
is_train
=
True
if
is_train
:
transform
=
DetectTrainReader
()
train_reader
=
DetectionData
(
transform
)
else
:
transform
=
DetectTestReader
()
test_reader
=
DetectionData
(
transform
)
model
=
hub
.
Module
(
name
=
'yolov3_darknet53_pascalvoc'
)
model
.
train
()
optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
0.0001
,
parameters
=
model
.
parameters
())
trainer
=
Trainer
(
model
,
optimizer
,
checkpoint_dir
=
'test_ckpt_img_cls'
)
trainer
.
train
(
train_reader
,
epochs
=
5
,
batch_size
=
4
,
eval_dataset
=
train_reader
,
log_interval
=
1
,
save_interval
=
1
)
hub_module/modules/image/object_detection/yolov3_darknet53_pascalvoc/darknet.py
0 → 100644
浏览文件 @
5148424a
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddle.regularizer
import
L2Decay
from
paddle.nn.initializer
import
Normal
class
ConvBNLayer
(
nn
.
Layer
):
"""Basic block for Darknet"""
def
__init__
(
self
,
ch_in
:
int
,
ch_out
:
int
,
filter_size
:
int
=
3
,
stride
:
int
=
1
,
groups
:
int
=
1
,
padding
:
int
=
0
,
act
:
str
=
'leakly'
,
is_test
:
bool
=
False
):
super
(
ConvBNLayer
,
self
).
__init__
()
self
.
conv
=
nn
.
Conv2d
(
ch_in
,
ch_out
,
filter_size
,
padding
=
padding
,
stride
=
stride
,
groups
=
groups
,
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
Normal
(
0.
,
0.02
)),
bias_attr
=
False
)
self
.
batch_norm
=
nn
.
BatchNorm
(
num_channels
=
ch_out
,
is_test
=
is_test
,
param_attr
=
paddle
.
ParamAttr
(
initializer
=
Normal
(
0.
,
0.02
),
regularizer
=
L2Decay
(
0.
)))
self
.
act
=
act
def
forward
(
self
,
inputs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
out
=
self
.
conv
(
inputs
)
out
=
self
.
batch_norm
(
out
)
if
self
.
act
==
"leakly"
:
out
=
F
.
leaky_relu
(
x
=
out
,
negative_slope
=
0.1
)
return
out
class
DownSample
(
nn
.
Layer
):
"""Downsample block for Darknet"""
def
__init__
(
self
,
ch_in
:
int
,
ch_out
:
int
,
filter_size
:
int
=
3
,
stride
:
int
=
2
,
padding
:
int
=
1
,
is_test
:
bool
=
False
):
super
(
DownSample
,
self
).
__init__
()
self
.
conv_bn_layer
=
ConvBNLayer
(
ch_in
=
ch_in
,
ch_out
=
ch_out
,
filter_size
=
filter_size
,
stride
=
stride
,
padding
=
padding
,
is_test
=
is_test
)
self
.
ch_out
=
ch_out
def
forward
(
self
,
inputs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
out
=
self
.
conv_bn_layer
(
inputs
)
return
out
class
BasicBlock
(
nn
.
Layer
):
"""Basic residual block for Darknet"""
def
__init__
(
self
,
ch_in
:
int
,
ch_out
:
int
,
is_test
:
bool
=
False
):
super
(
BasicBlock
,
self
).
__init__
()
self
.
conv1
=
ConvBNLayer
(
ch_in
=
ch_in
,
ch_out
=
ch_out
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
is_test
=
is_test
)
self
.
conv2
=
ConvBNLayer
(
ch_in
=
ch_out
,
ch_out
=
ch_out
*
2
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
,
is_test
=
is_test
)
def
forward
(
self
,
inputs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
conv1
=
self
.
conv1
(
inputs
)
conv2
=
self
.
conv2
(
conv1
)
out
=
paddle
.
elementwise_add
(
x
=
inputs
,
y
=
conv2
,
act
=
None
)
return
out
class
LayerWarp
(
nn
.
Layer
):
"""Warp layer composed by basic residual blocks"""
def
__init__
(
self
,
ch_in
:
int
,
ch_out
:
int
,
count
:
int
,
is_test
:
bool
=
False
):
super
(
LayerWarp
,
self
).
__init__
()
self
.
basicblock0
=
BasicBlock
(
ch_in
,
ch_out
,
is_test
=
is_test
)
self
.
res_out_list
=
[]
for
i
in
range
(
1
,
count
):
res_out
=
self
.
add_sublayer
(
"basic_block_%d"
%
(
i
),
BasicBlock
(
ch_out
*
2
,
ch_out
,
is_test
=
is_test
))
self
.
res_out_list
.
append
(
res_out
)
self
.
ch_out
=
ch_out
def
forward
(
self
,
inputs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
y
=
self
.
basicblock0
(
inputs
)
for
basic_block_i
in
self
.
res_out_list
:
y
=
basic_block_i
(
y
)
return
y
DarkNet_cfg
=
{
53
:
([
1
,
2
,
8
,
8
,
4
])}
class
DarkNet53_conv_body
(
nn
.
Layer
):
"""Darknet53
Args:
ch_in(int): Input channels, default is 3.
is_test (bool): Set the test mode, default is True.
"""
def
__init__
(
self
,
ch_in
:
int
=
3
,
is_test
:
bool
=
False
):
super
(
DarkNet53_conv_body
,
self
).
__init__
()
self
.
stages
=
DarkNet_cfg
[
53
]
self
.
stages
=
self
.
stages
[
0
:
5
]
self
.
conv0
=
ConvBNLayer
(
ch_in
=
ch_in
,
ch_out
=
32
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
,
is_test
=
is_test
)
self
.
downsample0
=
DownSample
(
ch_in
=
32
,
ch_out
=
32
*
2
,
is_test
=
is_test
)
self
.
darknet53_conv_block_list
=
[]
self
.
downsample_list
=
[]
ch_in
=
[
64
,
128
,
256
,
512
,
1024
]
for
i
,
stage
in
enumerate
(
self
.
stages
):
conv_block
=
self
.
add_sublayer
(
"stage_%d"
%
(
i
),
LayerWarp
(
int
(
ch_in
[
i
]),
32
*
(
2
**
i
),
stage
,
is_test
=
is_test
))
self
.
darknet53_conv_block_list
.
append
(
conv_block
)
for
i
in
range
(
len
(
self
.
stages
)
-
1
):
downsample
=
self
.
add_sublayer
(
"stage_%d_downsample"
%
i
,
DownSample
(
ch_in
=
32
*
(
2
**
(
i
+
1
)),
ch_out
=
32
*
(
2
**
(
i
+
2
)),
is_test
=
is_test
))
self
.
downsample_list
.
append
(
downsample
)
def
forward
(
self
,
inputs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
out
=
self
.
conv0
(
inputs
)
out
=
self
.
downsample0
(
out
)
blocks
=
[]
for
i
,
conv_block_i
in
enumerate
(
self
.
darknet53_conv_block_list
):
out
=
conv_block_i
(
out
)
blocks
.
append
(
out
)
if
i
<
len
(
self
.
stages
)
-
1
:
out
=
self
.
downsample_list
[
i
](
out
)
return
blocks
[
-
1
:
-
4
:
-
1
]
hub_module/modules/image/object_detection/yolov3_darknet53_pascalvoc/module.py
0 → 100644
浏览文件 @
5148424a
import
os
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddle.nn.initializer
import
Normal
,
Constant
from
paddle.regularizer
import
L2Decay
from
pycocotools.coco
import
COCO
from
darknet
import
DarkNet53_conv_body
from
darknet
import
ConvBNLayer
from
paddlehub.module.cv_module
import
Yolov3Module
from
paddlehub.process.transforms
import
DetectTrainReader
,
DetectTestReader
from
paddlehub.module.module
import
moduleinfo
class
YoloDetectionBlock
(
nn
.
Layer
):
"""Basic block for Yolov3"""
def
__init__
(
self
,
ch_in
:
int
,
channel
:
int
,
is_test
:
bool
=
True
):
super
(
YoloDetectionBlock
,
self
).
__init__
()
assert
channel
%
2
==
0
,
\
"channel {} cannot be divided by 2"
.
format
(
channel
)
self
.
conv0
=
ConvBNLayer
(
ch_in
=
ch_in
,
ch_out
=
channel
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
is_test
=
is_test
)
self
.
conv1
=
ConvBNLayer
(
ch_in
=
channel
,
ch_out
=
channel
*
2
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
,
is_test
=
is_test
)
self
.
conv2
=
ConvBNLayer
(
ch_in
=
channel
*
2
,
ch_out
=
channel
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
is_test
=
is_test
)
self
.
conv3
=
ConvBNLayer
(
ch_in
=
channel
,
ch_out
=
channel
*
2
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
,
is_test
=
is_test
)
self
.
route
=
ConvBNLayer
(
ch_in
=
channel
*
2
,
ch_out
=
channel
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
is_test
=
is_test
)
self
.
tip
=
ConvBNLayer
(
ch_in
=
channel
,
ch_out
=
channel
*
2
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
,
is_test
=
is_test
)
def
forward
(
self
,
inputs
):
out
=
self
.
conv0
(
inputs
)
out
=
self
.
conv1
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
conv3
(
out
)
route
=
self
.
route
(
out
)
tip
=
self
.
tip
(
route
)
return
route
,
tip
class
Upsample
(
nn
.
Layer
):
"""Upsample block for Yolov3"""
def
__init__
(
self
,
scale
:
int
=
2
):
super
(
Upsample
,
self
).
__init__
()
self
.
scale
=
scale
def
forward
(
self
,
inputs
:
paddle
.
Tensor
):
shape_nchw
=
paddle
.
to_tensor
(
inputs
.
shape
)
shape_hw
=
paddle
.
slice
(
shape_nchw
,
axes
=
[
0
],
starts
=
[
2
],
ends
=
[
4
])
shape_hw
.
stop_gradient
=
True
in_shape
=
paddle
.
cast
(
shape_hw
,
dtype
=
'int32'
)
out_shape
=
in_shape
*
self
.
scale
out_shape
.
stop_gradient
=
True
out
=
F
.
resize_nearest
(
input
=
inputs
,
scale
=
self
.
scale
,
actual_shape
=
out_shape
)
return
out
@
moduleinfo
(
name
=
"yolov3_darknet53_pascalvoc"
,
type
=
"CV/image_editing"
,
author
=
"paddlepaddle"
,
author_email
=
""
,
summary
=
"Yolov3 is a detection model, this module is trained with VOC dataset."
,
version
=
"1.0.0"
,
meta
=
Yolov3Module
)
class
YOLOv3
(
nn
.
Layer
):
"""YOLOV3 for detection
Args:
ch_in(int): Input channels, default is 3.
class_num(int): Categories for detection,if dataset is voc, class_num is 20.
ignore_thresh(float): The ignore threshold to ignore confidence loss.
valid_thresh(float): Threshold to filter out bounding boxes with low confidence score.
nms_topk(int): Maximum number of detections to be kept according to the confidences after the filtering
detections based on score_threshold.
nms_posk(int): Number of total bboxes to be kept per image after NMS step. -1 means keeping all bboxes after NMS
step.
nms_thresh (float): The threshold to be used in NMS. Default: 0.3.
is_train (bool): Set the train mode, default is True.
load_checkpoint(str): Whether to load checkpoint.
"""
def
__init__
(
self
,
ch_in
:
int
=
3
,
class_num
:
int
=
20
,
ignore_thresh
:
float
=
0.7
,
valid_thresh
:
float
=
0.005
,
nms_topk
:
int
=
400
,
nms_posk
:
int
=
100
,
nms_thresh
:
float
=
0.45
,
is_train
:
bool
=
True
,
load_checkpoint
:
str
=
None
):
super
(
YOLOv3
,
self
).
__init__
()
self
.
is_train
=
is_train
self
.
block
=
DarkNet53_conv_body
(
ch_in
=
ch_in
,
is_test
=
not
self
.
is_train
)
self
.
block_outputs
=
[]
self
.
yolo_blocks
=
[]
self
.
route_blocks_2
=
[]
self
.
anchor_masks
=
[[
6
,
7
,
8
],
[
3
,
4
,
5
],
[
0
,
1
,
2
]]
self
.
anchors
=
[
10
,
13
,
16
,
30
,
33
,
23
,
30
,
61
,
62
,
45
,
59
,
119
,
116
,
90
,
156
,
198
,
373
,
326
]
self
.
class_num
=
class_num
self
.
ignore_thresh
=
ignore_thresh
self
.
valid_thresh
=
valid_thresh
self
.
nms_topk
=
nms_topk
self
.
nms_posk
=
nms_posk
self
.
nms_thresh
=
nms_thresh
ch_in_list
=
[
1024
,
768
,
384
]
for
i
in
range
(
3
):
yolo_block
=
self
.
add_sublayer
(
"yolo_detecton_block_%d"
%
(
i
),
YoloDetectionBlock
(
ch_in_list
[
i
],
channel
=
512
//
(
2
**
i
),
is_test
=
not
self
.
is_train
))
self
.
yolo_blocks
.
append
(
yolo_block
)
num_filters
=
len
(
self
.
anchor_masks
[
i
])
*
(
self
.
class_num
+
5
)
block_out
=
self
.
add_sublayer
(
"block_out_%d"
%
(
i
),
nn
.
Conv2d
(
1024
//
(
2
**
i
),
num_filters
,
1
,
stride
=
1
,
padding
=
0
,
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
Normal
(
0.
,
0.02
)),
bias_attr
=
paddle
.
ParamAttr
(
initializer
=
Constant
(
0.0
),
regularizer
=
L2Decay
(
0.
))))
self
.
block_outputs
.
append
(
block_out
)
if
i
<
2
:
route
=
self
.
add_sublayer
(
"route2_%d"
%
i
,
ConvBNLayer
(
ch_in
=
512
//
(
2
**
i
),
ch_out
=
256
//
(
2
**
i
),
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
is_test
=
(
not
self
.
is_train
)))
self
.
route_blocks_2
.
append
(
route
)
self
.
upsample
=
Upsample
()
if
load_checkpoint
is
not
None
:
model_dict
=
paddle
.
load
(
load_checkpoint
)[
0
]
self
.
set_dict
(
model_dict
)
print
(
"load custom checkpoint success"
)
else
:
checkpoint
=
os
.
path
.
join
(
self
.
directory
,
'yolov3_70000.pdparams'
)
if
not
os
.
path
.
exists
(
checkpoint
):
os
.
system
(
'wget https://bj.bcebos.com/paddlehub/model/image/object_detection/yolov3_70000.pdparams -O '
\
+
checkpoint
)
model_dict
=
paddle
.
load
(
checkpoint
)[
0
]
self
.
set_dict
(
model_dict
)
print
(
"load pretrained checkpoint success"
)
def
transform
(
self
,
img
:
paddle
.
Tensor
,
size
:
int
):
if
self
.
is_train
:
transforms
=
DetectTrainReader
()
else
:
transforms
=
DetectTestReader
()
return
transforms
(
img
,
size
)
def
get_label_infos
(
self
,
file_list
:
str
):
self
.
COCO
=
COCO
(
file_list
)
label_names
=
[]
categories
=
self
.
COCO
.
loadCats
(
self
.
COCO
.
getCatIds
())
for
category
in
categories
:
label_names
.
append
(
category
[
'name'
])
return
label_names
def
forward
(
self
,
inputs
:
paddle
.
Tensor
,
gtbox
:
paddle
.
Tensor
=
None
,
gtlabel
:
paddle
.
Tensor
=
None
,
gtscore
:
paddle
.
Tensor
=
None
,
im_shape
:
paddle
.
Tensor
=
None
):
self
.
gtbox
=
gtbox
self
.
gtlabel
=
gtlabel
self
.
gtscore
=
gtscore
self
.
im_shape
=
im_shape
self
.
outputs
=
[]
self
.
boxes
=
[]
self
.
scores
=
[]
self
.
losses
=
[]
self
.
pred
=
[]
self
.
downsample
=
32
blocks
=
self
.
block
(
inputs
)
route
=
None
for
i
,
block
in
enumerate
(
blocks
):
if
i
>
0
:
block
=
paddle
.
concat
([
route
,
block
],
axis
=
1
)
route
,
tip
=
self
.
yolo_blocks
[
i
](
block
)
block_out
=
self
.
block_outputs
[
i
](
tip
)
self
.
outputs
.
append
(
block_out
)
if
i
<
2
:
route
=
self
.
route_blocks_2
[
i
](
route
)
route
=
self
.
upsample
(
route
)
for
i
,
out
in
enumerate
(
self
.
outputs
):
anchor_mask
=
self
.
anchor_masks
[
i
]
if
self
.
is_train
:
loss
=
F
.
yolov3_loss
(
x
=
out
,
gt_box
=
self
.
gtbox
,
gt_label
=
self
.
gtlabel
,
gt_score
=
self
.
gtscore
,
anchors
=
self
.
anchors
,
anchor_mask
=
anchor_mask
,
class_num
=
self
.
class_num
,
ignore_thresh
=
self
.
ignore_thresh
,
downsample_ratio
=
self
.
downsample
,
use_label_smooth
=
False
)
else
:
loss
=
paddle
.
to_tensor
(
0.0
)
self
.
losses
.
append
(
paddle
.
reduce_mean
(
loss
))
mask_anchors
=
[]
for
m
in
anchor_mask
:
mask_anchors
.
append
((
self
.
anchors
[
2
*
m
]))
mask_anchors
.
append
(
self
.
anchors
[
2
*
m
+
1
])
boxes
,
scores
=
F
.
yolo_box
(
x
=
out
,
img_size
=
self
.
im_shape
,
anchors
=
mask_anchors
,
class_num
=
self
.
class_num
,
conf_thresh
=
self
.
valid_thresh
,
downsample_ratio
=
self
.
downsample
,
name
=
"yolo_box"
+
str
(
i
))
self
.
boxes
.
append
(
boxes
)
self
.
scores
.
append
(
paddle
.
transpose
(
scores
,
perm
=
[
0
,
2
,
1
]))
self
.
downsample
//=
2
for
i
in
range
(
self
.
boxes
[
0
].
shape
[
0
]):
yolo_boxes
=
paddle
.
unsqueeze
(
paddle
.
concat
([
self
.
boxes
[
0
][
i
],
self
.
boxes
[
1
][
i
],
self
.
boxes
[
2
][
i
]],
axis
=
0
),
0
)
yolo_scores
=
paddle
.
unsqueeze
(
paddle
.
concat
([
self
.
scores
[
0
][
i
],
self
.
scores
[
1
][
i
],
self
.
scores
[
2
][
i
]],
axis
=
1
),
0
)
pred
=
F
.
multiclass_nms
(
bboxes
=
yolo_boxes
,
scores
=
yolo_scores
,
score_threshold
=
self
.
valid_thresh
,
nms_top_k
=
self
.
nms_topk
,
keep_top_k
=
self
.
nms_posk
,
nms_threshold
=
self
.
nms_thresh
,
background_label
=-
1
)
self
.
pred
.
append
(
pred
)
return
sum
(
self
.
losses
),
self
.
pred
paddlehub/datasets/pascalvoc.py
0 → 100644
浏览文件 @
5148424a
# coding:utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
typing
import
Callable
import
paddle
from
paddlehub.env
import
DATA_HOME
from
pycocotools.coco
import
COCO
from
paddlehub.process.transforms
import
DetectCatagory
,
ParseImages
class
DetectionData
(
paddle
.
io
.
Dataset
):
"""
Dataset for image detection.
Args:
transform(callmethod) : The method of preprocess images.
mode(str): The mode for preparing dataset.
Returns:
DataSet: An iterable object for data iterating
"""
def
__init__
(
self
,
transform
:
Callable
,
size
:
int
=
416
,
mode
:
str
=
'train'
):
self
.
mode
=
mode
self
.
transform
=
transform
self
.
size
=
size
if
self
.
mode
==
'train'
:
train_file_list
=
'annotations/instances_train2017.json'
train_data_dir
=
'train2017'
self
.
train_file_list
=
os
.
path
.
join
(
DATA_HOME
,
'voc'
,
train_file_list
)
self
.
train_data_dir
=
os
.
path
.
join
(
DATA_HOME
,
'voc'
,
train_data_dir
)
self
.
COCO
=
COCO
(
self
.
train_file_list
)
self
.
img_dir
=
self
.
train_data_dir
elif
self
.
mode
==
'test'
:
val_file_list
=
'annotations/instances_val2017.json'
val_data_dir
=
'val2017'
self
.
val_file_list
=
os
.
path
.
join
(
DATA_HOME
,
'voc'
,
val_file_list
)
self
.
val_data_dir
=
os
.
path
.
join
(
DATA_HOME
,
'voc'
,
val_data_dir
)
self
.
COCO
=
COCO
(
self
.
val_file_list
)
self
.
img_dir
=
self
.
val_data_dir
parse_dataset_catagory
=
DetectCatagory
(
self
.
COCO
,
self
.
img_dir
)
self
.
label_names
,
self
.
label_ids
,
self
.
category_to_id_map
=
parse_dataset_catagory
()
parse_images
=
ParseImages
(
self
.
COCO
,
self
.
mode
,
self
.
img_dir
,
self
.
category_to_id_map
)
self
.
data
=
parse_images
()
def
__getitem__
(
self
,
idx
:
int
):
if
self
.
mode
==
"train"
:
img
=
self
.
data
[
idx
]
out_img
,
gt_boxes
,
gt_labels
,
gt_scores
=
self
.
transform
(
img
,
416
)
return
out_img
,
gt_boxes
,
gt_labels
,
gt_scores
elif
self
.
mode
==
"test"
:
img
=
self
.
data
[
idx
]
out_img
,
id
,
(
h
,
w
)
=
self
.
transform
(
img
)
return
out_img
,
id
,
(
h
,
w
)
def
__len__
(
self
):
return
len
(
self
.
data
)
paddlehub/module/cv_module.py
浏览文件 @
5148424a
...
...
@@ -26,7 +26,7 @@ from PIL import Image
from
paddlehub.module.module
import
serving
,
RunModule
from
paddlehub.utils.utils
import
base64_to_cv2
from
paddlehub.process.transforms
import
ConvertColorSpace
,
ColorPostprocess
,
Resize
from
paddlehub.process.transforms
import
ConvertColorSpace
,
ColorPostprocess
,
Resize
,
BoxTool
class
ImageServing
(
object
):
...
...
@@ -103,11 +103,11 @@ class ImageColorizeModule(RunModule, ImageServing):
def
training_step
(
self
,
batch
:
int
,
batch_idx
:
int
)
->
dict
:
'''
One step for training, which should be called as forward computation.
Args:
batch(list[paddle.Tensor]): The one batch data, which contains images and labels.
batch_idx(int): The index of batch.
Returns:
results(dict) : The model outputs, such as loss and metrics.
'''
...
...
@@ -116,22 +116,22 @@ class ImageColorizeModule(RunModule, ImageServing):
def
validation_step
(
self
,
batch
:
int
,
batch_idx
:
int
)
->
dict
:
'''
One step for validation, which should be called as forward computation.
Args:
batch(list[paddle.Tensor]): The one batch data, which contains images and labels.
batch_idx(int): The index of batch.
Returns:
results(dict) : The model outputs, such as metrics.
'''
out_class
,
out_reg
=
self
(
batch
[
0
],
batch
[
1
],
batch
[
2
])
criterionCE
=
nn
.
loss
.
CrossEntropyLoss
()
loss_ce
=
criterionCE
(
out_class
,
batch
[
4
][:,
0
,
:,
:])
loss_G_L1_reg
=
paddle
.
sum
(
paddle
.
abs
(
batch
[
3
]
-
out_reg
),
axis
=
1
,
keepdim
=
True
)
loss_G_L1_reg
=
paddle
.
mean
(
loss_G_L1_reg
)
loss
=
loss_ce
+
loss_G_L1_reg
visual_ret
=
OrderedDict
()
psnrs
=
[]
lab2rgb
=
ConvertColorSpace
(
mode
=
'LAB2RGB'
)
...
...
@@ -141,7 +141,7 @@ class ImageColorizeModule(RunModule, ImageServing):
visual_ret
[
'real'
]
=
process
(
real
)
fake
=
lab2rgb
(
np
.
concatenate
((
batch
[
0
].
numpy
(),
out_reg
.
numpy
()),
axis
=
1
))[
i
]
visual_ret
[
'fake_reg'
]
=
process
(
fake
)
mse
=
np
.
mean
((
visual_ret
[
'real'
]
*
1.0
-
visual_ret
[
'fake_reg'
]
*
1.0
)
**
2
)
mse
=
np
.
mean
((
visual_ret
[
'real'
]
*
1.0
-
visual_ret
[
'fake_reg'
]
*
1.0
)
**
2
)
psnr_value
=
20
*
np
.
log10
(
255.
/
np
.
sqrt
(
mse
))
psnrs
.
append
(
psnr_value
)
psnr
=
paddle
.
to_variable
(
np
.
array
(
psnrs
))
...
...
@@ -150,12 +150,12 @@ class ImageColorizeModule(RunModule, ImageServing):
def
predict
(
self
,
images
:
str
,
visualization
:
bool
=
True
,
save_path
:
str
=
'result'
):
'''
Colorize images
Args:
images(str) : Images path to be colorized.
visualization(bool): Whether to save colorized images.
save_path(str) : Path to save colorized images.
Returns:
results(list[dict]) : The prediction result of each input image
'''
...
...
@@ -177,7 +177,7 @@ class ImageColorizeModule(RunModule, ImageServing):
visual_ret
[
'real'
]
=
resize
(
process
(
real
))
fake
=
lab2rgb
(
np
.
concatenate
((
im
[
'A'
],
out_reg
.
numpy
()),
axis
=
1
))[
i
]
visual_ret
[
'fake_reg'
]
=
resize
(
process
(
fake
))
if
visualization
:
fake_name
=
"fake_"
+
str
(
time
.
time
())
+
".png"
if
not
os
.
path
.
exists
(
save_path
):
...
...
@@ -185,8 +185,107 @@ class ImageColorizeModule(RunModule, ImageServing):
fake_path
=
os
.
path
.
join
(
save_path
,
fake_name
)
visual_gray
=
Image
.
fromarray
(
visual_ret
[
'fake_reg'
])
visual_gray
.
save
(
fake_path
)
mse
=
np
.
mean
((
visual_ret
[
'real'
]
*
1.0
-
visual_ret
[
'fake_reg'
]
*
1.0
)
**
2
)
mse
=
np
.
mean
((
visual_ret
[
'real'
]
*
1.0
-
visual_ret
[
'fake_reg'
]
*
1.0
)
**
2
)
psnr_value
=
20
*
np
.
log10
(
255.
/
np
.
sqrt
(
mse
))
result
.
append
(
visual_ret
)
return
result
class
Yolov3Module
(
RunModule
,
ImageServing
):
def
training_step
(
self
,
batch
:
int
,
batch_idx
:
int
)
->
dict
:
'''
One step for training, which should be called as forward computation.
Args:
batch(list[paddle.Tensor]): The one batch data, which contains images, ground truth boxes, labels and scores.
batch_idx(int): The index of batch.
Returns:
results(dict): The model outputs, such as loss.
'''
return
self
.
validation_step
(
batch
,
batch_idx
)
def
validation_step
(
self
,
batch
:
int
,
batch_idx
:
int
)
->
dict
:
'''
One step for validation, which should be called as forward computation.
Args:
batch(list[paddle.Tensor]): The one batch data, which contains images, ground truth boxes, labels and scores.
batch_idx(int): The index of batch.
Returns:
results(dict) : The model outputs, such as metrics.
'''
ious
=
[]
boxtool
=
BoxTool
()
img
=
batch
[
0
].
astype
(
'float32'
)
B
,
C
,
W
,
H
=
img
.
shape
im_shape
=
np
.
array
([(
W
,
H
)]
*
B
).
astype
(
'int32'
)
im_shape
=
paddle
.
to_tensor
(
im_shape
)
gt_box
=
batch
[
1
].
astype
(
'float32'
)
gt_label
=
batch
[
2
].
astype
(
'int32'
)
gt_score
=
batch
[
3
].
astype
(
"float32"
)
loss
,
pred
=
self
(
img
,
gt_box
,
gt_label
,
gt_score
,
im_shape
)
for
i
in
range
(
len
(
pred
)):
bboxes
=
pred
[
i
].
numpy
()
labels
=
bboxes
[:,
0
].
astype
(
'int32'
)
scores
=
bboxes
[:,
1
].
astype
(
'float32'
)
boxes
=
bboxes
[:,
2
:].
astype
(
'float32'
)
iou
=
[]
for
j
,
(
box
,
score
,
label
)
in
enumerate
(
zip
(
boxes
,
scores
,
labels
)):
x1
,
y1
,
x2
,
y2
=
box
w
=
x2
-
x1
+
1
h
=
y2
-
y1
+
1
bbox
=
[
x1
,
y1
,
w
,
h
]
bbox
=
np
.
expand_dims
(
boxtool
.
coco_anno_box_to_center_relative
(
bbox
,
H
,
W
),
0
)
gt
=
gt_box
[
i
].
numpy
()
iou
.
append
(
max
(
boxtool
.
box_iou_xywh
(
bbox
,
gt
)))
ious
.
append
(
max
(
iou
))
ious
=
paddle
.
to_tensor
(
np
.
array
(
ious
))
return
{
'loss'
:
loss
,
'metrics'
:
{
'iou'
:
ious
}}
def
predict
(
self
,
imgpath
:
str
,
filelist
:
str
,
visualization
:
bool
=
True
,
save_path
:
str
=
'result'
):
'''
Detect images
Args:
imgpath(str): Image path .
filelist(str): Path to get label name.
visualization(bool): Whether to save result image.
save_path(str) : Path to save detected images.
Returns:
boxes(np.ndarray): Predict box information.
scores(np.ndarray): Predict score.
labels(np.ndarray): Predict labels.
'''
boxtool
=
BoxTool
()
img
=
{}
img
[
'image'
]
=
imgpath
img
[
'id'
]
=
0
im
,
im_id
,
im_shape
=
self
.
transform
(
img
,
416
)
label_names
=
self
.
get_label_infos
(
filelist
)
img_data
=
np
.
array
([
im
]).
astype
(
'float32'
)
img_data
=
paddle
.
to_tensor
(
img_data
)
im_shape
=
np
.
array
([
im_shape
]).
astype
(
'int32'
)
im_shape
=
paddle
.
to_tensor
(
im_shape
)
output
,
pred
=
self
(
img_data
,
None
,
None
,
None
,
im_shape
)
for
i
in
range
(
len
(
pred
)):
bboxes
=
pred
[
i
].
numpy
()
labels
=
bboxes
[:,
0
].
astype
(
'int32'
)
scores
=
bboxes
[:,
1
].
astype
(
'float32'
)
boxes
=
bboxes
[:,
2
:].
astype
(
'float32'
)
if
visualization
:
boxtool
.
draw_boxes_on_image
(
imgpath
,
boxes
,
scores
,
labels
,
label_names
,
0.5
)
return
boxes
,
scores
,
labels
paddlehub/process/transforms.py
浏览文件 @
5148424a
此差异已折叠。
点击以展开。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录