未验证 提交 2ecbbe59 编写于 作者: G Guanghua Yu 提交者: GitHub

remove unnecessary code and sort out code in deploy/infer (#1378)

上级 57a02fd6
......@@ -25,6 +25,24 @@ import numpy as np
import paddle.fluid as fluid
from visualize import visualize_box_mask
# Global dictionary
RESIZE_SCALE_SET = {
'RCNN',
'RetinaNet',
'FCOS',
}
SUPPORT_MODELS = {
'YOLO',
'SSD',
'RetinaNet',
'EfficientDet',
'RCNN',
'Face',
'TTF',
'FCOS',
}
def decode_image(im_file, im_info):
"""read rgb image
......@@ -70,11 +88,10 @@ class Resize(object):
interp=cv2.INTER_LINEAR):
self.target_size = target_size
self.max_size = max_size
self.image_shape = image_shape,
self.image_shape = image_shape
self.arch = arch
self.use_cv2 = use_cv2
self.interp = interp
self.scale_set = {'RCNN', 'RetinaNet', 'FCOS'}
def __call__(self, im, im_info):
"""
......@@ -129,7 +146,7 @@ class Resize(object):
"""
origin_shape = im.shape[:2]
im_c = im.shape[2]
if self.max_size != 0 and self.arch in self.scale_set:
if self.max_size != 0 and self.arch in RESIZE_SCALE_SET:
im_size_min = np.min(origin_shape[0:2])
im_size_max = np.max(origin_shape[0:2])
im_scale = float(self.target_size) / float(im_size_min)
......@@ -255,7 +272,7 @@ def create_inputs(im, im_info, model_arch='YOLO'):
if 'YOLO' in model_arch:
im_size = np.array([origin_shape]).astype('int32')
inputs['im_size'] = im_size
elif 'RetinaNet' in model_arch:
elif 'RetinaNet' or 'EfficientDet' in model_arch:
scale = scale_x
im_info = np.array([resize_shape + [scale]]).astype('float32')
inputs['im_info'] = im_info
......@@ -276,15 +293,6 @@ class Config():
Args:
model_dir (str): root path of model.yml
"""
support_models = [
'YOLO',
'SSD',
'RetinaNet',
'RCNN',
'Face',
'TTF',
'FCOS',
]
def __init__(self, model_dir):
# parsing Yaml config for Preprocess
......@@ -307,11 +315,11 @@ class Config():
Raises:
ValueError: loaded model not in supported model type
"""
for support_model in self.support_models:
for support_model in SUPPORT_MODELS:
if support_model in yml_conf['arch']:
return True
raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[
'arch'], self.support_models))
'arch'], SUPPORT_MODELS))
def print_config(self):
print('----------- Model Configuration -----------')
......
......@@ -36,13 +36,29 @@ FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
# Global dictionary
TRT_MIN_SUBGRAPH = {
'YOLO': 3,
'SSD': 3,
'RCNN': 40,
'RetinaNet': 40,
'EfficientDet': 40,
'Face': 3,
'TTFNet': 3,
'FCOS': 3,
}
RESIZE_SCALE_SET = {
'RCNN',
'RetinaNet',
'FCOS',
}
def parse_reader(reader_cfg, metric, arch):
preprocess_list = []
image_shape = reader_cfg['inputs_def'].get('image_shape', [3, None, None])
has_shape_def = not None in image_shape
scale_set = {'RCNN', 'RetinaNet'}
dataset = reader_cfg['dataset']
anno_file = dataset.get_anno()
......@@ -72,9 +88,9 @@ def parse_reader(reader_cfg, metric, arch):
params.pop('_id')
if p['type'] == 'Resize' and has_shape_def:
params['target_size'] = min(image_shape[
1:]) if arch in scale_set else image_shape[1]
1:]) if arch in RESIZE_SCALE_SET else image_shape[1]
params['max_size'] = max(image_shape[
1:]) if arch in scale_set else 0
1:]) if arch in RESIZE_SCALE_SET else 0
params['image_shape'] = image_shape[1:]
if 'target_dim' in params:
params.pop('target_dim')
......@@ -114,19 +130,9 @@ def dump_infer_config(FLAGS, config):
'draw_threshold': 0.5,
'metric': config['metric']
})
trt_min_subgraph = {
'YOLO': 3,
'SSD': 3,
'RCNN': 40,
'RetinaNet': 40,
'Face': 3,
'TTFNet': 3,
'FCOS': 3,
}
infer_arch = config['architecture']
infer_arch = 'RetinaNet' if infer_arch == 'EfficientDet' else infer_arch
for arch, min_subgraph_size in trt_min_subgraph.items():
for arch, min_subgraph_size in TRT_MIN_SUBGRAPH.items():
if arch in infer_arch:
infer_cfg['arch'] = arch
infer_cfg['min_subgraph_size'] = min_subgraph_size
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册