diff --git a/deploy/python/build_gallery.py b/deploy/python/build_gallery.py index 439e9ab0af13d330d07b821090fdab252671bab0..63c411c64ead923cf77d3ed1b870642c58be9d92 100644 --- a/deploy/python/build_gallery.py +++ b/deploy/python/build_gallery.py @@ -50,8 +50,8 @@ class GalleryBuilder(object): self.config = config self.rec_predictor = RecPredictor(config) assert 'IndexProcess' in config.keys(), "Index config not found ... " - self.build(config['IndexProcess']) self.android_demo = config["Global"].get("android_demo", False) + self.build(config['IndexProcess']) def build(self, config): ''' @@ -78,7 +78,8 @@ class GalleryBuilder(object): index, ids = None, None if operation_method in ["remove", "append"]: # if remove or append, load vector.index and id_map.pkl - index, ids = self._load_index() + index, ids = self._load_index(config) + index_method = config.get("index_method", "HNSW32") else: index_method, index, ids = self._create_index(config) if index_method == "HNSW32": @@ -87,7 +88,7 @@ class GalleryBuilder(object): if operation_method != "remove": # calculate id for new data - index, ids = self._add_gallery(index, ids, gallery_features, gallery_docs) + index, ids = self._add_gallery(index, ids, gallery_features, gallery_docs, config, operation_method) else: if index_method == "HNSW32": raise RuntimeError( @@ -104,14 +105,11 @@ class GalleryBuilder(object): os.makedirs(config["index_dir"], exist_ok=True) #build index index = faiss.IndexFlatIP(config["embedding_size"]) - ids = {} - - # calculate id for new data - ids_now = ( - np.arange(0, len(gallery_images))).astype(np.int64) - index.add(gallery_features) + # calculate id for data + ids_now = (np.arange(0, len(gallery_docs))).astype(np.int64) + ids = {} for i, d in zip(list(ids_now), gallery_docs): ids[i] = d self._save_gallery(config, index, ids) @@ -197,7 +195,7 @@ class GalleryBuilder(object): ids = {} return index_method, index, ids - def _add_gallery(self, index, ids, gallery_features, gallery_docs): + def _add_gallery(self, index, ids, gallery_features, gallery_docs, config, operation_method): start_id = max(ids.keys()) + 1 if ids else 0 ids_now = ( np.arange(0, len(gallery_docs)) + start_id).astype(np.int64) @@ -216,7 +214,7 @@ class GalleryBuilder(object): ids[i] = d return index, ids - def _rm_id_in_galllery(self, index, ids, gallery_docs) + def _rm_id_in_galllery(self, index, ids, gallery_docs): remove_ids = list( filter(lambda k: ids.get(k) in gallery_docs, ids.keys())) remove_ids = np.asarray(remove_ids)