diff --git a/deploy/python/ppshitu_v2/configs/test_cls_config.yaml b/deploy/python/ppshitu_v2/configs/test_cls_config.yaml index 2b139a84922c0013e6ef29cd633617ccec069ada..7b18e9a08b4891f99eaadad5fc374286482be364 100644 --- a/deploy/python/ppshitu_v2/configs/test_cls_config.yaml +++ b/deploy/python/ppshitu_v2/configs/test_cls_config.yaml @@ -27,10 +27,10 @@ Modules: - name: PaddlePredictor type: predictor inference_model_dir: "./MobileNetV2_infer" - input_names: - inputs: image - output_names: - save_infer_model/scale_0.tmp_1: logits + to_model_names: + image: inputs + from_model_names: + logits: 0 - name: TopK type: postprocessor k: 10 diff --git a/deploy/python/ppshitu_v2/configs/test_rec_config.yaml b/deploy/python/ppshitu_v2/configs/test_rec_config.yaml index 7ce8e4681f85b8b2f107d229f890355d165cc79d..1c986bd809bc3a2d9638f589217b3c0c5ca8dc8a 100644 --- a/deploy/python/ppshitu_v2/configs/test_rec_config.yaml +++ b/deploy/python/ppshitu_v2/configs/test_rec_config.yaml @@ -26,9 +26,9 @@ Modules: - name: PaddlePredictor type: predictor inference_model_dir: models/product_ResNet50_vd_aliproduct_v1.0_infer - input_names: - x: image - output_names: - save_infer_model/scale_0.tmp_1: features + to_model_names: + image: x + from_model_names: + features: 0 - name: FeatureNormalizer type: postprocessor \ No newline at end of file diff --git a/deploy/python/ppshitu_v2/examples/predict.py b/deploy/python/ppshitu_v2/examples/predict.py index 3790ec988c2e2ec82914d5edb2b8476dd87caca4..19d02406d45799b7dedcbba6904b8ab29980cad1 100644 --- a/deploy/python/ppshitu_v2/examples/predict.py +++ b/deploy/python/ppshitu_v2/examples/predict.py @@ -20,14 +20,20 @@ def main(): input_data = {"input_image": img} data = engine.process(input_data) - # for det, cls - # print(data) - + # for cls + if "classification_res" in data: + print(data["classification_res"]) + # for det + elif "detection_res" in data: + print(data["detection_res"]) # for rec - # features = data["pred"]["features"] - # print(features) - # print(features.shape) - # print(type(features)) + elif "features" in data["pred"]: + features = data["pred"]["features"] + print(features) + print(features.shape) + print(type(features)) + else: + print("ERROR") if __name__ == '__main__': diff --git a/deploy/python/ppshitu_v2/processor/algo_mod/__init__.py b/deploy/python/ppshitu_v2/processor/algo_mod/__init__.py index d5b09fe6fbd2591e5fe5bcbdfd431ce4b2737288..57f82899c37504732a28334340d62b8a1146f114 100644 --- a/deploy/python/ppshitu_v2/processor/algo_mod/__init__.py +++ b/deploy/python/ppshitu_v2/processor/algo_mod/__init__.py @@ -1,13 +1,7 @@ -# from .postprocessor import build_postprocessor -# from .preprocessor import build_preprocessor -# from .predictor import build_predictor - -import importlib - -from processor.algo_mod import preprocessor -from processor.algo_mod import predictor -from processor.algo_mod import postprocessor -from processor.algo_mod import searcher +from .postprocessor import build_postprocessor +from .preprocessor import build_preprocessor +from .predictor import build_predictor +from .searcher import build_searcher from ..base_processor import BaseProcessor @@ -17,20 +11,18 @@ class AlgoMod(BaseProcessor): self.processors = [] for processor_config in config["processors"]: processor_type = processor_config.get("type") - processor_name = processor_config.get("name") - _mod = importlib.import_module(__name__) - processor = getattr( - getattr(_mod, processor_type), - processor_name)(processor_config) - # if processor_type == "preprocessor": - # processor = build_preprocessor(processor_config) - # elif processor_type == "predictor": - # processor = build_predictor(processor_config) - # elif processor_type == "postprocessor": - # processor = build_postprocessor(processor_config) - # else: - # raise NotImplemented("processor type {} unknown.".format(processor_type)) + if processor_type == "preprocessor": + processor = build_preprocessor(processor_config) + elif processor_type == "predictor": + processor = build_predictor(processor_config) + elif processor_type == "postprocessor": + processor = build_postprocessor(processor_config) + elif processor_type == "searcher": + processor = build_searcher(processor_config) + else: + raise NotImplemented("processor type {} unknown.".format( + processor_type)) self.processors.append(processor) def process(self, input_data): diff --git a/deploy/python/ppshitu_v2/processor/algo_mod/postprocessor/__init__.py b/deploy/python/ppshitu_v2/processor/algo_mod/postprocessor/__init__.py index 6567e6d0c8f837061f35543b0f61dc72d3333e38..9edb322b59ac89b62b3967b0852b522d4ab898db 100644 --- a/deploy/python/ppshitu_v2/processor/algo_mod/postprocessor/__init__.py +++ b/deploy/python/ppshitu_v2/processor/algo_mod/postprocessor/__init__.py @@ -4,7 +4,8 @@ from .classification import TopK from .det import DetPostPro from .rec import FeatureNormalizer -# def build_postprocessor(config): -# processor_mod = importlib.import_module(__name__) -# processor_name = config.get("name") -# return getattr(processor_mod, processor_name)(config) + +def build_postprocessor(config): + processor_mod = importlib.import_module(__name__) + processor_name = config.get("name") + return getattr(processor_mod, processor_name)(config) diff --git a/deploy/python/ppshitu_v2/processor/algo_mod/postprocessor/classification.py b/deploy/python/ppshitu_v2/processor/algo_mod/postprocessor/classification.py index 6ca945f70902330bccf8f966b16c07ca2f25be77..4c15b957f534ef82ba3ac705c8e03cae170e78b6 100644 --- a/deploy/python/ppshitu_v2/processor/algo_mod/postprocessor/classification.py +++ b/deploy/python/ppshitu_v2/processor/algo_mod/postprocessor/classification.py @@ -2,6 +2,7 @@ import os import numpy as np +from utils import logger from ...base_processor import BaseProcessor @@ -20,8 +21,8 @@ class TopK(BaseProcessor): return None if not os.path.exists(class_id_map_file): - print( - "Warning: If want to use your own label_dict, please input legal path!\nOtherwise label_names will be empty!" + logger.warning( + "[Classification] If want to use your own label_dict, please input legal path!\nOtherwise label_names will be empty!" ) return None @@ -33,36 +34,31 @@ class TopK(BaseProcessor): partition = line.split("\n")[0].partition(" ") class_id_map[int(partition[0])] = str(partition[-1]) except Exception as ex: - print(ex) + logger.warning(f"[Classification] {ex}") class_id_map = None return class_id_map def process(self, data): - x = data["pred"]["logits"] - # TODO(gaotingquan): support file_name - # if file_names is not None: - # assert x.shape[0] == len(file_names) - y = [] - for idx, probs in enumerate(x): - index = probs.argsort(axis=0)[-self.topk:][::-1].astype( - "int32") if not self.multilabel else np.where( - probs >= 0.5)[0].astype("int32") - clas_id_list = [] - score_list = [] - label_name_list = [] - for i in index: - clas_id_list.append(i.item()) - score_list.append(probs[i].item()) - if self.class_id_map is not None: - label_name_list.append(self.class_id_map[i.item()]) - result = { - "class_ids": clas_id_list, - "scores": np.around( - score_list, decimals=5).tolist(), - } - # if file_names is not None: - # result["file_name"] = file_names[idx] - if label_name_list is not None: - result["label_names"] = label_name_list - y.append(result) - return y + # TODO(gaotingquan): only support bs==1 when 'connector' is not implemented. + probs = data["pred"]["logits"][0] + index = probs.argsort(axis=0)[-self.topk:][::-1].astype( + "int32") if not self.multilabel else np.where( + probs >= 0.5)[0].astype("int32") + clas_id_list = [] + score_list = [] + label_name_list = [] + for i in index: + clas_id_list.append(i.item()) + score_list.append(probs[i].item()) + if self.class_id_map is not None: + label_name_list.append(self.class_id_map[i.item()]) + result = { + "class_ids": clas_id_list, + "scores": np.around( + score_list, decimals=5).tolist(), + } + if label_name_list is not None: + result["label_names"] = label_name_list + + data["classification_res"] = result + return data diff --git a/deploy/python/ppshitu_v2/processor/algo_mod/postprocessor/det.py b/deploy/python/ppshitu_v2/processor/algo_mod/postprocessor/det.py index c69a9cf1762b7c526b3021ee877801e7aab7eb56..58743063726ad51c6957c3ccd505ebcf0c27864d 100644 --- a/deploy/python/ppshitu_v2/processor/algo_mod/postprocessor/det.py +++ b/deploy/python/ppshitu_v2/processor/algo_mod/postprocessor/det.py @@ -11,27 +11,34 @@ class DetPostPro(BaseProcessor): self.label_list = config["label_list"] self.max_det_results = config["max_det_results"] - def process(self, input_data): - pred = input_data["pred"] + def process(self, data): + pred = data["pred"] np_boxes = pred[list(pred.keys())[0]] - if reduce(lambda x, y: x * y, np_boxes.shape) < 6: - logger.warning('[Detector] No object detected.') - np_boxes = np.array([]) - - keep_indexes = np_boxes[:, 1].argsort()[::-1][:self.max_det_results] - results = [] - for idx in keep_indexes: - single_res = np_boxes[idx] + if reduce(lambda x, y: x * y, np_boxes.shape) >= 6: + keep_indexes = np_boxes[:, 1].argsort()[::-1][: + self.max_det_results] + # TODO(gaotingquan): only support bs==1 + single_res = np_boxes[0] class_id = int(single_res[0]) score = single_res[1] bbox = single_res[2:] - if score < self.threshold: - continue - label_name = self.label_list[class_id] - results.append({ - "class_id": class_id, - "score": score, - "bbox": bbox, - "label_name": label_name, - }) - return results + if score > self.threshold: + label_name = self.label_list[class_id] + results = { + "class_id": class_id, + "score": score, + "bbox": bbox, + "label_name": label_name, + } + data["detection_res"] = results + return data + + logger.warning('[Detector] No object detected.') + results = { + "class_id": None, + "score": None, + "bbox": None, + "label_name": None, + } + data["detection_res"] = results + return data diff --git a/deploy/python/ppshitu_v2/processor/algo_mod/predictor/__init__.py b/deploy/python/ppshitu_v2/processor/algo_mod/predictor/__init__.py index 4e27b1176e34bdfee3d17d7c646d650acdf0b984..2913771b32ac0f43b66ad1dad0815b58d9a4544f 100644 --- a/deploy/python/ppshitu_v2/processor/algo_mod/predictor/__init__.py +++ b/deploy/python/ppshitu_v2/processor/algo_mod/predictor/__init__.py @@ -3,7 +3,8 @@ import importlib from processor.algo_mod.predictor.paddle_predictor import PaddlePredictor from processor.algo_mod.predictor.onnx_predictor import ONNXPredictor -# def build_predictor(config): -# processor_mod = importlib.import_module(__name__) -# processor_name = config.get("name") -# return getattr(processor_mod, processor_name)(config) + +def build_predictor(config): + processor_mod = importlib.import_module(__name__) + processor_name = config.get("name") + return getattr(processor_mod, processor_name)(config) diff --git a/deploy/python/ppshitu_v2/processor/algo_mod/predictor/paddle_predictor.py b/deploy/python/ppshitu_v2/processor/algo_mod/predictor/paddle_predictor.py index d8c61e930b084c364adcd91e67a85d20594e5643..0a10a44349f685c27786f69413135b8bb6f818a5 100644 --- a/deploy/python/ppshitu_v2/processor/algo_mod/predictor/paddle_predictor.py +++ b/deploy/python/ppshitu_v2/processor/algo_mod/predictor/paddle_predictor.py @@ -48,30 +48,40 @@ class PaddlePredictor(BaseProcessor): paddle_config.switch_use_feed_fetch_ops(False) self.predictor = create_predictor(paddle_config) - if "input_names" in config and config["input_names"]: - self.input_name_mapping = config["input_names"] + if "to_model_names" in config and config["to_model_names"]: + self.input_name_map = { + v: k + for k, v in config["to_model_names"].items() + } else: - self.input_name_mapping = [] - if "output_names" in config and config["output_names"]: - self.output_name_mapping = config["output_names"] + self.input_name_map = {} + if "from_model_names" in config and config["from_model_names"]: + self.output_name_map = config["from_model_names"] else: - self.output_name_mapping = [] + self.output_name_map = {} def process(self, data): input_names = self.predictor.get_input_names() for input_name in input_names: input_tensor = self.predictor.get_input_handle(input_name) - name = self.input_name_mapping[ - input_name] if input_name in self.input_name_mapping else input_name + name = self.input_name_map[ + input_name] if input_name in self.input_name_map else input_name input_tensor.copy_from_cpu(data[name]) self.predictor.run() - output_data = {} + model_output = [] output_names = self.predictor.get_output_names() for output_name in output_names: output = self.predictor.get_output_handle(output_name) - name = self.output_name_mapping[ - output_name] if output_name in self.output_name_mapping else output_name - output_data[name] = output.copy_to_cpu() + model_output.append((output_name, output.copy_to_cpu())) + + if self.output_name_map: + output_data = {} + for name in self.output_name_map: + idx = self.output_name_map[name] + output_data[name] = model_output[idx][1] + else: + output_data = dict(model_output) + data["pred"] = output_data return data diff --git a/deploy/python/ppshitu_v2/processor/algo_mod/preprocessor/__init__.py b/deploy/python/ppshitu_v2/processor/algo_mod/preprocessor/__init__.py index 94b889f43b52ac7fe87f48fdcfa4e18eb8a34a1e..ffc9efde853bf32b8fe997d4c7b8402c9354dce1 100644 --- a/deploy/python/ppshitu_v2/processor/algo_mod/preprocessor/__init__.py +++ b/deploy/python/ppshitu_v2/processor/algo_mod/preprocessor/__init__.py @@ -2,7 +2,8 @@ import importlib from processor.algo_mod.preprocessor.image_processor import ImageProcessor -# def build_preprocessor(config): -# processor_mod = importlib.import_module(__name__) -# processor_name = config.get("name") -# return getattr(processor_mod, processor_name)(config) + +def build_preprocessor(config): + processor_mod = importlib.import_module(__name__) + processor_name = config.get("name") + return getattr(processor_mod, processor_name)(config) diff --git a/deploy/python/ppshitu_v2/processor/algo_mod/searcher/__init__.py b/deploy/python/ppshitu_v2/processor/algo_mod/searcher/__init__.py index f986738cfac84f6cd90de5e3bcb0724d3565cbbe..6bc378c8fa8b93b47eb07b2536711c4a9931505b 100644 --- a/deploy/python/ppshitu_v2/processor/algo_mod/searcher/__init__.py +++ b/deploy/python/ppshitu_v2/processor/algo_mod/searcher/__init__.py @@ -4,11 +4,15 @@ import pickle import faiss +def build_searcher(config): + return Searcher(config) + + class Searcher: def __init__(self, config): super().__init__() - self.Searcher = faiss.read_index( + self.faiss_searcher = faiss.read_index( os.path.join(config["index_dir"], "vector.index")) with open(os.path.join(config["index_dir"], "id_map.pkl"), "rb") as fd: @@ -18,6 +22,11 @@ class Searcher: def process(self, data): features = data["features"] - scores, docs = self.Searcher.search(features, self.return_k) - data["search_res"] = (scores, docs) + scores, docs = self.faiss_searcher.search(features, self.return_k) + + preds = {} + preds["rec_docs"] = self.id_map[docs[0][0]].split()[1] + preds["rec_scores"] = scores[0][0] + + data["search_res"] = preds return data