未验证 提交 ace2e5e1 编写于 作者: C channings 提交者: GitHub

deploy add log & benchmark (#791)

* python deploy add log
* deploy add log & benchmark
* Update infer.py
上级 b4ea2699
...@@ -182,7 +182,12 @@ void ObjectDetector::Predict(const cv::Mat& im, ...@@ -182,7 +182,12 @@ void ObjectDetector::Predict(const cv::Mat& im,
// Calculate output length // Calculate output length
int output_size = 1; int output_size = 1;
for (int j = 0; j < output_shape.size(); ++j) { for (int j = 0; j < output_shape.size(); ++j) {
output_size *= output_shape[j]; output_size *= output_shape[j];
}
if (output_size < 6) {
std::cerr << "[WARNING] No object detected." << std::endl;
return true;
} }
output_data_.resize(output_size); output_data_.resize(output_size);
out_tensor->copy_to_cpu(output_data_.data()); out_tensor->copy_to_cpu(output_data_.data());
......
...@@ -48,6 +48,7 @@ python deploy/python/infer.py --model_dir=/path/to/models --image_file=/path/to/ ...@@ -48,6 +48,7 @@ python deploy/python/infer.py --model_dir=/path/to/models --image_file=/path/to/
| --run_mode |No|使用GPU时,默认为fluid, 可选(fluid/trt_fp32/trt_fp16)| | --run_mode |No|使用GPU时,默认为fluid, 可选(fluid/trt_fp32/trt_fp16)|
| --threshold |No|预测得分的阈值,默认为0.5| | --threshold |No|预测得分的阈值,默认为0.5|
| --output_dir |No|可视化结果保存的根目录,默认为output/| | --output_dir |No|可视化结果保存的根目录,默认为output/|
| --run_benchmark |No|是否运行benchmark,同时需指定--image_file|
说明: 说明:
......
...@@ -16,6 +16,8 @@ import os ...@@ -16,6 +16,8 @@ import os
import argparse import argparse
import time import time
import yaml import yaml
import ast
from functools import reduce
from PIL import Image from PIL import Image
import cv2 import cv2
...@@ -286,6 +288,7 @@ class Config(): ...@@ -286,6 +288,7 @@ class Config():
self.mask_resolution = None self.mask_resolution = None
if 'mask_resolution' in yml_conf: if 'mask_resolution' in yml_conf:
self.mask_resolution = yml_conf['mask_resolution'] self.mask_resolution = yml_conf['mask_resolution']
self.print_config()
def check_model(self, yml_conf): def check_model(self, yml_conf):
""" """
...@@ -299,6 +302,15 @@ class Config(): ...@@ -299,6 +302,15 @@ class Config():
"Unsupported arch: {}, expect SSD, YOLO, RetinaNet, RCNN and Face". "Unsupported arch: {}, expect SSD, YOLO, RetinaNet, RCNN and Face".
format(yml_conf['arch'])) format(yml_conf['arch']))
def print_config(self):
print('----------- Model Configuration -----------')
print('%s: %s' % ('Model Arch', self.arch))
print('%s: %s' % ('Use Padddle Executor', self.use_python_inference))
print('%s: ' % ('Transform Order'))
for op_info in self.preprocess_infos:
print('--%s: %s' % ('transform op', op_info['type']))
print('--------------------------------------------')
def load_predictor(model_dir, def load_predictor(model_dir,
run_mode='fluid', run_mode='fluid',
...@@ -322,6 +334,7 @@ def load_predictor(model_dir, ...@@ -322,6 +334,7 @@ def load_predictor(model_dir,
raise ValueError("TensorRT int8 mode is not supported now, " raise ValueError("TensorRT int8 mode is not supported now, "
"please use trt_fp32 or trt_fp16 instead.") "please use trt_fp32 or trt_fp16 instead.")
precision_map = { precision_map = {
'trt_int8': fluid.core.AnalysisConfig.Precision.Int8,
'trt_fp32': fluid.core.AnalysisConfig.Precision.Float32, 'trt_fp32': fluid.core.AnalysisConfig.Precision.Float32,
'trt_fp16': fluid.core.AnalysisConfig.Precision.Half 'trt_fp16': fluid.core.AnalysisConfig.Precision.Half
} }
...@@ -450,7 +463,7 @@ class Detector(): ...@@ -450,7 +463,7 @@ class Detector():
results['masks'] = np_masks results['masks'] = np_masks
return results return results
def predict(self, image, threshold=0.5): def predict(self, image, threshold=0.5, warmup=0, repeats=1):
''' '''
Args: Args:
image (str/np.ndarray): path of image/ np.ndarray read by cv2 image (str/np.ndarray): path of image/ np.ndarray read by cv2
...@@ -464,13 +477,19 @@ class Detector(): ...@@ -464,13 +477,19 @@ class Detector():
inputs, im_info = self.preprocess(image) inputs, im_info = self.preprocess(image)
np_boxes, np_masks = None, None np_boxes, np_masks = None, None
if self.config.use_python_inference: if self.config.use_python_inference:
for i in range(warmup):
outs = self.executor.run(self.program,
feed=inputs,
fetch_list=self.fecth_targets,
return_numpy=False)
t1 = time.time() t1 = time.time()
outs = self.executor.run(self.program, for i in range(repeats):
feed=inputs, outs = self.executor.run(self.program,
fetch_list=self.fecth_targets, feed=inputs,
return_numpy=False) fetch_list=self.fecth_targets,
return_numpy=False)
t2 = time.time() t2 = time.time()
ms = (t2 - t1) * 1000.0 ms = (t2 - t1) * 1000.0 / repeats
print("Inference: {} ms per batch image".format(ms)) print("Inference: {} ms per batch image".format(ms))
np_boxes = np.array(outs[0]) np_boxes = np.array(outs[0])
...@@ -481,35 +500,55 @@ class Detector(): ...@@ -481,35 +500,55 @@ class Detector():
for i in range(len(inputs)): for i in range(len(inputs)):
input_tensor = self.predictor.get_input_tensor(input_names[i]) input_tensor = self.predictor.get_input_tensor(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]]) input_tensor.copy_from_cpu(inputs[input_names[i]])
t1 = time.time()
self.predictor.zero_copy_run()
t2 = time.time()
output_names = self.predictor.get_output_names() for i in range(warmup):
boxes_tensor = self.predictor.get_output_tensor(output_names[0]) self.predictor.zero_copy_run()
np_boxes = boxes_tensor.copy_to_cpu() output_names = self.predictor.get_output_names()
if self.config.mask_resolution is not None: boxes_tensor = self.predictor.get_output_tensor(output_names[0])
masks_tensor = self.predictor.get_output_tensor(output_names[1]) np_boxes = boxes_tensor.copy_to_cpu()
np_masks = masks_tensor.copy_to_cpu() if self.config.mask_resolution is not None:
masks_tensor = self.predictor.get_output_tensor(
output_names[1])
np_masks = masks_tensor.copy_to_cpu()
ms = (t2 - t1) * 1000.0 t1 = time.time()
for i in range(repeats):
self.predictor.zero_copy_run()
output_names = self.predictor.get_output_names()
boxes_tensor = self.predictor.get_output_tensor(output_names[0])
np_boxes = boxes_tensor.copy_to_cpu()
if self.config.mask_resolution is not None:
masks_tensor = self.predictor.get_output_tensor(
output_names[1])
np_masks = masks_tensor.copy_to_cpu()
t2 = time.time()
ms = (t2 - t1) * 1000.0 / repeats
print("Inference: {} ms per batch image".format(ms)) print("Inference: {} ms per batch image".format(ms))
results = self.postprocess( if reduce(lambda x, y: x * y, np_boxes.shape) < 6:
np_boxes, np_masks, im_info, threshold=threshold) print('[WARNNING] No object detected.')
results = {'boxes': np.array([])}
else:
results = self.postprocess(
np_boxes, np_masks, im_info, threshold=threshold)
return results return results
def predict_image(): def predict_image():
detector = Detector( detector = Detector(
FLAGS.model_dir, use_gpu=FLAGS.use_gpu, run_mode=FLAGS.run_mode) FLAGS.model_dir, use_gpu=FLAGS.use_gpu, run_mode=FLAGS.run_mode)
results = detector.predict(FLAGS.image_file, FLAGS.threshold) if FLAGS.run_benchmark:
visualize( detector.predict(
FLAGS.image_file, FLAGS.image_file, FLAGS.threshold, warmup=100, repeats=100)
results, else:
detector.config.labels, results = detector.predict(FLAGS.image_file, FLAGS.threshold)
mask_resolution=detector.config.mask_resolution, visualize(
output_dir=FLAGS.output_dir) FLAGS.image_file,
results,
detector.config.labels,
mask_resolution=detector.config.mask_resolution,
output_dir=FLAGS.output_dir)
def predict_video(): def predict_video():
...@@ -543,6 +582,13 @@ def predict_video(): ...@@ -543,6 +582,13 @@ def predict_video():
writer.release() writer.release()
def print_arguments(args):
print('----------- Running Arguments -----------')
for arg, value in sorted(vars(args).items()):
print('%s: %s' % (arg, value))
print('------------------------------------------')
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument( parser.add_argument(
...@@ -562,7 +608,15 @@ if __name__ == '__main__': ...@@ -562,7 +608,15 @@ if __name__ == '__main__':
default='fluid', default='fluid',
help="mode of running(fluid/trt_fp32/trt_fp16)") help="mode of running(fluid/trt_fp32/trt_fp16)")
parser.add_argument( parser.add_argument(
"--use_gpu", default=False, help="Whether to predict with GPU.") "--use_gpu",
type=ast.literal_eval,
default=False,
help="Whether to predict with GPU.")
parser.add_argument(
"--run_benchmark",
type=ast.literal_eval,
default=False,
help="Whether to predict a image_file repeatedly for benchmark")
parser.add_argument( parser.add_argument(
"--threshold", type=float, default=0.5, help="Threshold of score.") "--threshold", type=float, default=0.5, help="Threshold of score.")
parser.add_argument( parser.add_argument(
...@@ -572,6 +626,8 @@ if __name__ == '__main__': ...@@ -572,6 +626,8 @@ if __name__ == '__main__':
help="Directory of output visualization files.") help="Directory of output visualization files.")
FLAGS = parser.parse_args() FLAGS = parser.parse_args()
print_arguments(FLAGS)
if FLAGS.image_file != '' and FLAGS.video_file != '': if FLAGS.image_file != '' and FLAGS.video_file != '':
assert "Cannot predict image and video at the same time" assert "Cannot predict image and video at the same time"
if FLAGS.image_file != '': if FLAGS.image_file != '':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册