From 79616f79d8badf9a087a9bc56c42e2ddda15a891 Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Thu, 26 Mar 2020 23:08:10 +0800 Subject: [PATCH] Refine inference on python (#259) * export config file in export_model.py * add custom label_list & dump_box * add infer_cfg.yml & doc * export label_list to inference config * update PaddleTensor to ZeroCopyTensor --- demo/infer_cfg.yml | 48 ++ .../inference/EXPORT_MODEL.md | 5 +- .../advanced_tutorials/inference/INFERENCE.md | 3 +- ppdet/core/config/yaml_helpers.py | 9 + tools/cpp_demo.yml | 29 -- tools/cpp_infer.py | 411 +++++++++++++++--- tools/export_model.py | 84 +++- 7 files changed, 498 insertions(+), 91 deletions(-) create mode 100644 demo/infer_cfg.yml delete mode 100644 tools/cpp_demo.yml diff --git a/demo/infer_cfg.yml b/demo/infer_cfg.yml new file mode 100644 index 000000000..99f1d63fa --- /dev/null +++ b/demo/infer_cfg.yml @@ -0,0 +1,48 @@ +draw_threshold: 0.5 +use_python_inference: false +mode: fluid +metric: VOC +arch: YOLO +min_subgraph_size: 3 +with_background: false +Preprocess: +- interp: 2 + max_size: 0 + target_size: 608 + type: Resize + use_cv2: true +- is_channel_first: false + is_scale: true + mean: + - 0.485 + - 0.456 + - 0.406 + std: + - 0.229 + - 0.224 + - 0.225 + type: Normalize +- channel_first: true + to_bgr: false + type: Permute +label_list: +- aeroplane +- bicycle +- bird +- boat +- bottle +- bus +- car +- cat +- chair +- cow +- diningtable +- dog +- horse +- motorbike +- person +- pottedplant +- sheep +- sofa +- train +- tvmonitor diff --git a/docs/advanced_tutorials/inference/EXPORT_MODEL.md b/docs/advanced_tutorials/inference/EXPORT_MODEL.md index 8cd73a4bf..47c074e7f 100644 --- a/docs/advanced_tutorials/inference/EXPORT_MODEL.md +++ b/docs/advanced_tutorials/inference/EXPORT_MODEL.md @@ -1,7 +1,6 @@ # 模型导出 -训练得到一个满足要求的模型后,如果想要将该模型接入到C++预测库或者Serving服务,需要通过`tools/export_model.py`导出该模型。 - +训练得到一个满足要求的模型后,如果想要将该模型接入到C++预测库或者Serving服务,需要通过`tools/export_model.py`导出该模型。同时,会导出预测时使用的配置文件,路径与模型保存路径相同, 配置文件名为`infer_cfg.yml`。 **说明:** - 导出模型输入为网络输入图像,即原始图片经过预处理后的图像,具体预处理方式可参考配置文件中TestReader部分。各类检测模型的输入格式分别为: @@ -57,6 +56,6 @@ python tools/export_model.py -c configs/yolov3_darknet.yml \ # 导出SSD模型,输入是3x300x300 python tools/export_model.py -c configs/ssd/ssd_mobilenet_v1_voc.yml \ --output_dir=./inference_model \ - -o weights= https://paddlemodels.bj.bcebos.com/object_detection/ssd_mobilenet_v1_voc.tar \ + -o weights=https://paddlemodels.bj.bcebos.com/object_detection/ssd_mobilenet_v1_voc.tar \ TestReader.inputs_def.image_shape=[3,300,300] ``` diff --git a/docs/advanced_tutorials/inference/INFERENCE.md b/docs/advanced_tutorials/inference/INFERENCE.md index c97db89d0..36505f988 100644 --- a/docs/advanced_tutorials/inference/INFERENCE.md +++ b/docs/advanced_tutorials/inference/INFERENCE.md @@ -22,9 +22,10 @@ python tools/cpp_infer.py --model_path=inference_model/faster_rcnn_r50_1x/ --con - config_path: 参数配置、数据预处理配置文件,注意不是训练时的配置文件 - infer_img: 待预测图片 - visualize: 是否保存可视化结果,默认保存路径为```output/``` +- dump_result: 是否保存预测结果,保存格式为json文件,默认保存路径为```output/``` -更多参数可在```tools/cpp_demo.yml```中查看,主要参数: +更多参数可在```demo/infer_cfg.yml```中查看,主要参数: - use_python_inference: diff --git a/ppdet/core/config/yaml_helpers.py b/ppdet/core/config/yaml_helpers.py index 8a7738b47..1545b6be7 100644 --- a/ppdet/core/config/yaml_helpers.py +++ b/ppdet/core/config/yaml_helpers.py @@ -21,6 +21,15 @@ from .schema import SharedConfig __all__ = ['serializable', 'Callable'] +def represent_dictionary_order(self, dict_data): + return self.represent_mapping('tag:yaml.org,2002:map', dict_data.items()) + + +def setup_orderdict(): + from collections import OrderedDict + yaml.add_representer(OrderedDict, represent_dictionary_order) + + def _make_python_constructor(cls): def python_constructor(loader, node): if isinstance(node, yaml.SequenceNode): diff --git a/tools/cpp_demo.yml b/tools/cpp_demo.yml deleted file mode 100644 index 7fac69840..000000000 --- a/tools/cpp_demo.yml +++ /dev/null @@ -1,29 +0,0 @@ -# demo for cpp_infer.py - -use_python_inference: true # whether to use python inference -mode: fluid # trt_fp32, trt_fp16, trt_int8, fluid -arch: RCNN # YOLO, SSD, RCNN, RetinaNet -min_subgraph_size: 40 # need 3 for YOLO arch - -# visualize the predicted image -metric: COCO # COCO, VOC -draw_threshold: 0.5 - -Preprocess: -- type: Resize - target_size: 640 - max_size: 640 -- type: Normalize - mean: - - 0.485 - - 0.456 - - 0.406 - std: - - 0.229 - - 0.224 - - 0.225 - is_scale: True -- type: Permute - to_bgr: False -- type: PadStride - stride: 0 # set 32 on FPN and 128 on RetinaNet diff --git a/tools/cpp_infer.py b/tools/cpp_infer.py index c7232897f..2dd2cce44 100644 --- a/tools/cpp_infer.py +++ b/tools/cpp_infer.py @@ -2,15 +2,11 @@ import os import time import numpy as np -from PIL import Image +from PIL import Image, ImageDraw import paddle.fluid as fluid import argparse -from ppdet.utils.visualizer import visualize_results, draw_bbox -from ppdet.utils.eval_utils import eval_results -import ppdet.utils.voc_eval as voc_eval -import ppdet.utils.coco_eval as coco_eval import cv2 import yaml import copy @@ -20,8 +16,6 @@ FORMAT = '%(asctime)s-%(levelname)s: %(message)s' logging.basicConfig(level=logging.INFO, format=FORMAT) logger = logging.getLogger(__name__) -eval_clses = {'COCO': coco_eval, 'VOC': voc_eval} - precision_map = { 'trt_int8': fluid.core.AnalysisConfig.Precision.Int8, 'trt_fp32': fluid.core.AnalysisConfig.Precision.Float32, @@ -34,6 +28,8 @@ def create_config(model_path, mode='fluid', batch_size=1, min_subgraph_size=3): params_file = os.path.join(model_path, '__params__') config = fluid.core.AnalysisConfig(model_file, params_file) config.enable_use_gpu(100, 0) + config.switch_use_feed_fetch_ops(False) + config.switch_specify_input_names(True) logger.info('min_subgraph_size = %d.' % (min_subgraph_size)) if mode in precision_map.keys(): @@ -60,6 +56,8 @@ def offset_to_lengths(lod): def DecodeImage(im_path): + assert os.path.exists(im_path), "Image path {} can not be found".format( + im_path) with open(im_path, 'rb') as f: im = f.read() data = np.frombuffer(im, dtype='uint8') @@ -102,17 +100,21 @@ def get_extra_info(im, arch, shape, scale): class Resize(object): - def __init__(self, target_size, max_size=0, interp=cv2.INTER_LINEAR): + def __init__(self, + target_size, + max_size=0, + interp=cv2.INTER_LINEAR, + use_cv2=True): super(Resize, self).__init__() self.target_size = target_size self.max_size = max_size self.interp = interp + self.use_cv2 = use_cv2 - def __call__(self, im, arch): + def __call__(self, im, use_trt=False): origin_shape = im.shape[:2] im_c = im.shape[2] - scale_set = {'RCNN', 'RetinaNet'} - if self.max_size != 0 and arch in scale_set: + if self.max_size != 0: im_size_min = np.min(origin_shape[0:2]) im_size_max = np.max(origin_shape[0:2]) im_scale = float(self.target_size) / float(im_size_min) @@ -125,15 +127,30 @@ class Resize(object): else: im_scale_x = float(self.target_size) / float(origin_shape[1]) im_scale_y = float(self.target_size) / float(origin_shape[0]) - im = cv2.resize( - im, - None, - None, - fx=im_scale_x, - fy=im_scale_y, - interpolation=self.interp) + resize_w = self.target_size + resize_h = self.target_size + if self.use_cv2: + im = cv2.resize( + im, + None, + None, + fx=im_scale_x, + fy=im_scale_y, + interpolation=self.interp) + else: + if self.max_size != 0: + raise TypeError( + 'If you set max_size to cap the maximum size of image,' + 'please set use_cv2 to True to resize the image.') + im = im.astype('uint8') + im = Image.fromarray(im) + im = im.resize((int(resize_w), int(resize_h)), self.interp) + im = np.array(im) # padding im - if self.max_size != 0 and arch in scale_set: + if self.max_size != 0 and use_trt: + logger.warning('Due to the limitation of tensorRT, padding the ' + 'image shape to {} * {}'.format(self.max_size, + self.max_size)) padding_im = np.zeros( (self.max_size, self.max_size, im_c), dtype=np.float32) im_h, im_w = im.shape[:2] @@ -143,27 +160,36 @@ class Resize(object): class Normalize(object): - def __init__(self, mean, std, is_scale=True): + def __init__(self, mean, std, is_scale=True, is_channel_first=False): super(Normalize, self).__init__() self.mean = mean self.std = std self.is_scale = is_scale + self.is_channel_first = is_channel_first def __call__(self, im): im = im.astype(np.float32, copy=False) + if self.is_channel_first: + mean = np.array(self.mean)[:, np.newaxis, np.newaxis] + std = np.array(self.std)[:, np.newaxis, np.newaxis] + else: + mean = np.array(self.mean)[np.newaxis, np.newaxis, :] + std = np.array(self.std)[np.newaxis, np.newaxis, :] if self.is_scale: im = im / 255.0 - im -= self.mean - im /= self.std + im -= mean + im /= std return im class Permute(object): - def __init__(self, to_bgr=False): + def __init__(self, to_bgr=False, channel_first=True): self.to_bgr = to_bgr + self.channel_first = channel_first def __call__(self, im): - im = im.transpose((2, 0, 1)).copy() + if self.channel_first: + im = im.transpose((2, 0, 1)).copy() if self.to_bgr: im = im[[2, 1, 0], :, :] return im @@ -171,8 +197,9 @@ class Permute(object): class PadStride(object): def __init__(self, stride=0): - assert stride >= 0, "Unsupported stride: {}, the stride in PadStride must be greater or equal to 0".format( - stride) + assert stride >= 0, "Unsupported stride: {}," + " the stride in PadStride must be greater " + "or equal to 0".format(stride) self.coarsest_stride = stride def __call__(self, im): @@ -187,7 +214,7 @@ class PadStride(object): return padding_im -def Preprocess(img_path, arch, config): +def Preprocess(img_path, arch, config, use_trt): img = DecodeImage(img_path) orig_shape = img.shape scale = 1. @@ -197,7 +224,7 @@ def Preprocess(img_path, arch, config): obj = data_aug_conf.pop('type') preprocess = eval(obj)(**data_aug_conf) if obj == 'Resize': - img, scale = preprocess(img, arch) + img, scale = preprocess(img, use_trt) else: img = preprocess(img) @@ -208,6 +235,268 @@ def Preprocess(img_path, arch, config): return data +def get_category_info(with_background, label_list): + if label_list[0] != 'background' and with_background: + label_list.insert(0, 'background') + if label_list[0] == 'background' and not with_background: + label_list = label_list[1:] + clsid2catid = {i: i for i in range(len(label_list))} + catid2name = {i: name for i, name in enumerate(label_list)} + return clsid2catid, catid2name + + +def bbox2out(results, clsid2catid, is_bbox_normalized=False): + """ + Args: + results: request a dict, should include: `bbox`, `im_id`, + if is_bbox_normalized=True, also need `im_shape`. + clsid2catid: class id to category id map of COCO2017 dataset. + is_bbox_normalized: whether or not bbox is normalized. + """ + xywh_res = [] + for t in results: + bboxes = t['bbox'][0] + lengths = t['bbox'][1][0] + if bboxes.shape == (1, 1) or bboxes is None: + continue + + k = 0 + for i in range(len(lengths)): + num = lengths[i] + for j in range(num): + dt = bboxes[k] + clsid, score, xmin, ymin, xmax, ymax = dt.tolist() + catid = (clsid2catid[int(clsid)]) + + if is_bbox_normalized: + xmin, ymin, xmax, ymax = \ + clip_bbox([xmin, ymin, xmax, ymax]) + w = xmax - xmin + h = ymax - ymin + im_shape = t['im_shape'][0][i].tolist() + im_height, im_width = int(im_shape[0]), int(im_shape[1]) + xmin *= im_width + ymin *= im_height + w *= im_width + h *= im_height + else: + w = xmax - xmin + 1 + h = ymax - ymin + 1 + + bbox = [xmin, ymin, w, h] + coco_res = {'category_id': catid, 'bbox': bbox, 'score': score} + xywh_res.append(coco_res) + k += 1 + return xywh_res + + +def expand_boxes(boxes, scale): + """ + Expand an array of boxes by a given scale. + """ + w_half = (boxes[:, 2] - boxes[:, 0]) * .5 + h_half = (boxes[:, 3] - boxes[:, 1]) * .5 + x_c = (boxes[:, 2] + boxes[:, 0]) * .5 + y_c = (boxes[:, 3] + boxes[:, 1]) * .5 + + w_half *= scale + h_half *= scale + + boxes_exp = np.zeros(boxes.shape) + boxes_exp[:, 0] = x_c - w_half + boxes_exp[:, 2] = x_c + w_half + boxes_exp[:, 1] = y_c - h_half + boxes_exp[:, 3] = y_c + h_half + + return boxes_exp + + +def mask2out(results, clsid2catid, resolution, thresh_binarize=0.5): + import pycocotools.mask as mask_util + scale = (resolution + 2.0) / resolution + + segm_res = [] + + for t in results: + bboxes = t['bbox'][0] + lengths = t['bbox'][1][0] + if bboxes.shape == (1, 1) or bboxes is None: + continue + if len(bboxes.tolist()) == 0: + continue + masks = t['mask'][0] + + s = 0 + # for each sample + for i in range(len(lengths)): + num = lengths[i] + im_shape = t['im_shape'][i] + + bbox = bboxes[s:s + num][:, 2:] + clsid_scores = bboxes[s:s + num][:, 0:2] + mask = masks[s:s + num] + s += num + + im_h = int(im_shape[0]) + im_w = int(im_shape[1]) + + expand_bbox = expand_boxes(bbox, scale) + expand_bbox = expand_bbox.astype(np.int32) + + padded_mask = np.zeros( + (resolution + 2, resolution + 2), dtype=np.float32) + + for j in range(num): + xmin, ymin, xmax, ymax = expand_bbox[j].tolist() + clsid, score = clsid_scores[j].tolist() + clsid = int(clsid) + padded_mask[1:-1, 1:-1] = mask[j, clsid, :, :] + + catid = clsid2catid[clsid] + + w = xmax - xmin + 1 + h = ymax - ymin + 1 + w = np.maximum(w, 1) + h = np.maximum(h, 1) + + resized_mask = cv2.resize(padded_mask, (w, h)) + resized_mask = np.array( + resized_mask > thresh_binarize, dtype=np.uint8) + im_mask = np.zeros((im_h, im_w), dtype=np.uint8) + + x0 = min(max(xmin, 0), im_w) + x1 = min(max(xmax + 1, 0), im_w) + y0 = min(max(ymin, 0), im_h) + y1 = min(max(ymax + 1, 0), im_h) + + im_mask[y0:y1, x0:x1] = resized_mask[(y0 - ymin):(y1 - ymin), ( + x0 - xmin):(x1 - xmin)] + segm = mask_util.encode( + np.array( + im_mask[:, :, np.newaxis], order='F'))[0] + catid = clsid2catid[clsid] + segm['counts'] = segm['counts'].decode('utf8') + coco_res = { + 'category_id': catid, + 'segmentation': segm, + 'score': score + } + segm_res.append(coco_res) + return segm_res + + +def color_map(num_classes): + color_map = num_classes * [0, 0, 0] + for i in range(0, num_classes): + j = 0 + lab = i + while lab: + color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j)) + color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j)) + color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j)) + j += 1 + lab >>= 3 + color_map = np.array(color_map).reshape(-1, 3) + return color_map + + +def draw_bbox(image, catid2name, bboxes, threshold, color_list): + """ + draw bbox on image + """ + draw = ImageDraw.Draw(image) + + for dt in np.array(bboxes): + catid, bbox, score = dt['category_id'], dt['bbox'], dt['score'] + if score < threshold: + continue + + xmin, ymin, w, h = bbox + xmax = xmin + w + ymax = ymin + h + + color = tuple(color_list[catid]) + + # draw bbox + draw.line( + [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin), + (xmin, ymin)], + width=2, + fill=color) + + # draw label + text = "{} {:.2f}".format(catid2name[catid], score) + tw, th = draw.textsize(text) + draw.rectangle( + [(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color) + draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255)) + + return image + + +def draw_mask(image, masks, threshold, color_list, alpha=0.7): + """ + Draw mask on image + """ + mask_color_id = 0 + w_ratio = .4 + img_array = np.array(image).astype('float32') + for dt in np.array(masks): + segm, score = dt['segmentation'], dt['score'] + if score < threshold: + continue + import pycocotools.mask as mask_util + mask = mask_util.decode(segm) * 255 + color_mask = color_list[mask_color_id % len(color_list), 0:3] + mask_color_id += 1 + for c in range(3): + color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255 + idx = np.nonzero(mask) + img_array[idx[0], idx[1], :] *= 1.0 - alpha + img_array[idx[0], idx[1], :] += alpha * color_mask + return Image.fromarray(img_array.astype('uint8')) + + +def get_bbox_result(output, result, conf, clsid2catid): + is_bbox_normalized = True if 'SSD' in conf['arch'] else False + lengths = offset_to_lengths(output.lod()) + np_data = np.array(output) if conf[ + 'use_python_inference'] else output.copy_to_cpu() + result['bbox'] = (np_data, lengths) + result['im_id'] = np.array([[0]]) + + bbox_results = bbox2out([result], clsid2catid, is_bbox_normalized) + return bbox_results + + +def get_mask_result(output, result, conf, clsid2catid): + resolution = conf['mask_resolution'] + bbox_out, mask_out = output + lengths = offset_to_lengths(bbox_out.lod()) + bbox = np.array(bbox_out) if conf[ + 'use_python_inference'] else bbox_out.copy_to_cpu() + mask = np.array(mask_out) if conf[ + 'use_python_inference'] else mask_out.copy_to_cpu() + result['bbox'] = (bbox, lengths) + result['mask'] = (mask, lengths) + mask_results = mask2out([result], clsid2catid, conf['mask_resolution']) + return mask_results + + +def visualize(bbox_results, catid2name, num_classes, mask_results=None): + image = Image.open(FLAGS.infer_img).convert('RGB') + color_list = color_map(num_classes) + image = draw_bbox(image, catid2name, bbox_results, 0.5, color_list) + if mask_results is not None: + image = draw_mask(image, mask_results, 0.5, color_list) + image_path = os.path.split(FLAGS.infer_img)[-1] + if not os.path.exists(FLAGS.output_dir): + os.makedirs(FLAGS.output_dir) + out_path = os.path.join(FLAGS.output_dir, image_path) + image.save(out_path, quality=95) + logger.info('Save visualize result to {}'.format(out_path)) + + def infer(): model_path = FLAGS.model_path config_path = FLAGS.config_path @@ -219,7 +508,9 @@ def infer(): with open(config_path) as f: conf = yaml.safe_load(f) - img_data = Preprocess(FLAGS.infer_img, conf['arch'], conf['Preprocess']) + use_trt = not conf['use_python_inference'] and 'trt' in conf['mode'] + img_data = Preprocess(FLAGS.infer_img, conf['arch'], conf['Preprocess'], + use_trt) if 'SSD' in conf['arch']: img_data, res['im_shape'] = img_data img_data = [img_data] @@ -234,12 +525,15 @@ def infer(): params_filename='__params__') data_dict = {k: v for k, v in zip(feed_var_names, img_data)} else: - inputs = [fluid.core.PaddleTensor(d.copy()) for d in img_data] config = create_config( model_path, mode=conf['mode'], min_subgraph_size=conf['min_subgraph_size']) predict = fluid.core.create_paddle_predictor(config) + input_names = predict.get_input_names() + for ind, d in enumerate(img_data): + input_tensor = predict.get_input_tensor(input_names[ind]) + input_tensor.copy_from_cpu(d.copy()) logger.info('warmup...') for i in range(10): @@ -249,7 +543,7 @@ def infer(): fetch_list=fetch_targets, return_numpy=False) else: - outs = predict.run(inputs) + predict.zero_copy_run() cnt = 100 logger.info('run benchmark...') @@ -261,40 +555,40 @@ def infer(): fetch_list=fetch_targets, return_numpy=False) else: - outs = predict.run(inputs) + outs = [] + predict.zero_copy_run() + output_names = predict.get_output_names() + for o_name in output_names: + outs.append(predict.get_output_tensor(o_name)) t2 = time.time() ms = (t2 - t1) * 1000.0 / float(cnt) print("Inference: {} ms per batch image".format(ms)) - if FLAGS.visualize: - eval_cls = eval_clses[conf['metric']] - - with_background = conf['arch'] != 'YOLO' - clsid2catid, catid2name = eval_cls.get_category_info( - None, with_background, True) + clsid2catid, catid2name = get_category_info(conf['with_background'], + conf['label_list']) + bbox_result = get_bbox_result(outs[0], res, conf, clsid2catid) - is_bbox_normalized = True if 'SSD' in conf['arch'] else False + mask_result = None + if 'mask_resolution' in conf: + res['im_shape'] = img_data[-1] + mask_result = get_mask_result(outs, res, conf, clsid2catid) - out = outs[-1] - lod = out.lod() if conf['use_python_inference'] else out.lod - lengths = offset_to_lengths(lod) - np_data = np.array(out) if conf[ - 'use_python_inference'] else out.as_ndarray() - - res['bbox'] = (np_data, lengths) - res['im_id'] = np.array([[0]]) - - bbox_results = eval_cls.bbox2out([res], clsid2catid, is_bbox_normalized) + if FLAGS.visualize: + visualize(bbox_result, catid2name, len(conf['label_list']), mask_result) - image = Image.open(FLAGS.infer_img).convert('RGB') - image = draw_bbox(image, 0, catid2name, bbox_results, 0.5) - image_path = os.path.split(FLAGS.infer_img)[-1] - if not os.path.exists(FLAGS.output_dir): - os.makedirs(FLAGS.output_dir) - out_path = os.path.join(FLAGS.output_dir, image_path) - image.save(out_path, quality=95) + if FLAGS.dump_result: + import json + bbox_file = os.path.join(FLAGS.output_dir, 'bbox.json') + logger.info('dump bbox to {}'.format(bbox_file)) + with open(bbox_file, 'w') as f: + json.dump(bbox_result, f) + if mask_result is not None: + mask_file = os.path.join(FLAGS.output_dir, 'mask.json') + logger.info('dump mask to {}'.format(mask_file)) + with open(mask_file, 'w') as f: + json.dump(mask_result, f) if __name__ == '__main__': @@ -315,5 +609,10 @@ if __name__ == '__main__': type=str, default="output", help="Directory for storing the output visualization files.") + parser.add_argument( + "--dump_result", + action='store_true', + default=False, + help="Whether to dump result") FLAGS = parser.parse_args() infer() diff --git a/tools/export_model.py b/tools/export_model.py index 9e31aca53..e2e166f91 100644 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -23,13 +23,90 @@ from paddle import fluid from ppdet.core.workspace import load_config, merge_config, create from ppdet.utils.cli import ArgsParser import ppdet.utils.checkpoint as checkpoint - +import yaml import logging +from collections import OrderedDict FORMAT = '%(asctime)s-%(levelname)s: %(message)s' logging.basicConfig(level=logging.INFO, format=FORMAT) logger = logging.getLogger(__name__) +def parse_reader(reader_cfg, metric, arch): + preprocess_list = [] + + image_shape = reader_cfg['inputs_def'].get('image_shape', [None]) + has_shape_def = not None in image_shape + scale_set = {'RCNN', 'RetinaNet'} + + dataset = reader_cfg['dataset'] + anno_file = dataset.get_anno() + with_background = dataset.with_background + use_default_label = dataset.use_default_label + + if metric == 'COCO': + from ppdet.utils.coco_eval import get_category_info + if metric == "VOC": + from ppdet.utils.voc_eval import get_category_info + clsid2catid, catid2name = get_category_info(anno_file, with_background, + use_default_label) + label_list = [str(cat) for cat in catid2name.values()] + + sample_transforms = reader_cfg['sample_transforms'] + for st in sample_transforms[1:]: + method = st.__class__.__name__ + p = {'type': method.replace('Image', '')} + params = st.__dict__ + params.pop('_id') + if p['type'] == 'Resize' and has_shape_def: + params['target_size'] = image_shape[1] + params['max_size'] = image_shape[2] if arch in scale_set else 0 + + p.update(params) + preprocess_list.append(p) + batch_transforms = reader_cfg.get('batch_transforms', None) + if batch_transforms: + methods = [bt.__class__.__name__ for bt in batch_transforms] + for bt in batch_transforms: + method = bt.__class__.__name__ + if method == 'PadBatch': + preprocess_list.append({'type': 'PadStride'}) + params = bt.__dict__ + preprocess_list[-1].update({'stride': params['pad_to_stride']}) + break + + return with_background, preprocess_list, label_list + + +def dump_infer_config(config): + cfg_name = os.path.basename(FLAGS.config).split('.')[0] + save_dir = os.path.join(FLAGS.output_dir, cfg_name) + from ppdet.core.config.yaml_helpers import setup_orderdict + setup_orderdict() + infer_cfg = OrderedDict({ + 'use_python_inference': False, + 'mode': 'fluid', + 'draw_threshold': 0.5, + 'metric': config['metric'] + }) + trt_min_subgraph = {'YOLO': 3, 'SSD': 40, 'RCNN': 40, 'RetinaNet': 40} + infer_arch = config['architecture'] + + for arch, min_subgraph_size in trt_min_subgraph.items(): + if arch in infer_arch: + infer_cfg['arch'] = arch + infer_cfg['min_subgraph_size'] = min_subgraph_size + break + + if 'Mask' in config['architecture']: + infer_cfg['mask_resolution'] = config['MaskHead']['resolution'] + infer_cfg['with_background'], infer_cfg['Preprocess'], infer_cfg[ + 'label_list'] = parse_reader(config['TestReader'], config['metric'], + infer_cfg['arch']) + yaml.dump(infer_cfg, open(os.path.join(save_dir, 'infer_cfg.yml'), 'w')) + logger.info("Export inference config file to {}".format( + os.path.join(save_dir, 'infer_cfg.yml'))) + + def prune_feed_vars(feeded_var_names, target_vars, prog): """ Filter out feed variables which are not in program, @@ -57,7 +134,8 @@ def save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog): cfg_name = os.path.basename(FLAGS.config).split('.')[0] save_dir = os.path.join(FLAGS.output_dir, cfg_name) feed_var_names = [var.name for var in feed_vars.values()] - target_vars = list(test_fetches.values()) + fetch_list = sorted(test_fetches.items(), key=lambda i: i[0]) + target_vars = [var[1] for var in fetch_list] feed_var_names = prune_feed_vars(feed_var_names, target_vars, infer_prog) logger.info("Export inference model to {}, input: {}, output: " "{}...".format(save_dir, feed_var_names, @@ -101,6 +179,7 @@ def main(): checkpoint.load_params(exe, infer_prog, cfg.weights) save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog) + dump_infer_config(cfg) if __name__ == '__main__': @@ -110,5 +189,6 @@ if __name__ == '__main__': type=str, default="output", help="Directory for storing the output model files.") + FLAGS = parser.parse_args() main() -- GitLab