提交 687c1352 编写于 作者: D dongshuilong

replace vector_search with faiss

上级 800487e4
......@@ -28,11 +28,11 @@ RecPostProcess: null
# indexing engine config
IndexProcess:
index_path: "./recognition_demo_data_v1.0/gallery_cartoon/index/"
index_method: "HNSW32" # supported: HNSW32, IVF, Flat
index_dir: "./recognition_demo_data_v1.0/gallery_cartoon/index/"
image_root: "./recognition_demo_data_v1.0/gallery_cartoon/"
data_file: "./recognition_demo_data_v1.0/gallery_cartoon/data_file.txt"
append_index: False
index_operation: "new" # suported: "append", "remove", "new"
delimiter: "\t"
dist_type: "IP"
pq_size: 100
embedding_size: 2048
......@@ -26,11 +26,11 @@ RecPostProcess: null
# indexing engine config
IndexProcess:
index_path: "./recognition_demo_data_v1.0/gallery_logo/index/"
index_method: "HNSW32" # supported: HNSW32, IVF, Flat
index_dir: "./recognition_demo_data_v1.0/gallery_logo/index/"
image_root: "./recognition_demo_data_v1.0/gallery_logo/"
data_file: "./recognition_demo_data_v1.0/gallery_logo/data_file.txt"
append_index: False
index_operation: "new" # suported: "append", "remove", "new"
delimiter: "\t"
dist_type: "IP"
pq_size: 100
embedding_size: 512
......@@ -26,11 +26,11 @@ RecPostProcess: null
# indexing engine config
IndexProcess:
index_path: "./recognition_demo_data_v1.0/gallery_product/index"
index_method: "HNSW32" # supported: HNSW32, IVF, Flat
index_dir: "./recognition_demo_data_v1.0/gallery_product/index"
image_root: "./recognition_demo_data_v1.0/gallery_product/"
data_file: "./recognition_demo_data_v1.0/gallery_product/data_file.txt"
append_index: False
index_operation: "new" # suported: "append", "remove", "new"
delimiter: "\t"
dist_type: "IP"
pq_size: 100
embedding_size: 512
......@@ -26,11 +26,11 @@ RecPostProcess: null
# indexing engine config
IndexProcess:
index_path: "./recognition_demo_data_v1.0/gallery_vehicle/index/"
index_method: "HNSW32" # supported: HNSW32, IVF, Flat
index_dir: "./recognition_demo_data_v1.0/gallery_vehicle/index/"
image_root: "./recognition_demo_data_v1.0/gallery_vehicle/"
data_file: "./recognition_demo_data_v1.0/gallery_vehicle/data_file.txt"
append_index: False
index_operation: "new" # suported: "append", "remove", "new"
delimiter: "\t"
dist_type: "IP"
pq_size: 100
embedding_size: 512
......@@ -51,8 +51,6 @@ RecPreProcess:
RecPostProcess: null
IndexProcess:
index_path: "./recognition_demo_data_v1.0/gallery_cartoon/index/"
search_budget: 100
index_dir: "./recognition_demo_data_v1.0/gallery_cartoon/index/"
return_k: 5
dist_type: "IP"
score_thres: 0.5
......@@ -50,8 +50,6 @@ RecPostProcess: null
# indexing engine config
IndexProcess:
index_path: "./recognition_demo_data_v1.0/gallery_logo/index/"
search_budget: 100
index_dir: "./recognition_demo_data_v1.0/gallery_logo/index/"
return_k: 5
dist_type: "IP"
score_thres: 0.5
......@@ -50,8 +50,6 @@ RecPostProcess: null
# indexing engine config
IndexProcess:
index_path: "./recognition_demo_data_v1.0/gallery_product/index"
search_budget: 100
index_dir: "./recognition_demo_data_v1.0/gallery_product/index"
return_k: 5
dist_type: "IP"
score_thres: 0.5
......@@ -52,8 +52,6 @@ RecPostProcess: null
# indexing engine config
IndexProcess:
index_path: "./recognition_demo_data_v1.0/gallery_vehicle/index/"
search_budget: 100
index_dir: "./recognition_demo_data_v1.0/gallery_vehicle/index/"
return_k: 5
dist_type: "IP"
score_thres: 0.5
......@@ -17,13 +17,13 @@ import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))
import copy
import cv2
import faiss
import numpy as np
from tqdm import tqdm
import pickle
from python.predict_rec import RecPredictor
from vector_search import Graph_Index
from utils import logger
from utils import config
......@@ -31,9 +31,9 @@ from utils import config
def split_datafile(data_file, image_root, delimiter="\t"):
'''
data_file: image path and info, which can be splitted by spacer
data_file: image path and info, which can be splitted by spacer
image_root: image path root
delimiter: delimiter
delimiter: delimiter
'''
gallery_images = []
gallery_docs = []
......@@ -45,9 +45,8 @@ def split_datafile(data_file, image_root, delimiter="\t"):
assert text_num >= 2, f"line({ori_line}) must be splitted into at least 2 parts, but got {text_num}"
image_file = os.path.join(image_root, line[0])
image_doc = line[1]
gallery_images.append(image_file)
gallery_docs.append(image_doc)
gallery_docs.append(ori_line.strip())
return gallery_images, gallery_docs
......@@ -64,9 +63,77 @@ class GalleryBuilder(object):
'''
build index from scratch
'''
operation_method = config.get("index_operation", "new").lower()
gallery_images, gallery_docs = split_datafile(
config['data_file'], config['image_root'], config['delimiter'])
if operation_method != "remove":
gallery_features = self._extract_features(gallery_images, config)
assert operation_method in [
"new", "remove", "append"
], "Only append, remove and new operation are supported"
if operation_method in ["remove", "append"]:
assert os.path.join(
config["index_dir"], "vector.index"
), "The vector.index dose not exist in {} when 'index_operation' is not None".format(
config["index_dir"])
assert os.path.join(
config["index_dir"], "id_map.pkl"
), "The id_map.pkl dose not exist in {} when 'index_operation' is not None".format(
config["index_dir"])
index = faiss.read_index(
os.path.join(config["index_dir"], "vector.index"))
with open(os.path.join(config["index_dir"], "id_map.pkl"),
'rb') as fd:
ids = pickle.load(fd)
assert index.ntotal == len(ids.keys(
)), "data number in index is not equal in in id_map"
else:
if not os.path.exists(config["index_dir"]):
os.makedirs(config["index_dir"], exist_ok=True)
index_method = config.get("index_method", "HNSW32")
if index_method == "IVF":
index_method = index_method + str(
min(int(len(gallery_images) // 8), 65536)) + ",Flat"
dist_type = faiss.METRIC_INNER_PRODUCT if config[
"dist_type"] == "IP" else faiss.METRIC_L2
index = faiss.index_factory(config["embedding_size"], index_method,
dist_type)
index = faiss.IndexIDMap2(index)
ids = {}
if config["index_method"] == "HNSW32":
logger.warning(
"The HNSW32 method dose not support 'remove' operation")
if operation_method != "remove":
start_id = max(ids.keys()) + 1 if ids else 0
ids_now = np.arange(0, len(gallery_images)) + start_id
if operation_method == "new":
index.train(gallery_features)
index.add_with_ids(gallery_features, ids_now)
for i, d in zip(list(ids_now), gallery_docs):
ids[i] = d
else:
if config["index_method"] == "HNSW32":
raise RuntimeError(
"The index_method: HNSW32 dose not support 'remove' operation"
)
remove_ids = list(
filter(lambda k: ids.get(k) in gallery_docs, ids.keys()))
remove_ids = np.asarray(remove_ids)
index.remove_ids(remove_ids)
for k in remove_ids:
del ids[k]
faiss.write_index(index,
os.path.join(config["index_dir"], "vector.index"))
with open(os.path.join(config["index_dir"], "id_map.pkl"), 'wb') as fd:
pickle.dump(ids, fd)
def _extract_features(self, gallery_images, config):
# extract gallery features
gallery_features = np.zeros(
[len(gallery_images), config['embedding_size']], dtype=np.float32)
......@@ -91,19 +158,11 @@ class GalleryBuilder(object):
rec_feat = self.rec_predictor.predict(batch_img)
gallery_features[-len(batch_img):, :] = rec_feat
batch_img = []
# train index
self.Searcher = Graph_Index(dist_type=config['dist_type'])
self.Searcher.build(
gallery_vectors=gallery_features,
gallery_docs=gallery_docs,
pq_size=config['pq_size'],
index_path=config['index_path'],
append_index=config["append_index"])
return gallery_features
def main(config):
system_builder = GalleryBuilder(config)
GalleryBuilder(config)
return
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
......@@ -20,10 +20,11 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))
import copy
import cv2
import numpy as np
import faiss
import pickle
from python.predict_rec import RecPredictor
from python.predict_det import DetPredictor
from vector_search import Graph_Index
from utils import logger
from utils import config
......@@ -40,11 +41,16 @@ class SystemPredictor(object):
assert 'IndexProcess' in config.keys(), "Index config not found ... "
self.return_k = self.config['IndexProcess']['return_k']
self.search_budget = self.config['IndexProcess']['search_budget']
self.Searcher = Graph_Index(
dist_type=config['IndexProcess']['dist_type'])
self.Searcher.load(config['IndexProcess']['index_path'])
index_dir = self.config["IndexProcess"]["index_dir"]
assert os.path.exists(os.path.join(
index_dir, "vector.index")), "vector.index not found ..."
assert os.path.exists(os.path.join(
index_dir, "id_map.pkl")), "id_map.pkl not found ... "
self.Searcher = faiss.read_index(
os.path.join(index_dir, "vector.index"))
with open(os.path.join(index_dir, "id_map.pkl"), "rb") as fd:
self.id_map = pickle.load(fd)
def append_self(self, results, shape):
results.append({
......@@ -98,14 +104,11 @@ class SystemPredictor(object):
crop_img = img[ymin:ymax, xmin:xmax, :].copy()
rec_results = self.rec_predictor.predict(crop_img)
preds["bbox"] = [xmin, ymin, xmax, ymax]
scores, docs = self.Searcher.search(
query=rec_results,
return_k=self.return_k,
search_budget=self.search_budget)
scores, docs = self.Searcher.search(rec_results, self.return_k)
# just top-1 result will be returned for the final
if scores[0] >= self.config["IndexProcess"]["score_thres"]:
preds["rec_docs"] = docs[0]
preds["rec_scores"] = scores[0]
if scores[0][0] >= self.config["IndexProcess"]["score_thres"]:
preds["rec_docs"] = self.id_map[docs[0][0]].split()[1]
preds["rec_scores"] = scores[0][0]
output.append(preds)
# st5: nms to the final results to avoid fetching duplicate results
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册