未验证 提交 f7a9a0ae 编写于 作者: H HydrogenSulfate 提交者: GitHub

Merge pull request #2289 from RainFrost1/develop

update pp-shituv2 doc
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from PyQt5 import QtCore, QtGui, QtWidgets
import mod.mainwindow
"""
完整的index库如下:
root_path/ # 库存储目录
|-- image_list.txt # 图像列表,每行:image_path label。由前端生成及修改。后端只读
|-- features.pkl # 建库之后,保存的embedding向量,后端生成,前端无需操作
|-- images # 图像存储目录,由前端生成及增删查等操作。后端只读
| |-- md5.jpg
| |-- md5.jpg
| |-- ……
|-- index # 真正的生成的index库存储目录,后端生成及操作,前端无需操作。
| |-- vector.index # faiss生成的索引库
| |-- id_map.pkl # 索引文件
"""
def FrontInterface(server_ip=None, server_port=None):
front = QtWidgets.QApplication([])
main_window = mod.mainwindow.MainWindow(ip=server_ip, port=server_port)
main_window.showMaximized()
sys.exit(front.exec_())
if __name__ == '__main__':
server_ip = None
server_port = None
if len(sys.argv) == 2 and len(sys.argv[1].split(' ')) == 2:
[server_ip, server_port] = sys.argv[1].split(' ')
FrontInterface(server_ip, server_port)
......@@ -13,22 +13,10 @@
# limitations under the License.
import os
import sys
from PyQt5 import QtCore, QtGui, QtWidgets
import mod.mainwindow
from paddleclas.deploy.utils import config, logger
from paddleclas.deploy.python.predict_rec import RecPredictor
from fastapi import FastAPI
import uvicorn
import numpy as np
import faiss
from typing import List
import pickle
import cv2
import socket
import json
import operator
from multiprocessing import Process
import subprocess
import shlex
import psutil
import time
"""
完整的index库如下:
root_path/ # 库存储目录
......@@ -43,307 +31,34 @@ root_path/ # 库存储目录
| |-- id_map.pkl # 索引文件
"""
class ShiTuIndexManager(object):
def __init__(self, config):
self.root_path = None
self.image_list_path = "image_list.txt"
self.image_dir = "images"
self.index_path = "index/vector.index"
self.id_map_path = "index/id_map.pkl"
self.features_path = "features.pkl"
self.index = None
self.id_map = None
self.features = None
self.config = config
self.predictor = RecPredictor(config)
def _load_pickle(self, path):
if os.path.exists(path):
return pickle.load(open(path, 'rb'))
else:
return None
def _save_pickle(self, path, data):
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, 'wb') as fd:
pickle.dump(data, fd)
def _load_index(self):
self.index = faiss.read_index(
os.path.join(self.root_path, self.index_path))
self.id_map = self._load_pickle(
os.path.join(self.root_path, self.id_map_path))
self.features = self._load_pickle(
os.path.join(self.root_path, self.features_path))
def _save_index(self, index, id_map, features):
faiss.write_index(index, os.path.join(self.root_path, self.index_path))
self._save_pickle(os.path.join(self.root_path, self.id_map_path),
id_map)
self._save_pickle(os.path.join(self.root_path, self.features_path),
features)
def _update_path(self, root_path, image_list_path=None):
if root_path == self.root_path:
pass
else:
self.root_path = root_path
if not os.path.exists(os.path.join(root_path, "index")):
os.mkdir(os.path.join(root_path, "index"))
if image_list_path is not None:
self.image_list_path = image_list_path
def _cal_featrue(self, image_list):
batch_images = []
featrures = None
cnt = 0
for idx, image_path in enumerate(image_list):
image = cv2.imread(image_path)
if image is None:
return "{} is broken or not exist. Stop"
else:
image = image[:, :, ::-1]
batch_images.append(image)
cnt += 1
if cnt % self.config["Global"]["batch_size"] == 0 or (
idx + 1) == len(image_list):
if len(batch_images) == 0:
continue
batch_results = self.predictor.predict(batch_images)
featrures = batch_results if featrures is None else np.concatenate(
(featrures, batch_results), axis=0)
batch_images = []
return featrures
def _split_datafile(self, data_file, image_root):
'''
data_file: image path and info, which can be splitted by spacer
image_root: image path root
delimiter: delimiter
'''
gallery_images = []
gallery_docs = []
gallery_ids = []
with open(data_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
for _, ori_line in enumerate(lines):
line = ori_line.strip().split()
text_num = len(line)
assert text_num >= 2, f"line({ori_line}) must be splitted into at least 2 parts, but got {text_num}"
image_file = os.path.join(image_root, line[0])
gallery_images.append(image_file)
gallery_docs.append(ori_line.strip())
gallery_ids.append(os.path.basename(line[0]).split(".")[0])
return gallery_images, gallery_docs, gallery_ids
def create_index(self,
image_list: str,
index_method: str = "HNSW32",
image_root: str = None):
if not os.path.exists(image_list):
return "{} is not exist".format(image_list)
if index_method.lower() not in ['hnsw32', 'ivf', 'flat']:
return "The index method Only support: HNSW32, IVF, Flat"
self._update_path(os.path.dirname(image_list), image_list)
# get image_paths
image_root = image_root if image_root is not None else self.root_path
gallery_images, gallery_docs, image_ids = self._split_datafile(
image_list, image_root)
# gernerate index
if index_method == "IVF":
index_method = index_method + str(
min(max(int(len(gallery_images) // 32), 2), 65536)) + ",Flat"
index = faiss.index_factory(
self.config["IndexProcess"]["embedding_size"], index_method,
faiss.METRIC_INNER_PRODUCT)
self.index = faiss.IndexIDMap2(index)
features = self._cal_featrue(gallery_images)
self.index.train(features)
index_ids = np.arange(0, len(gallery_images)).astype(np.int64)
self.index.add_with_ids(features, index_ids)
self.id_map = dict()
for i, d in zip(list(index_ids), gallery_docs):
self.id_map[i] = d
self.features = {
"features": features,
"index_method": index_method,
"image_ids": image_ids,
"index_ids": index_ids.tolist()
}
self._save_index(self.index, self.id_map, self.features)
def open_index(self, root_path: str, image_list_path: str) -> str:
self._update_path(root_path)
_, _, image_ids = self._split_datafile(image_list_path, root_path)
if os.path.exists(os.path.join(self.root_path, self.index_path)) and \
os.path.exists(os.path.join(self.root_path, self.id_map_path)) and \
os.path.exists(os.path.join(self.root_path, self.features_path)):
self._update_path(root_path)
self._load_index()
if operator.eq(set(image_ids), set(self.features['image_ids'])):
return ""
else:
return "The image list is different from index, Please update index"
else:
return "File not exist: features.pkl, vector.index, id_map.pkl"
def update_index(self, image_list: str, image_root: str = None) -> str:
if self.index and self.id_map and self.features:
image_paths, image_docs, image_ids = self._split_datafile(
image_list,
image_root if image_root is not None else self.root_path)
# for add image
add_ids = list(
set(image_ids).difference(set(self.features["image_ids"])))
add_indexes = [i for i, x in enumerate(image_ids) if x in add_ids]
add_image_paths = [image_paths[i] for i in add_indexes]
add_image_docs = [image_docs[i] for i in add_indexes]
add_image_ids = [image_ids[i] for i in add_indexes]
self._add_index(add_image_paths, add_image_docs, add_image_ids)
# delete images
delete_ids = list(
set(self.features["image_ids"]).difference(set(image_ids)))
self._delete_index(delete_ids)
self._save_index(self.index, self.id_map, self.features)
return ""
else:
return "Failed. Please create or open index first"
def _add_index(self, image_list: List, image_docs: List, image_ids: List):
if len(image_ids) == 0:
return
featrures = self._cal_featrue(image_list)
index_ids = (np.arange(0, len(image_list)) + max(self.id_map.keys()) +
1).astype(np.int64)
self.index.add_with_ids(featrures, index_ids)
for i, d in zip(index_ids, image_docs):
self.id_map[i] = d
self.features['features'] = np.concatenate(
[self.features['features'], featrures], axis=0)
self.features['image_ids'].extend(image_ids)
self.features['index_ids'].extend(index_ids.tolist())
def _delete_index(self, image_ids: List):
if len(image_ids) == 0:
return
indexes = [
i for i, x in enumerate(self.features['image_ids'])
if x in image_ids
]
self.features["features"] = np.delete(self.features["features"],
indexes,
axis=0)
self.features["image_ids"] = np.delete(np.asarray(
self.features["image_ids"]),
indexes,
axis=0).tolist()
index_ids = np.delete(np.asarray(self.features["index_ids"]),
indexes,
axis=0).tolist()
id_map_values = [self.id_map[i] for i in index_ids]
self.index.reset()
ids = np.arange(0, len(id_map_values)).astype(np.int64)
self.index.add_with_ids(self.features['features'], ids)
self.id_map.clear()
for i, d in zip(ids, id_map_values):
self.id_map[i] = d
self.features["index_ids"] = ids
app = FastAPI()
@app.get("/new_index")
def new_index(image_list_path: str,
index_method: str = "HNSW32",
index_root_path: str = None,
force: bool = False):
result = ""
try:
if index_root_path is not None:
image_list_path = os.path.join(index_root_path, image_list_path)
index_path = os.path.join(index_root_path, "index", "vector.index")
id_map_path = os.path.join(index_root_path, "index", "id_map.pkl")
if not (os.path.exists(index_path)
and os.path.exists(id_map_path)) or force:
manager.create_index(image_list_path, index_method, index_root_path)
else:
result = "There alrealy has index in {}".format(index_root_path)
except Exception as e:
result = e.__str__()
data = {"error_message": result}
return json.dumps(data).encode()
@app.get("/open_index")
def open_index(index_root_path: str, image_list_path: str):
result = ""
try:
image_list_path = os.path.join(index_root_path, image_list_path)
result = manager.open_index(index_root_path, image_list_path)
except Exception as e:
result = e.__str__()
data = {"error_message": result}
return json.dumps(data).encode()
@app.get("/update_index")
def update_index(image_list_path: str, index_root_path: str = None):
result = ""
try:
if index_root_path is not None:
image_list_path = os.path.join(index_root_path, image_list_path)
result = manager.update_index(image_list=image_list_path,
image_root=index_root_path)
except Exception as e:
result = e.__str__()
data = {"error_message": result}
return json.dumps(data).encode()
def FrontInterface(server_process=None):
front = QtWidgets.QApplication([])
main_window = mod.mainwindow.MainWindow(process=server_process)
main_window.showMaximized()
sys.exit(front.exec_())
def Server(args):
[app, host, port] = args
uvicorn.run(app, host=host, port=port)
if __name__ == '__main__':
args = config.parse_args()
model_config = config.get_config(args.config,
overrides=args.override,
show=True)
manager = ShiTuIndexManager(model_config)
if not (len(sys.argv) == 3 or len(sys.argv) == 5):
print("start example:")
print(" python index_manager.py -c xxx.yaml")
print(" python index_manager.py -c xxx.yaml -p port")
yaml_path = sys.argv[2]
if len(sys.argv) == 5:
port = sys.argv[4]
else:
port = 8000
assert int(port) > 1024 and int(
port) < 65536, "The port should be bigger than 1024 and \
smaller than 65536"
try:
ip = socket.gethostbyname(socket.gethostname())
except:
ip = '127.0.0.1'
port = 8000
p_server = Process(target=Server, args=([app, ip, port],))
p_server.start()
# p_client = Process(target=FrontInterface, args=())
# p_client.start()
# p_client.join()
FrontInterface(p_server)
p_server.terminate()
sys.exit(0)
server_cmd = "python server.py -c {} -o ip={} -o port={}".format(yaml_path,
ip, port)
server_proc = subprocess.Popen(shlex.split(server_cmd))
client_proc = subprocess.Popen(
["python", "client.py", "{} {}".format(ip, port)])
try:
while psutil.Process(client_proc.pid).status() == "running":
time.sleep(0.5)
except:
pass
client_proc.terminate()
server_proc.terminate()
......@@ -22,8 +22,6 @@ try:
DEFAULT_HOST = socket.gethostbyname(socket.gethostname())
except:
DEFAULT_HOST = '127.0.0.1'
# DEFAULT_HOST = "localhost"
DEFAULT_PORT = 8000
PADDLECLAS_DOC_URL = "https://gitee.com/paddlepaddle/PaddleClas/docs/zh_CN/inference_deployment/shitu_gallery_manager.md"
......@@ -35,12 +33,17 @@ class MainWindow(QtWidgets.QMainWindow):
updateIndexMsg = QtCore.pyqtSignal(str) # 更新索引库线程信号
importImageCount = QtCore.pyqtSignal(int) # 导入图像数量信号
def __init__(self, process=None):
def __init__(self, ip=None, port=None):
super(MainWindow, self).__init__()
self.server_process = process
if ip is not None and port is not None:
self.server_ip = ip
self.server_port = port
else:
self.server_ip = DEFAULT_HOST
self.server_port = DEFAULT_PORT
self.ui = ui_mainwindow.Ui_MainWindow()
self.ui.setupUi(self) # 初始化主窗口界面
self.__imageListMgr = image_list_manager.ImageListManager()
self.__appMenu = QtWidgets.QMenu() # 应用菜单
......@@ -115,8 +118,7 @@ class MainWindow(QtWidgets.QMainWindow):
self.ui.saveImageLibraryBtn.clicked.connect(self.saveImageLibrary)
self.__setToolButton(self.ui.addClassifyBtn, "添加分类",
"./resource/add_classify.png",
TOOL_BTN_ICON_SIZE)
"./resource/add_classify.png", TOOL_BTN_ICON_SIZE)
self.ui.addClassifyBtn.clicked.connect(
self.__classifyUiContext.addClassify)
......@@ -145,7 +147,10 @@ class MainWindow(QtWidgets.QMainWindow):
self.ui.searchClassifyHistoryCmb.setToolTip("查找分类历史")
self.ui.imageScaleSlider.setToolTip("图片缩放")
def __setToolButton(self, button, tool_tip: str, icon_path: str,
def __setToolButton(self,
button,
tool_tip: str,
icon_path: str,
icon_size: int):
"""设置工具按钮"""
button.setToolTip(tool_tip)
......@@ -160,9 +165,9 @@ class MainWindow(QtWidgets.QMainWindow):
self.__libraryAppendMenu.setTitle("导入图像")
utils.setMenu(self.__libraryAppendMenu, "导入 image_list 图像",
self.importImageListImage)
self.importImageListImage)
utils.setMenu(self.__libraryAppendMenu, "导入多文件夹图像",
self.importDirsImage)
self.importDirsImage)
self.__appMenu.addMenu(self.__libraryAppendMenu)
self.__appMenu.addSeparator()
......@@ -179,16 +184,16 @@ class MainWindow(QtWidgets.QMainWindow):
def __initWaitDialog(self):
"""初始化等待对话框"""
self.__waitDialogUi.setupUi(self.__waitDialog)
self.__waitDialog.setWindowFlags(QtCore.Qt.Dialog
| QtCore.Qt.FramelessWindowHint)
self.__waitDialog.setWindowFlags(QtCore.Qt.Dialog |
QtCore.Qt.FramelessWindowHint)
def __startWait(self, msg: str):
"""开始显示等待对话框"""
self.setEnabled(False)
self.__waitDialogUi.msgLabel.setText(msg)
self.__waitDialog.setWindowFlags(QtCore.Qt.Dialog
| QtCore.Qt.FramelessWindowHint
| QtCore.Qt.WindowStaysOnTopHint)
self.__waitDialog.setWindowFlags(QtCore.Qt.Dialog |
QtCore.Qt.FramelessWindowHint |
QtCore.Qt.WindowStaysOnTopHint)
self.__waitDialog.show()
self.__waitDialog.repaint()
......@@ -196,9 +201,9 @@ class MainWindow(QtWidgets.QMainWindow):
"""停止显示等待对话框"""
self.setEnabled(True)
self.__waitDialogUi.msgLabel.setText("执行完毕!")
self.__waitDialog.setWindowFlags(QtCore.Qt.Dialog
| QtCore.Qt.FramelessWindowHint
| QtCore.Qt.CustomizeWindowHint)
self.__waitDialog.setWindowFlags(QtCore.Qt.Dialog |
QtCore.Qt.FramelessWindowHint |
QtCore.Qt.CustomizeWindowHint)
self.__waitDialog.close()
def __connectSignal(self):
......@@ -290,8 +295,8 @@ class MainWindow(QtWidgets.QMainWindow):
def __importImageListImageThread(self, from_path: str, to_path: str):
"""导入 image_list 图像 线程"""
count = utils.oneKeyImportFromFile(from_path=from_path,
to_path=to_path)
count = utils.oneKeyImportFromFile(
from_path=from_path, to_path=to_path)
if count == None:
count = -1
self.importImageCount.emit(count)
......@@ -308,9 +313,9 @@ class MainWindow(QtWidgets.QMainWindow):
return
from_mgr = image_list_manager.ImageListManager(from_path)
self.__startWait("正在导入图像,请等待。。。")
thread = threading.Thread(target=self.__importImageListImageThread,
args=(from_mgr.filePath,
self.__imageListMgr.filePath))
thread = threading.Thread(
target=self.__importImageListImageThread,
args=(from_mgr.filePath, self.__imageListMgr.filePath))
thread.start()
def __importDirsImageThread(self, from_dir: str, to_image_list_path: str):
......@@ -333,21 +338,25 @@ class MainWindow(QtWidgets.QMainWindow):
QtWidgets.QMessageBox.information(self, "提示", "打开的目录不存在")
return
self.__startWait("正在导入图像,请等待。。。")
thread = threading.Thread(target=self.__importDirsImageThread,
args=(dir_path,
self.__imageListMgr.filePath))
thread = threading.Thread(
target=self.__importDirsImageThread,
args=(dir_path, self.__imageListMgr.filePath))
thread.start()
def __newIndexThread(self, index_root_path: str, image_list_path: str,
index_method: str, force: bool):
def __newIndexThread(self,
index_root_path: str,
image_list_path: str,
index_method: str,
force: bool):
"""新建重建索引库线程"""
try:
client = index_http_client.IndexHttpClient(
DEFAULT_HOST, DEFAULT_PORT)
err_msg = client.new_index(image_list_path=image_list_path,
index_root_path=index_root_path,
index_method=index_method,
force=force)
client = index_http_client.IndexHttpClient(self.server_ip,
self.server_port)
err_msg = client.new_index(
image_list_path=image_list_path,
index_root_path=index_root_path,
index_method=index_method,
force=force)
if err_msg == None:
err_msg = ""
self.newIndexMsg.emit(err_msg)
......@@ -375,19 +384,20 @@ class MainWindow(QtWidgets.QMainWindow):
force = ui.resetCheckBox.isChecked()
if result == QtWidgets.QDialog.Accepted:
self.__startWait("正在 新建/重建 索引库,请等待。。。")
thread = threading.Thread(target=self.__newIndexThread,
args=(self.__imageListMgr.dirName,
"image_list.txt", index_method,
force))
thread = threading.Thread(
target=self.__newIndexThread,
args=(self.__imageListMgr.dirName, "image_list.txt",
index_method, force))
thread.start()
def __openIndexThread(self, index_root_path: str, image_list_path: str):
"""打开索引库线程"""
try:
client = index_http_client.IndexHttpClient(
DEFAULT_HOST, DEFAULT_PORT)
err_msg = client.open_index(index_root_path=index_root_path,
image_list_path=image_list_path)
client = index_http_client.IndexHttpClient(self.server_ip,
self.server_port)
err_msg = client.open_index(
index_root_path=index_root_path,
image_list_path=image_list_path)
if err_msg == None:
err_msg = ""
self.openIndexMsg.emit(err_msg)
......@@ -408,18 +418,19 @@ class MainWindow(QtWidgets.QMainWindow):
QtWidgets.QMessageBox.information(self, "提示", "请先打开正确的图像库")
return
self.__startWait("正在打开索引库,请等待。。。")
thread = threading.Thread(target=self.__openIndexThread,
args=(self.__imageListMgr.dirName,
"image_list.txt"))
thread = threading.Thread(
target=self.__openIndexThread,
args=(self.__imageListMgr.dirName, "image_list.txt"))
thread.start()
def __updateIndexThread(self, index_root_path: str, image_list_path: str):
"""更新索引库线程"""
try:
client = index_http_client.IndexHttpClient(
DEFAULT_HOST, DEFAULT_PORT)
err_msg = client.update_index(image_list_path=image_list_path,
index_root_path=index_root_path)
client = index_http_client.IndexHttpClient(self.server_ip,
self.server_port)
err_msg = client.update_index(
image_list_path=image_list_path,
index_root_path=index_root_path)
if err_msg == None:
err_msg = ""
self.updateIndexMsg.emit(err_msg)
......@@ -440,9 +451,9 @@ class MainWindow(QtWidgets.QMainWindow):
QtWidgets.QMessageBox.information(self, "提示", "请先打开正确的图像库")
return
self.__startWait("正在更新索引库,请等待。。。")
thread = threading.Thread(target=self.__updateIndexThread,
args=(self.__imageListMgr.dirName,
"image_list.txt"))
thread = threading.Thread(
target=self.__updateIndexThread,
args=(self.__imageListMgr.dirName, "image_list.txt"))
thread.start()
def searchClassify(self):
......@@ -471,9 +482,6 @@ class MainWindow(QtWidgets.QMainWindow):
def exitApp(self):
"""退出应用"""
if isinstance(self.server_process, Process):
self.server_process.terminate()
# os.kill(self.server_pid)
sys.exit(0)
def __setPathBar(self, msg: str):
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from PyQt5 import QtCore, QtGui, QtWidgets
import mod.mainwindow
from paddleclas.deploy.utils import config, logger
from paddleclas.deploy.python.predict_rec import RecPredictor
from fastapi import FastAPI
import uvicorn
import numpy as np
import faiss
from typing import List
import pickle
import cv2
import socket
import json
import operator
from multiprocessing import Process
"""
完整的index库如下:
root_path/ # 库存储目录
|-- image_list.txt # 图像列表,每行:image_path label。由前端生成及修改。后端只读
|-- features.pkl # 建库之后,保存的embedding向量,后端生成,前端无需操作
|-- images # 图像存储目录,由前端生成及增删查等操作。后端只读
| |-- md5.jpg
| |-- md5.jpg
| |-- ……
|-- index # 真正的生成的index库存储目录,后端生成及操作,前端无需操作。
| |-- vector.index # faiss生成的索引库
| |-- id_map.pkl # 索引文件
"""
class ShiTuIndexManager(object):
def __init__(self, config):
self.root_path = None
self.image_list_path = "image_list.txt"
self.image_dir = "images"
self.index_path = "index/vector.index"
self.id_map_path = "index/id_map.pkl"
self.features_path = "features.pkl"
self.index = None
self.id_map = None
self.features = None
self.config = config
self.predictor = RecPredictor(config)
def _load_pickle(self, path):
if os.path.exists(path):
return pickle.load(open(path, 'rb'))
else:
return None
def _save_pickle(self, path, data):
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, 'wb') as fd:
pickle.dump(data, fd)
def _load_index(self):
self.index = faiss.read_index(
os.path.join(self.root_path, self.index_path))
self.id_map = self._load_pickle(
os.path.join(self.root_path, self.id_map_path))
self.features = self._load_pickle(
os.path.join(self.root_path, self.features_path))
def _save_index(self, index, id_map, features):
faiss.write_index(index, os.path.join(self.root_path, self.index_path))
self._save_pickle(
os.path.join(self.root_path, self.id_map_path), id_map)
self._save_pickle(
os.path.join(self.root_path, self.features_path), features)
def _update_path(self, root_path, image_list_path=None):
if root_path == self.root_path:
pass
else:
self.root_path = root_path
if not os.path.exists(os.path.join(root_path, "index")):
os.mkdir(os.path.join(root_path, "index"))
if image_list_path is not None:
self.image_list_path = image_list_path
def _cal_featrue(self, image_list):
batch_images = []
featrures = None
cnt = 0
for idx, image_path in enumerate(image_list):
image = cv2.imread(image_path)
if image is None:
return "{} is broken or not exist. Stop"
else:
image = image[:, :, ::-1]
batch_images.append(image)
cnt += 1
if cnt % self.config["Global"]["batch_size"] == 0 or (
idx + 1) == len(image_list):
if len(batch_images) == 0:
continue
batch_results = self.predictor.predict(batch_images)
featrures = batch_results if featrures is None else np.concatenate(
(featrures, batch_results), axis=0)
batch_images = []
return featrures
def _split_datafile(self, data_file, image_root):
'''
data_file: image path and info, which can be splitted by spacer
image_root: image path root
delimiter: delimiter
'''
gallery_images = []
gallery_docs = []
gallery_ids = []
with open(data_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
for _, ori_line in enumerate(lines):
line = ori_line.strip().split()
text_num = len(line)
assert text_num >= 2, f"line({ori_line}) must be splitted into at least 2 parts, but got {text_num}"
image_file = os.path.join(image_root, line[0])
gallery_images.append(image_file)
gallery_docs.append(ori_line.strip())
gallery_ids.append(os.path.basename(line[0]).split(".")[0])
return gallery_images, gallery_docs, gallery_ids
def create_index(self,
image_list: str,
index_method: str="HNSW32",
image_root: str=None):
if not os.path.exists(image_list):
return "{} is not exist".format(image_list)
if index_method.lower() not in ['hnsw32', 'ivf', 'flat']:
return "The index method Only support: HNSW32, IVF, Flat"
self._update_path(os.path.dirname(image_list), image_list)
# get image_paths
image_root = image_root if image_root is not None else self.root_path
gallery_images, gallery_docs, image_ids = self._split_datafile(
image_list, image_root)
# gernerate index
if index_method == "IVF":
index_method = index_method + str(
min(max(int(len(gallery_images) // 32), 2), 65536)) + ",Flat"
index = faiss.index_factory(
self.config["IndexProcess"]["embedding_size"], index_method,
faiss.METRIC_INNER_PRODUCT)
self.index = faiss.IndexIDMap2(index)
features = self._cal_featrue(gallery_images)
self.index.train(features)
index_ids = np.arange(0, len(gallery_images)).astype(np.int64)
self.index.add_with_ids(features, index_ids)
self.id_map = dict()
for i, d in zip(list(index_ids), gallery_docs):
self.id_map[i] = d
self.features = {
"features": features,
"index_method": index_method,
"image_ids": image_ids,
"index_ids": index_ids.tolist()
}
self._save_index(self.index, self.id_map, self.features)
def open_index(self, root_path: str, image_list_path: str) -> str:
self._update_path(root_path)
_, _, image_ids = self._split_datafile(image_list_path, root_path)
if os.path.exists(os.path.join(self.root_path, self.index_path)) and \
os.path.exists(os.path.join(self.root_path, self.id_map_path)) and \
os.path.exists(os.path.join(self.root_path, self.features_path)):
self._update_path(root_path)
self._load_index()
if operator.eq(set(image_ids), set(self.features['image_ids'])):
return ""
else:
return "The image list is different from index, Please update index"
else:
return "File not exist: features.pkl, vector.index, id_map.pkl"
def update_index(self, image_list: str, image_root: str=None) -> str:
if self.index and self.id_map and self.features:
image_paths, image_docs, image_ids = self._split_datafile(
image_list, image_root
if image_root is not None else self.root_path)
# for add image
add_ids = list(
set(image_ids).difference(set(self.features["image_ids"])))
add_indexes = [i for i, x in enumerate(image_ids) if x in add_ids]
add_image_paths = [image_paths[i] for i in add_indexes]
add_image_docs = [image_docs[i] for i in add_indexes]
add_image_ids = [image_ids[i] for i in add_indexes]
self._add_index(add_image_paths, add_image_docs, add_image_ids)
# delete images
delete_ids = list(
set(self.features["image_ids"]).difference(set(image_ids)))
self._delete_index(delete_ids)
self._save_index(self.index, self.id_map, self.features)
return ""
else:
return "Failed. Please create or open index first"
def _add_index(self, image_list: List, image_docs: List, image_ids: List):
if len(image_ids) == 0:
return
featrures = self._cal_featrue(image_list)
index_ids = (
np.arange(0, len(image_list)) + max(self.id_map.keys()) + 1
).astype(np.int64)
self.index.add_with_ids(featrures, index_ids)
for i, d in zip(index_ids, image_docs):
self.id_map[i] = d
self.features['features'] = np.concatenate(
[self.features['features'], featrures], axis=0)
self.features['image_ids'].extend(image_ids)
self.features['index_ids'].extend(index_ids.tolist())
def _delete_index(self, image_ids: List):
if len(image_ids) == 0:
return
indexes = [
i for i, x in enumerate(self.features['image_ids'])
if x in image_ids
]
self.features["features"] = np.delete(
self.features["features"], indexes, axis=0)
self.features["image_ids"] = np.delete(
np.asarray(self.features["image_ids"]), indexes, axis=0).tolist()
index_ids = np.delete(
np.asarray(self.features["index_ids"]), indexes, axis=0).tolist()
id_map_values = [self.id_map[i] for i in index_ids]
self.index.reset()
ids = np.arange(0, len(id_map_values)).astype(np.int64)
self.index.add_with_ids(self.features['features'], ids)
self.id_map.clear()
for i, d in zip(ids, id_map_values):
self.id_map[i] = d
self.features["index_ids"] = ids
app = FastAPI()
@app.get("/new_index")
def new_index(image_list_path: str,
index_method: str="HNSW32",
index_root_path: str=None,
force: bool=False):
result = ""
try:
if index_root_path is not None:
image_list_path = os.path.join(index_root_path, image_list_path)
index_path = os.path.join(index_root_path, "index", "vector.index")
id_map_path = os.path.join(index_root_path, "index", "id_map.pkl")
if not (os.path.exists(index_path) and
os.path.exists(id_map_path)) or force:
manager.create_index(image_list_path, index_method,
index_root_path)
else:
result = "There alrealy has index in {}".format(index_root_path)
except Exception as e:
result = e.__str__()
data = {"error_message": result}
return json.dumps(data).encode()
@app.get("/open_index")
def open_index(index_root_path: str, image_list_path: str):
result = ""
try:
image_list_path = os.path.join(index_root_path, image_list_path)
result = manager.open_index(index_root_path, image_list_path)
except Exception as e:
result = e.__str__()
data = {"error_message": result}
return json.dumps(data).encode()
@app.get("/update_index")
def update_index(image_list_path: str, index_root_path: str=None):
result = ""
try:
if index_root_path is not None:
image_list_path = os.path.join(index_root_path, image_list_path)
result = manager.update_index(
image_list=image_list_path, image_root=index_root_path)
except Exception as e:
result = e.__str__()
data = {"error_message": result}
return json.dumps(data).encode()
def FrontInterface(server_process=None):
front = QtWidgets.QApplication([])
main_window = mod.mainwindow.MainWindow(process=server_process)
main_window.showMaximized()
sys.exit(front.exec_())
def Server(app, host, port):
uvicorn.run(app, host=host, port=port)
if __name__ == '__main__':
args = config.parse_args()
model_config = config.get_config(
args.config, overrides=args.override, show=True)
manager = ShiTuIndexManager(model_config)
ip = model_config.get('ip', None)
port = model_config.get('port', None)
if ip is None or port is None:
try:
ip = socket.gethostbyname(socket.gethostname())
except:
ip = '127.0.0.1'
port = 8000
Server(app, ip, port)
docs/images/structure.jpg

1.8 MB | W: | H:

docs/images/structure.jpg

193.3 KB | W: | H:

docs/images/structure.jpg
docs/images/structure.jpg
docs/images/structure.jpg
docs/images/structure.jpg
  • 2-up
  • Swipe
  • Onion skin
......@@ -24,30 +24,32 @@ PP-ShiTu对原数据集进行了`Gallery`库和`Query`库划分,并生成了
| 场景 |示例图|场景简介|Recall@1|场景库下载地址|原数据集下载地址|
|:---:|:---:|:---:|:---:|:---:|:---:|
| 球类 | <img src="../../images/ppshitu_application_scenarios/Ball.jpg" height = "100"> |各种球类识别 | 0.9769 | [Balls](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/PP-ShiTuV2_application_dataset/Balls.tar) | [原数据下载地址](https://www.kaggle.com/datasets/gpiosenka/balls-image-classification) |
| 狗识别 | <img src="../../images/ppshitu_application_scenarios/DogBreeds.jpg" height = "100"> | 狗细分类识别,包括69种狗的图像 | 0.9606 | [DogBreeds](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/PP-ShiTuV2_application_dataset/DogBreeds.tar) | [原数据下载地址](https://www.kaggle.com/datasets/gpiosenka/70-dog-breedsimage-data-set) |
| 宝石 | <img src="../../images/ppshitu_application_scenarios/Gemstones.jpg" height = "100"> | 宝石种类识别 | 0.9653 | [Gemstones](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/PP-ShiTuV2_application_dataset/Gemstones.tar) | [原数据下载地址](https://www.kaggle.com/datasets/lsind18/gemstones-images) |
| 动物 | <img src="../../images/ppshitu_application_scenarios/AnimalImageDataset.jpg" height = "100"> |各种动物识别 | 0.9078 | [AnimalImageDataset](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/PP-ShiTuV2_application_dataset/AnimalImageDataset.tar) | [原数据下载地址](https://www.kaggle.com/datasets/iamsouravbanerjee/animal-image-dataset-90-different-animals) |
| 鸟类 | <img src="../../images/ppshitu_application_scenarios/Bird400.jpg" height = "100"> |鸟细分类识别,包括400种各种姿态的鸟类图像 | 0.9673 | [Bird400](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/PP-ShiTuV2_application_dataset/Bird400.tar) | [原数据下载地址](https://www.kaggle.com/datasets/gpiosenka/100-bird-species) |
| 花 | <img src="../../images/ppshitu_application_scenarios/104flowers.jpeg" height = "100"> |104种花细分类识别 | 0.9788 | [104flowers](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/PP-ShiTuV2_application_dataset/104flowers.tar) | [原数据下载地址](https://www.kaggle.com/datasets/msheriey/104-flowers-garden-of-eden) |
| 交通工具 | <img src="../../images/ppshitu_application_scenarios/Vechicles.jpg" height = "100"> |车、船等交通工具粗分类识别 | 0.9307 | [Vechicles](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/PP-ShiTuV2_application_dataset/Vechicles.tar) | [原数据下载地址](https://www.kaggle.com/datasets/rishabkoul1/vechicle-dataset) |
| 花 | <img src="../../images/ppshitu_application_scenarios/104flowers.jpeg" height = "100"> |104种花细分类识别 | 0.9788 | [104flowers](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/PP-ShiTuV2_application_dataset/104flowrs.tar) | [原数据下载地址](https://www.kaggle.com/datasets/msheriey/104-flowers-garden-of-eden) |
| 运动种类 | <img src="../../images/ppshitu_application_scenarios/100sports.jpg" height = "100"> |100种运动图像识别 | 0.9413 | [100sports](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/PP-ShiTuV2_application_dataset/100sports.tar) | [原数据下载地址](https://www.kaggle.com/datasets/gpiosenka/sports-classification) |
| 乐器 | <img src="../../images/ppshitu_application_scenarios/MusicInstruments.jpg" height = "100"> |30种不同乐器种类识别 | 0.9467 | [MusicInstruments](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/PP-ShiTuV2_application_dataset/MusicInstruments.tar) | [原数据下载地址](https://www.kaggle.com/datasets/gpiosenka/musical-instruments-image-classification) |
| 宝可梦 | <img src="../../images/ppshitu_application_scenarios/Pokemon.png" height = "100"> |宝可梦神奇宝贝识别 | 0.9236 | [Pokemon](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/PP-ShiTuV2_application_dataset/Pokemon.tar) | [原数据下载地址](https://www.kaggle.com/datasets/lantian773030/pokemonclassification) |
| 船 | <img src="../../images/ppshitu_application_scenarios/Boat.jpg" height = "100"> |船种类识别 |0.9242 | [Boat](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/PP-ShiTuV2_application_dataset/Boat.tar) | [原数据下载地址](https://www.kaggle.com/datasets/imsparsh/dockship-boat-type-classification) |
| 鞋子 | <img src="../../images/ppshitu_application_scenarios/Shoes.jpeg" height = "100"> |鞋子种类识别,包括靴子、拖鞋等 | 0.9000 | [Shoes](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/PP-ShiTuV2_application_dataset/Shoes.tar) | [原数据下载地址](https://www.kaggle.com/datasets/noobyogi0100/shoe-dataset) |
| 巴黎建筑 | <img src="../../images/ppshitu_application_scenarios/Paris.jpg" height = "100"> |巴黎著名建筑景点识别,如:巴黎铁塔、圣母院等 | 1.000 | [Paris](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/PP-ShiTuV2_application_dataset/Paris.tar) | [原数据下载地址](https://www.kaggle.com/datasets/skylord/oxbuildings) |
| 蝴蝶 | <img src="../../images/ppshitu_application_scenarios/Butterfly.jpg" height = "100"> |75种蝴蝶细分类识别 | 0.9360 | [Butterfly](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/PP-ShiTuV2_application_dataset/Butterfly.tar) | [原数据下载地址](https://www.kaggle.com/datasets/gpiosenka/butterfly-images40-species) |
| 野外植物 | <img src="../../images/ppshitu_application_scenarios/WildEdiblePlants.jpg" height = "100"> |野外植物识别 | 0.9758 | [WildEdiblePlants](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/PP-ShiTuV2_application_dataset/WildEdiblePlants.tar) | [原数据下载地址](https://www.kaggle.com/datasets/ryanpartridge01/wild-edible-plants) |
| 天气 | <img src="../../images/ppshitu_application_scenarios/WeatherImageRecognition.jpg" height = "100"> |各种天气场景识别,如:雨天、打雷、下雪等 | 0.9924 | [WeatherImageRecognition](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/PP-ShiTuV2_application_dataset/WeatherImageRecognition.tar) | [原数据下载地址](https://www.kaggle.com/datasets/jehanbhathena/weather-dataset) |
| 坚果 | <img src="../../images/ppshitu_application_scenarios/TreeNuts.jpg" height = "100"> |各种坚果种类识别 | 0.9412 | [TreeNuts](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/PP-ShiTuV2_application_dataset/TreeNuts.tar) | [原数据下载地址](https://www.kaggle.com/datasets/gpiosenka/tree-nuts-image-classification) |
| 动物 | <img src="../../images/ppshitu_application_scenarios/AnimalImageDataset.jpg" height = "100"> |各种动物识别 | 0.9078 | [AnimalImageDataset](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/PP-ShiTuV2_application_dataset/AnimalImageDataset.tar) | [原数据下载地址](https://www.kaggle.com/datasets/iamsouravbanerjee/animal-image-dataset-90-different-animals) |
| 时装 | <img src="../../images/ppshitu_application_scenarios/FashionProductsImage.jpg" height = "100"> |首饰、挎包、化妆品等时尚商品识别 | 0.9555 | [FashionProductImageSmall](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/PP-ShiTuV2_application_dataset/FashionProductImageSmall.tar) | [原数据下载地址](https://www.kaggle.com/datasets/paramaggarwal/fashion-product-images-small) |
| 垃圾 | <img src="../../images/ppshitu_application_scenarios/Garbage12.jpg" height = "100"> |12种垃圾分类识别 | 0.9845 | [Garbage12](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/PP-ShiTuV2_application_dataset/Garbage12.tar) | [原数据下载地址](https://www.kaggle.com/datasets/mostafaabla/garbage-classification) |
| 航拍场景 | <img src="../../images/ppshitu_application_scenarios/AID.jpg" height = "100"> |各种航拍场景识别,如机场、火车站等 | 0.9797 | [AID](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/PP-ShiTuV2_application_dataset/AID.tar) | [原数据下载地址](https://www.kaggle.com/datasets/jiayuanchengala/aid-scene-classification-datasets) |
| 蔬菜 | <img src="../../images/ppshitu_application_scenarios/Veg200.jpg" height = "100"> |各种蔬菜识别 | 0.8929 | [Veg200](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/PP-ShiTuV2_application_dataset/Veg200.tar) | [原数据下载地址](https://www.kaggle.com/datasets/zhaoyj688/vegfru) |
| 商标 | <img src="../../images/ppshitu_application_scenarios/Logo3K.jpg" height = "100"> |两千多种logo识别 | 0.9313 | [Logo3k](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/PP-ShiTuV2_application_dataset/Logo3k.tar) | [原数据下载地址](https://github.com/Wangjing1551/LogoDet-3K-Dataset) |
| 球类 | <img src="../../images/ppshitu_application_scenarios/Ball.jpg" height = "100"> |各种球类识别 | 0.9769 | [Balls](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/PP-ShiTuV2_application_dataset/Balls.tar) | [原数据下载地址](https://www.kaggle.com/datasets/gpiosenka/balls-image-classification) |
| 狗识别 | <img src="../../images/ppshitu_application_scenarios/DogBreeds.jpg" height = "100"> | 狗细分类识别,包括69种狗的图像 | 0.9606 | 添加技术交流群后即可获得 | [原数据下载地址](https://www.kaggle.com/datasets/gpiosenka/70-dog-breedsimage-data-set) |
| 宝石 | <img src="../../images/ppshitu_application_scenarios/Gemstones.jpg" height = "100"> | 宝石种类识别 | 0.9653 | 添加技术交流群后即可获得 | [原数据下载地址](https://www.kaggle.com/datasets/lsind18/gemstones-images) |
| 鸟类 | <img src="../../images/ppshitu_application_scenarios/Bird400.jpg" height = "100"> |鸟细分类识别,包括400种各种姿态的鸟类图像 | 0.9673 | 添加技术交流群后即可获得 | [原数据下载地址](https://www.kaggle.com/datasets/gpiosenka/100-bird-species) |
| 运动种类 | <img src="../../images/ppshitu_application_scenarios/100sports.jpg" height = "100"> |100种运动图像识别 | 0.9413 | 添加技术交流群后即可获得 | [原数据下载地址](https://www.kaggle.com/datasets/gpiosenka/sports-classification) |
| 乐器 | <img src="../../images/ppshitu_application_scenarios/MusicInstruments.jpg" height = "100"> |30种不同乐器种类识别 | 0.9467 | 添加技术交流群后即可获得 | [原数据下载地址](https://www.kaggle.com/datasets/gpiosenka/musical-instruments-image-classification) |
| 宝可梦 | <img src="../../images/ppshitu_application_scenarios/Pokemon.png" height = "100"> |宝可梦神奇宝贝识别 | 0.9236 | 添加技术交流群后即可获得 | [原数据下载地址](https://www.kaggle.com/datasets/lantian773030/pokemonclassification) |
| 船 | <img src="../../images/ppshitu_application_scenarios/Boat.jpg" height = "100"> |船种类识别 |0.9242 | 添加技术交流群后即可获得 | [原数据下载地址](https://www.kaggle.com/datasets/imsparsh/dockship-boat-type-classification) |
| 鞋子 | <img src="../../images/ppshitu_application_scenarios/Shoes.jpeg" height = "100"> |鞋子种类识别,包括靴子、拖鞋等 | 0.9000 | 添加技术交流群后即可获得 | [原数据下载地址](https://www.kaggle.com/datasets/noobyogi0100/shoe-dataset) |
| 巴黎建筑 | <img src="../../images/ppshitu_application_scenarios/Paris.jpg" height = "100"> |巴黎著名建筑景点识别,如:巴黎铁塔、圣母院等 | 1.000 | 添加技术交流群后即可获得 | [原数据下载地址](https://www.kaggle.com/datasets/skylord/oxbuildings) |
| 蝴蝶 | <img src="../../images/ppshitu_application_scenarios/Butterfly.jpg" height = "100"> |75种蝴蝶细分类识别 | 0.9360 | 添加技术交流群后即可获得 | [原数据下载地址](https://www.kaggle.com/datasets/gpiosenka/butterfly-images40-species) |
| 野外植物 | <img src="../../images/ppshitu_application_scenarios/WildEdiblePlants.jpg" height = "100"> |野外植物识别 | 0.9758 | 添加技术交流群后即可获得 | [原数据下载地址](https://www.kaggle.com/datasets/ryanpartridge01/wild-edible-plants) |
| 天气 | <img src="../../images/ppshitu_application_scenarios/WeatherImageRecognition.jpg" height = "100"> |各种天气场景识别,如:雨天、打雷、下雪等 | 0.9924 | 添加技术交流群后即可获得 | [原数据下载地址](https://www.kaggle.com/datasets/jehanbhathena/weather-dataset) |
| 坚果 | <img src="../../images/ppshitu_application_scenarios/TreeNuts.jpg" height = "100"> |各种坚果种类识别 | 0.9412 | 添加技术交流群后即可获得 | [原数据下载地址](https://www.kaggle.com/datasets/gpiosenka/tree-nuts-image-classification) |
| 垃圾 | <img src="../../images/ppshitu_application_scenarios/Garbage12.jpg" height = "100"> |12种垃圾分类识别 | 0.9845 | 添加技术交流群后即可获得 | [原数据下载地址](https://www.kaggle.com/datasets/mostafaabla/garbage-classification) |
| 航拍场景 | <img src="../../images/ppshitu_application_scenarios/AID.jpg" height = "100"> |各种航拍场景识别,如机场、火车站等 | 0.9797 | 添加技术交流群后即可获得 | [原数据下载地址](https://www.kaggle.com/datasets/jiayuanchengala/aid-scene-classification-datasets) |
| 蔬菜 | <img src="../../images/ppshitu_application_scenarios/Veg200.jpg" height = "100"> |各种蔬菜识别 | 0.8929 | 添加技术交流群后即可获得 | [原数据下载地址](https://www.kaggle.com/datasets/zhaoyj688/vegfru) |
<div align="center">
<img src="https://user-images.githubusercontent.com/11568925/189877049-d17ddcea-22d2-44ab-91fe-36d12af3add8.png" width = "200" height = "200"/>
<p>添加二维码获取相关场景库及大礼包</p>
</div>
<a name="2. 使用说明"></a>
......
......@@ -22,7 +22,7 @@
- [2. 使用说明](#2)
- [2.1 环境安装](#2.1)
- [2.2 模型准备](#2.2)
- [2.2 模型及数据准备](#2.2)
- [2.3运行使用](#2.3)
- [3.生成文件介绍](#3)
......@@ -123,13 +123,25 @@
pip install fastapi
pip install uvicorn
pip install pyqt5
pip install psutil
```
<a name="2.2"></a>
### 2.2 模型准备
### 2.2 模型及数据准备
请按照[PP-ShiTu快速体验](../../quick_start/quick_start_recognition.md#2.2.1)中下载及准备inference model,并修改好`${PaddleClas}/deploy/configs/inference_drink.yaml`的相关参数。
请按照[PP-ShiTu快速体验](../quick_start/quick_start_recognition.md#2.2.1)中下载及准备inference model,并修改好`${PaddleClas}/deploy/configs/inference_drink.yaml`的相关参数,同时准备好数据集。在具体使用时,请替换好自己的数据集及模型文件。
```shell
cd ${PaddleClas}/deploy/shitu_index_manager
mkdir models
cd models
# 下载及解压识别模型
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/PP-ShiTuV2/general_PPLCNetV2_base_pretrained_v1.0_infer.tar && tar -xf general_PPLCNetV2_base_pretrained_v1.0_infer.tar
cd ..
# 下载及解压示例数据集
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/drink_dataset_v2.0.tar && tar -xf drink_dataset_v2.0.tar
```
<a name="2.3"></a>
......@@ -139,9 +151,26 @@ pip install pyqt5
```shell
cd ${PaddleClas}/deploy/shitu_index_manager
python index_manager.py -c ../configs/inference_drink.yaml
cp ../configs/inference_drink.yaml .
# 注意如果没有按照2.2中准备数据集及代码,请手动修改inference_drink.yaml,做好适配
python index_manager.py -c inference_drink.yaml
```
运行成功后,会自动跳转到工具界面,可以按照如下步骤,生成新的index库。
1. 点击菜单栏`新建图像库`,会提示打开一个文件夹,此时请创建一个**新的文件夹**,并打开。如在`${PaddleClas}/deploy/shitu_index_manager`下新建一个`drink_index`文件夹
2. 导入图像,或者如上面功能介绍,自己手动新增类别和相应的图像,下面介绍两种导入图像方式,操作时,二选一即可。
- 点击`导入图像`->`导入image_list图像`,打开`${PaddleClas}/deploy/shitu_index_manager/drink_dataset_v2.0/gallery/drink_label.txt`,此时就可以将`drink_label.txt`中的图像全部导入进来,图像类别就是`drink_label.txt`中记录的类别。
- 点击`导入图像`->`导入多文件夹图像`,打开`${PaddleClas}/deploy/shitu_index_manager/drink_dataset_v2.0/gallery/`文件夹,此时就将`gallery`文件夹下,所有子文件夹都导入进来,图像类别就是子文件夹的名字。
3. 点击菜单栏中`新建/重建 索引库`,此时就会开始生成索引库。如果图片较多或者使用cpu来进行特征提取,那么耗时会比较长,请耐心等待。
4. 生成索引库成功后,会发现在`drink_index`文件夹下生成如[3](#3) 中介绍的文件,此时`index`子文件夹下生出的文件,就是`PP-ShiTu`所使用的索引文件。
**注意**
- 利用此工具生成的index库,如`drink_index`文件夹,请妥善存储。之后,可以继续使用此工具中`打开图像库`功能,打开`drink_index`文件夹,继续对index库进行增删改查操作,具体功能可以查看[功能介绍](#1)
- 打开一个生成好的库,在其上面进行增删改查操作后,请及时保存。保存后并及时使用菜单中`更新索引库`功能,对索引库进行更新
- 如果要使用自己的图像库文件,图像生成格式如示例数据格式,生成`image_list.txt`或者多文件夹存储,二选一。
<a name="3"></a>
## 3. 生成文件介绍
......
......@@ -124,11 +124,15 @@ PaddleClas 提供了转换并优化后的推理模型,可以直接参考下方
```shell
# 进入lite_ppshitu目录
cd $PaddleClas/deploy/lite_shitu
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/lite/ppshitu_lite_models_v1.2.tar
tar -xf ppshitu_lite_models_v1.2.tar
rm -f ppshitu_lite_models_v1.2.tar
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/lite/ppshituv2_lite_models_v1.0.tar
tar -xf ppshituv2_lite_models_v1.0.tar
rm -f ppshituv2_lite_models_v1.0.tar
```
解压后的文件夹中包含PP-ShiTu V2的检测和特征模型的量化模型,即解压后的两个`.nb`模型。 存储大小约为14M,相对于非量化模型,模型加速一倍左右,精度降低不到一个点。如果想用非量化模型,可以参考[使用其他模型](#2.1.2)
同时,文件中存在`general_PPLCNetV2_base_quant_v1.0_lite_inference_model`子文件夹,子文件夹中存储的是`general_PPLCNetV2_base_quant_v1.0_lite.nb`对应的`inference model`,在实际部署前,先需要用此`inference model`生成好对应的`index`库。实际推理时,只需要生成好的`index`库和两个量化模型来完成部署。
<a name="2.1.2"></a>
#### 2.1.2 使用其他模型
......@@ -188,6 +192,8 @@ Paddle-Lite 提供了多种策略来自动优化原始的模型,其中包括
下面介绍使用`paddle_lite_opt`完成主体检测模型和识别模型的预训练模型,转成inference模型,最终转换成Paddle-Lite的优化模型的过程。
此示例是使用`fp32`的模型直接生成`.nb`模型,相对于量化模型,精度会更高,但是速度会降低。
1. 转换主体检测模型
```shell
......@@ -214,13 +220,13 @@ cp $code_path/PaddleDetection/inference/picodet_lcnet_x2_5_640_mainbody/mainbody
```shell
# 识别模型下载
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/general_PPLCNet_x2_5_lite_v1.0_infer.tar
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/PP-ShiTuV2/general_PPLCNetV2_base_pretrained_v1.0_infer.tar
# 解压模型
tar -xf general_PPLCNet_x2_5_lite_v1.0_infer.tar
tar -xf general_PPLCNetV2_base_pretrained_v1.0_infer.tar
# 转换为Paddle-Lite模型
paddle_lite_opt --model_file=general_PPLCNet_x2_5_lite_v1.0_infer/inference.pdmodel --param_file=general_PPLCNet_x2_5_lite_v1.0_infer/inference.pdiparams --optimize_out=general_PPLCNet_x2_5_lite_v1.0_infer/rec
paddle_lite_opt --model_file=general_PPLCNetV2_base_pretrained_v1.0_infer/inference.pdmodel --param_file=general_PPLCNetV2_base_pretrained_v1.0_infer/inference.pdiparams --optimize_out=general_PPLCNetV2_base_pretrained_v1.0_infer/rec
# 将模型文件拷贝到lite_shitu下
cp general_PPLCNet_x2_5_lite_v1.0_infer/rec.nb deploy/lite_shitu/models/
cp general_PPLCNetV2_base_pretrained_v1.0_infer/rec.nb deploy/lite_shitu/models/
```
**注意**`--optimize_out` 参数为优化后模型的保存路径,无需加后缀`.nb``--model_file` 参数为模型结构信息文件的路径,`--param_file` 参数为模型权重信息文件的路径,请注意文件名。
......@@ -242,17 +248,12 @@ cd $PaddleClas
python setup.py install
cd deploy
# 下载瓶装饮料数据集
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/drink_dataset_v1.0.tar && tar -xf drink_dataset_v1.0.tar
rm -rf drink_dataset_v1.0.tar
rm -rf drink_dataset_v1.0/index
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/drink_dataset_v2.0.tar && tar -xf drink_dataset_v2.0.tar
rm -rf drink_dataset_v2.0.tar
rm -rf drink_dataset_v2.0/index
# 安装1.5.3版本的faiss
pip install faiss-cpu==1.5.3
# 下载通用识别模型,可替换成自己的inference model
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/general_PPLCNet_x2_5_lite_v1.0_infer.tar
tar -xf general_PPLCNet_x2_5_lite_v1.0_infer.tar
rm -rf general_PPLCNet_x2_5_lite_v1.0_infer.tar
```
<a name="2.2.2"></a>
......@@ -262,11 +263,11 @@ rm -rf general_PPLCNet_x2_5_lite_v1.0_infer.tar
```shell
# 生成新的index库,注意指定好识别模型的路径,同时将index_mothod修改成Flat,HNSW32和IVF在此版本中可能存在bug,请慎重使用。
# 如果使用自己的识别模型,对应的修改inference model的目录
python python/build_gallery.py -c configs/inference_drink.yaml -o Global.rec_inference_model_dir=general_PPLCNet_x2_5_lite_v1.0_infer -o IndexProcess.index_method=Flat
python python/build_gallery.py -c configs/inference_drink.yaml -o Global.rec_inference_model_dir=lite_shitu/ppshituv2_lite_models_v1.0/general_PPLCNetV2_base_quant_v1.0_lite_inference_model -o IndexProcess.index_method=Flat
# 进入到lite_shitu目录
cd lite_shitu
mv ../drink_dataset_v1.0 .
mv ../drink_dataset_v2.0 .
```
<a name="2.3"></a>
......@@ -274,14 +275,17 @@ mv ../drink_dataset_v1.0 .
### 2.3 将yaml文件转换成json文件
```shell
# 如果是使用自己生成的.nb模型,请修改好对应的det_model_path和rec_model_path的路径
# 如果测试单张图像,路径使用相对路径
python generate_json_config.py --det_model_path ppshitu_lite_models_v1.2/mainbody_PPLCNet_x2_5_640_v1.2_lite.nb --rec_model_path ppshitu_lite_models_v1.2/general_PPLCNet_x2_5_lite_v1.2_infer.nb --img_path images/demo.jpeg
python generate_json_config.py --det_model_path ppshituv2_lite_models_v1.0/mainbody_PPLCNet_x2_5_640_quant_v1.0_lite.nb --rec_model_path ppshituv2_lite_models_v1.0/general_PPLCNetV2_base_quant_v1.0_lite.nb --img_path images/demo.jpeg
# or
# 如果测试多张图像
python generate_json_config.py --det_model_path ppshitu_lite_models_v1.2/mainbody_PPLCNet_x2_5_640_v1.2_lite.nb --rec_model_path ppshitu_lite_models_v1.2/general_PPLCNet_x2_5_lite_v1.2_infer.nb --img_dir images
python generate_json_config.py --det_model_path ppshituv2_lite_models_v1.0/mainbody_PPLCNet_x2_5_640_quant_v1.0_lite.nb --rec_model_path ppshituv2_lite_models_v1.0/general_PPLCNetV2_base_quant_v1.0_lite.nb --img_dir images
# 执行完成后,会在lit_shitu下生成shitu_config.json配置文件
```
**注意**:生成json文件的时候,请注意修改好模型路径,特别是使用自己生成的`.nb`模型时
<a name="2.4"></a>
### 2.4 index字典转换
......@@ -289,7 +293,6 @@ python generate_json_config.py --det_model_path ppshitu_lite_models_v1.2/mainbod
由于python的检索库字典,使用`pickle`进行的序列化存储,导致C++不方便读取,因此需要进行转换
```shell
# 转化id_map.pkl为id_map.txt
python transform_id_map.py -c ../configs/inference_drink.yaml
```
......@@ -349,8 +352,8 @@ make ARM_ABI=arm8
```shell
mkdir deploy
# 移动的模型路径要和之前生成的json文件中模型路径一致
mv ppshitu_lite_models_v1.2 deploy/
mv drink_dataset_v1.0 deploy/
mv ppshituv2_lite_models_v1.0 deploy/
mv drink_dataset_v2.0 deploy/
mv images deploy/
mv shitu_config.json deploy/
cp pp_shitu deploy/
......@@ -363,12 +366,13 @@ cp ../../../cxx/lib/libpaddle_light_api_shared.so deploy/
```shell
deploy/
|-- ppshitu_lite_models_v1.1/
| |--mainbody_PPLCNet_x2_5_640_quant_v1.1_lite.nb 优化后的主体检测模型文件
| |--general_PPLCNet_x2_5_lite_v1.1_infer.nb 优化后的识别模型文件
|-- ppshituv2_lite_models_v1.0/
| |--mainbody_PPLCNet_x2_5_640_quant_v1.0_lite.nb 主体检测lite模型文件
| |--general_PPLCNetV2_base_quant_v1.0_lite.nb 识别lite模型文件
| |--general_PPLCNetV2_base_quant_v1.0_lite_inference_model 识别对应的inference model文件夹
|-- images/
| |--demo.jpg 图片文件
|-- drink_dataset_v1.0/ 瓶装饮料demo数据
|-- drink_dataset_v2.0/ 瓶装饮料demo数据
| |--index 检索index目录
|-- pp_shitu 生成的移动端执行文件
|-- shitu_config.json 执行时参数配置文件
......@@ -399,8 +403,7 @@ chmod 777 pp_shitu
运行效果如下:
```
images/demo.jpeg:
result0: bbox[344, 98, 527, 593], score: 0.811656, label: 红牛-强化型
result1: bbox[0, 0, 600, 600], score: 0.729664, label: 红牛-强化型
result1: bbox[0, 0, 600, 600], score: 0.696782, label: 红牛-强化型
```
<a name="FAQ"></a>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册