提交 8baf879a 编写于 作者: S stephon

support for binary index build and search

上级 a368e3eb
Global:
#rec_inference_model_dir: "./models/product_ResNet50_vd_aliproduct_v1.0_infer"
rec_inference_model_dir: "../inference"
batch_size: 32
use_gpu: True
enable_mkldnn: True
cpu_num_threads: 10
enable_benchmark: True
use_fp16: False
ir_optim: True
use_tensorrt: False
gpu_mem: 8000
enable_profile: False
RecPreProcess:
transform_ops:
- ResizeImage:
size: 224
- NormalizeImage:
scale: 0.00392157
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
RecPostProcess:
main_indicator: Binarize
Binarize:
method: "round"
# indexing engine config
IndexProcess:
index_method: "Flat" # supported: HNSW32, Flat
index_dir: "./recognition_demo_data_v1.1/gallery_product/index_binary"
image_root: "./recognition_demo_data_v1.1/gallery_product/"
data_file: "./recognition_demo_data_v1.1/gallery_product/data_file.txt"
index_operation: "new" # suported: "append", "remove", "new"
delimiter: "\t"
dist_type: "hamming"
embedding_size: 512
Global:
infer_imgs: "./recognition_demo_data_v1.1/test_product/daoxiangcunjinzhubing_6.jpg"
det_inference_model_dir: "./models/ppyolov2_r50vd_dcn_mainbody_v1.0_infer"
rec_inference_model_dir: "./models/product_ResNet50_vd_aliproduct_v1.0_infer"
rec_nms_thresold: 0.05
batch_size: 1
image_shape: [3, 640, 640]
threshold: 0.2
max_det_results: 5
labe_list:
- foreground
# inference engine config
use_gpu: True
enable_mkldnn: True
cpu_num_threads: 10
enable_benchmark: True
use_fp16: False
ir_optim: True
use_tensorrt: False
gpu_mem: 8000
enable_profile: False
DetPreProcess:
transform_ops:
- DetResize:
interp: 2
keep_ratio: false
target_size: [640, 640]
- DetNormalizeImage:
is_scale: true
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
- DetPermute: {}
DetPostProcess: {}
RecPreProcess:
transform_ops:
- ResizeImage:
size: 224
- NormalizeImage:
scale: 0.00392157
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
RecPostProcess:
main_indicator: Binarize
Binarize:
method: "round"
# indexing engine config
IndexProcess:
binary_index: true
index_dir: "./recognition_demo_data_v1.1/gallery_product/index_binary"
return_k: 5
score_thres: 0
...@@ -28,7 +28,6 @@ from python.predict_rec import RecPredictor ...@@ -28,7 +28,6 @@ from python.predict_rec import RecPredictor
from utils import logger from utils import logger
from utils import config from utils import config
def split_datafile(data_file, image_root, delimiter="\t"): def split_datafile(data_file, image_root, delimiter="\t"):
''' '''
data_file: image path and info, which can be splitted by spacer data_file: image path and info, which can be splitted by spacer
...@@ -70,8 +69,8 @@ class GalleryBuilder(object): ...@@ -70,8 +69,8 @@ class GalleryBuilder(object):
# when remove data in index, do not need extract fatures # when remove data in index, do not need extract fatures
if operation_method != "remove": if operation_method != "remove":
gallery_features = self._extract_features(gallery_images, config) gallery_features = self._extract_features(gallery_images, config) #76 * 512
assert operation_method in [ assert operation_method in [
"new", "remove", "append" "new", "remove", "append"
], "Only append, remove and new operation are supported" ], "Only append, remove and new operation are supported"
...@@ -104,11 +103,22 @@ class GalleryBuilder(object): ...@@ -104,11 +103,22 @@ class GalleryBuilder(object):
if index_method == "IVF": if index_method == "IVF":
index_method = index_method + str( index_method = index_method + str(
min(int(len(gallery_images) // 8), 65536)) + ",Flat" min(int(len(gallery_images) // 8), 65536)) + ",Flat"
# for binary index, add B at head of index_method
if config["dist_type"] == "hamming":
index_method = "B" + index_method
#dist_type
dist_type = faiss.METRIC_INNER_PRODUCT if config[ dist_type = faiss.METRIC_INNER_PRODUCT if config[
"dist_type"] == "IP" else faiss.METRIC_L2 "dist_type"] == "IP" else faiss.METRIC_L2
index = faiss.index_factory(config["embedding_size"], index_method,
dist_type) #build index
index = faiss.IndexIDMap2(index) if config["dist_type"] == "hamming":
index = faiss.index_binary_factory(config["embedding_size"], index_method)
else:
index = faiss.index_factory(config["embedding_size"], index_method,
dist_type)
index = faiss.IndexIDMap2(index)
ids = {} ids = {}
if config["index_method"] == "HNSW32": if config["index_method"] == "HNSW32":
...@@ -119,12 +129,17 @@ class GalleryBuilder(object): ...@@ -119,12 +129,17 @@ class GalleryBuilder(object):
# calculate id for new data # calculate id for new data
start_id = max(ids.keys()) + 1 if ids else 0 start_id = max(ids.keys()) + 1 if ids else 0
ids_now = ( ids_now = (
np.arange(0, len(gallery_images)) + start_id).astype(np.int64) np.arange(0, len(gallery_images)) + start_id).astype(np.int64) #ids: just the number sequence
# only train when new index file # only train when new index file
if operation_method == "new": if operation_method == "new":
index.train(gallery_features) if config["dist_type"] == "hamming":
index.add_with_ids(gallery_features, ids_now) index.add(gallery_features)
else:
index.train(gallery_features)
if not config["dist_type"] == "hamming":
index.add_with_ids(gallery_features, ids_now)
for i, d in zip(list(ids_now), gallery_docs): for i, d in zip(list(ids_now), gallery_docs):
ids[i] = d ids[i] = d
...@@ -142,15 +157,25 @@ class GalleryBuilder(object): ...@@ -142,15 +157,25 @@ class GalleryBuilder(object):
del ids[k] del ids[k]
# store faiss index file and id_map file # store faiss index file and id_map file
faiss.write_index(index, if config["dist_type"] == "hamming":
os.path.join(config["index_dir"], "vector.index")) faiss.write_index_binary(index,
os.path.join(config["index_dir"], "vector.index"))
else:
faiss.write_index(index,
os.path.join(config["index_dir"], "vector.index"))
with open(os.path.join(config["index_dir"], "id_map.pkl"), 'wb') as fd: with open(os.path.join(config["index_dir"], "id_map.pkl"), 'wb') as fd:
pickle.dump(ids, fd) pickle.dump(ids, fd)
def _extract_features(self, gallery_images, config): def _extract_features(self, gallery_images, config):
# extract gallery features # extract gallery features
gallery_features = np.zeros( if config["dist_type"] == "hamming":
[len(gallery_images), config['embedding_size']], dtype=np.float32) gallery_features = np.zeros(
[len(gallery_images), config['embedding_size'] // 8], dtype=np.uint8)
else:
gallery_features = np.zeros(
[len(gallery_images), config['embedding_size']], dtype=np.float32)
#construct batch imgs and do inference #construct batch imgs and do inference
batch_size = config.get("batch_size", 32) batch_size = config.get("batch_size", 32)
...@@ -164,7 +189,7 @@ class GalleryBuilder(object): ...@@ -164,7 +189,7 @@ class GalleryBuilder(object):
batch_img.append(img) batch_img.append(img)
if (i + 1) % batch_size == 0: if (i + 1) % batch_size == 0:
rec_feat = self.rec_predictor.predict(batch_img) rec_feat = self.rec_predictor.predict(batch_img) #32 * 512
gallery_features[i - batch_size + 1:i + 1, :] = rec_feat gallery_features[i - batch_size + 1:i + 1, :] = rec_feat
batch_img = [] batch_img = []
...@@ -172,6 +197,7 @@ class GalleryBuilder(object): ...@@ -172,6 +197,7 @@ class GalleryBuilder(object):
rec_feat = self.rec_predictor.predict(batch_img) rec_feat = self.rec_predictor.predict(batch_img)
gallery_features[-len(batch_img):, :] = rec_feat gallery_features[-len(batch_img):, :] = rec_feat
batch_img = [] batch_img = []
return gallery_features return gallery_features
......
...@@ -62,6 +62,7 @@ class Topk(object): ...@@ -62,6 +62,7 @@ class Topk(object):
def parse_class_id_map(self, class_id_map_file): def parse_class_id_map(self, class_id_map_file):
if class_id_map_file is None: if class_id_map_file is None:
return None return None
if not os.path.exists(class_id_map_file): if not os.path.exists(class_id_map_file):
print( print(
"Warning: If want to use your own label_dict, please input legal path!\nOtherwise label_names will be empty!" "Warning: If want to use your own label_dict, please input legal path!\nOtherwise label_names will be empty!"
...@@ -126,3 +127,42 @@ class SavePreLabel(object): ...@@ -126,3 +127,42 @@ class SavePreLabel(object):
output_dir = self.save_dir(str(id)) output_dir = self.save_dir(str(id))
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
shutil.copy(image_file, output_dir) shutil.copy(image_file, output_dir)
class Binarize(object):
def __init__(self, method = "round"):
self.method = method
self.unit = np.array([[128, 64, 32, 16, 8, 4, 2, 1]]).T
def __call__(self, x, file_names=None):
if self.method == "round":
x = np.round(x + 1).astype("uint8") - 1
if self.method == "sign":
x = ((np.sign(x) + 1) / 2).astype("uint8")
embedding_size = x.shape[1]
assert embedding_size % 8 == 0, "The Binary index only support vectors with sizes multiple of 8"
byte = np.zeros([x.shape[0], embedding_size // 8], dtype=np.uint8)
for i in range(embedding_size // 8):
byte[:, i:i+1] = np.dot(x[:, i * 8: (i + 1)* 8], self.unit)
return byte
if __name__== "__main__":
a = Binarize()
x = np.random.random((31, 64)).astype('float32')
y = a(x)
print(y)
print(y.shape)
...@@ -47,8 +47,14 @@ class SystemPredictor(object): ...@@ -47,8 +47,14 @@ class SystemPredictor(object):
index_dir, "vector.index")), "vector.index not found ..." index_dir, "vector.index")), "vector.index not found ..."
assert os.path.exists(os.path.join( assert os.path.exists(os.path.join(
index_dir, "id_map.pkl")), "id_map.pkl not found ... " index_dir, "id_map.pkl")), "id_map.pkl not found ... "
self.Searcher = faiss.read_index(
os.path.join(index_dir, "vector.index")) if config['IndexProcess'].get("binary_index", False):
self.Searcher = faiss.read_index_binary(
os.path.join(index_dir, "vector.index"))
else:
self.Searcher = faiss.read_index(
os.path.join(index_dir, "vector.index"))
with open(os.path.join(index_dir, "id_map.pkl"), "rb") as fd: with open(os.path.join(index_dir, "id_map.pkl"), "rb") as fd:
self.id_map = pickle.load(fd) self.id_map = pickle.load(fd)
...@@ -105,6 +111,7 @@ class SystemPredictor(object): ...@@ -105,6 +111,7 @@ class SystemPredictor(object):
rec_results = self.rec_predictor.predict(crop_img) rec_results = self.rec_predictor.predict(crop_img)
preds["bbox"] = [xmin, ymin, xmax, ymax] preds["bbox"] = [xmin, ymin, xmax, ymax]
scores, docs = self.Searcher.search(rec_results, self.return_k) scores, docs = self.Searcher.search(rec_results, self.return_k)
# just top-1 result will be returned for the final # just top-1 result will be returned for the final
if scores[0][0] >= self.config["IndexProcess"]["score_thres"]: if scores[0][0] >= self.config["IndexProcess"]["score_thres"]:
preds["rec_docs"] = self.id_map[docs[0][0]].split()[1] preds["rec_docs"] = self.id_map[docs[0][0]].split()[1]
......
from paddle import nn from paddle import nn
import paddle
class IdentityHead(nn.Layer): class IdentityHead(nn.Layer):
def __init__(self): def __init__(self):
super(IdentityHead, self).__init__() super(IdentityHead, self).__init__()
def forward(self, x, label=None): def forward(self, x, label=None):
return {"features": x, "logits": None} return {"features": x, "logits": None}
\ No newline at end of file
...@@ -378,7 +378,6 @@ class ExportModel(nn.Layer): ...@@ -378,7 +378,6 @@ class ExportModel(nn.Layer):
self.infer_output_key = config.get("infer_output_key", None) self.infer_output_key = config.get("infer_output_key", None)
if self.infer_output_key == "features" and isinstance(self.base_model, if self.infer_output_key == "features" and isinstance(self.base_model,
RecModel): RecModel):
self.base_model.head = IdentityHead()
if config.get("infer_add_softmax", True): if config.get("infer_add_softmax", True):
self.softmax = nn.Softmax(axis=-1) self.softmax = nn.Softmax(axis=-1)
else: else:
...@@ -394,10 +393,13 @@ class ExportModel(nn.Layer): ...@@ -394,10 +393,13 @@ class ExportModel(nn.Layer):
x = self.base_model(x) x = self.base_model(x)
if isinstance(x, list): if isinstance(x, list):
x = x[0] x = x[0]
if self.infer_model_name is not None: if self.infer_model_name is not None:
x = x[self.infer_model_name] x = x[self.infer_model_name]
if self.infer_output_key is not None: if self.infer_output_key is not None:
x = x[self.infer_output_key] x = x[self.infer_output_key]
if self.softmax is not None: if self.softmax is not None:
x = self.softmax(x) x = self.softmax(x)
return x return x
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册