未验证 提交 2ecbbe59 编写于 作者: G Guanghua Yu 提交者: GitHub

remove unnecessary code and sort out code in deploy/infer (#1378)

上级 57a02fd6
...@@ -25,6 +25,24 @@ import numpy as np ...@@ -25,6 +25,24 @@ import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from visualize import visualize_box_mask from visualize import visualize_box_mask
# Global dictionary
RESIZE_SCALE_SET = {
'RCNN',
'RetinaNet',
'FCOS',
}
SUPPORT_MODELS = {
'YOLO',
'SSD',
'RetinaNet',
'EfficientDet',
'RCNN',
'Face',
'TTF',
'FCOS',
}
def decode_image(im_file, im_info): def decode_image(im_file, im_info):
"""read rgb image """read rgb image
...@@ -70,11 +88,10 @@ class Resize(object): ...@@ -70,11 +88,10 @@ class Resize(object):
interp=cv2.INTER_LINEAR): interp=cv2.INTER_LINEAR):
self.target_size = target_size self.target_size = target_size
self.max_size = max_size self.max_size = max_size
self.image_shape = image_shape, self.image_shape = image_shape
self.arch = arch self.arch = arch
self.use_cv2 = use_cv2 self.use_cv2 = use_cv2
self.interp = interp self.interp = interp
self.scale_set = {'RCNN', 'RetinaNet', 'FCOS'}
def __call__(self, im, im_info): def __call__(self, im, im_info):
""" """
...@@ -129,7 +146,7 @@ class Resize(object): ...@@ -129,7 +146,7 @@ class Resize(object):
""" """
origin_shape = im.shape[:2] origin_shape = im.shape[:2]
im_c = im.shape[2] im_c = im.shape[2]
if self.max_size != 0 and self.arch in self.scale_set: if self.max_size != 0 and self.arch in RESIZE_SCALE_SET:
im_size_min = np.min(origin_shape[0:2]) im_size_min = np.min(origin_shape[0:2])
im_size_max = np.max(origin_shape[0:2]) im_size_max = np.max(origin_shape[0:2])
im_scale = float(self.target_size) / float(im_size_min) im_scale = float(self.target_size) / float(im_size_min)
...@@ -255,7 +272,7 @@ def create_inputs(im, im_info, model_arch='YOLO'): ...@@ -255,7 +272,7 @@ def create_inputs(im, im_info, model_arch='YOLO'):
if 'YOLO' in model_arch: if 'YOLO' in model_arch:
im_size = np.array([origin_shape]).astype('int32') im_size = np.array([origin_shape]).astype('int32')
inputs['im_size'] = im_size inputs['im_size'] = im_size
elif 'RetinaNet' in model_arch: elif 'RetinaNet' or 'EfficientDet' in model_arch:
scale = scale_x scale = scale_x
im_info = np.array([resize_shape + [scale]]).astype('float32') im_info = np.array([resize_shape + [scale]]).astype('float32')
inputs['im_info'] = im_info inputs['im_info'] = im_info
...@@ -276,15 +293,6 @@ class Config(): ...@@ -276,15 +293,6 @@ class Config():
Args: Args:
model_dir (str): root path of model.yml model_dir (str): root path of model.yml
""" """
support_models = [
'YOLO',
'SSD',
'RetinaNet',
'RCNN',
'Face',
'TTF',
'FCOS',
]
def __init__(self, model_dir): def __init__(self, model_dir):
# parsing Yaml config for Preprocess # parsing Yaml config for Preprocess
...@@ -307,11 +315,11 @@ class Config(): ...@@ -307,11 +315,11 @@ class Config():
Raises: Raises:
ValueError: loaded model not in supported model type ValueError: loaded model not in supported model type
""" """
for support_model in self.support_models: for support_model in SUPPORT_MODELS:
if support_model in yml_conf['arch']: if support_model in yml_conf['arch']:
return True return True
raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[ raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[
'arch'], self.support_models)) 'arch'], SUPPORT_MODELS))
def print_config(self): def print_config(self):
print('----------- Model Configuration -----------') print('----------- Model Configuration -----------')
......
...@@ -36,13 +36,29 @@ FORMAT = '%(asctime)s-%(levelname)s: %(message)s' ...@@ -36,13 +36,29 @@ FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT) logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Global dictionary
TRT_MIN_SUBGRAPH = {
'YOLO': 3,
'SSD': 3,
'RCNN': 40,
'RetinaNet': 40,
'EfficientDet': 40,
'Face': 3,
'TTFNet': 3,
'FCOS': 3,
}
RESIZE_SCALE_SET = {
'RCNN',
'RetinaNet',
'FCOS',
}
def parse_reader(reader_cfg, metric, arch): def parse_reader(reader_cfg, metric, arch):
preprocess_list = [] preprocess_list = []
image_shape = reader_cfg['inputs_def'].get('image_shape', [3, None, None]) image_shape = reader_cfg['inputs_def'].get('image_shape', [3, None, None])
has_shape_def = not None in image_shape has_shape_def = not None in image_shape
scale_set = {'RCNN', 'RetinaNet'}
dataset = reader_cfg['dataset'] dataset = reader_cfg['dataset']
anno_file = dataset.get_anno() anno_file = dataset.get_anno()
...@@ -72,9 +88,9 @@ def parse_reader(reader_cfg, metric, arch): ...@@ -72,9 +88,9 @@ def parse_reader(reader_cfg, metric, arch):
params.pop('_id') params.pop('_id')
if p['type'] == 'Resize' and has_shape_def: if p['type'] == 'Resize' and has_shape_def:
params['target_size'] = min(image_shape[ params['target_size'] = min(image_shape[
1:]) if arch in scale_set else image_shape[1] 1:]) if arch in RESIZE_SCALE_SET else image_shape[1]
params['max_size'] = max(image_shape[ params['max_size'] = max(image_shape[
1:]) if arch in scale_set else 0 1:]) if arch in RESIZE_SCALE_SET else 0
params['image_shape'] = image_shape[1:] params['image_shape'] = image_shape[1:]
if 'target_dim' in params: if 'target_dim' in params:
params.pop('target_dim') params.pop('target_dim')
...@@ -114,19 +130,9 @@ def dump_infer_config(FLAGS, config): ...@@ -114,19 +130,9 @@ def dump_infer_config(FLAGS, config):
'draw_threshold': 0.5, 'draw_threshold': 0.5,
'metric': config['metric'] 'metric': config['metric']
}) })
trt_min_subgraph = {
'YOLO': 3,
'SSD': 3,
'RCNN': 40,
'RetinaNet': 40,
'Face': 3,
'TTFNet': 3,
'FCOS': 3,
}
infer_arch = config['architecture'] infer_arch = config['architecture']
infer_arch = 'RetinaNet' if infer_arch == 'EfficientDet' else infer_arch
for arch, min_subgraph_size in trt_min_subgraph.items(): for arch, min_subgraph_size in TRT_MIN_SUBGRAPH.items():
if arch in infer_arch: if arch in infer_arch:
infer_cfg['arch'] = arch infer_cfg['arch'] = arch
infer_cfg['min_subgraph_size'] = min_subgraph_size infer_cfg['min_subgraph_size'] = min_subgraph_size
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册