Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
9b40ee0e
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9b40ee0e
编写于
9月 07, 2022
作者:
H
HydrogenSulfate
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add shitu whl
上级
6b218caf
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
159 addition
and
29 deletion
+159
-29
paddleclas.py
paddleclas.py
+159
-29
未找到文件。
paddleclas.py
浏览文件 @
9b40ee0e
...
...
@@ -32,6 +32,7 @@ from .ppcls.arch import backbone
from
.ppcls.utils
import
logger
from
.deploy.python.predict_cls
import
ClsPredictor
from
.deploy.python.predict_system
import
SystemPredictor
from
.deploy.utils.get_image_list
import
get_image_list
from
.deploy.utils
import
config
...
...
@@ -194,6 +195,14 @@ PULC_MODELS = [
"textline_orientation"
,
"traffic_sign"
,
"vehicle_attribute"
]
SHITU_MODEL_BASE_DOWNLOAD_URL
=
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/{}_infer.tar"
SHITU_MODELS
=
[
# "picodet_PPLCNet_x2_5_mainbody_lite_v1.0", # ShiTuV1(V2)_mainbody_det
# "general_PPLCNet_x2_5_lite_v1.0" # ShiTuV1_general_rec
# "PP-ShiTuV2/general_PPLCNetV2_base_pretrained_v1.0", # ShiTuV2_general_rec TODO(hesensen): add lite model
"PP-ShiTuV2"
]
class
ImageTypeError
(
Exception
):
"""ImageTypeError.
...
...
@@ -213,12 +222,24 @@ class InputModelError(Exception):
def
init_config
(
model_type
,
model_name
,
inference_model_dir
,
**
kwargs
):
cfg_path
=
f
"deploy/configs/PULC/
{
model_name
}
/inference_
{
model_name
}
.yaml"
if
model_type
==
"pulc"
else
"deploy/configs/inference_cls.yaml"
if
model_type
==
"pulc"
:
cfg_path
=
f
"deploy/configs/PULC/
{
model_name
}
/inference_
{
model_name
}
.yaml"
elif
model_type
==
"shitu"
:
cfg_path
=
"deploy/configs/inference_general.yaml"
else
:
cfg_path
=
"deploy/configs/inference_cls.yaml"
__dir__
=
os
.
path
.
dirname
(
__file__
)
cfg_path
=
os
.
path
.
join
(
__dir__
,
cfg_path
)
cfg
=
config
.
get_config
(
cfg_path
,
show
=
False
)
if
cfg
.
Global
.
get
(
"inference_model_dir"
):
cfg
.
Global
.
inference_model_dir
=
inference_model_dir
else
:
cfg
.
Global
.
rec_inference_model_dir
=
os
.
path
.
join
(
inference_model_dir
,
"PP-ShiTuV2/general_PPLCNetV2_base_pretrained_v1.0"
)
cfg
.
Global
.
det_inference_model_dir
=
os
.
path
.
join
(
inference_model_dir
,
"picodet_PPLCNet_x2_5_mainbody_lite_v1.0"
)
if
"batch_size"
in
kwargs
and
kwargs
[
"batch_size"
]:
cfg
.
Global
.
batch_size
=
kwargs
[
"batch_size"
]
...
...
@@ -232,6 +253,10 @@ def init_config(model_type, model_name, inference_model_dir, **kwargs):
if
"infer_imgs"
in
kwargs
and
kwargs
[
"infer_imgs"
]:
cfg
.
Global
.
infer_imgs
=
kwargs
[
"infer_imgs"
]
if
"index_dir"
in
kwargs
and
kwargs
[
"index_dir"
]:
cfg
.
IndexProcess
.
index_dir
=
kwargs
[
"index_dir"
]
if
"data_file"
in
kwargs
and
kwargs
[
"data_file"
]:
cfg
.
IndexProcess
.
data_file
=
kwargs
[
"data_file"
]
if
"enable_mkldnn"
in
kwargs
and
kwargs
[
"enable_mkldnn"
]:
cfg
.
Global
.
enable_mkldnn
=
kwargs
[
"enable_mkldnn"
]
if
"cpu_num_threads"
in
kwargs
and
kwargs
[
"cpu_num_threads"
]:
...
...
@@ -253,6 +278,7 @@ def init_config(model_type, model_name, inference_model_dir, **kwargs):
if
"thresh"
in
kwargs
and
kwargs
[
"thresh"
]
and
"ThreshOutput"
in
cfg
.
PostProcess
:
cfg
.
PostProcess
.
ThreshOutput
.
thresh
=
kwargs
[
"thresh"
]
if
cfg
.
get
(
"PostProcess"
):
if
"Topk"
in
cfg
.
PostProcess
:
if
"topk"
in
kwargs
and
kwargs
[
"topk"
]:
cfg
.
PostProcess
.
Topk
.
topk
=
kwargs
[
"topk"
]
...
...
@@ -295,6 +321,13 @@ def args_cfg():
type
=
str
,
help
=
"The directory of model files. Valid when model_name not specifed."
)
parser
.
add_argument
(
"--index_dir"
,
type
=
str
,
required
=
False
,
help
=
"The index directory path."
)
parser
.
add_argument
(
"--data_file"
,
type
=
str
,
required
=
False
,
help
=
"The label file path."
)
parser
.
add_argument
(
"--use_gpu"
,
type
=
str2bool
,
help
=
"Whether use GPU."
)
parser
.
add_argument
(
"--gpu_mem"
,
...
...
@@ -347,6 +380,7 @@ def print_info():
"""
imn_table
=
PrettyTable
([
"IMN Model Series"
,
"Model Name"
])
pulc_table
=
PrettyTable
([
"PULC Models"
])
shitu_table
=
PrettyTable
([
"PP-ShiTu Models"
])
try
:
sz
=
os
.
get_terminal_size
()
total_width
=
sz
.
columns
...
...
@@ -365,11 +399,16 @@ def print_info():
textwrap
.
fill
(
" "
.
join
(
PULC_MODELS
),
width
=
total_width
).
center
(
table_width
-
4
)
])
shitu_table
.
add_row
([
textwrap
.
fill
(
" "
.
join
(
SHITU_MODELS
),
width
=
total_width
).
center
(
table_width
-
4
)
])
print
(
"{}"
.
format
(
"-"
*
table_width
))
print
(
"Models supported by PaddleClas"
.
center
(
table_width
))
print
(
imn_table
)
print
(
pulc_table
)
print
(
shitu_table
)
print
(
"Powered by PaddlePaddle!"
.
rjust
(
table_width
))
print
(
"{}"
.
format
(
"-"
*
table_width
))
...
...
@@ -425,6 +464,10 @@ def check_model_file(model_type, model_name):
storage_directory
=
partial
(
os
.
path
.
join
,
BASE_INFERENCE_MODEL_DIR
,
"PULC"
,
model_name
)
url
=
PULC_MODEL_BASE_DOWNLOAD_URL
.
format
(
model_name
)
elif
model_type
==
"shitu"
:
storage_directory
=
partial
(
os
.
path
.
join
,
BASE_INFERENCE_MODEL_DIR
,
"PP-ShiTu"
,
model_name
)
url
=
SHITU_MODEL_BASE_DOWNLOAD_URL
.
format
(
model_name
)
else
:
storage_directory
=
partial
(
os
.
path
.
join
,
BASE_INFERENCE_MODEL_DIR
,
"IMN"
,
model_name
)
...
...
@@ -485,8 +528,10 @@ class PaddleClas(object):
model_name
,
inference_model_dir
)
self
.
_config
=
init_config
(
self
.
model_type
,
model_name
,
inference_model_dir
,
**
kwargs
)
self
.
cls_predictor
=
ClsPredictor
(
self
.
_config
)
if
self
.
model_type
==
"shitu"
:
self
.
predictor
=
SystemPredictor
(
self
.
_config
)
else
:
self
.
predictor
=
ClsPredictor
(
self
.
_config
)
def
get_config
(
self
):
"""Get the config.
...
...
@@ -498,6 +543,7 @@ class PaddleClas(object):
"""
all_imn_model_names
=
get_imn_model_names
()
all_pulc_model_names
=
PULC_MODELS
all_shitu_model_names
=
SHITU_MODELS
if
model_name
:
if
model_name
in
all_imn_model_names
:
...
...
@@ -506,6 +552,15 @@ class PaddleClas(object):
elif
model_name
in
all_pulc_model_names
:
inference_model_dir
=
check_model_file
(
"pulc"
,
model_name
)
return
"pulc"
,
inference_model_dir
elif
model_name
in
all_shitu_model_names
:
inference_model_dir
=
check_model_file
(
"shitu"
,
"PP-ShiTuV2/general_PPLCNetV2_base_pretrained_v1.0"
)
inference_model_dir
=
check_model_file
(
"shitu"
,
"picodet_PPLCNet_x2_5_mainbody_lite_v1.0"
)
inference_model_dir
=
os
.
path
.
abspath
(
os
.
path
.
dirname
(
inference_model_dir
))
return
"shitu"
,
inference_model_dir
else
:
similar_imn_names
=
similar_model_names
(
model_name
,
all_imn_model_names
)
...
...
@@ -526,11 +581,12 @@ class PaddleClas(object):
raise
InputModelError
(
err
)
return
"custom"
,
inference_model_dir
else
:
err
=
f
"Please specify the model name supported by PaddleClas or directory contained model files(inference.pdmodel, inference.pdiparams)."
err
=
"Please specify the model name supported by PaddleClas or directory contained model files(inference.pdmodel, inference.pdiparams)."
raise
InputModelError
(
err
)
return
None
def
predict
(
self
,
input_data
:
Union
[
str
,
np
.
array
],
def
predict_cls
(
self
,
input_data
:
Union
[
str
,
np
.
array
],
print_pred
:
bool
=
False
)
->
Generator
[
list
,
None
,
None
]:
"""Predict input_data.
...
...
@@ -551,7 +607,7 @@ class PaddleClas(object):
"""
if
isinstance
(
input_data
,
np
.
ndarray
):
yield
self
.
cls_
predictor
.
predict
(
input_data
)
yield
self
.
predictor
.
predict
(
input_data
)
elif
isinstance
(
input_data
,
str
):
if
input_data
.
startswith
(
"http"
)
or
input_data
.
startswith
(
"https"
):
image_storage_dir
=
partial
(
os
.
path
.
join
,
BASE_IMAGES_DIR
)
...
...
@@ -583,7 +639,7 @@ class PaddleClas(object):
cnt
+=
1
if
cnt
%
batch_size
==
0
or
(
idx_img
+
1
)
==
len
(
image_list
):
preds
=
self
.
cls_
predictor
.
predict
(
img_list
)
preds
=
self
.
predictor
.
predict
(
img_list
)
if
preds
:
for
idx_pred
,
pred
in
enumerate
(
preds
):
...
...
@@ -600,6 +656,77 @@ class PaddleClas(object):
raise
ImageTypeError
(
err
)
return
def
predict_shitu
(
self
,
input_data
:
Union
[
str
,
np
.
array
],
print_pred
:
bool
=
False
)
->
Generator
[
list
,
None
,
None
]:
"""Predict input_data.
Args:
input_data (Union[str, np.array]):
When the type is str, it is the path of image, or the directory containing images, or the URL of image from Internet.
When the type is np.array, it is the image data whose channel order is RGB.
print_pred (bool, optional): Whether print the prediction result. Defaults to False.
Raises:
ImageTypeError: Illegal input_data.
Yields:
Generator[list, None, None]:
The prediction result(s) of input_data by batch_size. For every one image,
prediction result(s) is zipped as a dict, that includs topk "class_ids", "scores" and "label_names".
The format of batch prediction result(s) is as follow: [{"class_ids": [...], "scores": [...], "label_names": [...]}, ...]
"""
if
isinstance
(
input_data
,
np
.
ndarray
):
yield
self
.
predictor
.
predict
(
input_data
)
elif
isinstance
(
input_data
,
str
):
if
input_data
.
startswith
(
"http"
)
or
input_data
.
startswith
(
"https"
):
image_storage_dir
=
partial
(
os
.
path
.
join
,
BASE_IMAGES_DIR
)
if
not
os
.
path
.
exists
(
image_storage_dir
()):
os
.
makedirs
(
image_storage_dir
())
image_save_path
=
image_storage_dir
(
"tmp.jpg"
)
download_with_progressbar
(
input_data
,
image_save_path
)
logger
.
info
(
f
"Image to be predicted from Internet:
{
input_data
}
, has been saved to:
{
image_save_path
}
"
)
input_data
=
image_save_path
image_list
=
get_image_list
(
input_data
)
cnt
=
0
for
idx_img
,
img_path
in
enumerate
(
image_list
):
img
=
cv2
.
imread
(
img_path
)
if
img
is
None
:
logger
.
warning
(
f
"Image file failed to read and has been skipped. The path:
{
img_path
}
"
)
continue
img
=
img
[:,
:,
::
-
1
]
cnt
+=
1
preds
=
self
.
predictor
.
predict
(
img
)
# [dict1, dict2, ..., dictn]
if
preds
:
if
print_pred
:
logger
.
info
(
f
"
{
preds
}
, filename:
{
img_path
}
"
)
yield
preds
else
:
err
=
"Please input legal image! The type of image supported by PaddleClas are: NumPy.ndarray and string of local path or Ineternet URL"
raise
ImageTypeError
(
err
)
return
def
predict
(
self
,
input_data
:
Union
[
str
,
np
.
array
],
print_pred
:
bool
=
False
,
predict_type
=
"cls"
):
if
predict_type
==
"cls"
:
return
self
.
predict_cls
(
input_data
,
print_pred
)
elif
predict_type
==
"shitu"
:
assert
not
isinstance
(
input_data
,
(
list
,
tuple
)),
"PP-ShiTu predictor only support single image as input now."
return
self
.
predict_shitu
(
input_data
,
print_pred
)
else
:
raise
ModuleNotFoundError
# for CLI
def
main
():
...
...
@@ -608,7 +735,10 @@ def main():
print_info
()
cfg
=
args_cfg
()
clas_engine
=
PaddleClas
(
**
cfg
)
res
=
clas_engine
.
predict
(
cfg
[
"infer_imgs"
],
print_pred
=
True
)
res
=
clas_engine
.
predict
(
cfg
[
"infer_imgs"
],
print_pred
=
True
,
predict_type
=
"cls"
if
"PP-ShiTu"
not
in
cfg
[
"model_name"
]
else
"shitu"
)
for
_
in
res
:
pass
logger
.
info
(
"Predict complete!"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录