提交 52ba23c8 编写于 作者: D dongshuilong

fix build_gallery bug

上级 9eaa3353
......@@ -50,8 +50,8 @@ class GalleryBuilder(object):
self.config = config
self.rec_predictor = RecPredictor(config)
assert 'IndexProcess' in config.keys(), "Index config not found ... "
self.build(config['IndexProcess'])
self.android_demo = config["Global"].get("android_demo", False)
self.build(config['IndexProcess'])
def build(self, config):
'''
......@@ -78,7 +78,8 @@ class GalleryBuilder(object):
index, ids = None, None
if operation_method in ["remove", "append"]:
# if remove or append, load vector.index and id_map.pkl
index, ids = self._load_index()
index, ids = self._load_index(config)
index_method = config.get("index_method", "HNSW32")
else:
index_method, index, ids = self._create_index(config)
if index_method == "HNSW32":
......@@ -87,7 +88,7 @@ class GalleryBuilder(object):
if operation_method != "remove":
# calculate id for new data
index, ids = self._add_gallery(index, ids, gallery_features, gallery_docs)
index, ids = self._add_gallery(index, ids, gallery_features, gallery_docs, config, operation_method)
else:
if index_method == "HNSW32":
raise RuntimeError(
......@@ -104,14 +105,11 @@ class GalleryBuilder(object):
os.makedirs(config["index_dir"], exist_ok=True)
#build index
index = faiss.IndexFlatIP(config["embedding_size"])
ids = {}
# calculate id for new data
ids_now = (
np.arange(0, len(gallery_images))).astype(np.int64)
index.add(gallery_features)
# calculate id for data
ids_now = (np.arange(0, len(gallery_docs))).astype(np.int64)
ids = {}
for i, d in zip(list(ids_now), gallery_docs):
ids[i] = d
self._save_gallery(config, index, ids)
......@@ -197,7 +195,7 @@ class GalleryBuilder(object):
ids = {}
return index_method, index, ids
def _add_gallery(self, index, ids, gallery_features, gallery_docs):
def _add_gallery(self, index, ids, gallery_features, gallery_docs, config, operation_method):
start_id = max(ids.keys()) + 1 if ids else 0
ids_now = (
np.arange(0, len(gallery_docs)) + start_id).astype(np.int64)
......@@ -216,7 +214,7 @@ class GalleryBuilder(object):
ids[i] = d
return index, ids
def _rm_id_in_galllery(self, index, ids, gallery_docs)
def _rm_id_in_galllery(self, index, ids, gallery_docs):
remove_ids = list(
filter(lambda k: ids.get(k) in gallery_docs, ids.keys()))
remove_ids = np.asarray(remove_ids)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册