Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
6fce3751
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
接近 2 年 前同步成功
通知
116
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6fce3751
编写于
6月 08, 2021
作者:
F
Felix
提交者:
GitHub
6月 08, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update predict_system.py
上级
f6b768ed
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
12 addition
and
47 deletion
+12
-47
deploy/python/predict_system.py
deploy/python/predict_system.py
+12
-47
未找到文件。
deploy/python/predict_system.py
浏览文件 @
6fce3751
...
...
@@ -29,23 +29,6 @@ from utils import logger
from
utils
import
config
from
utils.get_image_list
import
get_image_list
def
split_datafile
(
data_file
,
image_root
):
gallery_images
=
[]
gallery_docs
=
[]
with
open
(
data_file
)
as
f
:
lines
=
f
.
readlines
()
for
i
,
line
in
enumerate
(
lines
):
line
=
line
.
strip
().
split
(
"
\t
"
)
if
line
[
0
]
==
'image_id'
:
continue
image_file
=
os
.
path
.
join
(
image_root
,
line
[
3
])
image_doc
=
line
[
1
]
gallery_images
.
append
(
image_file
)
gallery_docs
.
append
(
image_doc
)
return
gallery_images
,
gallery_docs
class
SystemPredictor
(
object
):
def
__init__
(
self
,
config
):
...
...
@@ -54,48 +37,30 @@ class SystemPredictor(object):
self
.
det_predictor
=
DetPredictor
(
config
)
assert
'IndexProcess'
in
config
.
keys
(),
"Index config not found ... "
self
.
indexer
(
config
[
'IndexProcess'
])
self
.
return_k
=
self
.
config
[
'IndexProcess'
][
'infer'
][
'return_k'
]
self
.
search_budget
=
self
.
config
[
'IndexProcess'
][
'infer'
][
'search_budget'
]
def
indexer
(
self
,
config
):
if
'build'
in
config
.
keys
()
and
config
[
'build'
][
'enable'
]:
# build the index from scratch
with
open
(
config
[
'build'
][
'data_file'
])
as
f
:
lines
=
f
.
readlines
()
gallery_images
,
gallery_docs
=
split_datafile
(
config
[
'build'
][
'data_file'
],
config
[
'build'
][
'image_root'
])
# extract gallery features
gallery_features
=
np
.
zeros
([
len
(
gallery_images
),
config
[
'build'
][
'embedding_size'
]],
dtype
=
np
.
float32
)
for
i
,
image_file
in
enumerate
(
gallery_images
):
img
=
cv2
.
imread
(
image_file
)[:,
:,
::
-
1
]
rec_feat
=
self
.
rec_predictor
.
predict
(
img
)
gallery_features
[
i
,:]
=
rec_feat
# train index
self
.
Searcher
=
Graph_Index
(
dist_type
=
config
[
'build'
][
'dist_type'
])
self
.
Searcher
.
build
(
gallery_vectors
=
gallery_features
,
gallery_docs
=
gallery_docs
,
pq_size
=
config
[
'build'
][
'pq_size'
],
index_path
=
config
[
'build'
][
'index_path'
])
else
:
# load local index
self
.
Searcher
=
Graph_Index
(
dist_type
=
config
[
'build'
][
'dist_type'
])
self
.
Searcher
.
load
(
config
[
'infer'
][
'index_path'
])
self
.
return_k
=
self
.
config
[
'IndexProcess'
][
'return_k'
]
self
.
search_budget
=
self
.
config
[
'IndexProcess'
][
'search_budget'
]
self
.
Searcher
=
Graph_Index
(
dist_type
=
config
[
'IndexProcess'
][
'dist_type'
])
self
.
Searcher
.
load
(
config
[
'IndexProcess'
][
'index_path'
])
def
predict
(
self
,
img
):
output
=
[]
results
=
self
.
det_predictor
.
predict
(
img
)
for
result
in
results
:
preds
=
{}
xmin
,
ymin
,
xmax
,
ymax
=
result
[
"bbox"
].
astype
(
"int"
)
crop_img
=
img
[
xmin
:
xmax
,
ymin
:
ymax
,
:].
copy
()
rec_results
=
self
.
rec_predictor
.
predict
(
crop_img
)
result
[
"feature"
]
=
rec_results
#preds
["feature"] = rec_results
preds
[
"bbox"
]
=
[
xmin
,
ymin
,
xmax
,
ymax
]
scores
,
docs
=
self
.
Searcher
.
search
(
query
=
rec_results
,
return_k
=
self
.
return_k
,
search_budget
=
self
.
search_budget
)
result
[
"ret
_docs"
]
=
docs
result
[
"ret
_scores"
]
=
scores
preds
[
"rec
_docs"
]
=
docs
preds
[
"rec
_scores"
]
=
scores
output
.
append
(
result
)
output
.
append
(
preds
)
return
output
def
main
(
config
):
system_predictor
=
SystemPredictor
(
config
)
image_list
=
get_image_list
(
config
[
"Global"
][
"infer_imgs"
])
...
...
@@ -104,7 +69,7 @@ def main(config):
for
idx
,
image_file
in
enumerate
(
image_list
):
img
=
cv2
.
imread
(
image_file
)[:,
:,
::
-
1
]
output
=
system_predictor
.
predict
(
img
)
#
print(output)
print
(
output
)
return
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录