提交 8211c039 编写于 作者: D dongshuilong

fix shitu_index manager bug

上级 3e0f7767
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from PyQt5 import QtCore, QtGui, QtWidgets
import mod.mainwindow
"""
完整的index库如下:
root_path/ # 库存储目录
|-- image_list.txt # 图像列表,每行:image_path label。由前端生成及修改。后端只读
|-- features.pkl # 建库之后,保存的embedding向量,后端生成,前端无需操作
|-- images # 图像存储目录,由前端生成及增删查等操作。后端只读
| |-- md5.jpg
| |-- md5.jpg
| |-- ……
|-- index # 真正的生成的index库存储目录,后端生成及操作,前端无需操作。
| |-- vector.index # faiss生成的索引库
| |-- id_map.pkl # 索引文件
"""
def FrontInterface(server_ip=None, server_port=None):
front = QtWidgets.QApplication([])
main_window = mod.mainwindow.MainWindow(ip=server_ip, port=server_port)
main_window.showMaximized()
sys.exit(front.exec_())
if __name__ == '__main__':
server_ip = None
server_port = None
if len(sys.argv) == 2 and len(sys.argv[1].split(' ')) == 2:
[server_ip, server_port] = sys.argv[1].split(' ')
FrontInterface(server_ip, server_port)
...@@ -13,22 +13,10 @@ ...@@ -13,22 +13,10 @@
# limitations under the License. # limitations under the License.
import os import os
import sys import sys
from PyQt5 import QtCore, QtGui, QtWidgets import subprocess
import mod.mainwindow import shlex
import psutil
from paddleclas.deploy.utils import config, logger import time
from paddleclas.deploy.python.predict_rec import RecPredictor
from fastapi import FastAPI
import uvicorn
import numpy as np
import faiss
from typing import List
import pickle
import cv2
import socket
import json
import operator
from multiprocessing import Process
""" """
完整的index库如下: 完整的index库如下:
root_path/ # 库存储目录 root_path/ # 库存储目录
...@@ -43,307 +31,34 @@ root_path/ # 库存储目录 ...@@ -43,307 +31,34 @@ root_path/ # 库存储目录
| |-- id_map.pkl # 索引文件 | |-- id_map.pkl # 索引文件
""" """
if __name__ == '__main__':
class ShiTuIndexManager(object): if not (len(sys.argv) == 3 or len(sys.argv) == 5):
print("start example:")
def __init__(self, config): print(" python index_manager.py -c xxx.yaml")
self.root_path = None print(" python index_manager.py -c xxx.yaml -p port")
self.image_list_path = "image_list.txt" yaml_path = sys.argv[2]
self.image_dir = "images" if len(sys.argv) == 5:
self.index_path = "index/vector.index" port = sys.argv[4]
self.id_map_path = "index/id_map.pkl"
self.features_path = "features.pkl"
self.index = None
self.id_map = None
self.features = None
self.config = config
self.predictor = RecPredictor(config)
def _load_pickle(self, path):
if os.path.exists(path):
return pickle.load(open(path, 'rb'))
else:
return None
def _save_pickle(self, path, data):
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, 'wb') as fd:
pickle.dump(data, fd)
def _load_index(self):
self.index = faiss.read_index(
os.path.join(self.root_path, self.index_path))
self.id_map = self._load_pickle(
os.path.join(self.root_path, self.id_map_path))
self.features = self._load_pickle(
os.path.join(self.root_path, self.features_path))
def _save_index(self, index, id_map, features):
faiss.write_index(index, os.path.join(self.root_path, self.index_path))
self._save_pickle(os.path.join(self.root_path, self.id_map_path),
id_map)
self._save_pickle(os.path.join(self.root_path, self.features_path),
features)
def _update_path(self, root_path, image_list_path=None):
if root_path == self.root_path:
pass
else:
self.root_path = root_path
if not os.path.exists(os.path.join(root_path, "index")):
os.mkdir(os.path.join(root_path, "index"))
if image_list_path is not None:
self.image_list_path = image_list_path
def _cal_featrue(self, image_list):
batch_images = []
featrures = None
cnt = 0
for idx, image_path in enumerate(image_list):
image = cv2.imread(image_path)
if image is None:
return "{} is broken or not exist. Stop"
else:
image = image[:, :, ::-1]
batch_images.append(image)
cnt += 1
if cnt % self.config["Global"]["batch_size"] == 0 or (
idx + 1) == len(image_list):
if len(batch_images) == 0:
continue
batch_results = self.predictor.predict(batch_images)
featrures = batch_results if featrures is None else np.concatenate(
(featrures, batch_results), axis=0)
batch_images = []
return featrures
def _split_datafile(self, data_file, image_root):
'''
data_file: image path and info, which can be splitted by spacer
image_root: image path root
delimiter: delimiter
'''
gallery_images = []
gallery_docs = []
gallery_ids = []
with open(data_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
for _, ori_line in enumerate(lines):
line = ori_line.strip().split()
text_num = len(line)
assert text_num >= 2, f"line({ori_line}) must be splitted into at least 2 parts, but got {text_num}"
image_file = os.path.join(image_root, line[0])
gallery_images.append(image_file)
gallery_docs.append(ori_line.strip())
gallery_ids.append(os.path.basename(line[0]).split(".")[0])
return gallery_images, gallery_docs, gallery_ids
def create_index(self,
image_list: str,
index_method: str = "HNSW32",
image_root: str = None):
if not os.path.exists(image_list):
return "{} is not exist".format(image_list)
if index_method.lower() not in ['hnsw32', 'ivf', 'flat']:
return "The index method Only support: HNSW32, IVF, Flat"
self._update_path(os.path.dirname(image_list), image_list)
# get image_paths
image_root = image_root if image_root is not None else self.root_path
gallery_images, gallery_docs, image_ids = self._split_datafile(
image_list, image_root)
# gernerate index
if index_method == "IVF":
index_method = index_method + str(
min(max(int(len(gallery_images) // 32), 2), 65536)) + ",Flat"
index = faiss.index_factory(
self.config["IndexProcess"]["embedding_size"], index_method,
faiss.METRIC_INNER_PRODUCT)
self.index = faiss.IndexIDMap2(index)
features = self._cal_featrue(gallery_images)
self.index.train(features)
index_ids = np.arange(0, len(gallery_images)).astype(np.int64)
self.index.add_with_ids(features, index_ids)
self.id_map = dict()
for i, d in zip(list(index_ids), gallery_docs):
self.id_map[i] = d
self.features = {
"features": features,
"index_method": index_method,
"image_ids": image_ids,
"index_ids": index_ids.tolist()
}
self._save_index(self.index, self.id_map, self.features)
def open_index(self, root_path: str, image_list_path: str) -> str:
self._update_path(root_path)
_, _, image_ids = self._split_datafile(image_list_path, root_path)
if os.path.exists(os.path.join(self.root_path, self.index_path)) and \
os.path.exists(os.path.join(self.root_path, self.id_map_path)) and \
os.path.exists(os.path.join(self.root_path, self.features_path)):
self._update_path(root_path)
self._load_index()
if operator.eq(set(image_ids), set(self.features['image_ids'])):
return ""
else:
return "The image list is different from index, Please update index"
else:
return "File not exist: features.pkl, vector.index, id_map.pkl"
def update_index(self, image_list: str, image_root: str = None) -> str:
if self.index and self.id_map and self.features:
image_paths, image_docs, image_ids = self._split_datafile(
image_list,
image_root if image_root is not None else self.root_path)
# for add image
add_ids = list(
set(image_ids).difference(set(self.features["image_ids"])))
add_indexes = [i for i, x in enumerate(image_ids) if x in add_ids]
add_image_paths = [image_paths[i] for i in add_indexes]
add_image_docs = [image_docs[i] for i in add_indexes]
add_image_ids = [image_ids[i] for i in add_indexes]
self._add_index(add_image_paths, add_image_docs, add_image_ids)
# delete images
delete_ids = list(
set(self.features["image_ids"]).difference(set(image_ids)))
self._delete_index(delete_ids)
self._save_index(self.index, self.id_map, self.features)
return ""
else:
return "Failed. Please create or open index first"
def _add_index(self, image_list: List, image_docs: List, image_ids: List):
if len(image_ids) == 0:
return
featrures = self._cal_featrue(image_list)
index_ids = (np.arange(0, len(image_list)) + max(self.id_map.keys()) +
1).astype(np.int64)
self.index.add_with_ids(featrures, index_ids)
for i, d in zip(index_ids, image_docs):
self.id_map[i] = d
self.features['features'] = np.concatenate(
[self.features['features'], featrures], axis=0)
self.features['image_ids'].extend(image_ids)
self.features['index_ids'].extend(index_ids.tolist())
def _delete_index(self, image_ids: List):
if len(image_ids) == 0:
return
indexes = [
i for i, x in enumerate(self.features['image_ids'])
if x in image_ids
]
self.features["features"] = np.delete(self.features["features"],
indexes,
axis=0)
self.features["image_ids"] = np.delete(np.asarray(
self.features["image_ids"]),
indexes,
axis=0).tolist()
index_ids = np.delete(np.asarray(self.features["index_ids"]),
indexes,
axis=0).tolist()
id_map_values = [self.id_map[i] for i in index_ids]
self.index.reset()
ids = np.arange(0, len(id_map_values)).astype(np.int64)
self.index.add_with_ids(self.features['features'], ids)
self.id_map.clear()
for i, d in zip(ids, id_map_values):
self.id_map[i] = d
self.features["index_ids"] = ids
app = FastAPI()
@app.get("/new_index")
def new_index(image_list_path: str,
index_method: str = "HNSW32",
index_root_path: str = None,
force: bool = False):
result = ""
try:
if index_root_path is not None:
image_list_path = os.path.join(index_root_path, image_list_path)
index_path = os.path.join(index_root_path, "index", "vector.index")
id_map_path = os.path.join(index_root_path, "index", "id_map.pkl")
if not (os.path.exists(index_path)
and os.path.exists(id_map_path)) or force:
manager.create_index(image_list_path, index_method, index_root_path)
else: else:
result = "There alrealy has index in {}".format(index_root_path) port = 8000
except Exception as e: assert int(port) > 1024 and int(
result = e.__str__() port) < 65536, "The port should be bigger than 1024 and \
data = {"error_message": result} smaller than 65536"
return json.dumps(data).encode()
@app.get("/open_index")
def open_index(index_root_path: str, image_list_path: str):
result = ""
try:
image_list_path = os.path.join(index_root_path, image_list_path)
result = manager.open_index(index_root_path, image_list_path)
except Exception as e:
result = e.__str__()
data = {"error_message": result}
return json.dumps(data).encode()
@app.get("/update_index")
def update_index(image_list_path: str, index_root_path: str = None):
result = ""
try:
if index_root_path is not None:
image_list_path = os.path.join(index_root_path, image_list_path)
result = manager.update_index(image_list=image_list_path,
image_root=index_root_path)
except Exception as e:
result = e.__str__()
data = {"error_message": result}
return json.dumps(data).encode()
def FrontInterface(server_process=None):
front = QtWidgets.QApplication([])
main_window = mod.mainwindow.MainWindow(process=server_process)
main_window.showMaximized()
sys.exit(front.exec_())
def Server(args):
[app, host, port] = args
uvicorn.run(app, host=host, port=port)
if __name__ == '__main__':
args = config.parse_args()
model_config = config.get_config(args.config,
overrides=args.override,
show=True)
manager = ShiTuIndexManager(model_config)
try: try:
ip = socket.gethostbyname(socket.gethostname()) ip = socket.gethostbyname(socket.gethostname())
except: except:
ip = '127.0.0.1' ip = '127.0.0.1'
port = 8000 server_cmd = "python server.py -c {} -o ip={} -o port={}".format(yaml_path,
p_server = Process(target=Server, args=([app, ip, port],)) ip, port)
p_server.start() server_proc = subprocess.Popen(shlex.split(server_cmd))
# p_client = Process(target=FrontInterface, args=()) client_proc = subprocess.Popen(
# p_client.start() ["python", "client.py", "{} {}".format(ip, port)])
# p_client.join() try:
FrontInterface(p_server) while psutil.Process(client_proc.pid).status() == "running":
p_server.terminate() time.sleep(0.5)
sys.exit(0) except:
pass
client_proc.terminate()
server_proc.terminate()
...@@ -22,8 +22,6 @@ try: ...@@ -22,8 +22,6 @@ try:
DEFAULT_HOST = socket.gethostbyname(socket.gethostname()) DEFAULT_HOST = socket.gethostbyname(socket.gethostname())
except: except:
DEFAULT_HOST = '127.0.0.1' DEFAULT_HOST = '127.0.0.1'
# DEFAULT_HOST = "localhost"
DEFAULT_PORT = 8000 DEFAULT_PORT = 8000
PADDLECLAS_DOC_URL = "https://gitee.com/paddlepaddle/PaddleClas/docs/zh_CN/inference_deployment/shitu_gallery_manager.md" PADDLECLAS_DOC_URL = "https://gitee.com/paddlepaddle/PaddleClas/docs/zh_CN/inference_deployment/shitu_gallery_manager.md"
...@@ -35,12 +33,17 @@ class MainWindow(QtWidgets.QMainWindow): ...@@ -35,12 +33,17 @@ class MainWindow(QtWidgets.QMainWindow):
updateIndexMsg = QtCore.pyqtSignal(str) # 更新索引库线程信号 updateIndexMsg = QtCore.pyqtSignal(str) # 更新索引库线程信号
importImageCount = QtCore.pyqtSignal(int) # 导入图像数量信号 importImageCount = QtCore.pyqtSignal(int) # 导入图像数量信号
def __init__(self, process=None): def __init__(self, ip=None, port=None):
super(MainWindow, self).__init__() super(MainWindow, self).__init__()
self.server_process = process if ip is not None and port is not None:
self.server_ip = ip
self.server_port = port
else:
self.server_ip = DEFAULT_HOST
self.server_port = DEFAULT_PORT
self.ui = ui_mainwindow.Ui_MainWindow() self.ui = ui_mainwindow.Ui_MainWindow()
self.ui.setupUi(self) # 初始化主窗口界面 self.ui.setupUi(self) # 初始化主窗口界面
self.__imageListMgr = image_list_manager.ImageListManager() self.__imageListMgr = image_list_manager.ImageListManager()
self.__appMenu = QtWidgets.QMenu() # 应用菜单 self.__appMenu = QtWidgets.QMenu() # 应用菜单
...@@ -115,8 +118,7 @@ class MainWindow(QtWidgets.QMainWindow): ...@@ -115,8 +118,7 @@ class MainWindow(QtWidgets.QMainWindow):
self.ui.saveImageLibraryBtn.clicked.connect(self.saveImageLibrary) self.ui.saveImageLibraryBtn.clicked.connect(self.saveImageLibrary)
self.__setToolButton(self.ui.addClassifyBtn, "添加分类", self.__setToolButton(self.ui.addClassifyBtn, "添加分类",
"./resource/add_classify.png", "./resource/add_classify.png", TOOL_BTN_ICON_SIZE)
TOOL_BTN_ICON_SIZE)
self.ui.addClassifyBtn.clicked.connect( self.ui.addClassifyBtn.clicked.connect(
self.__classifyUiContext.addClassify) self.__classifyUiContext.addClassify)
...@@ -145,7 +147,10 @@ class MainWindow(QtWidgets.QMainWindow): ...@@ -145,7 +147,10 @@ class MainWindow(QtWidgets.QMainWindow):
self.ui.searchClassifyHistoryCmb.setToolTip("查找分类历史") self.ui.searchClassifyHistoryCmb.setToolTip("查找分类历史")
self.ui.imageScaleSlider.setToolTip("图片缩放") self.ui.imageScaleSlider.setToolTip("图片缩放")
def __setToolButton(self, button, tool_tip: str, icon_path: str, def __setToolButton(self,
button,
tool_tip: str,
icon_path: str,
icon_size: int): icon_size: int):
"""设置工具按钮""" """设置工具按钮"""
button.setToolTip(tool_tip) button.setToolTip(tool_tip)
...@@ -179,16 +184,16 @@ class MainWindow(QtWidgets.QMainWindow): ...@@ -179,16 +184,16 @@ class MainWindow(QtWidgets.QMainWindow):
def __initWaitDialog(self): def __initWaitDialog(self):
"""初始化等待对话框""" """初始化等待对话框"""
self.__waitDialogUi.setupUi(self.__waitDialog) self.__waitDialogUi.setupUi(self.__waitDialog)
self.__waitDialog.setWindowFlags(QtCore.Qt.Dialog self.__waitDialog.setWindowFlags(QtCore.Qt.Dialog |
| QtCore.Qt.FramelessWindowHint) QtCore.Qt.FramelessWindowHint)
def __startWait(self, msg: str): def __startWait(self, msg: str):
"""开始显示等待对话框""" """开始显示等待对话框"""
self.setEnabled(False) self.setEnabled(False)
self.__waitDialogUi.msgLabel.setText(msg) self.__waitDialogUi.msgLabel.setText(msg)
self.__waitDialog.setWindowFlags(QtCore.Qt.Dialog self.__waitDialog.setWindowFlags(QtCore.Qt.Dialog |
| QtCore.Qt.FramelessWindowHint QtCore.Qt.FramelessWindowHint |
| QtCore.Qt.WindowStaysOnTopHint) QtCore.Qt.WindowStaysOnTopHint)
self.__waitDialog.show() self.__waitDialog.show()
self.__waitDialog.repaint() self.__waitDialog.repaint()
...@@ -196,9 +201,9 @@ class MainWindow(QtWidgets.QMainWindow): ...@@ -196,9 +201,9 @@ class MainWindow(QtWidgets.QMainWindow):
"""停止显示等待对话框""" """停止显示等待对话框"""
self.setEnabled(True) self.setEnabled(True)
self.__waitDialogUi.msgLabel.setText("执行完毕!") self.__waitDialogUi.msgLabel.setText("执行完毕!")
self.__waitDialog.setWindowFlags(QtCore.Qt.Dialog self.__waitDialog.setWindowFlags(QtCore.Qt.Dialog |
| QtCore.Qt.FramelessWindowHint QtCore.Qt.FramelessWindowHint |
| QtCore.Qt.CustomizeWindowHint) QtCore.Qt.CustomizeWindowHint)
self.__waitDialog.close() self.__waitDialog.close()
def __connectSignal(self): def __connectSignal(self):
...@@ -290,8 +295,8 @@ class MainWindow(QtWidgets.QMainWindow): ...@@ -290,8 +295,8 @@ class MainWindow(QtWidgets.QMainWindow):
def __importImageListImageThread(self, from_path: str, to_path: str): def __importImageListImageThread(self, from_path: str, to_path: str):
"""导入 image_list 图像 线程""" """导入 image_list 图像 线程"""
count = utils.oneKeyImportFromFile(from_path=from_path, count = utils.oneKeyImportFromFile(
to_path=to_path) from_path=from_path, to_path=to_path)
if count == None: if count == None:
count = -1 count = -1
self.importImageCount.emit(count) self.importImageCount.emit(count)
...@@ -308,9 +313,9 @@ class MainWindow(QtWidgets.QMainWindow): ...@@ -308,9 +313,9 @@ class MainWindow(QtWidgets.QMainWindow):
return return
from_mgr = image_list_manager.ImageListManager(from_path) from_mgr = image_list_manager.ImageListManager(from_path)
self.__startWait("正在导入图像,请等待。。。") self.__startWait("正在导入图像,请等待。。。")
thread = threading.Thread(target=self.__importImageListImageThread, thread = threading.Thread(
args=(from_mgr.filePath, target=self.__importImageListImageThread,
self.__imageListMgr.filePath)) args=(from_mgr.filePath, self.__imageListMgr.filePath))
thread.start() thread.start()
def __importDirsImageThread(self, from_dir: str, to_image_list_path: str): def __importDirsImageThread(self, from_dir: str, to_image_list_path: str):
...@@ -333,18 +338,22 @@ class MainWindow(QtWidgets.QMainWindow): ...@@ -333,18 +338,22 @@ class MainWindow(QtWidgets.QMainWindow):
QtWidgets.QMessageBox.information(self, "提示", "打开的目录不存在") QtWidgets.QMessageBox.information(self, "提示", "打开的目录不存在")
return return
self.__startWait("正在导入图像,请等待。。。") self.__startWait("正在导入图像,请等待。。。")
thread = threading.Thread(target=self.__importDirsImageThread, thread = threading.Thread(
args=(dir_path, target=self.__importDirsImageThread,
self.__imageListMgr.filePath)) args=(dir_path, self.__imageListMgr.filePath))
thread.start() thread.start()
def __newIndexThread(self, index_root_path: str, image_list_path: str, def __newIndexThread(self,
index_method: str, force: bool): index_root_path: str,
image_list_path: str,
index_method: str,
force: bool):
"""新建重建索引库线程""" """新建重建索引库线程"""
try: try:
client = index_http_client.IndexHttpClient( client = index_http_client.IndexHttpClient(self.server_ip,
DEFAULT_HOST, DEFAULT_PORT) self.server_port)
err_msg = client.new_index(image_list_path=image_list_path, err_msg = client.new_index(
image_list_path=image_list_path,
index_root_path=index_root_path, index_root_path=index_root_path,
index_method=index_method, index_method=index_method,
force=force) force=force)
...@@ -375,18 +384,19 @@ class MainWindow(QtWidgets.QMainWindow): ...@@ -375,18 +384,19 @@ class MainWindow(QtWidgets.QMainWindow):
force = ui.resetCheckBox.isChecked() force = ui.resetCheckBox.isChecked()
if result == QtWidgets.QDialog.Accepted: if result == QtWidgets.QDialog.Accepted:
self.__startWait("正在 新建/重建 索引库,请等待。。。") self.__startWait("正在 新建/重建 索引库,请等待。。。")
thread = threading.Thread(target=self.__newIndexThread, thread = threading.Thread(
args=(self.__imageListMgr.dirName, target=self.__newIndexThread,
"image_list.txt", index_method, args=(self.__imageListMgr.dirName, "image_list.txt",
force)) index_method, force))
thread.start() thread.start()
def __openIndexThread(self, index_root_path: str, image_list_path: str): def __openIndexThread(self, index_root_path: str, image_list_path: str):
"""打开索引库线程""" """打开索引库线程"""
try: try:
client = index_http_client.IndexHttpClient( client = index_http_client.IndexHttpClient(self.server_ip,
DEFAULT_HOST, DEFAULT_PORT) self.server_port)
err_msg = client.open_index(index_root_path=index_root_path, err_msg = client.open_index(
index_root_path=index_root_path,
image_list_path=image_list_path) image_list_path=image_list_path)
if err_msg == None: if err_msg == None:
err_msg = "" err_msg = ""
...@@ -408,17 +418,18 @@ class MainWindow(QtWidgets.QMainWindow): ...@@ -408,17 +418,18 @@ class MainWindow(QtWidgets.QMainWindow):
QtWidgets.QMessageBox.information(self, "提示", "请先打开正确的图像库") QtWidgets.QMessageBox.information(self, "提示", "请先打开正确的图像库")
return return
self.__startWait("正在打开索引库,请等待。。。") self.__startWait("正在打开索引库,请等待。。。")
thread = threading.Thread(target=self.__openIndexThread, thread = threading.Thread(
args=(self.__imageListMgr.dirName, target=self.__openIndexThread,
"image_list.txt")) args=(self.__imageListMgr.dirName, "image_list.txt"))
thread.start() thread.start()
def __updateIndexThread(self, index_root_path: str, image_list_path: str): def __updateIndexThread(self, index_root_path: str, image_list_path: str):
"""更新索引库线程""" """更新索引库线程"""
try: try:
client = index_http_client.IndexHttpClient( client = index_http_client.IndexHttpClient(self.server_ip,
DEFAULT_HOST, DEFAULT_PORT) self.server_port)
err_msg = client.update_index(image_list_path=image_list_path, err_msg = client.update_index(
image_list_path=image_list_path,
index_root_path=index_root_path) index_root_path=index_root_path)
if err_msg == None: if err_msg == None:
err_msg = "" err_msg = ""
...@@ -440,9 +451,9 @@ class MainWindow(QtWidgets.QMainWindow): ...@@ -440,9 +451,9 @@ class MainWindow(QtWidgets.QMainWindow):
QtWidgets.QMessageBox.information(self, "提示", "请先打开正确的图像库") QtWidgets.QMessageBox.information(self, "提示", "请先打开正确的图像库")
return return
self.__startWait("正在更新索引库,请等待。。。") self.__startWait("正在更新索引库,请等待。。。")
thread = threading.Thread(target=self.__updateIndexThread, thread = threading.Thread(
args=(self.__imageListMgr.dirName, target=self.__updateIndexThread,
"image_list.txt")) args=(self.__imageListMgr.dirName, "image_list.txt"))
thread.start() thread.start()
def searchClassify(self): def searchClassify(self):
...@@ -471,9 +482,6 @@ class MainWindow(QtWidgets.QMainWindow): ...@@ -471,9 +482,6 @@ class MainWindow(QtWidgets.QMainWindow):
def exitApp(self): def exitApp(self):
"""退出应用""" """退出应用"""
if isinstance(self.server_process, Process):
self.server_process.terminate()
# os.kill(self.server_pid)
sys.exit(0) sys.exit(0)
def __setPathBar(self, msg: str): def __setPathBar(self, msg: str):
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from PyQt5 import QtCore, QtGui, QtWidgets
import mod.mainwindow
from paddleclas.deploy.utils import config, logger
from paddleclas.deploy.python.predict_rec import RecPredictor
from fastapi import FastAPI
import uvicorn
import numpy as np
import faiss
from typing import List
import pickle
import cv2
import socket
import json
import operator
from multiprocessing import Process
"""
完整的index库如下:
root_path/ # 库存储目录
|-- image_list.txt # 图像列表,每行:image_path label。由前端生成及修改。后端只读
|-- features.pkl # 建库之后,保存的embedding向量,后端生成,前端无需操作
|-- images # 图像存储目录,由前端生成及增删查等操作。后端只读
| |-- md5.jpg
| |-- md5.jpg
| |-- ……
|-- index # 真正的生成的index库存储目录,后端生成及操作,前端无需操作。
| |-- vector.index # faiss生成的索引库
| |-- id_map.pkl # 索引文件
"""
class ShiTuIndexManager(object):
def __init__(self, config):
self.root_path = None
self.image_list_path = "image_list.txt"
self.image_dir = "images"
self.index_path = "index/vector.index"
self.id_map_path = "index/id_map.pkl"
self.features_path = "features.pkl"
self.index = None
self.id_map = None
self.features = None
self.config = config
self.predictor = RecPredictor(config)
def _load_pickle(self, path):
if os.path.exists(path):
return pickle.load(open(path, 'rb'))
else:
return None
def _save_pickle(self, path, data):
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, 'wb') as fd:
pickle.dump(data, fd)
def _load_index(self):
self.index = faiss.read_index(
os.path.join(self.root_path, self.index_path))
self.id_map = self._load_pickle(
os.path.join(self.root_path, self.id_map_path))
self.features = self._load_pickle(
os.path.join(self.root_path, self.features_path))
def _save_index(self, index, id_map, features):
faiss.write_index(index, os.path.join(self.root_path, self.index_path))
self._save_pickle(
os.path.join(self.root_path, self.id_map_path), id_map)
self._save_pickle(
os.path.join(self.root_path, self.features_path), features)
def _update_path(self, root_path, image_list_path=None):
if root_path == self.root_path:
pass
else:
self.root_path = root_path
if not os.path.exists(os.path.join(root_path, "index")):
os.mkdir(os.path.join(root_path, "index"))
if image_list_path is not None:
self.image_list_path = image_list_path
def _cal_featrue(self, image_list):
batch_images = []
featrures = None
cnt = 0
for idx, image_path in enumerate(image_list):
image = cv2.imread(image_path)
if image is None:
return "{} is broken or not exist. Stop"
else:
image = image[:, :, ::-1]
batch_images.append(image)
cnt += 1
if cnt % self.config["Global"]["batch_size"] == 0 or (
idx + 1) == len(image_list):
if len(batch_images) == 0:
continue
batch_results = self.predictor.predict(batch_images)
featrures = batch_results if featrures is None else np.concatenate(
(featrures, batch_results), axis=0)
batch_images = []
return featrures
def _split_datafile(self, data_file, image_root):
'''
data_file: image path and info, which can be splitted by spacer
image_root: image path root
delimiter: delimiter
'''
gallery_images = []
gallery_docs = []
gallery_ids = []
with open(data_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
for _, ori_line in enumerate(lines):
line = ori_line.strip().split()
text_num = len(line)
assert text_num >= 2, f"line({ori_line}) must be splitted into at least 2 parts, but got {text_num}"
image_file = os.path.join(image_root, line[0])
gallery_images.append(image_file)
gallery_docs.append(ori_line.strip())
gallery_ids.append(os.path.basename(line[0]).split(".")[0])
return gallery_images, gallery_docs, gallery_ids
def create_index(self,
image_list: str,
index_method: str="HNSW32",
image_root: str=None):
if not os.path.exists(image_list):
return "{} is not exist".format(image_list)
if index_method.lower() not in ['hnsw32', 'ivf', 'flat']:
return "The index method Only support: HNSW32, IVF, Flat"
self._update_path(os.path.dirname(image_list), image_list)
# get image_paths
image_root = image_root if image_root is not None else self.root_path
gallery_images, gallery_docs, image_ids = self._split_datafile(
image_list, image_root)
# gernerate index
if index_method == "IVF":
index_method = index_method + str(
min(max(int(len(gallery_images) // 32), 2), 65536)) + ",Flat"
index = faiss.index_factory(
self.config["IndexProcess"]["embedding_size"], index_method,
faiss.METRIC_INNER_PRODUCT)
self.index = faiss.IndexIDMap2(index)
features = self._cal_featrue(gallery_images)
self.index.train(features)
index_ids = np.arange(0, len(gallery_images)).astype(np.int64)
self.index.add_with_ids(features, index_ids)
self.id_map = dict()
for i, d in zip(list(index_ids), gallery_docs):
self.id_map[i] = d
self.features = {
"features": features,
"index_method": index_method,
"image_ids": image_ids,
"index_ids": index_ids.tolist()
}
self._save_index(self.index, self.id_map, self.features)
def open_index(self, root_path: str, image_list_path: str) -> str:
self._update_path(root_path)
_, _, image_ids = self._split_datafile(image_list_path, root_path)
if os.path.exists(os.path.join(self.root_path, self.index_path)) and \
os.path.exists(os.path.join(self.root_path, self.id_map_path)) and \
os.path.exists(os.path.join(self.root_path, self.features_path)):
self._update_path(root_path)
self._load_index()
if operator.eq(set(image_ids), set(self.features['image_ids'])):
return ""
else:
return "The image list is different from index, Please update index"
else:
return "File not exist: features.pkl, vector.index, id_map.pkl"
def update_index(self, image_list: str, image_root: str=None) -> str:
if self.index and self.id_map and self.features:
image_paths, image_docs, image_ids = self._split_datafile(
image_list, image_root
if image_root is not None else self.root_path)
# for add image
add_ids = list(
set(image_ids).difference(set(self.features["image_ids"])))
add_indexes = [i for i, x in enumerate(image_ids) if x in add_ids]
add_image_paths = [image_paths[i] for i in add_indexes]
add_image_docs = [image_docs[i] for i in add_indexes]
add_image_ids = [image_ids[i] for i in add_indexes]
self._add_index(add_image_paths, add_image_docs, add_image_ids)
# delete images
delete_ids = list(
set(self.features["image_ids"]).difference(set(image_ids)))
self._delete_index(delete_ids)
self._save_index(self.index, self.id_map, self.features)
return ""
else:
return "Failed. Please create or open index first"
def _add_index(self, image_list: List, image_docs: List, image_ids: List):
if len(image_ids) == 0:
return
featrures = self._cal_featrue(image_list)
index_ids = (
np.arange(0, len(image_list)) + max(self.id_map.keys()) + 1
).astype(np.int64)
self.index.add_with_ids(featrures, index_ids)
for i, d in zip(index_ids, image_docs):
self.id_map[i] = d
self.features['features'] = np.concatenate(
[self.features['features'], featrures], axis=0)
self.features['image_ids'].extend(image_ids)
self.features['index_ids'].extend(index_ids.tolist())
def _delete_index(self, image_ids: List):
if len(image_ids) == 0:
return
indexes = [
i for i, x in enumerate(self.features['image_ids'])
if x in image_ids
]
self.features["features"] = np.delete(
self.features["features"], indexes, axis=0)
self.features["image_ids"] = np.delete(
np.asarray(self.features["image_ids"]), indexes, axis=0).tolist()
index_ids = np.delete(
np.asarray(self.features["index_ids"]), indexes, axis=0).tolist()
id_map_values = [self.id_map[i] for i in index_ids]
self.index.reset()
ids = np.arange(0, len(id_map_values)).astype(np.int64)
self.index.add_with_ids(self.features['features'], ids)
self.id_map.clear()
for i, d in zip(ids, id_map_values):
self.id_map[i] = d
self.features["index_ids"] = ids
app = FastAPI()
@app.get("/new_index")
def new_index(image_list_path: str,
index_method: str="HNSW32",
index_root_path: str=None,
force: bool=False):
result = ""
try:
if index_root_path is not None:
image_list_path = os.path.join(index_root_path, image_list_path)
index_path = os.path.join(index_root_path, "index", "vector.index")
id_map_path = os.path.join(index_root_path, "index", "id_map.pkl")
if not (os.path.exists(index_path) and
os.path.exists(id_map_path)) or force:
manager.create_index(image_list_path, index_method,
index_root_path)
else:
result = "There alrealy has index in {}".format(index_root_path)
except Exception as e:
result = e.__str__()
data = {"error_message": result}
return json.dumps(data).encode()
@app.get("/open_index")
def open_index(index_root_path: str, image_list_path: str):
result = ""
try:
image_list_path = os.path.join(index_root_path, image_list_path)
result = manager.open_index(index_root_path, image_list_path)
except Exception as e:
result = e.__str__()
data = {"error_message": result}
return json.dumps(data).encode()
@app.get("/update_index")
def update_index(image_list_path: str, index_root_path: str=None):
result = ""
try:
if index_root_path is not None:
image_list_path = os.path.join(index_root_path, image_list_path)
result = manager.update_index(
image_list=image_list_path, image_root=index_root_path)
except Exception as e:
result = e.__str__()
data = {"error_message": result}
return json.dumps(data).encode()
def FrontInterface(server_process=None):
front = QtWidgets.QApplication([])
main_window = mod.mainwindow.MainWindow(process=server_process)
main_window.showMaximized()
sys.exit(front.exec_())
def Server(app, host, port):
uvicorn.run(app, host=host, port=port)
if __name__ == '__main__':
args = config.parse_args()
model_config = config.get_config(
args.config, overrides=args.override, show=True)
manager = ShiTuIndexManager(model_config)
ip = model_config.get('ip', None)
port = model_config.get('port', None)
if ip is None or port is None:
try:
ip = socket.gethostbyname(socket.gethostname())
except:
ip = '127.0.0.1'
port = 8000
Server(app, ip, port)
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
- [2. 使用说明](#2) - [2. 使用说明](#2)
- [2.1 环境安装](#2.1) - [2.1 环境安装](#2.1)
- [2.2 模型准备](#2.2) - [2.2 模型及数据准备](#2.2)
- [2.3运行使用](#2.3) - [2.3运行使用](#2.3)
- [3.生成文件介绍](#3) - [3.生成文件介绍](#3)
...@@ -123,13 +123,25 @@ ...@@ -123,13 +123,25 @@
pip install fastapi pip install fastapi
pip install uvicorn pip install uvicorn
pip install pyqt5 pip install pyqt5
pip install psutil
``` ```
<a name="2.2"></a> <a name="2.2"></a>
### 2.2 模型准备 ### 2.2 模型及数据准备
请按照[PP-ShiTu快速体验](../quick_start/quick_start_recognition.md#2.2.1)中下载及准备inference model,并修改好`${PaddleClas}/deploy/configs/inference_drink.yaml`的相关参数。 请按照[PP-ShiTu快速体验](../quick_start/quick_start_recognition.md#2.2.1)中下载及准备inference model,并修改好`${PaddleClas}/deploy/configs/inference_drink.yaml`的相关参数,同时准备好数据集。在具体使用时,请替换好自己的数据集及模型文件。
```shell
cd ${PaddleClas}/deploy/shitu_index_manager
mkdir models
cd models
# 下载及解压识别模型
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/PP-ShiTuV2/general_PPLCNetV2_base_pretrained_v1.0_infer.tar && tar -xf general_PPLCNetV2_base_pretrained_v1.0_infer.tar
cd ..
# 下载及解压示例数据集
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/drink_dataset_v2.0.tar && tar -xf drink_dataset_v2.0.tar
```
<a name="2.3"></a> <a name="2.3"></a>
...@@ -139,9 +151,26 @@ pip install pyqt5 ...@@ -139,9 +151,26 @@ pip install pyqt5
```shell ```shell
cd ${PaddleClas}/deploy/shitu_index_manager cd ${PaddleClas}/deploy/shitu_index_manager
python index_manager.py -c ../configs/inference_drink.yaml cp ../configs/inference_drink.yaml .
# 注意如果没有按照2.2中准备数据集及代码,请手动修改inference_drink.yaml,做好适配
python index_manager.py -c inference_drink.yaml
``` ```
运行成功后,会自动跳转到工具界面,可以按照如下步骤,生成新的index库。
1. 点击菜单栏`新建图像库`,会提示打开一个文件夹,此时请创建一个**新的文件夹**,并打开。如在`${PaddleClas}/deploy/shitu_index_manager`下新建一个`drink_index`文件夹
2. 导入图像,或者如上面功能介绍,自己手动新增类别和相应的图像,下面介绍两种导入图像方式,操作时,二选一即可。
- 点击`导入图像`->`导入image_list图像`,打开`${PaddleClas}/deploy/shitu_index_manager/drink_dataset_v2.0/gallery/drink_label.txt`,此时就可以将`drink_label.txt`中的图像全部导入进来,图像类别就是`drink_label.txt`中记录的类别。
- 点击`导入图像`->`导入多文件夹图像`,打开`${PaddleClas}/deploy/shitu_index_manager/drink_dataset_v2.0/gallery/`文件夹,此时就将`gallery`文件夹下,所有子文件夹都导入进来,图像类别就是子文件夹的名字。
3. 点击菜单栏中`新建/重建 索引库`,此时就会开始生成索引库。如果图片较多或者使用cpu来进行特征提取,那么耗时会比较长,请耐心等待。
4. 生成索引库成功后,会发现在`drink_index`文件夹下生成如[3](#3) 中介绍的文件,此时`index`子文件夹下生出的文件,就是`PP-ShiTu`所使用的索引文件。
**注意**
- 利用此工具生成的index库,如`drink_index`文件夹,请妥善存储。之后,可以继续使用此工具中`打开图像库`功能,打开`drink_index`文件夹,继续对index库进行增删改查操作,具体功能可以查看[功能介绍](#1)
- 打开一个生成好的库,在其上面进行增删改查操作后,请及时保存。保存后并及时使用菜单中`更新索引库`功能,对索引库进行更新
- 如果要使用自己的图像库文件,图像生成格式如示例数据格式,生成`image_list.txt`或者多文件夹存储,二选一。
<a name="3"></a> <a name="3"></a>
## 3. 生成文件介绍 ## 3. 生成文件介绍
...@@ -192,4 +221,3 @@ index_root/ # 库存储目录 ...@@ -192,4 +221,3 @@ index_root/ # 库存储目录
- 问题4: 报错 图像与index库不一致 - 问题4: 报错 图像与index库不一致
答:可能用户自己修改了image_list.txt,修改完成后,请及时更新index库,保证其一致。 答:可能用户自己修改了image_list.txt,修改完成后,请及时更新index库,保证其一致。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册