未验证 提交 46a08f1d 编写于 作者: B Bin Lu 提交者: GitHub

Merge pull request #1218 from Intsigstephon/binaryindex_build_search

support for binary index build and search
Global:
rec_inference_model_dir: "./models/product_MV3_x1_0_aliproduct_bin_v1.0_infer"
batch_size: 32
use_gpu: True
enable_mkldnn: True
cpu_num_threads: 10
enable_benchmark: True
use_fp16: False
ir_optim: True
use_tensorrt: False
gpu_mem: 8000
enable_profile: False
RecPreProcess:
transform_ops:
- ResizeImage:
size: 224
- NormalizeImage:
scale: 0.00392157
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
RecPostProcess:
main_indicator: Binarize
Binarize:
method: "round"
# indexing engine config
IndexProcess:
index_method: "Flat" # supported: HNSW32, Flat
index_dir: "./recognition_demo_data_v1.1/gallery_product/index_binary"
image_root: "./recognition_demo_data_v1.1/gallery_product/"
data_file: "./recognition_demo_data_v1.1/gallery_product/data_file.txt"
index_operation: "new" # suported: "append", "remove", "new"
delimiter: "\t"
dist_type: "hamming"
embedding_size: 512
Global:
infer_imgs: "./recognition_demo_data_v1.1/test_product/daoxiangcunjinzhubing_6.jpg"
det_inference_model_dir: "./models/ppyolov2_r50vd_dcn_mainbody_v1.0_infer"
rec_inference_model_dir: "./models/product_MV3_x1_0_aliproduct_bin_v1.0_infer"
rec_nms_thresold: 0.05
batch_size: 1
image_shape: [3, 640, 640]
threshold: 0.2
max_det_results: 5
labe_list:
- foreground
# inference engine config
use_gpu: True
enable_mkldnn: True
cpu_num_threads: 10
enable_benchmark: True
use_fp16: False
ir_optim: True
use_tensorrt: False
gpu_mem: 8000
enable_profile: False
DetPreProcess:
transform_ops:
- DetResize:
interp: 2
keep_ratio: false
target_size: [640, 640]
- DetNormalizeImage:
is_scale: true
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
- DetPermute: {}
DetPostProcess: {}
RecPreProcess:
transform_ops:
- ResizeImage:
size: 224
- NormalizeImage:
scale: 0.00392157
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
RecPostProcess:
main_indicator: Binarize
Binarize:
method: "round"
# indexing engine config
IndexProcess:
binary_index: true
index_dir: "./recognition_demo_data_v1.1/gallery_product/index_binary"
return_k: 5
score_thres: 0
......@@ -71,7 +71,6 @@ class GalleryBuilder(object):
# when remove data in index, do not need extract fatures
if operation_method != "remove":
gallery_features = self._extract_features(gallery_images, config)
assert operation_method in [
"new", "remove", "append"
], "Only append, remove and new operation are supported"
......@@ -104,11 +103,23 @@ class GalleryBuilder(object):
if index_method == "IVF":
index_method = index_method + str(
min(int(len(gallery_images) // 8), 65536)) + ",Flat"
# for binary index, add B at head of index_method
if config["dist_type"] == "hamming":
index_method = "B" + index_method
#dist_type
dist_type = faiss.METRIC_INNER_PRODUCT if config[
"dist_type"] == "IP" else faiss.METRIC_L2
index = faiss.index_factory(config["embedding_size"], index_method,
dist_type)
index = faiss.IndexIDMap2(index)
#build index
if config["dist_type"] == "hamming":
index = faiss.index_binary_factory(config["embedding_size"],
index_method)
else:
index = faiss.index_factory(config["embedding_size"],
index_method, dist_type)
index = faiss.IndexIDMap2(index)
ids = {}
if config["index_method"] == "HNSW32":
......@@ -123,8 +134,13 @@ class GalleryBuilder(object):
# only train when new index file
if operation_method == "new":
index.train(gallery_features)
index.add_with_ids(gallery_features, ids_now)
if config["dist_type"] == "hamming":
index.add(gallery_features)
else:
index.train(gallery_features)
if not config["dist_type"] == "hamming":
index.add_with_ids(gallery_features, ids_now)
for i, d in zip(list(ids_now), gallery_docs):
ids[i] = d
......@@ -142,15 +158,26 @@ class GalleryBuilder(object):
del ids[k]
# store faiss index file and id_map file
faiss.write_index(index,
os.path.join(config["index_dir"], "vector.index"))
if config["dist_type"] == "hamming":
faiss.write_index_binary(
index, os.path.join(config["index_dir"], "vector.index"))
else:
faiss.write_index(
index, os.path.join(config["index_dir"], "vector.index"))
with open(os.path.join(config["index_dir"], "id_map.pkl"), 'wb') as fd:
pickle.dump(ids, fd)
def _extract_features(self, gallery_images, config):
# extract gallery features
gallery_features = np.zeros(
[len(gallery_images), config['embedding_size']], dtype=np.float32)
if config["dist_type"] == "hamming":
gallery_features = np.zeros(
[len(gallery_images), config['embedding_size'] // 8],
dtype=np.uint8)
else:
gallery_features = np.zeros(
[len(gallery_images), config['embedding_size']],
dtype=np.float32)
#construct batch imgs and do inference
batch_size = config.get("batch_size", 32)
......@@ -172,6 +199,7 @@ class GalleryBuilder(object):
rec_feat = self.rec_predictor.predict(batch_img)
gallery_features[-len(batch_img):, :] = rec_feat
batch_img = []
return gallery_features
......
......@@ -62,6 +62,7 @@ class Topk(object):
def parse_class_id_map(self, class_id_map_file):
if class_id_map_file is None:
return None
if not os.path.exists(class_id_map_file):
print(
"Warning: If want to use your own label_dict, please input legal path!\nOtherwise label_names will be empty!"
......@@ -126,3 +127,24 @@ class SavePreLabel(object):
output_dir = self.save_dir(str(id))
os.makedirs(output_dir, exist_ok=True)
shutil.copy(image_file, output_dir)
class Binarize(object):
def __init__(self, method = "round"):
self.method = method
self.unit = np.array([[128, 64, 32, 16, 8, 4, 2, 1]]).T
def __call__(self, x, file_names=None):
if self.method == "round":
x = np.round(x + 1).astype("uint8") - 1
if self.method == "sign":
x = ((np.sign(x) + 1) / 2).astype("uint8")
embedding_size = x.shape[1]
assert embedding_size % 8 == 0, "The Binary index only support vectors with sizes multiple of 8"
byte = np.zeros([x.shape[0], embedding_size // 8], dtype=np.uint8)
for i in range(embedding_size // 8):
byte[:, i:i+1] = np.dot(x[:, i * 8: (i + 1)* 8], self.unit)
return byte
......@@ -47,8 +47,14 @@ class SystemPredictor(object):
index_dir, "vector.index")), "vector.index not found ..."
assert os.path.exists(os.path.join(
index_dir, "id_map.pkl")), "id_map.pkl not found ... "
self.Searcher = faiss.read_index(
os.path.join(index_dir, "vector.index"))
if config['IndexProcess'].get("binary_index", False):
self.Searcher = faiss.read_index_binary(
os.path.join(index_dir, "vector.index"))
else:
self.Searcher = faiss.read_index(
os.path.join(index_dir, "vector.index"))
with open(os.path.join(index_dir, "id_map.pkl"), "rb") as fd:
self.id_map = pickle.load(fd)
......@@ -105,6 +111,7 @@ class SystemPredictor(object):
rec_results = self.rec_predictor.predict(crop_img)
preds["bbox"] = [xmin, ymin, xmax, ymax]
scores, docs = self.Searcher.search(rec_results, self.return_k)
# just top-1 result will be returned for the final
if scores[0][0] >= self.config["IndexProcess"]["score_thres"]:
preds["rec_docs"] = self.id_map[docs[0][0]].split()[1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册