提交 52ba23c8 编写于 作者: D dongshuilong

fix build_gallery bug

上级 9eaa3353
...@@ -50,8 +50,8 @@ class GalleryBuilder(object): ...@@ -50,8 +50,8 @@ class GalleryBuilder(object):
self.config = config self.config = config
self.rec_predictor = RecPredictor(config) self.rec_predictor = RecPredictor(config)
assert 'IndexProcess' in config.keys(), "Index config not found ... " assert 'IndexProcess' in config.keys(), "Index config not found ... "
self.build(config['IndexProcess'])
self.android_demo = config["Global"].get("android_demo", False) self.android_demo = config["Global"].get("android_demo", False)
self.build(config['IndexProcess'])
def build(self, config): def build(self, config):
''' '''
...@@ -78,7 +78,8 @@ class GalleryBuilder(object): ...@@ -78,7 +78,8 @@ class GalleryBuilder(object):
index, ids = None, None index, ids = None, None
if operation_method in ["remove", "append"]: if operation_method in ["remove", "append"]:
# if remove or append, load vector.index and id_map.pkl # if remove or append, load vector.index and id_map.pkl
index, ids = self._load_index() index, ids = self._load_index(config)
index_method = config.get("index_method", "HNSW32")
else: else:
index_method, index, ids = self._create_index(config) index_method, index, ids = self._create_index(config)
if index_method == "HNSW32": if index_method == "HNSW32":
...@@ -87,7 +88,7 @@ class GalleryBuilder(object): ...@@ -87,7 +88,7 @@ class GalleryBuilder(object):
if operation_method != "remove": if operation_method != "remove":
# calculate id for new data # calculate id for new data
index, ids = self._add_gallery(index, ids, gallery_features, gallery_docs) index, ids = self._add_gallery(index, ids, gallery_features, gallery_docs, config, operation_method)
else: else:
if index_method == "HNSW32": if index_method == "HNSW32":
raise RuntimeError( raise RuntimeError(
...@@ -104,14 +105,11 @@ class GalleryBuilder(object): ...@@ -104,14 +105,11 @@ class GalleryBuilder(object):
os.makedirs(config["index_dir"], exist_ok=True) os.makedirs(config["index_dir"], exist_ok=True)
#build index #build index
index = faiss.IndexFlatIP(config["embedding_size"]) index = faiss.IndexFlatIP(config["embedding_size"])
ids = {}
# calculate id for new data
ids_now = (
np.arange(0, len(gallery_images))).astype(np.int64)
index.add(gallery_features) index.add(gallery_features)
# calculate id for data
ids_now = (np.arange(0, len(gallery_docs))).astype(np.int64)
ids = {}
for i, d in zip(list(ids_now), gallery_docs): for i, d in zip(list(ids_now), gallery_docs):
ids[i] = d ids[i] = d
self._save_gallery(config, index, ids) self._save_gallery(config, index, ids)
...@@ -197,7 +195,7 @@ class GalleryBuilder(object): ...@@ -197,7 +195,7 @@ class GalleryBuilder(object):
ids = {} ids = {}
return index_method, index, ids return index_method, index, ids
def _add_gallery(self, index, ids, gallery_features, gallery_docs): def _add_gallery(self, index, ids, gallery_features, gallery_docs, config, operation_method):
start_id = max(ids.keys()) + 1 if ids else 0 start_id = max(ids.keys()) + 1 if ids else 0
ids_now = ( ids_now = (
np.arange(0, len(gallery_docs)) + start_id).astype(np.int64) np.arange(0, len(gallery_docs)) + start_id).astype(np.int64)
...@@ -216,7 +214,7 @@ class GalleryBuilder(object): ...@@ -216,7 +214,7 @@ class GalleryBuilder(object):
ids[i] = d ids[i] = d
return index, ids return index, ids
def _rm_id_in_galllery(self, index, ids, gallery_docs) def _rm_id_in_galllery(self, index, ids, gallery_docs):
remove_ids = list( remove_ids = list(
filter(lambda k: ids.get(k) in gallery_docs, ids.keys())) filter(lambda k: ids.get(k) in gallery_docs, ids.keys()))
remove_ids = np.asarray(remove_ids) remove_ids = np.asarray(remove_ids)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册