diff --git a/deploy/python/ppshitu_v2/configs/test_config.yml b/deploy/python/ppshitu_v2/configs/test_config.yml index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..75d33f70a936da3ec801cff1302f931257785e2c 100644 --- a/deploy/python/ppshitu_v2/configs/test_config.yml +++ b/deploy/python/ppshitu_v2/configs/test_config.yml @@ -0,0 +1,13 @@ +AlgoModule: + - preprocess: + - processor_type: data_processor + processor_name: image_processor + image_processors: + - ResizeImage: + size: [640, 640] + interpolation: 2 + - NormalizeImage: + scale: 0.00392157 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + - ToRGB diff --git a/deploy/python/ppshitu_v2/processor/__init__.py b/deploy/python/ppshitu_v2/processor/__init__.py index 8ef7d3b7f89905a91a0ffbbb1fe4798cb9138b2c..e970ce983d87be700df5cb7e303a431d0a229537 100644 --- a/deploy/python/ppshitu_v2/processor/__init__.py +++ b/deploy/python/ppshitu_v2/processor/__init__.py @@ -1,20 +1,13 @@ from abc import ABC, abstractmethod -from algo_mod import build_algo_mod -from searcher import build_searcher -from data_processor import build_data_processor +from processor.algo_mod import predictors, searcher def build_processor(config): processor_type = config.get("processor_type") - if processor_type == "algo_mod": - return build_algo_mod(config) - elif processor_type == "searcher": - return build_searcher(config) - elif processor_type == "data_processor": - return build_data_processor(config) - else: - raise NotImplemented("processor_type {} not implemented.".format(processor_type)) + processor_mod = locals()[processor_type] + processor_name = config.get("processor_name") + return getattr(processor_mod, processor_name) class BaseProcessor(ABC): diff --git a/deploy/python/ppshitu_v2/processor/algo_mod/__init__.py b/deploy/python/ppshitu_v2/processor/algo_mod/__init__.py index 32ef848ff7b2f473d299c5bde6a924be14f2d94d..841ea7126dae482f501b41818d9746303b268687 100644 --- a/deploy/python/ppshitu_v2/processor/algo_mod/__init__.py +++ b/deploy/python/ppshitu_v2/processor/algo_mod/__init__.py @@ -1,7 +1,14 @@ -from .fake_cls import FakeClassifier +from .. import BaseProcessor, build_processor -def build_algo_mod(config): - algo_name = config.get("algo_name") - if algo_name == "fake_clas": - return FakeClassifier(config) +class AlgoMod(BaseProcessor): + def __init__(self, config): + self.pre_processor = build_processor(config["pre_processor"]) + self.predictor = build_processor(config["predictor"]) + self.post_processor = build_processor(config["post_processor"]) + + def process(self, input_data): + input_data = self.pre_processor(input_data) + input_data = self.predictor(input_data) + input_data = self.post_processor(input_data) + return input_data diff --git a/deploy/python/ppshitu_v2/processor/algo_mod/data_processor/__init__.py b/deploy/python/ppshitu_v2/processor/algo_mod/data_processor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7600ca6b2e156b8f6fcfc210ed960ec79865824e --- /dev/null +++ b/deploy/python/ppshitu_v2/processor/algo_mod/data_processor/__init__.py @@ -0,0 +1 @@ +from image_processor import ImageProcessor diff --git a/deploy/python/ppshitu_v2/processor/data_processor/bbox_cropper.py b/deploy/python/ppshitu_v2/processor/algo_mod/data_processor/bbox_cropper.py similarity index 78% rename from deploy/python/ppshitu_v2/processor/data_processor/bbox_cropper.py rename to deploy/python/ppshitu_v2/processor/algo_mod/data_processor/bbox_cropper.py index 9f15fc268c0f1df1603f1e5aed95f3051d7673d1..a3deeaef13fb3ed7afffef7af110d762476ed5eb 100644 --- a/deploy/python/ppshitu_v2/processor/data_processor/bbox_cropper.py +++ b/deploy/python/ppshitu_v2/processor/algo_mod/data_processor/bbox_cropper.py @@ -1,4 +1,4 @@ -from .. import BaseProcessor +from processor import BaseProcessor class BBoxCropper(BaseProcessor): diff --git a/deploy/python/ppshitu_v2/processor/data_processor/preprocess.py b/deploy/python/ppshitu_v2/processor/algo_mod/data_processor/image_processor.py similarity index 52% rename from deploy/python/ppshitu_v2/processor/data_processor/preprocess.py rename to deploy/python/ppshitu_v2/processor/algo_mod/data_processor/image_processor.py index f4a202fb5276dbcea5574f02630fbb93fdc9c446..f7e758f38b64f138a17e006df1417592cedcf2d8 100644 --- a/deploy/python/ppshitu_v2/processor/data_processor/preprocess.py +++ b/deploy/python/ppshitu_v2/processor/algo_mod/data_processor/image_processor.py @@ -1,85 +1,93 @@ from functools import partial -import six -import math -import random import cv2 import numpy as np import importlib from PIL import Image +import paddle from utils import logger +from processor import BaseProcessor -class PreProcesser(object): +class ImageProcessor(BaseProcessor): def __init__(self, config): - """Image PreProcesser - - Args: - config (list): A list consisting of Dict object that describe an image processer operator. - """ - super().__init__() - self.ops = self.create_ops(config) - - def create_ops(self, config): - if not isinstance(config, list): - msg = "The preprocess config should be a list consisting of Dict object." - logger.error(msg) - raise Exception(msg) - mod = importlib.import_module(__name__) - ops = [] - for op_config in config: - name = list(op_config)[0] - param = {} if op_config[name] is None else op_config[name] - op = getattr(mod, name)(**param) - ops.append(op) - return ops + self.processors = [] + for processor_config in config.get("image_processors"): + name = list(processor_config)[0] + param = {} if processor_config[name] is None else processor_config[name] + op = locals()[name](**param) + self.processors.append(op) + + def process(self, input_data): + image = input_data["input_image"] + for processor in self.processors: + if isinstance(processor, BaseProcessor): + input_data["image"] = image + input_data = processor.process(input_data) + else: + image = processor(image) + return input_data + + +class GetShapeInfo(BaseProcessor): + def __init__(self): + pass + + def process(self, input_data): + input_image = input_data["input_image"] + image = input_data["image"] + input_data['im_shape'] = np.array(input_image.shape[:2], dtype=np.float32) + input_data['input_shape'] = np.array(image.shape[:2], dtype=np.float32) + input_data['scale_factor'] = np.array([image.shape[0] / input_image.shape[0], + image.shape[1] / input_image.shape[1]], dtype=np.float32) - def __call__(self, img, img_info=None): - if img_info: - for op in self.ops: - img, img_info = op(img, img_info) - return img, img_info - else: - for op in self.ops: - img = op(img) - return img +class ToTensor(BaseProcessor): + def __init__(self, config): + pass + + def process(self, input_data): + image = input_data["image"] + input_data["input_tensor"] = paddle.to_tensor(image) + return input_data -class DecodeImage(object): - """ decode image """ - def __init__(self, to_rgb=True, to_np=False, channel_first=False): - self.to_rgb = to_rgb - self.to_np = to_np # to numpy - self.channel_first = channel_first # only enabled when to_np is True +class ToRGB: + def __init__(self): + pass + + def __call__(self, img): + img = img[:, :, ::-1] + return img + + +class ToCHWImage: + def __init__(self): + pass def __call__(self, img, img_info=None): - if six.PY2: - assert type(img) is str and len( - img) > 0, "invalid input 'img' in DecodeImage" - else: - assert type(img) is bytes and len( - img) > 0, "invalid input 'img' in DecodeImage" - data = np.frombuffer(img, dtype='uint8') - img = cv2.imdecode(data, 1) - if self.to_rgb: - assert img.shape[2] == 3, 'invalid shape of image[%s]' % ( - img.shape) - img = img[:, :, ::-1] - - if self.channel_first: - img = img.transpose((2, 0, 1)) - - if img_info: - img_info["im_shape"] = np.array(img.shape[:2], dtype=np.float32) - img_info["scale_factor"] = np.array([1., 1.], dtype=np.float32) - return img, img_info - else: - return img + img = img.transpose((2, 0, 1)) + return img + +class ResizeImage: + def __init__(self, + size=None, + resize_short=None, + interpolation=None, + backend="cv2"): + if resize_short is not None and resize_short > 0: + self.resize_short = resize_short + self.w = None + self.h = None + elif size is not None: + self.resize_short = None + self.w = size if type(size) is int else size[0] + self.h = size if type(size) is int else size[1] + else: + raise Exception("invalid params for ReisizeImage for '\ + 'both 'size' and 'resize_short' are None") -class UnifiedResize(object): - def __init__(self, interpolation=None, backend="cv2"): _cv2_interp_from_str = { 'nearest': cv2.INTER_NEAREST, 'bilinear': cv2.INTER_LINEAR, @@ -114,38 +122,12 @@ class UnifiedResize(object): self.resize_func = partial(_pil_resize, resample=interpolation) else: logger.warning( - f"The backend of Resize only support \"cv2\" or \"PIL\". \"f{backend}\" is unavailable. Use \"cv2\" instead." + f"The backend of Resize only support \"cv2\" or \"PIL\". \"f{backend}\" is unavailable. " + f"Use \"cv2\" instead." ) self.resize_func = cv2.resize - def __call__(self, src, size): - return self.resize_func(src, size) - - -class ResizeImage(object): - """ resize image """ - - def __init__(self, - size=None, - resize_short=None, - interpolation=None, - backend="cv2"): - if resize_short is not None and resize_short > 0: - self.resize_short = resize_short - self.w = None - self.h = None - elif size is not None: - self.resize_short = None - self.w = size if type(size) is int else size[0] - self.h = size if type(size) is int else size[1] - else: - raise Exception("invalid params for ReisizeImage for '\ - 'both 'size' and 'resize_short' are None") - - self._resize_func = UnifiedResize( - interpolation=interpolation, backend=backend) - - def __call__(self, img, img_info=None): + def __call__(self, img): img_h, img_w = img.shape[:2] if self.resize_short is not None: percent = float(self.resize_short) / min(img_w, img_h) @@ -154,17 +136,11 @@ class ResizeImage(object): else: w = self.w h = self.h - img = self._resize_func(img, (w, h)) - if img_info: - img_info["input_shape"] = img.shape[:2] - img_info["scale_factor"] = np.array( - [img.shape[0] / img_h, img.shape[1] / img_w]).astype("float32") - return img, img_info - else: - return img + img = self.resize_func(img, (w, h)) + return img -class CropImage(object): +class CropImage: """ crop image """ def __init__(self, size): @@ -173,34 +149,25 @@ class CropImage(object): else: self.size = size # (h, w) - def __call__(self, img, img_info=None): + def __call__(self, img): w, h = self.size img_h, img_w = img.shape[:2] if img_h < h or img_w < w: raise Exception( - f"The size({h}, {w}) of CropImage must be greater than size({img_h}, {img_w}) of image. Please check image original size and size of ResizeImage if used." + f"The size({h}, {w}) of CropImage must be greater than size({img_h}, {img_w}) of image. " + f"Please check image original size and size of ResizeImage if used." ) - w_start = (img_w - w) // 2 h_start = (img_h - h) // 2 w_end = w_start + w h_end = h_start + h img = img[h_start:h_end, w_start:w_end, :] - if img_info: - img_info["input_shape"] = img.shape[:2] - # TODO(gaotingquan): im_shape is needed to update? - img_info["im_shape"] = np.array(img.shape[:2], dtype=np.float32) - return img, img_info - else: - return img - + return img -class NormalizeImage(object): - """ normalize image such as substract mean, divide std - """ +class NormalizeImage: def __init__(self, scale=None, mean=None, @@ -210,9 +177,8 @@ class NormalizeImage(object): channel_num=3): if isinstance(scale, str): scale = eval(scale) - assert channel_num in [ - 3, 4 - ], "channel number of input image should be set to 3 or 4." + assert channel_num in [3, 4], \ + "channel number of input image should be set to 3 or 4." self.channel_num = channel_num self.output_dtype = 'float16' if output_fp16 else 'float32' self.scale = np.float32(scale if scale is not None else 1.0 / 255.0) @@ -224,12 +190,8 @@ class NormalizeImage(object): self.mean = np.array(mean).reshape(shape).astype('float32') self.std = np.array(std).reshape(shape).astype('float32') - def __call__(self, img, img_info=None): - if isinstance(img, Image.Image): - img = np.array(img) - - assert isinstance(img, - np.ndarray), "invalid input 'img' in NormalizeImage" + def __call__(self, img): + assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage" img = (img.astype('float32') * self.scale - self.mean) / self.std @@ -244,25 +206,4 @@ class NormalizeImage(object): if self.order == 'chw' else np.concatenate( (img, pad_zeros), axis=2)) img = img.astype(self.output_dtype) - if img_info: - return img, img_info - else: - return img - - -class ToCHWImage(object): - """ convert hwc image to chw image - """ - - def __init__(self): - pass - - def __call__(self, img, img_info=None): - if isinstance(img, Image.Image): - img = np.array(img) - - img = img.transpose((2, 0, 1)) - if img_info: - return img, img_info - else: - return img + return img diff --git a/deploy/python/ppshitu_v2/processor/post_process/__init__.py b/deploy/python/ppshitu_v2/processor/algo_mod/post_processor/__init__.py similarity index 100% rename from deploy/python/ppshitu_v2/processor/post_process/__init__.py rename to deploy/python/ppshitu_v2/processor/algo_mod/post_processor/__init__.py diff --git a/deploy/python/ppshitu_v2/processor/post_process/det.py b/deploy/python/ppshitu_v2/processor/algo_mod/post_processor/det.py similarity index 100% rename from deploy/python/ppshitu_v2/processor/post_process/det.py rename to deploy/python/ppshitu_v2/processor/algo_mod/post_processor/det.py diff --git a/deploy/python/ppshitu_v2/processor/algo_mod/predictors/__init__.py b/deploy/python/ppshitu_v2/processor/algo_mod/predictors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..32ef848ff7b2f473d299c5bde6a924be14f2d94d --- /dev/null +++ b/deploy/python/ppshitu_v2/processor/algo_mod/predictors/__init__.py @@ -0,0 +1,7 @@ +from .fake_cls import FakeClassifier + + +def build_algo_mod(config): + algo_name = config.get("algo_name") + if algo_name == "fake_clas": + return FakeClassifier(config) diff --git a/deploy/python/ppshitu_v2/processor/algo_mod/fake_cls.py b/deploy/python/ppshitu_v2/processor/algo_mod/predictors/fake_cls.py similarity index 78% rename from deploy/python/ppshitu_v2/processor/algo_mod/fake_cls.py rename to deploy/python/ppshitu_v2/processor/algo_mod/predictors/fake_cls.py index 7d1d2a504028208b82aa16cc6943eb86a8c3c184..5a7b2ac0ab70efb3128ed6f3e875e11e657953e8 100644 --- a/deploy/python/ppshitu_v2/processor/algo_mod/fake_cls.py +++ b/deploy/python/ppshitu_v2/processor/algo_mod/predictors/fake_cls.py @@ -1,4 +1,4 @@ -from .. import BaseProcessor +from processor import BaseProcessor class FakeClassifier(BaseProcessor): diff --git a/deploy/python/ppshitu_v2/processor/algo_mod/fake_det.py b/deploy/python/ppshitu_v2/processor/algo_mod/predictors/fake_det.py similarity index 100% rename from deploy/python/ppshitu_v2/processor/algo_mod/fake_det.py rename to deploy/python/ppshitu_v2/processor/algo_mod/predictors/fake_det.py diff --git a/deploy/python/ppshitu_v2/processor/algo_mod/onnx_predictor.py b/deploy/python/ppshitu_v2/processor/algo_mod/predictors/onnx_predictor.py similarity index 100% rename from deploy/python/ppshitu_v2/processor/algo_mod/onnx_predictor.py rename to deploy/python/ppshitu_v2/processor/algo_mod/predictors/onnx_predictor.py diff --git a/deploy/python/ppshitu_v2/processor/algo_mod/paddle_predictor.py b/deploy/python/ppshitu_v2/processor/algo_mod/predictors/paddle_predictor.py similarity index 100% rename from deploy/python/ppshitu_v2/processor/algo_mod/paddle_predictor.py rename to deploy/python/ppshitu_v2/processor/algo_mod/predictors/paddle_predictor.py diff --git a/deploy/python/ppshitu_v2/processor/searcher/__init__.py b/deploy/python/ppshitu_v2/processor/algo_mod/searcher/__init__.py similarity index 100% rename from deploy/python/ppshitu_v2/processor/searcher/__init__.py rename to deploy/python/ppshitu_v2/processor/algo_mod/searcher/__init__.py diff --git a/deploy/python/ppshitu_v2/processor/data_processor/__init__.py b/deploy/python/ppshitu_v2/processor/data_processor/__init__.py deleted file mode 100644 index 4e8226ccaabf64a0e0e248239ccef870efcdb9d0..0000000000000000000000000000000000000000 --- a/deploy/python/ppshitu_v2/processor/data_processor/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# from bbox_cropper import - - -def build_data_processor(config): - return diff --git a/deploy/python/ppshitu_v2/processor/data_processor/image_reader.py b/deploy/python/ppshitu_v2/processor/data_processor/image_reader.py deleted file mode 100644 index b434a8eec5ece5c5c26785d10bbd7928a19dea81..0000000000000000000000000000000000000000 --- a/deploy/python/ppshitu_v2/processor/data_processor/image_reader.py +++ /dev/null @@ -1,9 +0,0 @@ -from .. import BaseProcessor - - -class ImageReader(BaseProcessor): - def __init__(self): - pass - - def process(self, input_data): - pass