diff --git a/deploy/python/build_gallery.py b/deploy/python/build_gallery.py index 8184d59608d4f6593a7170f9f933794d85ef675e..63c411c64ead923cf77d3ed1b870642c58be9d92 100644 --- a/deploy/python/build_gallery.py +++ b/deploy/python/build_gallery.py @@ -12,16 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import pickle import cv2 import faiss import numpy as np -from tqdm import tqdm -import pickle - -from paddleclas.deploy.utils import logger, config -from paddleclas.deploy.python.predict_rec import RecPredictor from paddleclas.deploy.python.predict_rec import RecPredictor +from paddleclas.deploy.utils import config, logger +from tqdm import tqdm def split_datafile(data_file, image_root, delimiter="\t"): @@ -52,6 +50,7 @@ class GalleryBuilder(object): self.config = config self.rec_predictor = RecPredictor(config) assert 'IndexProcess' in config.keys(), "Index config not found ... " + self.android_demo = config["Global"].get("android_demo", False) self.build(config['IndexProcess']) def build(self, config): @@ -70,98 +69,50 @@ class GalleryBuilder(object): "new", "remove", "append" ], "Only append, remove and new operation are supported" + if self.android_demo: + self._create_index_for_android_demo(config, gallery_features, gallery_docs) + return + # vector.index: faiss index file # id_map.pkl: use this file to map id to image_doc + index, ids = None, None if operation_method in ["remove", "append"]: - # if remove or append, vector.index and id_map.pkl must exist - assert os.path.join( - config["index_dir"], "vector.index" - ), "The vector.index dose not exist in {} when 'index_operation' is not None".format( - config["index_dir"]) - assert os.path.join( - config["index_dir"], "id_map.pkl" - ), "The id_map.pkl dose not exist in {} when 'index_operation' is not None".format( - config["index_dir"]) - index = faiss.read_index( - os.path.join(config["index_dir"], "vector.index")) - with open(os.path.join(config["index_dir"], "id_map.pkl"), - 'rb') as fd: - ids = pickle.load(fd) - assert index.ntotal == len(ids.keys( - )), "data number in index is not equal in in id_map" - else: - if not os.path.exists(config["index_dir"]): - os.makedirs(config["index_dir"], exist_ok=True) + # if remove or append, load vector.index and id_map.pkl + index, ids = self._load_index(config) index_method = config.get("index_method", "HNSW32") - - # if IVF method, cal ivf number automaticlly - if index_method == "IVF": - index_method = index_method + str( - min(int(len(gallery_images) // 8), 65536)) + ",Flat" - - # for binary index, add B at head of index_method - if config["dist_type"] == "hamming": - index_method = "B" + index_method - - #dist_type - dist_type = faiss.METRIC_INNER_PRODUCT if config[ - "dist_type"] == "IP" else faiss.METRIC_L2 - - #build index - if config["dist_type"] == "hamming": - index = faiss.index_binary_factory(config["embedding_size"], - index_method) - else: - index = faiss.index_factory(config["embedding_size"], - index_method, dist_type) - index = faiss.IndexIDMap2(index) - ids = {} - - if config["index_method"] == "HNSW32": + else: + index_method, index, ids = self._create_index(config) + if index_method == "HNSW32": logger.warning( "The HNSW32 method dose not support 'remove' operation") if operation_method != "remove": # calculate id for new data - start_id = max(ids.keys()) + 1 if ids else 0 - ids_now = ( - np.arange(0, len(gallery_images)) + start_id).astype(np.int64) - - # only train when new index file - if operation_method == "new": - if config["dist_type"] == "hamming": - index.add(gallery_features) - else: - index.train(gallery_features) - - if not config["dist_type"] == "hamming": - index.add_with_ids(gallery_features, ids_now) - - for i, d in zip(list(ids_now), gallery_docs): - ids[i] = d + index, ids = self._add_gallery(index, ids, gallery_features, gallery_docs, config, operation_method) else: - if config["index_method"] == "HNSW32": + if index_method == "HNSW32": raise RuntimeError( "The index_method: HNSW32 dose not support 'remove' operation" ) # remove ids in id_map, remove index data in faiss index - remove_ids = list( - filter(lambda k: ids.get(k) in gallery_docs, ids.keys())) - remove_ids = np.asarray(remove_ids) - index.remove_ids(remove_ids) - for k in remove_ids: - del ids[k] + index, ids = self._rm_id_in_galllery(index, ids, gallery_docs) # store faiss index file and id_map file - if config["dist_type"] == "hamming": - faiss.write_index_binary( - index, os.path.join(config["index_dir"], "vector.index")) - else: - faiss.write_index( - index, os.path.join(config["index_dir"], "vector.index")) - - with open(os.path.join(config["index_dir"], "id_map.pkl"), 'wb') as fd: - pickle.dump(ids, fd) + self._save_gallery(config, index, ids) + + def _create_index_for_android_demo(self, config, gallery_features, gallery_docs): + if not os.path.exists(config["index_dir"]): + os.makedirs(config["index_dir"], exist_ok=True) + #build index + index = faiss.IndexFlatIP(config["embedding_size"]) + index.add(gallery_features) + + # calculate id for data + ids_now = (np.arange(0, len(gallery_docs))).astype(np.int64) + ids = {} + for i, d in zip(list(ids_now), gallery_docs): + ids[i] = d + self._save_gallery(config, index, ids) def _extract_features(self, gallery_images, config): # extract gallery features @@ -197,6 +148,93 @@ class GalleryBuilder(object): return gallery_features + def _load_index(self, config): + assert os.path.join( + config["index_dir"], "vector.index" + ), "The vector.index dose not exist in {} when 'index_operation' is not None".format( + config["index_dir"]) + assert os.path.join( + config["index_dir"], "id_map.pkl" + ), "The id_map.pkl dose not exist in {} when 'index_operation' is not None".format( + config["index_dir"]) + index = faiss.read_index( + os.path.join(config["index_dir"], "vector.index")) + with open(os.path.join(config["index_dir"], "id_map.pkl"), + 'rb') as fd: + ids = pickle.load(fd) + assert index.ntotal == len(ids.keys( + )), "data number in index is not equal in in id_map" + return index, ids + + def _create_index(self, config): + if not os.path.exists(config["index_dir"]): + os.makedirs(config["index_dir"], exist_ok=True) + index_method = config.get("index_method", "HNSW32") + + # if IVF method, cal ivf number automaticlly + if index_method == "IVF": + index_method = index_method + str( + min(int(len(gallery_images) // 8), 65536)) + ",Flat" + + # for binary index, add B at head of index_method + if config["dist_type"] == "hamming": + index_method = "B" + index_method + + #dist_type + dist_type = faiss.METRIC_INNER_PRODUCT if config[ + "dist_type"] == "IP" else faiss.METRIC_L2 + + #build index + if config["dist_type"] == "hamming": + index = faiss.index_binary_factory(config["embedding_size"], + index_method) + else: + index = faiss.index_factory(config["embedding_size"], + index_method, dist_type) + index = faiss.IndexIDMap2(index) + ids = {} + return index_method, index, ids + + def _add_gallery(self, index, ids, gallery_features, gallery_docs, config, operation_method): + start_id = max(ids.keys()) + 1 if ids else 0 + ids_now = ( + np.arange(0, len(gallery_docs)) + start_id).astype(np.int64) + + # only train when new index file + if operation_method == "new": + if config["dist_type"] == "hamming": + index.add(gallery_features) + else: + index.train(gallery_features) + + if not config["dist_type"] == "hamming": + index.add_with_ids(gallery_features, ids_now) + + for i, d in zip(list(ids_now), gallery_docs): + ids[i] = d + return index, ids + + def _rm_id_in_galllery(self, index, ids, gallery_docs): + remove_ids = list( + filter(lambda k: ids.get(k) in gallery_docs, ids.keys())) + remove_ids = np.asarray(remove_ids) + index.remove_ids(remove_ids) + for k in remove_ids: + del ids[k] + + return index, ids + + def _save_gallery(self, config, index, ids): + if config["dist_type"] == "hamming": + faiss.write_index_binary( + index, os.path.join(config["index_dir"], "vector.index")) + else: + faiss.write_index( + index, os.path.join(config["index_dir"], "vector.index")) + + with open(os.path.join(config["index_dir"], "id_map.pkl"), 'wb') as fd: + pickle.dump(ids, fd) + def main(config): GalleryBuilder(config)