未验证 提交 bb409b0b 编写于 作者: C channings 提交者: GitHub

deploy add log & benchmark (#827)

* python deploy add log

* deploy add log & benchmark

* Update infer.py

* fix bug cannot load video

* fix cpp bug cannot load video

* fix bug of build
上级 3f1e7855
......@@ -85,7 +85,7 @@ make
```shell
sh ./scripts/build.sh
```
**注意**: OPENCV依赖OPENBLAS,Ubuntu用户需确认系统是否已存在`libopenblas.so`。如未安装,可执行apt-get install libopenblas-dev进行安装。
### Step5: 预测及可视化
编译成功后,预测入口程序为`build/main`其主要命令参数说明如下:
......@@ -95,9 +95,10 @@ make
| image_path | 要预测的图片文件路径 |
| video_path | 要预测的视频文件路径 |
| use_gpu | 是否使用 GPU 预测, 支持值为0或1(默认值为0)|
| gpu_id | 指定进行推理的GPU device id(默认值为0)|
| --run_mode |使用GPU时,默认为fluid, 可选(fluid/trt_fp32/trt_fp16)|
**注意**如果同时设置了`video_path``image_path`,程序仅预测`video_path`
**注意**: 如果同时设置了`video_path``image_path`,程序仅预测`video_path`
`样例一`
......@@ -111,7 +112,7 @@ make
`样例二`:
```shell
#使用 `GPU`预测视频`/root/projects/videos/test.avi`
./build/main --model_dir=/root/projects/models/yolov3_darknet --video_path=/root/projects/images/test.avi --use_gpu=1
#使用 `GPU`预测视频`/root/projects/videos/test.mp4`
./build/main --model_dir=/root/projects/models/yolov3_darknet --video_path=/root/projects/images/test.mp4 --use_gpu=1
```
视频文件`可视化预测结果`会保存在当前目录下`output.avi`文件中。
视频文件目前支持`.mp4`格式的预测,`可视化预测结果`会保存在当前目录下`output.mp4`文件中。
......@@ -96,6 +96,7 @@ cd D:\projects\PaddleDetection\deploy\cpp\out\build\x64-Release
| image_path | 要预测的图片文件路径 |
| video_path | 要预测的视频文件路径 |
| use_gpu | 是否使用 GPU 预测, 支持值为0或1(默认值为0)|
| gpu_id | 指定进行推理的GPU device id(默认值为0)|
**注意**:如果同时设置了`video_path``image_path`,程序仅预测`video_path`
......@@ -111,8 +112,8 @@ cd D:\projects\PaddleDetection\deploy\cpp\out\build\x64-Release
`样例二`:
```shell
#使用`GPU`测试视频 `D:\\videos\\test.avi`
.\main --model_dir=D:\\models\\yolov3_darknet --video_path=D:\\videos\\test.jpeg --use_gpu=1
#使用`GPU`测试视频 `D:\\videos\\test.mp4`
.\main --model_dir=D:\\models\\yolov3_darknet --video_path=D:\\videos\\test.mp4 --use_gpu=1
```
视频文件`可视化预测结果`会保存在当前目录下`output.avi`文件中。
视频文件目前支持`.mp4`格式的预测,`可视化预测结果`会保存在当前目录下`output.mp4`文件中。
......@@ -54,12 +54,14 @@ cv::Mat VisualizeResult(const cv::Mat& img,
class ObjectDetector {
public:
explicit ObjectDetector(const std::string& model_dir, bool use_gpu = false,
const std::string& run_mode = "fluid") {
explicit ObjectDetector(const std::string& model_dir,
bool use_gpu=false,
const std::string& run_mode="fluid",
const int gpu_id=0) {
config_.load_config(model_dir);
threshold_ = config_.draw_threshold_;
preprocessor_.Init(config_.preprocess_info_, config_.arch_);
LoadModel(model_dir, use_gpu, config_.min_subgraph_size_, 1, run_mode);
LoadModel(model_dir, use_gpu, config_.min_subgraph_size_, 1, run_mode, gpu_id);
}
// Load Paddle inference model
......@@ -68,7 +70,8 @@ class ObjectDetector {
bool use_gpu,
const int min_subgraph_size,
const int batch_size = 1,
const std::string& run_mode = "fluid");
const std::string& run_mode = "fluid",
const int gpu_id=0);
// Run predictor
void Predict(
......
# download pre-compiled opencv lib
OPENCV_URL=https://paddleseg.bj.bcebos.com/deploy/docker/opencv3gcc4.8.tar.bz2
if [ ! -d "./deps/opencv3gcc4.8" ]; then
OPENCV_URL=https://bj.bcebos.com/paddleseg/deploy/opencv3.4.6gcc4.8ffmpeg.tar.gz2
if [ ! -d "./deps/opencv3.4.6gcc4.8ffmpeg/" ]; then
mkdir -p deps
cd deps
wget -c ${OPENCV_URL}
tar xvfj opencv3gcc4.8.tar.bz2
rm -rf opencv3gcc4.8.tar.bz2
tar xvfj opencv3.4.6gcc4.8ffmpeg.tar.gz2
cd ..
fi
......@@ -18,7 +18,7 @@ CUDNN_LIB=/path/to/cudnn/lib/
# OPENCV 路径, 如果使用自带预编译版本可不修改
sh $(pwd)/scripts/bootstrap.sh # 下载预编译版本的opencv
OPENCV_DIR=$(pwd)/deps/opencv3gcc4.8/
OPENCV_DIR=$(pwd)/deps/opencv3.4.6gcc4.8ffmpeg/
# 以下无需改动
rm -rf build
......
......@@ -25,13 +25,15 @@ DEFINE_string(model_dir, "", "Path of inference model");
DEFINE_string(image_path, "", "Path of input image");
DEFINE_string(video_path, "", "Path of input video");
DEFINE_bool(use_gpu, false, "Infering with GPU or CPU");
DEFINE_string(run_mode, "fluid", "mode of running(fluid/trt_fp32/trt_fp16)");
DEFINE_string(run_mode, "fluid", "Mode of running(fluid/trt_fp32/trt_fp16)");
DEFINE_int32(gpu_id, 0, "Device id of GPU to execute");
void PredictVideo(const std::string& video_path,
PaddleDetection::ObjectDetector* det) {
// Open video
cv::VideoCapture capture;
capture.open(video_path.c_str());
//capture.open(video_path.c_str());
capture.open(video_path);
if (!capture.isOpened()) {
printf("can not open video : %s\n", video_path.c_str());
return;
......@@ -44,9 +46,9 @@ void PredictVideo(const std::string& video_path,
// Create VideoWriter for output
cv::VideoWriter video_out;
std::string video_out_path = "output.avi";
std::string video_out_path = "output.mp4";
video_out.open(video_out_path.c_str(),
CV_FOURCC('M', 'J', 'P', 'G'),
0x00000021,
video_fps,
cv::Size(video_width, video_height),
true);
......@@ -60,6 +62,7 @@ void PredictVideo(const std::string& video_path,
auto colormap = PaddleDetection::GenerateColorMap(labels.size());
// Capture all frames and do inference
cv::Mat frame;
int frame_id = 0;
while (capture.read(frame)) {
if (frame.empty()) {
break;
......@@ -67,7 +70,18 @@ void PredictVideo(const std::string& video_path,
det->Predict(frame, &result);
cv::Mat out_im = PaddleDetection::VisualizeResult(
frame, result, labels, colormap);
for (const auto& item : result) {
printf("In frame id %d, we detect: class=%d confidence=%.2f rect=[%d %d %d %d]\n",
frame_id,
item.class_id,
item.confidence,
item.rect[0],
item.rect[1],
item.rect[2],
item.rect[3]);
}
video_out.write(out_im);
frame_id += 1;
}
capture.release();
video_out.release();
......@@ -97,7 +111,7 @@ void PredictImage(const std::string& image_path,
std::vector<int> compression_params;
compression_params.push_back(CV_IMWRITE_JPEG_QUALITY);
compression_params.push_back(95);
cv::imwrite("output.jpeg", vis_img, compression_params);
cv::imwrite("output.jpg", vis_img, compression_params);
printf("Visualized output saved as output.jpeg\n");
}
......@@ -118,7 +132,7 @@ int main(int argc, char** argv) {
// Load model and create a object detector
PaddleDetection::ObjectDetector det(FLAGS_model_dir, FLAGS_use_gpu,
FLAGS_run_mode);
FLAGS_run_mode, FLAGS_gpu_id);
// Do inference on input video or image
if (!FLAGS_video_path.empty()) {
PredictVideo(FLAGS_video_path, &det);
......
......@@ -21,13 +21,14 @@ void ObjectDetector::LoadModel(const std::string& model_dir,
bool use_gpu,
const int min_subgraph_size,
const int batch_size,
const std::string& run_mode) {
const std::string& run_mode,
const int gpu_id) {
paddle::AnalysisConfig config;
std::string prog_file = model_dir + OS_PATH_SEP + "__model__";
std::string params_file = model_dir + OS_PATH_SEP + "__params__";
config.SetModel(prog_file, params_file);
if (use_gpu) {
config.EnableUseGpu(100, 0);
config.EnableUseGpu(100, gpu_id);
if (run_mode != "fluid") {
auto precision = paddle::AnalysisConfig::Precision::kFloat32;
if (run_mode == "trt_fp16") {
......@@ -182,7 +183,11 @@ void ObjectDetector::Predict(const cv::Mat& im,
// Calculate output length
int output_size = 1;
for (int j = 0; j < output_shape.size(); ++j) {
output_size *= output_shape[j];
output_size *= output_shape[j];
}
if (output_size < 6) {
std::cerr << "[WARNING] No object detected." << std::endl;
}
output_data_.resize(output_size);
out_tensor->copy_to_cpu(output_data_.data());
......
......@@ -48,6 +48,7 @@ python deploy/python/infer.py --model_dir=/path/to/models --image_file=/path/to/
| --run_mode |No|使用GPU时,默认为fluid, 可选(fluid/trt_fp32/trt_fp16)|
| --threshold |No|预测得分的阈值,默认为0.5|
| --output_dir |No|可视化结果保存的根目录,默认为output/|
| --run_benchmark |No|是否运行benchmark,同时需指定--image_file|
说明:
......
......@@ -16,6 +16,8 @@ import os
import argparse
import time
import yaml
import ast
from functools import reduce
from PIL import Image
import cv2
......@@ -286,6 +288,7 @@ class Config():
self.mask_resolution = None
if 'mask_resolution' in yml_conf:
self.mask_resolution = yml_conf['mask_resolution']
self.print_config()
def check_model(self, yml_conf):
"""
......@@ -299,6 +302,15 @@ class Config():
"Unsupported arch: {}, expect SSD, YOLO, RetinaNet, RCNN and Face".
format(yml_conf['arch']))
def print_config(self):
print('----------- Model Configuration -----------')
print('%s: %s' % ('Model Arch', self.arch))
print('%s: %s' % ('Use Padddle Executor', self.use_python_inference))
print('%s: ' % ('Transform Order'))
for op_info in self.preprocess_infos:
print('--%s: %s' % ('transform op', op_info['type']))
print('--------------------------------------------')
def load_predictor(model_dir,
run_mode='fluid',
......@@ -322,6 +334,7 @@ def load_predictor(model_dir,
raise ValueError("TensorRT int8 mode is not supported now, "
"please use trt_fp32 or trt_fp16 instead.")
precision_map = {
'trt_int8': fluid.core.AnalysisConfig.Precision.Int8,
'trt_fp32': fluid.core.AnalysisConfig.Precision.Float32,
'trt_fp16': fluid.core.AnalysisConfig.Precision.Half
}
......@@ -450,7 +463,7 @@ class Detector():
results['masks'] = np_masks
return results
def predict(self, image, threshold=0.5):
def predict(self, image, threshold=0.5, warmup=0, repeats=1):
'''
Args:
image (str/np.ndarray): path of image/ np.ndarray read by cv2
......@@ -464,13 +477,19 @@ class Detector():
inputs, im_info = self.preprocess(image)
np_boxes, np_masks = None, None
if self.config.use_python_inference:
for i in range(warmup):
outs = self.executor.run(self.program,
feed=inputs,
fetch_list=self.fecth_targets,
return_numpy=False)
t1 = time.time()
outs = self.executor.run(self.program,
feed=inputs,
fetch_list=self.fecth_targets,
return_numpy=False)
for i in range(repeats):
outs = self.executor.run(self.program,
feed=inputs,
fetch_list=self.fecth_targets,
return_numpy=False)
t2 = time.time()
ms = (t2 - t1) * 1000.0
ms = (t2 - t1) * 1000.0 / repeats
print("Inference: {} ms per batch image".format(ms))
np_boxes = np.array(outs[0])
......@@ -481,35 +500,55 @@ class Detector():
for i in range(len(inputs)):
input_tensor = self.predictor.get_input_tensor(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]])
t1 = time.time()
self.predictor.zero_copy_run()
t2 = time.time()
output_names = self.predictor.get_output_names()
boxes_tensor = self.predictor.get_output_tensor(output_names[0])
np_boxes = boxes_tensor.copy_to_cpu()
if self.config.mask_resolution is not None:
masks_tensor = self.predictor.get_output_tensor(output_names[1])
np_masks = masks_tensor.copy_to_cpu()
for i in range(warmup):
self.predictor.zero_copy_run()
output_names = self.predictor.get_output_names()
boxes_tensor = self.predictor.get_output_tensor(output_names[0])
np_boxes = boxes_tensor.copy_to_cpu()
if self.config.mask_resolution is not None:
masks_tensor = self.predictor.get_output_tensor(
output_names[1])
np_masks = masks_tensor.copy_to_cpu()
ms = (t2 - t1) * 1000.0
t1 = time.time()
for i in range(repeats):
self.predictor.zero_copy_run()
output_names = self.predictor.get_output_names()
boxes_tensor = self.predictor.get_output_tensor(output_names[0])
np_boxes = boxes_tensor.copy_to_cpu()
if self.config.mask_resolution is not None:
masks_tensor = self.predictor.get_output_tensor(
output_names[1])
np_masks = masks_tensor.copy_to_cpu()
t2 = time.time()
ms = (t2 - t1) * 1000.0 / repeats
print("Inference: {} ms per batch image".format(ms))
results = self.postprocess(
np_boxes, np_masks, im_info, threshold=threshold)
if reduce(lambda x, y: x * y, np_boxes.shape) < 6:
print('[WARNNING] No object detected.')
results = {'boxes': np.array([])}
else:
results = self.postprocess(
np_boxes, np_masks, im_info, threshold=threshold)
return results
def predict_image():
detector = Detector(
FLAGS.model_dir, use_gpu=FLAGS.use_gpu, run_mode=FLAGS.run_mode)
results = detector.predict(FLAGS.image_file, FLAGS.threshold)
visualize(
FLAGS.image_file,
results,
detector.config.labels,
mask_resolution=detector.config.mask_resolution,
output_dir=FLAGS.output_dir)
if FLAGS.run_benchmark:
detector.predict(
FLAGS.image_file, FLAGS.threshold, warmup=100, repeats=100)
else:
results = detector.predict(FLAGS.image_file, FLAGS.threshold)
visualize(
FLAGS.image_file,
results,
detector.config.labels,
mask_resolution=detector.config.mask_resolution,
output_dir=FLAGS.output_dir)
def predict_video():
......@@ -543,6 +582,13 @@ def predict_video():
writer.release()
def print_arguments(args):
print('----------- Running Arguments -----------')
for arg, value in sorted(vars(args).items()):
print('%s: %s' % (arg, value))
print('------------------------------------------')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
......@@ -562,7 +608,15 @@ if __name__ == '__main__':
default='fluid',
help="mode of running(fluid/trt_fp32/trt_fp16)")
parser.add_argument(
"--use_gpu", default=False, help="Whether to predict with GPU.")
"--use_gpu",
type=ast.literal_eval,
default=False,
help="Whether to predict with GPU.")
parser.add_argument(
"--run_benchmark",
type=ast.literal_eval,
default=False,
help="Whether to predict a image_file repeatedly for benchmark")
parser.add_argument(
"--threshold", type=float, default=0.5, help="Threshold of score.")
parser.add_argument(
......@@ -572,6 +626,8 @@ if __name__ == '__main__':
help="Directory of output visualization files.")
FLAGS = parser.parse_args()
print_arguments(FLAGS)
if FLAGS.image_file != '' and FLAGS.video_file != '':
assert "Cannot predict image and video at the same time"
if FLAGS.image_file != '':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册