diff --git a/deploy/configs/build_product_binary.yaml b/deploy/configs/build_product_binary.yaml new file mode 100644 index 0000000000000000000000000000000000000000..21ebfbc6fd18c6798369a3a553f92da0dbb8017e --- /dev/null +++ b/deploy/configs/build_product_binary.yaml @@ -0,0 +1,39 @@ +Global: + rec_inference_model_dir: "./models/product_MV3_x1_0_aliproduct_bin_v1.0_infer" + batch_size: 32 + use_gpu: True + enable_mkldnn: True + cpu_num_threads: 10 + enable_benchmark: True + use_fp16: False + ir_optim: True + use_tensorrt: False + gpu_mem: 8000 + enable_profile: False + +RecPreProcess: + transform_ops: + - ResizeImage: + size: 224 + - NormalizeImage: + scale: 0.00392157 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + +RecPostProcess: + main_indicator: Binarize + Binarize: + method: "round" + +# indexing engine config +IndexProcess: + index_method: "Flat" # supported: HNSW32, Flat + index_dir: "./recognition_demo_data_v1.1/gallery_product/index_binary" + image_root: "./recognition_demo_data_v1.1/gallery_product/" + data_file: "./recognition_demo_data_v1.1/gallery_product/data_file.txt" + index_operation: "new" # suported: "append", "remove", "new" + delimiter: "\t" + dist_type: "hamming" + embedding_size: 512 diff --git a/deploy/configs/inference_product_binary.yaml b/deploy/configs/inference_product_binary.yaml new file mode 100644 index 0000000000000000000000000000000000000000..37ad6b3963e21ea16f1d2ba97c7cfd0cf6a792d8 --- /dev/null +++ b/deploy/configs/inference_product_binary.yaml @@ -0,0 +1,60 @@ +Global: + infer_imgs: "./recognition_demo_data_v1.1/test_product/daoxiangcunjinzhubing_6.jpg" + det_inference_model_dir: "./models/ppyolov2_r50vd_dcn_mainbody_v1.0_infer" + rec_inference_model_dir: "./models/product_MV3_x1_0_aliproduct_bin_v1.0_infer" + rec_nms_thresold: 0.05 + + batch_size: 1 + image_shape: [3, 640, 640] + threshold: 0.2 + max_det_results: 5 + labe_list: + - foreground + + # inference engine config + use_gpu: True + enable_mkldnn: True + cpu_num_threads: 10 + enable_benchmark: True + use_fp16: False + ir_optim: True + use_tensorrt: False + gpu_mem: 8000 + enable_profile: False + +DetPreProcess: + transform_ops: + - DetResize: + interp: 2 + keep_ratio: false + target_size: [640, 640] + - DetNormalizeImage: + is_scale: true + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + - DetPermute: {} +DetPostProcess: {} + +RecPreProcess: + transform_ops: + - ResizeImage: + size: 224 + - NormalizeImage: + scale: 0.00392157 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + +RecPostProcess: + main_indicator: Binarize + Binarize: + method: "round" + +# indexing engine config +IndexProcess: + binary_index: true + index_dir: "./recognition_demo_data_v1.1/gallery_product/index_binary" + return_k: 5 + score_thres: 0 + diff --git a/deploy/python/build_gallery.py b/deploy/python/build_gallery.py index 8412f99f2b26ae4d6a9eb727bb1fd7730a5bc999..7b69a04d77422d428c33278da52e8ae10d34a8ca 100644 --- a/deploy/python/build_gallery.py +++ b/deploy/python/build_gallery.py @@ -71,7 +71,6 @@ class GalleryBuilder(object): # when remove data in index, do not need extract fatures if operation_method != "remove": gallery_features = self._extract_features(gallery_images, config) - assert operation_method in [ "new", "remove", "append" ], "Only append, remove and new operation are supported" @@ -104,11 +103,23 @@ class GalleryBuilder(object): if index_method == "IVF": index_method = index_method + str( min(int(len(gallery_images) // 8), 65536)) + ",Flat" + + # for binary index, add B at head of index_method + if config["dist_type"] == "hamming": + index_method = "B" + index_method + + #dist_type dist_type = faiss.METRIC_INNER_PRODUCT if config[ "dist_type"] == "IP" else faiss.METRIC_L2 - index = faiss.index_factory(config["embedding_size"], index_method, - dist_type) - index = faiss.IndexIDMap2(index) + + #build index + if config["dist_type"] == "hamming": + index = faiss.index_binary_factory(config["embedding_size"], + index_method) + else: + index = faiss.index_factory(config["embedding_size"], + index_method, dist_type) + index = faiss.IndexIDMap2(index) ids = {} if config["index_method"] == "HNSW32": @@ -123,8 +134,13 @@ class GalleryBuilder(object): # only train when new index file if operation_method == "new": - index.train(gallery_features) - index.add_with_ids(gallery_features, ids_now) + if config["dist_type"] == "hamming": + index.add(gallery_features) + else: + index.train(gallery_features) + + if not config["dist_type"] == "hamming": + index.add_with_ids(gallery_features, ids_now) for i, d in zip(list(ids_now), gallery_docs): ids[i] = d @@ -142,15 +158,26 @@ class GalleryBuilder(object): del ids[k] # store faiss index file and id_map file - faiss.write_index(index, - os.path.join(config["index_dir"], "vector.index")) + if config["dist_type"] == "hamming": + faiss.write_index_binary( + index, os.path.join(config["index_dir"], "vector.index")) + else: + faiss.write_index( + index, os.path.join(config["index_dir"], "vector.index")) + with open(os.path.join(config["index_dir"], "id_map.pkl"), 'wb') as fd: pickle.dump(ids, fd) def _extract_features(self, gallery_images, config): # extract gallery features - gallery_features = np.zeros( - [len(gallery_images), config['embedding_size']], dtype=np.float32) + if config["dist_type"] == "hamming": + gallery_features = np.zeros( + [len(gallery_images), config['embedding_size'] // 8], + dtype=np.uint8) + else: + gallery_features = np.zeros( + [len(gallery_images), config['embedding_size']], + dtype=np.float32) #construct batch imgs and do inference batch_size = config.get("batch_size", 32) @@ -172,6 +199,7 @@ class GalleryBuilder(object): rec_feat = self.rec_predictor.predict(batch_img) gallery_features[-len(batch_img):, :] = rec_feat batch_img = [] + return gallery_features diff --git a/deploy/python/postprocess.py b/deploy/python/postprocess.py index 17a01f985ad3f2c879a4a93f57475bd03df66b7e..61b5fbcebd6839292cae53f4afcf6d8b6ac40661 100644 --- a/deploy/python/postprocess.py +++ b/deploy/python/postprocess.py @@ -62,6 +62,7 @@ class Topk(object): def parse_class_id_map(self, class_id_map_file): if class_id_map_file is None: return None + if not os.path.exists(class_id_map_file): print( "Warning: If want to use your own label_dict, please input legal path!\nOtherwise label_names will be empty!" @@ -126,3 +127,24 @@ class SavePreLabel(object): output_dir = self.save_dir(str(id)) os.makedirs(output_dir, exist_ok=True) shutil.copy(image_file, output_dir) + +class Binarize(object): + def __init__(self, method = "round"): + self.method = method + self.unit = np.array([[128, 64, 32, 16, 8, 4, 2, 1]]).T + + def __call__(self, x, file_names=None): + if self.method == "round": + x = np.round(x + 1).astype("uint8") - 1 + + if self.method == "sign": + x = ((np.sign(x) + 1) / 2).astype("uint8") + + embedding_size = x.shape[1] + assert embedding_size % 8 == 0, "The Binary index only support vectors with sizes multiple of 8" + + byte = np.zeros([x.shape[0], embedding_size // 8], dtype=np.uint8) + for i in range(embedding_size // 8): + byte[:, i:i+1] = np.dot(x[:, i * 8: (i + 1)* 8], self.unit) + + return byte diff --git a/deploy/python/predict_system.py b/deploy/python/predict_system.py index 79c1ea703920504ed58af976d9bb22c7cf0a0ca5..a93d5f06a34de79c83bf2e94ed6352e644945517 100644 --- a/deploy/python/predict_system.py +++ b/deploy/python/predict_system.py @@ -47,8 +47,14 @@ class SystemPredictor(object): index_dir, "vector.index")), "vector.index not found ..." assert os.path.exists(os.path.join( index_dir, "id_map.pkl")), "id_map.pkl not found ... " - self.Searcher = faiss.read_index( - os.path.join(index_dir, "vector.index")) + + if config['IndexProcess'].get("binary_index", False): + self.Searcher = faiss.read_index_binary( + os.path.join(index_dir, "vector.index")) + else: + self.Searcher = faiss.read_index( + os.path.join(index_dir, "vector.index")) + with open(os.path.join(index_dir, "id_map.pkl"), "rb") as fd: self.id_map = pickle.load(fd) @@ -105,6 +111,7 @@ class SystemPredictor(object): rec_results = self.rec_predictor.predict(crop_img) preds["bbox"] = [xmin, ymin, xmax, ymax] scores, docs = self.Searcher.search(rec_results, self.return_k) + # just top-1 result will be returned for the final if scores[0][0] >= self.config["IndexProcess"]["score_thres"]: preds["rec_docs"] = self.id_map[docs[0][0]].split()[1]