From 8211c0391eb67b5c090979cca0ed35a1bcb13045 Mon Sep 17 00:00:00 2001 From: dongshuilong Date: Tue, 20 Sep 2022 17:11:56 +0800 Subject: [PATCH] fix shitu_index manager bug --- deploy/shitu_index_manager/client.py | 45 +++ deploy/shitu_index_manager/index_manager.py | 345 ++---------------- deploy/shitu_index_manager/mod/mainwindow.py | 118 +++--- deploy/shitu_index_manager/server.py | 340 +++++++++++++++++ .../shitu_gallery_manager.md | 46 ++- 5 files changed, 515 insertions(+), 379 deletions(-) create mode 100644 deploy/shitu_index_manager/client.py create mode 100644 deploy/shitu_index_manager/server.py diff --git a/deploy/shitu_index_manager/client.py b/deploy/shitu_index_manager/client.py new file mode 100644 index 00000000..229691a0 --- /dev/null +++ b/deploy/shitu_index_manager/client.py @@ -0,0 +1,45 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys +from PyQt5 import QtCore, QtGui, QtWidgets +import mod.mainwindow +""" +完整的index库如下: +root_path/ # 库存储目录 +|-- image_list.txt # 图像列表,每行:image_path label。由前端生成及修改。后端只读 +|-- features.pkl # 建库之后,保存的embedding向量,后端生成,前端无需操作 +|-- images # 图像存储目录,由前端生成及增删查等操作。后端只读 +| |-- md5.jpg +| |-- md5.jpg +| |-- …… +|-- index # 真正的生成的index库存储目录,后端生成及操作,前端无需操作。 +| |-- vector.index # faiss生成的索引库 +| |-- id_map.pkl # 索引文件 +""" + + +def FrontInterface(server_ip=None, server_port=None): + front = QtWidgets.QApplication([]) + main_window = mod.mainwindow.MainWindow(ip=server_ip, port=server_port) + main_window.showMaximized() + sys.exit(front.exec_()) + + +if __name__ == '__main__': + server_ip = None + server_port = None + if len(sys.argv) == 2 and len(sys.argv[1].split(' ')) == 2: + [server_ip, server_port] = sys.argv[1].split(' ') + FrontInterface(server_ip, server_port) diff --git a/deploy/shitu_index_manager/index_manager.py b/deploy/shitu_index_manager/index_manager.py index 97e3eec5..bfcfc136 100644 --- a/deploy/shitu_index_manager/index_manager.py +++ b/deploy/shitu_index_manager/index_manager.py @@ -13,22 +13,10 @@ # limitations under the License. import os import sys -from PyQt5 import QtCore, QtGui, QtWidgets -import mod.mainwindow - -from paddleclas.deploy.utils import config, logger -from paddleclas.deploy.python.predict_rec import RecPredictor -from fastapi import FastAPI -import uvicorn -import numpy as np -import faiss -from typing import List -import pickle -import cv2 -import socket -import json -import operator -from multiprocessing import Process +import subprocess +import shlex +import psutil +import time """ 完整的index库如下: root_path/ # 库存储目录 @@ -43,307 +31,34 @@ root_path/ # 库存储目录 | |-- id_map.pkl # 索引文件 """ - -class ShiTuIndexManager(object): - - def __init__(self, config): - self.root_path = None - self.image_list_path = "image_list.txt" - self.image_dir = "images" - self.index_path = "index/vector.index" - self.id_map_path = "index/id_map.pkl" - self.features_path = "features.pkl" - self.index = None - self.id_map = None - self.features = None - self.config = config - self.predictor = RecPredictor(config) - - def _load_pickle(self, path): - if os.path.exists(path): - return pickle.load(open(path, 'rb')) - else: - return None - - def _save_pickle(self, path, data): - if not os.path.exists(os.path.dirname(path)): - os.makedirs(os.path.dirname(path), exist_ok=True) - with open(path, 'wb') as fd: - pickle.dump(data, fd) - - def _load_index(self): - self.index = faiss.read_index( - os.path.join(self.root_path, self.index_path)) - self.id_map = self._load_pickle( - os.path.join(self.root_path, self.id_map_path)) - self.features = self._load_pickle( - os.path.join(self.root_path, self.features_path)) - - def _save_index(self, index, id_map, features): - faiss.write_index(index, os.path.join(self.root_path, self.index_path)) - self._save_pickle(os.path.join(self.root_path, self.id_map_path), - id_map) - self._save_pickle(os.path.join(self.root_path, self.features_path), - features) - - def _update_path(self, root_path, image_list_path=None): - if root_path == self.root_path: - pass - else: - self.root_path = root_path - if not os.path.exists(os.path.join(root_path, "index")): - os.mkdir(os.path.join(root_path, "index")) - if image_list_path is not None: - self.image_list_path = image_list_path - - def _cal_featrue(self, image_list): - batch_images = [] - featrures = None - cnt = 0 - for idx, image_path in enumerate(image_list): - image = cv2.imread(image_path) - if image is None: - return "{} is broken or not exist. Stop" - else: - image = image[:, :, ::-1] - batch_images.append(image) - cnt += 1 - if cnt % self.config["Global"]["batch_size"] == 0 or ( - idx + 1) == len(image_list): - if len(batch_images) == 0: - continue - batch_results = self.predictor.predict(batch_images) - featrures = batch_results if featrures is None else np.concatenate( - (featrures, batch_results), axis=0) - batch_images = [] - return featrures - - def _split_datafile(self, data_file, image_root): - ''' - data_file: image path and info, which can be splitted by spacer - image_root: image path root - delimiter: delimiter - ''' - gallery_images = [] - gallery_docs = [] - gallery_ids = [] - with open(data_file, 'r', encoding='utf-8') as f: - lines = f.readlines() - for _, ori_line in enumerate(lines): - line = ori_line.strip().split() - text_num = len(line) - assert text_num >= 2, f"line({ori_line}) must be splitted into at least 2 parts, but got {text_num}" - image_file = os.path.join(image_root, line[0]) - - gallery_images.append(image_file) - gallery_docs.append(ori_line.strip()) - gallery_ids.append(os.path.basename(line[0]).split(".")[0]) - - return gallery_images, gallery_docs, gallery_ids - - def create_index(self, - image_list: str, - index_method: str = "HNSW32", - image_root: str = None): - if not os.path.exists(image_list): - return "{} is not exist".format(image_list) - if index_method.lower() not in ['hnsw32', 'ivf', 'flat']: - return "The index method Only support: HNSW32, IVF, Flat" - self._update_path(os.path.dirname(image_list), image_list) - - # get image_paths - image_root = image_root if image_root is not None else self.root_path - gallery_images, gallery_docs, image_ids = self._split_datafile( - image_list, image_root) - - # gernerate index - if index_method == "IVF": - index_method = index_method + str( - min(max(int(len(gallery_images) // 32), 2), 65536)) + ",Flat" - index = faiss.index_factory( - self.config["IndexProcess"]["embedding_size"], index_method, - faiss.METRIC_INNER_PRODUCT) - self.index = faiss.IndexIDMap2(index) - features = self._cal_featrue(gallery_images) - self.index.train(features) - index_ids = np.arange(0, len(gallery_images)).astype(np.int64) - self.index.add_with_ids(features, index_ids) - - self.id_map = dict() - for i, d in zip(list(index_ids), gallery_docs): - self.id_map[i] = d - - self.features = { - "features": features, - "index_method": index_method, - "image_ids": image_ids, - "index_ids": index_ids.tolist() - } - self._save_index(self.index, self.id_map, self.features) - - def open_index(self, root_path: str, image_list_path: str) -> str: - self._update_path(root_path) - _, _, image_ids = self._split_datafile(image_list_path, root_path) - if os.path.exists(os.path.join(self.root_path, self.index_path)) and \ - os.path.exists(os.path.join(self.root_path, self.id_map_path)) and \ - os.path.exists(os.path.join(self.root_path, self.features_path)): - self._update_path(root_path) - self._load_index() - if operator.eq(set(image_ids), set(self.features['image_ids'])): - return "" - else: - return "The image list is different from index, Please update index" - else: - return "File not exist: features.pkl, vector.index, id_map.pkl" - - def update_index(self, image_list: str, image_root: str = None) -> str: - if self.index and self.id_map and self.features: - image_paths, image_docs, image_ids = self._split_datafile( - image_list, - image_root if image_root is not None else self.root_path) - - # for add image - add_ids = list( - set(image_ids).difference(set(self.features["image_ids"]))) - add_indexes = [i for i, x in enumerate(image_ids) if x in add_ids] - add_image_paths = [image_paths[i] for i in add_indexes] - add_image_docs = [image_docs[i] for i in add_indexes] - add_image_ids = [image_ids[i] for i in add_indexes] - self._add_index(add_image_paths, add_image_docs, add_image_ids) - - # delete images - delete_ids = list( - set(self.features["image_ids"]).difference(set(image_ids))) - self._delete_index(delete_ids) - self._save_index(self.index, self.id_map, self.features) - return "" - else: - return "Failed. Please create or open index first" - - def _add_index(self, image_list: List, image_docs: List, image_ids: List): - if len(image_ids) == 0: - return - featrures = self._cal_featrue(image_list) - index_ids = (np.arange(0, len(image_list)) + max(self.id_map.keys()) + - 1).astype(np.int64) - self.index.add_with_ids(featrures, index_ids) - - for i, d in zip(index_ids, image_docs): - self.id_map[i] = d - - self.features['features'] = np.concatenate( - [self.features['features'], featrures], axis=0) - self.features['image_ids'].extend(image_ids) - self.features['index_ids'].extend(index_ids.tolist()) - - def _delete_index(self, image_ids: List): - if len(image_ids) == 0: - return - indexes = [ - i for i, x in enumerate(self.features['image_ids']) - if x in image_ids - ] - self.features["features"] = np.delete(self.features["features"], - indexes, - axis=0) - self.features["image_ids"] = np.delete(np.asarray( - self.features["image_ids"]), - indexes, - axis=0).tolist() - index_ids = np.delete(np.asarray(self.features["index_ids"]), - indexes, - axis=0).tolist() - id_map_values = [self.id_map[i] for i in index_ids] - self.index.reset() - ids = np.arange(0, len(id_map_values)).astype(np.int64) - self.index.add_with_ids(self.features['features'], ids) - self.id_map.clear() - for i, d in zip(ids, id_map_values): - self.id_map[i] = d - self.features["index_ids"] = ids - - -app = FastAPI() - - -@app.get("/new_index") -def new_index(image_list_path: str, - index_method: str = "HNSW32", - index_root_path: str = None, - force: bool = False): - result = "" - try: - if index_root_path is not None: - image_list_path = os.path.join(index_root_path, image_list_path) - index_path = os.path.join(index_root_path, "index", "vector.index") - id_map_path = os.path.join(index_root_path, "index", "id_map.pkl") - - if not (os.path.exists(index_path) - and os.path.exists(id_map_path)) or force: - manager.create_index(image_list_path, index_method, index_root_path) - else: - result = "There alrealy has index in {}".format(index_root_path) - except Exception as e: - result = e.__str__() - data = {"error_message": result} - return json.dumps(data).encode() - - -@app.get("/open_index") -def open_index(index_root_path: str, image_list_path: str): - result = "" - try: - image_list_path = os.path.join(index_root_path, image_list_path) - result = manager.open_index(index_root_path, image_list_path) - except Exception as e: - result = e.__str__() - - data = {"error_message": result} - return json.dumps(data).encode() - - -@app.get("/update_index") -def update_index(image_list_path: str, index_root_path: str = None): - result = "" - try: - if index_root_path is not None: - image_list_path = os.path.join(index_root_path, image_list_path) - result = manager.update_index(image_list=image_list_path, - image_root=index_root_path) - except Exception as e: - result = e.__str__() - data = {"error_message": result} - return json.dumps(data).encode() - - -def FrontInterface(server_process=None): - front = QtWidgets.QApplication([]) - main_window = mod.mainwindow.MainWindow(process=server_process) - main_window.showMaximized() - sys.exit(front.exec_()) - - -def Server(args): - [app, host, port] = args - uvicorn.run(app, host=host, port=port) - - if __name__ == '__main__': - args = config.parse_args() - model_config = config.get_config(args.config, - overrides=args.override, - show=True) - manager = ShiTuIndexManager(model_config) + if not (len(sys.argv) == 3 or len(sys.argv) == 5): + print("start example:") + print(" python index_manager.py -c xxx.yaml") + print(" python index_manager.py -c xxx.yaml -p port") + yaml_path = sys.argv[2] + if len(sys.argv) == 5: + port = sys.argv[4] + else: + port = 8000 + assert int(port) > 1024 and int( + port) < 65536, "The port should be bigger than 1024 and \ + smaller than 65536" + try: ip = socket.gethostbyname(socket.gethostname()) except: ip = '127.0.0.1' - port = 8000 - p_server = Process(target=Server, args=([app, ip, port],)) - p_server.start() - # p_client = Process(target=FrontInterface, args=()) - # p_client.start() - # p_client.join() - FrontInterface(p_server) - p_server.terminate() - sys.exit(0) + server_cmd = "python server.py -c {} -o ip={} -o port={}".format(yaml_path, + ip, port) + server_proc = subprocess.Popen(shlex.split(server_cmd)) + client_proc = subprocess.Popen( + ["python", "client.py", "{} {}".format(ip, port)]) + try: + while psutil.Process(client_proc.pid).status() == "running": + time.sleep(0.5) + except: + pass + + client_proc.terminate() + server_proc.terminate() diff --git a/deploy/shitu_index_manager/mod/mainwindow.py b/deploy/shitu_index_manager/mod/mainwindow.py index 40d11f6c..879161e4 100644 --- a/deploy/shitu_index_manager/mod/mainwindow.py +++ b/deploy/shitu_index_manager/mod/mainwindow.py @@ -22,8 +22,6 @@ try: DEFAULT_HOST = socket.gethostbyname(socket.gethostname()) except: DEFAULT_HOST = '127.0.0.1' - -# DEFAULT_HOST = "localhost" DEFAULT_PORT = 8000 PADDLECLAS_DOC_URL = "https://gitee.com/paddlepaddle/PaddleClas/docs/zh_CN/inference_deployment/shitu_gallery_manager.md" @@ -35,12 +33,17 @@ class MainWindow(QtWidgets.QMainWindow): updateIndexMsg = QtCore.pyqtSignal(str) # 更新索引库线程信号 importImageCount = QtCore.pyqtSignal(int) # 导入图像数量信号 - def __init__(self, process=None): + def __init__(self, ip=None, port=None): super(MainWindow, self).__init__() - self.server_process = process + if ip is not None and port is not None: + self.server_ip = ip + self.server_port = port + else: + self.server_ip = DEFAULT_HOST + self.server_port = DEFAULT_PORT + self.ui = ui_mainwindow.Ui_MainWindow() self.ui.setupUi(self) # 初始化主窗口界面 - self.__imageListMgr = image_list_manager.ImageListManager() self.__appMenu = QtWidgets.QMenu() # 应用菜单 @@ -115,8 +118,7 @@ class MainWindow(QtWidgets.QMainWindow): self.ui.saveImageLibraryBtn.clicked.connect(self.saveImageLibrary) self.__setToolButton(self.ui.addClassifyBtn, "添加分类", - "./resource/add_classify.png", - TOOL_BTN_ICON_SIZE) + "./resource/add_classify.png", TOOL_BTN_ICON_SIZE) self.ui.addClassifyBtn.clicked.connect( self.__classifyUiContext.addClassify) @@ -145,7 +147,10 @@ class MainWindow(QtWidgets.QMainWindow): self.ui.searchClassifyHistoryCmb.setToolTip("查找分类历史") self.ui.imageScaleSlider.setToolTip("图片缩放") - def __setToolButton(self, button, tool_tip: str, icon_path: str, + def __setToolButton(self, + button, + tool_tip: str, + icon_path: str, icon_size: int): """设置工具按钮""" button.setToolTip(tool_tip) @@ -160,9 +165,9 @@ class MainWindow(QtWidgets.QMainWindow): self.__libraryAppendMenu.setTitle("导入图像") utils.setMenu(self.__libraryAppendMenu, "导入 image_list 图像", - self.importImageListImage) + self.importImageListImage) utils.setMenu(self.__libraryAppendMenu, "导入多文件夹图像", - self.importDirsImage) + self.importDirsImage) self.__appMenu.addMenu(self.__libraryAppendMenu) self.__appMenu.addSeparator() @@ -179,16 +184,16 @@ class MainWindow(QtWidgets.QMainWindow): def __initWaitDialog(self): """初始化等待对话框""" self.__waitDialogUi.setupUi(self.__waitDialog) - self.__waitDialog.setWindowFlags(QtCore.Qt.Dialog - | QtCore.Qt.FramelessWindowHint) + self.__waitDialog.setWindowFlags(QtCore.Qt.Dialog | + QtCore.Qt.FramelessWindowHint) def __startWait(self, msg: str): """开始显示等待对话框""" self.setEnabled(False) self.__waitDialogUi.msgLabel.setText(msg) - self.__waitDialog.setWindowFlags(QtCore.Qt.Dialog - | QtCore.Qt.FramelessWindowHint - | QtCore.Qt.WindowStaysOnTopHint) + self.__waitDialog.setWindowFlags(QtCore.Qt.Dialog | + QtCore.Qt.FramelessWindowHint | + QtCore.Qt.WindowStaysOnTopHint) self.__waitDialog.show() self.__waitDialog.repaint() @@ -196,9 +201,9 @@ class MainWindow(QtWidgets.QMainWindow): """停止显示等待对话框""" self.setEnabled(True) self.__waitDialogUi.msgLabel.setText("执行完毕!") - self.__waitDialog.setWindowFlags(QtCore.Qt.Dialog - | QtCore.Qt.FramelessWindowHint - | QtCore.Qt.CustomizeWindowHint) + self.__waitDialog.setWindowFlags(QtCore.Qt.Dialog | + QtCore.Qt.FramelessWindowHint | + QtCore.Qt.CustomizeWindowHint) self.__waitDialog.close() def __connectSignal(self): @@ -290,8 +295,8 @@ class MainWindow(QtWidgets.QMainWindow): def __importImageListImageThread(self, from_path: str, to_path: str): """导入 image_list 图像 线程""" - count = utils.oneKeyImportFromFile(from_path=from_path, - to_path=to_path) + count = utils.oneKeyImportFromFile( + from_path=from_path, to_path=to_path) if count == None: count = -1 self.importImageCount.emit(count) @@ -308,9 +313,9 @@ class MainWindow(QtWidgets.QMainWindow): return from_mgr = image_list_manager.ImageListManager(from_path) self.__startWait("正在导入图像,请等待。。。") - thread = threading.Thread(target=self.__importImageListImageThread, - args=(from_mgr.filePath, - self.__imageListMgr.filePath)) + thread = threading.Thread( + target=self.__importImageListImageThread, + args=(from_mgr.filePath, self.__imageListMgr.filePath)) thread.start() def __importDirsImageThread(self, from_dir: str, to_image_list_path: str): @@ -333,21 +338,25 @@ class MainWindow(QtWidgets.QMainWindow): QtWidgets.QMessageBox.information(self, "提示", "打开的目录不存在") return self.__startWait("正在导入图像,请等待。。。") - thread = threading.Thread(target=self.__importDirsImageThread, - args=(dir_path, - self.__imageListMgr.filePath)) + thread = threading.Thread( + target=self.__importDirsImageThread, + args=(dir_path, self.__imageListMgr.filePath)) thread.start() - def __newIndexThread(self, index_root_path: str, image_list_path: str, - index_method: str, force: bool): + def __newIndexThread(self, + index_root_path: str, + image_list_path: str, + index_method: str, + force: bool): """新建重建索引库线程""" try: - client = index_http_client.IndexHttpClient( - DEFAULT_HOST, DEFAULT_PORT) - err_msg = client.new_index(image_list_path=image_list_path, - index_root_path=index_root_path, - index_method=index_method, - force=force) + client = index_http_client.IndexHttpClient(self.server_ip, + self.server_port) + err_msg = client.new_index( + image_list_path=image_list_path, + index_root_path=index_root_path, + index_method=index_method, + force=force) if err_msg == None: err_msg = "" self.newIndexMsg.emit(err_msg) @@ -375,19 +384,20 @@ class MainWindow(QtWidgets.QMainWindow): force = ui.resetCheckBox.isChecked() if result == QtWidgets.QDialog.Accepted: self.__startWait("正在 新建/重建 索引库,请等待。。。") - thread = threading.Thread(target=self.__newIndexThread, - args=(self.__imageListMgr.dirName, - "image_list.txt", index_method, - force)) + thread = threading.Thread( + target=self.__newIndexThread, + args=(self.__imageListMgr.dirName, "image_list.txt", + index_method, force)) thread.start() def __openIndexThread(self, index_root_path: str, image_list_path: str): """打开索引库线程""" try: - client = index_http_client.IndexHttpClient( - DEFAULT_HOST, DEFAULT_PORT) - err_msg = client.open_index(index_root_path=index_root_path, - image_list_path=image_list_path) + client = index_http_client.IndexHttpClient(self.server_ip, + self.server_port) + err_msg = client.open_index( + index_root_path=index_root_path, + image_list_path=image_list_path) if err_msg == None: err_msg = "" self.openIndexMsg.emit(err_msg) @@ -408,18 +418,19 @@ class MainWindow(QtWidgets.QMainWindow): QtWidgets.QMessageBox.information(self, "提示", "请先打开正确的图像库") return self.__startWait("正在打开索引库,请等待。。。") - thread = threading.Thread(target=self.__openIndexThread, - args=(self.__imageListMgr.dirName, - "image_list.txt")) + thread = threading.Thread( + target=self.__openIndexThread, + args=(self.__imageListMgr.dirName, "image_list.txt")) thread.start() def __updateIndexThread(self, index_root_path: str, image_list_path: str): """更新索引库线程""" try: - client = index_http_client.IndexHttpClient( - DEFAULT_HOST, DEFAULT_PORT) - err_msg = client.update_index(image_list_path=image_list_path, - index_root_path=index_root_path) + client = index_http_client.IndexHttpClient(self.server_ip, + self.server_port) + err_msg = client.update_index( + image_list_path=image_list_path, + index_root_path=index_root_path) if err_msg == None: err_msg = "" self.updateIndexMsg.emit(err_msg) @@ -440,9 +451,9 @@ class MainWindow(QtWidgets.QMainWindow): QtWidgets.QMessageBox.information(self, "提示", "请先打开正确的图像库") return self.__startWait("正在更新索引库,请等待。。。") - thread = threading.Thread(target=self.__updateIndexThread, - args=(self.__imageListMgr.dirName, - "image_list.txt")) + thread = threading.Thread( + target=self.__updateIndexThread, + args=(self.__imageListMgr.dirName, "image_list.txt")) thread.start() def searchClassify(self): @@ -471,9 +482,6 @@ class MainWindow(QtWidgets.QMainWindow): def exitApp(self): """退出应用""" - if isinstance(self.server_process, Process): - self.server_process.terminate() - # os.kill(self.server_pid) sys.exit(0) def __setPathBar(self, msg: str): diff --git a/deploy/shitu_index_manager/server.py b/deploy/shitu_index_manager/server.py new file mode 100644 index 00000000..ed572085 --- /dev/null +++ b/deploy/shitu_index_manager/server.py @@ -0,0 +1,340 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys +from PyQt5 import QtCore, QtGui, QtWidgets +import mod.mainwindow + +from paddleclas.deploy.utils import config, logger +from paddleclas.deploy.python.predict_rec import RecPredictor +from fastapi import FastAPI +import uvicorn +import numpy as np +import faiss +from typing import List +import pickle +import cv2 +import socket +import json +import operator +from multiprocessing import Process +""" +完整的index库如下: +root_path/ # 库存储目录 +|-- image_list.txt # 图像列表,每行:image_path label。由前端生成及修改。后端只读 +|-- features.pkl # 建库之后,保存的embedding向量,后端生成,前端无需操作 +|-- images # 图像存储目录,由前端生成及增删查等操作。后端只读 +| |-- md5.jpg +| |-- md5.jpg +| |-- …… +|-- index # 真正的生成的index库存储目录,后端生成及操作,前端无需操作。 +| |-- vector.index # faiss生成的索引库 +| |-- id_map.pkl # 索引文件 +""" + + +class ShiTuIndexManager(object): + def __init__(self, config): + self.root_path = None + self.image_list_path = "image_list.txt" + self.image_dir = "images" + self.index_path = "index/vector.index" + self.id_map_path = "index/id_map.pkl" + self.features_path = "features.pkl" + self.index = None + self.id_map = None + self.features = None + self.config = config + self.predictor = RecPredictor(config) + + def _load_pickle(self, path): + if os.path.exists(path): + return pickle.load(open(path, 'rb')) + else: + return None + + def _save_pickle(self, path, data): + if not os.path.exists(os.path.dirname(path)): + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, 'wb') as fd: + pickle.dump(data, fd) + + def _load_index(self): + self.index = faiss.read_index( + os.path.join(self.root_path, self.index_path)) + self.id_map = self._load_pickle( + os.path.join(self.root_path, self.id_map_path)) + self.features = self._load_pickle( + os.path.join(self.root_path, self.features_path)) + + def _save_index(self, index, id_map, features): + faiss.write_index(index, os.path.join(self.root_path, self.index_path)) + self._save_pickle( + os.path.join(self.root_path, self.id_map_path), id_map) + self._save_pickle( + os.path.join(self.root_path, self.features_path), features) + + def _update_path(self, root_path, image_list_path=None): + if root_path == self.root_path: + pass + else: + self.root_path = root_path + if not os.path.exists(os.path.join(root_path, "index")): + os.mkdir(os.path.join(root_path, "index")) + if image_list_path is not None: + self.image_list_path = image_list_path + + def _cal_featrue(self, image_list): + batch_images = [] + featrures = None + cnt = 0 + for idx, image_path in enumerate(image_list): + image = cv2.imread(image_path) + if image is None: + return "{} is broken or not exist. Stop" + else: + image = image[:, :, ::-1] + batch_images.append(image) + cnt += 1 + if cnt % self.config["Global"]["batch_size"] == 0 or ( + idx + 1) == len(image_list): + if len(batch_images) == 0: + continue + batch_results = self.predictor.predict(batch_images) + featrures = batch_results if featrures is None else np.concatenate( + (featrures, batch_results), axis=0) + batch_images = [] + return featrures + + def _split_datafile(self, data_file, image_root): + ''' + data_file: image path and info, which can be splitted by spacer + image_root: image path root + delimiter: delimiter + ''' + gallery_images = [] + gallery_docs = [] + gallery_ids = [] + with open(data_file, 'r', encoding='utf-8') as f: + lines = f.readlines() + for _, ori_line in enumerate(lines): + line = ori_line.strip().split() + text_num = len(line) + assert text_num >= 2, f"line({ori_line}) must be splitted into at least 2 parts, but got {text_num}" + image_file = os.path.join(image_root, line[0]) + + gallery_images.append(image_file) + gallery_docs.append(ori_line.strip()) + gallery_ids.append(os.path.basename(line[0]).split(".")[0]) + + return gallery_images, gallery_docs, gallery_ids + + def create_index(self, + image_list: str, + index_method: str="HNSW32", + image_root: str=None): + if not os.path.exists(image_list): + return "{} is not exist".format(image_list) + if index_method.lower() not in ['hnsw32', 'ivf', 'flat']: + return "The index method Only support: HNSW32, IVF, Flat" + self._update_path(os.path.dirname(image_list), image_list) + + # get image_paths + image_root = image_root if image_root is not None else self.root_path + gallery_images, gallery_docs, image_ids = self._split_datafile( + image_list, image_root) + + # gernerate index + if index_method == "IVF": + index_method = index_method + str( + min(max(int(len(gallery_images) // 32), 2), 65536)) + ",Flat" + index = faiss.index_factory( + self.config["IndexProcess"]["embedding_size"], index_method, + faiss.METRIC_INNER_PRODUCT) + self.index = faiss.IndexIDMap2(index) + features = self._cal_featrue(gallery_images) + self.index.train(features) + index_ids = np.arange(0, len(gallery_images)).astype(np.int64) + self.index.add_with_ids(features, index_ids) + + self.id_map = dict() + for i, d in zip(list(index_ids), gallery_docs): + self.id_map[i] = d + + self.features = { + "features": features, + "index_method": index_method, + "image_ids": image_ids, + "index_ids": index_ids.tolist() + } + self._save_index(self.index, self.id_map, self.features) + + def open_index(self, root_path: str, image_list_path: str) -> str: + self._update_path(root_path) + _, _, image_ids = self._split_datafile(image_list_path, root_path) + if os.path.exists(os.path.join(self.root_path, self.index_path)) and \ + os.path.exists(os.path.join(self.root_path, self.id_map_path)) and \ + os.path.exists(os.path.join(self.root_path, self.features_path)): + self._update_path(root_path) + self._load_index() + if operator.eq(set(image_ids), set(self.features['image_ids'])): + return "" + else: + return "The image list is different from index, Please update index" + else: + return "File not exist: features.pkl, vector.index, id_map.pkl" + + def update_index(self, image_list: str, image_root: str=None) -> str: + if self.index and self.id_map and self.features: + image_paths, image_docs, image_ids = self._split_datafile( + image_list, image_root + if image_root is not None else self.root_path) + + # for add image + add_ids = list( + set(image_ids).difference(set(self.features["image_ids"]))) + add_indexes = [i for i, x in enumerate(image_ids) if x in add_ids] + add_image_paths = [image_paths[i] for i in add_indexes] + add_image_docs = [image_docs[i] for i in add_indexes] + add_image_ids = [image_ids[i] for i in add_indexes] + self._add_index(add_image_paths, add_image_docs, add_image_ids) + + # delete images + delete_ids = list( + set(self.features["image_ids"]).difference(set(image_ids))) + self._delete_index(delete_ids) + self._save_index(self.index, self.id_map, self.features) + return "" + else: + return "Failed. Please create or open index first" + + def _add_index(self, image_list: List, image_docs: List, image_ids: List): + if len(image_ids) == 0: + return + featrures = self._cal_featrue(image_list) + index_ids = ( + np.arange(0, len(image_list)) + max(self.id_map.keys()) + 1 + ).astype(np.int64) + self.index.add_with_ids(featrures, index_ids) + + for i, d in zip(index_ids, image_docs): + self.id_map[i] = d + + self.features['features'] = np.concatenate( + [self.features['features'], featrures], axis=0) + self.features['image_ids'].extend(image_ids) + self.features['index_ids'].extend(index_ids.tolist()) + + def _delete_index(self, image_ids: List): + if len(image_ids) == 0: + return + indexes = [ + i for i, x in enumerate(self.features['image_ids']) + if x in image_ids + ] + self.features["features"] = np.delete( + self.features["features"], indexes, axis=0) + self.features["image_ids"] = np.delete( + np.asarray(self.features["image_ids"]), indexes, axis=0).tolist() + index_ids = np.delete( + np.asarray(self.features["index_ids"]), indexes, axis=0).tolist() + id_map_values = [self.id_map[i] for i in index_ids] + self.index.reset() + ids = np.arange(0, len(id_map_values)).astype(np.int64) + self.index.add_with_ids(self.features['features'], ids) + self.id_map.clear() + for i, d in zip(ids, id_map_values): + self.id_map[i] = d + self.features["index_ids"] = ids + + +app = FastAPI() + + +@app.get("/new_index") +def new_index(image_list_path: str, + index_method: str="HNSW32", + index_root_path: str=None, + force: bool=False): + result = "" + try: + if index_root_path is not None: + image_list_path = os.path.join(index_root_path, image_list_path) + index_path = os.path.join(index_root_path, "index", "vector.index") + id_map_path = os.path.join(index_root_path, "index", "id_map.pkl") + + if not (os.path.exists(index_path) and + os.path.exists(id_map_path)) or force: + manager.create_index(image_list_path, index_method, + index_root_path) + else: + result = "There alrealy has index in {}".format(index_root_path) + except Exception as e: + result = e.__str__() + data = {"error_message": result} + return json.dumps(data).encode() + + +@app.get("/open_index") +def open_index(index_root_path: str, image_list_path: str): + result = "" + try: + image_list_path = os.path.join(index_root_path, image_list_path) + result = manager.open_index(index_root_path, image_list_path) + except Exception as e: + result = e.__str__() + + data = {"error_message": result} + return json.dumps(data).encode() + + +@app.get("/update_index") +def update_index(image_list_path: str, index_root_path: str=None): + result = "" + try: + if index_root_path is not None: + image_list_path = os.path.join(index_root_path, image_list_path) + result = manager.update_index( + image_list=image_list_path, image_root=index_root_path) + except Exception as e: + result = e.__str__() + data = {"error_message": result} + return json.dumps(data).encode() + + +def FrontInterface(server_process=None): + front = QtWidgets.QApplication([]) + main_window = mod.mainwindow.MainWindow(process=server_process) + main_window.showMaximized() + sys.exit(front.exec_()) + + +def Server(app, host, port): + uvicorn.run(app, host=host, port=port) + + +if __name__ == '__main__': + args = config.parse_args() + model_config = config.get_config( + args.config, overrides=args.override, show=True) + manager = ShiTuIndexManager(model_config) + ip = model_config.get('ip', None) + port = model_config.get('port', None) + if ip is None or port is None: + try: + ip = socket.gethostbyname(socket.gethostname()) + except: + ip = '127.0.0.1' + port = 8000 + Server(app, ip, port) diff --git a/docs/zh_CN/inference_deployment/shitu_gallery_manager.md b/docs/zh_CN/inference_deployment/shitu_gallery_manager.md index 4023ff9c..1146444d 100644 --- a/docs/zh_CN/inference_deployment/shitu_gallery_manager.md +++ b/docs/zh_CN/inference_deployment/shitu_gallery_manager.md @@ -22,7 +22,7 @@ - [2. 使用说明](#2) - [2.1 环境安装](#2.1) - - [2.2 模型准备](#2.2) + - [2.2 模型及数据准备](#2.2) - [2.3运行使用](#2.3) - [3.生成文件介绍](#3) @@ -90,7 +90,7 @@ 在打开图像库或者新建图像库完成后,可以使用导入图像功能,即导入用户自己生成好的图像库。具体有支持两种导入格式 - image_list格式:打开具体的`.txt`文件。`.txt`文件中每一行格式: `image_path label`。跟据文件路径及label导入 -- 多文件夹格式:打开`具体文件夹`,此文件夹下存储多个子文件夹,每个子文件夹名字为`label_name`,每个子文件夹中存储对应的图像数据。 +- 多文件夹格式:打开`具体文件夹`,此文件夹下存储多个子文件夹,每个子文件夹名字为`label_name`,每个子文件夹中存储对应的图像数据。 @@ -123,13 +123,25 @@ pip install fastapi pip install uvicorn pip install pyqt5 +pip install psutil ``` -### 2.2 模型准备 +### 2.2 模型及数据准备 -请按照[PP-ShiTu快速体验](../quick_start/quick_start_recognition.md#2.2.1)中下载及准备inference model,并修改好`${PaddleClas}/deploy/configs/inference_drink.yaml`的相关参数。 +请按照[PP-ShiTu快速体验](../quick_start/quick_start_recognition.md#2.2.1)中下载及准备inference model,并修改好`${PaddleClas}/deploy/configs/inference_drink.yaml`的相关参数,同时准备好数据集。在具体使用时,请替换好自己的数据集及模型文件。 + +```shell +cd ${PaddleClas}/deploy/shitu_index_manager +mkdir models +cd models +# 下载及解压识别模型 +wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/PP-ShiTuV2/general_PPLCNetV2_base_pretrained_v1.0_infer.tar && tar -xf general_PPLCNetV2_base_pretrained_v1.0_infer.tar +cd .. +# 下载及解压示例数据集 +wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/drink_dataset_v2.0.tar && tar -xf drink_dataset_v2.0.tar +``` @@ -139,9 +151,26 @@ pip install pyqt5 ```shell cd ${PaddleClas}/deploy/shitu_index_manager -python index_manager.py -c ../configs/inference_drink.yaml +cp ../configs/inference_drink.yaml . +# 注意如果没有按照2.2中准备数据集及代码,请手动修改inference_drink.yaml,做好适配 +python index_manager.py -c inference_drink.yaml ``` +运行成功后,会自动跳转到工具界面,可以按照如下步骤,生成新的index库。 + +1. 点击菜单栏`新建图像库`,会提示打开一个文件夹,此时请创建一个**新的文件夹**,并打开。如在`${PaddleClas}/deploy/shitu_index_manager`下新建一个`drink_index`文件夹 +2. 导入图像,或者如上面功能介绍,自己手动新增类别和相应的图像,下面介绍两种导入图像方式,操作时,二选一即可。 + - 点击`导入图像`->`导入image_list图像`,打开`${PaddleClas}/deploy/shitu_index_manager/drink_dataset_v2.0/gallery/drink_label.txt`,此时就可以将`drink_label.txt`中的图像全部导入进来,图像类别就是`drink_label.txt`中记录的类别。 + - 点击`导入图像`->`导入多文件夹图像`,打开`${PaddleClas}/deploy/shitu_index_manager/drink_dataset_v2.0/gallery/`文件夹,此时就将`gallery`文件夹下,所有子文件夹都导入进来,图像类别就是子文件夹的名字。 +3. 点击菜单栏中`新建/重建 索引库`,此时就会开始生成索引库。如果图片较多或者使用cpu来进行特征提取,那么耗时会比较长,请耐心等待。 +4. 生成索引库成功后,会发现在`drink_index`文件夹下生成如[3](#3) 中介绍的文件,此时`index`子文件夹下生出的文件,就是`PP-ShiTu`所使用的索引文件。 + +**注意**: + +- 利用此工具生成的index库,如`drink_index`文件夹,请妥善存储。之后,可以继续使用此工具中`打开图像库`功能,打开`drink_index`文件夹,继续对index库进行增删改查操作,具体功能可以查看[功能介绍](#1)。 +- 打开一个生成好的库,在其上面进行增删改查操作后,请及时保存。保存后并及时使用菜单中`更新索引库`功能,对索引库进行更新 +- 如果要使用自己的图像库文件,图像生成格式如示例数据格式,生成`image_list.txt`或者多文件夹存储,二选一。 + ## 3. 生成文件介绍 @@ -150,10 +179,10 @@ python index_manager.py -c ../configs/inference_drink.yaml ```shell index_root/ # 库存储目录 -|-- image_list.txt # 图像列表,每行:image_path label。由前端生成及修改,后端只读 +|-- image_list.txt # 图像列表,每行:image_path label。由前端生成及修改,后端只读 |-- images # 图像存储目录,由前端生成及增删查等操作。后端只读 -| |-- md5.jpg -| |-- md5.jpg +| |-- md5.jpg +| |-- md5.jpg | |-- …… |-- features.pkl # 建库之后,保存的embedding向量,后端生成,前端无需操作 |-- index # 真正的生成的index库存储目录,后端生成及操作,前端无需操作。 @@ -192,4 +221,3 @@ index_root/ # 库存储目录 - 问题4: 报错 图像与index库不一致 答:可能用户自己修改了image_list.txt,修改完成后,请及时更新index库,保证其一致。 - -- GitLab