Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
9d3c22f5
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9d3c22f5
编写于
10月 27, 2022
作者:
悟、
提交者:
Tingquan Gao
10月 28, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update whl for shitu
上级
7e097f3c
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
56 addition
and
18 deletion
+56
-18
paddleclas.py
paddleclas.py
+56
-18
未找到文件。
paddleclas.py
浏览文件 @
9d3c22f5
...
...
@@ -33,6 +33,7 @@ from .ppcls.utils import logger
from
.deploy.python.predict_cls
import
ClsPredictor
from
.deploy.python.predict_system
import
SystemPredictor
from
.deploy.python.build_gallery
import
GalleryBuilder
from
.deploy.utils.get_image_list
import
get_image_list
from
.deploy.utils
import
config
...
...
@@ -227,7 +228,9 @@ class InputModelError(Exception):
def
init_config
(
model_type
,
model_name
,
inference_model_dir
,
**
kwargs
):
if
model_type
==
"pulc"
:
if
kwargs
.
get
(
"build_gallery"
,
False
):
cfg_path
=
"deploy/configs/inference_general.yaml"
elif
model_type
==
"pulc"
:
cfg_path
=
f
"deploy/configs/PULC/
{
model_name
}
/inference_
{
model_name
}
.yaml"
elif
model_type
==
"shitu"
:
cfg_path
=
"deploy/configs/inference_general.yaml"
...
...
@@ -236,7 +239,8 @@ def init_config(model_type, model_name, inference_model_dir, **kwargs):
__dir__
=
os
.
path
.
dirname
(
__file__
)
cfg_path
=
os
.
path
.
join
(
__dir__
,
cfg_path
)
cfg
=
config
.
get_config
(
cfg_path
,
show
=
False
)
cfg
=
config
.
get_config
(
cfg_path
,
overrides
=
kwargs
.
get
(
"override"
,
None
),
show
=
False
)
if
cfg
.
Global
.
get
(
"inference_model_dir"
):
cfg
.
Global
.
inference_model_dir
=
inference_model_dir
else
:
...
...
@@ -337,10 +341,15 @@ def args_cfg():
parser
.
add_argument
(
"--infer_imgs"
,
type
=
str
,
required
=
Tru
e
,
required
=
Fals
e
,
help
=
"The image(s) to be predicted."
)
parser
.
add_argument
(
"--model_name"
,
type
=
str
,
help
=
"The model name to be used."
)
parser
.
add_argument
(
"--predict_type"
,
type
=
str
,
default
=
"cls"
,
help
=
"The predict type to be selected."
)
parser
.
add_argument
(
"--inference_model_dir"
,
type
=
str
,
...
...
@@ -395,7 +404,17 @@ def args_cfg():
parser
.
add_argument
(
"--resize_short"
,
type
=
int
,
help
=
"Resize according to short size."
)
parser
.
add_argument
(
"--crop_size"
,
type
=
int
,
help
=
"Centor crop size."
)
parser
.
add_argument
(
"--build_gallery"
,
type
=
str2bool
,
default
=
False
,
help
=
"Whether build gallery."
)
parser
.
add_argument
(
'-o'
,
'--override'
,
action
=
'append'
,
default
=
[],
help
=
'config options to be overridden'
)
args
=
parser
.
parse_args
()
return
vars
(
args
)
...
...
@@ -549,14 +568,27 @@ class PaddleClas(object):
"""
super
().
__init__
()
self
.
model_type
,
inference_model_dir
=
self
.
_check_input_model
(
model_name
,
inference_model_dir
)
self
.
_config
=
init_config
(
self
.
model_type
,
model_name
,
inference_model_dir
,
**
kwargs
)
if
self
.
model_type
==
"shitu"
:
self
.
predictor
=
SystemPredictor
(
self
.
_config
)
if
kwargs
.
get
(
"build_gallery"
,
False
):
self
.
model_type
,
inference_model_dir
=
self
.
_check_input_model
(
model_name
if
model_name
else
"PP-ShiTuV2"
,
inference_model_dir
)
self
.
_config
=
init_config
(
self
.
model_type
,
model_name
if
model_name
else
"PP-ShiTuV2"
,
inference_model_dir
,
**
kwargs
)
logger
.
info
(
"Building Gallery..."
)
GalleryBuilder
(
self
.
_config
)
else
:
self
.
predictor
=
ClsPredictor
(
self
.
_config
)
self
.
model_type
,
inference_model_dir
=
self
.
_check_input_model
(
model_name
,
inference_model_dir
)
self
.
_config
=
init_config
(
self
.
model_type
,
model_name
,
inference_model_dir
,
**
kwargs
)
if
self
.
model_type
==
"shitu"
:
self
.
predictor
=
SystemPredictor
(
self
.
_config
)
else
:
self
.
predictor
=
ClsPredictor
(
self
.
_config
)
def
get_config
(
self
):
"""Get the config.
...
...
@@ -700,6 +732,9 @@ class PaddleClas(object):
prediction result(s) is zipped as a dict, that includs topk "class_ids", "scores" and "label_names".
The format of batch prediction result(s) is as follow: [{"class_ids": [...], "scores": [...], "label_names": [...]}, ...]
"""
if
input_data
==
None
and
self
.
_config
.
Global
.
infer_imgs
:
input_data
=
self
.
_config
.
Global
.
infer_imgs
if
isinstance
(
input_data
,
np
.
ndarray
):
yield
self
.
predictor
.
predict
(
input_data
)
elif
isinstance
(
input_data
,
str
):
...
...
@@ -742,6 +777,8 @@ class PaddleClas(object):
input_data
:
Union
[
str
,
np
.
array
],
print_pred
:
bool
=
False
,
predict_type
=
"cls"
):
assert
predict_type
in
[
"cls"
,
"shitu"
],
"Predict type should be 'cls' or 'shitu'."
if
predict_type
==
"cls"
:
return
self
.
predict_cls
(
input_data
,
print_pred
)
elif
predict_type
==
"shitu"
:
...
...
@@ -760,13 +797,14 @@ def main():
print_info
()
cfg
=
args_cfg
()
clas_engine
=
PaddleClas
(
**
cfg
)
res
=
clas_engine
.
predict
(
cfg
[
"infer_imgs"
],
print_pred
=
True
,
predict_type
=
"cls"
if
"PP-ShiTu"
not
in
cfg
[
"model_name"
]
else
"shitu"
)
for
_
in
res
:
pass
logger
.
info
(
"Predict complete!"
)
if
cfg
[
"build_gallery"
]
==
False
:
res
=
clas_engine
.
predict
(
cfg
[
"infer_imgs"
],
print_pred
=
True
,
predict_type
=
cfg
[
"predict_type"
])
for
_
in
res
:
pass
logger
.
info
(
"Predict complete!"
)
return
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录