diff --git a/deploy/shitu_index_manager/client.py b/deploy/shitu_index_manager/client.py
new file mode 100644
index 0000000000000000000000000000000000000000..229691a06df14a90fa8f5497d0a4396024521e5f
--- /dev/null
+++ b/deploy/shitu_index_manager/client.py
@@ -0,0 +1,45 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import sys
+from PyQt5 import QtCore, QtGui, QtWidgets
+import mod.mainwindow
+"""
+完整的index库如下:
+root_path/ # 库存储目录
+|-- image_list.txt # 图像列表,每行:image_path label。由前端生成及修改。后端只读
+|-- features.pkl # 建库之后,保存的embedding向量,后端生成,前端无需操作
+|-- images # 图像存储目录,由前端生成及增删查等操作。后端只读
+| |-- md5.jpg
+| |-- md5.jpg
+| |-- ……
+|-- index # 真正的生成的index库存储目录,后端生成及操作,前端无需操作。
+| |-- vector.index # faiss生成的索引库
+| |-- id_map.pkl # 索引文件
+"""
+
+
+def FrontInterface(server_ip=None, server_port=None):
+ front = QtWidgets.QApplication([])
+ main_window = mod.mainwindow.MainWindow(ip=server_ip, port=server_port)
+ main_window.showMaximized()
+ sys.exit(front.exec_())
+
+
+if __name__ == '__main__':
+ server_ip = None
+ server_port = None
+ if len(sys.argv) == 2 and len(sys.argv[1].split(' ')) == 2:
+ [server_ip, server_port] = sys.argv[1].split(' ')
+ FrontInterface(server_ip, server_port)
diff --git a/deploy/shitu_index_manager/index_manager.py b/deploy/shitu_index_manager/index_manager.py
index 97e3eec561cf7a45476bd750624d721dfd85fdb9..bfcfc1369711fef841e11018fb101629e32c23f1 100644
--- a/deploy/shitu_index_manager/index_manager.py
+++ b/deploy/shitu_index_manager/index_manager.py
@@ -13,22 +13,10 @@
# limitations under the License.
import os
import sys
-from PyQt5 import QtCore, QtGui, QtWidgets
-import mod.mainwindow
-
-from paddleclas.deploy.utils import config, logger
-from paddleclas.deploy.python.predict_rec import RecPredictor
-from fastapi import FastAPI
-import uvicorn
-import numpy as np
-import faiss
-from typing import List
-import pickle
-import cv2
-import socket
-import json
-import operator
-from multiprocessing import Process
+import subprocess
+import shlex
+import psutil
+import time
"""
完整的index库如下:
root_path/ # 库存储目录
@@ -43,307 +31,34 @@ root_path/ # 库存储目录
| |-- id_map.pkl # 索引文件
"""
-
-class ShiTuIndexManager(object):
-
- def __init__(self, config):
- self.root_path = None
- self.image_list_path = "image_list.txt"
- self.image_dir = "images"
- self.index_path = "index/vector.index"
- self.id_map_path = "index/id_map.pkl"
- self.features_path = "features.pkl"
- self.index = None
- self.id_map = None
- self.features = None
- self.config = config
- self.predictor = RecPredictor(config)
-
- def _load_pickle(self, path):
- if os.path.exists(path):
- return pickle.load(open(path, 'rb'))
- else:
- return None
-
- def _save_pickle(self, path, data):
- if not os.path.exists(os.path.dirname(path)):
- os.makedirs(os.path.dirname(path), exist_ok=True)
- with open(path, 'wb') as fd:
- pickle.dump(data, fd)
-
- def _load_index(self):
- self.index = faiss.read_index(
- os.path.join(self.root_path, self.index_path))
- self.id_map = self._load_pickle(
- os.path.join(self.root_path, self.id_map_path))
- self.features = self._load_pickle(
- os.path.join(self.root_path, self.features_path))
-
- def _save_index(self, index, id_map, features):
- faiss.write_index(index, os.path.join(self.root_path, self.index_path))
- self._save_pickle(os.path.join(self.root_path, self.id_map_path),
- id_map)
- self._save_pickle(os.path.join(self.root_path, self.features_path),
- features)
-
- def _update_path(self, root_path, image_list_path=None):
- if root_path == self.root_path:
- pass
- else:
- self.root_path = root_path
- if not os.path.exists(os.path.join(root_path, "index")):
- os.mkdir(os.path.join(root_path, "index"))
- if image_list_path is not None:
- self.image_list_path = image_list_path
-
- def _cal_featrue(self, image_list):
- batch_images = []
- featrures = None
- cnt = 0
- for idx, image_path in enumerate(image_list):
- image = cv2.imread(image_path)
- if image is None:
- return "{} is broken or not exist. Stop"
- else:
- image = image[:, :, ::-1]
- batch_images.append(image)
- cnt += 1
- if cnt % self.config["Global"]["batch_size"] == 0 or (
- idx + 1) == len(image_list):
- if len(batch_images) == 0:
- continue
- batch_results = self.predictor.predict(batch_images)
- featrures = batch_results if featrures is None else np.concatenate(
- (featrures, batch_results), axis=0)
- batch_images = []
- return featrures
-
- def _split_datafile(self, data_file, image_root):
- '''
- data_file: image path and info, which can be splitted by spacer
- image_root: image path root
- delimiter: delimiter
- '''
- gallery_images = []
- gallery_docs = []
- gallery_ids = []
- with open(data_file, 'r', encoding='utf-8') as f:
- lines = f.readlines()
- for _, ori_line in enumerate(lines):
- line = ori_line.strip().split()
- text_num = len(line)
- assert text_num >= 2, f"line({ori_line}) must be splitted into at least 2 parts, but got {text_num}"
- image_file = os.path.join(image_root, line[0])
-
- gallery_images.append(image_file)
- gallery_docs.append(ori_line.strip())
- gallery_ids.append(os.path.basename(line[0]).split(".")[0])
-
- return gallery_images, gallery_docs, gallery_ids
-
- def create_index(self,
- image_list: str,
- index_method: str = "HNSW32",
- image_root: str = None):
- if not os.path.exists(image_list):
- return "{} is not exist".format(image_list)
- if index_method.lower() not in ['hnsw32', 'ivf', 'flat']:
- return "The index method Only support: HNSW32, IVF, Flat"
- self._update_path(os.path.dirname(image_list), image_list)
-
- # get image_paths
- image_root = image_root if image_root is not None else self.root_path
- gallery_images, gallery_docs, image_ids = self._split_datafile(
- image_list, image_root)
-
- # gernerate index
- if index_method == "IVF":
- index_method = index_method + str(
- min(max(int(len(gallery_images) // 32), 2), 65536)) + ",Flat"
- index = faiss.index_factory(
- self.config["IndexProcess"]["embedding_size"], index_method,
- faiss.METRIC_INNER_PRODUCT)
- self.index = faiss.IndexIDMap2(index)
- features = self._cal_featrue(gallery_images)
- self.index.train(features)
- index_ids = np.arange(0, len(gallery_images)).astype(np.int64)
- self.index.add_with_ids(features, index_ids)
-
- self.id_map = dict()
- for i, d in zip(list(index_ids), gallery_docs):
- self.id_map[i] = d
-
- self.features = {
- "features": features,
- "index_method": index_method,
- "image_ids": image_ids,
- "index_ids": index_ids.tolist()
- }
- self._save_index(self.index, self.id_map, self.features)
-
- def open_index(self, root_path: str, image_list_path: str) -> str:
- self._update_path(root_path)
- _, _, image_ids = self._split_datafile(image_list_path, root_path)
- if os.path.exists(os.path.join(self.root_path, self.index_path)) and \
- os.path.exists(os.path.join(self.root_path, self.id_map_path)) and \
- os.path.exists(os.path.join(self.root_path, self.features_path)):
- self._update_path(root_path)
- self._load_index()
- if operator.eq(set(image_ids), set(self.features['image_ids'])):
- return ""
- else:
- return "The image list is different from index, Please update index"
- else:
- return "File not exist: features.pkl, vector.index, id_map.pkl"
-
- def update_index(self, image_list: str, image_root: str = None) -> str:
- if self.index and self.id_map and self.features:
- image_paths, image_docs, image_ids = self._split_datafile(
- image_list,
- image_root if image_root is not None else self.root_path)
-
- # for add image
- add_ids = list(
- set(image_ids).difference(set(self.features["image_ids"])))
- add_indexes = [i for i, x in enumerate(image_ids) if x in add_ids]
- add_image_paths = [image_paths[i] for i in add_indexes]
- add_image_docs = [image_docs[i] for i in add_indexes]
- add_image_ids = [image_ids[i] for i in add_indexes]
- self._add_index(add_image_paths, add_image_docs, add_image_ids)
-
- # delete images
- delete_ids = list(
- set(self.features["image_ids"]).difference(set(image_ids)))
- self._delete_index(delete_ids)
- self._save_index(self.index, self.id_map, self.features)
- return ""
- else:
- return "Failed. Please create or open index first"
-
- def _add_index(self, image_list: List, image_docs: List, image_ids: List):
- if len(image_ids) == 0:
- return
- featrures = self._cal_featrue(image_list)
- index_ids = (np.arange(0, len(image_list)) + max(self.id_map.keys()) +
- 1).astype(np.int64)
- self.index.add_with_ids(featrures, index_ids)
-
- for i, d in zip(index_ids, image_docs):
- self.id_map[i] = d
-
- self.features['features'] = np.concatenate(
- [self.features['features'], featrures], axis=0)
- self.features['image_ids'].extend(image_ids)
- self.features['index_ids'].extend(index_ids.tolist())
-
- def _delete_index(self, image_ids: List):
- if len(image_ids) == 0:
- return
- indexes = [
- i for i, x in enumerate(self.features['image_ids'])
- if x in image_ids
- ]
- self.features["features"] = np.delete(self.features["features"],
- indexes,
- axis=0)
- self.features["image_ids"] = np.delete(np.asarray(
- self.features["image_ids"]),
- indexes,
- axis=0).tolist()
- index_ids = np.delete(np.asarray(self.features["index_ids"]),
- indexes,
- axis=0).tolist()
- id_map_values = [self.id_map[i] for i in index_ids]
- self.index.reset()
- ids = np.arange(0, len(id_map_values)).astype(np.int64)
- self.index.add_with_ids(self.features['features'], ids)
- self.id_map.clear()
- for i, d in zip(ids, id_map_values):
- self.id_map[i] = d
- self.features["index_ids"] = ids
-
-
-app = FastAPI()
-
-
-@app.get("/new_index")
-def new_index(image_list_path: str,
- index_method: str = "HNSW32",
- index_root_path: str = None,
- force: bool = False):
- result = ""
- try:
- if index_root_path is not None:
- image_list_path = os.path.join(index_root_path, image_list_path)
- index_path = os.path.join(index_root_path, "index", "vector.index")
- id_map_path = os.path.join(index_root_path, "index", "id_map.pkl")
-
- if not (os.path.exists(index_path)
- and os.path.exists(id_map_path)) or force:
- manager.create_index(image_list_path, index_method, index_root_path)
- else:
- result = "There alrealy has index in {}".format(index_root_path)
- except Exception as e:
- result = e.__str__()
- data = {"error_message": result}
- return json.dumps(data).encode()
-
-
-@app.get("/open_index")
-def open_index(index_root_path: str, image_list_path: str):
- result = ""
- try:
- image_list_path = os.path.join(index_root_path, image_list_path)
- result = manager.open_index(index_root_path, image_list_path)
- except Exception as e:
- result = e.__str__()
-
- data = {"error_message": result}
- return json.dumps(data).encode()
-
-
-@app.get("/update_index")
-def update_index(image_list_path: str, index_root_path: str = None):
- result = ""
- try:
- if index_root_path is not None:
- image_list_path = os.path.join(index_root_path, image_list_path)
- result = manager.update_index(image_list=image_list_path,
- image_root=index_root_path)
- except Exception as e:
- result = e.__str__()
- data = {"error_message": result}
- return json.dumps(data).encode()
-
-
-def FrontInterface(server_process=None):
- front = QtWidgets.QApplication([])
- main_window = mod.mainwindow.MainWindow(process=server_process)
- main_window.showMaximized()
- sys.exit(front.exec_())
-
-
-def Server(args):
- [app, host, port] = args
- uvicorn.run(app, host=host, port=port)
-
-
if __name__ == '__main__':
- args = config.parse_args()
- model_config = config.get_config(args.config,
- overrides=args.override,
- show=True)
- manager = ShiTuIndexManager(model_config)
+ if not (len(sys.argv) == 3 or len(sys.argv) == 5):
+ print("start example:")
+ print(" python index_manager.py -c xxx.yaml")
+ print(" python index_manager.py -c xxx.yaml -p port")
+ yaml_path = sys.argv[2]
+ if len(sys.argv) == 5:
+ port = sys.argv[4]
+ else:
+ port = 8000
+ assert int(port) > 1024 and int(
+ port) < 65536, "The port should be bigger than 1024 and \
+ smaller than 65536"
+
try:
ip = socket.gethostbyname(socket.gethostname())
except:
ip = '127.0.0.1'
- port = 8000
- p_server = Process(target=Server, args=([app, ip, port],))
- p_server.start()
- # p_client = Process(target=FrontInterface, args=())
- # p_client.start()
- # p_client.join()
- FrontInterface(p_server)
- p_server.terminate()
- sys.exit(0)
+ server_cmd = "python server.py -c {} -o ip={} -o port={}".format(yaml_path,
+ ip, port)
+ server_proc = subprocess.Popen(shlex.split(server_cmd))
+ client_proc = subprocess.Popen(
+ ["python", "client.py", "{} {}".format(ip, port)])
+ try:
+ while psutil.Process(client_proc.pid).status() == "running":
+ time.sleep(0.5)
+ except:
+ pass
+
+ client_proc.terminate()
+ server_proc.terminate()
diff --git a/deploy/shitu_index_manager/mod/mainwindow.py b/deploy/shitu_index_manager/mod/mainwindow.py
index 40d11f6c480619b537cb0c738e99ede89a8fe50c..879161e4250d8b4a3f5cb4a131517ddbb4c693c3 100644
--- a/deploy/shitu_index_manager/mod/mainwindow.py
+++ b/deploy/shitu_index_manager/mod/mainwindow.py
@@ -22,8 +22,6 @@ try:
DEFAULT_HOST = socket.gethostbyname(socket.gethostname())
except:
DEFAULT_HOST = '127.0.0.1'
-
-# DEFAULT_HOST = "localhost"
DEFAULT_PORT = 8000
PADDLECLAS_DOC_URL = "https://gitee.com/paddlepaddle/PaddleClas/docs/zh_CN/inference_deployment/shitu_gallery_manager.md"
@@ -35,12 +33,17 @@ class MainWindow(QtWidgets.QMainWindow):
updateIndexMsg = QtCore.pyqtSignal(str) # 更新索引库线程信号
importImageCount = QtCore.pyqtSignal(int) # 导入图像数量信号
- def __init__(self, process=None):
+ def __init__(self, ip=None, port=None):
super(MainWindow, self).__init__()
- self.server_process = process
+ if ip is not None and port is not None:
+ self.server_ip = ip
+ self.server_port = port
+ else:
+ self.server_ip = DEFAULT_HOST
+ self.server_port = DEFAULT_PORT
+
self.ui = ui_mainwindow.Ui_MainWindow()
self.ui.setupUi(self) # 初始化主窗口界面
-
self.__imageListMgr = image_list_manager.ImageListManager()
self.__appMenu = QtWidgets.QMenu() # 应用菜单
@@ -115,8 +118,7 @@ class MainWindow(QtWidgets.QMainWindow):
self.ui.saveImageLibraryBtn.clicked.connect(self.saveImageLibrary)
self.__setToolButton(self.ui.addClassifyBtn, "添加分类",
- "./resource/add_classify.png",
- TOOL_BTN_ICON_SIZE)
+ "./resource/add_classify.png", TOOL_BTN_ICON_SIZE)
self.ui.addClassifyBtn.clicked.connect(
self.__classifyUiContext.addClassify)
@@ -145,7 +147,10 @@ class MainWindow(QtWidgets.QMainWindow):
self.ui.searchClassifyHistoryCmb.setToolTip("查找分类历史")
self.ui.imageScaleSlider.setToolTip("图片缩放")
- def __setToolButton(self, button, tool_tip: str, icon_path: str,
+ def __setToolButton(self,
+ button,
+ tool_tip: str,
+ icon_path: str,
icon_size: int):
"""设置工具按钮"""
button.setToolTip(tool_tip)
@@ -160,9 +165,9 @@ class MainWindow(QtWidgets.QMainWindow):
self.__libraryAppendMenu.setTitle("导入图像")
utils.setMenu(self.__libraryAppendMenu, "导入 image_list 图像",
- self.importImageListImage)
+ self.importImageListImage)
utils.setMenu(self.__libraryAppendMenu, "导入多文件夹图像",
- self.importDirsImage)
+ self.importDirsImage)
self.__appMenu.addMenu(self.__libraryAppendMenu)
self.__appMenu.addSeparator()
@@ -179,16 +184,16 @@ class MainWindow(QtWidgets.QMainWindow):
def __initWaitDialog(self):
"""初始化等待对话框"""
self.__waitDialogUi.setupUi(self.__waitDialog)
- self.__waitDialog.setWindowFlags(QtCore.Qt.Dialog
- | QtCore.Qt.FramelessWindowHint)
+ self.__waitDialog.setWindowFlags(QtCore.Qt.Dialog |
+ QtCore.Qt.FramelessWindowHint)
def __startWait(self, msg: str):
"""开始显示等待对话框"""
self.setEnabled(False)
self.__waitDialogUi.msgLabel.setText(msg)
- self.__waitDialog.setWindowFlags(QtCore.Qt.Dialog
- | QtCore.Qt.FramelessWindowHint
- | QtCore.Qt.WindowStaysOnTopHint)
+ self.__waitDialog.setWindowFlags(QtCore.Qt.Dialog |
+ QtCore.Qt.FramelessWindowHint |
+ QtCore.Qt.WindowStaysOnTopHint)
self.__waitDialog.show()
self.__waitDialog.repaint()
@@ -196,9 +201,9 @@ class MainWindow(QtWidgets.QMainWindow):
"""停止显示等待对话框"""
self.setEnabled(True)
self.__waitDialogUi.msgLabel.setText("执行完毕!")
- self.__waitDialog.setWindowFlags(QtCore.Qt.Dialog
- | QtCore.Qt.FramelessWindowHint
- | QtCore.Qt.CustomizeWindowHint)
+ self.__waitDialog.setWindowFlags(QtCore.Qt.Dialog |
+ QtCore.Qt.FramelessWindowHint |
+ QtCore.Qt.CustomizeWindowHint)
self.__waitDialog.close()
def __connectSignal(self):
@@ -290,8 +295,8 @@ class MainWindow(QtWidgets.QMainWindow):
def __importImageListImageThread(self, from_path: str, to_path: str):
"""导入 image_list 图像 线程"""
- count = utils.oneKeyImportFromFile(from_path=from_path,
- to_path=to_path)
+ count = utils.oneKeyImportFromFile(
+ from_path=from_path, to_path=to_path)
if count == None:
count = -1
self.importImageCount.emit(count)
@@ -308,9 +313,9 @@ class MainWindow(QtWidgets.QMainWindow):
return
from_mgr = image_list_manager.ImageListManager(from_path)
self.__startWait("正在导入图像,请等待。。。")
- thread = threading.Thread(target=self.__importImageListImageThread,
- args=(from_mgr.filePath,
- self.__imageListMgr.filePath))
+ thread = threading.Thread(
+ target=self.__importImageListImageThread,
+ args=(from_mgr.filePath, self.__imageListMgr.filePath))
thread.start()
def __importDirsImageThread(self, from_dir: str, to_image_list_path: str):
@@ -333,21 +338,25 @@ class MainWindow(QtWidgets.QMainWindow):
QtWidgets.QMessageBox.information(self, "提示", "打开的目录不存在")
return
self.__startWait("正在导入图像,请等待。。。")
- thread = threading.Thread(target=self.__importDirsImageThread,
- args=(dir_path,
- self.__imageListMgr.filePath))
+ thread = threading.Thread(
+ target=self.__importDirsImageThread,
+ args=(dir_path, self.__imageListMgr.filePath))
thread.start()
- def __newIndexThread(self, index_root_path: str, image_list_path: str,
- index_method: str, force: bool):
+ def __newIndexThread(self,
+ index_root_path: str,
+ image_list_path: str,
+ index_method: str,
+ force: bool):
"""新建重建索引库线程"""
try:
- client = index_http_client.IndexHttpClient(
- DEFAULT_HOST, DEFAULT_PORT)
- err_msg = client.new_index(image_list_path=image_list_path,
- index_root_path=index_root_path,
- index_method=index_method,
- force=force)
+ client = index_http_client.IndexHttpClient(self.server_ip,
+ self.server_port)
+ err_msg = client.new_index(
+ image_list_path=image_list_path,
+ index_root_path=index_root_path,
+ index_method=index_method,
+ force=force)
if err_msg == None:
err_msg = ""
self.newIndexMsg.emit(err_msg)
@@ -375,19 +384,20 @@ class MainWindow(QtWidgets.QMainWindow):
force = ui.resetCheckBox.isChecked()
if result == QtWidgets.QDialog.Accepted:
self.__startWait("正在 新建/重建 索引库,请等待。。。")
- thread = threading.Thread(target=self.__newIndexThread,
- args=(self.__imageListMgr.dirName,
- "image_list.txt", index_method,
- force))
+ thread = threading.Thread(
+ target=self.__newIndexThread,
+ args=(self.__imageListMgr.dirName, "image_list.txt",
+ index_method, force))
thread.start()
def __openIndexThread(self, index_root_path: str, image_list_path: str):
"""打开索引库线程"""
try:
- client = index_http_client.IndexHttpClient(
- DEFAULT_HOST, DEFAULT_PORT)
- err_msg = client.open_index(index_root_path=index_root_path,
- image_list_path=image_list_path)
+ client = index_http_client.IndexHttpClient(self.server_ip,
+ self.server_port)
+ err_msg = client.open_index(
+ index_root_path=index_root_path,
+ image_list_path=image_list_path)
if err_msg == None:
err_msg = ""
self.openIndexMsg.emit(err_msg)
@@ -408,18 +418,19 @@ class MainWindow(QtWidgets.QMainWindow):
QtWidgets.QMessageBox.information(self, "提示", "请先打开正确的图像库")
return
self.__startWait("正在打开索引库,请等待。。。")
- thread = threading.Thread(target=self.__openIndexThread,
- args=(self.__imageListMgr.dirName,
- "image_list.txt"))
+ thread = threading.Thread(
+ target=self.__openIndexThread,
+ args=(self.__imageListMgr.dirName, "image_list.txt"))
thread.start()
def __updateIndexThread(self, index_root_path: str, image_list_path: str):
"""更新索引库线程"""
try:
- client = index_http_client.IndexHttpClient(
- DEFAULT_HOST, DEFAULT_PORT)
- err_msg = client.update_index(image_list_path=image_list_path,
- index_root_path=index_root_path)
+ client = index_http_client.IndexHttpClient(self.server_ip,
+ self.server_port)
+ err_msg = client.update_index(
+ image_list_path=image_list_path,
+ index_root_path=index_root_path)
if err_msg == None:
err_msg = ""
self.updateIndexMsg.emit(err_msg)
@@ -440,9 +451,9 @@ class MainWindow(QtWidgets.QMainWindow):
QtWidgets.QMessageBox.information(self, "提示", "请先打开正确的图像库")
return
self.__startWait("正在更新索引库,请等待。。。")
- thread = threading.Thread(target=self.__updateIndexThread,
- args=(self.__imageListMgr.dirName,
- "image_list.txt"))
+ thread = threading.Thread(
+ target=self.__updateIndexThread,
+ args=(self.__imageListMgr.dirName, "image_list.txt"))
thread.start()
def searchClassify(self):
@@ -471,9 +482,6 @@ class MainWindow(QtWidgets.QMainWindow):
def exitApp(self):
"""退出应用"""
- if isinstance(self.server_process, Process):
- self.server_process.terminate()
- # os.kill(self.server_pid)
sys.exit(0)
def __setPathBar(self, msg: str):
diff --git a/deploy/shitu_index_manager/server.py b/deploy/shitu_index_manager/server.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed57208560262036a4ef6e15260dea65d2954fae
--- /dev/null
+++ b/deploy/shitu_index_manager/server.py
@@ -0,0 +1,340 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import sys
+from PyQt5 import QtCore, QtGui, QtWidgets
+import mod.mainwindow
+
+from paddleclas.deploy.utils import config, logger
+from paddleclas.deploy.python.predict_rec import RecPredictor
+from fastapi import FastAPI
+import uvicorn
+import numpy as np
+import faiss
+from typing import List
+import pickle
+import cv2
+import socket
+import json
+import operator
+from multiprocessing import Process
+"""
+完整的index库如下:
+root_path/ # 库存储目录
+|-- image_list.txt # 图像列表,每行:image_path label。由前端生成及修改。后端只读
+|-- features.pkl # 建库之后,保存的embedding向量,后端生成,前端无需操作
+|-- images # 图像存储目录,由前端生成及增删查等操作。后端只读
+| |-- md5.jpg
+| |-- md5.jpg
+| |-- ……
+|-- index # 真正的生成的index库存储目录,后端生成及操作,前端无需操作。
+| |-- vector.index # faiss生成的索引库
+| |-- id_map.pkl # 索引文件
+"""
+
+
+class ShiTuIndexManager(object):
+ def __init__(self, config):
+ self.root_path = None
+ self.image_list_path = "image_list.txt"
+ self.image_dir = "images"
+ self.index_path = "index/vector.index"
+ self.id_map_path = "index/id_map.pkl"
+ self.features_path = "features.pkl"
+ self.index = None
+ self.id_map = None
+ self.features = None
+ self.config = config
+ self.predictor = RecPredictor(config)
+
+ def _load_pickle(self, path):
+ if os.path.exists(path):
+ return pickle.load(open(path, 'rb'))
+ else:
+ return None
+
+ def _save_pickle(self, path, data):
+ if not os.path.exists(os.path.dirname(path)):
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ with open(path, 'wb') as fd:
+ pickle.dump(data, fd)
+
+ def _load_index(self):
+ self.index = faiss.read_index(
+ os.path.join(self.root_path, self.index_path))
+ self.id_map = self._load_pickle(
+ os.path.join(self.root_path, self.id_map_path))
+ self.features = self._load_pickle(
+ os.path.join(self.root_path, self.features_path))
+
+ def _save_index(self, index, id_map, features):
+ faiss.write_index(index, os.path.join(self.root_path, self.index_path))
+ self._save_pickle(
+ os.path.join(self.root_path, self.id_map_path), id_map)
+ self._save_pickle(
+ os.path.join(self.root_path, self.features_path), features)
+
+ def _update_path(self, root_path, image_list_path=None):
+ if root_path == self.root_path:
+ pass
+ else:
+ self.root_path = root_path
+ if not os.path.exists(os.path.join(root_path, "index")):
+ os.mkdir(os.path.join(root_path, "index"))
+ if image_list_path is not None:
+ self.image_list_path = image_list_path
+
+ def _cal_featrue(self, image_list):
+ batch_images = []
+ featrures = None
+ cnt = 0
+ for idx, image_path in enumerate(image_list):
+ image = cv2.imread(image_path)
+ if image is None:
+ return "{} is broken or not exist. Stop"
+ else:
+ image = image[:, :, ::-1]
+ batch_images.append(image)
+ cnt += 1
+ if cnt % self.config["Global"]["batch_size"] == 0 or (
+ idx + 1) == len(image_list):
+ if len(batch_images) == 0:
+ continue
+ batch_results = self.predictor.predict(batch_images)
+ featrures = batch_results if featrures is None else np.concatenate(
+ (featrures, batch_results), axis=0)
+ batch_images = []
+ return featrures
+
+ def _split_datafile(self, data_file, image_root):
+ '''
+ data_file: image path and info, which can be splitted by spacer
+ image_root: image path root
+ delimiter: delimiter
+ '''
+ gallery_images = []
+ gallery_docs = []
+ gallery_ids = []
+ with open(data_file, 'r', encoding='utf-8') as f:
+ lines = f.readlines()
+ for _, ori_line in enumerate(lines):
+ line = ori_line.strip().split()
+ text_num = len(line)
+ assert text_num >= 2, f"line({ori_line}) must be splitted into at least 2 parts, but got {text_num}"
+ image_file = os.path.join(image_root, line[0])
+
+ gallery_images.append(image_file)
+ gallery_docs.append(ori_line.strip())
+ gallery_ids.append(os.path.basename(line[0]).split(".")[0])
+
+ return gallery_images, gallery_docs, gallery_ids
+
+ def create_index(self,
+ image_list: str,
+ index_method: str="HNSW32",
+ image_root: str=None):
+ if not os.path.exists(image_list):
+ return "{} is not exist".format(image_list)
+ if index_method.lower() not in ['hnsw32', 'ivf', 'flat']:
+ return "The index method Only support: HNSW32, IVF, Flat"
+ self._update_path(os.path.dirname(image_list), image_list)
+
+ # get image_paths
+ image_root = image_root if image_root is not None else self.root_path
+ gallery_images, gallery_docs, image_ids = self._split_datafile(
+ image_list, image_root)
+
+ # gernerate index
+ if index_method == "IVF":
+ index_method = index_method + str(
+ min(max(int(len(gallery_images) // 32), 2), 65536)) + ",Flat"
+ index = faiss.index_factory(
+ self.config["IndexProcess"]["embedding_size"], index_method,
+ faiss.METRIC_INNER_PRODUCT)
+ self.index = faiss.IndexIDMap2(index)
+ features = self._cal_featrue(gallery_images)
+ self.index.train(features)
+ index_ids = np.arange(0, len(gallery_images)).astype(np.int64)
+ self.index.add_with_ids(features, index_ids)
+
+ self.id_map = dict()
+ for i, d in zip(list(index_ids), gallery_docs):
+ self.id_map[i] = d
+
+ self.features = {
+ "features": features,
+ "index_method": index_method,
+ "image_ids": image_ids,
+ "index_ids": index_ids.tolist()
+ }
+ self._save_index(self.index, self.id_map, self.features)
+
+ def open_index(self, root_path: str, image_list_path: str) -> str:
+ self._update_path(root_path)
+ _, _, image_ids = self._split_datafile(image_list_path, root_path)
+ if os.path.exists(os.path.join(self.root_path, self.index_path)) and \
+ os.path.exists(os.path.join(self.root_path, self.id_map_path)) and \
+ os.path.exists(os.path.join(self.root_path, self.features_path)):
+ self._update_path(root_path)
+ self._load_index()
+ if operator.eq(set(image_ids), set(self.features['image_ids'])):
+ return ""
+ else:
+ return "The image list is different from index, Please update index"
+ else:
+ return "File not exist: features.pkl, vector.index, id_map.pkl"
+
+ def update_index(self, image_list: str, image_root: str=None) -> str:
+ if self.index and self.id_map and self.features:
+ image_paths, image_docs, image_ids = self._split_datafile(
+ image_list, image_root
+ if image_root is not None else self.root_path)
+
+ # for add image
+ add_ids = list(
+ set(image_ids).difference(set(self.features["image_ids"])))
+ add_indexes = [i for i, x in enumerate(image_ids) if x in add_ids]
+ add_image_paths = [image_paths[i] for i in add_indexes]
+ add_image_docs = [image_docs[i] for i in add_indexes]
+ add_image_ids = [image_ids[i] for i in add_indexes]
+ self._add_index(add_image_paths, add_image_docs, add_image_ids)
+
+ # delete images
+ delete_ids = list(
+ set(self.features["image_ids"]).difference(set(image_ids)))
+ self._delete_index(delete_ids)
+ self._save_index(self.index, self.id_map, self.features)
+ return ""
+ else:
+ return "Failed. Please create or open index first"
+
+ def _add_index(self, image_list: List, image_docs: List, image_ids: List):
+ if len(image_ids) == 0:
+ return
+ featrures = self._cal_featrue(image_list)
+ index_ids = (
+ np.arange(0, len(image_list)) + max(self.id_map.keys()) + 1
+ ).astype(np.int64)
+ self.index.add_with_ids(featrures, index_ids)
+
+ for i, d in zip(index_ids, image_docs):
+ self.id_map[i] = d
+
+ self.features['features'] = np.concatenate(
+ [self.features['features'], featrures], axis=0)
+ self.features['image_ids'].extend(image_ids)
+ self.features['index_ids'].extend(index_ids.tolist())
+
+ def _delete_index(self, image_ids: List):
+ if len(image_ids) == 0:
+ return
+ indexes = [
+ i for i, x in enumerate(self.features['image_ids'])
+ if x in image_ids
+ ]
+ self.features["features"] = np.delete(
+ self.features["features"], indexes, axis=0)
+ self.features["image_ids"] = np.delete(
+ np.asarray(self.features["image_ids"]), indexes, axis=0).tolist()
+ index_ids = np.delete(
+ np.asarray(self.features["index_ids"]), indexes, axis=0).tolist()
+ id_map_values = [self.id_map[i] for i in index_ids]
+ self.index.reset()
+ ids = np.arange(0, len(id_map_values)).astype(np.int64)
+ self.index.add_with_ids(self.features['features'], ids)
+ self.id_map.clear()
+ for i, d in zip(ids, id_map_values):
+ self.id_map[i] = d
+ self.features["index_ids"] = ids
+
+
+app = FastAPI()
+
+
+@app.get("/new_index")
+def new_index(image_list_path: str,
+ index_method: str="HNSW32",
+ index_root_path: str=None,
+ force: bool=False):
+ result = ""
+ try:
+ if index_root_path is not None:
+ image_list_path = os.path.join(index_root_path, image_list_path)
+ index_path = os.path.join(index_root_path, "index", "vector.index")
+ id_map_path = os.path.join(index_root_path, "index", "id_map.pkl")
+
+ if not (os.path.exists(index_path) and
+ os.path.exists(id_map_path)) or force:
+ manager.create_index(image_list_path, index_method,
+ index_root_path)
+ else:
+ result = "There alrealy has index in {}".format(index_root_path)
+ except Exception as e:
+ result = e.__str__()
+ data = {"error_message": result}
+ return json.dumps(data).encode()
+
+
+@app.get("/open_index")
+def open_index(index_root_path: str, image_list_path: str):
+ result = ""
+ try:
+ image_list_path = os.path.join(index_root_path, image_list_path)
+ result = manager.open_index(index_root_path, image_list_path)
+ except Exception as e:
+ result = e.__str__()
+
+ data = {"error_message": result}
+ return json.dumps(data).encode()
+
+
+@app.get("/update_index")
+def update_index(image_list_path: str, index_root_path: str=None):
+ result = ""
+ try:
+ if index_root_path is not None:
+ image_list_path = os.path.join(index_root_path, image_list_path)
+ result = manager.update_index(
+ image_list=image_list_path, image_root=index_root_path)
+ except Exception as e:
+ result = e.__str__()
+ data = {"error_message": result}
+ return json.dumps(data).encode()
+
+
+def FrontInterface(server_process=None):
+ front = QtWidgets.QApplication([])
+ main_window = mod.mainwindow.MainWindow(process=server_process)
+ main_window.showMaximized()
+ sys.exit(front.exec_())
+
+
+def Server(app, host, port):
+ uvicorn.run(app, host=host, port=port)
+
+
+if __name__ == '__main__':
+ args = config.parse_args()
+ model_config = config.get_config(
+ args.config, overrides=args.override, show=True)
+ manager = ShiTuIndexManager(model_config)
+ ip = model_config.get('ip', None)
+ port = model_config.get('port', None)
+ if ip is None or port is None:
+ try:
+ ip = socket.gethostbyname(socket.gethostname())
+ except:
+ ip = '127.0.0.1'
+ port = 8000
+ Server(app, ip, port)
diff --git a/docs/zh_CN/inference_deployment/shitu_gallery_manager.md b/docs/zh_CN/inference_deployment/shitu_gallery_manager.md
index 4023ff9c4e87a79c33d30fdbc16d4f479a88f51b..1146444da267ce026b5e5a74c1e764c3a4870321 100644
--- a/docs/zh_CN/inference_deployment/shitu_gallery_manager.md
+++ b/docs/zh_CN/inference_deployment/shitu_gallery_manager.md
@@ -22,7 +22,7 @@
- [2. 使用说明](#2)
- [2.1 环境安装](#2.1)
- - [2.2 模型准备](#2.2)
+ - [2.2 模型及数据准备](#2.2)
- [2.3运行使用](#2.3)
- [3.生成文件介绍](#3)
@@ -90,7 +90,7 @@
在打开图像库或者新建图像库完成后,可以使用导入图像功能,即导入用户自己生成好的图像库。具体有支持两种导入格式
- image_list格式:打开具体的`.txt`文件。`.txt`文件中每一行格式: `image_path label`。跟据文件路径及label导入
-- 多文件夹格式:打开`具体文件夹`,此文件夹下存储多个子文件夹,每个子文件夹名字为`label_name`,每个子文件夹中存储对应的图像数据。
+- 多文件夹格式:打开`具体文件夹`,此文件夹下存储多个子文件夹,每个子文件夹名字为`label_name`,每个子文件夹中存储对应的图像数据。
@@ -123,13 +123,25 @@
pip install fastapi
pip install uvicorn
pip install pyqt5
+pip install psutil
```
-### 2.2 模型准备
+### 2.2 模型及数据准备
-请按照[PP-ShiTu快速体验](../quick_start/quick_start_recognition.md#2.2.1)中下载及准备inference model,并修改好`${PaddleClas}/deploy/configs/inference_drink.yaml`的相关参数。
+请按照[PP-ShiTu快速体验](../quick_start/quick_start_recognition.md#2.2.1)中下载及准备inference model,并修改好`${PaddleClas}/deploy/configs/inference_drink.yaml`的相关参数,同时准备好数据集。在具体使用时,请替换好自己的数据集及模型文件。
+
+```shell
+cd ${PaddleClas}/deploy/shitu_index_manager
+mkdir models
+cd models
+# 下载及解压识别模型
+wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/PP-ShiTuV2/general_PPLCNetV2_base_pretrained_v1.0_infer.tar && tar -xf general_PPLCNetV2_base_pretrained_v1.0_infer.tar
+cd ..
+# 下载及解压示例数据集
+wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/drink_dataset_v2.0.tar && tar -xf drink_dataset_v2.0.tar
+```
@@ -139,9 +151,26 @@ pip install pyqt5
```shell
cd ${PaddleClas}/deploy/shitu_index_manager
-python index_manager.py -c ../configs/inference_drink.yaml
+cp ../configs/inference_drink.yaml .
+# 注意如果没有按照2.2中准备数据集及代码,请手动修改inference_drink.yaml,做好适配
+python index_manager.py -c inference_drink.yaml
```
+运行成功后,会自动跳转到工具界面,可以按照如下步骤,生成新的index库。
+
+1. 点击菜单栏`新建图像库`,会提示打开一个文件夹,此时请创建一个**新的文件夹**,并打开。如在`${PaddleClas}/deploy/shitu_index_manager`下新建一个`drink_index`文件夹
+2. 导入图像,或者如上面功能介绍,自己手动新增类别和相应的图像,下面介绍两种导入图像方式,操作时,二选一即可。
+ - 点击`导入图像`->`导入image_list图像`,打开`${PaddleClas}/deploy/shitu_index_manager/drink_dataset_v2.0/gallery/drink_label.txt`,此时就可以将`drink_label.txt`中的图像全部导入进来,图像类别就是`drink_label.txt`中记录的类别。
+ - 点击`导入图像`->`导入多文件夹图像`,打开`${PaddleClas}/deploy/shitu_index_manager/drink_dataset_v2.0/gallery/`文件夹,此时就将`gallery`文件夹下,所有子文件夹都导入进来,图像类别就是子文件夹的名字。
+3. 点击菜单栏中`新建/重建 索引库`,此时就会开始生成索引库。如果图片较多或者使用cpu来进行特征提取,那么耗时会比较长,请耐心等待。
+4. 生成索引库成功后,会发现在`drink_index`文件夹下生成如[3](#3) 中介绍的文件,此时`index`子文件夹下生出的文件,就是`PP-ShiTu`所使用的索引文件。
+
+**注意**:
+
+- 利用此工具生成的index库,如`drink_index`文件夹,请妥善存储。之后,可以继续使用此工具中`打开图像库`功能,打开`drink_index`文件夹,继续对index库进行增删改查操作,具体功能可以查看[功能介绍](#1)。
+- 打开一个生成好的库,在其上面进行增删改查操作后,请及时保存。保存后并及时使用菜单中`更新索引库`功能,对索引库进行更新
+- 如果要使用自己的图像库文件,图像生成格式如示例数据格式,生成`image_list.txt`或者多文件夹存储,二选一。
+
## 3. 生成文件介绍
@@ -150,10 +179,10 @@ python index_manager.py -c ../configs/inference_drink.yaml
```shell
index_root/ # 库存储目录
-|-- image_list.txt # 图像列表,每行:image_path label。由前端生成及修改,后端只读
+|-- image_list.txt # 图像列表,每行:image_path label。由前端生成及修改,后端只读
|-- images # 图像存储目录,由前端生成及增删查等操作。后端只读
-| |-- md5.jpg
-| |-- md5.jpg
+| |-- md5.jpg
+| |-- md5.jpg
| |-- ……
|-- features.pkl # 建库之后,保存的embedding向量,后端生成,前端无需操作
|-- index # 真正的生成的index库存储目录,后端生成及操作,前端无需操作。
@@ -192,4 +221,3 @@ index_root/ # 库存储目录
- 问题4: 报错 图像与index库不一致
答:可能用户自己修改了image_list.txt,修改完成后,请及时更新index库,保证其一致。
-