diff --git a/deploy/configs/build_product_binary.yaml b/deploy/configs/build_product_binary.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e2343df429f4050fb72249c895eb0a420a570a06 --- /dev/null +++ b/deploy/configs/build_product_binary.yaml @@ -0,0 +1,40 @@ +Global: + #rec_inference_model_dir: "./models/product_ResNet50_vd_aliproduct_v1.0_infer" + rec_inference_model_dir: "../inference" + batch_size: 32 + use_gpu: True + enable_mkldnn: True + cpu_num_threads: 10 + enable_benchmark: True + use_fp16: False + ir_optim: True + use_tensorrt: False + gpu_mem: 8000 + enable_profile: False + +RecPreProcess: + transform_ops: + - ResizeImage: + size: 224 + - NormalizeImage: + scale: 0.00392157 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + +RecPostProcess: + main_indicator: Binarize + Binarize: + method: "round" + +# indexing engine config +IndexProcess: + index_method: "Flat" # supported: HNSW32, Flat + index_dir: "./recognition_demo_data_v1.1/gallery_product/index_binary" + image_root: "./recognition_demo_data_v1.1/gallery_product/" + data_file: "./recognition_demo_data_v1.1/gallery_product/data_file.txt" + index_operation: "new" # suported: "append", "remove", "new" + delimiter: "\t" + dist_type: "hamming" + embedding_size: 512 diff --git a/deploy/configs/inference_product_binary.yaml b/deploy/configs/inference_product_binary.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d48c03ca74aca275c6be1699de86ebb77b0813b5 --- /dev/null +++ b/deploy/configs/inference_product_binary.yaml @@ -0,0 +1,60 @@ +Global: + infer_imgs: "./recognition_demo_data_v1.1/test_product/daoxiangcunjinzhubing_6.jpg" + det_inference_model_dir: "./models/ppyolov2_r50vd_dcn_mainbody_v1.0_infer" + rec_inference_model_dir: "./models/product_ResNet50_vd_aliproduct_v1.0_infer" + rec_nms_thresold: 0.05 + + batch_size: 1 + image_shape: [3, 640, 640] + threshold: 0.2 + max_det_results: 5 + labe_list: + - foreground + + # inference engine config + use_gpu: True + enable_mkldnn: True + cpu_num_threads: 10 + enable_benchmark: True + use_fp16: False + ir_optim: True + use_tensorrt: False + gpu_mem: 8000 + enable_profile: False + +DetPreProcess: + transform_ops: + - DetResize: + interp: 2 + keep_ratio: false + target_size: [640, 640] + - DetNormalizeImage: + is_scale: true + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + - DetPermute: {} +DetPostProcess: {} + +RecPreProcess: + transform_ops: + - ResizeImage: + size: 224 + - NormalizeImage: + scale: 0.00392157 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + +RecPostProcess: + main_indicator: Binarize + Binarize: + method: "round" + +# indexing engine config +IndexProcess: + binary_index: true + index_dir: "./recognition_demo_data_v1.1/gallery_product/index_binary" + return_k: 5 + score_thres: 0 + diff --git a/deploy/python/build_gallery.py b/deploy/python/build_gallery.py index 8412f99f2b26ae4d6a9eb727bb1fd7730a5bc999..4c8abdeb9d50f2d7b73e208e8c5459fcfcc3dbe3 100644 --- a/deploy/python/build_gallery.py +++ b/deploy/python/build_gallery.py @@ -28,7 +28,6 @@ from python.predict_rec import RecPredictor from utils import logger from utils import config - def split_datafile(data_file, image_root, delimiter="\t"): ''' data_file: image path and info, which can be splitted by spacer @@ -70,8 +69,8 @@ class GalleryBuilder(object): # when remove data in index, do not need extract fatures if operation_method != "remove": - gallery_features = self._extract_features(gallery_images, config) - + gallery_features = self._extract_features(gallery_images, config) #76 * 512 + assert operation_method in [ "new", "remove", "append" ], "Only append, remove and new operation are supported" @@ -104,11 +103,22 @@ class GalleryBuilder(object): if index_method == "IVF": index_method = index_method + str( min(int(len(gallery_images) // 8), 65536)) + ",Flat" + + # for binary index, add B at head of index_method + if config["dist_type"] == "hamming": + index_method = "B" + index_method + + #dist_type dist_type = faiss.METRIC_INNER_PRODUCT if config[ "dist_type"] == "IP" else faiss.METRIC_L2 - index = faiss.index_factory(config["embedding_size"], index_method, - dist_type) - index = faiss.IndexIDMap2(index) + + #build index + if config["dist_type"] == "hamming": + index = faiss.index_binary_factory(config["embedding_size"], index_method) + else: + index = faiss.index_factory(config["embedding_size"], index_method, + dist_type) + index = faiss.IndexIDMap2(index) ids = {} if config["index_method"] == "HNSW32": @@ -119,12 +129,17 @@ class GalleryBuilder(object): # calculate id for new data start_id = max(ids.keys()) + 1 if ids else 0 ids_now = ( - np.arange(0, len(gallery_images)) + start_id).astype(np.int64) + np.arange(0, len(gallery_images)) + start_id).astype(np.int64) #ids: just the number sequence # only train when new index file if operation_method == "new": - index.train(gallery_features) - index.add_with_ids(gallery_features, ids_now) + if config["dist_type"] == "hamming": + index.add(gallery_features) + else: + index.train(gallery_features) + + if not config["dist_type"] == "hamming": + index.add_with_ids(gallery_features, ids_now) for i, d in zip(list(ids_now), gallery_docs): ids[i] = d @@ -142,15 +157,25 @@ class GalleryBuilder(object): del ids[k] # store faiss index file and id_map file - faiss.write_index(index, - os.path.join(config["index_dir"], "vector.index")) + if config["dist_type"] == "hamming": + faiss.write_index_binary(index, + os.path.join(config["index_dir"], "vector.index")) + else: + faiss.write_index(index, + os.path.join(config["index_dir"], "vector.index")) + with open(os.path.join(config["index_dir"], "id_map.pkl"), 'wb') as fd: pickle.dump(ids, fd) def _extract_features(self, gallery_images, config): + # extract gallery features - gallery_features = np.zeros( - [len(gallery_images), config['embedding_size']], dtype=np.float32) + if config["dist_type"] == "hamming": + gallery_features = np.zeros( + [len(gallery_images), config['embedding_size'] // 8], dtype=np.uint8) + else: + gallery_features = np.zeros( + [len(gallery_images), config['embedding_size']], dtype=np.float32) #construct batch imgs and do inference batch_size = config.get("batch_size", 32) @@ -164,7 +189,7 @@ class GalleryBuilder(object): batch_img.append(img) if (i + 1) % batch_size == 0: - rec_feat = self.rec_predictor.predict(batch_img) + rec_feat = self.rec_predictor.predict(batch_img) #32 * 512 gallery_features[i - batch_size + 1:i + 1, :] = rec_feat batch_img = [] @@ -172,6 +197,7 @@ class GalleryBuilder(object): rec_feat = self.rec_predictor.predict(batch_img) gallery_features[-len(batch_img):, :] = rec_feat batch_img = [] + return gallery_features diff --git a/deploy/python/postprocess.py b/deploy/python/postprocess.py index 17a01f985ad3f2c879a4a93f57475bd03df66b7e..bf26823e3ba40050077631955b081484d5d3ab61 100644 --- a/deploy/python/postprocess.py +++ b/deploy/python/postprocess.py @@ -62,6 +62,7 @@ class Topk(object): def parse_class_id_map(self, class_id_map_file): if class_id_map_file is None: return None + if not os.path.exists(class_id_map_file): print( "Warning: If want to use your own label_dict, please input legal path!\nOtherwise label_names will be empty!" @@ -126,3 +127,42 @@ class SavePreLabel(object): output_dir = self.save_dir(str(id)) os.makedirs(output_dir, exist_ok=True) shutil.copy(image_file, output_dir) + +class Binarize(object): + def __init__(self, method = "round"): + self.method = method + self.unit = np.array([[128, 64, 32, 16, 8, 4, 2, 1]]).T + + def __call__(self, x, file_names=None): + if self.method == "round": + x = np.round(x + 1).astype("uint8") - 1 + + if self.method == "sign": + x = ((np.sign(x) + 1) / 2).astype("uint8") + + embedding_size = x.shape[1] + assert embedding_size % 8 == 0, "The Binary index only support vectors with sizes multiple of 8" + + byte = np.zeros([x.shape[0], embedding_size // 8], dtype=np.uint8) + for i in range(embedding_size // 8): + byte[:, i:i+1] = np.dot(x[:, i * 8: (i + 1)* 8], self.unit) + + return byte + +if __name__== "__main__": + a = Binarize() + x = np.random.random((31, 64)).astype('float32') + + y = a(x) + print(y) + print(y.shape) + + + + + + + + + + diff --git a/deploy/python/predict_system.py b/deploy/python/predict_system.py index 79c1ea703920504ed58af976d9bb22c7cf0a0ca5..a93d5f06a34de79c83bf2e94ed6352e644945517 100644 --- a/deploy/python/predict_system.py +++ b/deploy/python/predict_system.py @@ -47,8 +47,14 @@ class SystemPredictor(object): index_dir, "vector.index")), "vector.index not found ..." assert os.path.exists(os.path.join( index_dir, "id_map.pkl")), "id_map.pkl not found ... " - self.Searcher = faiss.read_index( - os.path.join(index_dir, "vector.index")) + + if config['IndexProcess'].get("binary_index", False): + self.Searcher = faiss.read_index_binary( + os.path.join(index_dir, "vector.index")) + else: + self.Searcher = faiss.read_index( + os.path.join(index_dir, "vector.index")) + with open(os.path.join(index_dir, "id_map.pkl"), "rb") as fd: self.id_map = pickle.load(fd) @@ -105,6 +111,7 @@ class SystemPredictor(object): rec_results = self.rec_predictor.predict(crop_img) preds["bbox"] = [xmin, ymin, xmax, ymax] scores, docs = self.Searcher.search(rec_results, self.return_k) + # just top-1 result will be returned for the final if scores[0][0] >= self.config["IndexProcess"]["score_thres"]: preds["rec_docs"] = self.id_map[docs[0][0]].split()[1] diff --git a/ppcls/arch/gears/identity_head.py b/ppcls/arch/gears/identity_head.py index 7d11e5742e1645a4a3407c73f1c6d316e0aa4e6d..8f7dd4269d14f6459f9994d71c5e075c1308f619 100644 --- a/ppcls/arch/gears/identity_head.py +++ b/ppcls/arch/gears/identity_head.py @@ -1,9 +1,35 @@ from paddle import nn - +import paddle class IdentityHead(nn.Layer): - def __init__(self): + def __init__(self, binarize_method = "none", embedding_size = 256): super(IdentityHead, self).__init__() + self.binarize_method = binarize_method + self.embedding_size = embedding_size + self.multiplier = self._init_multiplier(embedding_size) def forward(self, x, label=None): + if self.binarize_method == "round": + x = paddle.round(x) + + if self.binarize_method == "sign": + x = (paddle.sign(x) + 1.0) / 2.0 + + if self.binarize_method == "round" or self.binarize_method == "sign": + x = self._binary_to_byte(x, self.multiplier) + return {"features": x, "logits": None} + + def _init_multiplier(self, embedding_size): + unit = paddle.to_tensor([128, 64, 32, 16, 8, 4, 2, 1]) + repeat = embedding_size // 8 + assert embedding_size % 8 == 0, "The binary index only support vectors with sizes multiple of 8" + unit = paddle.broadcast_to(unit, shape=[repeat, 8]) + multiplier = paddle.reshape(unit, shape=[1, -1]).astype("float32") + return multiplier + + def _binary_to_byte(self, input_tensor, multiplier): + tmp = paddle.multiply(input_tensor, multiplier) + tmp = paddle.reshape(tmp, shape=[tmp.shape[0], -1, 8]) + byte = paddle.sum(tmp, axis=-1).astype("uint8") + return byte \ No newline at end of file diff --git a/ppcls/configs/Products/MV3_Large_1x_Aliproduct_DLBHC.yaml b/ppcls/configs/Products/MV3_Large_1x_Aliproduct_DLBHC.yaml index c9a8b7b4012378f8d04ce6f2b215076ad5850166..42358aba95f3d8109e10c84e1f9d7b98bc054f3c 100644 --- a/ppcls/configs/Products/MV3_Large_1x_Aliproduct_DLBHC.yaml +++ b/ppcls/configs/Products/MV3_Large_1x_Aliproduct_DLBHC.yaml @@ -1,7 +1,7 @@ # global configs Global: checkpoints: null - pretrained_model: null + pretrained_model: ./output_product_binary/best_model output_dir: ./output_dlbhc/ device: gpu save_interval: 1 @@ -34,6 +34,7 @@ Arch: infer_output_key: "features" infer_add_softmax: "false" + infer_binarize: "round" # loss function config for train/eval process Loss: diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index c24a9163342e381abdc2706d96fe25f4d729e850..3db3ef0a4e17da6243605d08a7cb355e574176b1 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -378,7 +378,9 @@ class ExportModel(nn.Layer): self.infer_output_key = config.get("infer_output_key", None) if self.infer_output_key == "features" and isinstance(self.base_model, RecModel): - self.base_model.head = IdentityHead() + embedding_size = config["Head"]["embedding_size"] + self.base_model.head = IdentityHead(config.get("infer_binarize", "none"), embedding_size) + if config.get("infer_add_softmax", True): self.softmax = nn.Softmax(axis=-1) else: @@ -394,10 +396,13 @@ class ExportModel(nn.Layer): x = self.base_model(x) if isinstance(x, list): x = x[0] + if self.infer_model_name is not None: x = x[self.infer_model_name] + if self.infer_output_key is not None: x = x[self.infer_output_key] + if self.softmax is not None: x = self.softmax(x) return x