提交 9eaa3353 编写于 作者: D dongshuilong

update build_gallery and add android demo index support

上级 291015f4
......@@ -12,16 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pickle
import cv2
import faiss
import numpy as np
from tqdm import tqdm
import pickle
from paddleclas.deploy.utils import logger, config
from paddleclas.deploy.python.predict_rec import RecPredictor
from paddleclas.deploy.python.predict_rec import RecPredictor
from paddleclas.deploy.utils import config, logger
from tqdm import tqdm
def split_datafile(data_file, image_root, delimiter="\t"):
......@@ -53,6 +51,7 @@ class GalleryBuilder(object):
self.rec_predictor = RecPredictor(config)
assert 'IndexProcess' in config.keys(), "Index config not found ... "
self.build(config['IndexProcess'])
self.android_demo = config["Global"].get("android_demo", False)
def build(self, config):
'''
......@@ -70,10 +69,88 @@ class GalleryBuilder(object):
"new", "remove", "append"
], "Only append, remove and new operation are supported"
if self.android_demo:
self._create_index_for_android_demo(config, gallery_features, gallery_docs)
return
# vector.index: faiss index file
# id_map.pkl: use this file to map id to image_doc
index, ids = None, None
if operation_method in ["remove", "append"]:
# if remove or append, vector.index and id_map.pkl must exist
# if remove or append, load vector.index and id_map.pkl
index, ids = self._load_index()
else:
index_method, index, ids = self._create_index(config)
if index_method == "HNSW32":
logger.warning(
"The HNSW32 method dose not support 'remove' operation")
if operation_method != "remove":
# calculate id for new data
index, ids = self._add_gallery(index, ids, gallery_features, gallery_docs)
else:
if index_method == "HNSW32":
raise RuntimeError(
"The index_method: HNSW32 dose not support 'remove' operation"
)
# remove ids in id_map, remove index data in faiss index
index, ids = self._rm_id_in_galllery(index, ids, gallery_docs)
# store faiss index file and id_map file
self._save_gallery(config, index, ids)
def _create_index_for_android_demo(self, config, gallery_features, gallery_docs):
if not os.path.exists(config["index_dir"]):
os.makedirs(config["index_dir"], exist_ok=True)
#build index
index = faiss.IndexFlatIP(config["embedding_size"])
ids = {}
# calculate id for new data
ids_now = (
np.arange(0, len(gallery_images))).astype(np.int64)
index.add(gallery_features)
for i, d in zip(list(ids_now), gallery_docs):
ids[i] = d
self._save_gallery(config, index, ids)
def _extract_features(self, gallery_images, config):
# extract gallery features
if config["dist_type"] == "hamming":
gallery_features = np.zeros(
[len(gallery_images), config['embedding_size'] // 8],
dtype=np.uint8)
else:
gallery_features = np.zeros(
[len(gallery_images), config['embedding_size']],
dtype=np.float32)
#construct batch imgs and do inference
batch_size = config.get("batch_size", 32)
batch_img = []
for i, image_file in enumerate(tqdm(gallery_images)):
img = cv2.imread(image_file)
if img is None:
logger.error("img empty, please check {}".format(image_file))
exit()
img = img[:, :, ::-1]
batch_img.append(img)
if (i + 1) % batch_size == 0:
rec_feat = self.rec_predictor.predict(batch_img)
gallery_features[i - batch_size + 1:i + 1, :] = rec_feat
batch_img = []
if len(batch_img) > 0:
rec_feat = self.rec_predictor.predict(batch_img)
gallery_features[-len(batch_img):, :] = rec_feat
batch_img = []
return gallery_features
def _load_index(self, config):
assert os.path.join(
config["index_dir"], "vector.index"
), "The vector.index dose not exist in {} when 'index_operation' is not None".format(
......@@ -89,7 +166,9 @@ class GalleryBuilder(object):
ids = pickle.load(fd)
assert index.ntotal == len(ids.keys(
)), "data number in index is not equal in in id_map"
else:
return index, ids
def _create_index(self, config):
if not os.path.exists(config["index_dir"]):
os.makedirs(config["index_dir"], exist_ok=True)
index_method = config.get("index_method", "HNSW32")
......@@ -116,16 +195,12 @@ class GalleryBuilder(object):
index_method, dist_type)
index = faiss.IndexIDMap2(index)
ids = {}
return index_method, index, ids
if config["index_method"] == "HNSW32":
logger.warning(
"The HNSW32 method dose not support 'remove' operation")
if operation_method != "remove":
# calculate id for new data
def _add_gallery(self, index, ids, gallery_features, gallery_docs):
start_id = max(ids.keys()) + 1 if ids else 0
ids_now = (
np.arange(0, len(gallery_images)) + start_id).astype(np.int64)
np.arange(0, len(gallery_docs)) + start_id).astype(np.int64)
# only train when new index file
if operation_method == "new":
......@@ -139,12 +214,9 @@ class GalleryBuilder(object):
for i, d in zip(list(ids_now), gallery_docs):
ids[i] = d
else:
if config["index_method"] == "HNSW32":
raise RuntimeError(
"The index_method: HNSW32 dose not support 'remove' operation"
)
# remove ids in id_map, remove index data in faiss index
return index, ids
def _rm_id_in_galllery(self, index, ids, gallery_docs)
remove_ids = list(
filter(lambda k: ids.get(k) in gallery_docs, ids.keys()))
remove_ids = np.asarray(remove_ids)
......@@ -152,7 +224,9 @@ class GalleryBuilder(object):
for k in remove_ids:
del ids[k]
# store faiss index file and id_map file
return index, ids
def _save_gallery(self, config, index, ids):
if config["dist_type"] == "hamming":
faiss.write_index_binary(
index, os.path.join(config["index_dir"], "vector.index"))
......@@ -163,40 +237,6 @@ class GalleryBuilder(object):
with open(os.path.join(config["index_dir"], "id_map.pkl"), 'wb') as fd:
pickle.dump(ids, fd)
def _extract_features(self, gallery_images, config):
# extract gallery features
if config["dist_type"] == "hamming":
gallery_features = np.zeros(
[len(gallery_images), config['embedding_size'] // 8],
dtype=np.uint8)
else:
gallery_features = np.zeros(
[len(gallery_images), config['embedding_size']],
dtype=np.float32)
#construct batch imgs and do inference
batch_size = config.get("batch_size", 32)
batch_img = []
for i, image_file in enumerate(tqdm(gallery_images)):
img = cv2.imread(image_file)
if img is None:
logger.error("img empty, please check {}".format(image_file))
exit()
img = img[:, :, ::-1]
batch_img.append(img)
if (i + 1) % batch_size == 0:
rec_feat = self.rec_predictor.predict(batch_img)
gallery_features[i - batch_size + 1:i + 1, :] = rec_feat
batch_img = []
if len(batch_img) > 0:
rec_feat = self.rec_predictor.predict(batch_img)
gallery_features[-len(batch_img):, :] = rec_feat
batch_img = []
return gallery_features
def main(config):
GalleryBuilder(config)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册