From 2ecbbe59259f51987388686eaf2b3b7111c70726 Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Wed, 9 Sep 2020 19:59:43 +0800 Subject: [PATCH] remove unnecessary code and sort out code in deploy/infer (#1378) --- deploy/python/infer.py | 42 +++++++++++++++++++++++++----------------- tools/export_model.py | 34 ++++++++++++++++++++-------------- 2 files changed, 45 insertions(+), 31 deletions(-) diff --git a/deploy/python/infer.py b/deploy/python/infer.py index 4f3c5ef21..57214f577 100644 --- a/deploy/python/infer.py +++ b/deploy/python/infer.py @@ -25,6 +25,24 @@ import numpy as np import paddle.fluid as fluid from visualize import visualize_box_mask +# Global dictionary +RESIZE_SCALE_SET = { + 'RCNN', + 'RetinaNet', + 'FCOS', +} + +SUPPORT_MODELS = { + 'YOLO', + 'SSD', + 'RetinaNet', + 'EfficientDet', + 'RCNN', + 'Face', + 'TTF', + 'FCOS', +} + def decode_image(im_file, im_info): """read rgb image @@ -70,11 +88,10 @@ class Resize(object): interp=cv2.INTER_LINEAR): self.target_size = target_size self.max_size = max_size - self.image_shape = image_shape, + self.image_shape = image_shape self.arch = arch self.use_cv2 = use_cv2 self.interp = interp - self.scale_set = {'RCNN', 'RetinaNet', 'FCOS'} def __call__(self, im, im_info): """ @@ -124,12 +141,12 @@ class Resize(object): Args: im (np.ndarray): image (np.ndarray) Returns: - im_scale_x: the resize ratio of X - im_scale_y: the resize ratio of Y + im_scale_x: the resize ratio of X + im_scale_y: the resize ratio of Y """ origin_shape = im.shape[:2] im_c = im.shape[2] - if self.max_size != 0 and self.arch in self.scale_set: + if self.max_size != 0 and self.arch in RESIZE_SCALE_SET: im_size_min = np.min(origin_shape[0:2]) im_size_max = np.max(origin_shape[0:2]) im_scale = float(self.target_size) / float(im_size_min) @@ -255,7 +272,7 @@ def create_inputs(im, im_info, model_arch='YOLO'): if 'YOLO' in model_arch: im_size = np.array([origin_shape]).astype('int32') inputs['im_size'] = im_size - elif 'RetinaNet' in model_arch: + elif 'RetinaNet' or 'EfficientDet' in model_arch: scale = scale_x im_info = np.array([resize_shape + [scale]]).astype('float32') inputs['im_info'] = im_info @@ -276,15 +293,6 @@ class Config(): Args: model_dir (str): root path of model.yml """ - support_models = [ - 'YOLO', - 'SSD', - 'RetinaNet', - 'RCNN', - 'Face', - 'TTF', - 'FCOS', - ] def __init__(self, model_dir): # parsing Yaml config for Preprocess @@ -307,11 +315,11 @@ class Config(): Raises: ValueError: loaded model not in supported model type """ - for support_model in self.support_models: + for support_model in SUPPORT_MODELS: if support_model in yml_conf['arch']: return True raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[ - 'arch'], self.support_models)) + 'arch'], SUPPORT_MODELS)) def print_config(self): print('----------- Model Configuration -----------') diff --git a/tools/export_model.py b/tools/export_model.py index 9570c3c06..1cf6549c7 100644 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -36,13 +36,29 @@ FORMAT = '%(asctime)s-%(levelname)s: %(message)s' logging.basicConfig(level=logging.INFO, format=FORMAT) logger = logging.getLogger(__name__) +# Global dictionary +TRT_MIN_SUBGRAPH = { + 'YOLO': 3, + 'SSD': 3, + 'RCNN': 40, + 'RetinaNet': 40, + 'EfficientDet': 40, + 'Face': 3, + 'TTFNet': 3, + 'FCOS': 3, +} +RESIZE_SCALE_SET = { + 'RCNN', + 'RetinaNet', + 'FCOS', +} + def parse_reader(reader_cfg, metric, arch): preprocess_list = [] image_shape = reader_cfg['inputs_def'].get('image_shape', [3, None, None]) has_shape_def = not None in image_shape - scale_set = {'RCNN', 'RetinaNet'} dataset = reader_cfg['dataset'] anno_file = dataset.get_anno() @@ -72,9 +88,9 @@ def parse_reader(reader_cfg, metric, arch): params.pop('_id') if p['type'] == 'Resize' and has_shape_def: params['target_size'] = min(image_shape[ - 1:]) if arch in scale_set else image_shape[1] + 1:]) if arch in RESIZE_SCALE_SET else image_shape[1] params['max_size'] = max(image_shape[ - 1:]) if arch in scale_set else 0 + 1:]) if arch in RESIZE_SCALE_SET else 0 params['image_shape'] = image_shape[1:] if 'target_dim' in params: params.pop('target_dim') @@ -114,19 +130,9 @@ def dump_infer_config(FLAGS, config): 'draw_threshold': 0.5, 'metric': config['metric'] }) - trt_min_subgraph = { - 'YOLO': 3, - 'SSD': 3, - 'RCNN': 40, - 'RetinaNet': 40, - 'Face': 3, - 'TTFNet': 3, - 'FCOS': 3, - } infer_arch = config['architecture'] - infer_arch = 'RetinaNet' if infer_arch == 'EfficientDet' else infer_arch - for arch, min_subgraph_size in trt_min_subgraph.items(): + for arch, min_subgraph_size in TRT_MIN_SUBGRAPH.items(): if arch in infer_arch: infer_cfg['arch'] = arch infer_cfg['min_subgraph_size'] = min_subgraph_size -- GitLab