未验证 提交 922bc008 编写于 作者: W wangguanzhong 提交者: GitHub

Update export model for face detection (#501)

* support face detection in python inference

* support face detection in python deploy
上级 9c7b2cbc
...@@ -268,7 +268,7 @@ class Config(): ...@@ -268,7 +268,7 @@ class Config():
Args: Args:
model_dir (str): root path of model.yml model_dir (str): root path of model.yml
""" """
support_models = ['YOLO', 'SSD', 'RetinaNet', 'RCNN'] support_models = ['YOLO', 'SSD', 'RetinaNet', 'RCNN', 'Face']
def __init__(self, model_dir): def __init__(self, model_dir):
# parsing Yaml config for Preprocess # parsing Yaml config for Preprocess
...@@ -297,8 +297,8 @@ class Config(): ...@@ -297,8 +297,8 @@ class Config():
if support_model in yml_conf['arch']: if support_model in yml_conf['arch']:
return True return True
raise ValueError( raise ValueError(
"Unsupported arch: {}, expect SSD, YOLO, RetinaNet and RCNN".format( "Unsupported arch: {}, expect SSD, YOLO, RetinaNet, RCNN and Face".
yml_conf['arch'])) format(yml_conf['arch']))
def load_predictor(model_dir, def load_predictor(model_dir,
...@@ -426,7 +426,7 @@ class Detector(): ...@@ -426,7 +426,7 @@ class Detector():
def postprocess(self, np_boxes, np_masks, im_info, threshold=0.5): def postprocess(self, np_boxes, np_masks, im_info, threshold=0.5):
# postprocess output of predictor # postprocess output of predictor
results = {} results = {}
if 'SSD' in self.config.arch: if self.config.arch in ['SSD', 'Face']:
w, h = im_info['origin_shape'] w, h = im_info['origin_shape']
np_boxes[:, 2] *= h np_boxes[:, 2] *= h
np_boxes[:, 3] *= w np_boxes[:, 3] *= w
......
...@@ -75,7 +75,7 @@ def get_extra_info(im, arch, shape, scale): ...@@ -75,7 +75,7 @@ def get_extra_info(im, arch, shape, scale):
im_size = np.array([shape[:2]]).astype('int32') im_size = np.array([shape[:2]]).astype('int32')
logger.info('Extra info: im_size') logger.info('Extra info: im_size')
info.append(im_size) info.append(im_size)
elif 'SSD' in arch: elif arch in ['SSD', 'Face']:
im_shape = np.array([shape[:2]]).astype('int32') im_shape = np.array([shape[:2]]).astype('int32')
logger.info('Extra info: im_shape') logger.info('Extra info: im_shape')
info.append([im_shape]) info.append([im_shape])
...@@ -94,8 +94,8 @@ def get_extra_info(im, arch, shape, scale): ...@@ -94,8 +94,8 @@ def get_extra_info(im, arch, shape, scale):
info.append(im_shape) info.append(im_shape)
else: else:
logger.error( logger.error(
"Unsupported arch: {}, expect YOLO, SSD, RetinaNet and RCNN".format( "Unsupported arch: {}, expect YOLO, SSD, RetinaNet, RCNN and Face".
arch)) format(arch))
return info return info
...@@ -244,6 +244,14 @@ def get_category_info(with_background, label_list): ...@@ -244,6 +244,14 @@ def get_category_info(with_background, label_list):
return clsid2catid, catid2name return clsid2catid, catid2name
def clip_bbox(bbox):
xmin = max(min(bbox[0], 1.), 0.)
ymin = max(min(bbox[1], 1.), 0.)
xmax = max(min(bbox[2], 1.), 0.)
ymax = max(min(bbox[3], 1.), 0.)
return xmin, ymin, xmax, ymax
def bbox2out(results, clsid2catid, is_bbox_normalized=False): def bbox2out(results, clsid2catid, is_bbox_normalized=False):
""" """
Args: Args:
...@@ -457,7 +465,7 @@ def draw_mask(image, masks, threshold, color_list, alpha=0.7): ...@@ -457,7 +465,7 @@ def draw_mask(image, masks, threshold, color_list, alpha=0.7):
def get_bbox_result(output, result, conf, clsid2catid): def get_bbox_result(output, result, conf, clsid2catid):
is_bbox_normalized = True if 'SSD' in conf['arch'] else False is_bbox_normalized = True if conf['arch'] in ['SSD', 'Face'] else False
lengths = offset_to_lengths(output.lod()) lengths = offset_to_lengths(output.lod())
np_data = np.array(output) if conf[ np_data = np.array(output) if conf[
'use_python_inference'] else output.copy_to_cpu() 'use_python_inference'] else output.copy_to_cpu()
...@@ -513,7 +521,7 @@ def infer(): ...@@ -513,7 +521,7 @@ def infer():
"Due to the limitation of tensorRT, the image shape needs to set in export_model" "Due to the limitation of tensorRT, the image shape needs to set in export_model"
) )
img_data = Preprocess(FLAGS.infer_img, conf['arch'], conf['Preprocess']) img_data = Preprocess(FLAGS.infer_img, conf['arch'], conf['Preprocess'])
if 'SSD' in conf['arch']: if conf['arch'] in ['SSD', 'Face']:
img_data, res['im_shape'] = img_data img_data, res['im_shape'] = img_data
img_data = [img_data] img_data = [img_data]
......
...@@ -47,6 +47,12 @@ def parse_reader(reader_cfg, metric, arch): ...@@ -47,6 +47,12 @@ def parse_reader(reader_cfg, metric, arch):
from ppdet.utils.coco_eval import get_category_info from ppdet.utils.coco_eval import get_category_info
if metric == "VOC": if metric == "VOC":
from ppdet.utils.voc_eval import get_category_info from ppdet.utils.voc_eval import get_category_info
if metric == "WIDERFACE":
from ppdet.utils.widerface_eval_utils import get_category_info
else:
raise ValueError(
"metric only supports COCO, VOC, WIDERFACE, but received {}".format(
metric))
clsid2catid, catid2name = get_category_info(anno_file, with_background, clsid2catid, catid2name = get_category_info(anno_file, with_background,
use_default_label) use_default_label)
label_list = [str(cat) for cat in catid2name.values()] label_list = [str(cat) for cat in catid2name.values()]
...@@ -90,7 +96,13 @@ def dump_infer_config(config): ...@@ -90,7 +96,13 @@ def dump_infer_config(config):
'draw_threshold': 0.5, 'draw_threshold': 0.5,
'metric': config['metric'] 'metric': config['metric']
}) })
trt_min_subgraph = {'YOLO': 3, 'SSD': 40, 'RCNN': 40, 'RetinaNet': 40} trt_min_subgraph = {
'YOLO': 3,
'SSD': 3,
'RCNN': 40,
'RetinaNet': 40,
'Face': 3,
}
infer_arch = config['architecture'] infer_arch = config['architecture']
for arch, min_subgraph_size in trt_min_subgraph.items(): for arch, min_subgraph_size in trt_min_subgraph.items():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册