Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
8211c039
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
1 年多 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8211c039
编写于
9月 20, 2022
作者:
D
dongshuilong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix shitu_index manager bug
上级
3e0f7767
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
515 addition
and
379 deletion
+515
-379
deploy/shitu_index_manager/client.py
deploy/shitu_index_manager/client.py
+45
-0
deploy/shitu_index_manager/index_manager.py
deploy/shitu_index_manager/index_manager.py
+30
-315
deploy/shitu_index_manager/mod/mainwindow.py
deploy/shitu_index_manager/mod/mainwindow.py
+63
-55
deploy/shitu_index_manager/server.py
deploy/shitu_index_manager/server.py
+340
-0
docs/zh_CN/inference_deployment/shitu_gallery_manager.md
docs/zh_CN/inference_deployment/shitu_gallery_manager.md
+37
-9
未找到文件。
deploy/shitu_index_manager/client.py
0 → 100644
浏览文件 @
8211c039
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
from
PyQt5
import
QtCore
,
QtGui
,
QtWidgets
import
mod.mainwindow
"""
完整的index库如下:
root_path/ # 库存储目录
|-- image_list.txt # 图像列表,每行:image_path label。由前端生成及修改。后端只读
|-- features.pkl # 建库之后,保存的embedding向量,后端生成,前端无需操作
|-- images # 图像存储目录,由前端生成及增删查等操作。后端只读
| |-- md5.jpg
| |-- md5.jpg
| |-- ……
|-- index # 真正的生成的index库存储目录,后端生成及操作,前端无需操作。
| |-- vector.index # faiss生成的索引库
| |-- id_map.pkl # 索引文件
"""
def
FrontInterface
(
server_ip
=
None
,
server_port
=
None
):
front
=
QtWidgets
.
QApplication
([])
main_window
=
mod
.
mainwindow
.
MainWindow
(
ip
=
server_ip
,
port
=
server_port
)
main_window
.
showMaximized
()
sys
.
exit
(
front
.
exec_
())
if
__name__
==
'__main__'
:
server_ip
=
None
server_port
=
None
if
len
(
sys
.
argv
)
==
2
and
len
(
sys
.
argv
[
1
].
split
(
' '
))
==
2
:
[
server_ip
,
server_port
]
=
sys
.
argv
[
1
].
split
(
' '
)
FrontInterface
(
server_ip
,
server_port
)
deploy/shitu_index_manager/index_manager.py
浏览文件 @
8211c039
...
@@ -13,22 +13,10 @@
...
@@ -13,22 +13,10 @@
# limitations under the License.
# limitations under the License.
import
os
import
os
import
sys
import
sys
from
PyQt5
import
QtCore
,
QtGui
,
QtWidgets
import
subprocess
import
mod.mainwindow
import
shlex
import
psutil
from
paddleclas.deploy.utils
import
config
,
logger
import
time
from
paddleclas.deploy.python.predict_rec
import
RecPredictor
from
fastapi
import
FastAPI
import
uvicorn
import
numpy
as
np
import
faiss
from
typing
import
List
import
pickle
import
cv2
import
socket
import
json
import
operator
from
multiprocessing
import
Process
"""
"""
完整的index库如下:
完整的index库如下:
root_path/ # 库存储目录
root_path/ # 库存储目录
...
@@ -43,307 +31,34 @@ root_path/ # 库存储目录
...
@@ -43,307 +31,34 @@ root_path/ # 库存储目录
| |-- id_map.pkl # 索引文件
| |-- id_map.pkl # 索引文件
"""
"""
if
__name__
==
'__main__'
:
class
ShiTuIndexManager
(
object
):
if
not
(
len
(
sys
.
argv
)
==
3
or
len
(
sys
.
argv
)
==
5
):
print
(
"start example:"
)
def
__init__
(
self
,
config
):
print
(
" python index_manager.py -c xxx.yaml"
)
self
.
root_path
=
None
print
(
" python index_manager.py -c xxx.yaml -p port"
)
self
.
image_list_path
=
"image_list.txt"
yaml_path
=
sys
.
argv
[
2
]
self
.
image_dir
=
"images"
if
len
(
sys
.
argv
)
==
5
:
self
.
index_path
=
"index/vector.index"
port
=
sys
.
argv
[
4
]
self
.
id_map_path
=
"index/id_map.pkl"
self
.
features_path
=
"features.pkl"
self
.
index
=
None
self
.
id_map
=
None
self
.
features
=
None
self
.
config
=
config
self
.
predictor
=
RecPredictor
(
config
)
def
_load_pickle
(
self
,
path
):
if
os
.
path
.
exists
(
path
):
return
pickle
.
load
(
open
(
path
,
'rb'
))
else
:
return
None
def
_save_pickle
(
self
,
path
,
data
):
if
not
os
.
path
.
exists
(
os
.
path
.
dirname
(
path
)):
os
.
makedirs
(
os
.
path
.
dirname
(
path
),
exist_ok
=
True
)
with
open
(
path
,
'wb'
)
as
fd
:
pickle
.
dump
(
data
,
fd
)
def
_load_index
(
self
):
self
.
index
=
faiss
.
read_index
(
os
.
path
.
join
(
self
.
root_path
,
self
.
index_path
))
self
.
id_map
=
self
.
_load_pickle
(
os
.
path
.
join
(
self
.
root_path
,
self
.
id_map_path
))
self
.
features
=
self
.
_load_pickle
(
os
.
path
.
join
(
self
.
root_path
,
self
.
features_path
))
def
_save_index
(
self
,
index
,
id_map
,
features
):
faiss
.
write_index
(
index
,
os
.
path
.
join
(
self
.
root_path
,
self
.
index_path
))
self
.
_save_pickle
(
os
.
path
.
join
(
self
.
root_path
,
self
.
id_map_path
),
id_map
)
self
.
_save_pickle
(
os
.
path
.
join
(
self
.
root_path
,
self
.
features_path
),
features
)
def
_update_path
(
self
,
root_path
,
image_list_path
=
None
):
if
root_path
==
self
.
root_path
:
pass
else
:
self
.
root_path
=
root_path
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
root_path
,
"index"
)):
os
.
mkdir
(
os
.
path
.
join
(
root_path
,
"index"
))
if
image_list_path
is
not
None
:
self
.
image_list_path
=
image_list_path
def
_cal_featrue
(
self
,
image_list
):
batch_images
=
[]
featrures
=
None
cnt
=
0
for
idx
,
image_path
in
enumerate
(
image_list
):
image
=
cv2
.
imread
(
image_path
)
if
image
is
None
:
return
"{} is broken or not exist. Stop"
else
:
image
=
image
[:,
:,
::
-
1
]
batch_images
.
append
(
image
)
cnt
+=
1
if
cnt
%
self
.
config
[
"Global"
][
"batch_size"
]
==
0
or
(
idx
+
1
)
==
len
(
image_list
):
if
len
(
batch_images
)
==
0
:
continue
batch_results
=
self
.
predictor
.
predict
(
batch_images
)
featrures
=
batch_results
if
featrures
is
None
else
np
.
concatenate
(
(
featrures
,
batch_results
),
axis
=
0
)
batch_images
=
[]
return
featrures
def
_split_datafile
(
self
,
data_file
,
image_root
):
'''
data_file: image path and info, which can be splitted by spacer
image_root: image path root
delimiter: delimiter
'''
gallery_images
=
[]
gallery_docs
=
[]
gallery_ids
=
[]
with
open
(
data_file
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
lines
=
f
.
readlines
()
for
_
,
ori_line
in
enumerate
(
lines
):
line
=
ori_line
.
strip
().
split
()
text_num
=
len
(
line
)
assert
text_num
>=
2
,
f
"line(
{
ori_line
}
) must be splitted into at least 2 parts, but got
{
text_num
}
"
image_file
=
os
.
path
.
join
(
image_root
,
line
[
0
])
gallery_images
.
append
(
image_file
)
gallery_docs
.
append
(
ori_line
.
strip
())
gallery_ids
.
append
(
os
.
path
.
basename
(
line
[
0
]).
split
(
"."
)[
0
])
return
gallery_images
,
gallery_docs
,
gallery_ids
def
create_index
(
self
,
image_list
:
str
,
index_method
:
str
=
"HNSW32"
,
image_root
:
str
=
None
):
if
not
os
.
path
.
exists
(
image_list
):
return
"{} is not exist"
.
format
(
image_list
)
if
index_method
.
lower
()
not
in
[
'hnsw32'
,
'ivf'
,
'flat'
]:
return
"The index method Only support: HNSW32, IVF, Flat"
self
.
_update_path
(
os
.
path
.
dirname
(
image_list
),
image_list
)
# get image_paths
image_root
=
image_root
if
image_root
is
not
None
else
self
.
root_path
gallery_images
,
gallery_docs
,
image_ids
=
self
.
_split_datafile
(
image_list
,
image_root
)
# gernerate index
if
index_method
==
"IVF"
:
index_method
=
index_method
+
str
(
min
(
max
(
int
(
len
(
gallery_images
)
//
32
),
2
),
65536
))
+
",Flat"
index
=
faiss
.
index_factory
(
self
.
config
[
"IndexProcess"
][
"embedding_size"
],
index_method
,
faiss
.
METRIC_INNER_PRODUCT
)
self
.
index
=
faiss
.
IndexIDMap2
(
index
)
features
=
self
.
_cal_featrue
(
gallery_images
)
self
.
index
.
train
(
features
)
index_ids
=
np
.
arange
(
0
,
len
(
gallery_images
)).
astype
(
np
.
int64
)
self
.
index
.
add_with_ids
(
features
,
index_ids
)
self
.
id_map
=
dict
()
for
i
,
d
in
zip
(
list
(
index_ids
),
gallery_docs
):
self
.
id_map
[
i
]
=
d
self
.
features
=
{
"features"
:
features
,
"index_method"
:
index_method
,
"image_ids"
:
image_ids
,
"index_ids"
:
index_ids
.
tolist
()
}
self
.
_save_index
(
self
.
index
,
self
.
id_map
,
self
.
features
)
def
open_index
(
self
,
root_path
:
str
,
image_list_path
:
str
)
->
str
:
self
.
_update_path
(
root_path
)
_
,
_
,
image_ids
=
self
.
_split_datafile
(
image_list_path
,
root_path
)
if
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
root_path
,
self
.
index_path
))
and
\
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
root_path
,
self
.
id_map_path
))
and
\
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
root_path
,
self
.
features_path
)):
self
.
_update_path
(
root_path
)
self
.
_load_index
()
if
operator
.
eq
(
set
(
image_ids
),
set
(
self
.
features
[
'image_ids'
])):
return
""
else
:
return
"The image list is different from index, Please update index"
else
:
return
"File not exist: features.pkl, vector.index, id_map.pkl"
def
update_index
(
self
,
image_list
:
str
,
image_root
:
str
=
None
)
->
str
:
if
self
.
index
and
self
.
id_map
and
self
.
features
:
image_paths
,
image_docs
,
image_ids
=
self
.
_split_datafile
(
image_list
,
image_root
if
image_root
is
not
None
else
self
.
root_path
)
# for add image
add_ids
=
list
(
set
(
image_ids
).
difference
(
set
(
self
.
features
[
"image_ids"
])))
add_indexes
=
[
i
for
i
,
x
in
enumerate
(
image_ids
)
if
x
in
add_ids
]
add_image_paths
=
[
image_paths
[
i
]
for
i
in
add_indexes
]
add_image_docs
=
[
image_docs
[
i
]
for
i
in
add_indexes
]
add_image_ids
=
[
image_ids
[
i
]
for
i
in
add_indexes
]
self
.
_add_index
(
add_image_paths
,
add_image_docs
,
add_image_ids
)
# delete images
delete_ids
=
list
(
set
(
self
.
features
[
"image_ids"
]).
difference
(
set
(
image_ids
)))
self
.
_delete_index
(
delete_ids
)
self
.
_save_index
(
self
.
index
,
self
.
id_map
,
self
.
features
)
return
""
else
:
return
"Failed. Please create or open index first"
def
_add_index
(
self
,
image_list
:
List
,
image_docs
:
List
,
image_ids
:
List
):
if
len
(
image_ids
)
==
0
:
return
featrures
=
self
.
_cal_featrue
(
image_list
)
index_ids
=
(
np
.
arange
(
0
,
len
(
image_list
))
+
max
(
self
.
id_map
.
keys
())
+
1
).
astype
(
np
.
int64
)
self
.
index
.
add_with_ids
(
featrures
,
index_ids
)
for
i
,
d
in
zip
(
index_ids
,
image_docs
):
self
.
id_map
[
i
]
=
d
self
.
features
[
'features'
]
=
np
.
concatenate
(
[
self
.
features
[
'features'
],
featrures
],
axis
=
0
)
self
.
features
[
'image_ids'
].
extend
(
image_ids
)
self
.
features
[
'index_ids'
].
extend
(
index_ids
.
tolist
())
def
_delete_index
(
self
,
image_ids
:
List
):
if
len
(
image_ids
)
==
0
:
return
indexes
=
[
i
for
i
,
x
in
enumerate
(
self
.
features
[
'image_ids'
])
if
x
in
image_ids
]
self
.
features
[
"features"
]
=
np
.
delete
(
self
.
features
[
"features"
],
indexes
,
axis
=
0
)
self
.
features
[
"image_ids"
]
=
np
.
delete
(
np
.
asarray
(
self
.
features
[
"image_ids"
]),
indexes
,
axis
=
0
).
tolist
()
index_ids
=
np
.
delete
(
np
.
asarray
(
self
.
features
[
"index_ids"
]),
indexes
,
axis
=
0
).
tolist
()
id_map_values
=
[
self
.
id_map
[
i
]
for
i
in
index_ids
]
self
.
index
.
reset
()
ids
=
np
.
arange
(
0
,
len
(
id_map_values
)).
astype
(
np
.
int64
)
self
.
index
.
add_with_ids
(
self
.
features
[
'features'
],
ids
)
self
.
id_map
.
clear
()
for
i
,
d
in
zip
(
ids
,
id_map_values
):
self
.
id_map
[
i
]
=
d
self
.
features
[
"index_ids"
]
=
ids
app
=
FastAPI
()
@
app
.
get
(
"/new_index"
)
def
new_index
(
image_list_path
:
str
,
index_method
:
str
=
"HNSW32"
,
index_root_path
:
str
=
None
,
force
:
bool
=
False
):
result
=
""
try
:
if
index_root_path
is
not
None
:
image_list_path
=
os
.
path
.
join
(
index_root_path
,
image_list_path
)
index_path
=
os
.
path
.
join
(
index_root_path
,
"index"
,
"vector.index"
)
id_map_path
=
os
.
path
.
join
(
index_root_path
,
"index"
,
"id_map.pkl"
)
if
not
(
os
.
path
.
exists
(
index_path
)
and
os
.
path
.
exists
(
id_map_path
))
or
force
:
manager
.
create_index
(
image_list_path
,
index_method
,
index_root_path
)
else
:
else
:
result
=
"There alrealy has index in {}"
.
format
(
index_root_path
)
port
=
8000
except
Exception
as
e
:
assert
int
(
port
)
>
1024
and
int
(
result
=
e
.
__str__
()
port
)
<
65536
,
"The port should be bigger than 1024 and
\
data
=
{
"error_message"
:
result
}
smaller than 65536"
return
json
.
dumps
(
data
).
encode
()
@
app
.
get
(
"/open_index"
)
def
open_index
(
index_root_path
:
str
,
image_list_path
:
str
):
result
=
""
try
:
image_list_path
=
os
.
path
.
join
(
index_root_path
,
image_list_path
)
result
=
manager
.
open_index
(
index_root_path
,
image_list_path
)
except
Exception
as
e
:
result
=
e
.
__str__
()
data
=
{
"error_message"
:
result
}
return
json
.
dumps
(
data
).
encode
()
@
app
.
get
(
"/update_index"
)
def
update_index
(
image_list_path
:
str
,
index_root_path
:
str
=
None
):
result
=
""
try
:
if
index_root_path
is
not
None
:
image_list_path
=
os
.
path
.
join
(
index_root_path
,
image_list_path
)
result
=
manager
.
update_index
(
image_list
=
image_list_path
,
image_root
=
index_root_path
)
except
Exception
as
e
:
result
=
e
.
__str__
()
data
=
{
"error_message"
:
result
}
return
json
.
dumps
(
data
).
encode
()
def
FrontInterface
(
server_process
=
None
):
front
=
QtWidgets
.
QApplication
([])
main_window
=
mod
.
mainwindow
.
MainWindow
(
process
=
server_process
)
main_window
.
showMaximized
()
sys
.
exit
(
front
.
exec_
())
def
Server
(
args
):
[
app
,
host
,
port
]
=
args
uvicorn
.
run
(
app
,
host
=
host
,
port
=
port
)
if
__name__
==
'__main__'
:
args
=
config
.
parse_args
()
model_config
=
config
.
get_config
(
args
.
config
,
overrides
=
args
.
override
,
show
=
True
)
manager
=
ShiTuIndexManager
(
model_config
)
try
:
try
:
ip
=
socket
.
gethostbyname
(
socket
.
gethostname
())
ip
=
socket
.
gethostbyname
(
socket
.
gethostname
())
except
:
except
:
ip
=
'127.0.0.1'
ip
=
'127.0.0.1'
port
=
8000
server_cmd
=
"python server.py -c {} -o ip={} -o port={}"
.
format
(
yaml_path
,
p_server
=
Process
(
target
=
Server
,
args
=
([
app
,
ip
,
port
],))
ip
,
port
)
p_server
.
start
()
server_proc
=
subprocess
.
Popen
(
shlex
.
split
(
server_cmd
))
# p_client = Process(target=FrontInterface, args=())
client_proc
=
subprocess
.
Popen
(
# p_client.start()
[
"python"
,
"client.py"
,
"{} {}"
.
format
(
ip
,
port
)])
# p_client.join()
try
:
FrontInterface
(
p_server
)
while
psutil
.
Process
(
client_proc
.
pid
).
status
()
==
"running"
:
p_server
.
terminate
()
time
.
sleep
(
0.5
)
sys
.
exit
(
0
)
except
:
pass
client_proc
.
terminate
()
server_proc
.
terminate
()
deploy/shitu_index_manager/mod/mainwindow.py
浏览文件 @
8211c039
...
@@ -22,8 +22,6 @@ try:
...
@@ -22,8 +22,6 @@ try:
DEFAULT_HOST
=
socket
.
gethostbyname
(
socket
.
gethostname
())
DEFAULT_HOST
=
socket
.
gethostbyname
(
socket
.
gethostname
())
except
:
except
:
DEFAULT_HOST
=
'127.0.0.1'
DEFAULT_HOST
=
'127.0.0.1'
# DEFAULT_HOST = "localhost"
DEFAULT_PORT
=
8000
DEFAULT_PORT
=
8000
PADDLECLAS_DOC_URL
=
"https://gitee.com/paddlepaddle/PaddleClas/docs/zh_CN/inference_deployment/shitu_gallery_manager.md"
PADDLECLAS_DOC_URL
=
"https://gitee.com/paddlepaddle/PaddleClas/docs/zh_CN/inference_deployment/shitu_gallery_manager.md"
...
@@ -35,12 +33,17 @@ class MainWindow(QtWidgets.QMainWindow):
...
@@ -35,12 +33,17 @@ class MainWindow(QtWidgets.QMainWindow):
updateIndexMsg
=
QtCore
.
pyqtSignal
(
str
)
# 更新索引库线程信号
updateIndexMsg
=
QtCore
.
pyqtSignal
(
str
)
# 更新索引库线程信号
importImageCount
=
QtCore
.
pyqtSignal
(
int
)
# 导入图像数量信号
importImageCount
=
QtCore
.
pyqtSignal
(
int
)
# 导入图像数量信号
def
__init__
(
self
,
process
=
None
):
def
__init__
(
self
,
ip
=
None
,
port
=
None
):
super
(
MainWindow
,
self
).
__init__
()
super
(
MainWindow
,
self
).
__init__
()
self
.
server_process
=
process
if
ip
is
not
None
and
port
is
not
None
:
self
.
server_ip
=
ip
self
.
server_port
=
port
else
:
self
.
server_ip
=
DEFAULT_HOST
self
.
server_port
=
DEFAULT_PORT
self
.
ui
=
ui_mainwindow
.
Ui_MainWindow
()
self
.
ui
=
ui_mainwindow
.
Ui_MainWindow
()
self
.
ui
.
setupUi
(
self
)
# 初始化主窗口界面
self
.
ui
.
setupUi
(
self
)
# 初始化主窗口界面
self
.
__imageListMgr
=
image_list_manager
.
ImageListManager
()
self
.
__imageListMgr
=
image_list_manager
.
ImageListManager
()
self
.
__appMenu
=
QtWidgets
.
QMenu
()
# 应用菜单
self
.
__appMenu
=
QtWidgets
.
QMenu
()
# 应用菜单
...
@@ -115,8 +118,7 @@ class MainWindow(QtWidgets.QMainWindow):
...
@@ -115,8 +118,7 @@ class MainWindow(QtWidgets.QMainWindow):
self
.
ui
.
saveImageLibraryBtn
.
clicked
.
connect
(
self
.
saveImageLibrary
)
self
.
ui
.
saveImageLibraryBtn
.
clicked
.
connect
(
self
.
saveImageLibrary
)
self
.
__setToolButton
(
self
.
ui
.
addClassifyBtn
,
"添加分类"
,
self
.
__setToolButton
(
self
.
ui
.
addClassifyBtn
,
"添加分类"
,
"./resource/add_classify.png"
,
"./resource/add_classify.png"
,
TOOL_BTN_ICON_SIZE
)
TOOL_BTN_ICON_SIZE
)
self
.
ui
.
addClassifyBtn
.
clicked
.
connect
(
self
.
ui
.
addClassifyBtn
.
clicked
.
connect
(
self
.
__classifyUiContext
.
addClassify
)
self
.
__classifyUiContext
.
addClassify
)
...
@@ -145,7 +147,10 @@ class MainWindow(QtWidgets.QMainWindow):
...
@@ -145,7 +147,10 @@ class MainWindow(QtWidgets.QMainWindow):
self
.
ui
.
searchClassifyHistoryCmb
.
setToolTip
(
"查找分类历史"
)
self
.
ui
.
searchClassifyHistoryCmb
.
setToolTip
(
"查找分类历史"
)
self
.
ui
.
imageScaleSlider
.
setToolTip
(
"图片缩放"
)
self
.
ui
.
imageScaleSlider
.
setToolTip
(
"图片缩放"
)
def
__setToolButton
(
self
,
button
,
tool_tip
:
str
,
icon_path
:
str
,
def
__setToolButton
(
self
,
button
,
tool_tip
:
str
,
icon_path
:
str
,
icon_size
:
int
):
icon_size
:
int
):
"""设置工具按钮"""
"""设置工具按钮"""
button
.
setToolTip
(
tool_tip
)
button
.
setToolTip
(
tool_tip
)
...
@@ -179,16 +184,16 @@ class MainWindow(QtWidgets.QMainWindow):
...
@@ -179,16 +184,16 @@ class MainWindow(QtWidgets.QMainWindow):
def
__initWaitDialog
(
self
):
def
__initWaitDialog
(
self
):
"""初始化等待对话框"""
"""初始化等待对话框"""
self
.
__waitDialogUi
.
setupUi
(
self
.
__waitDialog
)
self
.
__waitDialogUi
.
setupUi
(
self
.
__waitDialog
)
self
.
__waitDialog
.
setWindowFlags
(
QtCore
.
Qt
.
Dialog
self
.
__waitDialog
.
setWindowFlags
(
QtCore
.
Qt
.
Dialog
|
|
QtCore
.
Qt
.
FramelessWindowHint
)
QtCore
.
Qt
.
FramelessWindowHint
)
def
__startWait
(
self
,
msg
:
str
):
def
__startWait
(
self
,
msg
:
str
):
"""开始显示等待对话框"""
"""开始显示等待对话框"""
self
.
setEnabled
(
False
)
self
.
setEnabled
(
False
)
self
.
__waitDialogUi
.
msgLabel
.
setText
(
msg
)
self
.
__waitDialogUi
.
msgLabel
.
setText
(
msg
)
self
.
__waitDialog
.
setWindowFlags
(
QtCore
.
Qt
.
Dialog
self
.
__waitDialog
.
setWindowFlags
(
QtCore
.
Qt
.
Dialog
|
|
QtCore
.
Qt
.
FramelessWindowHint
QtCore
.
Qt
.
FramelessWindowHint
|
|
QtCore
.
Qt
.
WindowStaysOnTopHint
)
QtCore
.
Qt
.
WindowStaysOnTopHint
)
self
.
__waitDialog
.
show
()
self
.
__waitDialog
.
show
()
self
.
__waitDialog
.
repaint
()
self
.
__waitDialog
.
repaint
()
...
@@ -196,9 +201,9 @@ class MainWindow(QtWidgets.QMainWindow):
...
@@ -196,9 +201,9 @@ class MainWindow(QtWidgets.QMainWindow):
"""停止显示等待对话框"""
"""停止显示等待对话框"""
self
.
setEnabled
(
True
)
self
.
setEnabled
(
True
)
self
.
__waitDialogUi
.
msgLabel
.
setText
(
"执行完毕!"
)
self
.
__waitDialogUi
.
msgLabel
.
setText
(
"执行完毕!"
)
self
.
__waitDialog
.
setWindowFlags
(
QtCore
.
Qt
.
Dialog
self
.
__waitDialog
.
setWindowFlags
(
QtCore
.
Qt
.
Dialog
|
|
QtCore
.
Qt
.
FramelessWindowHint
QtCore
.
Qt
.
FramelessWindowHint
|
|
QtCore
.
Qt
.
CustomizeWindowHint
)
QtCore
.
Qt
.
CustomizeWindowHint
)
self
.
__waitDialog
.
close
()
self
.
__waitDialog
.
close
()
def
__connectSignal
(
self
):
def
__connectSignal
(
self
):
...
@@ -290,8 +295,8 @@ class MainWindow(QtWidgets.QMainWindow):
...
@@ -290,8 +295,8 @@ class MainWindow(QtWidgets.QMainWindow):
def
__importImageListImageThread
(
self
,
from_path
:
str
,
to_path
:
str
):
def
__importImageListImageThread
(
self
,
from_path
:
str
,
to_path
:
str
):
"""导入 image_list 图像 线程"""
"""导入 image_list 图像 线程"""
count
=
utils
.
oneKeyImportFromFile
(
from_path
=
from_path
,
count
=
utils
.
oneKeyImportFromFile
(
to_path
=
to_path
)
from_path
=
from_path
,
to_path
=
to_path
)
if
count
==
None
:
if
count
==
None
:
count
=
-
1
count
=
-
1
self
.
importImageCount
.
emit
(
count
)
self
.
importImageCount
.
emit
(
count
)
...
@@ -308,9 +313,9 @@ class MainWindow(QtWidgets.QMainWindow):
...
@@ -308,9 +313,9 @@ class MainWindow(QtWidgets.QMainWindow):
return
return
from_mgr
=
image_list_manager
.
ImageListManager
(
from_path
)
from_mgr
=
image_list_manager
.
ImageListManager
(
from_path
)
self
.
__startWait
(
"正在导入图像,请等待。。。"
)
self
.
__startWait
(
"正在导入图像,请等待。。。"
)
thread
=
threading
.
Thread
(
target
=
self
.
__importImageListImageThread
,
thread
=
threading
.
Thread
(
args
=
(
from_mgr
.
filePath
,
target
=
self
.
__importImageListImageThread
,
self
.
__imageListMgr
.
filePath
))
args
=
(
from_mgr
.
filePath
,
self
.
__imageListMgr
.
filePath
))
thread
.
start
()
thread
.
start
()
def
__importDirsImageThread
(
self
,
from_dir
:
str
,
to_image_list_path
:
str
):
def
__importDirsImageThread
(
self
,
from_dir
:
str
,
to_image_list_path
:
str
):
...
@@ -333,18 +338,22 @@ class MainWindow(QtWidgets.QMainWindow):
...
@@ -333,18 +338,22 @@ class MainWindow(QtWidgets.QMainWindow):
QtWidgets
.
QMessageBox
.
information
(
self
,
"提示"
,
"打开的目录不存在"
)
QtWidgets
.
QMessageBox
.
information
(
self
,
"提示"
,
"打开的目录不存在"
)
return
return
self
.
__startWait
(
"正在导入图像,请等待。。。"
)
self
.
__startWait
(
"正在导入图像,请等待。。。"
)
thread
=
threading
.
Thread
(
target
=
self
.
__importDirsImageThread
,
thread
=
threading
.
Thread
(
args
=
(
dir_path
,
target
=
self
.
__importDirsImageThread
,
self
.
__imageListMgr
.
filePath
))
args
=
(
dir_path
,
self
.
__imageListMgr
.
filePath
))
thread
.
start
()
thread
.
start
()
def
__newIndexThread
(
self
,
index_root_path
:
str
,
image_list_path
:
str
,
def
__newIndexThread
(
self
,
index_method
:
str
,
force
:
bool
):
index_root_path
:
str
,
image_list_path
:
str
,
index_method
:
str
,
force
:
bool
):
"""新建重建索引库线程"""
"""新建重建索引库线程"""
try
:
try
:
client
=
index_http_client
.
IndexHttpClient
(
client
=
index_http_client
.
IndexHttpClient
(
self
.
server_ip
,
DEFAULT_HOST
,
DEFAULT_PORT
)
self
.
server_port
)
err_msg
=
client
.
new_index
(
image_list_path
=
image_list_path
,
err_msg
=
client
.
new_index
(
image_list_path
=
image_list_path
,
index_root_path
=
index_root_path
,
index_root_path
=
index_root_path
,
index_method
=
index_method
,
index_method
=
index_method
,
force
=
force
)
force
=
force
)
...
@@ -375,18 +384,19 @@ class MainWindow(QtWidgets.QMainWindow):
...
@@ -375,18 +384,19 @@ class MainWindow(QtWidgets.QMainWindow):
force
=
ui
.
resetCheckBox
.
isChecked
()
force
=
ui
.
resetCheckBox
.
isChecked
()
if
result
==
QtWidgets
.
QDialog
.
Accepted
:
if
result
==
QtWidgets
.
QDialog
.
Accepted
:
self
.
__startWait
(
"正在 新建/重建 索引库,请等待。。。"
)
self
.
__startWait
(
"正在 新建/重建 索引库,请等待。。。"
)
thread
=
threading
.
Thread
(
target
=
self
.
__newIndexThread
,
thread
=
threading
.
Thread
(
args
=
(
self
.
__imageListMgr
.
dirName
,
target
=
self
.
__newIndexThread
,
"image_list.txt"
,
index_method
,
args
=
(
self
.
__imageListMgr
.
dirName
,
"image_list.txt"
,
force
))
index_method
,
force
))
thread
.
start
()
thread
.
start
()
def
__openIndexThread
(
self
,
index_root_path
:
str
,
image_list_path
:
str
):
def
__openIndexThread
(
self
,
index_root_path
:
str
,
image_list_path
:
str
):
"""打开索引库线程"""
"""打开索引库线程"""
try
:
try
:
client
=
index_http_client
.
IndexHttpClient
(
client
=
index_http_client
.
IndexHttpClient
(
self
.
server_ip
,
DEFAULT_HOST
,
DEFAULT_PORT
)
self
.
server_port
)
err_msg
=
client
.
open_index
(
index_root_path
=
index_root_path
,
err_msg
=
client
.
open_index
(
index_root_path
=
index_root_path
,
image_list_path
=
image_list_path
)
image_list_path
=
image_list_path
)
if
err_msg
==
None
:
if
err_msg
==
None
:
err_msg
=
""
err_msg
=
""
...
@@ -408,17 +418,18 @@ class MainWindow(QtWidgets.QMainWindow):
...
@@ -408,17 +418,18 @@ class MainWindow(QtWidgets.QMainWindow):
QtWidgets
.
QMessageBox
.
information
(
self
,
"提示"
,
"请先打开正确的图像库"
)
QtWidgets
.
QMessageBox
.
information
(
self
,
"提示"
,
"请先打开正确的图像库"
)
return
return
self
.
__startWait
(
"正在打开索引库,请等待。。。"
)
self
.
__startWait
(
"正在打开索引库,请等待。。。"
)
thread
=
threading
.
Thread
(
target
=
self
.
__openIndexThread
,
thread
=
threading
.
Thread
(
args
=
(
self
.
__imageListMgr
.
dirName
,
target
=
self
.
__openIndexThread
,
"image_list.txt"
))
args
=
(
self
.
__imageListMgr
.
dirName
,
"image_list.txt"
))
thread
.
start
()
thread
.
start
()
def
__updateIndexThread
(
self
,
index_root_path
:
str
,
image_list_path
:
str
):
def
__updateIndexThread
(
self
,
index_root_path
:
str
,
image_list_path
:
str
):
"""更新索引库线程"""
"""更新索引库线程"""
try
:
try
:
client
=
index_http_client
.
IndexHttpClient
(
client
=
index_http_client
.
IndexHttpClient
(
self
.
server_ip
,
DEFAULT_HOST
,
DEFAULT_PORT
)
self
.
server_port
)
err_msg
=
client
.
update_index
(
image_list_path
=
image_list_path
,
err_msg
=
client
.
update_index
(
image_list_path
=
image_list_path
,
index_root_path
=
index_root_path
)
index_root_path
=
index_root_path
)
if
err_msg
==
None
:
if
err_msg
==
None
:
err_msg
=
""
err_msg
=
""
...
@@ -440,9 +451,9 @@ class MainWindow(QtWidgets.QMainWindow):
...
@@ -440,9 +451,9 @@ class MainWindow(QtWidgets.QMainWindow):
QtWidgets
.
QMessageBox
.
information
(
self
,
"提示"
,
"请先打开正确的图像库"
)
QtWidgets
.
QMessageBox
.
information
(
self
,
"提示"
,
"请先打开正确的图像库"
)
return
return
self
.
__startWait
(
"正在更新索引库,请等待。。。"
)
self
.
__startWait
(
"正在更新索引库,请等待。。。"
)
thread
=
threading
.
Thread
(
target
=
self
.
__updateIndexThread
,
thread
=
threading
.
Thread
(
args
=
(
self
.
__imageListMgr
.
dirName
,
target
=
self
.
__updateIndexThread
,
"image_list.txt"
))
args
=
(
self
.
__imageListMgr
.
dirName
,
"image_list.txt"
))
thread
.
start
()
thread
.
start
()
def
searchClassify
(
self
):
def
searchClassify
(
self
):
...
@@ -471,9 +482,6 @@ class MainWindow(QtWidgets.QMainWindow):
...
@@ -471,9 +482,6 @@ class MainWindow(QtWidgets.QMainWindow):
def
exitApp
(
self
):
def
exitApp
(
self
):
"""退出应用"""
"""退出应用"""
if
isinstance
(
self
.
server_process
,
Process
):
self
.
server_process
.
terminate
()
# os.kill(self.server_pid)
sys
.
exit
(
0
)
sys
.
exit
(
0
)
def
__setPathBar
(
self
,
msg
:
str
):
def
__setPathBar
(
self
,
msg
:
str
):
...
...
deploy/shitu_index_manager/server.py
0 → 100644
浏览文件 @
8211c039
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
from
PyQt5
import
QtCore
,
QtGui
,
QtWidgets
import
mod.mainwindow
from
paddleclas.deploy.utils
import
config
,
logger
from
paddleclas.deploy.python.predict_rec
import
RecPredictor
from
fastapi
import
FastAPI
import
uvicorn
import
numpy
as
np
import
faiss
from
typing
import
List
import
pickle
import
cv2
import
socket
import
json
import
operator
from
multiprocessing
import
Process
"""
完整的index库如下:
root_path/ # 库存储目录
|-- image_list.txt # 图像列表,每行:image_path label。由前端生成及修改。后端只读
|-- features.pkl # 建库之后,保存的embedding向量,后端生成,前端无需操作
|-- images # 图像存储目录,由前端生成及增删查等操作。后端只读
| |-- md5.jpg
| |-- md5.jpg
| |-- ……
|-- index # 真正的生成的index库存储目录,后端生成及操作,前端无需操作。
| |-- vector.index # faiss生成的索引库
| |-- id_map.pkl # 索引文件
"""
class
ShiTuIndexManager
(
object
):
def
__init__
(
self
,
config
):
self
.
root_path
=
None
self
.
image_list_path
=
"image_list.txt"
self
.
image_dir
=
"images"
self
.
index_path
=
"index/vector.index"
self
.
id_map_path
=
"index/id_map.pkl"
self
.
features_path
=
"features.pkl"
self
.
index
=
None
self
.
id_map
=
None
self
.
features
=
None
self
.
config
=
config
self
.
predictor
=
RecPredictor
(
config
)
def
_load_pickle
(
self
,
path
):
if
os
.
path
.
exists
(
path
):
return
pickle
.
load
(
open
(
path
,
'rb'
))
else
:
return
None
def
_save_pickle
(
self
,
path
,
data
):
if
not
os
.
path
.
exists
(
os
.
path
.
dirname
(
path
)):
os
.
makedirs
(
os
.
path
.
dirname
(
path
),
exist_ok
=
True
)
with
open
(
path
,
'wb'
)
as
fd
:
pickle
.
dump
(
data
,
fd
)
def
_load_index
(
self
):
self
.
index
=
faiss
.
read_index
(
os
.
path
.
join
(
self
.
root_path
,
self
.
index_path
))
self
.
id_map
=
self
.
_load_pickle
(
os
.
path
.
join
(
self
.
root_path
,
self
.
id_map_path
))
self
.
features
=
self
.
_load_pickle
(
os
.
path
.
join
(
self
.
root_path
,
self
.
features_path
))
def
_save_index
(
self
,
index
,
id_map
,
features
):
faiss
.
write_index
(
index
,
os
.
path
.
join
(
self
.
root_path
,
self
.
index_path
))
self
.
_save_pickle
(
os
.
path
.
join
(
self
.
root_path
,
self
.
id_map_path
),
id_map
)
self
.
_save_pickle
(
os
.
path
.
join
(
self
.
root_path
,
self
.
features_path
),
features
)
def
_update_path
(
self
,
root_path
,
image_list_path
=
None
):
if
root_path
==
self
.
root_path
:
pass
else
:
self
.
root_path
=
root_path
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
root_path
,
"index"
)):
os
.
mkdir
(
os
.
path
.
join
(
root_path
,
"index"
))
if
image_list_path
is
not
None
:
self
.
image_list_path
=
image_list_path
def
_cal_featrue
(
self
,
image_list
):
batch_images
=
[]
featrures
=
None
cnt
=
0
for
idx
,
image_path
in
enumerate
(
image_list
):
image
=
cv2
.
imread
(
image_path
)
if
image
is
None
:
return
"{} is broken or not exist. Stop"
else
:
image
=
image
[:,
:,
::
-
1
]
batch_images
.
append
(
image
)
cnt
+=
1
if
cnt
%
self
.
config
[
"Global"
][
"batch_size"
]
==
0
or
(
idx
+
1
)
==
len
(
image_list
):
if
len
(
batch_images
)
==
0
:
continue
batch_results
=
self
.
predictor
.
predict
(
batch_images
)
featrures
=
batch_results
if
featrures
is
None
else
np
.
concatenate
(
(
featrures
,
batch_results
),
axis
=
0
)
batch_images
=
[]
return
featrures
def
_split_datafile
(
self
,
data_file
,
image_root
):
'''
data_file: image path and info, which can be splitted by spacer
image_root: image path root
delimiter: delimiter
'''
gallery_images
=
[]
gallery_docs
=
[]
gallery_ids
=
[]
with
open
(
data_file
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
lines
=
f
.
readlines
()
for
_
,
ori_line
in
enumerate
(
lines
):
line
=
ori_line
.
strip
().
split
()
text_num
=
len
(
line
)
assert
text_num
>=
2
,
f
"line(
{
ori_line
}
) must be splitted into at least 2 parts, but got
{
text_num
}
"
image_file
=
os
.
path
.
join
(
image_root
,
line
[
0
])
gallery_images
.
append
(
image_file
)
gallery_docs
.
append
(
ori_line
.
strip
())
gallery_ids
.
append
(
os
.
path
.
basename
(
line
[
0
]).
split
(
"."
)[
0
])
return
gallery_images
,
gallery_docs
,
gallery_ids
def
create_index
(
self
,
image_list
:
str
,
index_method
:
str
=
"HNSW32"
,
image_root
:
str
=
None
):
if
not
os
.
path
.
exists
(
image_list
):
return
"{} is not exist"
.
format
(
image_list
)
if
index_method
.
lower
()
not
in
[
'hnsw32'
,
'ivf'
,
'flat'
]:
return
"The index method Only support: HNSW32, IVF, Flat"
self
.
_update_path
(
os
.
path
.
dirname
(
image_list
),
image_list
)
# get image_paths
image_root
=
image_root
if
image_root
is
not
None
else
self
.
root_path
gallery_images
,
gallery_docs
,
image_ids
=
self
.
_split_datafile
(
image_list
,
image_root
)
# gernerate index
if
index_method
==
"IVF"
:
index_method
=
index_method
+
str
(
min
(
max
(
int
(
len
(
gallery_images
)
//
32
),
2
),
65536
))
+
",Flat"
index
=
faiss
.
index_factory
(
self
.
config
[
"IndexProcess"
][
"embedding_size"
],
index_method
,
faiss
.
METRIC_INNER_PRODUCT
)
self
.
index
=
faiss
.
IndexIDMap2
(
index
)
features
=
self
.
_cal_featrue
(
gallery_images
)
self
.
index
.
train
(
features
)
index_ids
=
np
.
arange
(
0
,
len
(
gallery_images
)).
astype
(
np
.
int64
)
self
.
index
.
add_with_ids
(
features
,
index_ids
)
self
.
id_map
=
dict
()
for
i
,
d
in
zip
(
list
(
index_ids
),
gallery_docs
):
self
.
id_map
[
i
]
=
d
self
.
features
=
{
"features"
:
features
,
"index_method"
:
index_method
,
"image_ids"
:
image_ids
,
"index_ids"
:
index_ids
.
tolist
()
}
self
.
_save_index
(
self
.
index
,
self
.
id_map
,
self
.
features
)
def
open_index
(
self
,
root_path
:
str
,
image_list_path
:
str
)
->
str
:
self
.
_update_path
(
root_path
)
_
,
_
,
image_ids
=
self
.
_split_datafile
(
image_list_path
,
root_path
)
if
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
root_path
,
self
.
index_path
))
and
\
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
root_path
,
self
.
id_map_path
))
and
\
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
root_path
,
self
.
features_path
)):
self
.
_update_path
(
root_path
)
self
.
_load_index
()
if
operator
.
eq
(
set
(
image_ids
),
set
(
self
.
features
[
'image_ids'
])):
return
""
else
:
return
"The image list is different from index, Please update index"
else
:
return
"File not exist: features.pkl, vector.index, id_map.pkl"
def
update_index
(
self
,
image_list
:
str
,
image_root
:
str
=
None
)
->
str
:
if
self
.
index
and
self
.
id_map
and
self
.
features
:
image_paths
,
image_docs
,
image_ids
=
self
.
_split_datafile
(
image_list
,
image_root
if
image_root
is
not
None
else
self
.
root_path
)
# for add image
add_ids
=
list
(
set
(
image_ids
).
difference
(
set
(
self
.
features
[
"image_ids"
])))
add_indexes
=
[
i
for
i
,
x
in
enumerate
(
image_ids
)
if
x
in
add_ids
]
add_image_paths
=
[
image_paths
[
i
]
for
i
in
add_indexes
]
add_image_docs
=
[
image_docs
[
i
]
for
i
in
add_indexes
]
add_image_ids
=
[
image_ids
[
i
]
for
i
in
add_indexes
]
self
.
_add_index
(
add_image_paths
,
add_image_docs
,
add_image_ids
)
# delete images
delete_ids
=
list
(
set
(
self
.
features
[
"image_ids"
]).
difference
(
set
(
image_ids
)))
self
.
_delete_index
(
delete_ids
)
self
.
_save_index
(
self
.
index
,
self
.
id_map
,
self
.
features
)
return
""
else
:
return
"Failed. Please create or open index first"
def
_add_index
(
self
,
image_list
:
List
,
image_docs
:
List
,
image_ids
:
List
):
if
len
(
image_ids
)
==
0
:
return
featrures
=
self
.
_cal_featrue
(
image_list
)
index_ids
=
(
np
.
arange
(
0
,
len
(
image_list
))
+
max
(
self
.
id_map
.
keys
())
+
1
).
astype
(
np
.
int64
)
self
.
index
.
add_with_ids
(
featrures
,
index_ids
)
for
i
,
d
in
zip
(
index_ids
,
image_docs
):
self
.
id_map
[
i
]
=
d
self
.
features
[
'features'
]
=
np
.
concatenate
(
[
self
.
features
[
'features'
],
featrures
],
axis
=
0
)
self
.
features
[
'image_ids'
].
extend
(
image_ids
)
self
.
features
[
'index_ids'
].
extend
(
index_ids
.
tolist
())
def
_delete_index
(
self
,
image_ids
:
List
):
if
len
(
image_ids
)
==
0
:
return
indexes
=
[
i
for
i
,
x
in
enumerate
(
self
.
features
[
'image_ids'
])
if
x
in
image_ids
]
self
.
features
[
"features"
]
=
np
.
delete
(
self
.
features
[
"features"
],
indexes
,
axis
=
0
)
self
.
features
[
"image_ids"
]
=
np
.
delete
(
np
.
asarray
(
self
.
features
[
"image_ids"
]),
indexes
,
axis
=
0
).
tolist
()
index_ids
=
np
.
delete
(
np
.
asarray
(
self
.
features
[
"index_ids"
]),
indexes
,
axis
=
0
).
tolist
()
id_map_values
=
[
self
.
id_map
[
i
]
for
i
in
index_ids
]
self
.
index
.
reset
()
ids
=
np
.
arange
(
0
,
len
(
id_map_values
)).
astype
(
np
.
int64
)
self
.
index
.
add_with_ids
(
self
.
features
[
'features'
],
ids
)
self
.
id_map
.
clear
()
for
i
,
d
in
zip
(
ids
,
id_map_values
):
self
.
id_map
[
i
]
=
d
self
.
features
[
"index_ids"
]
=
ids
app
=
FastAPI
()
@
app
.
get
(
"/new_index"
)
def
new_index
(
image_list_path
:
str
,
index_method
:
str
=
"HNSW32"
,
index_root_path
:
str
=
None
,
force
:
bool
=
False
):
result
=
""
try
:
if
index_root_path
is
not
None
:
image_list_path
=
os
.
path
.
join
(
index_root_path
,
image_list_path
)
index_path
=
os
.
path
.
join
(
index_root_path
,
"index"
,
"vector.index"
)
id_map_path
=
os
.
path
.
join
(
index_root_path
,
"index"
,
"id_map.pkl"
)
if
not
(
os
.
path
.
exists
(
index_path
)
and
os
.
path
.
exists
(
id_map_path
))
or
force
:
manager
.
create_index
(
image_list_path
,
index_method
,
index_root_path
)
else
:
result
=
"There alrealy has index in {}"
.
format
(
index_root_path
)
except
Exception
as
e
:
result
=
e
.
__str__
()
data
=
{
"error_message"
:
result
}
return
json
.
dumps
(
data
).
encode
()
@
app
.
get
(
"/open_index"
)
def
open_index
(
index_root_path
:
str
,
image_list_path
:
str
):
result
=
""
try
:
image_list_path
=
os
.
path
.
join
(
index_root_path
,
image_list_path
)
result
=
manager
.
open_index
(
index_root_path
,
image_list_path
)
except
Exception
as
e
:
result
=
e
.
__str__
()
data
=
{
"error_message"
:
result
}
return
json
.
dumps
(
data
).
encode
()
@
app
.
get
(
"/update_index"
)
def
update_index
(
image_list_path
:
str
,
index_root_path
:
str
=
None
):
result
=
""
try
:
if
index_root_path
is
not
None
:
image_list_path
=
os
.
path
.
join
(
index_root_path
,
image_list_path
)
result
=
manager
.
update_index
(
image_list
=
image_list_path
,
image_root
=
index_root_path
)
except
Exception
as
e
:
result
=
e
.
__str__
()
data
=
{
"error_message"
:
result
}
return
json
.
dumps
(
data
).
encode
()
def
FrontInterface
(
server_process
=
None
):
front
=
QtWidgets
.
QApplication
([])
main_window
=
mod
.
mainwindow
.
MainWindow
(
process
=
server_process
)
main_window
.
showMaximized
()
sys
.
exit
(
front
.
exec_
())
def
Server
(
app
,
host
,
port
):
uvicorn
.
run
(
app
,
host
=
host
,
port
=
port
)
if
__name__
==
'__main__'
:
args
=
config
.
parse_args
()
model_config
=
config
.
get_config
(
args
.
config
,
overrides
=
args
.
override
,
show
=
True
)
manager
=
ShiTuIndexManager
(
model_config
)
ip
=
model_config
.
get
(
'ip'
,
None
)
port
=
model_config
.
get
(
'port'
,
None
)
if
ip
is
None
or
port
is
None
:
try
:
ip
=
socket
.
gethostbyname
(
socket
.
gethostname
())
except
:
ip
=
'127.0.0.1'
port
=
8000
Server
(
app
,
ip
,
port
)
docs/zh_CN/inference_deployment/shitu_gallery_manager.md
浏览文件 @
8211c039
...
@@ -22,7 +22,7 @@
...
@@ -22,7 +22,7 @@
-
[
2. 使用说明
](
#2
)
-
[
2. 使用说明
](
#2
)
-
[
2.1 环境安装
](
#2.1
)
-
[
2.1 环境安装
](
#2.1
)
-
[
2.2 模型准备
](
#2.2
)
-
[
2.2 模型
及数据
准备
](
#2.2
)
-
[
2.3运行使用
](
#2.3
)
-
[
2.3运行使用
](
#2.3
)
-
[
3.生成文件介绍
](
#3
)
-
[
3.生成文件介绍
](
#3
)
...
@@ -123,13 +123,25 @@
...
@@ -123,13 +123,25 @@
pip
install
fastapi
pip
install
fastapi
pip
install
uvicorn
pip
install
uvicorn
pip
install
pyqt5
pip
install
pyqt5
pip
install
psutil
```
```
<a
name=
"2.2"
></a>
<a
name=
"2.2"
></a>
### 2.2 模型准备
### 2.2 模型
及数据
准备
请按照
[
PP-ShiTu快速体验
](
../quick_start/quick_start_recognition.md#2.2.1
)
中下载及准备inference model,并修改好
`${PaddleClas}/deploy/configs/inference_drink.yaml`
的相关参数。
请按照
[
PP-ShiTu快速体验
](
../quick_start/quick_start_recognition.md#2.2.1
)
中下载及准备inference model,并修改好
`${PaddleClas}/deploy/configs/inference_drink.yaml`
的相关参数,同时准备好数据集。在具体使用时,请替换好自己的数据集及模型文件。
```
shell
cd
${
PaddleClas
}
/deploy/shitu_index_manager
mkdir
models
cd
models
# 下载及解压识别模型
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/PP-ShiTuV2/general_PPLCNetV2_base_pretrained_v1.0_infer.tar
&&
tar
-xf
general_PPLCNetV2_base_pretrained_v1.0_infer.tar
cd
..
# 下载及解压示例数据集
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/drink_dataset_v2.0.tar
&&
tar
-xf
drink_dataset_v2.0.tar
```
<a
name=
"2.3"
></a>
<a
name=
"2.3"
></a>
...
@@ -139,9 +151,26 @@ pip install pyqt5
...
@@ -139,9 +151,26 @@ pip install pyqt5
```
shell
```
shell
cd
${
PaddleClas
}
/deploy/shitu_index_manager
cd
${
PaddleClas
}
/deploy/shitu_index_manager
python index_manager.py
-c
../configs/inference_drink.yaml
cp
../configs/inference_drink.yaml
.
# 注意如果没有按照2.2中准备数据集及代码,请手动修改inference_drink.yaml,做好适配
python index_manager.py
-c
inference_drink.yaml
```
```
运行成功后,会自动跳转到工具界面,可以按照如下步骤,生成新的index库。
1.
点击菜单栏
`新建图像库`
,会提示打开一个文件夹,此时请创建一个
**新的文件夹**
,并打开。如在
`${PaddleClas}/deploy/shitu_index_manager`
下新建一个
`drink_index`
文件夹
2.
导入图像,或者如上面功能介绍,自己手动新增类别和相应的图像,下面介绍两种导入图像方式,操作时,二选一即可。
-
点击
`导入图像`
->
`导入image_list图像`
,打开
`${PaddleClas}/deploy/shitu_index_manager/drink_dataset_v2.0/gallery/drink_label.txt`
,此时就可以将
`drink_label.txt`
中的图像全部导入进来,图像类别就是
`drink_label.txt`
中记录的类别。
-
点击
`导入图像`
->
`导入多文件夹图像`
,打开
`${PaddleClas}/deploy/shitu_index_manager/drink_dataset_v2.0/gallery/`
文件夹,此时就将
`gallery`
文件夹下,所有子文件夹都导入进来,图像类别就是子文件夹的名字。
3.
点击菜单栏中
`新建/重建 索引库`
,此时就会开始生成索引库。如果图片较多或者使用cpu来进行特征提取,那么耗时会比较长,请耐心等待。
4.
生成索引库成功后,会发现在
`drink_index`
文件夹下生成如
[
3
](
#3
)
中介绍的文件,此时
`index`
子文件夹下生出的文件,就是
`PP-ShiTu`
所使用的索引文件。
**注意**
:
-
利用此工具生成的index库,如
`drink_index`
文件夹,请妥善存储。之后,可以继续使用此工具中
`打开图像库`
功能,打开
`drink_index`
文件夹,继续对index库进行增删改查操作,具体功能可以查看
[
功能介绍
](
#1
)
。
-
打开一个生成好的库,在其上面进行增删改查操作后,请及时保存。保存后并及时使用菜单中
`更新索引库`
功能,对索引库进行更新
-
如果要使用自己的图像库文件,图像生成格式如示例数据格式,生成
`image_list.txt`
或者多文件夹存储,二选一。
<a
name=
"3"
></a>
<a
name=
"3"
></a>
## 3. 生成文件介绍
## 3. 生成文件介绍
...
@@ -192,4 +221,3 @@ index_root/ # 库存储目录
...
@@ -192,4 +221,3 @@ index_root/ # 库存储目录
-
问题4: 报错 图像与index库不一致
-
问题4: 报错 图像与index库不一致
答:可能用户自己修改了image_list.txt,修改完成后,请及时更新index库,保证其一致。
答:可能用户自己修改了image_list.txt,修改完成后,请及时更新index库,保证其一致。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录