未验证 提交 ace2e5e1 编写于 作者: C channings 提交者: GitHub

deploy add log & benchmark (#791)

* python deploy add log
* deploy add log & benchmark
* Update infer.py
上级 b4ea2699
......@@ -182,7 +182,12 @@ void ObjectDetector::Predict(const cv::Mat& im,
// Calculate output length
int output_size = 1;
for (int j = 0; j < output_shape.size(); ++j) {
output_size *= output_shape[j];
output_size *= output_shape[j];
}
if (output_size < 6) {
std::cerr << "[WARNING] No object detected." << std::endl;
return true;
}
output_data_.resize(output_size);
out_tensor->copy_to_cpu(output_data_.data());
......
......@@ -48,6 +48,7 @@ python deploy/python/infer.py --model_dir=/path/to/models --image_file=/path/to/
| --run_mode |No|使用GPU时,默认为fluid, 可选(fluid/trt_fp32/trt_fp16)|
| --threshold |No|预测得分的阈值,默认为0.5|
| --output_dir |No|可视化结果保存的根目录,默认为output/|
| --run_benchmark |No|是否运行benchmark,同时需指定--image_file|
说明:
......
......@@ -16,6 +16,8 @@ import os
import argparse
import time
import yaml
import ast
from functools import reduce
from PIL import Image
import cv2
......@@ -286,6 +288,7 @@ class Config():
self.mask_resolution = None
if 'mask_resolution' in yml_conf:
self.mask_resolution = yml_conf['mask_resolution']
self.print_config()
def check_model(self, yml_conf):
"""
......@@ -299,6 +302,15 @@ class Config():
"Unsupported arch: {}, expect SSD, YOLO, RetinaNet, RCNN and Face".
format(yml_conf['arch']))
def print_config(self):
print('----------- Model Configuration -----------')
print('%s: %s' % ('Model Arch', self.arch))
print('%s: %s' % ('Use Padddle Executor', self.use_python_inference))
print('%s: ' % ('Transform Order'))
for op_info in self.preprocess_infos:
print('--%s: %s' % ('transform op', op_info['type']))
print('--------------------------------------------')
def load_predictor(model_dir,
run_mode='fluid',
......@@ -322,6 +334,7 @@ def load_predictor(model_dir,
raise ValueError("TensorRT int8 mode is not supported now, "
"please use trt_fp32 or trt_fp16 instead.")
precision_map = {
'trt_int8': fluid.core.AnalysisConfig.Precision.Int8,
'trt_fp32': fluid.core.AnalysisConfig.Precision.Float32,
'trt_fp16': fluid.core.AnalysisConfig.Precision.Half
}
......@@ -450,7 +463,7 @@ class Detector():
results['masks'] = np_masks
return results
def predict(self, image, threshold=0.5):
def predict(self, image, threshold=0.5, warmup=0, repeats=1):
'''
Args:
image (str/np.ndarray): path of image/ np.ndarray read by cv2
......@@ -464,13 +477,19 @@ class Detector():
inputs, im_info = self.preprocess(image)
np_boxes, np_masks = None, None
if self.config.use_python_inference:
for i in range(warmup):
outs = self.executor.run(self.program,
feed=inputs,
fetch_list=self.fecth_targets,
return_numpy=False)
t1 = time.time()
outs = self.executor.run(self.program,
feed=inputs,
fetch_list=self.fecth_targets,
return_numpy=False)
for i in range(repeats):
outs = self.executor.run(self.program,
feed=inputs,
fetch_list=self.fecth_targets,
return_numpy=False)
t2 = time.time()
ms = (t2 - t1) * 1000.0
ms = (t2 - t1) * 1000.0 / repeats
print("Inference: {} ms per batch image".format(ms))
np_boxes = np.array(outs[0])
......@@ -481,35 +500,55 @@ class Detector():
for i in range(len(inputs)):
input_tensor = self.predictor.get_input_tensor(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]])
t1 = time.time()
self.predictor.zero_copy_run()
t2 = time.time()
output_names = self.predictor.get_output_names()
boxes_tensor = self.predictor.get_output_tensor(output_names[0])
np_boxes = boxes_tensor.copy_to_cpu()
if self.config.mask_resolution is not None:
masks_tensor = self.predictor.get_output_tensor(output_names[1])
np_masks = masks_tensor.copy_to_cpu()
for i in range(warmup):
self.predictor.zero_copy_run()
output_names = self.predictor.get_output_names()
boxes_tensor = self.predictor.get_output_tensor(output_names[0])
np_boxes = boxes_tensor.copy_to_cpu()
if self.config.mask_resolution is not None:
masks_tensor = self.predictor.get_output_tensor(
output_names[1])
np_masks = masks_tensor.copy_to_cpu()
ms = (t2 - t1) * 1000.0
t1 = time.time()
for i in range(repeats):
self.predictor.zero_copy_run()
output_names = self.predictor.get_output_names()
boxes_tensor = self.predictor.get_output_tensor(output_names[0])
np_boxes = boxes_tensor.copy_to_cpu()
if self.config.mask_resolution is not None:
masks_tensor = self.predictor.get_output_tensor(
output_names[1])
np_masks = masks_tensor.copy_to_cpu()
t2 = time.time()
ms = (t2 - t1) * 1000.0 / repeats
print("Inference: {} ms per batch image".format(ms))
results = self.postprocess(
np_boxes, np_masks, im_info, threshold=threshold)
if reduce(lambda x, y: x * y, np_boxes.shape) < 6:
print('[WARNNING] No object detected.')
results = {'boxes': np.array([])}
else:
results = self.postprocess(
np_boxes, np_masks, im_info, threshold=threshold)
return results
def predict_image():
detector = Detector(
FLAGS.model_dir, use_gpu=FLAGS.use_gpu, run_mode=FLAGS.run_mode)
results = detector.predict(FLAGS.image_file, FLAGS.threshold)
visualize(
FLAGS.image_file,
results,
detector.config.labels,
mask_resolution=detector.config.mask_resolution,
output_dir=FLAGS.output_dir)
if FLAGS.run_benchmark:
detector.predict(
FLAGS.image_file, FLAGS.threshold, warmup=100, repeats=100)
else:
results = detector.predict(FLAGS.image_file, FLAGS.threshold)
visualize(
FLAGS.image_file,
results,
detector.config.labels,
mask_resolution=detector.config.mask_resolution,
output_dir=FLAGS.output_dir)
def predict_video():
......@@ -543,6 +582,13 @@ def predict_video():
writer.release()
def print_arguments(args):
print('----------- Running Arguments -----------')
for arg, value in sorted(vars(args).items()):
print('%s: %s' % (arg, value))
print('------------------------------------------')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
......@@ -562,7 +608,15 @@ if __name__ == '__main__':
default='fluid',
help="mode of running(fluid/trt_fp32/trt_fp16)")
parser.add_argument(
"--use_gpu", default=False, help="Whether to predict with GPU.")
"--use_gpu",
type=ast.literal_eval,
default=False,
help="Whether to predict with GPU.")
parser.add_argument(
"--run_benchmark",
type=ast.literal_eval,
default=False,
help="Whether to predict a image_file repeatedly for benchmark")
parser.add_argument(
"--threshold", type=float, default=0.5, help="Threshold of score.")
parser.add_argument(
......@@ -572,6 +626,8 @@ if __name__ == '__main__':
help="Directory of output visualization files.")
FLAGS = parser.parse_args()
print_arguments(FLAGS)
if FLAGS.image_file != '' and FLAGS.video_file != '':
assert "Cannot predict image and video at the same time"
if FLAGS.image_file != '':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册