diff --git a/docs/INFERENCE.md b/docs/INFERENCE.md new file mode 100644 index 0000000000000000000000000000000000000000..154bb4441449e159b8e770abeec7267964886e2a --- /dev/null +++ b/docs/INFERENCE.md @@ -0,0 +1,42 @@ +# 模型预测 + +基于[模型导出](EXPORT_MODEL.md)保存inference_model,通过下列方法对保存模型进行预测,同时测试不同方法下的预测速度 + +## 使用方式 + +```bash +export CUDA_VISIBLE_DEVICES=0 +python tools/cpp_infer.py --model_path=output/yolov3_mobilenet_v1/ --config_path=tools/cpp_demo.yml --infer_img=demo/000000570688.jpg --visualize +``` + + +主要参数说明: + +1. model_path: inference_model保存路径 +2. config_path: 数据预处理配置文件 +3. infer_img: 待预测图片 +4. visualize: 是否保存可视化结果,默认保存路径为```output/```。 + + +更多参数可在```tools/cpp_demo.yml```中查看 + + +## Paddle环境搭建 + +需要基于develop分支编译TensorRT版本Paddle, 在编译命令中指定TensorRT路径: + +``` +cmake .. -DWITH_MKL=ON \ + -DWITH_GPU=ON \ + -DWITH_TESTING=ON \ + -DCMAKE_BUILD_TYPE=Release \ + -DCUDA_ARCH_NAME=Auto \ + -DCMAKE_INSTALL_PREFIX=`pwd`/output \ + -DON_INFER=ON \ + -DTENSORRT_ROOT=${PATH_TO_TensorRT} \ + +make -j20 +make install +``` + + diff --git a/ppdet/utils/coco_eval.py b/ppdet/utils/coco_eval.py index f67f356c460f8dd9e52dbff6864c68926f37ae5f..d4596d2fb6b1cc8c1b4374ad5adba84a00f26b35 100644 --- a/ppdet/utils/coco_eval.py +++ b/ppdet/utils/coco_eval.py @@ -156,14 +156,15 @@ def proposal2out(results, is_bbox_normalized=False): for t in results: bboxes = t['proposal'][0] lengths = t['proposal'][1][0] - im_ids = np.array(t['im_id'][0]) + im_ids = np.array(t['im_id'][0]).flatten() + assert len(lengths) == im_ids.size if bboxes.shape == (1, 1) or bboxes is None: continue k = 0 for i in range(len(lengths)): num = lengths[i] - im_id = int(im_ids[i][0]) + im_id = int(im_ids[i]) for j in range(num): dt = bboxes[k] xmin, ymin, xmax, ymax = dt.tolist() @@ -201,14 +202,14 @@ def bbox2out(results, clsid2catid, is_bbox_normalized=False): for t in results: bboxes = t['bbox'][0] lengths = t['bbox'][1][0] - im_ids = np.array(t['im_id'][0]) + im_ids = np.array(t['im_id'][0]).flatten() if bboxes.shape == (1, 1) or bboxes is None: continue k = 0 for i in range(len(lengths)): num = lengths[i] - im_id = int(im_ids[i][0]) + im_id = int(im_ids[i]) for j in range(num): dt = bboxes[k] clsid, score, xmin, ymin, xmax, ymax = dt.tolist() diff --git a/tools/cpp_demo.yml b/tools/cpp_demo.yml new file mode 100644 index 0000000000000000000000000000000000000000..98f2496f0c69d50e62574c2756592b17ba4489f3 --- /dev/null +++ b/tools/cpp_demo.yml @@ -0,0 +1,27 @@ +# demo for tensorrt_infer.py + +mode: trt_fp32 # trt_fp32, trt_fp16, trt_int8, fluid +arch: RCNN # YOLO, SSD, RCNN, RetinaNet +min_subgraph_size: 20 # need 3 for YOLO arch +use_python_inference: False # whether to use python inference + +# visulize the predicted image +metric: COCO # COCO, VOC +draw_threshold: 0.5 + +Preprocess: +- type: Resize + target_size: 640 + max_size: 640 +- type: Normalize + mean: + - 0.485 + - 0.456 + - 0.406 + std: + - 0.229 + - 0.224 + - 0.225 + is_scale: True +- type: Permute + to_bgr: False diff --git a/tools/cpp_infer.py b/tools/cpp_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..8d17d29611810ffdaf27dea839bfecf307b7aae0 --- /dev/null +++ b/tools/cpp_infer.py @@ -0,0 +1,293 @@ +import os +import time + +import numpy as np +from PIL import Image + +import paddle.fluid as fluid + +import argparse +from ppdet.utils.visualizer import visualize_results, draw_bbox +from ppdet.utils.eval_utils import eval_results +import ppdet.utils.voc_eval as voc_eval +import ppdet.utils.coco_eval as coco_eval +import cv2 +import yaml + +import logging +FORMAT = '%(asctime)s-%(levelname)s: %(message)s' +logging.basicConfig(level=logging.INFO, format=FORMAT) +logger = logging.getLogger(__name__) + +eval_clses = {'COCO': coco_eval, 'VOC': voc_eval} + +precision_map = { + 'trt_int8': fluid.core.AnalysisConfig.Precision.Int8, + 'trt_fp32': fluid.core.AnalysisConfig.Precision.Float32, + 'trt_fp16': fluid.core.AnalysisConfig.Precision.Half +} + + +def create_config(model_path, mode='fluid', batch_size=1, min_subgraph_size=3): + model_file = os.path.join(model_path, '__model__') + params_file = os.path.join(model_path, '__params__') + config = fluid.core.AnalysisConfig(model_file, params_file) + config.enable_use_gpu(100, 0) + logger.info('min_subgraph_size = %d.' % (min_subgraph_size)) + + if mode in precision_map.keys(): + config.enable_tensorrt_engine( + workspace_size=1 << 30, + max_batch_size=batch_size, + min_subgraph_size=min_subgraph_size, + precision_mode=precision_map[mode], + use_static=False, + use_calib_mode=mode == 'trt_int8') + logger.info('Run inference by {}.'.format(mode)) + elif mode == 'fluid': + logger.info('Run inference by Fluid FP32.') + else: + logger.fatal( + 'Wrong mode, only support trt_int8, trt_fp32, trt_fp16, fluid.') + return config + + +def offset_to_lengths(lod): + offset = lod[0] + lengths = [offset[i + 1] - offset[i] for i in range(len(offset) - 1)] + return [lengths] + + +def DecodeImage(im_path): + with open(im_path, 'rb') as f: + im = f.read() + data = np.frombuffer(im, dtype='uint8') + im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + return im + + +def get_extra_info(im, arch, shape, scale): + info = [] + input_shape = [] + im_shape = [] + logger.info('The architecture is {}'.format(arch)) + if 'YOLO' in arch: + im_size = np.array([shape[:2]]).astype('int32') + logger.info('Extra info: im_size') + info.append(im_size) + elif 'SSD' in arch: + pass + elif 'RetinaNet' in arch: + input_shape.extend(im.shape[2:]) + im_info = np.array([input_shape + [scale]]).astype('float32') + logger.info('Extra info: im_info') + info.append(im_info) + elif 'RCNN' in arch: + input_shape.extend(im.shape[2:]) + im_shape.extend(shape[:2]) + im_info = np.array([input_shape + [scale]]).astype('float32') + im_shape = np.array([im_shape + [1.]]).astype('float32') + logger.info('Extra info: im_info, im_shape') + info.append(im_info) + info.append(im_shape) + else: + logger.error( + "Unsupported arch: {}, expect YOLO, SSD, RetinaNet and RCNN".format( + arch)) + return info + + +class Resize(object): + def __init__(self, target_size, max_size=0, interp=cv2.INTER_LINEAR): + super(Resize, self).__init__() + self.target_size = target_size + self.max_size = max_size + self.interp = interp + + def __call__(self, im): + origin_shape = im.shape[:2] + im_c = im.shape[2] + if self.max_size != 0: + im_size_min = np.min(origin_shape[0:2]) + im_size_max = np.max(origin_shape[0:2]) + im_scale = float(self.target_size) / float(im_size_min) + if np.round(im_scale * im_size_max) > self.max_size: + im_scale = float(self.max_size) / float(im_size_max) + im_scale_x = im_scale + im_scale_y = im_scale + resize_w = int(im_scale_x * float(origin_shape[1])) + resize_h = int(im_scale_y * float(origin_shape[0])) + else: + im_scale_x = float(self.target_size) / float(origin_shape[1]) + im_scale_y = float(self.target_size) / float(origin_shape[0]) + im = cv2.resize( + im, + None, + None, + fx=im_scale_x, + fy=im_scale_y, + interpolation=self.interp) + # padding im + if self.max_size != 0: + padding_im = np.zeros( + (self.max_size, self.max_size, im_c), dtype=np.float32) + im_h, im_w = im.shape[:2] + padding_im[:im_h, :im_w, :] = im + im = padding_im + return im, im_scale_x + + +class Normalize(object): + def __init__(self, mean, std, is_scale=True): + super(Normalize, self).__init__() + self.mean = mean + self.std = std + self.is_scale = is_scale + + def __call__(self, im): + im = im.astype(np.float32, copy=False) + if self.is_scale: + im = im / 255.0 + im -= self.mean + im /= self.std + return im + + +class Permute(object): + def __init__(self, to_bgr=False): + self.to_bgr = to_bgr + + def __call__(self, im): + im = im.transpose((2, 0, 1)).copy() + if self.to_bgr: + im = im[[2, 1, 0], :, :] + return im + + +def Preprocess(img_path, arch, config): + img = DecodeImage(img_path) + orig_shape = img.shape + scale = 1. + data = [] + for data_aug_conf in config: + obj = data_aug_conf.pop('type') + preprocess = eval(obj)(**data_aug_conf) + if obj == 'Resize': + img, scale = preprocess(img) + else: + img = preprocess(img) + + img = img[np.newaxis, :] # N, C, H, W + data.append(img) + extra_info = get_extra_info(img, arch, orig_shape, scale) + data += extra_info + return data + + +def infer(): + model_path = FLAGS.model_path + config_path = FLAGS.config_path + assert model_path is not None, "Model path: {} does not exist!".format( + model_path) + assert config_path is not None, "Config path: {} does not exist!".format( + config_path) + with open(config_path) as f: + conf = yaml.safe_load(f) + + img_data = Preprocess(FLAGS.infer_img, conf['arch'], conf['Preprocess']) + + if conf['use_python_inference']: + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + infer_prog, feed_var_names, fetch_targets = fluid.io.load_inference_model( + dirname=model_path, + executor=exe, + model_filename='__model__', + params_filename='__params__') + data_dict = {k: v for k, v in zip(feed_var_names, img_data)} + else: + inputs = [fluid.core.PaddleTensor(d.copy()) for d in img_data] + config = create_config( + model_path, + mode=conf['mode'], + min_subgraph_size=conf['min_subgraph_size']) + predict = fluid.core.create_paddle_predictor(config) + + logger.info('warmup...') + for i in range(10): + if conf['use_python_inference']: + outs = exe.run(infer_prog, + feed=[data_dict], + fetch_list=fetch_targets, + return_numpy=False) + else: + outs = predict.run(inputs) + + cnt = 100 + logger.info('run benchmark...') + t1 = time.time() + for i in range(cnt): + if conf['use_python_inference']: + outs = exe.run(infer_prog, + feed=[data_dict], + fetch_list=fetch_targets, + return_numpy=False) + else: + outs = predict.run(inputs) + t2 = time.time() + + ms = (t2 - t1) * 1000.0 / float(cnt) + + print("Inference: {} ms per batch image".format(ms)) + + if FLAGS.visualize: + eval_cls = eval_clses[conf['metric']] + + with_background = conf['arch'] != 'YOLO' + clsid2catid, catid2name = eval_cls.get_category_info( + None, with_background, True) + + is_bbox_normalized = True if 'SSD' in conf['arch'] else False + + out = outs[-1] + res = {} + lod = out.lod() if conf['use_python_inference'] else out.lod + lengths = offset_to_lengths(lod) + np_data = np.array(out) if conf[ + 'use_python_inference'] else out.as_ndarray() + + res['bbox'] = (np_data, lengths) + res['im_id'] = np.array([[0]]) + + bbox_results = eval_cls.bbox2out([res], clsid2catid, is_bbox_normalized) + + image = Image.open(FLAGS.infer_img).convert('RGB') + image = draw_bbox(image, 0, catid2name, bbox_results, 0.5) + image_path = os.path.split(FLAGS.infer_img)[-1] + if not os.path.exists(FLAGS.output_dir): + os.makedirs(FLAGS.output_dir) + out_path = os.path.join(FLAGS.output_dir, image_path) + image.save(out_path, quality=95) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--model_path", type=str, default=None, help="model path.") + parser.add_argument( + "--config_path", type=str, default=None, help="preprocess config path.") + parser.add_argument( + "--infer_img", type=str, default=None, help="Image path") + parser.add_argument( + "--visualize", + action='store_true', + default=False, + help="Whether to visualize detection output") + parser.add_argument( + "--output_dir", + type=str, + default="output", + help="Directory for storing the output visualization files.") + FLAGS = parser.parse_args() + infer() diff --git a/tools/export_model.py b/tools/export_model.py index c7de582f28cbdb9bb5df612af4421e873ba8c66b..9e31aca5390ca7e76e12d90b91e57c57264e038b 100644 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -21,7 +21,6 @@ import os from paddle import fluid from ppdet.core.workspace import load_config, merge_config, create -from ppdet.modeling.model_input import create_feed from ppdet.utils.cli import ArgsParser import ppdet.utils.checkpoint as checkpoint