diff --git a/deploy/python/ppshitu_v2/configs/test_cls_config.yaml b/deploy/python/ppshitu_v2/configs/test_cls_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2b139a84922c0013e6ef29cd633617ccec069ada --- /dev/null +++ b/deploy/python/ppshitu_v2/configs/test_cls_config.yaml @@ -0,0 +1,38 @@ +Global: + Engine: POPEngine + infer_imgs: "../../images/wangzai.jpg" + + +Modules: + - name: + type: AlgoMod + processors: + - name: ImageProcessor + type: preprocessor + ops: + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 0.00392157 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: hwc + - ToCHWImage: + - GetShapeInfo: + configs: + order: chw + - ToBatch: + - name: PaddlePredictor + type: predictor + inference_model_dir: "./MobileNetV2_infer" + input_names: + inputs: image + output_names: + save_infer_model/scale_0.tmp_1: logits + - name: TopK + type: postprocessor + k: 10 + class_id_map_file: "../ppcls/utils/imagenet1k_label_list.txt" + save_dir: None \ No newline at end of file diff --git a/deploy/python/ppshitu_v2/configs/test_det_config.yaml b/deploy/python/ppshitu_v2/configs/test_det_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..64a421fae827d15fd39603736a4ff2294f21c590 --- /dev/null +++ b/deploy/python/ppshitu_v2/configs/test_det_config.yaml @@ -0,0 +1,33 @@ +Global: + Engine: POPEngine + infer_imgs: "../../images/wangzai.jpg" + +Modules: + - name: + type: AlgoMod + processors: + - name: ImageProcessor + type: preprocessor + ops: + - ResizeImage: + size: [640, 640] + interpolation: 2 + - NormalizeImage: + scale: 0.00392157 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: hwc + - ToCHWImage: + - GetShapeInfo: + configs: + order: chw + - ToBatch: + - name: PaddlePredictor + type: predictor + inference_model_dir: ./models/ppyolov2_r50vd_dcn_mainbody_v1.0_infer/ + - name: DetPostPro + type: postprocessor + threshold: 0.2 + max_det_results: 1 + label_list: + - foreground \ No newline at end of file diff --git a/deploy/python/ppshitu_v2/configs/test_rec_config.yaml b/deploy/python/ppshitu_v2/configs/test_rec_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7ce8e4681f85b8b2f107d229f890355d165cc79d --- /dev/null +++ b/deploy/python/ppshitu_v2/configs/test_rec_config.yaml @@ -0,0 +1,34 @@ +Global: + Engine: POPEngine + infer_imgs: "../../images/wangzai.jpg" + +Modules: + - name: + type: AlgoMod + processors: + - name: ImageProcessor + type: preprocessor + ops: + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 0.00392157 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: hwc + - ToCHWImage: + - GetShapeInfo: + configs: + order: chw + - ToBatch: + - name: PaddlePredictor + type: predictor + inference_model_dir: models/product_ResNet50_vd_aliproduct_v1.0_infer + input_names: + x: image + output_names: + save_infer_model/scale_0.tmp_1: features + - name: FeatureNormalizer + type: postprocessor \ No newline at end of file diff --git a/deploy/python/ppshitu_v2/configs/test_search_config.yaml b/deploy/python/ppshitu_v2/configs/test_search_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..eaf8486ab522be1ad62debe716b85ed0601bd69a --- /dev/null +++ b/deploy/python/ppshitu_v2/configs/test_search_config.yaml @@ -0,0 +1,16 @@ +Global: + Engine: POPEngine + infer_imgs: "./vector.npy" + +Modules: + - name: + type: AlgoMod + processors: + - name: Searcher + type: searcher + index_dir: "./index" + dist_type: "IP" + embedding_size: 512 + batch_size: 32 + return_k: 5 + score_thres: 0.5 \ No newline at end of file diff --git a/deploy/python/ppshitu_v2/examples/predict.py b/deploy/python/ppshitu_v2/examples/predict.py index 0379808dc1ae52ce8d55280fcb8d34a9fd6111e6..3790ec988c2e2ec82914d5edb2b8476dd87caca4 100644 --- a/deploy/python/ppshitu_v2/examples/predict.py +++ b/deploy/python/ppshitu_v2/examples/predict.py @@ -18,8 +18,16 @@ def main(): image_file = "../../images/wangzai.jpg" img = cv2.imread(image_file)[:, :, ::-1] input_data = {"input_image": img} - output = engine.process(input_data) - print(output) + data = engine.process(input_data) + + # for det, cls + # print(data) + + # for rec + # features = data["pred"]["features"] + # print(features) + # print(features.shape) + # print(type(features)) if __name__ == '__main__': diff --git a/deploy/python/ppshitu_v2/examples/test_search.py b/deploy/python/ppshitu_v2/examples/test_search.py new file mode 100644 index 0000000000000000000000000000000000000000..11b36df739035b6f0dd1ee0b4ed163b595a2c381 --- /dev/null +++ b/deploy/python/ppshitu_v2/examples/test_search.py @@ -0,0 +1,31 @@ +import os +import sys +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(os.path.abspath(os.path.join(__dir__, '../'))) + +import cv2 + +from engine import build_engine +from utils import config +from utils.get_image_list import get_image_list + +import numpy as np + + +def load_vector(path): + return np.load(path) + + +def main(): + args = config.parse_args() + config_dict = config.get_config( + args.config, overrides=args.override, show=False) + config_dict.profiler_options = args.profiler_options + engine = build_engine(config_dict) + vector = load_vector(config_dict["Global"]["infer_imgs"]) + output = engine.process({"features": vector}) + print(output["search_res"]) + + +if __name__ == '__main__': + main() diff --git a/deploy/python/ppshitu_v2/processor/algo_mod/__init__.py b/deploy/python/ppshitu_v2/processor/algo_mod/__init__.py index c4672b9eebd98c5fbdebe4b79df9e12db716aaae..d5b09fe6fbd2591e5fe5bcbdfd431ce4b2737288 100644 --- a/deploy/python/ppshitu_v2/processor/algo_mod/__init__.py +++ b/deploy/python/ppshitu_v2/processor/algo_mod/__init__.py @@ -1,6 +1,13 @@ -from .postprocessor import build_postprocessor -from .preprocessor import build_preprocessor -from .predictor import build_predictor +# from .postprocessor import build_postprocessor +# from .preprocessor import build_preprocessor +# from .predictor import build_predictor + +import importlib + +from processor.algo_mod import preprocessor +from processor.algo_mod import predictor +from processor.algo_mod import postprocessor +from processor.algo_mod import searcher from ..base_processor import BaseProcessor @@ -10,14 +17,20 @@ class AlgoMod(BaseProcessor): self.processors = [] for processor_config in config["processors"]: processor_type = processor_config.get("type") - if processor_type == "preprocessor": - processor = build_preprocessor(processor_config) - elif processor_type == "predictor": - processor = build_predictor(processor_config) - elif processor_type == "postprocessor": - processor = build_postprocessor(processor_config) - else: - raise NotImplemented("processor type {} unknown.".format(processor_type)) + processor_name = processor_config.get("name") + _mod = importlib.import_module(__name__) + processor = getattr( + getattr(_mod, processor_type), + processor_name)(processor_config) + + # if processor_type == "preprocessor": + # processor = build_preprocessor(processor_config) + # elif processor_type == "predictor": + # processor = build_predictor(processor_config) + # elif processor_type == "postprocessor": + # processor = build_postprocessor(processor_config) + # else: + # raise NotImplemented("processor type {} unknown.".format(processor_type)) self.processors.append(processor) def process(self, input_data): diff --git a/deploy/python/ppshitu_v2/processor/algo_mod/postprocessor/__init__.py b/deploy/python/ppshitu_v2/processor/algo_mod/postprocessor/__init__.py index 89500a7d4b8668ea0c849f16e18c95f69108bf0e..6567e6d0c8f837061f35543b0f61dc72d3333e38 100644 --- a/deploy/python/ppshitu_v2/processor/algo_mod/postprocessor/__init__.py +++ b/deploy/python/ppshitu_v2/processor/algo_mod/postprocessor/__init__.py @@ -1,9 +1,10 @@ import importlib +from .classification import TopK from .det import DetPostPro +from .rec import FeatureNormalizer - -def build_postprocessor(config): - processor_mod = importlib.import_module(__name__) - processor_name = config.get("name") - return getattr(processor_mod, processor_name)(config) +# def build_postprocessor(config): +# processor_mod = importlib.import_module(__name__) +# processor_name = config.get("name") +# return getattr(processor_mod, processor_name)(config) diff --git a/deploy/python/ppshitu_v2/processor/algo_mod/postprocessor/classification.py b/deploy/python/ppshitu_v2/processor/algo_mod/postprocessor/classification.py new file mode 100644 index 0000000000000000000000000000000000000000..6ca945f70902330bccf8f966b16c07ca2f25be77 --- /dev/null +++ b/deploy/python/ppshitu_v2/processor/algo_mod/postprocessor/classification.py @@ -0,0 +1,68 @@ +import os + +import numpy as np + +from ...base_processor import BaseProcessor + + +class TopK(BaseProcessor): + def __init__(self, config): + self.topk = config["k"] + assert isinstance(self.topk, (int, )) + + class_id_map_file = config["class_id_map_file"] + self.class_id_map = self.parse_class_id_map(class_id_map_file) + + self.multilabel = config.get("multilabel", False) + + def parse_class_id_map(self, class_id_map_file): + if class_id_map_file is None: + return None + + if not os.path.exists(class_id_map_file): + print( + "Warning: If want to use your own label_dict, please input legal path!\nOtherwise label_names will be empty!" + ) + return None + + try: + class_id_map = {} + with open(class_id_map_file, "r") as fin: + lines = fin.readlines() + for line in lines: + partition = line.split("\n")[0].partition(" ") + class_id_map[int(partition[0])] = str(partition[-1]) + except Exception as ex: + print(ex) + class_id_map = None + return class_id_map + + def process(self, data): + x = data["pred"]["logits"] + # TODO(gaotingquan): support file_name + # if file_names is not None: + # assert x.shape[0] == len(file_names) + y = [] + for idx, probs in enumerate(x): + index = probs.argsort(axis=0)[-self.topk:][::-1].astype( + "int32") if not self.multilabel else np.where( + probs >= 0.5)[0].astype("int32") + clas_id_list = [] + score_list = [] + label_name_list = [] + for i in index: + clas_id_list.append(i.item()) + score_list.append(probs[i].item()) + if self.class_id_map is not None: + label_name_list.append(self.class_id_map[i.item()]) + result = { + "class_ids": clas_id_list, + "scores": np.around( + score_list, decimals=5).tolist(), + } + # if file_names is not None: + # result["file_name"] = file_names[idx] + if label_name_list is not None: + result["label_names"] = label_name_list + y.append(result) + return y diff --git a/deploy/python/ppshitu_v2/processor/algo_mod/postprocessor/rec.py b/deploy/python/ppshitu_v2/processor/algo_mod/postprocessor/rec.py new file mode 100644 index 0000000000000000000000000000000000000000..53c197fdb5d2b912aeb38da79943ed870350f5c1 --- /dev/null +++ b/deploy/python/ppshitu_v2/processor/algo_mod/postprocessor/rec.py @@ -0,0 +1,16 @@ +import numpy as np + +from ...base_processor import BaseProcessor + + +class FeatureNormalizer(BaseProcessor): + def __init__(self, config=None): + pass + + def process(self, data): + batch_output = data["pred"]["features"] + feas_norm = np.sqrt( + np.sum(np.square(batch_output), axis=1, keepdims=True)) + batch_output = np.divide(batch_output, feas_norm) + data["pred"]["features"] = batch_output + return data diff --git a/deploy/python/ppshitu_v2/processor/algo_mod/predictor/__init__.py b/deploy/python/ppshitu_v2/processor/algo_mod/predictor/__init__.py index 2913771b32ac0f43b66ad1dad0815b58d9a4544f..4e27b1176e34bdfee3d17d7c646d650acdf0b984 100644 --- a/deploy/python/ppshitu_v2/processor/algo_mod/predictor/__init__.py +++ b/deploy/python/ppshitu_v2/processor/algo_mod/predictor/__init__.py @@ -3,8 +3,7 @@ import importlib from processor.algo_mod.predictor.paddle_predictor import PaddlePredictor from processor.algo_mod.predictor.onnx_predictor import ONNXPredictor - -def build_predictor(config): - processor_mod = importlib.import_module(__name__) - processor_name = config.get("name") - return getattr(processor_mod, processor_name)(config) +# def build_predictor(config): +# processor_mod = importlib.import_module(__name__) +# processor_name = config.get("name") +# return getattr(processor_mod, processor_name)(config) diff --git a/deploy/python/ppshitu_v2/processor/algo_mod/predictor/paddle_predictor.py b/deploy/python/ppshitu_v2/processor/algo_mod/predictor/paddle_predictor.py index b4a248616c6fcdbb19c77366eace2d89c846d869..d8c61e930b084c364adcd91e67a85d20594e5643 100644 --- a/deploy/python/ppshitu_v2/processor/algo_mod/predictor/paddle_predictor.py +++ b/deploy/python/ppshitu_v2/processor/algo_mod/predictor/paddle_predictor.py @@ -48,17 +48,30 @@ class PaddlePredictor(BaseProcessor): paddle_config.switch_use_feed_fetch_ops(False) self.predictor = create_predictor(paddle_config) - def process(self, input_data): + if "input_names" in config and config["input_names"]: + self.input_name_mapping = config["input_names"] + else: + self.input_name_mapping = [] + if "output_names" in config and config["output_names"]: + self.output_name_mapping = config["output_names"] + else: + self.output_name_mapping = [] + + def process(self, data): input_names = self.predictor.get_input_names() for input_name in input_names: input_tensor = self.predictor.get_input_handle(input_name) - input_tensor.copy_from_cpu(input_data[input_name]) + name = self.input_name_mapping[ + input_name] if input_name in self.input_name_mapping else input_name + input_tensor.copy_from_cpu(data[name]) self.predictor.run() output_data = {} output_names = self.predictor.get_output_names() for output_name in output_names: output = self.predictor.get_output_handle(output_name) - output_data[output_name] = output.copy_to_cpu() - input_data["pred"] = output_data - return input_data + name = self.output_name_mapping[ + output_name] if output_name in self.output_name_mapping else output_name + output_data[name] = output.copy_to_cpu() + data["pred"] = output_data + return data diff --git a/deploy/python/ppshitu_v2/processor/algo_mod/preprocessor/__init__.py b/deploy/python/ppshitu_v2/processor/algo_mod/preprocessor/__init__.py index ffc9efde853bf32b8fe997d4c7b8402c9354dce1..94b889f43b52ac7fe87f48fdcfa4e18eb8a34a1e 100644 --- a/deploy/python/ppshitu_v2/processor/algo_mod/preprocessor/__init__.py +++ b/deploy/python/ppshitu_v2/processor/algo_mod/preprocessor/__init__.py @@ -2,8 +2,7 @@ import importlib from processor.algo_mod.preprocessor.image_processor import ImageProcessor - -def build_preprocessor(config): - processor_mod = importlib.import_module(__name__) - processor_name = config.get("name") - return getattr(processor_mod, processor_name)(config) +# def build_preprocessor(config): +# processor_mod = importlib.import_module(__name__) +# processor_name = config.get("name") +# return getattr(processor_mod, processor_name)(config) diff --git a/deploy/python/ppshitu_v2/processor/algo_mod/searcher/__init__.py b/deploy/python/ppshitu_v2/processor/algo_mod/searcher/__init__.py index 28849db7a9b0d56552285eaf17c315e2c92c3353..f986738cfac84f6cd90de5e3bcb0724d3565cbbe 100644 --- a/deploy/python/ppshitu_v2/processor/algo_mod/searcher/__init__.py +++ b/deploy/python/ppshitu_v2/processor/algo_mod/searcher/__init__.py @@ -1,4 +1,23 @@ +import os +import pickle +import faiss -def build_searcher(config): - pass + +class Searcher: + def __init__(self, config): + super().__init__() + + self.Searcher = faiss.read_index( + os.path.join(config["index_dir"], "vector.index")) + + with open(os.path.join(config["index_dir"], "id_map.pkl"), "rb") as fd: + self.id_map = pickle.load(fd) + + self.return_k = config["return_k"] + + def process(self, data): + features = data["features"] + scores, docs = self.Searcher.search(features, self.return_k) + data["search_res"] = (scores, docs) + return data