diff --git a/deploy/cpp/src/object_detector.cc b/deploy/cpp/src/object_detector.cc index c6522f4f113927ea7964320fa462f2b6ba21045e..2de87f5e27170d904eec4c408a71370b70f85bfb 100644 --- a/deploy/cpp/src/object_detector.cc +++ b/deploy/cpp/src/object_detector.cc @@ -182,7 +182,12 @@ void ObjectDetector::Predict(const cv::Mat& im, // Calculate output length int output_size = 1; for (int j = 0; j < output_shape.size(); ++j) { - output_size *= output_shape[j]; + output_size *= output_shape[j]; + } + + if (output_size < 6) { + std::cerr << "[WARNING] No object detected." << std::endl; + return true; } output_data_.resize(output_size); out_tensor->copy_to_cpu(output_data_.data()); diff --git a/deploy/python/README.md b/deploy/python/README.md index 105f6285228a04afac4369f33a1fa25d27350bf9..9c810ae1de2c21db787ec0a9055b455f886b496c 100644 --- a/deploy/python/README.md +++ b/deploy/python/README.md @@ -48,6 +48,7 @@ python deploy/python/infer.py --model_dir=/path/to/models --image_file=/path/to/ | --run_mode |No|使用GPU时,默认为fluid, 可选(fluid/trt_fp32/trt_fp16)| | --threshold |No|预测得分的阈值,默认为0.5| | --output_dir |No|可视化结果保存的根目录,默认为output/| +| --run_benchmark |No|是否运行benchmark,同时需指定--image_file| 说明: diff --git a/deploy/python/infer.py b/deploy/python/infer.py index 77d10bf4e41a5fc0c4fc7599e9913a7ede8ffbee..091fb724d0a4c82c4f0391c9c4d99b2f60ffcd8e 100644 --- a/deploy/python/infer.py +++ b/deploy/python/infer.py @@ -16,6 +16,8 @@ import os import argparse import time import yaml +import ast +from functools import reduce from PIL import Image import cv2 @@ -286,6 +288,7 @@ class Config(): self.mask_resolution = None if 'mask_resolution' in yml_conf: self.mask_resolution = yml_conf['mask_resolution'] + self.print_config() def check_model(self, yml_conf): """ @@ -299,6 +302,15 @@ class Config(): "Unsupported arch: {}, expect SSD, YOLO, RetinaNet, RCNN and Face". format(yml_conf['arch'])) + def print_config(self): + print('----------- Model Configuration -----------') + print('%s: %s' % ('Model Arch', self.arch)) + print('%s: %s' % ('Use Padddle Executor', self.use_python_inference)) + print('%s: ' % ('Transform Order')) + for op_info in self.preprocess_infos: + print('--%s: %s' % ('transform op', op_info['type'])) + print('--------------------------------------------') + def load_predictor(model_dir, run_mode='fluid', @@ -322,6 +334,7 @@ def load_predictor(model_dir, raise ValueError("TensorRT int8 mode is not supported now, " "please use trt_fp32 or trt_fp16 instead.") precision_map = { + 'trt_int8': fluid.core.AnalysisConfig.Precision.Int8, 'trt_fp32': fluid.core.AnalysisConfig.Precision.Float32, 'trt_fp16': fluid.core.AnalysisConfig.Precision.Half } @@ -450,7 +463,7 @@ class Detector(): results['masks'] = np_masks return results - def predict(self, image, threshold=0.5): + def predict(self, image, threshold=0.5, warmup=0, repeats=1): ''' Args: image (str/np.ndarray): path of image/ np.ndarray read by cv2 @@ -464,13 +477,19 @@ class Detector(): inputs, im_info = self.preprocess(image) np_boxes, np_masks = None, None if self.config.use_python_inference: + for i in range(warmup): + outs = self.executor.run(self.program, + feed=inputs, + fetch_list=self.fecth_targets, + return_numpy=False) t1 = time.time() - outs = self.executor.run(self.program, - feed=inputs, - fetch_list=self.fecth_targets, - return_numpy=False) + for i in range(repeats): + outs = self.executor.run(self.program, + feed=inputs, + fetch_list=self.fecth_targets, + return_numpy=False) t2 = time.time() - ms = (t2 - t1) * 1000.0 + ms = (t2 - t1) * 1000.0 / repeats print("Inference: {} ms per batch image".format(ms)) np_boxes = np.array(outs[0]) @@ -481,35 +500,55 @@ class Detector(): for i in range(len(inputs)): input_tensor = self.predictor.get_input_tensor(input_names[i]) input_tensor.copy_from_cpu(inputs[input_names[i]]) - t1 = time.time() - self.predictor.zero_copy_run() - t2 = time.time() - output_names = self.predictor.get_output_names() - boxes_tensor = self.predictor.get_output_tensor(output_names[0]) - np_boxes = boxes_tensor.copy_to_cpu() - if self.config.mask_resolution is not None: - masks_tensor = self.predictor.get_output_tensor(output_names[1]) - np_masks = masks_tensor.copy_to_cpu() + for i in range(warmup): + self.predictor.zero_copy_run() + output_names = self.predictor.get_output_names() + boxes_tensor = self.predictor.get_output_tensor(output_names[0]) + np_boxes = boxes_tensor.copy_to_cpu() + if self.config.mask_resolution is not None: + masks_tensor = self.predictor.get_output_tensor( + output_names[1]) + np_masks = masks_tensor.copy_to_cpu() - ms = (t2 - t1) * 1000.0 + t1 = time.time() + for i in range(repeats): + self.predictor.zero_copy_run() + output_names = self.predictor.get_output_names() + boxes_tensor = self.predictor.get_output_tensor(output_names[0]) + np_boxes = boxes_tensor.copy_to_cpu() + if self.config.mask_resolution is not None: + masks_tensor = self.predictor.get_output_tensor( + output_names[1]) + np_masks = masks_tensor.copy_to_cpu() + t2 = time.time() + ms = (t2 - t1) * 1000.0 / repeats print("Inference: {} ms per batch image".format(ms)) - results = self.postprocess( - np_boxes, np_masks, im_info, threshold=threshold) + if reduce(lambda x, y: x * y, np_boxes.shape) < 6: + print('[WARNNING] No object detected.') + results = {'boxes': np.array([])} + else: + results = self.postprocess( + np_boxes, np_masks, im_info, threshold=threshold) + return results def predict_image(): detector = Detector( FLAGS.model_dir, use_gpu=FLAGS.use_gpu, run_mode=FLAGS.run_mode) - results = detector.predict(FLAGS.image_file, FLAGS.threshold) - visualize( - FLAGS.image_file, - results, - detector.config.labels, - mask_resolution=detector.config.mask_resolution, - output_dir=FLAGS.output_dir) + if FLAGS.run_benchmark: + detector.predict( + FLAGS.image_file, FLAGS.threshold, warmup=100, repeats=100) + else: + results = detector.predict(FLAGS.image_file, FLAGS.threshold) + visualize( + FLAGS.image_file, + results, + detector.config.labels, + mask_resolution=detector.config.mask_resolution, + output_dir=FLAGS.output_dir) def predict_video(): @@ -543,6 +582,13 @@ def predict_video(): writer.release() +def print_arguments(args): + print('----------- Running Arguments -----------') + for arg, value in sorted(vars(args).items()): + print('%s: %s' % (arg, value)) + print('------------------------------------------') + + if __name__ == '__main__': parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( @@ -562,7 +608,15 @@ if __name__ == '__main__': default='fluid', help="mode of running(fluid/trt_fp32/trt_fp16)") parser.add_argument( - "--use_gpu", default=False, help="Whether to predict with GPU.") + "--use_gpu", + type=ast.literal_eval, + default=False, + help="Whether to predict with GPU.") + parser.add_argument( + "--run_benchmark", + type=ast.literal_eval, + default=False, + help="Whether to predict a image_file repeatedly for benchmark") parser.add_argument( "--threshold", type=float, default=0.5, help="Threshold of score.") parser.add_argument( @@ -572,6 +626,8 @@ if __name__ == '__main__': help="Directory of output visualization files.") FLAGS = parser.parse_args() + print_arguments(FLAGS) + if FLAGS.image_file != '' and FLAGS.video_file != '': assert "Cannot predict image and video at the same time" if FLAGS.image_file != '':