未验证 提交 df3c8cd7 编写于 作者: C channings 提交者: GitHub

add runtime log& script support set run_mode (#630)

* add runtime log& script support set run_mode
* update code
上级 0d127b2b
## PaddleDetection Python 预测部署方案
本篇教程使用AnalysisPredictor对[导出模型](../../docs/advanced_tutorials/inference/EXPORT_MODEL.md)进行高性能预测。
在PaddlePaddle中预测引擎和训练引擎底层有着不同的优化方法, 下面列出了两种不同的预测方式。Executor同时支持训练和预测,AnalysisPredictor则专门针对推理进行了优化,是基于[C++预测库](https://www.paddlepaddle.org.cn/documentation/docs/zh/advanced_guide/inference_deployment/inference/native_infer.html)的Python接口,该引擎可以对模型进行多项图优化,减少不必要的内存拷贝。如果用户在部署已训练模型的过程中对性能有较高的要求,于是我们提供了独立于PaddleDetection的预测脚本,方便用户直接集成部署。
在PaddlePaddle中预测引擎和训练引擎底层有着不同的优化方法, 下面列出了两种不同的预测方式。Executor同时支持训练和预测,AnalysisPredictor则专门针对推理进行了优化,是基于[C++预测库](https://www.paddlepaddle.org.cn/documentation/docs/zh/advanced_guide/inference_deployment/inference/native_infer.html)的Python接口,该引擎可以对模型进行多项图优化,减少不必要的内存拷贝。如果用户在部署已训练模型的过程中对性能有较高的要求,我们提供了独立于PaddleDetection的预测脚本,方便用户直接集成部署。
- Executor:[Executor](https://www.paddlepaddle.org.cn/documentation/docs/zh/beginners_guide/basic_concept/executor.html#executor)
- AnalysisPredictor:[AnalysisPredictor](https://www.paddlepaddle.org.cn/documentation/docs/zh/advanced_guide/inference_deployment/inference/python_infer_cn.html#analysispredictor)
......@@ -40,10 +40,13 @@ python deploy/python/infer.py --model_dir=/path/to/models --image_file=/path/to/
| --image_file | Yes |需要预测的图片 |
| --video_file | Yes |需要预测的视频 |
| --use_gpu |No|是否GPU,默认为False|
| --run_mode |No|使用GPU时,默认为fluid, 可选(fluid/trt_fp32/trt_fp16/trt_int8)|
| --threshold |No|预测得分的阈值,默认为0.5|
| --visualize |No|是否可视化结果,默认为False|
| --output_dir |No|可视化结果保存的根目录,默认为output/|
说明:
run_mode:fluid代表使用AnalysisPredictor,精度float32来推理,其他参数指用AnalysisPredictor,TensorRT不同精度来推理。
## 3. 部署性能对比测试
对比AnalysisPredictor相对Executor的推理速度
......
......@@ -14,7 +14,9 @@
import os
import argparse
import time
import yaml
from PIL import Image
import cv2
import numpy as np
......@@ -279,7 +281,6 @@ class Config():
self.arch = yml_conf['arch']
self.preprocess_infos = yml_conf['Preprocess']
self.use_python_inference = yml_conf['use_python_inference']
self.run_mode = yml_conf['mode']
self.min_subgraph_size = yml_conf['min_subgraph_size']
self.labels = yml_conf['label_list']
if not yml_conf['with_background']:
......@@ -337,7 +338,7 @@ def load_predictor(model_dir,
if run_mode in precision_map.keys():
config.enable_tensorrt_engine(
workspace_size=1 << 30,
workspace_size=1 << 10,
max_batch_size=batch_size,
min_subgraph_size=min_subgraph_size,
precision_mode=precision_map[run_mode],
......@@ -391,7 +392,11 @@ class Detector():
use_gpu (bool): whether use gpu
"""
def __init__(self, model_dir, use_gpu=False, threshold=0.5):
def __init__(self,
model_dir,
use_gpu=False,
run_mode='fluid',
threshold=0.5):
self.config = Config(model_dir)
if self.config.use_python_inference:
self.executor, self.program, self.fecth_targets = load_executor(
......@@ -399,7 +404,7 @@ class Detector():
else:
self.predictor = load_predictor(
model_dir,
run_mode=self.config.run_mode,
run_mode=run_mode,
min_subgraph_size=self.config.min_subgraph_size,
use_gpu=use_gpu)
self.preprocess_ops = []
......@@ -459,19 +464,29 @@ class Detector():
inputs, im_info = self.preprocess(image)
np_boxes, np_masks = None, None
if self.config.use_python_inference:
t1 = time.time()
outs = self.executor.run(self.program,
feed=inputs,
fetch_list=self.fecth_targets,
return_numpy=False)
t2 = time.time()
ms = (t2 - t1) * 1000.0
print("Inference: {} ms per batch image".format(ms))
np_boxes = np.array(outs[0])
if self.config.mask_resolution is not None:
np_masks = np.arrya(outs[1])
np_masks = np.array(outs[1])
else:
input_names = self.predictor.get_input_names()
for i in range(len(inputs)):
input_tensor = self.predictor.get_input_tensor(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]])
t1 = time.time()
self.predictor.zero_copy_run()
t2 = time.time()
ms = (t2 - t1) * 1000.0
print("Inference: {} ms per batch image".format(ms))
output_names = self.predictor.get_output_names()
boxes_tensor = self.predictor.get_output_tensor(output_names[0])
np_boxes = boxes_tensor.copy_to_cpu()
......@@ -484,7 +499,8 @@ class Detector():
def predict_image():
detector = Detector(FLAGS.model_dir, use_gpu=FLAGS.use_gpu)
detector = Detector(
FLAGS.model_dir, use_gpu=FLAGS.use_gpu, run_mode=FLAGS.run_mode)
results = detector.predict(FLAGS.image_file, FLAGS.threshold)
visualize(
FLAGS.image_file,
......@@ -495,12 +511,13 @@ def predict_image():
def predict_video():
detector = Detector(FLAGS.model_dir, use_gpu=FLAGS.use_gpu)
detector = Detector(
FLAGS.model_dir, use_gpu=FLAGS.use_gpu, run_mode=FLAGS.run_mode)
capture = cv2.VideoCapture(FLAGS.video_file)
fps = 30
width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
video_name = os.path.split(FLAGS.video_file)[-1]
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGES.output_dir)
......@@ -537,6 +554,11 @@ if __name__ == '__main__':
"--image_file", type=str, default='', help="Path of image file.")
parser.add_argument(
"--video_file", type=str, default='', help="Path of video file.")
parser.add_argument(
"--run_mode",
type=str,
default='fluid',
help="mode of running(fluid/trt_fp32/trt_fp16/trt_int8)")
parser.add_argument(
"--use_gpu", default=False, help="Whether to predict with GPU.")
parser.add_argument(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册