diff --git a/deploy/python/infer.py b/deploy/python/infer.py index 8e7a261f24406ef2f8b64b698bd98a3ba70203ce..252f3e2b4a57fc2ea3c93895c3ef7bd7c679417f 100644 --- a/deploy/python/infer.py +++ b/deploy/python/infer.py @@ -268,7 +268,7 @@ class Config(): Args: model_dir (str): root path of model.yml """ - support_models = ['YOLO', 'SSD', 'RetinaNet', 'RCNN'] + support_models = ['YOLO', 'SSD', 'RetinaNet', 'RCNN', 'Face'] def __init__(self, model_dir): # parsing Yaml config for Preprocess @@ -297,8 +297,8 @@ class Config(): if support_model in yml_conf['arch']: return True raise ValueError( - "Unsupported arch: {}, expect SSD, YOLO, RetinaNet and RCNN".format( - yml_conf['arch'])) + "Unsupported arch: {}, expect SSD, YOLO, RetinaNet, RCNN and Face". + format(yml_conf['arch'])) def load_predictor(model_dir, @@ -426,7 +426,7 @@ class Detector(): def postprocess(self, np_boxes, np_masks, im_info, threshold=0.5): # postprocess output of predictor results = {} - if 'SSD' in self.config.arch: + if self.config.arch in ['SSD', 'Face']: w, h = im_info['origin_shape'] np_boxes[:, 2] *= h np_boxes[:, 3] *= w diff --git a/tools/cpp_infer.py b/tools/cpp_infer.py index 28d0e4790019726d6dd4003ba7931348121b866f..e4591be9e0abc224c4e4ac80363a592a5d388ce2 100644 --- a/tools/cpp_infer.py +++ b/tools/cpp_infer.py @@ -75,7 +75,7 @@ def get_extra_info(im, arch, shape, scale): im_size = np.array([shape[:2]]).astype('int32') logger.info('Extra info: im_size') info.append(im_size) - elif 'SSD' in arch: + elif arch in ['SSD', 'Face']: im_shape = np.array([shape[:2]]).astype('int32') logger.info('Extra info: im_shape') info.append([im_shape]) @@ -94,8 +94,8 @@ def get_extra_info(im, arch, shape, scale): info.append(im_shape) else: logger.error( - "Unsupported arch: {}, expect YOLO, SSD, RetinaNet and RCNN".format( - arch)) + "Unsupported arch: {}, expect YOLO, SSD, RetinaNet, RCNN and Face". + format(arch)) return info @@ -244,6 +244,14 @@ def get_category_info(with_background, label_list): return clsid2catid, catid2name +def clip_bbox(bbox): + xmin = max(min(bbox[0], 1.), 0.) + ymin = max(min(bbox[1], 1.), 0.) + xmax = max(min(bbox[2], 1.), 0.) + ymax = max(min(bbox[3], 1.), 0.) + return xmin, ymin, xmax, ymax + + def bbox2out(results, clsid2catid, is_bbox_normalized=False): """ Args: @@ -457,7 +465,7 @@ def draw_mask(image, masks, threshold, color_list, alpha=0.7): def get_bbox_result(output, result, conf, clsid2catid): - is_bbox_normalized = True if 'SSD' in conf['arch'] else False + is_bbox_normalized = True if conf['arch'] in ['SSD', 'Face'] else False lengths = offset_to_lengths(output.lod()) np_data = np.array(output) if conf[ 'use_python_inference'] else output.copy_to_cpu() @@ -513,7 +521,7 @@ def infer(): "Due to the limitation of tensorRT, the image shape needs to set in export_model" ) img_data = Preprocess(FLAGS.infer_img, conf['arch'], conf['Preprocess']) - if 'SSD' in conf['arch']: + if conf['arch'] in ['SSD', 'Face']: img_data, res['im_shape'] = img_data img_data = [img_data] diff --git a/tools/export_model.py b/tools/export_model.py index c6b632bd59dd512fe1a423806a47108a3fc8975a..cba13896fda4a0daacba414e216b6a16cf7011e8 100644 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -47,6 +47,12 @@ def parse_reader(reader_cfg, metric, arch): from ppdet.utils.coco_eval import get_category_info if metric == "VOC": from ppdet.utils.voc_eval import get_category_info + if metric == "WIDERFACE": + from ppdet.utils.widerface_eval_utils import get_category_info + else: + raise ValueError( + "metric only supports COCO, VOC, WIDERFACE, but received {}".format( + metric)) clsid2catid, catid2name = get_category_info(anno_file, with_background, use_default_label) label_list = [str(cat) for cat in catid2name.values()] @@ -90,7 +96,13 @@ def dump_infer_config(config): 'draw_threshold': 0.5, 'metric': config['metric'] }) - trt_min_subgraph = {'YOLO': 3, 'SSD': 40, 'RCNN': 40, 'RetinaNet': 40} + trt_min_subgraph = { + 'YOLO': 3, + 'SSD': 3, + 'RCNN': 40, + 'RetinaNet': 40, + 'Face': 3, + } infer_arch = config['architecture'] for arch, min_subgraph_size in trt_min_subgraph.items():