未验证 提交 6cd62518 编写于 作者: W Walter 提交者: GitHub

Merge pull request #2236 from RainFrost1/lite_shitu

update build_gallery and add android demo index support
...@@ -12,16 +12,14 @@ ...@@ -12,16 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
import pickle
import cv2 import cv2
import faiss import faiss
import numpy as np import numpy as np
from tqdm import tqdm
import pickle
from paddleclas.deploy.utils import logger, config
from paddleclas.deploy.python.predict_rec import RecPredictor
from paddleclas.deploy.python.predict_rec import RecPredictor from paddleclas.deploy.python.predict_rec import RecPredictor
from paddleclas.deploy.utils import config, logger
from tqdm import tqdm
def split_datafile(data_file, image_root, delimiter="\t"): def split_datafile(data_file, image_root, delimiter="\t"):
...@@ -52,6 +50,7 @@ class GalleryBuilder(object): ...@@ -52,6 +50,7 @@ class GalleryBuilder(object):
self.config = config self.config = config
self.rec_predictor = RecPredictor(config) self.rec_predictor = RecPredictor(config)
assert 'IndexProcess' in config.keys(), "Index config not found ... " assert 'IndexProcess' in config.keys(), "Index config not found ... "
self.android_demo = config["Global"].get("android_demo", False)
self.build(config['IndexProcess']) self.build(config['IndexProcess'])
def build(self, config): def build(self, config):
...@@ -70,10 +69,86 @@ class GalleryBuilder(object): ...@@ -70,10 +69,86 @@ class GalleryBuilder(object):
"new", "remove", "append" "new", "remove", "append"
], "Only append, remove and new operation are supported" ], "Only append, remove and new operation are supported"
if self.android_demo:
self._create_index_for_android_demo(config, gallery_features, gallery_docs)
return
# vector.index: faiss index file # vector.index: faiss index file
# id_map.pkl: use this file to map id to image_doc # id_map.pkl: use this file to map id to image_doc
index, ids = None, None
if operation_method in ["remove", "append"]: if operation_method in ["remove", "append"]:
# if remove or append, vector.index and id_map.pkl must exist # if remove or append, load vector.index and id_map.pkl
index, ids = self._load_index(config)
index_method = config.get("index_method", "HNSW32")
else:
index_method, index, ids = self._create_index(config)
if index_method == "HNSW32":
logger.warning(
"The HNSW32 method dose not support 'remove' operation")
if operation_method != "remove":
# calculate id for new data
index, ids = self._add_gallery(index, ids, gallery_features, gallery_docs, config, operation_method)
else:
if index_method == "HNSW32":
raise RuntimeError(
"The index_method: HNSW32 dose not support 'remove' operation"
)
# remove ids in id_map, remove index data in faiss index
index, ids = self._rm_id_in_galllery(index, ids, gallery_docs)
# store faiss index file and id_map file
self._save_gallery(config, index, ids)
def _create_index_for_android_demo(self, config, gallery_features, gallery_docs):
if not os.path.exists(config["index_dir"]):
os.makedirs(config["index_dir"], exist_ok=True)
#build index
index = faiss.IndexFlatIP(config["embedding_size"])
index.add(gallery_features)
# calculate id for data
ids_now = (np.arange(0, len(gallery_docs))).astype(np.int64)
ids = {}
for i, d in zip(list(ids_now), gallery_docs):
ids[i] = d
self._save_gallery(config, index, ids)
def _extract_features(self, gallery_images, config):
# extract gallery features
if config["dist_type"] == "hamming":
gallery_features = np.zeros(
[len(gallery_images), config['embedding_size'] // 8],
dtype=np.uint8)
else:
gallery_features = np.zeros(
[len(gallery_images), config['embedding_size']],
dtype=np.float32)
#construct batch imgs and do inference
batch_size = config.get("batch_size", 32)
batch_img = []
for i, image_file in enumerate(tqdm(gallery_images)):
img = cv2.imread(image_file)
if img is None:
logger.error("img empty, please check {}".format(image_file))
exit()
img = img[:, :, ::-1]
batch_img.append(img)
if (i + 1) % batch_size == 0:
rec_feat = self.rec_predictor.predict(batch_img)
gallery_features[i - batch_size + 1:i + 1, :] = rec_feat
batch_img = []
if len(batch_img) > 0:
rec_feat = self.rec_predictor.predict(batch_img)
gallery_features[-len(batch_img):, :] = rec_feat
batch_img = []
return gallery_features
def _load_index(self, config):
assert os.path.join( assert os.path.join(
config["index_dir"], "vector.index" config["index_dir"], "vector.index"
), "The vector.index dose not exist in {} when 'index_operation' is not None".format( ), "The vector.index dose not exist in {} when 'index_operation' is not None".format(
...@@ -89,7 +164,9 @@ class GalleryBuilder(object): ...@@ -89,7 +164,9 @@ class GalleryBuilder(object):
ids = pickle.load(fd) ids = pickle.load(fd)
assert index.ntotal == len(ids.keys( assert index.ntotal == len(ids.keys(
)), "data number in index is not equal in in id_map" )), "data number in index is not equal in in id_map"
else: return index, ids
def _create_index(self, config):
if not os.path.exists(config["index_dir"]): if not os.path.exists(config["index_dir"]):
os.makedirs(config["index_dir"], exist_ok=True) os.makedirs(config["index_dir"], exist_ok=True)
index_method = config.get("index_method", "HNSW32") index_method = config.get("index_method", "HNSW32")
...@@ -116,16 +193,12 @@ class GalleryBuilder(object): ...@@ -116,16 +193,12 @@ class GalleryBuilder(object):
index_method, dist_type) index_method, dist_type)
index = faiss.IndexIDMap2(index) index = faiss.IndexIDMap2(index)
ids = {} ids = {}
return index_method, index, ids
if config["index_method"] == "HNSW32": def _add_gallery(self, index, ids, gallery_features, gallery_docs, config, operation_method):
logger.warning(
"The HNSW32 method dose not support 'remove' operation")
if operation_method != "remove":
# calculate id for new data
start_id = max(ids.keys()) + 1 if ids else 0 start_id = max(ids.keys()) + 1 if ids else 0
ids_now = ( ids_now = (
np.arange(0, len(gallery_images)) + start_id).astype(np.int64) np.arange(0, len(gallery_docs)) + start_id).astype(np.int64)
# only train when new index file # only train when new index file
if operation_method == "new": if operation_method == "new":
...@@ -139,12 +212,9 @@ class GalleryBuilder(object): ...@@ -139,12 +212,9 @@ class GalleryBuilder(object):
for i, d in zip(list(ids_now), gallery_docs): for i, d in zip(list(ids_now), gallery_docs):
ids[i] = d ids[i] = d
else: return index, ids
if config["index_method"] == "HNSW32":
raise RuntimeError( def _rm_id_in_galllery(self, index, ids, gallery_docs):
"The index_method: HNSW32 dose not support 'remove' operation"
)
# remove ids in id_map, remove index data in faiss index
remove_ids = list( remove_ids = list(
filter(lambda k: ids.get(k) in gallery_docs, ids.keys())) filter(lambda k: ids.get(k) in gallery_docs, ids.keys()))
remove_ids = np.asarray(remove_ids) remove_ids = np.asarray(remove_ids)
...@@ -152,7 +222,9 @@ class GalleryBuilder(object): ...@@ -152,7 +222,9 @@ class GalleryBuilder(object):
for k in remove_ids: for k in remove_ids:
del ids[k] del ids[k]
# store faiss index file and id_map file return index, ids
def _save_gallery(self, config, index, ids):
if config["dist_type"] == "hamming": if config["dist_type"] == "hamming":
faiss.write_index_binary( faiss.write_index_binary(
index, os.path.join(config["index_dir"], "vector.index")) index, os.path.join(config["index_dir"], "vector.index"))
...@@ -163,40 +235,6 @@ class GalleryBuilder(object): ...@@ -163,40 +235,6 @@ class GalleryBuilder(object):
with open(os.path.join(config["index_dir"], "id_map.pkl"), 'wb') as fd: with open(os.path.join(config["index_dir"], "id_map.pkl"), 'wb') as fd:
pickle.dump(ids, fd) pickle.dump(ids, fd)
def _extract_features(self, gallery_images, config):
# extract gallery features
if config["dist_type"] == "hamming":
gallery_features = np.zeros(
[len(gallery_images), config['embedding_size'] // 8],
dtype=np.uint8)
else:
gallery_features = np.zeros(
[len(gallery_images), config['embedding_size']],
dtype=np.float32)
#construct batch imgs and do inference
batch_size = config.get("batch_size", 32)
batch_img = []
for i, image_file in enumerate(tqdm(gallery_images)):
img = cv2.imread(image_file)
if img is None:
logger.error("img empty, please check {}".format(image_file))
exit()
img = img[:, :, ::-1]
batch_img.append(img)
if (i + 1) % batch_size == 0:
rec_feat = self.rec_predictor.predict(batch_img)
gallery_features[i - batch_size + 1:i + 1, :] = rec_feat
batch_img = []
if len(batch_img) > 0:
rec_feat = self.rec_predictor.predict(batch_img)
gallery_features[-len(batch_img):, :] = rec_feat
batch_img = []
return gallery_features
def main(config): def main(config):
GalleryBuilder(config) GalleryBuilder(config)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册