未验证 提交 6cd62518 编写于 作者: W Walter 提交者: GitHub

Merge pull request #2236 from RainFrost1/lite_shitu

update build_gallery and add android demo index support
...@@ -12,16 +12,14 @@ ...@@ -12,16 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
import pickle
import cv2 import cv2
import faiss import faiss
import numpy as np import numpy as np
from tqdm import tqdm
import pickle
from paddleclas.deploy.utils import logger, config
from paddleclas.deploy.python.predict_rec import RecPredictor
from paddleclas.deploy.python.predict_rec import RecPredictor from paddleclas.deploy.python.predict_rec import RecPredictor
from paddleclas.deploy.utils import config, logger
from tqdm import tqdm
def split_datafile(data_file, image_root, delimiter="\t"): def split_datafile(data_file, image_root, delimiter="\t"):
...@@ -52,6 +50,7 @@ class GalleryBuilder(object): ...@@ -52,6 +50,7 @@ class GalleryBuilder(object):
self.config = config self.config = config
self.rec_predictor = RecPredictor(config) self.rec_predictor = RecPredictor(config)
assert 'IndexProcess' in config.keys(), "Index config not found ... " assert 'IndexProcess' in config.keys(), "Index config not found ... "
self.android_demo = config["Global"].get("android_demo", False)
self.build(config['IndexProcess']) self.build(config['IndexProcess'])
def build(self, config): def build(self, config):
...@@ -70,98 +69,50 @@ class GalleryBuilder(object): ...@@ -70,98 +69,50 @@ class GalleryBuilder(object):
"new", "remove", "append" "new", "remove", "append"
], "Only append, remove and new operation are supported" ], "Only append, remove and new operation are supported"
if self.android_demo:
self._create_index_for_android_demo(config, gallery_features, gallery_docs)
return
# vector.index: faiss index file # vector.index: faiss index file
# id_map.pkl: use this file to map id to image_doc # id_map.pkl: use this file to map id to image_doc
index, ids = None, None
if operation_method in ["remove", "append"]: if operation_method in ["remove", "append"]:
# if remove or append, vector.index and id_map.pkl must exist # if remove or append, load vector.index and id_map.pkl
assert os.path.join( index, ids = self._load_index(config)
config["index_dir"], "vector.index"
), "The vector.index dose not exist in {} when 'index_operation' is not None".format(
config["index_dir"])
assert os.path.join(
config["index_dir"], "id_map.pkl"
), "The id_map.pkl dose not exist in {} when 'index_operation' is not None".format(
config["index_dir"])
index = faiss.read_index(
os.path.join(config["index_dir"], "vector.index"))
with open(os.path.join(config["index_dir"], "id_map.pkl"),
'rb') as fd:
ids = pickle.load(fd)
assert index.ntotal == len(ids.keys(
)), "data number in index is not equal in in id_map"
else:
if not os.path.exists(config["index_dir"]):
os.makedirs(config["index_dir"], exist_ok=True)
index_method = config.get("index_method", "HNSW32") index_method = config.get("index_method", "HNSW32")
else:
# if IVF method, cal ivf number automaticlly index_method, index, ids = self._create_index(config)
if index_method == "IVF": if index_method == "HNSW32":
index_method = index_method + str(
min(int(len(gallery_images) // 8), 65536)) + ",Flat"
# for binary index, add B at head of index_method
if config["dist_type"] == "hamming":
index_method = "B" + index_method
#dist_type
dist_type = faiss.METRIC_INNER_PRODUCT if config[
"dist_type"] == "IP" else faiss.METRIC_L2
#build index
if config["dist_type"] == "hamming":
index = faiss.index_binary_factory(config["embedding_size"],
index_method)
else:
index = faiss.index_factory(config["embedding_size"],
index_method, dist_type)
index = faiss.IndexIDMap2(index)
ids = {}
if config["index_method"] == "HNSW32":
logger.warning( logger.warning(
"The HNSW32 method dose not support 'remove' operation") "The HNSW32 method dose not support 'remove' operation")
if operation_method != "remove": if operation_method != "remove":
# calculate id for new data # calculate id for new data
start_id = max(ids.keys()) + 1 if ids else 0 index, ids = self._add_gallery(index, ids, gallery_features, gallery_docs, config, operation_method)
ids_now = (
np.arange(0, len(gallery_images)) + start_id).astype(np.int64)
# only train when new index file
if operation_method == "new":
if config["dist_type"] == "hamming":
index.add(gallery_features)
else:
index.train(gallery_features)
if not config["dist_type"] == "hamming":
index.add_with_ids(gallery_features, ids_now)
for i, d in zip(list(ids_now), gallery_docs):
ids[i] = d
else: else:
if config["index_method"] == "HNSW32": if index_method == "HNSW32":
raise RuntimeError( raise RuntimeError(
"The index_method: HNSW32 dose not support 'remove' operation" "The index_method: HNSW32 dose not support 'remove' operation"
) )
# remove ids in id_map, remove index data in faiss index # remove ids in id_map, remove index data in faiss index
remove_ids = list( index, ids = self._rm_id_in_galllery(index, ids, gallery_docs)
filter(lambda k: ids.get(k) in gallery_docs, ids.keys()))
remove_ids = np.asarray(remove_ids)
index.remove_ids(remove_ids)
for k in remove_ids:
del ids[k]
# store faiss index file and id_map file # store faiss index file and id_map file
if config["dist_type"] == "hamming": self._save_gallery(config, index, ids)
faiss.write_index_binary(
index, os.path.join(config["index_dir"], "vector.index")) def _create_index_for_android_demo(self, config, gallery_features, gallery_docs):
else: if not os.path.exists(config["index_dir"]):
faiss.write_index( os.makedirs(config["index_dir"], exist_ok=True)
index, os.path.join(config["index_dir"], "vector.index")) #build index
index = faiss.IndexFlatIP(config["embedding_size"])
with open(os.path.join(config["index_dir"], "id_map.pkl"), 'wb') as fd: index.add(gallery_features)
pickle.dump(ids, fd)
# calculate id for data
ids_now = (np.arange(0, len(gallery_docs))).astype(np.int64)
ids = {}
for i, d in zip(list(ids_now), gallery_docs):
ids[i] = d
self._save_gallery(config, index, ids)
def _extract_features(self, gallery_images, config): def _extract_features(self, gallery_images, config):
# extract gallery features # extract gallery features
...@@ -197,6 +148,93 @@ class GalleryBuilder(object): ...@@ -197,6 +148,93 @@ class GalleryBuilder(object):
return gallery_features return gallery_features
def _load_index(self, config):
assert os.path.join(
config["index_dir"], "vector.index"
), "The vector.index dose not exist in {} when 'index_operation' is not None".format(
config["index_dir"])
assert os.path.join(
config["index_dir"], "id_map.pkl"
), "The id_map.pkl dose not exist in {} when 'index_operation' is not None".format(
config["index_dir"])
index = faiss.read_index(
os.path.join(config["index_dir"], "vector.index"))
with open(os.path.join(config["index_dir"], "id_map.pkl"),
'rb') as fd:
ids = pickle.load(fd)
assert index.ntotal == len(ids.keys(
)), "data number in index is not equal in in id_map"
return index, ids
def _create_index(self, config):
if not os.path.exists(config["index_dir"]):
os.makedirs(config["index_dir"], exist_ok=True)
index_method = config.get("index_method", "HNSW32")
# if IVF method, cal ivf number automaticlly
if index_method == "IVF":
index_method = index_method + str(
min(int(len(gallery_images) // 8), 65536)) + ",Flat"
# for binary index, add B at head of index_method
if config["dist_type"] == "hamming":
index_method = "B" + index_method
#dist_type
dist_type = faiss.METRIC_INNER_PRODUCT if config[
"dist_type"] == "IP" else faiss.METRIC_L2
#build index
if config["dist_type"] == "hamming":
index = faiss.index_binary_factory(config["embedding_size"],
index_method)
else:
index = faiss.index_factory(config["embedding_size"],
index_method, dist_type)
index = faiss.IndexIDMap2(index)
ids = {}
return index_method, index, ids
def _add_gallery(self, index, ids, gallery_features, gallery_docs, config, operation_method):
start_id = max(ids.keys()) + 1 if ids else 0
ids_now = (
np.arange(0, len(gallery_docs)) + start_id).astype(np.int64)
# only train when new index file
if operation_method == "new":
if config["dist_type"] == "hamming":
index.add(gallery_features)
else:
index.train(gallery_features)
if not config["dist_type"] == "hamming":
index.add_with_ids(gallery_features, ids_now)
for i, d in zip(list(ids_now), gallery_docs):
ids[i] = d
return index, ids
def _rm_id_in_galllery(self, index, ids, gallery_docs):
remove_ids = list(
filter(lambda k: ids.get(k) in gallery_docs, ids.keys()))
remove_ids = np.asarray(remove_ids)
index.remove_ids(remove_ids)
for k in remove_ids:
del ids[k]
return index, ids
def _save_gallery(self, config, index, ids):
if config["dist_type"] == "hamming":
faiss.write_index_binary(
index, os.path.join(config["index_dir"], "vector.index"))
else:
faiss.write_index(
index, os.path.join(config["index_dir"], "vector.index"))
with open(os.path.join(config["index_dir"], "id_map.pkl"), 'wb') as fd:
pickle.dump(ids, fd)
def main(config): def main(config):
GalleryBuilder(config) GalleryBuilder(config)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册