未验证 提交 c73e8070 编写于 作者: W wangxinxin08 提交者: GitHub

modify cpp inference code (#1401)

Co-authored-by: Nqingqing01 <dangqingqing@baidu.com>
上级 6b924717
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <ctime>
#include <opencv2/core/core.hpp> #include <opencv2/core/core.hpp>
#include <opencv2/imgproc/imgproc.hpp> #include <opencv2/imgproc/imgproc.hpp>
...@@ -74,9 +75,12 @@ class ObjectDetector { ...@@ -74,9 +75,12 @@ class ObjectDetector {
const int gpu_id=0); const int gpu_id=0);
// Run predictor // Run predictor
void Predict( void Predict(const cv::Mat& im,
const cv::Mat& img, const double threshold = 0.5,
std::vector<ObjectResult>* result); const int warmup = 0,
const int repeats = 1,
const bool run_benchmark = false,
std::vector<ObjectResult>* result = nullptr);
// Get Model Label list // Get Model Label list
const std::vector<std::string>& GetLabelList() const { const std::vector<std::string>& GetLabelList() const {
......
...@@ -25,14 +25,24 @@ DEFINE_string(model_dir, "", "Path of inference model"); ...@@ -25,14 +25,24 @@ DEFINE_string(model_dir, "", "Path of inference model");
DEFINE_string(image_path, "", "Path of input image"); DEFINE_string(image_path, "", "Path of input image");
DEFINE_string(video_path, "", "Path of input video"); DEFINE_string(video_path, "", "Path of input video");
DEFINE_bool(use_gpu, false, "Infering with GPU or CPU"); DEFINE_bool(use_gpu, false, "Infering with GPU or CPU");
DEFINE_bool(use_camera, false, "Use camera or not");
DEFINE_string(run_mode, "fluid", "Mode of running(fluid/trt_fp32/trt_fp16)"); DEFINE_string(run_mode, "fluid", "Mode of running(fluid/trt_fp32/trt_fp16)");
DEFINE_int32(gpu_id, 0, "Device id of GPU to execute"); DEFINE_int32(gpu_id, 0, "Device id of GPU to execute");
DEFINE_int32(camera_id, -1, "Device id of camera to predict");
DEFINE_bool(run_benchmark, false, "Whether to predict a image_file repeatedly for benchmark");
DEFINE_double(threshold, 0.5, "Threshold of score.");
DEFINE_string(output_dir, "output", "Directory of output visualization files.");
void PredictVideo(const std::string& video_path, void PredictVideo(const std::string& video_path,
PaddleDetection::ObjectDetector* det) { PaddleDetection::ObjectDetector* det) {
// Open video // Open video
cv::VideoCapture capture; cv::VideoCapture capture;
capture.open(video_path.c_str()); if (FLAGS_camera_id != -1){
capture.open(FLAGS_camera_id);
}else{
capture.open(video_path.c_str());
}
if (!capture.isOpened()) { if (!capture.isOpened()) {
printf("can not open video : %s\n", video_path.c_str()); printf("can not open video : %s\n", video_path.c_str());
return; return;
...@@ -66,7 +76,7 @@ void PredictVideo(const std::string& video_path, ...@@ -66,7 +76,7 @@ void PredictVideo(const std::string& video_path,
if (frame.empty()) { if (frame.empty()) {
break; break;
} }
det->Predict(frame, &result); det->Predict(frame, 0.5, 0, 1, false, &result);
cv::Mat out_im = PaddleDetection::VisualizeResult( cv::Mat out_im = PaddleDetection::VisualizeResult(
frame, result, labels, colormap); frame, result, labels, colormap);
for (const auto& item : result) { for (const auto& item : result) {
...@@ -87,31 +97,40 @@ void PredictVideo(const std::string& video_path, ...@@ -87,31 +97,40 @@ void PredictVideo(const std::string& video_path,
} }
void PredictImage(const std::string& image_path, void PredictImage(const std::string& image_path,
PaddleDetection::ObjectDetector* det) { const double threshold,
const bool run_benchmark,
PaddleDetection::ObjectDetector* det,
const std::string& output_dir = "output") {
// Open input image as an opencv cv::Mat object // Open input image as an opencv cv::Mat object
cv::Mat im = cv::imread(image_path, 1); cv::Mat im = cv::imread(image_path, 1);
// Store all detected result // Store all detected result
std::vector<PaddleDetection::ObjectResult> result; std::vector<PaddleDetection::ObjectResult> result;
det->Predict(im, &result); if (run_benchmark)
for (const auto& item : result) { {
printf("class=%d confidence=%.4f rect=[%d %d %d %d]\n", det->Predict(im, threshold, 100, 100, run_benchmark, &result);
item.class_id, }else
item.confidence, {
item.rect[0], det->Predict(im, 0.5, 0, 1, run_benchmark, &result);
item.rect[1], for (const auto& item : result) {
item.rect[2], printf("class=%d confidence=%.4f rect=[%d %d %d %d]\n",
item.rect[3]); item.class_id,
item.confidence,
item.rect[0],
item.rect[1],
item.rect[2],
item.rect[3]);
}
// Visualization result
auto labels = det->GetLabelList();
auto colormap = PaddleDetection::GenerateColorMap(labels.size());
cv::Mat vis_img = PaddleDetection::VisualizeResult(
im, result, labels, colormap);
std::vector<int> compression_params;
compression_params.push_back(CV_IMWRITE_JPEG_QUALITY);
compression_params.push_back(95);
cv::imwrite(output_dir + "/output.jpg", vis_img, compression_params);
printf("Visualized output saved as output.jpg\n");
} }
// Visualization result
auto labels = det->GetLabelList();
auto colormap = PaddleDetection::GenerateColorMap(labels.size());
cv::Mat vis_img = PaddleDetection::VisualizeResult(
im, result, labels, colormap);
std::vector<int> compression_params;
compression_params.push_back(CV_IMWRITE_JPEG_QUALITY);
compression_params.push_back(95);
cv::imwrite("output.jpg", vis_img, compression_params);
printf("Visualized output saved as output.jpeg\n");
} }
int main(int argc, char** argv) { int main(int argc, char** argv) {
...@@ -133,10 +152,10 @@ int main(int argc, char** argv) { ...@@ -133,10 +152,10 @@ int main(int argc, char** argv) {
PaddleDetection::ObjectDetector det(FLAGS_model_dir, FLAGS_use_gpu, PaddleDetection::ObjectDetector det(FLAGS_model_dir, FLAGS_use_gpu,
FLAGS_run_mode, FLAGS_gpu_id); FLAGS_run_mode, FLAGS_gpu_id);
// Do inference on input video or image // Do inference on input video or image
if (!FLAGS_video_path.empty()) { if (!FLAGS_video_path.empty() || FLAGS_use_camera) {
PredictVideo(FLAGS_video_path, &det); PredictVideo(FLAGS_video_path, &det);
} else if (!FLAGS_image_path.empty()) { } else if (!FLAGS_image_path.empty()) {
PredictImage(FLAGS_image_path, &det); PredictImage(FLAGS_image_path, FLAGS_threshold, FLAGS_run_benchmark, &det, FLAGS_output_dir);
} }
return 0; return 0;
} }
...@@ -39,7 +39,7 @@ void ObjectDetector::LoadModel(const std::string& model_dir, ...@@ -39,7 +39,7 @@ void ObjectDetector::LoadModel(const std::string& model_dir,
printf("TensorRT int8 mode is not supported now, " printf("TensorRT int8 mode is not supported now, "
"please use 'trt_fp32' or 'trt_fp16' instead"); "please use 'trt_fp32' or 'trt_fp16' instead");
} else { } else {
if (run_mode != "trt_32") { if (run_mode != "trt_fp32") {
printf("run_mode should be 'fluid', 'trt_fp32' or 'trt_fp16'"); printf("run_mode should be 'fluid', 'trt_fp32' or 'trt_fp16'");
} }
} }
...@@ -56,6 +56,7 @@ void ObjectDetector::LoadModel(const std::string& model_dir, ...@@ -56,6 +56,7 @@ void ObjectDetector::LoadModel(const std::string& model_dir,
} }
config.SwitchUseFeedFetchOps(false); config.SwitchUseFeedFetchOps(false);
config.SwitchSpecifyInputNames(true); config.SwitchSpecifyInputNames(true);
config.DisableGlogInfo();
// Memory optimization // Memory optimization
config.EnableMemoryOptim(); config.EnableMemoryOptim();
predictor_ = std::move(CreatePaddlePredictor(config)); predictor_ = std::move(CreatePaddlePredictor(config));
...@@ -155,7 +156,11 @@ void ObjectDetector::Postprocess( ...@@ -155,7 +156,11 @@ void ObjectDetector::Postprocess(
} }
void ObjectDetector::Predict(const cv::Mat& im, void ObjectDetector::Predict(const cv::Mat& im,
std::vector<ObjectResult>* result) { const double threshold,
const int warmup,
const int repeats,
const bool run_benchmark,
std::vector<ObjectResult>* result) {
// Preprocess image // Preprocess image
Preprocess(im); Preprocess(im);
// Prepare input tensor // Prepare input tensor
...@@ -182,24 +187,53 @@ void ObjectDetector::Predict(const cv::Mat& im, ...@@ -182,24 +187,53 @@ void ObjectDetector::Predict(const cv::Mat& im,
} }
} }
// Run predictor // Run predictor
predictor_->ZeroCopyRun(); for (int i = 0; i < warmup; i++)
// Get output tensor {
auto output_names = predictor_->GetOutputNames(); predictor_->ZeroCopyRun();
auto out_tensor = predictor_->GetOutputTensor(output_names[0]); // Get output tensor
std::vector<int> output_shape = out_tensor->shape(); auto output_names = predictor_->GetOutputNames();
// Calculate output length auto out_tensor = predictor_->GetOutputTensor(output_names[0]);
int output_size = 1; std::vector<int> output_shape = out_tensor->shape();
for (int j = 0; j < output_shape.size(); ++j) { // Calculate output length
output_size *= output_shape[j]; int output_size = 1;
for (int j = 0; j < output_shape.size(); ++j) {
output_size *= output_shape[j];
}
if (output_size < 6) {
std::cerr << "[WARNING] No object detected." << std::endl;
}
output_data_.resize(output_size);
out_tensor->copy_to_cpu(output_data_.data());
} }
if (output_size < 6) { std::clock_t start = clock();
std::cerr << "[WARNING] No object detected." << std::endl; for (int i = 0; i < repeats; i++)
{
predictor_->ZeroCopyRun();
// Get output tensor
auto output_names = predictor_->GetOutputNames();
auto out_tensor = predictor_->GetOutputTensor(output_names[0]);
std::vector<int> output_shape = out_tensor->shape();
// Calculate output length
int output_size = 1;
for (int j = 0; j < output_shape.size(); ++j) {
output_size *= output_shape[j];
}
if (output_size < 6) {
std::cerr << "[WARNING] No object detected." << std::endl;
}
output_data_.resize(output_size);
out_tensor->copy_to_cpu(output_data_.data());
} }
output_data_.resize(output_size); std::clock_t end = clock();
out_tensor->copy_to_cpu(output_data_.data()); float ms = static_cast<float>(end - start) / CLOCKS_PER_SEC / repeats * 1000.;
printf("Inference: %f ms per batch image\n", ms);
// Postprocessing result // Postprocessing result
Postprocess(im, result); if(!run_benchmark) {
Postprocess(im, result);
}
} }
std::vector<int> GenerateColorMap(int num_class) { std::vector<int> GenerateColorMap(int num_class) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册