提交 77f46bc1 编写于 作者: W weishengyu

update format

上级 cfee84c4
Global:
Engine: POPEngine
infer_imgs: "../../images/wangzai.jpg"
AlgoModule:
- Module:
preprocess:
name: ImageProcessor
processors:
Modules:
- name: Detector
type: AlgoMod
processors:
- name: ImageProcessor
type: preprocessor
ops:
- ResizeImage:
size: [640, 640]
interpolation: 2
......@@ -17,14 +15,16 @@ AlgoModule:
order: hwc
- ToCHWImage:
- GetShapeInfo:
order: chw
configs:
order: chw
- ToBatch:
predictor:
- name: PaddlePredictor
type: predictor
inference_model_dir: ./models/ppyolov2_r50vd_dcn_mainbody_v1.0_infer/
input_names:
input_names:
output_names:
postprocess:
name: DetPostProcessor
- name: PPYOLOv2PostPro
type: postprocessor
threshold: 0.2
max_det_results: 1
label_list:
......
import importlib
from processor.algo_mod import AlgoMod
class POPEngine:
def __init__(self, config):
self.algo_list = []
# last_algo_type = "start"
for algo_config in config["AlgoModule"]:
# algo_config["last_algo_type"] = last_algo_type
self.algo_list.append(AlgoMod(algo_config["Module"]))
# last_algo_type = algo_config["type"]
current_mod = importlib.import_module(__name__)
for mod_config in config["Modules"]:
mod_type = mod_config.get("type")
mod = getattr(current_mod, mod_type)(mod_config)
self.algo_list.append(mod)
def process(self, x):
def process(self, input_data):
for algo_module in self.algo_list:
x = algo_module.process(x)
return x
input_data = algo_module.process(input_data)
return input_data
......@@ -7,7 +7,6 @@ import cv2
from engine import build_engine
from utils import config
from utils.get_image_list import get_image_list
def main():
......@@ -16,13 +15,11 @@ def main():
args.config, overrides=args.override, show=False)
config_dict.profiler_options = args.profiler_options
engine = build_engine(config_dict)
image_list = get_image_list(config_dict["Global"]["infer_imgs"])
for idx, image_file in enumerate(image_list):
img = cv2.imread(image_file)[:, :, ::-1]
input_data = {"input_image": img}
output = engine.process(input_data)
print(output)
image_file = "../../images/wangzai.jpg"
img = cv2.imread(image_file)[:, :, ::-1]
input_data = {"input_image": img}
output = engine.process(input_data)
print(output)
if __name__ == '__main__':
......
from abc import ABC, abstractmethod
from processor.algo_mod import searcher
from processor.algo_mod.predictors import build_predictor
# def build_processor(config):
# print(config)
# processor_type = config.get("processor_type")
# processor_mod = locals()[processor_type]
# processor_name = config.get("processor_name")
# return getattr(processor_mod, processor_name)
# class BaseProcessor(ABC):
# @abstractmethod
# def __init__(self, config):
# pass
# @abstractmethod
# def process(self, input_data):
# pass
from .algo_mod import AlgoMod
from processor.algo_mod.data_processor import ImageProcessor
from processor.algo_mod.post_processor.det import DetPostProcessor
from processor.algo_mod.predictors import build_predictor
from .postprocessor import build_postprocessor
from .preprocessor import build_preprocessor
from .predictor import build_predictor
from ..base_processor import BaseProcessor
def build_processor(config):
# processor_type = config.get("processor_type")
# processor_mod = locals()[processor_type]
processor_name = config.get("name")
return eval(processor_name)(config)
class AlgoMod(object):
class AlgoMod(BaseProcessor):
def __init__(self, config):
self.pre_processor = build_processor(config["preprocess"])
self.predictor = build_predictor(config["predictor"])
self.post_processor = build_processor(config["postprocess"])
self.processors = []
for processor_config in config["processors"]:
processor_type = processor_config.get("type")
if processor_type == "preprocessor":
processor = build_preprocessor(processor_config)
elif processor_type == "predictor":
processor = build_predictor(processor_config)
elif processor_type == "postprocessor":
processor = build_postprocessor(processor_config)
else:
raise NotImplemented("processor type {} unknown.".format(processor_type))
self.processors.append(processor)
def process(self, input_data):
input_data = self.pre_processor.process(input_data)
input_data = self.predictor.process(input_data)
input_data = self.post_processor.process(input_data)
for processor in self.processors:
input_data = processor.process(input_data)
return input_data
from processor.algo_mod.data_processor.image_processor import ImageProcessor
from processor.algo_mod.data_processor.bbox_cropper import BBoxCropper
import importlib
from .det import PPYOLOv2PostPro
def build_postprocessor(config):
processor_mod = importlib.import_module(__name__)
processor_name = config.get("name")
return getattr(processor_mod, processor_name)(config)
from functools import reduce
import numpy as np
from utils import logger
from ...base_processor import BaseProcessor
class DetPostProcessor(object):
class PPYOLOv2PostPro(BaseProcessor):
def __init__(self, config):
super().__init__()
self.threshold = config["threshold"]
self.label_list = config["label_list"]
self.max_det_results = config["max_det_results"]
def process(self, pred):
np_boxes = pred["save_infer_model/scale_0.tmp_1"]
def process(self, input_data):
pred = input_data["pred"]
np_boxes = pred[list(pred.keys())[0]]
if reduce(lambda x, y: x * y, np_boxes.shape) < 6:
print('[WARNNING] No object detected.')
logger.warning('[Detector] No object detected.')
np_boxes = np.array([])
keep_indexes = np_boxes[:, 1].argsort()[::-1][:self.max_det_results]
......
import importlib
from processor.algo_mod.predictor.paddle_predictor import PaddlePredictor
from processor.algo_mod.predictor.onnx_predictor import ONNXPredictor
def build_predictor(config):
processor_mod = importlib.import_module(__name__)
processor_name = config.get("name")
return getattr(processor_mod, processor_name)(config)
from ...base_processor import BaseProcessor
class ONNXPredictor(BaseProcessor):
def __init__(self, config):
pass
def process(self, input_data):
raise NotImplemented("ONNXPredictor Not supported yet")
import os
import platform
from paddle.inference import create_predictor
from paddle.inference import Config as PaddleConfig
from ...base_processor import BaseProcessor
class Predictor(object):
class PaddlePredictor(BaseProcessor):
def __init__(self, config):
super().__init__()
# HALF precission predict only work when using tensorrt
if config.get("use_fp16", False):
assert config.get("use_tensorrt", False) is True
......@@ -61,5 +60,5 @@ class Predictor(object):
for output_name in output_names:
output = self.predictor.get_output_handle(output_name)
output_data[output_name] = output.copy_to_cpu()
return output_data
input_data["pred"] = output_data
return input_data
from processor.algo_mod.predictors.paddle_predictor import Predictor as paddle_predictor
from processor.algo_mod.predictors.onnx_predictor import Predictor as onnx_predictor
def build_predictor(config):
# if use paddle backend
if True:
return paddle_predictor(config)
# if use onnx backend
else:
return onnx_predictor(config)
\ No newline at end of file
class Predictor(object):
def __init__(self, config):
super().__init__()
import importlib
from processor.algo_mod.preprocessor.image_processor import ImageProcessor
def build_preprocessor(config):
processor_mod = importlib.import_module(__name__)
processor_name = config.get("name")
return getattr(processor_mod, processor_name)(config)
......@@ -3,32 +3,20 @@ import cv2
import numpy as np
import importlib
from PIL import Image
import paddle
from utils import logger
# from processor import BaseProcessor
from processor.base_processor import BaseProcessor
from abc import ABC, abstractmethod
class BaseProcessor(ABC):
@abstractmethod
def __init__(self, *args, **kwargs):
pass
@abstractmethod
def process(self, input_data):
pass
class ImageProcessor(object):
class ImageProcessor(BaseProcessor):
def __init__(self, config):
self.processors = []
for processor_config in config.get("processors"):
mod = importlib.import_module(__name__)
for processor_config in config.get("ops"):
name = list(processor_config)[0]
param = {} if processor_config[name] is None else processor_config[
name]
op = eval(name)(**param)
op = getattr(mod, name)(**param)
self.processors.append(op)
def process(self, input_data):
......@@ -39,13 +27,13 @@ class ImageProcessor(object):
input_data = processor.process(input_data)
else:
image = processor(image)
input_data["image"] = image
return input_data
class GetShapeInfo(BaseProcessor):
def __init__(self, order="hwc"):
super().__init__()
self.order = order
def __init__(self, configs):
self.order = configs.get("order")
def process(self, input_data):
input_image = input_data["input_image"]
......@@ -69,43 +57,22 @@ class GetShapeInfo(BaseProcessor):
],
dtype=np.float32)
input_data['input_shape'] = np.array(image.shape[:2], dtype=np.float32)
print(image.shape[0])
return input_data
# class ToTensor(BaseProcessor):
# def __init__(self):
# super().__init__()
# def process(self, input_data):
# image = input_data["image"]
# input_data["input_tensor"] = paddle.to_tensor(image)
# return input_data
class ToBatch(BaseProcessor):
def __init__(self):
super().__init__()
def process(self, input_data):
image = input_data["image"]
input_data["image"] = image[np.newaxis, :, :, :]
return input_data
class ToBatch:
def __call__(self, img):
img = img[np.newaxis, :, :, :]
return img
class ToRGB:
def __init__(self):
pass
def __call__(self, img):
img = img[:, :, ::-1]
return img
class ToCHWImage:
def __init__(self):
pass
def __call__(self, img, img_info=None):
img = img.transpose((2, 0, 1))
return img
......
from processor.algo_mod.data_processor.image_processor import BaseProcessor
from abc import ABC, abstractmethod
class BBoxCropper(BaseProcessor):
class BaseProcessor(ABC):
@abstractmethod
def __init__(self, config):
pass
@abstractmethod
def process(self, input_data):
pass
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册