Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
0d60bf5a
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 2 年 前同步成功
通知
285
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0d60bf5a
编写于
10月 19, 2021
作者:
H
houj04
提交者:
GitHub
10月 19, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add xpu and npu support for object detection series. (#1645)
上级
600eb492
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
701 addition
and
717 deletion
+701
-717
modules/image/object_detection/faster_rcnn_resnet50_coco2017/module.py
.../object_detection/faster_rcnn_resnet50_coco2017/module.py
+124
-129
modules/image/object_detection/faster_rcnn_resnet50_coco2017/processor.py
...ject_detection/faster_rcnn_resnet50_coco2017/processor.py
+11
-27
modules/image/object_detection/ssd_mobilenet_v1_pascal/module.py
.../image/object_detection/ssd_mobilenet_v1_pascal/module.py
+107
-104
modules/image/object_detection/ssd_mobilenet_v1_pascal/processor.py
...age/object_detection/ssd_mobilenet_v1_pascal/processor.py
+11
-27
modules/image/object_detection/yolov3_darknet53_coco2017/module.py
...mage/object_detection/yolov3_darknet53_coco2017/module.py
+87
-30
modules/image/object_detection/yolov3_darknet53_coco2017/processor.py
...e/object_detection/yolov3_darknet53_coco2017/processor.py
+2
-5
modules/image/object_detection/yolov3_darknet53_pedestrian/module.py
...ge/object_detection/yolov3_darknet53_pedestrian/module.py
+111
-99
modules/image/object_detection/yolov3_darknet53_pedestrian/processor.py
...object_detection/yolov3_darknet53_pedestrian/processor.py
+9
-30
modules/image/object_detection/yolov3_mobilenet_v1_coco2017/module.py
...e/object_detection/yolov3_mobilenet_v1_coco2017/module.py
+111
-106
modules/image/object_detection/yolov3_mobilenet_v1_coco2017/processor.py
...bject_detection/yolov3_mobilenet_v1_coco2017/processor.py
+9
-29
modules/image/object_detection/yolov3_resnet50_vd_coco2017/module.py
...ge/object_detection/yolov3_resnet50_vd_coco2017/module.py
+110
-101
modules/image/object_detection/yolov3_resnet50_vd_coco2017/processor.py
...object_detection/yolov3_resnet50_vd_coco2017/processor.py
+9
-30
未找到文件。
modules/image/object_detection/faster_rcnn_resnet50_coco2017/module.py
浏览文件 @
0d60bf5a
...
...
@@ -14,7 +14,10 @@ import numpy as np
import
paddle.fluid
as
fluid
import
paddlehub
as
hub
from
paddlehub.module.module
import
moduleinfo
,
runnable
,
serving
from
paddle.fluid.core
import
PaddleTensor
,
AnalysisConfig
,
create_paddle_predictor
from
paddle.inference
import
Config
from
paddle.inference
import
create_predictor
from
paddlehub.io.parser
import
txt_parser
from
paddlehub.common.paddle_helper
import
add_vars_prefix
...
...
@@ -31,45 +34,65 @@ from faster_rcnn_resnet50_coco2017.roi_extractor import RoIAlign
name
=
"faster_rcnn_resnet50_coco2017"
,
version
=
"1.1.1"
,
type
=
"cv/object_detection"
,
summary
=
"Baidu's Faster R-CNN model for object detection with backbone ResNet50, trained with dataset COCO2017"
,
summary
=
"Baidu's Faster R-CNN model for object detection with backbone ResNet50, trained with dataset COCO2017"
,
author
=
"paddlepaddle"
,
author_email
=
"paddle-dev@baidu.com"
)
class
FasterRCNNResNet50
(
hub
.
Module
):
def
_initialize
(
self
):
# default pretrained model, Faster-RCNN with backbone ResNet50, shape of input tensor is [3, 800, 1333]
self
.
default_pretrained_model_path
=
os
.
path
.
join
(
self
.
directory
,
"faster_rcnn_resnet50_model"
)
self
.
label_names
=
load_label_info
(
os
.
path
.
join
(
self
.
directory
,
"label_file.txt"
))
self
.
default_pretrained_model_path
=
os
.
path
.
join
(
self
.
directory
,
"faster_rcnn_resnet50_model"
)
self
.
label_names
=
load_label_info
(
os
.
path
.
join
(
self
.
directory
,
"label_file.txt"
))
self
.
_set_config
()
def
_get_device_id
(
self
,
places
):
try
:
places
=
os
.
environ
[
places
]
id
=
int
(
places
)
except
:
id
=
-
1
return
id
def
_set_config
(
self
):
"""
predictor config setting
"""
cpu_config
=
AnalysisConfig
(
self
.
default_pretrained_model_path
)
# create default cpu predictor
cpu_config
=
Config
(
self
.
default_pretrained_model_path
)
cpu_config
.
disable_glog_info
()
cpu_config
.
disable_gpu
()
self
.
cpu_predictor
=
create_paddle_predictor
(
cpu_config
)
try
:
_places
=
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
int
(
_places
[
0
])
use_gpu
=
True
except
:
use_gpu
=
False
if
use_gpu
:
gpu_config
=
AnalysisConfig
(
self
.
default_pretrained_model_path
)
self
.
cpu_predictor
=
create_predictor
(
cpu_config
)
# create predictors using various types of devices
# npu
npu_id
=
self
.
_get_device_id
(
"FLAGS_selected_npus"
)
if
npu_id
!=
-
1
:
# use npu
npu_config
=
Config
(
self
.
default_pretrained_model_path
)
npu_config
.
disable_glog_info
()
npu_config
.
enable_npu
(
device_id
=
npu_id
)
self
.
npu_predictor
=
create_predictor
(
npu_config
)
# gpu
gpu_id
=
self
.
_get_device_id
(
"CUDA_VISIBLE_DEVICES"
)
if
gpu_id
!=
-
1
:
# use gpu
gpu_config
=
Config
(
self
.
default_pretrained_model_path
)
gpu_config
.
disable_glog_info
()
gpu_config
.
enable_use_gpu
(
memory_pool_init_size_mb
=
500
,
device_id
=
0
)
self
.
gpu_predictor
=
create_paddle_predictor
(
gpu_config
)
def
context
(
self
,
num_classes
=
81
,
trainable
=
True
,
pretrained
=
True
,
phase
=
'train'
):
gpu_config
.
enable_use_gpu
(
memory_pool_init_size_mb
=
500
,
device_id
=
gpu_id
)
self
.
gpu_predictor
=
create_predictor
(
gpu_config
)
# xpu
xpu_id
=
self
.
_get_device_id
(
"XPU_VISIBLE_DEVICES"
)
if
xpu_id
!=
-
1
:
# use xpu
xpu_config
=
Config
(
self
.
default_pretrained_model_path
)
xpu_config
.
disable_glog_info
()
xpu_config
.
enable_xpu
(
100
)
self
.
xpu_predictor
=
create_predictor
(
xpu_config
)
def
context
(
self
,
num_classes
=
81
,
trainable
=
True
,
pretrained
=
True
,
phase
=
'train'
):
"""
Distill the Head Features, so as to perform transfer learning.
...
...
@@ -88,34 +111,24 @@ class FasterRCNNResNet50(hub.Module):
startup_program
=
fluid
.
Program
()
with
fluid
.
program_guard
(
context_prog
,
startup_program
):
with
fluid
.
unique_name
.
guard
():
image
=
fluid
.
layers
.
data
(
name
=
'image'
,
shape
=
[
-
1
,
3
,
-
1
,
-
1
],
dtype
=
'float32'
)
image
=
fluid
.
layers
.
data
(
name
=
'image'
,
shape
=
[
-
1
,
3
,
-
1
,
-
1
],
dtype
=
'float32'
)
# backbone
backbone
=
ResNet
(
norm_type
=
'affine_channel'
,
depth
=
50
,
feature_maps
=
4
,
freeze_at
=
2
)
backbone
=
ResNet
(
norm_type
=
'affine_channel'
,
depth
=
50
,
feature_maps
=
4
,
freeze_at
=
2
)
body_feats
=
backbone
(
image
)
# var_prefix
var_prefix
=
'@HUB_{}@'
.
format
(
self
.
name
)
im_info
=
fluid
.
layers
.
data
(
name
=
'im_info'
,
shape
=
[
3
],
dtype
=
'float32'
,
lod_level
=
0
)
im_shape
=
fluid
.
layers
.
data
(
name
=
'im_shape'
,
shape
=
[
3
],
dtype
=
'float32'
,
lod_level
=
0
)
im_info
=
fluid
.
layers
.
data
(
name
=
'im_info'
,
shape
=
[
3
],
dtype
=
'float32'
,
lod_level
=
0
)
im_shape
=
fluid
.
layers
.
data
(
name
=
'im_shape'
,
shape
=
[
3
],
dtype
=
'float32'
,
lod_level
=
0
)
body_feat_names
=
list
(
body_feats
.
keys
())
# rpn_head: RPNHead
rpn_head
=
self
.
rpn_head
()
rois
=
rpn_head
.
get_proposals
(
body_feats
,
im_info
,
mode
=
phase
)
# train
if
phase
==
'train'
:
gt_bbox
=
fluid
.
layers
.
data
(
name
=
'gt_bbox'
,
shape
=
[
4
],
dtype
=
'float32'
,
lod_level
=
1
)
is_crowd
=
fluid
.
layers
.
data
(
name
=
'is_crowd'
,
shape
=
[
1
],
dtype
=
'int32'
,
lod_level
=
1
)
gt_class
=
fluid
.
layers
.
data
(
name
=
'gt_class'
,
shape
=
[
1
],
dtype
=
'int32'
,
lod_level
=
1
)
gt_bbox
=
fluid
.
layers
.
data
(
name
=
'gt_bbox'
,
shape
=
[
4
],
dtype
=
'float32'
,
lod_level
=
1
)
is_crowd
=
fluid
.
layers
.
data
(
name
=
'is_crowd'
,
shape
=
[
1
],
dtype
=
'int32'
,
lod_level
=
1
)
gt_class
=
fluid
.
layers
.
data
(
name
=
'gt_class'
,
shape
=
[
1
],
dtype
=
'int32'
,
lod_level
=
1
)
rpn_loss
=
rpn_head
.
get_loss
(
im_info
,
gt_bbox
,
is_crowd
)
# bbox_assigner: BBoxAssigner
bbox_assigner
=
self
.
bbox_assigner
(
num_classes
)
...
...
@@ -160,18 +173,13 @@ class FasterRCNNResNet50(hub.Module):
'is_crowd'
:
var_prefix
+
is_crowd
.
name
}
outputs
=
{
'head_features'
:
var_prefix
+
head_feat
.
name
,
'rpn_cls_loss'
:
var_prefix
+
rpn_loss
[
'rpn_cls_loss'
].
name
,
'rpn_reg_loss'
:
var_prefix
+
rpn_loss
[
'rpn_reg_loss'
].
name
,
'generate_proposal_labels'
:
[
var_prefix
+
var
.
name
for
var
in
outs
]
'head_features'
:
var_prefix
+
head_feat
.
name
,
'rpn_cls_loss'
:
var_prefix
+
rpn_loss
[
'rpn_cls_loss'
].
name
,
'rpn_reg_loss'
:
var_prefix
+
rpn_loss
[
'rpn_reg_loss'
].
name
,
'generate_proposal_labels'
:
[
var_prefix
+
var
.
name
for
var
in
outs
]
}
elif
phase
==
'predict'
:
pred
=
bbox_head
.
get_prediction
(
roi_feat
,
rois
,
im_info
,
im_shape
)
pred
=
bbox_head
.
get_prediction
(
roi_feat
,
rois
,
im_info
,
im_shape
)
inputs
=
{
'image'
:
var_prefix
+
image
.
name
,
'im_info'
:
var_prefix
+
im_info
.
name
,
...
...
@@ -186,13 +194,9 @@ class FasterRCNNResNet50(hub.Module):
add_vars_prefix
(
startup_program
,
var_prefix
)
global_vars
=
context_prog
.
global_block
().
vars
inputs
=
{
key
:
global_vars
[
value
]
for
key
,
value
in
inputs
.
items
()
}
inputs
=
{
key
:
global_vars
[
value
]
for
key
,
value
in
inputs
.
items
()}
outputs
=
{
key
:
global_vars
[
value
]
if
not
isinstance
(
value
,
list
)
else
[
global_vars
[
var
]
for
var
in
value
]
key
:
global_vars
[
value
]
if
not
isinstance
(
value
,
list
)
else
[
global_vars
[
var
]
for
var
in
value
]
for
key
,
value
in
outputs
.
items
()
}
...
...
@@ -208,14 +212,9 @@ class FasterRCNNResNet50(hub.Module):
if
num_classes
!=
81
:
if
'bbox_pred'
in
var
.
name
or
'cls_score'
in
var
.
name
:
return
False
return
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
default_pretrained_model_path
,
var
.
name
))
fluid
.
io
.
load_vars
(
exe
,
self
.
default_pretrained_model_path
,
predicate
=
_if_exist
)
return
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
default_pretrained_model_path
,
var
.
name
))
fluid
.
io
.
load_vars
(
exe
,
self
.
default_pretrained_model_path
,
predicate
=
_if_exist
)
return
inputs
,
outputs
,
context_prog
def
rpn_head
(
self
):
...
...
@@ -231,16 +230,8 @@ class FasterRCNNResNet50(hub.Module):
rpn_negative_overlap
=
0.3
,
rpn_positive_overlap
=
0.7
,
rpn_straddle_thresh
=
0.0
),
train_proposal
=
GenerateProposals
(
min_size
=
0.0
,
nms_thresh
=
0.7
,
post_nms_top_n
=
12000
,
pre_nms_top_n
=
2000
),
test_proposal
=
GenerateProposals
(
min_size
=
0.0
,
nms_thresh
=
0.7
,
post_nms_top_n
=
6000
,
pre_nms_top_n
=
1000
))
train_proposal
=
GenerateProposals
(
min_size
=
0.0
,
nms_thresh
=
0.7
,
post_nms_top_n
=
12000
,
pre_nms_top_n
=
2000
),
test_proposal
=
GenerateProposals
(
min_size
=
0.0
,
nms_thresh
=
0.7
,
post_nms_top_n
=
6000
,
pre_nms_top_n
=
1000
))
def
roi_extractor
(
self
):
return
RoIAlign
(
resolution
=
14
,
sampling_ratio
=
0
,
spatial_scale
=
0.0625
)
...
...
@@ -248,8 +239,7 @@ class FasterRCNNResNet50(hub.Module):
def
bbox_head
(
self
,
num_classes
):
return
BBoxHead
(
head
=
ResNetC5
(
depth
=
50
,
norm_type
=
'affine_channel'
),
nms
=
MultiClassNMS
(
keep_top_k
=
100
,
nms_threshold
=
0.5
,
score_threshold
=
0.05
),
nms
=
MultiClassNMS
(
keep_top_k
=
100
,
nms_threshold
=
0.5
,
score_threshold
=
0.05
),
bbox_loss
=
SmoothL1Loss
(),
num_classes
=
num_classes
)
...
...
@@ -263,11 +253,7 @@ class FasterRCNNResNet50(hub.Module):
fg_thresh
=
0.5
,
class_nums
=
num_classes
)
def
save_inference_model
(
self
,
dirname
,
model_filename
=
None
,
params_filename
=
None
,
combined
=
True
):
def
save_inference_model
(
self
,
dirname
,
model_filename
=
None
,
params_filename
=
None
,
combined
=
True
):
if
combined
:
model_filename
=
"__model__"
if
not
model_filename
else
model_filename
params_filename
=
"__params__"
if
not
params_filename
else
params_filename
...
...
@@ -294,7 +280,8 @@ class FasterRCNNResNet50(hub.Module):
batch_size
=
1
,
output_dir
=
'detection_result'
,
score_thresh
=
0.5
,
visualization
=
True
):
visualization
=
True
,
use_device
=
None
):
"""API of Object Detection.
Args:
...
...
@@ -305,6 +292,7 @@ class FasterRCNNResNet50(hub.Module):
output_dir (str): The path to store output images.
visualization (bool): Whether to save image or not.
score_thresh (float): threshold for object detecion.
use_device (str): use cpu, gpu, xpu or npu, overwrites use_gpu flag.
Returns:
res (list[dict]): The result of coco2017 detecion. keys include 'data', 'save_path', the corresponding value is:
...
...
@@ -317,14 +305,25 @@ class FasterRCNNResNet50(hub.Module):
confidence (float): The confidence of detection result.
save_path (str, optional): The path to save output images.
"""
if
use_gpu
:
try
:
_places
=
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
int
(
_places
[
0
])
except
:
raise
RuntimeError
(
"Attempt to use GPU for prediction, but environment variable CUDA_VISIBLE_DEVICES was not set correctly."
)
# real predictor to use
if
use_device
is
not
None
:
if
use_device
==
"cpu"
:
predictor
=
self
.
cpu_predictor
elif
use_device
==
"xpu"
:
predictor
=
self
.
xpu_predictor
elif
use_device
==
"npu"
:
predictor
=
self
.
npu_predictor
elif
use_device
==
"gpu"
:
predictor
=
self
.
gpu_predictor
else
:
raise
Exception
(
"Unsupported device: "
+
use_device
)
else
:
# use_device is not set, therefore follow use_gpu
if
use_gpu
:
predictor
=
self
.
gpu_predictor
else
:
predictor
=
self
.
cpu_predictor
paths
=
paths
if
paths
else
list
()
if
data
and
'image'
in
data
:
paths
+=
data
[
'image'
]
...
...
@@ -345,22 +344,30 @@ class FasterRCNNResNet50(hub.Module):
except
:
pass
padding_image
,
padding_info
,
padding_shape
=
padding_minibatch
(
batch_data
)
padding_image_tensor
=
PaddleTensor
(
padding_image
.
copy
())
padding_info_tensor
=
PaddleTensor
(
padding_info
.
copy
())
padding_shape_tensor
=
PaddleTensor
(
padding_shape
.
copy
())
feed_list
=
[
padding_image_tensor
,
padding_info_tensor
,
padding_shape_tensor
]
if
use_gpu
:
data_out
=
self
.
gpu_predictor
.
run
(
feed_list
)
else
:
data_out
=
self
.
cpu_predictor
.
run
(
feed_list
)
padding_image
,
padding_info
,
padding_shape
=
padding_minibatch
(
batch_data
)
input_names
=
predictor
.
get_input_names
()
padding_image_tensor
=
predictor
.
get_input_handle
(
input_names
[
0
])
padding_image_tensor
.
reshape
(
padding_image
.
shape
)
padding_image_tensor
.
copy_from_cpu
(
padding_image
.
copy
())
padding_info_tensor
=
predictor
.
get_input_handle
(
input_names
[
1
])
padding_info_tensor
.
reshape
(
padding_info
.
shape
)
padding_info_tensor
.
copy_from_cpu
(
padding_info
.
copy
())
padding_shape_tensor
=
predictor
.
get_input_handle
(
input_names
[
2
])
padding_shape_tensor
.
reshape
(
padding_shape
.
shape
)
padding_shape_tensor
.
copy_from_cpu
(
padding_shape
.
copy
())
predictor
.
run
()
output_names
=
predictor
.
get_output_names
()
output_handle
=
predictor
.
get_output_handle
(
output_names
[
0
])
output
=
postprocess
(
paths
=
paths
,
images
=
images
,
data_out
=
data_out
,
data_out
=
output_handle
,
score_thresh
=
score_thresh
,
label_names
=
self
.
label_names
,
output_dir
=
output_dir
,
...
...
@@ -374,29 +381,21 @@ class FasterRCNNResNet50(hub.Module):
Add the command config options
"""
self
.
arg_config_group
.
add_argument
(
'--use_gpu'
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"whether use GPU or not"
)
'--use_gpu'
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"whether use GPU or not"
)
self
.
arg_config_group
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
1
,
help
=
"batch size for prediction"
)
self
.
arg_config_group
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
1
,
help
=
"batch size for prediction"
)
'--use_device'
,
choices
=
[
"cpu"
,
"gpu"
,
"xpu"
,
"npu"
],
help
=
"use cpu, gpu, xpu or npu. overwrites use_gpu flag."
)
def
add_module_input_arg
(
self
):
"""
Add the command input options
"""
self
.
arg_input_group
.
add_argument
(
'--input_path'
,
type
=
str
,
default
=
None
,
help
=
"input data"
)
self
.
arg_input_group
.
add_argument
(
'--input_path'
,
type
=
str
,
default
=
None
,
help
=
"input data"
)
self
.
arg_input_group
.
add_argument
(
'--input_file'
,
type
=
str
,
default
=
None
,
help
=
"file contain input data"
)
self
.
arg_input_group
.
add_argument
(
'--input_file'
,
type
=
str
,
default
=
None
,
help
=
"file contain input data"
)
def
check_input_data
(
self
,
args
):
input_data
=
[]
...
...
@@ -425,12 +424,9 @@ class FasterRCNNResNet50(hub.Module):
prog
=
"hub run {}"
.
format
(
self
.
name
),
usage
=
'%(prog)s'
,
add_help
=
True
)
self
.
arg_input_group
=
self
.
parser
.
add_argument_group
(
title
=
"Input options"
,
description
=
"Input data. Required"
)
self
.
arg_input_group
=
self
.
parser
.
add_argument_group
(
title
=
"Input options"
,
description
=
"Input data. Required"
)
self
.
arg_config_group
=
self
.
parser
.
add_argument_group
(
title
=
"Config options"
,
description
=
"Run configuration for controlling module behavior, not required."
)
title
=
"Config options"
,
description
=
"Run configuration for controlling module behavior, not required."
)
self
.
add_module_config_arg
()
self
.
add_module_input_arg
()
...
...
@@ -442,7 +438,6 @@ class FasterRCNNResNet50(hub.Module):
else
:
for
image_path
in
input_data
:
if
not
os
.
path
.
exists
(
image_path
):
raise
RuntimeError
(
"File %s or %s is not exist."
%
image_path
)
raise
RuntimeError
(
"File %s or %s is not exist."
%
image_path
)
return
self
.
object_detection
(
paths
=
input_data
,
use_gpu
=
args
.
use_gpu
,
batch_size
=
args
.
batch_size
)
paths
=
input_data
,
use_gpu
=
args
.
use_gpu
,
batch_size
=
args
.
batch_size
,
use_device
=
args
.
use_device
)
modules/image/object_detection/faster_rcnn_resnet50_coco2017/processor.py
浏览文件 @
0d60bf5a
...
...
@@ -19,6 +19,7 @@ def base64_to_cv2(b64str):
data
=
cv2
.
imdecode
(
data
,
cv2
.
IMREAD_COLOR
)
return
data
def
check_dir
(
dir_path
):
if
not
os
.
path
.
exists
(
dir_path
):
os
.
makedirs
(
dir_path
)
...
...
@@ -26,6 +27,7 @@ def check_dir(dir_path):
os
.
remove
(
dir_path
)
os
.
makedirs
(
dir_path
)
def
get_save_image_name
(
img
,
output_dir
,
image_path
):
"""Get save image name from source image path.
"""
...
...
@@ -54,23 +56,17 @@ def draw_bounding_box_on_image(image_path, data_list, save_dir):
image
=
Image
.
open
(
image_path
)
draw
=
ImageDraw
.
Draw
(
image
)
for
data
in
data_list
:
left
,
right
,
top
,
bottom
=
data
[
'left'
],
data
[
'right'
],
data
[
'top'
],
data
[
'bottom'
]
left
,
right
,
top
,
bottom
=
data
[
'left'
],
data
[
'right'
],
data
[
'top'
],
data
[
'bottom'
]
# draw bbox
draw
.
line
([(
left
,
top
),
(
left
,
bottom
),
(
right
,
bottom
),
(
right
,
top
),
(
left
,
top
)],
width
=
2
,
fill
=
'red'
)
draw
.
line
([(
left
,
top
),
(
left
,
bottom
),
(
right
,
bottom
),
(
right
,
top
),
(
left
,
top
)],
width
=
2
,
fill
=
'red'
)
# draw label
if
image
.
mode
==
'RGB'
:
text
=
data
[
'label'
]
+
": %.2f%%"
%
(
100
*
data
[
'confidence'
])
textsize_width
,
textsize_height
=
draw
.
textsize
(
text
=
text
)
draw
.
rectangle
(
xy
=
(
left
,
top
-
(
textsize_height
+
5
),
left
+
textsize_width
+
10
,
top
),
fill
=
(
255
,
255
,
255
))
xy
=
(
left
,
top
-
(
textsize_height
+
5
),
left
+
textsize_width
+
10
,
top
),
fill
=
(
255
,
255
,
255
))
draw
.
text
(
xy
=
(
left
,
top
-
15
),
text
=
text
,
fill
=
(
0
,
0
,
0
))
save_name
=
get_save_image_name
(
image
,
save_dir
,
image_path
)
...
...
@@ -98,14 +94,7 @@ def load_label_info(file_path):
return
label_names
def
postprocess
(
paths
,
images
,
data_out
,
score_thresh
,
label_names
,
output_dir
,
handle_id
,
visualization
=
True
):
def
postprocess
(
paths
,
images
,
data_out
,
score_thresh
,
label_names
,
output_dir
,
handle_id
,
visualization
=
True
):
"""
postprocess the lod_tensor produced by fluid.Executor.run
...
...
@@ -130,9 +119,8 @@ def postprocess(paths,
confidence (float): The confidence of detection result.
save_path (str): The path to save output images.
"""
lod_tensor
=
data_out
[
0
]
lod
=
lod_tensor
.
lod
[
0
]
results
=
lod_tensor
.
as_ndarray
()
results
=
data_out
.
copy_to_cpu
()
lod
=
data_out
.
lod
()[
0
]
check_dir
(
output_dir
)
...
...
@@ -162,9 +150,7 @@ def postprocess(paths,
org_img
=
org_img
.
astype
(
np
.
uint8
)
org_img
=
Image
.
fromarray
(
org_img
[:,
:,
::
-
1
])
if
visualization
:
org_img_path
=
get_save_image_name
(
org_img
,
output_dir
,
'image_numpy_{}'
.
format
(
(
handle_id
+
index
)))
org_img_path
=
get_save_image_name
(
org_img
,
output_dir
,
'image_numpy_{}'
.
format
((
handle_id
+
index
)))
org_img
.
save
(
org_img_path
)
org_img_height
=
org_img
.
height
org_img_width
=
org_img
.
width
...
...
@@ -180,13 +166,11 @@ def postprocess(paths,
dt
=
{}
dt
[
'label'
]
=
label_names
[
category_id
]
dt
[
'confidence'
]
=
float
(
confidence
)
dt
[
'left'
],
dt
[
'top'
],
dt
[
'right'
],
dt
[
'bottom'
]
=
clip_bbox
(
bbox
,
org_img_width
,
org_img_height
)
dt
[
'left'
],
dt
[
'top'
],
dt
[
'right'
],
dt
[
'bottom'
]
=
clip_bbox
(
bbox
,
org_img_width
,
org_img_height
)
output_i
[
'data'
].
append
(
dt
)
output
.
append
(
output_i
)
if
visualization
:
output_i
[
'save_path'
]
=
draw_bounding_box_on_image
(
org_img_path
,
output_i
[
'data'
],
output_dir
)
output_i
[
'save_path'
]
=
draw_bounding_box_on_image
(
org_img_path
,
output_i
[
'data'
],
output_dir
)
return
output
modules/image/object_detection/ssd_mobilenet_v1_pascal/module.py
浏览文件 @
0d60bf5a
...
...
@@ -10,7 +10,10 @@ import yaml
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddlehub
as
hub
from
paddle.fluid.core
import
PaddleTensor
,
AnalysisConfig
,
create_paddle_predictor
from
paddle.inference
import
Config
from
paddle.inference
import
create_predictor
from
paddlehub.module.module
import
moduleinfo
,
runnable
,
serving
from
paddlehub.common.paddle_helper
import
add_vars_prefix
...
...
@@ -28,32 +31,59 @@ from ssd_mobilenet_v1_pascal.data_feed import reader
author_email
=
"paddle-dev@baidu.com"
)
class
SSDMobileNetv1
(
hub
.
Module
):
def
_initialize
(
self
):
self
.
default_pretrained_model_path
=
os
.
path
.
join
(
self
.
directory
,
"ssd_mobilenet_v1_model"
)
self
.
label_names
=
load_label_info
(
os
.
path
.
join
(
self
.
directory
,
"label_file.txt"
))
self
.
default_pretrained_model_path
=
os
.
path
.
join
(
self
.
directory
,
"ssd_mobilenet_v1_model"
)
self
.
label_names
=
load_label_info
(
os
.
path
.
join
(
self
.
directory
,
"label_file.txt"
))
self
.
model_config
=
None
self
.
_set_config
()
def
_get_device_id
(
self
,
places
):
try
:
places
=
os
.
environ
[
places
]
id
=
int
(
places
)
except
:
id
=
-
1
return
id
def
_set_config
(
self
):
# predictor config setting.
cpu_config
=
AnalysisConfig
(
self
.
default_pretrained_model_path
)
"""
predictor config setting.
"""
# create default cpu predictor
cpu_config
=
Config
(
self
.
default_pretrained_model_path
)
cpu_config
.
disable_glog_info
()
cpu_config
.
disable_gpu
()
cpu_config
.
switch_ir_optim
(
False
)
self
.
cpu_predictor
=
create_p
addle_p
redictor
(
cpu_config
)
self
.
cpu_predictor
=
create_predictor
(
cpu_config
)
try
:
_places
=
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
int
(
_places
[
0
])
use_gpu
=
True
except
:
use_gpu
=
False
if
use_gpu
:
gpu_config
=
AnalysisConfig
(
self
.
default_pretrained_model_path
)
# create predictors using various types of devices
# npu
npu_id
=
self
.
_get_device_id
(
"FLAGS_selected_npus"
)
if
npu_id
!=
-
1
:
# use npu
npu_config
=
Config
(
self
.
default_pretrained_model_path
)
npu_config
.
disable_glog_info
()
npu_config
.
enable_npu
(
device_id
=
npu_id
)
self
.
npu_predictor
=
create_predictor
(
npu_config
)
# gpu
gpu_id
=
self
.
_get_device_id
(
"CUDA_VISIBLE_DEVICES"
)
if
gpu_id
!=
-
1
:
# use gpu
gpu_config
=
Config
(
self
.
default_pretrained_model_path
)
gpu_config
.
disable_glog_info
()
gpu_config
.
enable_use_gpu
(
memory_pool_init_size_mb
=
500
,
device_id
=
0
)
self
.
gpu_predictor
=
create_paddle_predictor
(
gpu_config
)
gpu_config
.
enable_use_gpu
(
memory_pool_init_size_mb
=
500
,
device_id
=
gpu_id
)
self
.
gpu_predictor
=
create_predictor
(
gpu_config
)
# xpu
xpu_id
=
self
.
_get_device_id
(
"XPU_VISIBLE_DEVICES"
)
if
xpu_id
!=
-
1
:
# use xpu
xpu_config
=
Config
(
self
.
default_pretrained_model_path
)
xpu_config
.
disable_glog_info
()
xpu_config
.
enable_xpu
(
100
)
self
.
xpu_predictor
=
create_predictor
(
xpu_config
)
# model config setting.
if
not
self
.
model_config
:
...
...
@@ -83,55 +113,34 @@ class SSDMobileNetv1(hub.Module):
with
fluid
.
program_guard
(
context_prog
,
startup_program
):
with
fluid
.
unique_name
.
guard
():
# image
image
=
fluid
.
layers
.
data
(
name
=
'image'
,
shape
=
[
3
,
300
,
300
],
dtype
=
'float32'
)
image
=
fluid
.
layers
.
data
(
name
=
'image'
,
shape
=
[
3
,
300
,
300
],
dtype
=
'float32'
)
# backbone
backbone
=
MobileNet
(
**
self
.
mobilenet_config
)
# body_feats
body_feats
=
backbone
(
image
)
# im_size
im_size
=
fluid
.
layers
.
data
(
name
=
'im_size'
,
shape
=
[
2
],
dtype
=
'int32'
)
im_size
=
fluid
.
layers
.
data
(
name
=
'im_size'
,
shape
=
[
2
],
dtype
=
'int32'
)
# var_prefix
var_prefix
=
'@HUB_{}@'
.
format
(
self
.
name
)
# names of inputs
inputs
=
{
'image'
:
var_prefix
+
image
.
name
,
'im_size'
:
var_prefix
+
im_size
.
name
}
inputs
=
{
'image'
:
var_prefix
+
image
.
name
,
'im_size'
:
var_prefix
+
im_size
.
name
}
# names of outputs
if
get_prediction
:
locs
,
confs
,
box
,
box_var
=
fluid
.
layers
.
multi_box_head
(
inputs
=
body_feats
,
image
=
image
,
num_classes
=
21
,
**
self
.
multi_box_head_config
)
inputs
=
body_feats
,
image
=
image
,
num_classes
=
21
,
**
self
.
multi_box_head_config
)
pred
=
fluid
.
layers
.
detection_output
(
loc
=
locs
,
scores
=
confs
,
prior_box
=
box
,
prior_box_var
=
box_var
,
**
self
.
output_decoder_config
)
loc
=
locs
,
scores
=
confs
,
prior_box
=
box
,
prior_box_var
=
box_var
,
**
self
.
output_decoder_config
)
outputs
=
{
'bbox_out'
:
[
var_prefix
+
pred
.
name
]}
else
:
outputs
=
{
'body_features'
:
[
var_prefix
+
var
.
name
for
var
in
body_feats
]
}
outputs
=
{
'body_features'
:
[
var_prefix
+
var
.
name
for
var
in
body_feats
]}
# add_vars_prefix
add_vars_prefix
(
context_prog
,
var_prefix
)
add_vars_prefix
(
fluid
.
default_startup_program
(),
var_prefix
)
# inputs
inputs
=
{
key
:
context_prog
.
global_block
().
vars
[
value
]
for
key
,
value
in
inputs
.
items
()
}
inputs
=
{
key
:
context_prog
.
global_block
().
vars
[
value
]
for
key
,
value
in
inputs
.
items
()}
outputs
=
{
out_key
:
[
context_prog
.
global_block
().
vars
[
varname
]
for
varname
in
out_value
]
out_key
:
[
context_prog
.
global_block
().
vars
[
varname
]
for
varname
in
out_value
]
for
out_key
,
out_value
in
outputs
.
items
()
}
# trainable
...
...
@@ -144,14 +153,9 @@ class SSDMobileNetv1(hub.Module):
if
pretrained
:
def
_if_exist
(
var
):
return
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
default_pretrained_model_path
,
var
.
name
))
fluid
.
io
.
load_vars
(
exe
,
self
.
default_pretrained_model_path
,
predicate
=
_if_exist
)
return
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
default_pretrained_model_path
,
var
.
name
))
fluid
.
io
.
load_vars
(
exe
,
self
.
default_pretrained_model_path
,
predicate
=
_if_exist
)
else
:
exe
.
run
(
startup_program
)
...
...
@@ -165,7 +169,8 @@ class SSDMobileNetv1(hub.Module):
use_gpu
=
False
,
output_dir
=
'detection_result'
,
score_thresh
=
0.5
,
visualization
=
True
):
visualization
=
True
,
use_device
=
None
):
"""API of Object Detection.
Args:
...
...
@@ -176,6 +181,7 @@ class SSDMobileNetv1(hub.Module):
output_dir (str): The path to store output images.
visualization (bool): Whether to save image or not.
score_thresh (float): threshold for object detecion.
use_device (str): use cpu, gpu, xpu or npu, overwrites use_gpu flag.
Returns:
res (list[dict]): The result of coco2017 detecion. keys include 'data', 'save_path', the corresponding value is:
...
...
@@ -188,14 +194,24 @@ class SSDMobileNetv1(hub.Module):
confidence (float): The confidence of detection result.
save_path (str, optional): The path to save output images.
"""
if
use_gpu
:
try
:
_places
=
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
int
(
_places
[
0
])
except
:
raise
RuntimeError
(
"Attempt to use GPU for prediction, but environment variable CUDA_VISIBLE_DEVICES was not set correctly."
)
# real predictor to use
if
use_device
is
not
None
:
if
use_device
==
"cpu"
:
predictor
=
self
.
cpu_predictor
elif
use_device
==
"xpu"
:
predictor
=
self
.
xpu_predictor
elif
use_device
==
"npu"
:
predictor
=
self
.
npu_predictor
elif
use_device
==
"gpu"
:
predictor
=
self
.
gpu_predictor
else
:
raise
Exception
(
"Unsupported device: "
+
use_device
)
else
:
# use_device is not set, therefore follow use_gpu
if
use_gpu
:
predictor
=
self
.
gpu_predictor
else
:
predictor
=
self
.
cpu_predictor
paths
=
paths
if
paths
else
list
()
if
data
and
'image'
in
data
:
...
...
@@ -206,16 +222,22 @@ class SSDMobileNetv1(hub.Module):
res
=
[]
for
iter_id
,
feed_data
in
enumerate
(
batch_reader
()):
feed_data
=
np
.
array
(
feed_data
)
image_tensor
=
PaddleTensor
(
np
.
array
(
list
(
feed_data
[:,
0
])).
copy
())
if
use_gpu
:
data_out
=
self
.
gpu_predictor
.
run
([
image_tensor
])
else
:
data_out
=
self
.
cpu_predictor
.
run
([
image_tensor
])
input_names
=
predictor
.
get_input_names
()
image_data
=
np
.
array
(
list
(
feed_data
[:,
0
]))
image_tensor
=
predictor
.
get_input_handle
(
input_names
[
0
])
image_tensor
.
reshape
(
image_data
.
shape
)
image_tensor
.
copy_from_cpu
(
image_data
.
copy
())
predictor
.
run
()
output_names
=
predictor
.
get_output_names
()
output_handle
=
predictor
.
get_output_handle
(
output_names
[
0
])
output
=
postprocess
(
paths
=
paths
,
images
=
images
,
data_out
=
data_out
,
data_out
=
output_handle
,
score_thresh
=
score_thresh
,
label_names
=
self
.
label_names
,
output_dir
=
output_dir
,
...
...
@@ -224,11 +246,7 @@ class SSDMobileNetv1(hub.Module):
res
.
extend
(
output
)
return
res
def
save_inference_model
(
self
,
dirname
,
model_filename
=
None
,
params_filename
=
None
,
combined
=
True
):
def
save_inference_model
(
self
,
dirname
,
model_filename
=
None
,
params_filename
=
None
,
combined
=
True
):
if
combined
:
model_filename
=
"__model__"
if
not
model_filename
else
model_filename
params_filename
=
"__params__"
if
not
params_filename
else
params_filename
...
...
@@ -266,12 +284,9 @@ class SSDMobileNetv1(hub.Module):
prog
=
'hub run {}'
.
format
(
self
.
name
),
usage
=
'%(prog)s'
,
add_help
=
True
)
self
.
arg_input_group
=
self
.
parser
.
add_argument_group
(
title
=
"Input options"
,
description
=
"Input data. Required"
)
self
.
arg_input_group
=
self
.
parser
.
add_argument_group
(
title
=
"Input options"
,
description
=
"Input data. Required"
)
self
.
arg_config_group
=
self
.
parser
.
add_argument_group
(
title
=
"Config options"
,
description
=
"Run configuration for controlling module behavior, not required."
)
title
=
"Config options"
,
description
=
"Run configuration for controlling module behavior, not required."
)
self
.
add_module_config_arg
()
self
.
add_module_input_arg
()
args
=
self
.
parser
.
parse_args
(
argvs
)
...
...
@@ -281,7 +296,8 @@ class SSDMobileNetv1(hub.Module):
use_gpu
=
args
.
use_gpu
,
output_dir
=
args
.
output_dir
,
visualization
=
args
.
visualization
,
score_thresh
=
args
.
score_thresh
)
score_thresh
=
args
.
score_thresh
,
use_device
=
args
.
use_device
)
return
results
def
add_module_config_arg
(
self
):
...
...
@@ -289,34 +305,21 @@ class SSDMobileNetv1(hub.Module):
Add the command config options.
"""
self
.
arg_config_group
.
add_argument
(
'--use_gpu'
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"whether use GPU or not"
)
'--use_gpu'
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"whether use GPU or not"
)
self
.
arg_config_group
.
add_argument
(
'--output_dir'
,
type
=
str
,
default
=
'detection_result'
,
help
=
"The directory to save output images."
)
'--output_dir'
,
type
=
str
,
default
=
'detection_result'
,
help
=
"The directory to save output images."
)
self
.
arg_config_group
.
add_argument
(
'--visualization'
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"whether to save output as images."
)
'--visualization'
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"whether to save output as images."
)
self
.
arg_config_group
.
add_argument
(
'--use_device'
,
choices
=
[
"cpu"
,
"gpu"
,
"xpu"
,
"npu"
],
help
=
"use cpu, gpu, xpu or npu. overwrites use_gpu flag."
)
def
add_module_input_arg
(
self
):
"""
Add the command input options.
"""
self
.
arg_input_group
.
add_argument
(
'--input_path'
,
type
=
str
,
help
=
"path to image."
)
self
.
arg_input_group
.
add_argument
(
'--batch_size'
,
type
=
ast
.
literal_eval
,
default
=
1
,
help
=
"batch size."
)
self
.
arg_input_group
.
add_argument
(
'--input_path'
,
type
=
str
,
help
=
"path to image."
)
self
.
arg_input_group
.
add_argument
(
'--batch_size'
,
type
=
ast
.
literal_eval
,
default
=
1
,
help
=
"batch size."
)
self
.
arg_input_group
.
add_argument
(
'--score_thresh'
,
type
=
ast
.
literal_eval
,
default
=
0.5
,
help
=
"threshold for object detecion."
)
'--score_thresh'
,
type
=
ast
.
literal_eval
,
default
=
0.5
,
help
=
"threshold for object detecion."
)
modules/image/object_detection/ssd_mobilenet_v1_pascal/processor.py
浏览文件 @
0d60bf5a
...
...
@@ -15,6 +15,7 @@ def base64_to_cv2(b64str):
data
=
cv2
.
imdecode
(
data
,
cv2
.
IMREAD_COLOR
)
return
data
def
check_dir
(
dir_path
):
if
not
os
.
path
.
exists
(
dir_path
):
os
.
makedirs
(
dir_path
)
...
...
@@ -22,6 +23,7 @@ def check_dir(dir_path):
os
.
remove
(
dir_path
)
os
.
makedirs
(
dir_path
)
def
get_save_image_name
(
img
,
output_dir
,
image_path
):
"""
Get save image name from source image path.
...
...
@@ -50,23 +52,17 @@ def draw_bounding_box_on_image(image_path, data_list, save_dir):
image
=
Image
.
open
(
image_path
)
draw
=
ImageDraw
.
Draw
(
image
)
for
data
in
data_list
:
left
,
right
,
top
,
bottom
=
data
[
'left'
],
data
[
'right'
],
data
[
'top'
],
data
[
'bottom'
]
left
,
right
,
top
,
bottom
=
data
[
'left'
],
data
[
'right'
],
data
[
'top'
],
data
[
'bottom'
]
# draw bbox
draw
.
line
([(
left
,
top
),
(
left
,
bottom
),
(
right
,
bottom
),
(
right
,
top
),
(
left
,
top
)],
width
=
2
,
fill
=
'red'
)
draw
.
line
([(
left
,
top
),
(
left
,
bottom
),
(
right
,
bottom
),
(
right
,
top
),
(
left
,
top
)],
width
=
2
,
fill
=
'red'
)
# draw label
if
image
.
mode
==
'RGB'
:
text
=
data
[
'label'
]
+
": %.2f%%"
%
(
100
*
data
[
'confidence'
])
textsize_width
,
textsize_height
=
draw
.
textsize
(
text
=
text
)
draw
.
rectangle
(
xy
=
(
left
,
top
-
(
textsize_height
+
5
),
left
+
textsize_width
+
10
,
top
),
fill
=
(
255
,
255
,
255
))
xy
=
(
left
,
top
-
(
textsize_height
+
5
),
left
+
textsize_width
+
10
,
top
),
fill
=
(
255
,
255
,
255
))
draw
.
text
(
xy
=
(
left
,
top
-
15
),
text
=
text
,
fill
=
(
0
,
0
,
0
))
save_name
=
get_save_image_name
(
image
,
save_dir
,
image_path
)
...
...
@@ -95,14 +91,7 @@ def load_label_info(file_path):
return
label_names
def
postprocess
(
paths
,
images
,
data_out
,
score_thresh
,
label_names
,
output_dir
,
handle_id
,
visualization
=
True
):
def
postprocess
(
paths
,
images
,
data_out
,
score_thresh
,
label_names
,
output_dir
,
handle_id
,
visualization
=
True
):
"""
postprocess the lod_tensor produced by fluid.Executor.run
...
...
@@ -127,9 +116,8 @@ def postprocess(paths,
confidence (float): The confidence of detection result.
save_path (str): The path to save output images.
"""
lod_tensor
=
data_out
[
0
]
lod
=
lod_tensor
.
lod
[
0
]
results
=
lod_tensor
.
as_ndarray
()
results
=
data_out
.
copy_to_cpu
()
lod
=
data_out
.
lod
()[
0
]
check_dir
(
output_dir
)
...
...
@@ -159,9 +147,7 @@ def postprocess(paths,
org_img
=
org_img
.
astype
(
np
.
uint8
)
org_img
=
Image
.
fromarray
(
org_img
[:,
:,
::
-
1
])
if
visualization
:
org_img_path
=
get_save_image_name
(
org_img
,
output_dir
,
'image_numpy_{}'
.
format
(
(
handle_id
+
index
)))
org_img_path
=
get_save_image_name
(
org_img
,
output_dir
,
'image_numpy_{}'
.
format
((
handle_id
+
index
)))
org_img
.
save
(
org_img_path
)
org_img_height
=
org_img
.
height
org_img_width
=
org_img
.
width
...
...
@@ -181,13 +167,11 @@ def postprocess(paths,
dt
=
{}
dt
[
'label'
]
=
label_names
[
category_id
]
dt
[
'confidence'
]
=
float
(
confidence
)
dt
[
'left'
],
dt
[
'top'
],
dt
[
'right'
],
dt
[
'bottom'
]
=
clip_bbox
(
bbox
,
org_img_width
,
org_img_height
)
dt
[
'left'
],
dt
[
'top'
],
dt
[
'right'
],
dt
[
'bottom'
]
=
clip_bbox
(
bbox
,
org_img_width
,
org_img_height
)
output_i
[
'data'
].
append
(
dt
)
output
.
append
(
output_i
)
if
visualization
:
output_i
[
'save_path'
]
=
draw_bounding_box_on_image
(
org_img_path
,
output_i
[
'data'
],
output_dir
)
output_i
[
'save_path'
]
=
draw_bounding_box_on_image
(
org_img_path
,
output_i
[
'data'
],
output_dir
)
return
output
modules/image/object_detection/yolov3_darknet53_coco2017/module.py
浏览文件 @
0d60bf5a
...
...
@@ -9,7 +9,10 @@ from functools import partial
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddlehub
as
hub
from
paddle.fluid.core
import
PaddleTensor
,
AnalysisConfig
,
create_paddle_predictor
from
paddle.inference
import
Config
from
paddle.inference
import
create_predictor
from
paddlehub.module.module
import
moduleinfo
,
runnable
,
serving
from
paddlehub.common.paddle_helper
import
add_vars_prefix
...
...
@@ -32,27 +35,54 @@ class YOLOv3DarkNet53Coco2017(hub.Module):
self
.
label_names
=
load_label_info
(
os
.
path
.
join
(
self
.
directory
,
"label_file.txt"
))
self
.
_set_config
()
def
_get_device_id
(
self
,
places
):
try
:
places
=
os
.
environ
[
places
]
id
=
int
(
places
)
except
:
id
=
-
1
return
id
def
_set_config
(
self
):
"""
predictor config setting.
"""
cpu_config
=
AnalysisConfig
(
self
.
default_pretrained_model_path
)
# create default cpu predictor
cpu_config
=
Config
(
self
.
default_pretrained_model_path
)
cpu_config
.
disable_glog_info
()
cpu_config
.
disable_gpu
()
cpu_config
.
switch_ir_optim
(
False
)
self
.
cpu_predictor
=
create_p
addle_p
redictor
(
cpu_config
)
self
.
cpu_predictor
=
create_predictor
(
cpu_config
)
try
:
_places
=
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
int
(
_places
[
0
])
use_gpu
=
True
except
:
use_gpu
=
False
if
use_gpu
:
gpu_config
=
AnalysisConfig
(
self
.
default_pretrained_model_path
)
# create predictors using various types of devices
# npu
npu_id
=
self
.
_get_device_id
(
"FLAGS_selected_npus"
)
if
npu_id
!=
-
1
:
# use npu
npu_config
=
Config
(
self
.
default_pretrained_model_path
)
npu_config
.
disable_glog_info
()
npu_config
.
enable_npu
(
device_id
=
npu_id
)
self
.
npu_predictor
=
create_predictor
(
npu_config
)
# gpu
gpu_id
=
self
.
_get_device_id
(
"CUDA_VISIBLE_DEVICES"
)
if
gpu_id
!=
-
1
:
# use gpu
gpu_config
=
Config
(
self
.
default_pretrained_model_path
)
gpu_config
.
disable_glog_info
()
gpu_config
.
enable_use_gpu
(
memory_pool_init_size_mb
=
500
,
device_id
=
0
)
self
.
gpu_predictor
=
create_paddle_predictor
(
gpu_config
)
gpu_config
.
enable_use_gpu
(
memory_pool_init_size_mb
=
500
,
device_id
=
gpu_id
)
self
.
gpu_predictor
=
create_predictor
(
gpu_config
)
# xpu
xpu_id
=
self
.
_get_device_id
(
"XPU_VISIBLE_DEVICES"
)
if
xpu_id
!=
-
1
:
# use xpu
xpu_config
=
Config
(
self
.
default_pretrained_model_path
)
xpu_config
.
disable_glog_info
()
xpu_config
.
enable_xpu
(
100
)
self
.
xpu_predictor
=
create_predictor
(
xpu_config
)
def
context
(
self
,
trainable
=
True
,
pretrained
=
True
,
get_prediction
=
False
):
"""
...
...
@@ -135,7 +165,8 @@ class YOLOv3DarkNet53Coco2017(hub.Module):
use_gpu
=
False
,
output_dir
=
'detection_result'
,
score_thresh
=
0.5
,
visualization
=
True
):
visualization
=
True
,
use_device
=
None
):
"""API of Object Detection.
Args:
...
...
@@ -146,6 +177,7 @@ class YOLOv3DarkNet53Coco2017(hub.Module):
output_dir (str): The path to store output images.
visualization (bool): Whether to save image or not.
score_thresh (float): threshold for object detecion.
use_device (str): use cpu, gpu, xpu or npu, overwrites use_gpu flag.
Returns:
res (list[dict]): The result of coco2017 detecion. keys include 'data', 'save_path', the corresponding value is:
...
...
@@ -158,14 +190,24 @@ class YOLOv3DarkNet53Coco2017(hub.Module):
confidence (float): The confidence of detection result.
save_path (str, optional): The path to save output images.
"""
if
use_gpu
:
try
:
_places
=
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
int
(
_places
[
0
])
except
:
raise
RuntimeError
(
"Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES as cuda_device_id."
)
# real predictor to use
if
use_device
is
not
None
:
if
use_device
==
"cpu"
:
predictor
=
self
.
cpu_predictor
elif
use_device
==
"xpu"
:
predictor
=
self
.
xpu_predictor
elif
use_device
==
"npu"
:
predictor
=
self
.
npu_predictor
elif
use_device
==
"gpu"
:
predictor
=
self
.
gpu_predictor
else
:
raise
Exception
(
"Unsupported device: "
+
use_device
)
else
:
# use_device is not set, therefore follow use_gpu
if
use_gpu
:
predictor
=
self
.
gpu_predictor
else
:
predictor
=
self
.
cpu_predictor
paths
=
paths
if
paths
else
list
()
if
data
and
'image'
in
data
:
...
...
@@ -176,17 +218,27 @@ class YOLOv3DarkNet53Coco2017(hub.Module):
res
=
[]
for
iter_id
,
feed_data
in
enumerate
(
batch_reader
()):
feed_data
=
np
.
array
(
feed_data
)
image_tensor
=
PaddleTensor
(
np
.
array
(
list
(
feed_data
[:,
0
])))
im_size_tensor
=
PaddleTensor
(
np
.
array
(
list
(
feed_data
[:,
1
])))
if
use_gpu
:
data_out
=
self
.
gpu_predictor
.
run
([
image_tensor
,
im_size_tensor
])
else
:
data_out
=
self
.
cpu_predictor
.
run
([
image_tensor
,
im_size_tensor
])
input_names
=
predictor
.
get_input_names
()
image_data
=
np
.
array
(
list
(
feed_data
[:,
0
]))
image_size_data
=
np
.
array
(
list
(
feed_data
[:,
1
]))
image_tensor
=
predictor
.
get_input_handle
(
input_names
[
0
])
image_tensor
.
reshape
(
image_data
.
shape
)
image_tensor
.
copy_from_cpu
(
image_data
.
copy
())
image_size_tensor
=
predictor
.
get_input_handle
(
input_names
[
1
])
image_size_tensor
.
reshape
(
image_size_data
.
shape
)
image_size_tensor
.
copy_from_cpu
(
image_size_data
.
copy
())
predictor
.
run
()
output_names
=
predictor
.
get_output_names
()
output_handle
=
predictor
.
get_output_handle
(
output_names
[
0
])
output
=
postprocess
(
paths
=
paths
,
images
=
images
,
data_out
=
data_out
,
data_out
=
output_handle
,
score_thresh
=
score_thresh
,
label_names
=
self
.
label_names
,
output_dir
=
output_dir
,
...
...
@@ -245,7 +297,8 @@ class YOLOv3DarkNet53Coco2017(hub.Module):
use_gpu
=
args
.
use_gpu
,
output_dir
=
args
.
output_dir
,
visualization
=
args
.
visualization
,
score_thresh
=
args
.
score_thresh
)
score_thresh
=
args
.
score_thresh
,
use_device
=
args
.
use_device
)
return
results
def
add_module_config_arg
(
self
):
...
...
@@ -258,6 +311,10 @@ class YOLOv3DarkNet53Coco2017(hub.Module):
'--output_dir'
,
type
=
str
,
default
=
'detection_result'
,
help
=
"The directory to save output images."
)
self
.
arg_config_group
.
add_argument
(
'--visualization'
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"whether to save output as images."
)
self
.
arg_config_group
.
add_argument
(
'--use_device'
,
choices
=
[
"cpu"
,
"gpu"
,
"xpu"
,
"npu"
],
help
=
"use cpu, gpu, xpu or npu. overwrites use_gpu flag."
)
def
add_module_input_arg
(
self
):
"""
...
...
modules/image/object_detection/yolov3_darknet53_coco2017/processor.py
浏览文件 @
0d60bf5a
...
...
@@ -94,8 +94,6 @@ def postprocess(paths, images, data_out, score_thresh, label_names, output_dir,
paths (list[str]): The paths of images.
images (list(numpy.ndarray)): images data, shape of each is [H, W, C]
data_out (lod_tensor): data output of predictor.
batch_size (int): batch size.
use_gpu (bool): Whether to use gpu.
output_dir (str): The path to store output images.
visualization (bool): Whether to save image or not.
score_thresh (float): the low limit of bounding box.
...
...
@@ -113,9 +111,8 @@ def postprocess(paths, images, data_out, score_thresh, label_names, output_dir,
confidence (float): The confidence of detection result.
save_path (str): The path to save output images.
"""
lod_tensor
=
data_out
[
0
]
lod
=
lod_tensor
.
lod
[
0
]
results
=
lod_tensor
.
as_ndarray
()
results
=
data_out
.
copy_to_cpu
()
lod
=
data_out
.
lod
()[
0
]
check_dir
(
output_dir
)
...
...
modules/image/object_detection/yolov3_darknet53_pedestrian/module.py
浏览文件 @
0d60bf5a
...
...
@@ -9,7 +9,10 @@ from functools import partial
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddlehub
as
hub
from
paddle.fluid.core
import
PaddleTensor
,
AnalysisConfig
,
create_paddle_predictor
from
paddle.inference
import
Config
from
paddle.inference
import
create_predictor
from
paddlehub.module.module
import
moduleinfo
,
runnable
,
serving
from
paddlehub.common.paddle_helper
import
add_vars_prefix
...
...
@@ -23,39 +26,63 @@ from yolov3_darknet53_pedestrian.yolo_head import MultiClassNMS, YOLOv3Head
name
=
"yolov3_darknet53_pedestrian"
,
version
=
"1.0.2"
,
type
=
"CV/object_detection"
,
summary
=
"Baidu's YOLOv3 model for pedestrian detection, with backbone DarkNet53."
,
summary
=
"Baidu's YOLOv3 model for pedestrian detection, with backbone DarkNet53."
,
author
=
"paddlepaddle"
,
author_email
=
"paddle-dev@baidu.com"
)
class
YOLOv3DarkNet53Pedestrian
(
hub
.
Module
):
def
_initialize
(
self
):
self
.
default_pretrained_model_path
=
os
.
path
.
join
(
self
.
directory
,
"yolov3_darknet53_pedestrian_model"
)
self
.
label_names
=
load_label_info
(
os
.
path
.
join
(
self
.
directory
,
"label_file.txt"
))
self
.
default_pretrained_model_path
=
os
.
path
.
join
(
self
.
directory
,
"yolov3_darknet53_pedestrian_model"
)
self
.
label_names
=
load_label_info
(
os
.
path
.
join
(
self
.
directory
,
"label_file.txt"
))
self
.
_set_config
()
def
_get_device_id
(
self
,
places
):
try
:
places
=
os
.
environ
[
places
]
id
=
int
(
places
)
except
:
id
=
-
1
return
id
def
_set_config
(
self
):
"""
predictor config setting.
"""
cpu_config
=
AnalysisConfig
(
self
.
default_pretrained_model_path
)
# create default cpu predictor
cpu_config
=
Config
(
self
.
default_pretrained_model_path
)
cpu_config
.
disable_glog_info
()
cpu_config
.
disable_gpu
()
cpu_config
.
switch_ir_optim
(
False
)
self
.
cpu_predictor
=
create_p
addle_p
redictor
(
cpu_config
)
self
.
cpu_predictor
=
create_predictor
(
cpu_config
)
try
:
_places
=
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
int
(
_places
[
0
])
use_gpu
=
True
except
:
use_gpu
=
False
if
use_gpu
:
gpu_config
=
AnalysisConfig
(
self
.
default_pretrained_model_path
)
# create predictors using various types of devices
# npu
npu_id
=
self
.
_get_device_id
(
"FLAGS_selected_npus"
)
if
npu_id
!=
-
1
:
# use npu
npu_config
=
Config
(
self
.
default_pretrained_model_path
)
npu_config
.
disable_glog_info
()
npu_config
.
enable_npu
(
device_id
=
npu_id
)
self
.
npu_predictor
=
create_predictor
(
npu_config
)
# gpu
gpu_id
=
self
.
_get_device_id
(
"CUDA_VISIBLE_DEVICES"
)
if
gpu_id
!=
-
1
:
# use gpu
gpu_config
=
Config
(
self
.
default_pretrained_model_path
)
gpu_config
.
disable_glog_info
()
gpu_config
.
enable_use_gpu
(
memory_pool_init_size_mb
=
500
,
device_id
=
0
)
self
.
gpu_predictor
=
create_paddle_predictor
(
gpu_config
)
gpu_config
.
enable_use_gpu
(
memory_pool_init_size_mb
=
500
,
device_id
=
gpu_id
)
self
.
gpu_predictor
=
create_predictor
(
gpu_config
)
# xpu
xpu_id
=
self
.
_get_device_id
(
"XPU_VISIBLE_DEVICES"
)
if
xpu_id
!=
-
1
:
# use xpu
xpu_config
=
Config
(
self
.
default_pretrained_model_path
)
xpu_config
.
disable_glog_info
()
xpu_config
.
enable_xpu
(
100
)
self
.
xpu_predictor
=
create_predictor
(
xpu_config
)
def
context
(
self
,
trainable
=
True
,
pretrained
=
True
,
get_prediction
=
False
):
"""
...
...
@@ -76,20 +103,18 @@ class YOLOv3DarkNet53Pedestrian(hub.Module):
with
fluid
.
program_guard
(
context_prog
,
startup_program
):
with
fluid
.
unique_name
.
guard
():
# image
image
=
fluid
.
layers
.
data
(
name
=
'image'
,
shape
=
[
3
,
608
,
608
],
dtype
=
'float32'
)
image
=
fluid
.
layers
.
data
(
name
=
'image'
,
shape
=
[
3
,
608
,
608
],
dtype
=
'float32'
)
# backbone
backbone
=
DarkNet
(
norm_type
=
'sync_bn'
,
norm_decay
=
0.
,
depth
=
53
)
# body_feats
body_feats
=
backbone
(
image
)
# im_size
im_size
=
fluid
.
layers
.
data
(
name
=
'im_size'
,
shape
=
[
2
],
dtype
=
'int32'
)
im_size
=
fluid
.
layers
.
data
(
name
=
'im_size'
,
shape
=
[
2
],
dtype
=
'int32'
)
# yolo_head
yolo_head
=
YOLOv3Head
(
anchor_masks
=
[[
6
,
7
,
8
],
[
3
,
4
,
5
],
[
0
,
1
,
2
]],
anchors
=
[[
10
,
13
],
[
16
,
30
],
[
33
,
23
],
[
30
,
61
],
[
62
,
45
],
[
59
,
119
],
[
116
,
90
],
[
156
,
198
],
[
373
,
326
]],
anchors
=
[[
10
,
13
],
[
16
,
30
],
[
33
,
23
],
[
30
,
61
],
[
62
,
45
],
[
59
,
119
],
[
116
,
90
],
[
156
,
198
],
[
373
,
326
]],
norm_decay
=
0.
,
num_classes
=
1
,
ignore_thresh
=
0.7
,
...
...
@@ -102,8 +127,7 @@ class YOLOv3DarkNet53Pedestrian(hub.Module):
normalized
=
False
,
score_threshold
=
0.01
))
# head_features
head_features
,
body_features
=
yolo_head
.
_get_outputs
(
body_feats
,
is_train
=
trainable
)
head_features
,
body_features
=
yolo_head
.
_get_outputs
(
body_feats
,
is_train
=
trainable
)
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
...
...
@@ -112,35 +136,24 @@ class YOLOv3DarkNet53Pedestrian(hub.Module):
# var_prefix
var_prefix
=
'@HUB_{}@'
.
format
(
self
.
name
)
# name of inputs
inputs
=
{
'image'
:
var_prefix
+
image
.
name
,
'im_size'
:
var_prefix
+
im_size
.
name
}
inputs
=
{
'image'
:
var_prefix
+
image
.
name
,
'im_size'
:
var_prefix
+
im_size
.
name
}
# name of outputs
if
get_prediction
:
bbox_out
=
yolo_head
.
get_prediction
(
head_features
,
im_size
)
outputs
=
{
'bbox_out'
:
[
var_prefix
+
bbox_out
.
name
]}
else
:
outputs
=
{
'head_features'
:
[
var_prefix
+
var
.
name
for
var
in
head_features
],
'body_features'
:
[
var_prefix
+
var
.
name
for
var
in
body_features
]
'head_features'
:
[
var_prefix
+
var
.
name
for
var
in
head_features
],
'body_features'
:
[
var_prefix
+
var
.
name
for
var
in
body_features
]
}
# add_vars_prefix
add_vars_prefix
(
context_prog
,
var_prefix
)
add_vars_prefix
(
fluid
.
default_startup_program
(),
var_prefix
)
# inputs
inputs
=
{
key
:
context_prog
.
global_block
().
vars
[
value
]
for
key
,
value
in
inputs
.
items
()
}
inputs
=
{
key
:
context_prog
.
global_block
().
vars
[
value
]
for
key
,
value
in
inputs
.
items
()}
# outputs
outputs
=
{
key
:
[
context_prog
.
global_block
().
vars
[
varname
]
for
varname
in
value
]
key
:
[
context_prog
.
global_block
().
vars
[
varname
]
for
varname
in
value
]
for
key
,
value
in
outputs
.
items
()
}
# trainable
...
...
@@ -150,14 +163,9 @@ class YOLOv3DarkNet53Pedestrian(hub.Module):
if
pretrained
:
def
_if_exist
(
var
):
return
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
default_pretrained_model_path
,
var
.
name
))
fluid
.
io
.
load_vars
(
exe
,
self
.
default_pretrained_model_path
,
predicate
=
_if_exist
)
return
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
default_pretrained_model_path
,
var
.
name
))
fluid
.
io
.
load_vars
(
exe
,
self
.
default_pretrained_model_path
,
predicate
=
_if_exist
)
else
:
exe
.
run
(
startup_program
)
...
...
@@ -170,7 +178,8 @@ class YOLOv3DarkNet53Pedestrian(hub.Module):
use_gpu
=
False
,
output_dir
=
'yolov3_pedestrian_detect_output'
,
score_thresh
=
0.2
,
visualization
=
True
):
visualization
=
True
,
use_device
=
None
):
"""API of Object Detection.
Args:
...
...
@@ -181,6 +190,7 @@ class YOLOv3DarkNet53Pedestrian(hub.Module):
output_dir (str): The path to store output images.
visualization (bool): Whether to save image or not.
score_thresh (float): threshold for object detecion.
use_device (str): use cpu, gpu, xpu or npu, overwrites use_gpu flag.
Returns:
res (list[dict]): The result of pedestrian detecion. keys include 'data', 'save_path', the corresponding value is:
...
...
@@ -193,14 +203,24 @@ class YOLOv3DarkNet53Pedestrian(hub.Module):
confidence (float): The confidence of detection result.
save_path (str, optional): The path to save output images.
"""
if
use_gpu
:
try
:
_places
=
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
int
(
_places
[
0
])
except
:
raise
RuntimeError
(
"Attempt to use GPU for prediction, but environment variable CUDA_VISIBLE_DEVICES was not set correctly."
)
# real predictor to use
if
use_device
is
not
None
:
if
use_device
==
"cpu"
:
predictor
=
self
.
cpu_predictor
elif
use_device
==
"xpu"
:
predictor
=
self
.
xpu_predictor
elif
use_device
==
"npu"
:
predictor
=
self
.
npu_predictor
elif
use_device
==
"gpu"
:
predictor
=
self
.
gpu_predictor
else
:
raise
Exception
(
"Unsupported device: "
+
use_device
)
else
:
# use_device is not set, therefore follow use_gpu
if
use_gpu
:
predictor
=
self
.
gpu_predictor
else
:
predictor
=
self
.
cpu_predictor
paths
=
paths
if
paths
else
list
()
data_reader
=
partial
(
reader
,
paths
,
images
)
...
...
@@ -208,19 +228,27 @@ class YOLOv3DarkNet53Pedestrian(hub.Module):
res
=
[]
for
iter_id
,
feed_data
in
enumerate
(
batch_reader
()):
feed_data
=
np
.
array
(
feed_data
)
image_tensor
=
PaddleTensor
(
np
.
array
(
list
(
feed_data
[:,
0
])))
im_size_tensor
=
PaddleTensor
(
np
.
array
(
list
(
feed_data
[:,
1
])))
if
use_gpu
:
data_out
=
self
.
gpu_predictor
.
run
(
[
image_tensor
,
im_size_tensor
])
else
:
data_out
=
self
.
cpu_predictor
.
run
(
[
image_tensor
,
im_size_tensor
])
input_names
=
predictor
.
get_input_names
()
image_data
=
np
.
array
(
list
(
feed_data
[:,
0
]))
image_size_data
=
np
.
array
(
list
(
feed_data
[:,
1
]))
image_tensor
=
predictor
.
get_input_handle
(
input_names
[
0
])
image_tensor
.
reshape
(
image_data
.
shape
)
image_tensor
.
copy_from_cpu
(
image_data
.
copy
())
image_size_tensor
=
predictor
.
get_input_handle
(
input_names
[
1
])
image_size_tensor
.
reshape
(
image_size_data
.
shape
)
image_size_tensor
.
copy_from_cpu
(
image_size_data
.
copy
())
predictor
.
run
()
output_names
=
predictor
.
get_output_names
()
output_handle
=
predictor
.
get_output_handle
(
output_names
[
0
])
output
=
postprocess
(
paths
=
paths
,
images
=
images
,
data_out
=
data_out
,
data_out
=
output_handle
,
score_thresh
=
score_thresh
,
label_names
=
self
.
label_names
,
output_dir
=
output_dir
,
...
...
@@ -229,11 +257,7 @@ class YOLOv3DarkNet53Pedestrian(hub.Module):
res
.
extend
(
output
)
return
res
def
save_inference_model
(
self
,
dirname
,
model_filename
=
None
,
params_filename
=
None
,
combined
=
True
):
def
save_inference_model
(
self
,
dirname
,
model_filename
=
None
,
params_filename
=
None
,
combined
=
True
):
if
combined
:
model_filename
=
"__model__"
if
not
model_filename
else
model_filename
params_filename
=
"__params__"
if
not
params_filename
else
params_filename
...
...
@@ -271,12 +295,9 @@ class YOLOv3DarkNet53Pedestrian(hub.Module):
prog
=
'hub run {}'
.
format
(
self
.
name
),
usage
=
'%(prog)s'
,
add_help
=
True
)
self
.
arg_input_group
=
self
.
parser
.
add_argument_group
(
title
=
"Input options"
,
description
=
"Input data. Required"
)
self
.
arg_input_group
=
self
.
parser
.
add_argument_group
(
title
=
"Input options"
,
description
=
"Input data. Required"
)
self
.
arg_config_group
=
self
.
parser
.
add_argument_group
(
title
=
"Config options"
,
description
=
"Run configuration for controlling module behavior, not required."
)
title
=
"Config options"
,
description
=
"Run configuration for controlling module behavior, not required."
)
self
.
add_module_config_arg
()
self
.
add_module_input_arg
()
args
=
self
.
parser
.
parse_args
(
argvs
)
...
...
@@ -286,7 +307,8 @@ class YOLOv3DarkNet53Pedestrian(hub.Module):
use_gpu
=
args
.
use_gpu
,
output_dir
=
args
.
output_dir
,
visualization
=
args
.
visualization
,
score_thresh
=
args
.
score_thresh
)
score_thresh
=
args
.
score_thresh
,
use_device
=
args
.
use_device
)
return
results
def
add_module_config_arg
(
self
):
...
...
@@ -294,34 +316,24 @@ class YOLOv3DarkNet53Pedestrian(hub.Module):
Add the command config options.
"""
self
.
arg_config_group
.
add_argument
(
'--use_gpu'
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"whether use GPU or not"
)
'--use_gpu'
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"whether use GPU or not"
)
self
.
arg_config_group
.
add_argument
(
'--output_dir'
,
type
=
str
,
default
=
'yolov3_pedestrian_detect_output'
,
help
=
"The directory to save output images."
)
self
.
arg_config_group
.
add_argument
(
'--visualization'
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"whether to save output as images."
)
'--visualization'
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"whether to save output as images."
)
self
.
arg_config_group
.
add_argument
(
'--use_device'
,
choices
=
[
"cpu"
,
"gpu"
,
"xpu"
,
"npu"
],
help
=
"use cpu, gpu, xpu or npu. overwrites use_gpu flag."
)
def
add_module_input_arg
(
self
):
"""
Add the command input options.
"""
self
.
arg_input_group
.
add_argument
(
'--input_path'
,
type
=
str
,
help
=
"path to image."
)
self
.
arg_input_group
.
add_argument
(
'--batch_size'
,
type
=
ast
.
literal_eval
,
default
=
1
,
help
=
"batch size."
)
self
.
arg_input_group
.
add_argument
(
'--input_path'
,
type
=
str
,
help
=
"path to image."
)
self
.
arg_input_group
.
add_argument
(
'--batch_size'
,
type
=
ast
.
literal_eval
,
default
=
1
,
help
=
"batch size."
)
self
.
arg_input_group
.
add_argument
(
'--score_thresh'
,
type
=
ast
.
literal_eval
,
default
=
0.2
,
help
=
"threshold for object detecion."
)
'--score_thresh'
,
type
=
ast
.
literal_eval
,
default
=
0.2
,
help
=
"threshold for object detecion."
)
modules/image/object_detection/yolov3_darknet53_pedestrian/processor.py
浏览文件 @
0d60bf5a
...
...
@@ -50,21 +50,15 @@ def draw_bounding_box_on_image(image_path, data_list, save_dir):
image
=
Image
.
open
(
image_path
)
draw
=
ImageDraw
.
Draw
(
image
)
for
data
in
data_list
:
left
,
right
,
top
,
bottom
=
data
[
'left'
],
data
[
'right'
],
data
[
'top'
],
data
[
'bottom'
]
left
,
right
,
top
,
bottom
=
data
[
'left'
],
data
[
'right'
],
data
[
'top'
],
data
[
'bottom'
]
# draw bbox
draw
.
line
([(
left
,
top
),
(
left
,
bottom
),
(
right
,
bottom
),
(
right
,
top
),
(
left
,
top
)],
width
=
2
,
fill
=
'red'
)
draw
.
line
([(
left
,
top
),
(
left
,
bottom
),
(
right
,
bottom
),
(
right
,
top
),
(
left
,
top
)],
width
=
2
,
fill
=
'red'
)
# draw label
if
image
.
mode
==
'RGB'
:
text
=
data
[
'label'
]
+
": %.2f%%"
%
(
100
*
data
[
'confidence'
])
textsize_width
,
textsize_height
=
draw
.
textsize
(
text
=
text
)
draw
.
rectangle
(
xy
=
(
left
,
top
-
(
textsize_height
+
5
),
left
+
textsize_width
+
10
,
top
),
fill
=
(
255
,
255
,
255
))
xy
=
(
left
,
top
-
(
textsize_height
+
5
),
left
+
textsize_width
+
10
,
top
),
fill
=
(
255
,
255
,
255
))
draw
.
text
(
xy
=
(
left
,
top
-
15
),
text
=
text
,
fill
=
(
0
,
0
,
0
))
save_name
=
get_save_image_name
(
image
,
save_dir
,
image_path
)
...
...
@@ -92,14 +86,7 @@ def load_label_info(file_path):
return
label_names
def
postprocess
(
paths
,
images
,
data_out
,
score_thresh
,
label_names
,
output_dir
,
handle_id
,
visualization
=
True
):
def
postprocess
(
paths
,
images
,
data_out
,
score_thresh
,
label_names
,
output_dir
,
handle_id
,
visualization
=
True
):
"""
postprocess the lod_tensor produced by fluid.Executor.run
...
...
@@ -107,8 +94,6 @@ def postprocess(paths,
paths (list[str]): The paths of images.
images (list(numpy.ndarray)): images data, shape of each is [H, W, C]
data_out (lod_tensor): data output of predictor.
batch_size (int): batch size.
use_gpu (bool): Whether to use gpu.
output_dir (str): The path to store output images.
visualization (bool): Whether to save image or not.
score_thresh (float): the low limit of bounding box.
...
...
@@ -126,9 +111,8 @@ def postprocess(paths,
confidence (float): The confidence of detection result.
save_path (str): The path to save output images.
"""
lod_tensor
=
data_out
[
0
]
lod
=
lod_tensor
.
lod
[
0
]
results
=
lod_tensor
.
as_ndarray
()
results
=
data_out
.
copy_to_cpu
()
lod
=
data_out
.
lod
()[
0
]
check_dir
(
output_dir
)
...
...
@@ -146,7 +130,6 @@ def postprocess(paths,
else
:
unhandled_paths_num
=
0
output
=
list
()
for
index
in
range
(
len
(
lod
)
-
1
):
output_i
=
{
'data'
:
[]}
...
...
@@ -158,9 +141,7 @@ def postprocess(paths,
org_img
=
org_img
.
astype
(
np
.
uint8
)
org_img
=
Image
.
fromarray
(
org_img
[:,
:,
::
-
1
])
if
visualization
:
org_img_path
=
get_save_image_name
(
org_img
,
output_dir
,
'image_numpy_{}'
.
format
(
(
handle_id
+
index
)))
org_img_path
=
get_save_image_name
(
org_img
,
output_dir
,
'image_numpy_{}'
.
format
((
handle_id
+
index
)))
org_img
.
save
(
org_img_path
)
org_img_height
=
org_img
.
height
org_img_width
=
org_img
.
width
...
...
@@ -176,13 +157,11 @@ def postprocess(paths,
dt
=
{}
dt
[
'label'
]
=
label_names
[
category_id
]
dt
[
'confidence'
]
=
float
(
confidence
)
dt
[
'left'
],
dt
[
'top'
],
dt
[
'right'
],
dt
[
'bottom'
]
=
clip_bbox
(
bbox
,
org_img_width
,
org_img_height
)
dt
[
'left'
],
dt
[
'top'
],
dt
[
'right'
],
dt
[
'bottom'
]
=
clip_bbox
(
bbox
,
org_img_width
,
org_img_height
)
output_i
[
'data'
].
append
(
dt
)
output
.
append
(
output_i
)
if
visualization
:
output_i
[
'save_path'
]
=
draw_bounding_box_on_image
(
org_img_path
,
output_i
[
'data'
],
output_dir
)
output_i
[
'save_path'
]
=
draw_bounding_box_on_image
(
org_img_path
,
output_i
[
'data'
],
output_dir
)
return
output
modules/image/object_detection/yolov3_mobilenet_v1_coco2017/module.py
浏览文件 @
0d60bf5a
...
...
@@ -9,7 +9,10 @@ from functools import partial
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddlehub
as
hub
from
paddle.fluid.core
import
PaddleTensor
,
AnalysisConfig
,
create_paddle_predictor
from
paddle.inference
import
Config
from
paddle.inference
import
create_predictor
from
paddlehub.module.module
import
moduleinfo
,
runnable
,
serving
from
paddlehub.common.paddle_helper
import
add_vars_prefix
...
...
@@ -23,39 +26,63 @@ from yolov3_mobilenet_v1_coco2017.yolo_head import MultiClassNMS, YOLOv3Head
name
=
"yolov3_mobilenet_v1_coco2017"
,
version
=
"1.0.2"
,
type
=
"CV/object_detection"
,
summary
=
"Baidu's YOLOv3 model for object detection with backbone MobileNet_V1, trained with dataset COCO2017."
,
summary
=
"Baidu's YOLOv3 model for object detection with backbone MobileNet_V1, trained with dataset COCO2017."
,
author
=
"paddlepaddle"
,
author_email
=
"paddle-dev@baidu.com"
)
class
YOLOv3MobileNetV1Coco2017
(
hub
.
Module
):
def
_initialize
(
self
):
self
.
default_pretrained_model_path
=
os
.
path
.
join
(
self
.
directory
,
"yolov3_mobilenet_v1_model"
)
self
.
label_names
=
load_label_info
(
os
.
path
.
join
(
self
.
directory
,
"label_file.txt"
))
self
.
default_pretrained_model_path
=
os
.
path
.
join
(
self
.
directory
,
"yolov3_mobilenet_v1_model"
)
self
.
label_names
=
load_label_info
(
os
.
path
.
join
(
self
.
directory
,
"label_file.txt"
))
self
.
_set_config
()
def
_get_device_id
(
self
,
places
):
try
:
places
=
os
.
environ
[
places
]
id
=
int
(
places
)
except
:
id
=
-
1
return
id
def
_set_config
(
self
):
"""
predictor config setting.
"""
cpu_config
=
AnalysisConfig
(
self
.
default_pretrained_model_path
)
# create default cpu predictor
cpu_config
=
Config
(
self
.
default_pretrained_model_path
)
cpu_config
.
disable_glog_info
()
cpu_config
.
disable_gpu
()
cpu_config
.
switch_ir_optim
(
False
)
self
.
cpu_predictor
=
create_p
addle_p
redictor
(
cpu_config
)
self
.
cpu_predictor
=
create_predictor
(
cpu_config
)
try
:
_places
=
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
int
(
_places
[
0
])
use_gpu
=
True
except
:
use_gpu
=
False
if
use_gpu
:
gpu_config
=
AnalysisConfig
(
self
.
default_pretrained_model_path
)
# create predictors using various types of devices
# npu
npu_id
=
self
.
_get_device_id
(
"FLAGS_selected_npus"
)
if
npu_id
!=
-
1
:
# use npu
npu_config
=
Config
(
self
.
default_pretrained_model_path
)
npu_config
.
disable_glog_info
()
npu_config
.
enable_npu
(
device_id
=
npu_id
)
self
.
npu_predictor
=
create_predictor
(
npu_config
)
# gpu
gpu_id
=
self
.
_get_device_id
(
"CUDA_VISIBLE_DEVICES"
)
if
gpu_id
!=
-
1
:
# use gpu
gpu_config
=
Config
(
self
.
default_pretrained_model_path
)
gpu_config
.
disable_glog_info
()
gpu_config
.
enable_use_gpu
(
memory_pool_init_size_mb
=
500
,
device_id
=
0
)
self
.
gpu_predictor
=
create_paddle_predictor
(
gpu_config
)
gpu_config
.
enable_use_gpu
(
memory_pool_init_size_mb
=
500
,
device_id
=
gpu_id
)
self
.
gpu_predictor
=
create_predictor
(
gpu_config
)
# xpu
xpu_id
=
self
.
_get_device_id
(
"XPU_VISIBLE_DEVICES"
)
if
xpu_id
!=
-
1
:
# use xpu
xpu_config
=
Config
(
self
.
default_pretrained_model_path
)
xpu_config
.
disable_glog_info
()
xpu_config
.
enable_xpu
(
100
)
self
.
xpu_predictor
=
create_predictor
(
xpu_config
)
def
context
(
self
,
trainable
=
True
,
pretrained
=
True
,
get_prediction
=
False
):
"""
...
...
@@ -76,24 +103,17 @@ class YOLOv3MobileNetV1Coco2017(hub.Module):
with
fluid
.
program_guard
(
context_prog
,
startup_program
):
with
fluid
.
unique_name
.
guard
():
# image
image
=
fluid
.
layers
.
data
(
name
=
'image'
,
shape
=
[
3
,
608
,
608
],
dtype
=
'float32'
)
image
=
fluid
.
layers
.
data
(
name
=
'image'
,
shape
=
[
3
,
608
,
608
],
dtype
=
'float32'
)
# backbone
backbone
=
MobileNet
(
norm_type
=
'sync_bn'
,
norm_decay
=
0.
,
conv_group_scale
=
1
,
with_extra_blocks
=
False
)
backbone
=
MobileNet
(
norm_type
=
'sync_bn'
,
norm_decay
=
0.
,
conv_group_scale
=
1
,
with_extra_blocks
=
False
)
# body_feats
body_feats
=
backbone
(
image
)
# im_size
im_size
=
fluid
.
layers
.
data
(
name
=
'im_size'
,
shape
=
[
2
],
dtype
=
'int32'
)
im_size
=
fluid
.
layers
.
data
(
name
=
'im_size'
,
shape
=
[
2
],
dtype
=
'int32'
)
# yolo_head
yolo_head
=
YOLOv3Head
(
num_classes
=
80
)
# head_features
head_features
,
body_features
=
yolo_head
.
_get_outputs
(
body_feats
,
is_train
=
trainable
)
head_features
,
body_features
=
yolo_head
.
_get_outputs
(
body_feats
,
is_train
=
trainable
)
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
...
...
@@ -102,35 +122,24 @@ class YOLOv3MobileNetV1Coco2017(hub.Module):
# var_prefix
var_prefix
=
'@HUB_{}@'
.
format
(
self
.
name
)
# name of inputs
inputs
=
{
'image'
:
var_prefix
+
image
.
name
,
'im_size'
:
var_prefix
+
im_size
.
name
}
inputs
=
{
'image'
:
var_prefix
+
image
.
name
,
'im_size'
:
var_prefix
+
im_size
.
name
}
# name of outputs
if
get_prediction
:
bbox_out
=
yolo_head
.
get_prediction
(
head_features
,
im_size
)
outputs
=
{
'bbox_out'
:
[
var_prefix
+
bbox_out
.
name
]}
else
:
outputs
=
{
'head_features'
:
[
var_prefix
+
var
.
name
for
var
in
head_features
],
'body_features'
:
[
var_prefix
+
var
.
name
for
var
in
body_features
]
'head_features'
:
[
var_prefix
+
var
.
name
for
var
in
head_features
],
'body_features'
:
[
var_prefix
+
var
.
name
for
var
in
body_features
]
}
# add_vars_prefix
add_vars_prefix
(
context_prog
,
var_prefix
)
add_vars_prefix
(
startup_program
,
var_prefix
)
# inputs
inputs
=
{
key
:
context_prog
.
global_block
().
vars
[
value
]
for
key
,
value
in
inputs
.
items
()
}
inputs
=
{
key
:
context_prog
.
global_block
().
vars
[
value
]
for
key
,
value
in
inputs
.
items
()}
# outputs
outputs
=
{
key
:
[
context_prog
.
global_block
().
vars
[
varname
]
for
varname
in
value
]
key
:
[
context_prog
.
global_block
().
vars
[
varname
]
for
varname
in
value
]
for
key
,
value
in
outputs
.
items
()
}
# trainable
...
...
@@ -140,14 +149,9 @@ class YOLOv3MobileNetV1Coco2017(hub.Module):
if
pretrained
:
def
_if_exist
(
var
):
return
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
default_pretrained_model_path
,
var
.
name
))
fluid
.
io
.
load_vars
(
exe
,
self
.
default_pretrained_model_path
,
predicate
=
_if_exist
)
return
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
default_pretrained_model_path
,
var
.
name
))
fluid
.
io
.
load_vars
(
exe
,
self
.
default_pretrained_model_path
,
predicate
=
_if_exist
)
else
:
exe
.
run
(
startup_program
)
...
...
@@ -160,7 +164,8 @@ class YOLOv3MobileNetV1Coco2017(hub.Module):
use_gpu
=
False
,
output_dir
=
'detection_result'
,
score_thresh
=
0.5
,
visualization
=
True
):
visualization
=
True
,
use_device
=
None
):
"""API of Object Detection.
Args:
...
...
@@ -171,6 +176,7 @@ class YOLOv3MobileNetV1Coco2017(hub.Module):
output_dir (str): The path to store output images.
visualization (bool): Whether to save image or not.
score_thresh (float): threshold for object detecion.
use_device (str): use cpu, gpu, xpu or npu, overwrites use_gpu flag.
Returns:
res (list[dict]): The result of coco2017 detecion. keys include 'data', 'save_path', the corresponding value is:
...
...
@@ -183,14 +189,24 @@ class YOLOv3MobileNetV1Coco2017(hub.Module):
confidence (float): The confidence of detection result.
save_path (str, optional): The path to save output images.
"""
if
use_gpu
:
try
:
_places
=
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
int
(
_places
[
0
])
except
:
raise
RuntimeError
(
"Attempt to use GPU for prediction, but environment variable CUDA_VISIBLE_DEVICES was not set correctly."
)
# real predictor to use
if
use_device
is
not
None
:
if
use_device
==
"cpu"
:
predictor
=
self
.
cpu_predictor
elif
use_device
==
"xpu"
:
predictor
=
self
.
xpu_predictor
elif
use_device
==
"npu"
:
predictor
=
self
.
npu_predictor
elif
use_device
==
"gpu"
:
predictor
=
self
.
gpu_predictor
else
:
raise
Exception
(
"Unsupported device: "
+
use_device
)
else
:
# use_device is not set, therefore follow use_gpu
if
use_gpu
:
predictor
=
self
.
gpu_predictor
else
:
predictor
=
self
.
cpu_predictor
paths
=
paths
if
paths
else
list
()
data_reader
=
partial
(
reader
,
paths
,
images
)
...
...
@@ -198,19 +214,27 @@ class YOLOv3MobileNetV1Coco2017(hub.Module):
res
=
[]
for
iter_id
,
feed_data
in
enumerate
(
batch_reader
()):
feed_data
=
np
.
array
(
feed_data
)
image_tensor
=
PaddleTensor
(
np
.
array
(
list
(
feed_data
[:,
0
])))
im_size_tensor
=
PaddleTensor
(
np
.
array
(
list
(
feed_data
[:,
1
])))
if
use_gpu
:
data_out
=
self
.
gpu_predictor
.
run
(
[
image_tensor
,
im_size_tensor
])
else
:
data_out
=
self
.
cpu_predictor
.
run
(
[
image_tensor
,
im_size_tensor
])
input_names
=
predictor
.
get_input_names
()
image_data
=
np
.
array
(
list
(
feed_data
[:,
0
]))
image_size_data
=
np
.
array
(
list
(
feed_data
[:,
1
]))
image_tensor
=
predictor
.
get_input_handle
(
input_names
[
0
])
image_tensor
.
reshape
(
image_data
.
shape
)
image_tensor
.
copy_from_cpu
(
image_data
.
copy
())
image_size_tensor
=
predictor
.
get_input_handle
(
input_names
[
1
])
image_size_tensor
.
reshape
(
image_size_data
.
shape
)
image_size_tensor
.
copy_from_cpu
(
image_size_data
.
copy
())
predictor
.
run
()
output_names
=
predictor
.
get_output_names
()
output_handle
=
predictor
.
get_output_handle
(
output_names
[
0
])
output
=
postprocess
(
paths
=
paths
,
images
=
images
,
data_out
=
data_out
,
data_out
=
output_handle
,
score_thresh
=
score_thresh
,
label_names
=
self
.
label_names
,
output_dir
=
output_dir
,
...
...
@@ -219,11 +243,7 @@ class YOLOv3MobileNetV1Coco2017(hub.Module):
res
.
extend
(
output
)
return
res
def
save_inference_model
(
self
,
dirname
,
model_filename
=
None
,
params_filename
=
None
,
combined
=
True
):
def
save_inference_model
(
self
,
dirname
,
model_filename
=
None
,
params_filename
=
None
,
combined
=
True
):
if
combined
:
model_filename
=
"__model__"
if
not
model_filename
else
model_filename
params_filename
=
"__params__"
if
not
params_filename
else
params_filename
...
...
@@ -261,12 +281,9 @@ class YOLOv3MobileNetV1Coco2017(hub.Module):
prog
=
'hub run {}'
.
format
(
self
.
name
),
usage
=
'%(prog)s'
,
add_help
=
True
)
self
.
arg_input_group
=
self
.
parser
.
add_argument_group
(
title
=
"Input options"
,
description
=
"Input data. Required"
)
self
.
arg_input_group
=
self
.
parser
.
add_argument_group
(
title
=
"Input options"
,
description
=
"Input data. Required"
)
self
.
arg_config_group
=
self
.
parser
.
add_argument_group
(
title
=
"Config options"
,
description
=
"Run configuration for controlling module behavior, not required."
)
title
=
"Config options"
,
description
=
"Run configuration for controlling module behavior, not required."
)
self
.
add_module_config_arg
()
self
.
add_module_input_arg
()
args
=
self
.
parser
.
parse_args
(
argvs
)
...
...
@@ -276,7 +293,8 @@ class YOLOv3MobileNetV1Coco2017(hub.Module):
use_gpu
=
args
.
use_gpu
,
output_dir
=
args
.
output_dir
,
visualization
=
args
.
visualization
,
score_thresh
=
args
.
score_thresh
)
score_thresh
=
args
.
score_thresh
,
use_device
=
args
.
use_device
)
return
results
def
add_module_config_arg
(
self
):
...
...
@@ -284,34 +302,21 @@ class YOLOv3MobileNetV1Coco2017(hub.Module):
Add the command config options.
"""
self
.
arg_config_group
.
add_argument
(
'--use_gpu'
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"whether use GPU or not"
)
'--use_gpu'
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"whether use GPU or not"
)
self
.
arg_config_group
.
add_argument
(
'--output_dir'
,
type
=
str
,
default
=
'detection_result'
,
help
=
"The directory to save output images."
)
'--output_dir'
,
type
=
str
,
default
=
'detection_result'
,
help
=
"The directory to save output images."
)
self
.
arg_config_group
.
add_argument
(
'--visualization'
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"whether to save output as images."
)
'--visualization'
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"whether to save output as images."
)
self
.
arg_config_group
.
add_argument
(
'--use_device'
,
choices
=
[
"cpu"
,
"gpu"
,
"xpu"
,
"npu"
],
help
=
"use cpu, gpu, xpu or npu. overwrites use_gpu flag."
)
def
add_module_input_arg
(
self
):
"""
Add the command input options.
"""
self
.
arg_input_group
.
add_argument
(
'--input_path'
,
type
=
str
,
help
=
"path to image."
)
self
.
arg_input_group
.
add_argument
(
'--batch_size'
,
type
=
ast
.
literal_eval
,
default
=
1
,
help
=
"batch size."
)
self
.
arg_input_group
.
add_argument
(
'--input_path'
,
type
=
str
,
help
=
"path to image."
)
self
.
arg_input_group
.
add_argument
(
'--batch_size'
,
type
=
ast
.
literal_eval
,
default
=
1
,
help
=
"batch size."
)
self
.
arg_input_group
.
add_argument
(
'--score_thresh'
,
type
=
ast
.
literal_eval
,
default
=
0.5
,
help
=
"threshold for object detecion."
)
'--score_thresh'
,
type
=
ast
.
literal_eval
,
default
=
0.5
,
help
=
"threshold for object detecion."
)
modules/image/object_detection/yolov3_mobilenet_v1_coco2017/processor.py
浏览文件 @
0d60bf5a
...
...
@@ -50,21 +50,15 @@ def draw_bounding_box_on_image(image_path, data_list, save_dir):
image
=
Image
.
open
(
image_path
)
draw
=
ImageDraw
.
Draw
(
image
)
for
data
in
data_list
:
left
,
right
,
top
,
bottom
=
data
[
'left'
],
data
[
'right'
],
data
[
'top'
],
data
[
'bottom'
]
left
,
right
,
top
,
bottom
=
data
[
'left'
],
data
[
'right'
],
data
[
'top'
],
data
[
'bottom'
]
# draw bbox
draw
.
line
([(
left
,
top
),
(
left
,
bottom
),
(
right
,
bottom
),
(
right
,
top
),
(
left
,
top
)],
width
=
2
,
fill
=
'red'
)
draw
.
line
([(
left
,
top
),
(
left
,
bottom
),
(
right
,
bottom
),
(
right
,
top
),
(
left
,
top
)],
width
=
2
,
fill
=
'red'
)
# draw label
if
image
.
mode
==
'RGB'
:
text
=
data
[
'label'
]
+
": %.2f%%"
%
(
100
*
data
[
'confidence'
])
textsize_width
,
textsize_height
=
draw
.
textsize
(
text
=
text
)
draw
.
rectangle
(
xy
=
(
left
,
top
-
(
textsize_height
+
5
),
left
+
textsize_width
+
10
,
top
),
fill
=
(
255
,
255
,
255
))
xy
=
(
left
,
top
-
(
textsize_height
+
5
),
left
+
textsize_width
+
10
,
top
),
fill
=
(
255
,
255
,
255
))
draw
.
text
(
xy
=
(
left
,
top
-
15
),
text
=
text
,
fill
=
(
0
,
0
,
0
))
save_name
=
get_save_image_name
(
image
,
save_dir
,
image_path
)
...
...
@@ -92,14 +86,7 @@ def load_label_info(file_path):
return
label_names
def
postprocess
(
paths
,
images
,
data_out
,
score_thresh
,
label_names
,
output_dir
,
handle_id
,
visualization
=
True
):
def
postprocess
(
paths
,
images
,
data_out
,
score_thresh
,
label_names
,
output_dir
,
handle_id
,
visualization
=
True
):
"""
postprocess the lod_tensor produced by fluid.Executor.run
...
...
@@ -108,8 +95,6 @@ def postprocess(paths,
images (list(numpy.ndarray)): images data, shape of each is [H, W, C]
data_out (lod_tensor): data output of predictor.
batch_size (int): batch size.
use_gpu (bool): Whether to use gpu.
output_dir (str): The path to store output images.
visualization (bool): Whether to save image or not.
score_thresh (float): the low limit of bounding box.
label_names (list[str]): label names.
...
...
@@ -126,9 +111,8 @@ def postprocess(paths,
confidence (float): The confidence of detection result.
save_path (str): The path to save output images.
"""
lod_tensor
=
data_out
[
0
]
lod
=
lod_tensor
.
lod
[
0
]
results
=
lod_tensor
.
as_ndarray
()
results
=
data_out
.
copy_to_cpu
()
lod
=
data_out
.
lod
()[
0
]
check_dir
(
output_dir
)
...
...
@@ -157,9 +141,7 @@ def postprocess(paths,
org_img
=
org_img
.
astype
(
np
.
uint8
)
org_img
=
Image
.
fromarray
(
org_img
[:,
:,
::
-
1
])
if
visualization
:
org_img_path
=
get_save_image_name
(
org_img
,
output_dir
,
'image_numpy_{}'
.
format
(
(
handle_id
+
index
)))
org_img_path
=
get_save_image_name
(
org_img
,
output_dir
,
'image_numpy_{}'
.
format
((
handle_id
+
index
)))
org_img
.
save
(
org_img_path
)
org_img_height
=
org_img
.
height
org_img_width
=
org_img
.
width
...
...
@@ -175,13 +157,11 @@ def postprocess(paths,
dt
=
{}
dt
[
'label'
]
=
label_names
[
category_id
]
dt
[
'confidence'
]
=
float
(
confidence
)
dt
[
'left'
],
dt
[
'top'
],
dt
[
'right'
],
dt
[
'bottom'
]
=
clip_bbox
(
bbox
,
org_img_width
,
org_img_height
)
dt
[
'left'
],
dt
[
'top'
],
dt
[
'right'
],
dt
[
'bottom'
]
=
clip_bbox
(
bbox
,
org_img_width
,
org_img_height
)
output_i
[
'data'
].
append
(
dt
)
output
.
append
(
output_i
)
if
visualization
:
output_i
[
'save_path'
]
=
draw_bounding_box_on_image
(
org_img_path
,
output_i
[
'data'
],
output_dir
)
output_i
[
'save_path'
]
=
draw_bounding_box_on_image
(
org_img_path
,
output_i
[
'data'
],
output_dir
)
return
output
modules/image/object_detection/yolov3_resnet50_vd_coco2017/module.py
浏览文件 @
0d60bf5a
...
...
@@ -9,7 +9,10 @@ from functools import partial
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddlehub
as
hub
from
paddle.fluid.core
import
PaddleTensor
,
AnalysisConfig
,
create_paddle_predictor
from
paddle.inference
import
Config
from
paddle.inference
import
create_predictor
from
paddlehub.module.module
import
moduleinfo
,
runnable
,
serving
from
paddlehub.common.paddle_helper
import
add_vars_prefix
...
...
@@ -23,39 +26,63 @@ from yolov3_resnet50_vd_coco2017.yolo_head import MultiClassNMS, YOLOv3Head
name
=
"yolov3_resnet50_vd_coco2017"
,
version
=
"1.0.2"
,
type
=
"CV/object_detection"
,
summary
=
"Baidu's YOLOv3 model for object detection with backbone ResNet50, trained with dataset coco2017."
,
summary
=
"Baidu's YOLOv3 model for object detection with backbone ResNet50, trained with dataset coco2017."
,
author
=
"paddlepaddle"
,
author_email
=
"paddle-dev@baidu.com"
)
class
YOLOv3ResNet50Coco2017
(
hub
.
Module
):
def
_initialize
(
self
):
self
.
default_pretrained_model_path
=
os
.
path
.
join
(
self
.
directory
,
"yolov3_resnet50_model"
)
self
.
label_names
=
load_label_info
(
os
.
path
.
join
(
self
.
directory
,
"label_file.txt"
))
self
.
default_pretrained_model_path
=
os
.
path
.
join
(
self
.
directory
,
"yolov3_resnet50_model"
)
self
.
label_names
=
load_label_info
(
os
.
path
.
join
(
self
.
directory
,
"label_file.txt"
))
self
.
_set_config
()
def
_get_device_id
(
self
,
places
):
try
:
places
=
os
.
environ
[
places
]
id
=
int
(
places
)
except
:
id
=
-
1
return
id
def
_set_config
(
self
):
"""
predictor config setting.
"""
cpu_config
=
AnalysisConfig
(
self
.
default_pretrained_model_path
)
# create default cpu predictor
cpu_config
=
Config
(
self
.
default_pretrained_model_path
)
cpu_config
.
disable_glog_info
()
cpu_config
.
disable_gpu
()
cpu_config
.
switch_ir_optim
(
False
)
self
.
cpu_predictor
=
create_p
addle_p
redictor
(
cpu_config
)
self
.
cpu_predictor
=
create_predictor
(
cpu_config
)
try
:
_places
=
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
int
(
_places
[
0
])
use_gpu
=
True
except
:
use_gpu
=
False
if
use_gpu
:
gpu_config
=
AnalysisConfig
(
self
.
default_pretrained_model_path
)
# create predictors using various types of devices
# npu
npu_id
=
self
.
_get_device_id
(
"FLAGS_selected_npus"
)
if
npu_id
!=
-
1
:
# use npu
npu_config
=
Config
(
self
.
default_pretrained_model_path
)
npu_config
.
disable_glog_info
()
npu_config
.
enable_npu
(
device_id
=
npu_id
)
self
.
npu_predictor
=
create_predictor
(
npu_config
)
# gpu
gpu_id
=
self
.
_get_device_id
(
"CUDA_VISIBLE_DEVICES"
)
if
gpu_id
!=
-
1
:
# use gpu
gpu_config
=
Config
(
self
.
default_pretrained_model_path
)
gpu_config
.
disable_glog_info
()
gpu_config
.
enable_use_gpu
(
memory_pool_init_size_mb
=
500
,
device_id
=
0
)
self
.
gpu_predictor
=
create_paddle_predictor
(
gpu_config
)
gpu_config
.
enable_use_gpu
(
memory_pool_init_size_mb
=
500
,
device_id
=
gpu_id
)
self
.
gpu_predictor
=
create_predictor
(
gpu_config
)
# xpu
xpu_id
=
self
.
_get_device_id
(
"XPU_VISIBLE_DEVICES"
)
if
xpu_id
!=
-
1
:
# use xpu
xpu_config
=
Config
(
self
.
default_pretrained_model_path
)
xpu_config
.
disable_glog_info
()
xpu_config
.
enable_xpu
(
100
)
self
.
xpu_predictor
=
create_predictor
(
xpu_config
)
def
context
(
self
,
trainable
=
True
,
pretrained
=
True
,
get_prediction
=
False
):
"""
...
...
@@ -76,8 +103,7 @@ class YOLOv3ResNet50Coco2017(hub.Module):
with
fluid
.
program_guard
(
context_prog
,
startup_program
):
with
fluid
.
unique_name
.
guard
():
# image
image
=
fluid
.
layers
.
data
(
name
=
'image'
,
shape
=
[
3
,
608
,
608
],
dtype
=
'float32'
)
image
=
fluid
.
layers
.
data
(
name
=
'image'
,
shape
=
[
3
,
608
,
608
],
dtype
=
'float32'
)
# backbone
backbone
=
ResNet
(
norm_type
=
'sync_bn'
,
...
...
@@ -91,13 +117,11 @@ class YOLOv3ResNet50Coco2017(hub.Module):
# body_feats
body_feats
=
backbone
(
image
)
# im_size
im_size
=
fluid
.
layers
.
data
(
name
=
'im_size'
,
shape
=
[
2
],
dtype
=
'int32'
)
im_size
=
fluid
.
layers
.
data
(
name
=
'im_size'
,
shape
=
[
2
],
dtype
=
'int32'
)
# yolo_head
yolo_head
=
YOLOv3Head
(
num_classes
=
80
)
# head_features
head_features
,
body_features
=
yolo_head
.
_get_outputs
(
body_feats
,
is_train
=
trainable
)
head_features
,
body_features
=
yolo_head
.
_get_outputs
(
body_feats
,
is_train
=
trainable
)
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
...
...
@@ -106,35 +130,24 @@ class YOLOv3ResNet50Coco2017(hub.Module):
# var_prefix
var_prefix
=
'@HUB_{}@'
.
format
(
self
.
name
)
# name of inputs
inputs
=
{
'image'
:
var_prefix
+
image
.
name
,
'im_size'
:
var_prefix
+
im_size
.
name
}
inputs
=
{
'image'
:
var_prefix
+
image
.
name
,
'im_size'
:
var_prefix
+
im_size
.
name
}
# name of outputs
if
get_prediction
:
bbox_out
=
yolo_head
.
get_prediction
(
head_features
,
im_size
)
outputs
=
{
'bbox_out'
:
[
var_prefix
+
bbox_out
.
name
]}
else
:
outputs
=
{
'head_features'
:
[
var_prefix
+
var
.
name
for
var
in
head_features
],
'body_features'
:
[
var_prefix
+
var
.
name
for
var
in
body_features
]
'head_features'
:
[
var_prefix
+
var
.
name
for
var
in
head_features
],
'body_features'
:
[
var_prefix
+
var
.
name
for
var
in
body_features
]
}
# add_vars_prefix
add_vars_prefix
(
context_prog
,
var_prefix
)
add_vars_prefix
(
fluid
.
default_startup_program
(),
var_prefix
)
# inputs
inputs
=
{
key
:
context_prog
.
global_block
().
vars
[
value
]
for
key
,
value
in
inputs
.
items
()
}
inputs
=
{
key
:
context_prog
.
global_block
().
vars
[
value
]
for
key
,
value
in
inputs
.
items
()}
# outputs
outputs
=
{
key
:
[
context_prog
.
global_block
().
vars
[
varname
]
for
varname
in
value
]
key
:
[
context_prog
.
global_block
().
vars
[
varname
]
for
varname
in
value
]
for
key
,
value
in
outputs
.
items
()
}
# trainable
...
...
@@ -144,14 +157,9 @@ class YOLOv3ResNet50Coco2017(hub.Module):
if
pretrained
:
def
_if_exist
(
var
):
return
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
default_pretrained_model_path
,
var
.
name
))
fluid
.
io
.
load_vars
(
exe
,
self
.
default_pretrained_model_path
,
predicate
=
_if_exist
)
return
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
default_pretrained_model_path
,
var
.
name
))
fluid
.
io
.
load_vars
(
exe
,
self
.
default_pretrained_model_path
,
predicate
=
_if_exist
)
else
:
exe
.
run
(
startup_program
)
...
...
@@ -164,7 +172,8 @@ class YOLOv3ResNet50Coco2017(hub.Module):
use_gpu
=
False
,
output_dir
=
'detection_result'
,
score_thresh
=
0.5
,
visualization
=
True
):
visualization
=
True
,
use_device
=
None
):
"""API of Object Detection.
Args:
...
...
@@ -175,6 +184,7 @@ class YOLOv3ResNet50Coco2017(hub.Module):
output_dir (str): The path to store output images.
visualization (bool): Whether to save image or not.
score_thresh (float): threshold for object detecion.
use_device (str): use cpu, gpu, xpu or npu, overwrites use_gpu flag.
Returns:
res (list[dict]): The result of coco2017 detecion. keys include 'data', 'save_path', the corresponding value is:
...
...
@@ -187,14 +197,24 @@ class YOLOv3ResNet50Coco2017(hub.Module):
confidence (float): The confidence of detection result.
save_path (str, optional): The path to save output images.
"""
if
use_gpu
:
try
:
_places
=
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
int
(
_places
[
0
])
except
:
raise
RuntimeError
(
"Attempt to use GPU for prediction, but environment variable CUDA_VISIBLE_DEVICES was not set correctly."
)
# real predictor to use
if
use_device
is
not
None
:
if
use_device
==
"cpu"
:
predictor
=
self
.
cpu_predictor
elif
use_device
==
"xpu"
:
predictor
=
self
.
xpu_predictor
elif
use_device
==
"npu"
:
predictor
=
self
.
npu_predictor
elif
use_device
==
"gpu"
:
predictor
=
self
.
gpu_predictor
else
:
raise
Exception
(
"Unsupported device: "
+
use_device
)
else
:
# use_device is not set, therefore follow use_gpu
if
use_gpu
:
predictor
=
self
.
gpu_predictor
else
:
predictor
=
self
.
cpu_predictor
paths
=
paths
if
paths
else
list
()
data_reader
=
partial
(
reader
,
paths
,
images
)
...
...
@@ -202,19 +222,27 @@ class YOLOv3ResNet50Coco2017(hub.Module):
res
=
[]
for
iter_id
,
feed_data
in
enumerate
(
batch_reader
()):
feed_data
=
np
.
array
(
feed_data
)
image_tensor
=
PaddleTensor
(
np
.
array
(
list
(
feed_data
[:,
0
])))
im_size_tensor
=
PaddleTensor
(
np
.
array
(
list
(
feed_data
[:,
1
])))
if
use_gpu
:
data_out
=
self
.
gpu_predictor
.
run
(
[
image_tensor
,
im_size_tensor
])
else
:
data_out
=
self
.
cpu_predictor
.
run
(
[
image_tensor
,
im_size_tensor
])
input_names
=
predictor
.
get_input_names
()
image_data
=
np
.
array
(
list
(
feed_data
[:,
0
]))
image_size_data
=
np
.
array
(
list
(
feed_data
[:,
1
]))
image_tensor
=
predictor
.
get_input_handle
(
input_names
[
0
])
image_tensor
.
reshape
(
image_data
.
shape
)
image_tensor
.
copy_from_cpu
(
image_data
.
copy
())
image_size_tensor
=
predictor
.
get_input_handle
(
input_names
[
1
])
image_size_tensor
.
reshape
(
image_size_data
.
shape
)
image_size_tensor
.
copy_from_cpu
(
image_size_data
.
copy
())
predictor
.
run
()
output_names
=
predictor
.
get_output_names
()
output_handle
=
predictor
.
get_output_handle
(
output_names
[
0
])
output
=
postprocess
(
paths
=
paths
,
images
=
images
,
data_out
=
data_out
,
data_out
=
output_handle
,
score_thresh
=
score_thresh
,
label_names
=
self
.
label_names
,
output_dir
=
output_dir
,
...
...
@@ -223,11 +251,7 @@ class YOLOv3ResNet50Coco2017(hub.Module):
res
.
extend
(
output
)
return
res
def
save_inference_model
(
self
,
dirname
,
model_filename
=
None
,
params_filename
=
None
,
combined
=
True
):
def
save_inference_model
(
self
,
dirname
,
model_filename
=
None
,
params_filename
=
None
,
combined
=
True
):
if
combined
:
model_filename
=
"__model__"
if
not
model_filename
else
model_filename
params_filename
=
"__params__"
if
not
params_filename
else
params_filename
...
...
@@ -265,12 +289,9 @@ class YOLOv3ResNet50Coco2017(hub.Module):
prog
=
'hub run {}'
.
format
(
self
.
name
),
usage
=
'%(prog)s'
,
add_help
=
True
)
self
.
arg_input_group
=
self
.
parser
.
add_argument_group
(
title
=
"Input options"
,
description
=
"Input data. Required"
)
self
.
arg_input_group
=
self
.
parser
.
add_argument_group
(
title
=
"Input options"
,
description
=
"Input data. Required"
)
self
.
arg_config_group
=
self
.
parser
.
add_argument_group
(
title
=
"Config options"
,
description
=
"Run configuration for controlling module behavior, not required."
)
title
=
"Config options"
,
description
=
"Run configuration for controlling module behavior, not required."
)
self
.
add_module_config_arg
()
self
.
add_module_input_arg
()
args
=
self
.
parser
.
parse_args
(
argvs
)
...
...
@@ -280,7 +301,8 @@ class YOLOv3ResNet50Coco2017(hub.Module):
use_gpu
=
args
.
use_gpu
,
output_dir
=
args
.
output_dir
,
visualization
=
args
.
visualization
,
score_thresh
=
args
.
score_thresh
)
score_thresh
=
args
.
score_thresh
,
use_device
=
args
.
use_device
)
return
results
def
add_module_config_arg
(
self
):
...
...
@@ -288,34 +310,21 @@ class YOLOv3ResNet50Coco2017(hub.Module):
Add the command config options.
"""
self
.
arg_config_group
.
add_argument
(
'--use_gpu'
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"whether use GPU or not"
)
'--use_gpu'
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"whether use GPU or not"
)
self
.
arg_config_group
.
add_argument
(
'--output_dir'
,
type
=
str
,
default
=
'detection_result'
,
help
=
"The directory to save output images."
)
'--output_dir'
,
type
=
str
,
default
=
'detection_result'
,
help
=
"The directory to save output images."
)
self
.
arg_config_group
.
add_argument
(
'--visualization'
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"whether to save output as images."
)
'--visualization'
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"whether to save output as images."
)
self
.
arg_config_group
.
add_argument
(
'--use_device'
,
choices
=
[
"cpu"
,
"gpu"
,
"xpu"
,
"npu"
],
help
=
"use cpu, gpu, xpu or npu. overwrites use_gpu flag."
)
def
add_module_input_arg
(
self
):
"""
Add the command input options.
"""
self
.
arg_input_group
.
add_argument
(
'--input_path'
,
type
=
str
,
help
=
"path to image."
)
self
.
arg_input_group
.
add_argument
(
'--batch_size'
,
type
=
ast
.
literal_eval
,
default
=
1
,
help
=
"batch size."
)
self
.
arg_input_group
.
add_argument
(
'--input_path'
,
type
=
str
,
help
=
"path to image."
)
self
.
arg_input_group
.
add_argument
(
'--batch_size'
,
type
=
ast
.
literal_eval
,
default
=
1
,
help
=
"batch size."
)
self
.
arg_input_group
.
add_argument
(
'--score_thresh'
,
type
=
ast
.
literal_eval
,
default
=
0.5
,
help
=
"threshold for object detecion."
)
'--score_thresh'
,
type
=
ast
.
literal_eval
,
default
=
0.5
,
help
=
"threshold for object detecion."
)
modules/image/object_detection/yolov3_resnet50_vd_coco2017/processor.py
浏览文件 @
0d60bf5a
...
...
@@ -50,21 +50,15 @@ def draw_bounding_box_on_image(image_path, data_list, save_dir):
image
=
Image
.
open
(
image_path
)
draw
=
ImageDraw
.
Draw
(
image
)
for
data
in
data_list
:
left
,
right
,
top
,
bottom
=
data
[
'left'
],
data
[
'right'
],
data
[
'top'
],
data
[
'bottom'
]
left
,
right
,
top
,
bottom
=
data
[
'left'
],
data
[
'right'
],
data
[
'top'
],
data
[
'bottom'
]
# draw bbox
draw
.
line
([(
left
,
top
),
(
left
,
bottom
),
(
right
,
bottom
),
(
right
,
top
),
(
left
,
top
)],
width
=
2
,
fill
=
'red'
)
draw
.
line
([(
left
,
top
),
(
left
,
bottom
),
(
right
,
bottom
),
(
right
,
top
),
(
left
,
top
)],
width
=
2
,
fill
=
'red'
)
# draw label
if
image
.
mode
==
'RGB'
:
text
=
data
[
'label'
]
+
": %.2f%%"
%
(
100
*
data
[
'confidence'
])
textsize_width
,
textsize_height
=
draw
.
textsize
(
text
=
text
)
draw
.
rectangle
(
xy
=
(
left
,
top
-
(
textsize_height
+
5
),
left
+
textsize_width
+
10
,
top
),
fill
=
(
255
,
255
,
255
))
xy
=
(
left
,
top
-
(
textsize_height
+
5
),
left
+
textsize_width
+
10
,
top
),
fill
=
(
255
,
255
,
255
))
draw
.
text
(
xy
=
(
left
,
top
-
15
),
text
=
text
,
fill
=
(
0
,
0
,
0
))
save_name
=
get_save_image_name
(
image
,
save_dir
,
image_path
)
...
...
@@ -92,14 +86,7 @@ def load_label_info(file_path):
return
label_names
def
postprocess
(
paths
,
images
,
data_out
,
score_thresh
,
label_names
,
output_dir
,
handle_id
,
visualization
=
True
):
def
postprocess
(
paths
,
images
,
data_out
,
score_thresh
,
label_names
,
output_dir
,
handle_id
,
visualization
=
True
):
"""
postprocess the lod_tensor produced by fluid.Executor.run
...
...
@@ -107,8 +94,6 @@ def postprocess(paths,
paths (list[str]): The paths of images.
images (list(numpy.ndarray)): images data, shape of each is [H, W, C]
data_out (lod_tensor): data output of predictor.
batch_size (int): batch size.
use_gpu (bool): Whether to use gpu.
output_dir (str): The path to store output images.
visualization (bool): Whether to save image or not.
score_thresh (float): the low limit of bounding box.
...
...
@@ -126,9 +111,8 @@ def postprocess(paths,
confidence (float): The confidence of detection result.
save_path (str): The path to save output images.
"""
lod_tensor
=
data_out
[
0
]
lod
=
lod_tensor
.
lod
[
0
]
results
=
lod_tensor
.
as_ndarray
()
results
=
data_out
.
copy_to_cpu
()
lod
=
data_out
.
lod
()[
0
]
check_dir
(
output_dir
)
...
...
@@ -146,7 +130,6 @@ def postprocess(paths,
else
:
unhandled_paths_num
=
0
output
=
list
()
for
index
in
range
(
len
(
lod
)
-
1
):
output_i
=
{
'data'
:
[]}
...
...
@@ -158,9 +141,7 @@ def postprocess(paths,
org_img
=
org_img
.
astype
(
np
.
uint8
)
org_img
=
Image
.
fromarray
(
org_img
[:,
:,
::
-
1
])
if
visualization
:
org_img_path
=
get_save_image_name
(
org_img
,
output_dir
,
'image_numpy_{}'
.
format
(
(
handle_id
+
index
)))
org_img_path
=
get_save_image_name
(
org_img
,
output_dir
,
'image_numpy_{}'
.
format
((
handle_id
+
index
)))
org_img
.
save
(
org_img_path
)
org_img_height
=
org_img
.
height
org_img_width
=
org_img
.
width
...
...
@@ -176,13 +157,11 @@ def postprocess(paths,
dt
=
{}
dt
[
'label'
]
=
label_names
[
category_id
]
dt
[
'confidence'
]
=
float
(
confidence
)
dt
[
'left'
],
dt
[
'top'
],
dt
[
'right'
],
dt
[
'bottom'
]
=
clip_bbox
(
bbox
,
org_img_width
,
org_img_height
)
dt
[
'left'
],
dt
[
'top'
],
dt
[
'right'
],
dt
[
'bottom'
]
=
clip_bbox
(
bbox
,
org_img_width
,
org_img_height
)
output_i
[
'data'
].
append
(
dt
)
output
.
append
(
output_i
)
if
visualization
:
output_i
[
'save_path'
]
=
draw_bounding_box_on_image
(
org_img_path
,
output_i
[
'data'
],
output_dir
)
output_i
[
'save_path'
]
=
draw_bounding_box_on_image
(
org_img_path
,
output_i
[
'data'
],
output_dir
)
return
output
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录