未验证 提交 e27674ff 编写于 作者: W wangguanzhong 提交者: GitHub

support rcnn bs=2 in cpp infer (#4854)

* support rcnn bs=2 in cpp infer

* check dynamic shape in padbatch

* replace fluid by paddle

* fix time display in pptracking cpp

* fix timer in bs=2
上级 e2ea3610
......@@ -169,7 +169,7 @@ WITH_KEYPOINT=ON
| --camera_id | Option | 用来预测的摄像头ID,默认为-1(表示不使用摄像头预测)|
| --device | 运行时的设备,可选择`CPU/GPU/XPU`,默认为`CPU`|
| --gpu_id | 指定进行推理的GPU device id(默认值为0)|
| --run_mode | 使用GPU时,默认为fluid, 可选(fluid/trt_fp32/trt_fp16/trt_int8)|
| --run_mode | 使用GPU时,默认为paddle, 可选(paddle/trt_fp32/trt_fp16/trt_int8)|
| --batch_size | 检测模型预测时的batch size,在指定`image_dir`时有效 |
| --batch_size_keypoint | 关键点模型预测时的batch size,默认为8 |
| --run_benchmark | 是否重复预测来进行benchmark测速 |
......
......@@ -108,7 +108,7 @@ make
| --camera_id | Option | 用来预测的摄像头ID,默认为-1(表示不使用摄像头预测)|
| --device | 运行时的设备,可选择`CPU/GPU/XPU`,默认为`CPU`|
| --gpu_id | 指定进行推理的GPU device id(默认值为0)|
| --run_mode | 使用GPU时,默认为fluid, 可选(fluid/trt_fp32/trt_fp16/trt_int8)|
| --run_mode | 使用GPU时,默认为paddle, 可选(paddle/trt_fp32/trt_fp16/trt_int8)|
| --batch_size | 检测模型预测时的batch size,在指定`image_dir`时有效 |
| --batch_size_keypoint | 关键点模型预测时的batch size,默认为8 |
| --run_benchmark | 是否重复预测来进行benchmark测速 |
......
......@@ -107,7 +107,7 @@ cd D:\projects\PaddleDetection\deploy\cpp\out\build\x64-Release
| --camera_id | Option | 用来预测的摄像头ID,默认为-1(表示不使用摄像头预测)|
| --device | 运行时的设备,可选择`CPU/GPU/XPU`,默认为`CPU`|
| --gpu_id | 指定进行推理的GPU device id(默认值为0)|
| --run_mode | 使用GPU时,默认为fluid, 可选(fluid/trt_fp32/trt_fp16/trt_int8)|
| --run_mode | 使用GPU时,默认为paddle, 可选(paddle/trt_fp32/trt_fp16/trt_int8)|
| --batch_size | 检测模型预测时的batch size,在指定`image_dir`时有效 |
| --batch_size_keypoint | 关键点模型预测时的batch size,默认为8 |
| --run_benchmark | 是否重复预测来进行benchmark测速 |
......
......@@ -15,9 +15,9 @@
#pragma once
#include <iostream>
#include <vector>
#include <string>
#include <map>
#include <string>
#include <vector>
#include "yaml-cpp/yaml.h"
......@@ -42,13 +42,12 @@ class ConfigPaser {
YAML::Node config;
config = YAML::LoadFile(model_dir + OS_PATH_SEP + cfg);
// Get runtime mode : fluid, trt_fp16, trt_fp32
// Get runtime mode : paddle, trt_fp16, trt_fp32
if (config["mode"].IsDefined()) {
mode_ = config["mode"].as<std::string>();
} else {
std::cerr << "Please set mode, "
<< "support value : fluid/trt_fp16/trt_fp32."
<< std::endl;
<< "support value : paddle/trt_fp16/trt_fp32." << std::endl;
return false;
}
......@@ -136,4 +135,3 @@ class ConfigPaser {
};
} // namespace PaddleDetection
......@@ -14,39 +14,37 @@
#pragma once
#include <string>
#include <vector>
#include <ctime>
#include <memory>
#include <string>
#include <utility>
#include <ctime>
#include <vector>
#include <opencv2/core/core.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include "paddle_inference_api.h" // NOLINT
#include "paddle_inference_api.h" // NOLINT
#include "include/preprocess_op.h"
#include "include/config_parser.h"
#include "include/preprocess_op.h"
#include "include/tracker.h"
using namespace paddle_infer;
namespace PaddleDetection {
// JDE Detection Result
struct MOT_Rect
{
float left;
float top;
float right;
float bottom;
struct MOT_Rect {
float left;
float top;
float right;
float bottom;
};
struct MOT_Track
{
int ids;
float score;
MOT_Rect rects;
struct MOT_Track {
int ids;
float score;
MOT_Rect rects;
};
typedef std::vector<MOT_Track> MOT_Result;
......@@ -56,24 +54,24 @@ cv::Scalar GetColor(int idx);
// Visualiztion Detection Result
cv::Mat VisualizeTrackResult(const cv::Mat& img,
const MOT_Result& results,
const float fps, const int frame_id);
const MOT_Result& results,
const float fps,
const int frame_id);
class JDEDetector {
public:
explicit JDEDetector(const std::string& model_dir,
const std::string& device="CPU",
bool use_mkldnn=false,
int cpu_threads=1,
const std::string& run_mode="fluid",
const int batch_size=1,
const int gpu_id=0,
const int trt_min_shape=1,
const int trt_max_shape=1280,
const int trt_opt_shape=640,
bool trt_calib_mode=false,
const int min_box_area=200) {
explicit JDEDetector(const std::string& model_dir,
const std::string& device = "CPU",
bool use_mkldnn = false,
int cpu_threads = 1,
const std::string& run_mode = "paddle",
const int batch_size = 1,
const int gpu_id = 0,
const int trt_min_shape = 1,
const int trt_max_shape = 1280,
const int trt_opt_shape = 640,
bool trt_calib_mode = false,
const int min_box_area = 200) {
this->device_ = device;
this->gpu_id_ = gpu_id;
this->cpu_math_library_num_threads_ = cpu_threads;
......@@ -94,18 +92,17 @@ class JDEDetector {
}
// Load Paddle inference model
void LoadModel(
const std::string& model_dir,
const int batch_size = 1,
const std::string& run_mode = "fluid");
void LoadModel(const std::string& model_dir,
const int batch_size = 1,
const std::string& run_mode = "paddle");
// Run predictor
void Predict(const std::vector<cv::Mat> imgs,
const double threshold = 0.5,
const int warmup = 0,
const int repeats = 1,
MOT_Result* result = nullptr,
std::vector<double>* times = nullptr);
const double threshold = 0.5,
const int warmup = 0,
const int repeats = 1,
MOT_Result* result = nullptr,
std::vector<double>* times = nullptr);
private:
std::string device_ = "CPU";
......@@ -121,9 +118,7 @@ class JDEDetector {
// Preprocess image and copy data to input buffer
void Preprocess(const cv::Mat& image_mat);
// Postprocess result
void Postprocess(
const cv::Mat dets, const cv::Mat emb,
MOT_Result* result);
void Postprocess(const cv::Mat dets, const cv::Mat emb, MOT_Result* result);
std::shared_ptr<Predictor> predictor_;
Preprocessor preprocessor_;
......
......@@ -51,7 +51,7 @@ class KeyPointDetector {
const std::string& device = "CPU",
bool use_mkldnn = false,
int cpu_threads = 1,
const std::string& run_mode = "fluid",
const std::string& run_mode = "paddle",
const int batch_size = 1,
const int gpu_id = 0,
const int trt_min_shape = 1,
......@@ -80,7 +80,7 @@ class KeyPointDetector {
// Load Paddle inference model
void LoadModel(const std::string& model_dir,
const int batch_size = 1,
const std::string& run_mode = "fluid");
const std::string& run_mode = "paddle");
// Run predictor
void Predict(const std::vector<cv::Mat> imgs,
......
......@@ -14,23 +14,23 @@
#pragma once
#include <string>
#include <vector>
#include <memory>
#include <utility>
#include <ctime>
#include <memory>
#include <numeric>
#include <string>
#include <utility>
#include <vector>
#include <opencv2/core/core.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include "paddle_inference_api.h" // NOLINT
#include "paddle_inference_api.h" // NOLINT
#include "include/preprocess_op.h"
#include "include/config_parser.h"
#include "include/utils.h"
#include "include/picodet_postprocess.h"
#include "include/preprocess_op.h"
#include "include/utils.h"
using namespace paddle_infer;
......@@ -39,28 +39,27 @@ namespace PaddleDetection {
// Generate visualization colormap for each class
std::vector<int> GenerateColorMap(int num_class);
// Visualiztion Detection Result
cv::Mat VisualizeResult(const cv::Mat& img,
const std::vector<PaddleDetection::ObjectResult>& results,
const std::vector<std::string>& lables,
const std::vector<int>& colormap,
const bool is_rbox);
cv::Mat VisualizeResult(
const cv::Mat& img,
const std::vector<PaddleDetection::ObjectResult>& results,
const std::vector<std::string>& lables,
const std::vector<int>& colormap,
const bool is_rbox);
class ObjectDetector {
public:
explicit ObjectDetector(const std::string& model_dir,
const std::string& device="CPU",
bool use_mkldnn=false,
int cpu_threads=1,
const std::string& run_mode="fluid",
const int batch_size=1,
const int gpu_id=0,
const int trt_min_shape=1,
const int trt_max_shape=1280,
const int trt_opt_shape=640,
bool trt_calib_mode=false) {
explicit ObjectDetector(const std::string& model_dir,
const std::string& device = "CPU",
bool use_mkldnn = false,
int cpu_threads = 1,
const std::string& run_mode = "paddle",
const int batch_size = 1,
const int gpu_id = 0,
const int trt_min_shape = 1,
const int trt_max_shape = 1280,
const int trt_opt_shape = 640,
bool trt_calib_mode = false) {
this->device_ = device;
this->gpu_id_ = gpu_id;
this->cpu_math_library_num_threads_ = cpu_threads;
......@@ -79,19 +78,18 @@ class ObjectDetector {
}
// Load Paddle inference model
void LoadModel(
const std::string& model_dir,
const int batch_size = 1,
const std::string& run_mode = "fluid");
void LoadModel(const std::string& model_dir,
const int batch_size = 1,
const std::string& run_mode = "paddle");
// Run predictor
void Predict(const std::vector<cv::Mat> imgs,
const double threshold = 0.5,
const int warmup = 0,
const int repeats = 1,
std::vector<PaddleDetection::ObjectResult>* result = nullptr,
std::vector<int>* bbox_num = nullptr,
std::vector<double>* times = nullptr);
const double threshold = 0.5,
const int warmup = 0,
const int repeats = 1,
std::vector<PaddleDetection::ObjectResult>* result = nullptr,
std::vector<int>* bbox_num = nullptr,
std::vector<double>* times = nullptr);
// Get Model Label list
const std::vector<std::string>& GetLabelList() const {
......@@ -112,19 +110,17 @@ class ObjectDetector {
// Preprocess image and copy data to input buffer
void Preprocess(const cv::Mat& image_mat);
// Postprocess result
void Postprocess(
const std::vector<cv::Mat> mats,
std::vector<PaddleDetection::ObjectResult>* result,
std::vector<int> bbox_num,
std::vector<float> output_data_,
bool is_rbox);
void Postprocess(const std::vector<cv::Mat> mats,
std::vector<PaddleDetection::ObjectResult>* result,
std::vector<int> bbox_num,
std::vector<float> output_data_,
bool is_rbox);
std::shared_ptr<Predictor> predictor_;
Preprocessor preprocessor_;
ImageBlob inputs_;
float threshold_;
ConfigPaser config_;
};
} // namespace PaddleDetection
......@@ -17,16 +17,16 @@
#include <glog/logging.h>
#include <yaml-cpp/yaml.h>
#include <vector>
#include <string>
#include <utility>
#include <iostream>
#include <memory>
#include <string>
#include <unordered_map>
#include <iostream>
#include <utility>
#include <vector>
#include <opencv2/core/core.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
namespace PaddleDetection {
......@@ -40,9 +40,11 @@ class ImageBlob {
// in net data shape(after pad)
std::vector<float> in_net_shape_;
// Evaluation image width and height
//std::vector<float> eval_im_size_f_;
// std::vector<float> eval_im_size_f_;
// Scale factor for image size to origin image size
std::vector<float> scale_factor_;
// in net image after preprocessing
cv::Mat in_net_im_;
};
// Abstraction of preprocessing opration class
......@@ -52,7 +54,7 @@ class PreprocessOp {
virtual void Run(cv::Mat* im, ImageBlob* data) = 0;
};
class InitInfo : public PreprocessOp{
class InitInfo : public PreprocessOp {
public:
virtual void Init(const YAML::Node& item) {}
virtual void Run(cv::Mat* im, ImageBlob* data);
......@@ -79,7 +81,6 @@ class Permute : public PreprocessOp {
public:
virtual void Init(const YAML::Node& item) {}
virtual void Run(cv::Mat* im, ImageBlob* data);
};
class Resize : public PreprocessOp {
......@@ -88,7 +89,7 @@ class Resize : public PreprocessOp {
interp_ = item["interp"].as<int>();
keep_ratio_ = item["keep_ratio"].as<bool>();
target_size_ = item["target_size"].as<std::vector<int>>();
}
}
// Compute best resize scale for x-dimension, y-dimension
std::pair<float, float> GenerateScale(const cv::Mat& im);
......@@ -106,7 +107,7 @@ class LetterBoxResize : public PreprocessOp {
public:
virtual void Init(const YAML::Node& item) {
target_size_ = item["target_size"].as<std::vector<int>>();
}
}
float GenerateScale(const cv::Mat& im);
......@@ -133,7 +134,7 @@ class TopDownEvalAffine : public PreprocessOp {
public:
virtual void Init(const YAML::Node& item) {
trainsize_ = item["trainsize"].as<std::vector<int>>();
}
}
virtual void Run(cv::Mat* im, ImageBlob* data);
......@@ -142,7 +143,18 @@ class TopDownEvalAffine : public PreprocessOp {
std::vector<int> trainsize_;
};
void CropImg(cv::Mat &img, cv::Mat &crop_img, std::vector<int> &area, std::vector<float> &center, std::vector<float> &scale, float expandratio=0.15);
void CropImg(cv::Mat& img,
cv::Mat& crop_img,
std::vector<int>& area,
std::vector<float>& center,
std::vector<float>& scale,
float expandratio = 0.15);
// check whether the input size is dynamic
bool CheckDynamicInput(const std::vector<cv::Mat>& imgs);
// Pad images in batch
std::vector<cv::Mat> PadBatch(const std::vector<cv::Mat>& imgs);
class Preprocessor {
public:
......@@ -172,7 +184,8 @@ class Preprocessor {
} else if (name == "TopDownEvalAffine") {
return std::make_shared<TopDownEvalAffine>();
}
std::cerr << "can not find function of OP: " << name << " and return: nullptr" << std::endl;
std::cerr << "can not find function of OP: " << name
<< " and return: nullptr" << std::endl;
return nullptr;
}
......@@ -186,4 +199,3 @@ class Preprocessor {
};
} // namespace PaddleDetection
......@@ -13,19 +13,18 @@
// limitations under the License.
#include <sstream>
// for setprecision
#include <iomanip>
#include <chrono>
#include <iomanip>
#include "include/jde_detector.h"
using namespace paddle_infer;
namespace PaddleDetection {
// Load Model and create model predictor
void JDEDetector::LoadModel(const std::string& model_dir,
const int batch_size,
const std::string& run_mode) {
const int batch_size,
const std::string& run_mode) {
paddle_infer::Config config;
std::string prog_file = model_dir + OS_PATH_SEP + "model.pdmodel";
std::string params_file = model_dir + OS_PATH_SEP + "model.pdiparams";
......@@ -34,47 +33,51 @@ void JDEDetector::LoadModel(const std::string& model_dir,
config.EnableUseGpu(200, this->gpu_id_);
config.SwitchIrOptim(true);
// use tensorrt
if (run_mode != "fluid") {
if (run_mode != "paddle") {
auto precision = paddle_infer::Config::Precision::kFloat32;
if (run_mode == "trt_fp32") {
precision = paddle_infer::Config::Precision::kFloat32;
}
else if (run_mode == "trt_fp16") {
} else if (run_mode == "trt_fp16") {
precision = paddle_infer::Config::Precision::kHalf;
}
else if (run_mode == "trt_int8") {
} else if (run_mode == "trt_int8") {
precision = paddle_infer::Config::Precision::kInt8;
} else {
printf("run_mode should be 'fluid', 'trt_fp32', 'trt_fp16' or 'trt_int8'");
printf(
"run_mode should be 'paddle', 'trt_fp32', 'trt_fp16' or "
"'trt_int8'");
}
// set tensorrt
config.EnableTensorRtEngine(
1 << 30,
batch_size,
this->min_subgraph_size_,
precision,
false,
this->trt_calib_mode_);
config.EnableTensorRtEngine(1 << 30,
batch_size,
this->min_subgraph_size_,
precision,
false,
this->trt_calib_mode_);
// set use dynamic shape
if (this->use_dynamic_shape_) {
// set DynamicShsape for image tensor
const std::vector<int> min_input_shape = {1, 3, this->trt_min_shape_, this->trt_min_shape_};
const std::vector<int> max_input_shape = {1, 3, this->trt_max_shape_, this->trt_max_shape_};
const std::vector<int> opt_input_shape = {1, 3, this->trt_opt_shape_, this->trt_opt_shape_};
const std::map<std::string, std::vector<int>> map_min_input_shape = {{"image", min_input_shape}};
const std::map<std::string, std::vector<int>> map_max_input_shape = {{"image", max_input_shape}};
const std::map<std::string, std::vector<int>> map_opt_input_shape = {{"image", opt_input_shape}};
config.SetTRTDynamicShapeInfo(map_min_input_shape,
map_max_input_shape,
map_opt_input_shape);
const std::vector<int> min_input_shape = {
1, 3, this->trt_min_shape_, this->trt_min_shape_};
const std::vector<int> max_input_shape = {
1, 3, this->trt_max_shape_, this->trt_max_shape_};
const std::vector<int> opt_input_shape = {
1, 3, this->trt_opt_shape_, this->trt_opt_shape_};
const std::map<std::string, std::vector<int>> map_min_input_shape = {
{"image", min_input_shape}};
const std::map<std::string, std::vector<int>> map_max_input_shape = {
{"image", max_input_shape}};
const std::map<std::string, std::vector<int>> map_opt_input_shape = {
{"image", opt_input_shape}};
config.SetTRTDynamicShapeInfo(
map_min_input_shape, map_max_input_shape, map_opt_input_shape);
std::cout << "TensorRT dynamic shape enabled" << std::endl;
}
}
} else if (this->device_ == "XPU"){
config.EnableXpu(10*1024*1024);
} else if (this->device_ == "XPU") {
config.EnableXpu(10 * 1024 * 1024);
} else {
config.DisableGpu();
if (this->use_mkldnn_) {
......@@ -94,8 +97,9 @@ void JDEDetector::LoadModel(const std::string& model_dir,
// Visualiztion results
cv::Mat VisualizeTrackResult(const cv::Mat& img,
const MOT_Result& results,
const float fps, const int frame_id) {
const MOT_Result& results,
const float fps,
const int frame_id) {
cv::Mat vis_img = img.clone();
int im_h = img.rows;
int im_w = img.cols;
......@@ -105,31 +109,34 @@ cv::Mat VisualizeTrackResult(const cv::Mat& img,
std::ostringstream oss;
oss << std::setiosflags(std::ios::fixed) << std::setprecision(4);
oss << "frame: " << frame_id<<" ";
oss << "fps: " << fps<<" ";
oss << "frame: " << frame_id << " ";
oss << "fps: " << fps << " ";
oss << "num: " << results.size();
std::string text = oss.str();
cv::Point origin;
origin.x = 0;
origin.y = int(15 * text_scale);
cv::putText(
vis_img,
text,
origin,
cv::FONT_HERSHEY_PLAIN,
text_scale, (0, 0, 255), 2);
cv::putText(vis_img,
text,
origin,
cv::FONT_HERSHEY_PLAIN,
text_scale,
(0, 0, 255),
2);
for (int i = 0; i < results.size(); ++i) {
const int obj_id = results[i].ids;
const float score = results[i].score;
cv::Scalar color = GetColor(obj_id);
cv::Point pt1 = cv::Point(results[i].rects.left, results[i].rects.top);
cv::Point pt2 = cv::Point(results[i].rects.right, results[i].rects.bottom);
cv::Point id_pt = cv::Point(results[i].rects.left, results[i].rects.top + 10);
cv::Point score_pt = cv::Point(results[i].rects.left, results[i].rects.top - 10);
cv::Point id_pt =
cv::Point(results[i].rects.left, results[i].rects.top + 10);
cv::Point score_pt =
cv::Point(results[i].rects.left, results[i].rects.top - 10);
cv::rectangle(vis_img, pt1, pt2, color, line_thickness);
std::ostringstream idoss;
......@@ -157,13 +164,13 @@ cv::Mat VisualizeTrackResult(const cv::Mat& img,
text_scale,
cv::Scalar(0, 255, 255),
text_thickness);
}
return vis_img;
}
void FilterDets(const float conf_thresh, const cv::Mat dets, std::vector<int>* index) {
void FilterDets(const float conf_thresh,
const cv::Mat dets,
std::vector<int>* index) {
for (int i = 0; i < dets.rows; ++i) {
float score = *dets.ptr<float>(i, 4);
if (score > conf_thresh) {
......@@ -178,9 +185,9 @@ void JDEDetector::Preprocess(const cv::Mat& ori_im) {
preprocessor_.Run(&im, &inputs_);
}
void JDEDetector::Postprocess(
const cv::Mat dets, const cv::Mat emb,
MOT_Result* result) {
void JDEDetector::Postprocess(const cv::Mat dets,
const cv::Mat emb,
MOT_Result* result) {
result->clear();
std::vector<Track> tracks;
std::vector<int> valid;
......@@ -193,7 +200,7 @@ void JDEDetector::Postprocess(
JDETracker::instance()->update(new_dets, new_emb, tracks);
if (tracks.size() == 0) {
MOT_Track mot_track;
MOT_Rect ret = {*dets.ptr<float>(0, 0),
MOT_Rect ret = {*dets.ptr<float>(0, 0),
*dets.ptr<float>(0, 1),
*dets.ptr<float>(0, 2),
*dets.ptr<float>(0, 3)};
......@@ -213,26 +220,24 @@ void JDEDetector::Postprocess(
float area = w * h;
if (area > min_box_area_ && !vertical) {
MOT_Track mot_track;
MOT_Rect ret = {titer->ltrb[0],
titer->ltrb[1],
titer->ltrb[2],
titer->ltrb[3]};
MOT_Rect ret = {
titer->ltrb[0], titer->ltrb[1], titer->ltrb[2], titer->ltrb[3]};
mot_track.rects = ret;
mot_track.score = titer->score;
mot_track.ids = titer->id;
result->push_back(mot_track);
}
}
}
}
}
}
void JDEDetector::Predict(const std::vector<cv::Mat> imgs,
const double threshold,
const int warmup,
const int repeats,
MOT_Result* result,
std::vector<double>* times) {
const double threshold,
const int warmup,
const int repeats,
MOT_Result* result,
std::vector<double>* times) {
auto preprocess_start = std::chrono::steady_clock::now();
int batch_size = imgs.size();
......@@ -240,7 +245,7 @@ void JDEDetector::Predict(const std::vector<cv::Mat> imgs,
std::vector<float> in_data_all;
std::vector<float> im_shape_all(batch_size * 2);
std::vector<float> scale_factor_all(batch_size * 2);
// Preprocess image
for (int bs_idx = 0; bs_idx < batch_size; bs_idx++) {
cv::Mat im = imgs.at(bs_idx);
......@@ -252,7 +257,8 @@ void JDEDetector::Predict(const std::vector<cv::Mat> imgs,
scale_factor_all[bs_idx * 2 + 1] = inputs_.scale_factor_[1];
// TODO: reduce cost time
in_data_all.insert(in_data_all.end(), inputs_.im_data_.begin(), inputs_.im_data_.end());
in_data_all.insert(
in_data_all.end(), inputs_.im_data_.begin(), inputs_.im_data_.end());
}
// Prepare input tensor
......@@ -272,14 +278,13 @@ void JDEDetector::Predict(const std::vector<cv::Mat> imgs,
in_tensor->CopyFromCpu(scale_factor_all.data());
}
}
auto preprocess_end = std::chrono::steady_clock::now();
std::vector<int> bbox_shape;
std::vector<int> emb_shape;
// Run predictor
// warmup
for (int i = 0; i < warmup; i++)
{
for (int i = 0; i < warmup; i++) {
predictor_->Run();
// Get output tensor
auto output_names = predictor_->GetOutputNames();
......@@ -299,15 +304,14 @@ void JDEDetector::Predict(const std::vector<cv::Mat> imgs,
}
bbox_data_.resize(bbox_size);
bbox_tensor->CopyToCpu(bbox_data_.data());
bbox_tensor->CopyToCpu(bbox_data_.data());
emb_data_.resize(emb_size);
emb_tensor->CopyToCpu(emb_data_.data());
}
auto inference_start = std::chrono::steady_clock::now();
for (int i = 0; i < repeats; i++)
{
for (int i = 0; i < repeats; i++) {
predictor_->Run();
// Get output tensor
auto output_names = predictor_->GetOutputNames();
......@@ -327,7 +331,7 @@ void JDEDetector::Predict(const std::vector<cv::Mat> imgs,
}
bbox_data_.resize(bbox_size);
bbox_tensor->CopyToCpu(bbox_data_.data());
bbox_tensor->CopyToCpu(bbox_data_.data());
emb_data_.resize(emb_size);
emb_tensor->CopyToCpu(emb_data_.data());
......@@ -344,19 +348,20 @@ void JDEDetector::Predict(const std::vector<cv::Mat> imgs,
auto postprocess_end = std::chrono::steady_clock::now();
std::chrono::duration<float> preprocess_diff = preprocess_end - preprocess_start;
std::chrono::duration<float> preprocess_diff =
preprocess_end - preprocess_start;
(*times)[0] += double(preprocess_diff.count() * 1000);
std::chrono::duration<float> inference_diff = inference_end - inference_start;
(*times)[1] += double(inference_diff.count() * 1000);
std::chrono::duration<float> postprocess_diff = postprocess_end - postprocess_start;
std::chrono::duration<float> postprocess_diff =
postprocess_end - postprocess_start;
(*times)[2] += double(postprocess_diff.count() * 1000);
}
cv::Scalar GetColor(int idx) {
idx = idx * 3;
cv::Scalar color = cv::Scalar((37 * idx) % 255,
(17 * idx) % 255,
(29 * idx) % 255);
cv::Scalar color =
cv::Scalar((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255);
return color;
}
......
......@@ -33,7 +33,7 @@ void KeyPointDetector::LoadModel(const std::string& model_dir,
config.EnableUseGpu(200, this->gpu_id_);
config.SwitchIrOptim(true);
// use tensorrt
if (run_mode != "fluid") {
if (run_mode != "paddle") {
auto precision = paddle_infer::Config::Precision::kFloat32;
if (run_mode == "trt_fp32") {
precision = paddle_infer::Config::Precision::kFloat32;
......@@ -43,7 +43,8 @@ void KeyPointDetector::LoadModel(const std::string& model_dir,
precision = paddle_infer::Config::Precision::kInt8;
} else {
printf(
"run_mode should be 'fluid', 'trt_fp32', 'trt_fp16' or 'trt_int8'");
"run_mode should be 'paddle', 'trt_fp32', 'trt_fp16' or "
"'trt_int8'");
}
// set tensorrt
config.EnableTensorRtEngine(1 << 30,
......@@ -99,22 +100,22 @@ cv::Mat VisualizeKptsResult(const cv::Mat& img,
const std::vector<KeyPointResult>& results,
const std::vector<int>& colormap) {
const int edge[][2] = {{0, 1},
{0, 2},
{1, 3},
{2, 4},
{3, 5},
{4, 6},
{5, 7},
{6, 8},
{7, 9},
{8, 10},
{5, 11},
{6, 12},
{11, 13},
{12, 14},
{13, 15},
{14, 16},
{11, 12}};
{0, 2},
{1, 3},
{2, 4},
{3, 5},
{4, 6},
{5, 7},
{6, 8},
{7, 9},
{8, 10},
{5, 11},
{6, 12},
{11, 13},
{12, 14},
{13, 15},
{14, 16},
{11, 12}};
cv::Mat vis_img = img.clone();
for (int batchid = 0; batchid < results.size(); batchid++) {
for (int i = 0; i < results[batchid].num_joints; i++) {
......
......@@ -14,14 +14,14 @@
#include <glog/logging.h>
#include <math.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <algorithm>
#include <iostream>
#include <numeric>
#include <string>
#include <vector>
#include <numeric>
#include <sys/types.h>
#include <sys/stat.h>
#include <math.h>
#include <algorithm>
#ifdef _WIN32
#include <direct.h>
......@@ -31,62 +31,86 @@
#include <sys/stat.h>
#endif
#include "include/object_detector.h"
#include <gflags/gflags.h>
#include "include/object_detector.h"
DEFINE_string(model_dir, "", "Path of inference model");
DEFINE_string(image_file, "", "Path of input image");
DEFINE_string(image_dir, "", "Dir of input image, `image_file` has a higher priority.");
DEFINE_string(image_dir,
"",
"Dir of input image, `image_file` has a higher priority.");
DEFINE_int32(batch_size, 1, "batch_size");
DEFINE_string(video_file, "", "Path of input video, `video_file` or `camera_id` has a highest priority.");
DEFINE_string(
video_file,
"",
"Path of input video, `video_file` or `camera_id` has a highest priority.");
DEFINE_int32(camera_id, -1, "Device id of camera to predict");
DEFINE_bool(use_gpu, false, "Deprecated, please use `--device` to set the device you want to run.");
DEFINE_string(device, "CPU", "Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU.");
DEFINE_bool(
use_gpu,
false,
"Deprecated, please use `--device` to set the device you want to run.");
DEFINE_string(device,
"CPU",
"Choose the device you want to run, it can be: CPU/GPU/XPU, "
"default is CPU.");
DEFINE_double(threshold, 0.5, "Threshold of score.");
DEFINE_string(output_dir, "output", "Directory of output visualization files.");
DEFINE_string(run_mode, "fluid", "Mode of running(fluid/trt_fp32/trt_fp16/trt_int8)");
DEFINE_string(run_mode,
"paddle",
"Mode of running(paddle/trt_fp32/trt_fp16/trt_int8)");
DEFINE_int32(gpu_id, 0, "Device id of GPU to execute");
DEFINE_bool(run_benchmark, false, "Whether to predict a image_file repeatedly for benchmark");
DEFINE_bool(run_benchmark,
false,
"Whether to predict a image_file repeatedly for benchmark");
DEFINE_bool(use_mkldnn, false, "Whether use mkldnn with CPU");
DEFINE_int32(cpu_threads, 1, "Num of threads with CPU");
DEFINE_int32(trt_min_shape, 1, "Min shape of TRT DynamicShapeI");
DEFINE_int32(trt_max_shape, 1280, "Max shape of TRT DynamicShapeI");
DEFINE_int32(trt_opt_shape, 640, "Opt shape of TRT DynamicShapeI");
DEFINE_bool(trt_calib_mode, false, "If the model is produced by TRT offline quantitative calibration, trt_calib_mode need to set True");
DEFINE_bool(trt_calib_mode,
false,
"If the model is produced by TRT offline quantitative calibration, "
"trt_calib_mode need to set True");
void PrintBenchmarkLog(std::vector<double> det_time, int img_num){
void PrintBenchmarkLog(std::vector<double> det_time, int img_num) {
LOG(INFO) << "----------------------- Config info -----------------------";
LOG(INFO) << "runtime_device: " << FLAGS_device;
LOG(INFO) << "ir_optim: " << "True";
LOG(INFO) << "enable_memory_optim: " << "True";
LOG(INFO) << "ir_optim: "
<< "True";
LOG(INFO) << "enable_memory_optim: "
<< "True";
int has_trt = FLAGS_run_mode.find("trt");
if (has_trt >= 0) {
LOG(INFO) << "enable_tensorrt: " << "True";
LOG(INFO) << "enable_tensorrt: "
<< "True";
std::string precision = FLAGS_run_mode.substr(4, 8);
LOG(INFO) << "precision: " << precision;
} else {
LOG(INFO) << "enable_tensorrt: " << "False";
LOG(INFO) << "precision: " << "fp32";
LOG(INFO) << "enable_tensorrt: "
<< "False";
LOG(INFO) << "precision: "
<< "fp32";
}
LOG(INFO) << "enable_mkldnn: " << (FLAGS_use_mkldnn ? "True" : "False");
LOG(INFO) << "cpu_math_library_num_threads: " << FLAGS_cpu_threads;
LOG(INFO) << "----------------------- Data info -----------------------";
LOG(INFO) << "batch_size: " << FLAGS_batch_size;
LOG(INFO) << "input_shape: " << "dynamic shape";
LOG(INFO) << "input_shape: "
<< "dynamic shape";
LOG(INFO) << "----------------------- Model info -----------------------";
FLAGS_model_dir.erase(FLAGS_model_dir.find_last_not_of("/") + 1);
LOG(INFO) << "model_name: " << FLAGS_model_dir.substr(FLAGS_model_dir.find_last_of('/') + 1);
LOG(INFO) << "model_name: "
<< FLAGS_model_dir.substr(FLAGS_model_dir.find_last_of('/') + 1);
LOG(INFO) << "----------------------- Perf info ------------------------";
LOG(INFO) << "Total number of predicted data: " << img_num
<< " and total time spent(ms): "
<< std::accumulate(det_time.begin(), det_time.end(), 0);
LOG(INFO) << "preproce_time(ms): " << det_time[0] / img_num
<< ", inference_time(ms): " << det_time[1] / img_num
<< ", postprocess_time(ms): " << det_time[2];
<< ", postprocess_time(ms): " << det_time[2] / img_num;
}
static std::string DirName(const std::string &filepath) {
static std::string DirName(const std::string& filepath) {
auto pos = filepath.rfind(OS_PATH_SEP);
if (pos == std::string::npos) {
return "";
......@@ -94,7 +118,7 @@ static std::string DirName(const std::string &filepath) {
return filepath.substr(0, pos);
}
static bool PathExists(const std::string& path){
static bool PathExists(const std::string& path) {
#ifdef _WIN32
struct _stat buffer;
return (_stat(path.c_str(), &buffer) == 0);
......@@ -133,11 +157,12 @@ void PredictVideo(const std::string& video_path,
// Open video
cv::VideoCapture capture;
std::string video_out_name = "output.mp4";
if (FLAGS_camera_id != -1){
if (FLAGS_camera_id != -1) {
capture.open(FLAGS_camera_id);
}else{
} else {
capture.open(video_path.c_str());
video_out_name = video_path.substr(video_path.find_last_of(OS_PATH_SEP) + 1);
video_out_name =
video_path.substr(video_path.find_last_of(OS_PATH_SEP) + 1);
}
if (!capture.isOpened()) {
printf("can not open video : %s\n", video_path.c_str());
......@@ -148,7 +173,8 @@ void PredictVideo(const std::string& video_path,
int video_width = static_cast<int>(capture.get(CV_CAP_PROP_FRAME_WIDTH));
int video_height = static_cast<int>(capture.get(CV_CAP_PROP_FRAME_HEIGHT));
int video_fps = static_cast<int>(capture.get(CV_CAP_PROP_FPS));
int video_frame_count = static_cast<int>(capture.get(CV_CAP_PROP_FRAME_COUNT));
int video_frame_count =
static_cast<int>(capture.get(CV_CAP_PROP_FRAME_COUNT));
printf("fps: %d, frame_count: %d\n", video_fps, video_frame_count);
// Create VideoWriter for output
......@@ -188,35 +214,34 @@ void PredictVideo(const std::string& video_path,
std::vector<PaddleDetection::ObjectResult> out_result;
for (const auto& item : result) {
if (item.confidence < FLAGS_threshold || item.class_id == -1) {
continue;
continue;
}
out_result.push_back(item);
if (item.rect.size() > 6){
is_rbox = true;
printf("class=%d confidence=%.4f rect=[%d %d %d %d %d %d %d %d]\n",
item.class_id,
item.confidence,
item.rect[0],
item.rect[1],
item.rect[2],
item.rect[3],
item.rect[4],
item.rect[5],
item.rect[6],
item.rect[7]);
}
else{
if (item.rect.size() > 6) {
is_rbox = true;
printf("class=%d confidence=%.4f rect=[%d %d %d %d %d %d %d %d]\n",
item.class_id,
item.confidence,
item.rect[0],
item.rect[1],
item.rect[2],
item.rect[3],
item.rect[4],
item.rect[5],
item.rect[6],
item.rect[7]);
} else {
printf("class=%d confidence=%.4f rect=[%d %d %d %d]\n",
item.class_id,
item.confidence,
item.rect[0],
item.rect[1],
item.rect[2],
item.rect[3]);
item.class_id,
item.confidence,
item.rect[0],
item.rect[1],
item.rect[2],
item.rect[3]);
}
}
}
cv::Mat out_im = PaddleDetection::VisualizeResult(
cv::Mat out_im = PaddleDetection::VisualizeResult(
frame, out_result, labels, colormap, is_rbox);
video_out.write(out_im);
......@@ -235,7 +260,9 @@ void PredictImage(const std::vector<std::string> all_img_paths,
std::vector<double> det_t = {0, 0, 0};
int steps = ceil(float(all_img_paths.size()) / batch_size);
printf("total images = %d, batch_size = %d, total steps = %d\n",
all_img_paths.size(), batch_size, steps);
all_img_paths.size(),
batch_size,
steps);
for (int idx = 0; idx < steps; idx++) {
std::vector<cv::Mat> batch_imgs;
int left_image_cnt = all_img_paths.size() - idx * batch_size;
......@@ -243,18 +270,19 @@ void PredictImage(const std::vector<std::string> all_img_paths,
left_image_cnt = batch_size;
}
for (int bs = 0; bs < left_image_cnt; bs++) {
std::string image_file_path = all_img_paths.at(idx * batch_size+bs);
std::string image_file_path = all_img_paths.at(idx * batch_size + bs);
cv::Mat im = cv::imread(image_file_path, 1);
batch_imgs.insert(batch_imgs.end(), im);
}
// Store all detected result
std::vector<PaddleDetection::ObjectResult> result;
std::vector<int> bbox_num;
std::vector<double> det_times;
bool is_rbox = false;
if (run_benchmark) {
det->Predict(batch_imgs, threshold, 10, 10, &result, &bbox_num, &det_times);
det->Predict(
batch_imgs, threshold, 10, 10, &result, &bbox_num, &det_times);
} else {
det->Predict(batch_imgs, threshold, 0, 1, &result, &bbox_num, &det_times);
// get labels and colormap
......@@ -274,31 +302,31 @@ void PredictImage(const std::vector<std::string> all_img_paths,
}
detect_num += 1;
im_result.push_back(item);
if (item.rect.size() > 6){
if (item.rect.size() > 6) {
is_rbox = true;
printf("class=%d confidence=%.4f rect=[%d %d %d %d %d %d %d %d]\n",
item.class_id,
item.confidence,
item.rect[0],
item.rect[1],
item.rect[2],
item.rect[3],
item.rect[4],
item.rect[5],
item.rect[6],
item.rect[7]);
}
else{
item.class_id,
item.confidence,
item.rect[0],
item.rect[1],
item.rect[2],
item.rect[3],
item.rect[4],
item.rect[5],
item.rect[6],
item.rect[7]);
} else {
printf("class=%d confidence=%.4f rect=[%d %d %d %d]\n",
item.class_id,
item.confidence,
item.rect[0],
item.rect[1],
item.rect[2],
item.rect[3]);
item.class_id,
item.confidence,
item.rect[0],
item.rect[1],
item.rect[2],
item.rect[3]);
}
}
std::cout << all_img_paths.at(idx * batch_size + i) << " The number of detected box: " << detect_num << std::endl;
std::cout << all_img_paths.at(idx * batch_size + i)
<< " The number of detected box: " << detect_num << std::endl;
item_start_idx = item_start_idx + bbox_num[i];
// Visualization result
cv::Mat vis_img = PaddleDetection::VisualizeResult(
......@@ -311,14 +339,16 @@ void PredictImage(const std::vector<std::string> all_img_paths,
output_path += OS_PATH_SEP;
}
std::string image_file_path = all_img_paths.at(idx * batch_size + i);
output_path += image_file_path.substr(image_file_path.find_last_of('/') + 1);
output_path +=
image_file_path.substr(image_file_path.find_last_of('/') + 1);
cv::imwrite(output_path, vis_img, compression_params);
printf("Visualized output saved as %s\n", output_path.c_str());
printf("Visualized output saved as %s\n", output_path.c_str());
}
}
det_t[0] += det_times[0];
det_t[1] += det_times[1];
det_t[2] += det_times[2];
det_times.clear();
}
PrintBenchmarkLog(det_t, all_img_paths.size());
}
......@@ -326,34 +356,48 @@ void PredictImage(const std::vector<std::string> all_img_paths,
int main(int argc, char** argv) {
// Parsing command-line
google::ParseCommandLineFlags(&argc, &argv, true);
if (FLAGS_model_dir.empty()
|| (FLAGS_image_file.empty() && FLAGS_image_dir.empty() && FLAGS_video_file.empty())) {
if (FLAGS_model_dir.empty() ||
(FLAGS_image_file.empty() && FLAGS_image_dir.empty() &&
FLAGS_video_file.empty())) {
std::cout << "Usage: ./main --model_dir=/PATH/TO/INFERENCE_MODEL/ "
<< "--image_file=/PATH/TO/INPUT/IMAGE/" << std::endl;
<< "--image_file=/PATH/TO/INPUT/IMAGE/" << std::endl;
return -1;
}
if (!(FLAGS_run_mode == "fluid" || FLAGS_run_mode == "trt_fp32"
|| FLAGS_run_mode == "trt_fp16" || FLAGS_run_mode == "trt_int8")) {
std::cout << "run_mode should be 'fluid', 'trt_fp32', 'trt_fp16' or 'trt_int8'.";
if (!(FLAGS_run_mode == "paddle" || FLAGS_run_mode == "trt_fp32" ||
FLAGS_run_mode == "trt_fp16" || FLAGS_run_mode == "trt_int8")) {
std::cout
<< "run_mode should be 'paddle', 'trt_fp32', 'trt_fp16' or 'trt_int8'.";
return -1;
}
transform(FLAGS_device.begin(),FLAGS_device.end(),FLAGS_device.begin(),::toupper);
if (!(FLAGS_device == "CPU" || FLAGS_device == "GPU" || FLAGS_device == "XPU")) {
transform(FLAGS_device.begin(),
FLAGS_device.end(),
FLAGS_device.begin(),
::toupper);
if (!(FLAGS_device == "CPU" || FLAGS_device == "GPU" ||
FLAGS_device == "XPU")) {
std::cout << "device should be 'CPU', 'GPU' or 'XPU'.";
return -1;
}
if (FLAGS_use_gpu) {
std::cout << "Deprecated, please use `--device` to set the device you want to run.";
std::cout << "Deprecated, please use `--device` to set the device you want "
"to run.";
return -1;
}
// Load model and create a object detector
PaddleDetection::ObjectDetector det(FLAGS_model_dir, FLAGS_device, FLAGS_use_mkldnn,
FLAGS_cpu_threads, FLAGS_run_mode, FLAGS_batch_size,FLAGS_gpu_id,
FLAGS_trt_min_shape, FLAGS_trt_max_shape, FLAGS_trt_opt_shape,
FLAGS_trt_calib_mode);
PaddleDetection::ObjectDetector det(FLAGS_model_dir,
FLAGS_device,
FLAGS_use_mkldnn,
FLAGS_cpu_threads,
FLAGS_run_mode,
FLAGS_batch_size,
FLAGS_gpu_id,
FLAGS_trt_min_shape,
FLAGS_trt_max_shape,
FLAGS_trt_opt_shape,
FLAGS_trt_calib_mode);
// Do inference on input video or image
if (!PathExists(FLAGS_output_dir)) {
MkDirs(FLAGS_output_dir);
MkDirs(FLAGS_output_dir);
}
if (!FLAGS_video_file.empty() || FLAGS_camera_id != -1) {
PredictVideo(FLAGS_video_file, &det, FLAGS_output_dir);
......@@ -363,17 +407,22 @@ int main(int argc, char** argv) {
if (!FLAGS_image_file.empty()) {
all_img_paths.push_back(FLAGS_image_file);
if (FLAGS_batch_size > 1) {
std::cout << "batch_size should be 1, when set `image_file`." << std::endl;
return -1;
std::cout << "batch_size should be 1, when set `image_file`."
<< std::endl;
return -1;
}
} else {
cv::glob(FLAGS_image_dir, cv_all_img_paths);
for (const auto & img_path : cv_all_img_paths) {
all_img_paths.push_back(img_path);
}
cv::glob(FLAGS_image_dir, cv_all_img_paths);
for (const auto& img_path : cv_all_img_paths) {
all_img_paths.push_back(img_path);
}
}
PredictImage(all_img_paths, FLAGS_batch_size, FLAGS_threshold,
FLAGS_run_benchmark, &det, FLAGS_output_dir);
PredictImage(all_img_paths,
FLAGS_batch_size,
FLAGS_threshold,
FLAGS_run_benchmark,
&det,
FLAGS_output_dir);
}
return 0;
}
......@@ -14,14 +14,14 @@
#include <glog/logging.h>
#include <math.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <algorithm>
#include <iostream>
#include <numeric>
#include <string>
#include <vector>
#include <numeric>
#include <sys/types.h>
#include <sys/stat.h>
#include <math.h>
#include <algorithm>
#ifdef _WIN32
#include <direct.h>
......@@ -31,52 +31,74 @@
#include <sys/stat.h>
#endif
#include "include/object_detector.h"
#include "include/jde_detector.h"
#include <gflags/gflags.h>
#include <opencv2/opencv.hpp>
#include "include/jde_detector.h"
#include "include/object_detector.h"
DEFINE_string(model_dir, "", "Path of inference model");
DEFINE_int32(batch_size, 1, "batch_size");
DEFINE_string(video_file, "", "Path of input video, `video_file` or `camera_id` has a highest priority.");
DEFINE_string(
video_file,
"",
"Path of input video, `video_file` or `camera_id` has a highest priority.");
DEFINE_int32(camera_id, -1, "Device id of camera to predict");
DEFINE_bool(use_gpu, false, "Deprecated, please use `--device` to set the device you want to run.");
DEFINE_string(device, "CPU", "Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU.");
DEFINE_bool(
use_gpu,
false,
"Deprecated, please use `--device` to set the device you want to run.");
DEFINE_string(device,
"CPU",
"Choose the device you want to run, it can be: CPU/GPU/XPU, "
"default is CPU.");
DEFINE_double(threshold, 0.5, "Threshold of score.");
DEFINE_string(output_dir, "output", "Directory of output visualization files.");
DEFINE_string(run_mode, "fluid", "Mode of running(fluid/trt_fp32/trt_fp16/trt_int8)");
DEFINE_string(run_mode,
"paddle",
"Mode of running(paddle/trt_fp32/trt_fp16/trt_int8)");
DEFINE_int32(gpu_id, 0, "Device id of GPU to execute");
DEFINE_bool(run_benchmark, false, "Whether to predict a image_file repeatedly for benchmark");
DEFINE_bool(run_benchmark,
false,
"Whether to predict a image_file repeatedly for benchmark");
DEFINE_bool(use_mkldnn, false, "Whether use mkldnn with CPU");
DEFINE_int32(cpu_threads, 1, "Num of threads with CPU");
DEFINE_int32(trt_min_shape, 1, "Min shape of TRT DynamicShapeI");
DEFINE_int32(trt_max_shape, 1280, "Max shape of TRT DynamicShapeI");
DEFINE_int32(trt_opt_shape, 640, "Opt shape of TRT DynamicShapeI");
DEFINE_bool(trt_calib_mode, false, "If the model is produced by TRT offline quantitative calibration, trt_calib_mode need to set True");
DEFINE_bool(trt_calib_mode,
false,
"If the model is produced by TRT offline quantitative calibration, "
"trt_calib_mode need to set True");
void PrintBenchmarkLog(std::vector<double> det_time, int img_num){
void PrintBenchmarkLog(std::vector<double> det_time, int img_num) {
LOG(INFO) << "----------------------- Config info -----------------------";
LOG(INFO) << "runtime_device: " << FLAGS_device;
LOG(INFO) << "ir_optim: " << "True";
LOG(INFO) << "enable_memory_optim: " << "True";
LOG(INFO) << "ir_optim: "
<< "True";
LOG(INFO) << "enable_memory_optim: "
<< "True";
int has_trt = FLAGS_run_mode.find("trt");
if (has_trt >= 0) {
LOG(INFO) << "enable_tensorrt: " << "True";
LOG(INFO) << "enable_tensorrt: "
<< "True";
std::string precision = FLAGS_run_mode.substr(4, 8);
LOG(INFO) << "precision: " << precision;
} else {
LOG(INFO) << "enable_tensorrt: " << "False";
LOG(INFO) << "precision: " << "fp32";
LOG(INFO) << "enable_tensorrt: "
<< "False";
LOG(INFO) << "precision: "
<< "fp32";
}
LOG(INFO) << "enable_mkldnn: " << (FLAGS_use_mkldnn ? "True" : "False");
LOG(INFO) << "cpu_math_library_num_threads: " << FLAGS_cpu_threads;
LOG(INFO) << "----------------------- Data info -----------------------";
LOG(INFO) << "batch_size: " << FLAGS_batch_size;
LOG(INFO) << "input_shape: " << "dynamic shape";
LOG(INFO) << "input_shape: "
<< "dynamic shape";
LOG(INFO) << "----------------------- Model info -----------------------";
FLAGS_model_dir.erase(FLAGS_model_dir.find_last_not_of("/") + 1);
LOG(INFO) << "model_name: " << FLAGS_model_dir.substr(FLAGS_model_dir.find_last_of('/') + 1);
LOG(INFO) << "model_name: "
<< FLAGS_model_dir.substr(FLAGS_model_dir.find_last_of('/') + 1);
LOG(INFO) << "----------------------- Perf info ------------------------";
LOG(INFO) << "Total number of predicted data: " << img_num
<< " and total time spent(ms): "
......@@ -86,7 +108,7 @@ void PrintBenchmarkLog(std::vector<double> det_time, int img_num){
<< ", postprocess_time(ms): " << det_time[2] / img_num;
}
static std::string DirName(const std::string &filepath) {
static std::string DirName(const std::string& filepath) {
auto pos = filepath.rfind(OS_PATH_SEP);
if (pos == std::string::npos) {
return "";
......@@ -94,7 +116,7 @@ static std::string DirName(const std::string &filepath) {
return filepath.substr(0, pos);
}
static bool PathExists(const std::string& path){
static bool PathExists(const std::string& path) {
#ifdef _WIN32
struct _stat buffer;
return (_stat(path.c_str(), &buffer) == 0);
......@@ -133,11 +155,12 @@ void PredictVideo(const std::string& video_path,
// Open video
cv::VideoCapture capture;
std::string video_out_name = "output.mp4";
if (FLAGS_camera_id != -1){
if (FLAGS_camera_id != -1) {
capture.open(FLAGS_camera_id);
}else{
} else {
capture.open(video_path.c_str());
video_out_name = video_path.substr(video_path.find_last_of(OS_PATH_SEP) + 1);
video_out_name =
video_path.substr(video_path.find_last_of(OS_PATH_SEP) + 1);
}
if (!capture.isOpened()) {
printf("can not open video : %s\n", video_path.c_str());
......@@ -148,7 +171,8 @@ void PredictVideo(const std::string& video_path,
int video_width = static_cast<int>(capture.get(CV_CAP_PROP_FRAME_WIDTH));
int video_height = static_cast<int>(capture.get(CV_CAP_PROP_FRAME_HEIGHT));
int video_fps = static_cast<int>(capture.get(CV_CAP_PROP_FPS));
int video_frame_count = static_cast<int>(capture.get(CV_CAP_PROP_FRAME_COUNT));
int video_frame_count =
static_cast<int>(capture.get(CV_CAP_PROP_FRAME_COUNT));
printf("fps: %d, frame_count: %d\n", video_fps, video_frame_count);
// Create VideoWriter for output
......@@ -186,47 +210,59 @@ void PredictVideo(const std::string& video_path,
times = std::accumulate(det_times.begin(), det_times.end(), 0) / frame_id;
cv::Mat out_im = PaddleDetection::VisualizeTrackResult(
frame, result, 1000./times, frame_id);
frame, result, 1000. / times, frame_id);
video_out.write(out_im);
}
capture.release();
video_out.release();
PrintBenchmarkLog(det_times, frame_id);
printf("Visualized output saved as %s\n", video_out_path.c_str());
printf("Visualized output saved as %s\n", video_out_path.c_str());
}
int main(int argc, char** argv) {
// Parsing command-line
google::ParseCommandLineFlags(&argc, &argv, true);
if (FLAGS_model_dir.empty()
|| FLAGS_video_file.empty()) {
if (FLAGS_model_dir.empty() || FLAGS_video_file.empty()) {
std::cout << "Usage: ./main --model_dir=/PATH/TO/INFERENCE_MODEL/ "
<< "--video_file=/PATH/TO/INPUT/VIDEO/" << std::endl;
<< "--video_file=/PATH/TO/INPUT/VIDEO/" << std::endl;
return -1;
}
if (!(FLAGS_run_mode == "fluid" || FLAGS_run_mode == "trt_fp32"
|| FLAGS_run_mode == "trt_fp16" || FLAGS_run_mode == "trt_int8")) {
std::cout << "run_mode should be 'fluid', 'trt_fp32', 'trt_fp16' or 'trt_int8'.";
if (!(FLAGS_run_mode == "paddle" || FLAGS_run_mode == "trt_fp32" ||
FLAGS_run_mode == "trt_fp16" || FLAGS_run_mode == "trt_int8")) {
std::cout
<< "run_mode should be 'paddle', 'trt_fp32', 'trt_fp16' or 'trt_int8'.";
return -1;
}
transform(FLAGS_device.begin(),FLAGS_device.end(),FLAGS_device.begin(),::toupper);
if (!(FLAGS_device == "CPU" || FLAGS_device == "GPU" || FLAGS_device == "XPU")) {
transform(FLAGS_device.begin(),
FLAGS_device.end(),
FLAGS_device.begin(),
::toupper);
if (!(FLAGS_device == "CPU" || FLAGS_device == "GPU" ||
FLAGS_device == "XPU")) {
std::cout << "device should be 'CPU', 'GPU' or 'XPU'.";
return -1;
}
if (FLAGS_use_gpu) {
std::cout << "Deprecated, please use `--device` to set the device you want to run.";
std::cout << "Deprecated, please use `--device` to set the device you want "
"to run.";
return -1;
}
// Do inference on input video or image
PaddleDetection::JDEDetector mot(FLAGS_model_dir, FLAGS_device, FLAGS_use_mkldnn,
FLAGS_cpu_threads, FLAGS_run_mode, FLAGS_batch_size,FLAGS_gpu_id,
FLAGS_trt_min_shape, FLAGS_trt_max_shape, FLAGS_trt_opt_shape,
FLAGS_trt_calib_mode);
PaddleDetection::JDEDetector mot(FLAGS_model_dir,
FLAGS_device,
FLAGS_use_mkldnn,
FLAGS_cpu_threads,
FLAGS_run_mode,
FLAGS_batch_size,
FLAGS_gpu_id,
FLAGS_trt_min_shape,
FLAGS_trt_max_shape,
FLAGS_trt_opt_shape,
FLAGS_trt_calib_mode);
if (!PathExists(FLAGS_output_dir)) {
MkDirs(FLAGS_output_dir);
MkDirs(FLAGS_output_dir);
}
PredictVideo(FLAGS_video_file, &mot, FLAGS_output_dir);
return 0;
......
......@@ -62,8 +62,8 @@ DEFINE_double(threshold, 0.5, "Threshold of score.");
DEFINE_double(threshold_keypoint, 0.5, "Threshold of score.");
DEFINE_string(output_dir, "output", "Directory of output visualization files.");
DEFINE_string(run_mode,
"fluid",
"Mode of running(fluid/trt_fp32/trt_fp16/trt_int8)");
"paddle",
"Mode of running(paddle/trt_fp32/trt_fp16/trt_int8)");
DEFINE_int32(gpu_id, 0, "Device id of GPU to execute");
DEFINE_bool(run_benchmark,
false,
......@@ -505,10 +505,10 @@ int main(int argc, char** argv) {
<< "--image_file=/PATH/TO/INPUT/IMAGE/" << std::endl;
return -1;
}
if (!(FLAGS_run_mode == "fluid" || FLAGS_run_mode == "trt_fp32" ||
if (!(FLAGS_run_mode == "paddle" || FLAGS_run_mode == "trt_fp32" ||
FLAGS_run_mode == "trt_fp16" || FLAGS_run_mode == "trt_int8")) {
std::cout
<< "run_mode should be 'fluid', 'trt_fp32', 'trt_fp16' or 'trt_int8'.";
<< "run_mode should be 'paddle', 'trt_fp32', 'trt_fp16' or 'trt_int8'.";
return -1;
}
transform(FLAGS_device.begin(),
......
......@@ -13,8 +13,8 @@
// limitations under the License.
#include <sstream>
// for setprecision
#include <iomanip>
#include <chrono>
#include <iomanip>
#include "include/object_detector.h"
using namespace paddle_infer;
......@@ -33,47 +33,51 @@ void ObjectDetector::LoadModel(const std::string& model_dir,
config.EnableUseGpu(200, this->gpu_id_);
config.SwitchIrOptim(true);
// use tensorrt
if (run_mode != "fluid") {
if (run_mode != "paddle") {
auto precision = paddle_infer::Config::Precision::kFloat32;
if (run_mode == "trt_fp32") {
precision = paddle_infer::Config::Precision::kFloat32;
}
else if (run_mode == "trt_fp16") {
} else if (run_mode == "trt_fp16") {
precision = paddle_infer::Config::Precision::kHalf;
}
else if (run_mode == "trt_int8") {
} else if (run_mode == "trt_int8") {
precision = paddle_infer::Config::Precision::kInt8;
} else {
printf("run_mode should be 'fluid', 'trt_fp32', 'trt_fp16' or 'trt_int8'");
printf(
"run_mode should be 'paddle', 'trt_fp32', 'trt_fp16' or "
"'trt_int8'");
}
// set tensorrt
config.EnableTensorRtEngine(
1 << 30,
batch_size,
this->min_subgraph_size_,
precision,
false,
this->trt_calib_mode_);
config.EnableTensorRtEngine(1 << 30,
batch_size,
this->min_subgraph_size_,
precision,
false,
this->trt_calib_mode_);
// set use dynamic shape
if (this->use_dynamic_shape_) {
// set DynamicShsape for image tensor
const std::vector<int> min_input_shape = {1, 3, this->trt_min_shape_, this->trt_min_shape_};
const std::vector<int> max_input_shape = {1, 3, this->trt_max_shape_, this->trt_max_shape_};
const std::vector<int> opt_input_shape = {1, 3, this->trt_opt_shape_, this->trt_opt_shape_};
const std::map<std::string, std::vector<int>> map_min_input_shape = {{"image", min_input_shape}};
const std::map<std::string, std::vector<int>> map_max_input_shape = {{"image", max_input_shape}};
const std::map<std::string, std::vector<int>> map_opt_input_shape = {{"image", opt_input_shape}};
const std::vector<int> min_input_shape = {
1, 3, this->trt_min_shape_, this->trt_min_shape_};
const std::vector<int> max_input_shape = {
1, 3, this->trt_max_shape_, this->trt_max_shape_};
const std::vector<int> opt_input_shape = {
1, 3, this->trt_opt_shape_, this->trt_opt_shape_};
const std::map<std::string, std::vector<int>> map_min_input_shape = {
{"image", min_input_shape}};
const std::map<std::string, std::vector<int>> map_max_input_shape = {
{"image", max_input_shape}};
const std::map<std::string, std::vector<int>> map_opt_input_shape = {
{"image", opt_input_shape}};
config.SetTRTDynamicShapeInfo(map_min_input_shape,
map_max_input_shape,
map_opt_input_shape);
config.SetTRTDynamicShapeInfo(
map_min_input_shape, map_max_input_shape, map_opt_input_shape);
std::cout << "TensorRT dynamic shape enabled" << std::endl;
}
}
} else if (this->device_ == "XPU"){
config.EnableXpu(10*1024*1024);
} else if (this->device_ == "XPU") {
config.EnableXpu(10 * 1024 * 1024);
} else {
config.DisableGpu();
if (this->use_mkldnn_) {
......@@ -92,11 +96,12 @@ void ObjectDetector::LoadModel(const std::string& model_dir,
}
// Visualiztion MaskDetector results
cv::Mat VisualizeResult(const cv::Mat& img,
const std::vector<PaddleDetection::ObjectResult>& results,
const std::vector<std::string>& lables,
const std::vector<int>& colormap,
const bool is_rbox=false) {
cv::Mat VisualizeResult(
const cv::Mat& img,
const std::vector<PaddleDetection::ObjectResult>& results,
const std::vector<std::string>& lables,
const std::vector<int>& colormap,
const bool is_rbox = false) {
cv::Mat vis_img = img.clone();
for (int i = 0; i < results.size(); ++i) {
// Configure color and text size
......@@ -112,32 +117,25 @@ cv::Mat VisualizeResult(const cv::Mat& img,
int font_face = cv::FONT_HERSHEY_COMPLEX_SMALL;
double font_scale = 0.5f;
float thickness = 0.5;
cv::Size text_size = cv::getTextSize(text,
font_face,
font_scale,
thickness,
nullptr);
cv::Size text_size =
cv::getTextSize(text, font_face, font_scale, thickness, nullptr);
cv::Point origin;
if (is_rbox)
{
// Draw object, text, and background
for (int k = 0; k < 4; k++)
{
cv::Point pt1 = cv::Point(results[i].rect[(k * 2) % 8],
results[i].rect[(k * 2 + 1) % 8]);
cv::Point pt2 = cv::Point(results[i].rect[(k * 2 + 2) % 8],
results[i].rect[(k * 2 + 3) % 8]);
cv::line(vis_img, pt1, pt2, roi_color, 2);
}
}
else
{
int w = results[i].rect[2] - results[i].rect[0];
int h = results[i].rect[3] - results[i].rect[1];
cv::Rect roi = cv::Rect(results[i].rect[0], results[i].rect[1], w, h);
// Draw roi object, text, and background
cv::rectangle(vis_img, roi, roi_color, 2);
if (is_rbox) {
// Draw object, text, and background
for (int k = 0; k < 4; k++) {
cv::Point pt1 = cv::Point(results[i].rect[(k * 2) % 8],
results[i].rect[(k * 2 + 1) % 8]);
cv::Point pt2 = cv::Point(results[i].rect[(k * 2 + 2) % 8],
results[i].rect[(k * 2 + 3) % 8]);
cv::line(vis_img, pt1, pt2, roi_color, 2);
}
} else {
int w = results[i].rect[2] - results[i].rect[0];
int h = results[i].rect[3] - results[i].rect[1];
cv::Rect roi = cv::Rect(results[i].rect[0], results[i].rect[1], w, h);
// Draw roi object, text, and background
cv::rectangle(vis_img, roi, roi_color, 2);
}
origin.x = results[i].rect[0];
......@@ -173,7 +171,7 @@ void ObjectDetector::Postprocess(
std::vector<PaddleDetection::ObjectResult>* result,
std::vector<int> bbox_num,
std::vector<float> output_data_,
bool is_rbox=false) {
bool is_rbox = false) {
result->clear();
int start_idx = 0;
for (int im_id = 0; im_id < mats.size(); im_id++) {
......@@ -184,7 +182,7 @@ void ObjectDetector::Postprocess(
rh = raw_mat.rows;
rw = raw_mat.cols;
}
for (int j = start_idx; j < start_idx+bbox_num[im_id]; j++) {
for (int j = start_idx; j < start_idx + bbox_num[im_id]; j++) {
if (is_rbox) {
// Class id
int class_id = static_cast<int>(round(output_data_[0 + j * 10]));
......@@ -198,14 +196,13 @@ void ObjectDetector::Postprocess(
int y3 = (output_data_[7 + j * 10] * rh);
int x4 = (output_data_[8 + j * 10] * rw);
int y4 = (output_data_[9 + j * 10] * rh);
PaddleDetection::ObjectResult result_item;
result_item.rect = {x1, y1, x2, y2, x3, y3, x4, y4};
result_item.class_id = class_id;
result_item.confidence = score;
result->push_back(result_item);
}
else {
} else {
// Class id
int class_id = static_cast<int>(round(output_data_[0 + j * 6]));
// Confidence score
......@@ -216,7 +213,7 @@ void ObjectDetector::Postprocess(
int ymax = (output_data_[5 + j * 6] * rh);
int wd = xmax - xmin;
int hd = ymax - ymin;
PaddleDetection::ObjectResult result_item;
result_item.rect = {xmin, ymin, xmax, ymax};
result_item.class_id = class_id;
......@@ -229,12 +226,12 @@ void ObjectDetector::Postprocess(
}
void ObjectDetector::Predict(const std::vector<cv::Mat> imgs,
const double threshold,
const int warmup,
const int repeats,
std::vector<PaddleDetection::ObjectResult>* result,
std::vector<int>* bbox_num,
std::vector<double>* times) {
const double threshold,
const int warmup,
const int repeats,
std::vector<PaddleDetection::ObjectResult>* result,
std::vector<int>* bbox_num,
std::vector<double>* times) {
auto preprocess_start = std::chrono::steady_clock::now();
int batch_size = imgs.size();
......@@ -242,9 +239,12 @@ void ObjectDetector::Predict(const std::vector<cv::Mat> imgs,
std::vector<float> in_data_all;
std::vector<float> im_shape_all(batch_size * 2);
std::vector<float> scale_factor_all(batch_size * 2);
std::vector<const float *> output_data_list_;
std::vector<const float*> output_data_list_;
std::vector<int> out_bbox_num_data_;
// in_net img for each batch
std::vector<cv::Mat> in_net_img_all(batch_size);
// Preprocess image
for (int bs_idx = 0; bs_idx < batch_size; bs_idx++) {
cv::Mat im = imgs.at(bs_idx);
......@@ -256,11 +256,39 @@ void ObjectDetector::Predict(const std::vector<cv::Mat> imgs,
scale_factor_all[bs_idx * 2 + 1] = inputs_.scale_factor_[1];
// TODO: reduce cost time
in_data_all.insert(in_data_all.end(), inputs_.im_data_.begin(), inputs_.im_data_.end());
in_data_all.insert(
in_data_all.end(), inputs_.im_data_.begin(), inputs_.im_data_.end());
// collect in_net img
in_net_img_all[bs_idx] = inputs_.in_net_im_;
}
// Pad Batch if batch size > 1
if (batch_size > 1 && CheckDynamicInput(in_net_img_all)) {
in_data_all.clear();
std::vector<cv::Mat> pad_img_all = PadBatch(in_net_img_all);
int rh = pad_img_all[0].rows;
int rw = pad_img_all[0].cols;
int rc = pad_img_all[0].channels();
for (int bs_idx = 0; bs_idx < batch_size; bs_idx++) {
cv::Mat pad_img = pad_img_all[bs_idx];
pad_img.convertTo(pad_img, CV_32FC3);
std::vector<float> pad_data;
pad_data.resize(rc * rh * rw);
float* base = pad_data.data();
for (int i = 0; i < rc; ++i) {
cv::extractChannel(
pad_img, cv::Mat(rh, rw, CV_32FC1, base + i * rh * rw), i);
}
in_data_all.insert(in_data_all.end(), pad_data.begin(), pad_data.end());
}
// update in_net_shape
inputs_.in_net_shape_ = {static_cast<float>(rh), static_cast<float>(rw)};
}
auto preprocess_end = std::chrono::steady_clock::now();
// Prepare input tensor
auto input_names = predictor_->GetInputNames();
for (const auto& tensor_name : input_names) {
auto in_tensor = predictor_->GetInputHandle(tensor_name);
......@@ -277,7 +305,7 @@ void ObjectDetector::Predict(const std::vector<cv::Mat> imgs,
in_tensor->CopyFromCpu(scale_factor_all.data());
}
}
// Run predictor
std::vector<std::vector<float>> out_tensor_list;
std::vector<std::vector<int>> output_shape_list;
......@@ -292,8 +320,8 @@ void ObjectDetector::Predict(const std::vector<cv::Mat> imgs,
for (int j = 0; j < output_names.size(); j++) {
auto output_tensor = predictor_->GetOutputHandle(output_names[j]);
std::vector<int> output_shape = output_tensor->shape();
int out_num = std::accumulate(output_shape.begin(), output_shape.end(),
1, std::multiplies<int>());
int out_num = std::accumulate(
output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
if (output_tensor->type() == paddle_infer::DataType::INT32) {
out_bbox_num_data_.resize(out_num);
output_tensor->CopyToCpu(out_bbox_num_data_.data());
......@@ -316,8 +344,8 @@ void ObjectDetector::Predict(const std::vector<cv::Mat> imgs,
for (int j = 0; j < output_names.size(); j++) {
auto output_tensor = predictor_->GetOutputHandle(output_names[j]);
std::vector<int> output_shape = output_tensor->shape();
int out_num = std::accumulate(output_shape.begin(), output_shape.end(),
1, std::multiplies<int>());
int out_num = std::accumulate(
output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
output_shape_list.push_back(output_shape);
if (output_tensor->type() == paddle_infer::DataType::INT32) {
out_bbox_num_data_.resize(out_num);
......@@ -343,35 +371,43 @@ void ObjectDetector::Predict(const std::vector<cv::Mat> imgs,
if (i == config_.fpn_stride_.size()) {
reg_max = output_shape_list[i][2] / 4 - 1;
}
float *buffer = new float[out_tensor_list[i].size()];
memcpy(buffer, &out_tensor_list[i][0],
out_tensor_list[i].size()*sizeof(float));
float* buffer = new float[out_tensor_list[i].size()];
memcpy(buffer,
&out_tensor_list[i][0],
out_tensor_list[i].size() * sizeof(float));
output_data_list_.push_back(buffer);
}
PaddleDetection::PicoDetPostProcess(
result, output_data_list_, config_.fpn_stride_,
inputs_.im_shape_, inputs_.scale_factor_,
config_.nms_info_["score_threshold"].as<float>(),
config_.nms_info_["nms_threshold"].as<float>(), num_class, reg_max);
result,
output_data_list_,
config_.fpn_stride_,
inputs_.im_shape_,
inputs_.scale_factor_,
config_.nms_info_["score_threshold"].as<float>(),
config_.nms_info_["nms_threshold"].as<float>(),
num_class,
reg_max);
bbox_num->push_back(result->size());
} else {
is_rbox = output_shape_list[0][output_shape_list[0].size()-1] % 10 == 0;
is_rbox = output_shape_list[0][output_shape_list[0].size() - 1] % 10 == 0;
Postprocess(imgs, result, out_bbox_num_data_, out_tensor_list[0], is_rbox);
for (int k=0; k < out_bbox_num_data_.size(); k++) {
for (int k = 0; k < out_bbox_num_data_.size(); k++) {
int tmp = out_bbox_num_data_[k];
bbox_num->push_back(tmp);
}
}
auto postprocess_end = std::chrono::steady_clock::now();
std::chrono::duration<float> preprocess_diff = preprocess_end - preprocess_start;
times->push_back(double(preprocess_diff.count() * 1000));
std::chrono::duration<float> preprocess_diff =
preprocess_end - preprocess_start;
times->push_back(static_cast<double>(preprocess_diff.count() * 1000));
std::chrono::duration<float> inference_diff = inference_end - inference_start;
times->push_back(double(inference_diff.count() / repeats * 1000));
std::chrono::duration<float> postprocess_diff = postprocess_end - postprocess_start;
times->push_back(double(postprocess_diff.count() * 1000));
times->push_back(
static_cast<double>(inference_diff.count() / repeats * 1000));
std::chrono::duration<float> postprocess_diff =
postprocess_end - postprocess_start;
times->push_back(static_cast<double>(postprocess_diff.count() * 1000));
}
std::vector<int> GenerateColorMap(int num_class) {
......
......@@ -12,24 +12,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include <string>
#include <thread>
#include <vector>
#include "include/preprocess_op.h"
namespace PaddleDetection {
void InitInfo::Run(cv::Mat* im, ImageBlob* data) {
data->im_shape_ = {
static_cast<float>(im->rows),
static_cast<float>(im->cols)
};
data->im_shape_ = {static_cast<float>(im->rows),
static_cast<float>(im->cols)};
data->scale_factor_ = {1., 1.};
data->in_net_shape_ = {
static_cast<float>(im->rows),
static_cast<float>(im->cols)
};
data->in_net_shape_ = {static_cast<float>(im->rows),
static_cast<float>(im->cols)};
}
void NormalizeImage::Run(cv::Mat* im, ImageBlob* data) {
......@@ -41,11 +37,11 @@ void NormalizeImage::Run(cv::Mat* im, ImageBlob* data) {
for (int h = 0; h < im->rows; h++) {
for (int w = 0; w < im->cols; w++) {
im->at<cv::Vec3f>(h, w)[0] =
(im->at<cv::Vec3f>(h, w)[0] - mean_[0] ) / scale_[0];
(im->at<cv::Vec3f>(h, w)[0] - mean_[0]) / scale_[0];
im->at<cv::Vec3f>(h, w)[1] =
(im->at<cv::Vec3f>(h, w)[1] - mean_[1] ) / scale_[1];
(im->at<cv::Vec3f>(h, w)[1] - mean_[1]) / scale_[1];
im->at<cv::Vec3f>(h, w)[2] =
(im->at<cv::Vec3f>(h, w)[2] - mean_[2] ) / scale_[2];
(im->at<cv::Vec3f>(h, w)[2] - mean_[2]) / scale_[2];
}
}
}
......@@ -64,27 +60,20 @@ void Permute::Run(cv::Mat* im, ImageBlob* data) {
void Resize::Run(cv::Mat* im, ImageBlob* data) {
auto resize_scale = GenerateScale(*im);
data->im_shape_ = {
static_cast<float>(im->cols * resize_scale.first),
static_cast<float>(im->rows * resize_scale.second)
};
data->in_net_shape_ = {
static_cast<float>(im->cols * resize_scale.first),
static_cast<float>(im->rows * resize_scale.second)
};
data->im_shape_ = {static_cast<float>(im->cols * resize_scale.first),
static_cast<float>(im->rows * resize_scale.second)};
data->in_net_shape_ = {static_cast<float>(im->cols * resize_scale.first),
static_cast<float>(im->rows * resize_scale.second)};
cv::resize(
*im, *im, cv::Size(), resize_scale.first, resize_scale.second, interp_);
data->im_shape_ = {
static_cast<float>(im->rows),
static_cast<float>(im->cols),
static_cast<float>(im->rows), static_cast<float>(im->cols),
};
data->scale_factor_ = {
resize_scale.second,
resize_scale.first,
resize_scale.second, resize_scale.first,
};
}
std::pair<float, float> Resize::GenerateScale(const cv::Mat& im) {
std::pair<float, float> resize_scale;
int origin_w = im.cols;
......@@ -93,8 +82,10 @@ std::pair<float, float> Resize::GenerateScale(const cv::Mat& im) {
if (keep_ratio_) {
int im_size_max = std::max(origin_w, origin_h);
int im_size_min = std::min(origin_w, origin_h);
int target_size_max = *std::max_element(target_size_.begin(), target_size_.end());
int target_size_min = *std::min_element(target_size_.begin(), target_size_.end());
int target_size_max =
*std::max_element(target_size_.begin(), target_size_.end());
int target_size_min =
*std::min_element(target_size_.begin(), target_size_.end());
float scale_min =
static_cast<float>(target_size_min) / static_cast<float>(im_size_min);
float scale_max =
......@@ -114,46 +105,38 @@ void LetterBoxResize::Run(cv::Mat* im, ImageBlob* data) {
float resize_scale = GenerateScale(*im);
int new_shape_w = std::round(im->cols * resize_scale);
int new_shape_h = std::round(im->rows * resize_scale);
data->im_shape_ = {
static_cast<float>(new_shape_h),
static_cast<float>(new_shape_w)
};
data->im_shape_ = {static_cast<float>(new_shape_h),
static_cast<float>(new_shape_w)};
float padw = (target_size_[1] - new_shape_w) / 2.;
float padh = (target_size_[0] - new_shape_h) / 2.;
int top = std::round(padh - 0.1);
int bottom = std::round(padh + 0.1);
int left = std::round(padw - 0.1);
int right = std::round(padw + 0.1);
cv::resize(
*im, *im, cv::Size(new_shape_w, new_shape_h), 0, 0, cv::INTER_AREA);
*im, *im, cv::Size(new_shape_w, new_shape_h), 0, 0, cv::INTER_AREA);
data->in_net_shape_ = {
static_cast<float>(im->rows),
static_cast<float>(im->cols),
static_cast<float>(im->rows), static_cast<float>(im->cols),
};
cv::copyMakeBorder(
*im,
*im,
top,
bottom,
left,
right,
cv::BORDER_CONSTANT,
cv::Scalar(127.5));
cv::copyMakeBorder(*im,
*im,
top,
bottom,
left,
right,
cv::BORDER_CONSTANT,
cv::Scalar(127.5));
data->in_net_shape_ = {
static_cast<float>(im->rows),
static_cast<float>(im->cols),
static_cast<float>(im->rows), static_cast<float>(im->cols),
};
data->scale_factor_ = {
resize_scale,
resize_scale,
resize_scale, resize_scale,
};
}
float LetterBoxResize::GenerateScale(const cv::Mat& im) {
......@@ -165,7 +148,7 @@ float LetterBoxResize::GenerateScale(const cv::Mat& im) {
float ratio_h = static_cast<float>(target_h) / static_cast<float>(origin_h);
float ratio_w = static_cast<float>(target_w) / static_cast<float>(origin_w);
float resize_scale = std::min(ratio_h, ratio_w);
float resize_scale = std::min(ratio_h, ratio_w);
return resize_scale;
}
......@@ -179,34 +162,29 @@ void PadStride::Run(cv::Mat* im, ImageBlob* data) {
int nh = (rh / stride_) * stride_ + (rh % stride_ != 0) * stride_;
int nw = (rw / stride_) * stride_ + (rw % stride_ != 0) * stride_;
cv::copyMakeBorder(
*im,
*im,
0,
nh - rh,
0,
nw - rw,
cv::BORDER_CONSTANT,
cv::Scalar(0));
*im, *im, 0, nh - rh, 0, nw - rw, cv::BORDER_CONSTANT, cv::Scalar(0));
data->in_net_im_ = im->clone();
data->in_net_shape_ = {
static_cast<float>(im->rows),
static_cast<float>(im->cols),
static_cast<float>(im->rows), static_cast<float>(im->cols),
};
}
void TopDownEvalAffine::Run(cv::Mat* im, ImageBlob* data) {
cv::resize(
*im, *im, cv::Size(trainsize_[0],trainsize_[1]), 0, 0, interp_);
cv::resize(*im, *im, cv::Size(trainsize_[0], trainsize_[1]), 0, 0, interp_);
// todo: Simd::ResizeBilinear();
data->in_net_shape_ = {
static_cast<float>(trainsize_[1]),
static_cast<float>(trainsize_[0]),
static_cast<float>(trainsize_[1]), static_cast<float>(trainsize_[0]),
};
}
// Preprocessor op running order
const std::vector<std::string> Preprocessor::RUN_ORDER = {
"InitInfo", "TopDownEvalAffine", "Resize", "LetterBoxResize", "NormalizeImage", "PadStride", "Permute"
};
const std::vector<std::string> Preprocessor::RUN_ORDER = {"InitInfo",
"TopDownEvalAffine",
"Resize",
"LetterBoxResize",
"NormalizeImage",
"PadStride",
"Permute"};
void Preprocessor::Run(cv::Mat* im, ImageBlob* data) {
for (const auto& name : RUN_ORDER) {
......@@ -216,37 +194,87 @@ void Preprocessor::Run(cv::Mat* im, ImageBlob* data) {
}
}
void CropImg(cv::Mat &img, cv::Mat &crop_img, std::vector<int> &area, std::vector<float> &center, std::vector<float> &scale, float expandratio) {
int crop_x1 = std::max(0, area[0]);
int crop_y1 = std::max(0, area[1]);
int crop_x2 = std::min(img.cols -1, area[2]);
int crop_y2 = std::min(img.rows - 1, area[3]);
int center_x = (crop_x1 + crop_x2)/2.;
int center_y = (crop_y1 + crop_y2)/2.;
int half_h = (crop_y2 - crop_y1)/2.;
int half_w = (crop_x2 - crop_x1)/2.;
//adjust h or w to keep image ratio, expand the shorter edge
if (half_h*3 > half_w*4){
half_w = static_cast<int>(half_h*0.75);
}
else{
half_h = static_cast<int>(half_w*4/3);
}
void CropImg(cv::Mat& img,
cv::Mat& crop_img,
std::vector<int>& area,
std::vector<float>& center,
std::vector<float>& scale,
float expandratio) {
int crop_x1 = std::max(0, area[0]);
int crop_y1 = std::max(0, area[1]);
int crop_x2 = std::min(img.cols - 1, area[2]);
int crop_y2 = std::min(img.rows - 1, area[3]);
int center_x = (crop_x1 + crop_x2) / 2.;
int center_y = (crop_y1 + crop_y2) / 2.;
int half_h = (crop_y2 - crop_y1) / 2.;
int half_w = (crop_x2 - crop_x1) / 2.;
// adjust h or w to keep image ratio, expand the shorter edge
if (half_h * 3 > half_w * 4) {
half_w = static_cast<int>(half_h * 0.75);
} else {
half_h = static_cast<int>(half_w * 4 / 3);
}
crop_x1 = std::max(0, center_x - static_cast<int>(half_w*(1+expandratio)));
crop_y1 = std::max(0, center_y - static_cast<int>(half_h*(1+expandratio)));
crop_x2 = std::min(img.cols -1, static_cast<int>(center_x + half_w*(1+expandratio)));
crop_y2 = std::min(img.rows - 1, static_cast<int>(center_y + half_h*(1+expandratio)));
crop_img = img(cv::Range(crop_y1, crop_y2+1), cv::Range(crop_x1, crop_x2 + 1));
crop_x1 =
std::max(0, center_x - static_cast<int>(half_w * (1 + expandratio)));
crop_y1 =
std::max(0, center_y - static_cast<int>(half_h * (1 + expandratio)));
crop_x2 = std::min(img.cols - 1,
static_cast<int>(center_x + half_w * (1 + expandratio)));
crop_y2 = std::min(img.rows - 1,
static_cast<int>(center_y + half_h * (1 + expandratio)));
crop_img =
img(cv::Range(crop_y1, crop_y2 + 1), cv::Range(crop_x1, crop_x2 + 1));
center.clear();
center.emplace_back((crop_x1+crop_x2)/2);
center.emplace_back((crop_y1+crop_y2)/2);
center.clear();
center.emplace_back((crop_x1 + crop_x2) / 2);
center.emplace_back((crop_y1 + crop_y2) / 2);
scale.clear();
scale.emplace_back((crop_x2 - crop_x1));
scale.emplace_back((crop_y2 - crop_y1));
}
scale.clear();
scale.emplace_back((crop_x2-crop_x1));
scale.emplace_back((crop_y2-crop_y1));
bool CheckDynamicInput(const std::vector<cv::Mat>& imgs) {
if (imgs.size() == 1) return false;
int h = imgs.at(0).rows;
int w = imgs.at(0).cols;
for (int i = 1; i < imgs.size(); ++i) {
if (imgs.at(i).rows != h || imgs.at(i).cols != w) {
return true;
}
}
return false;
}
std::vector<cv::Mat> PadBatch(const std::vector<cv::Mat>& imgs) {
std::vector<cv::Mat> out_imgs;
int max_h = 0;
int max_w = 0;
int rh = 0;
int rw = 0;
// find max_h and max_w in batch
for (int i = 0; i < imgs.size(); ++i) {
rh = imgs.at(i).rows;
rw = imgs.at(i).cols;
if (rh > max_h) max_h = rh;
if (rw > max_w) max_w = rw;
}
for (int i = 0; i < imgs.size(); ++i) {
cv::Mat im = imgs.at(i);
cv::copyMakeBorder(im,
im,
0,
max_h - imgs.at(i).rows,
0,
max_w - imgs.at(i).cols,
cv::BORDER_CONSTANT,
cv::Scalar(0));
out_imgs.push_back(im);
}
return out_imgs;
}
} // namespace PaddleDetection
......@@ -112,7 +112,7 @@ python tools/export_model.py -c configs/mot/fairmot/fairmot_hrnetv2_w18_dlafpn_3
| --video_file | 要预测的视频文件路径 |
| --device | 运行时的设备,可选择`CPU/GPU/XPU`,默认为`CPU`|
| --gpu_id | 指定进行推理的GPU device id(默认值为0)|
| --run_mode | 使用GPU时,默认为fluid, 可选(fluid/trt_fp32/trt_fp16/trt_int8)|
| --run_mode | 使用GPU时,默认为paddle, 可选(paddle/trt_fp32/trt_fp16/trt_int8)|
| --output_dir | 输出图片所在的文件夹, 默认为output |
| --use_mkldnn | CPU预测中是否开启MKLDNN加速 |
| --cpu_threads | 设置cpu线程数,默认为1 |
......
......@@ -42,12 +42,12 @@ class ConfigPaser {
YAML::Node config;
config = YAML::LoadFile(model_dir + OS_PATH_SEP + cfg);
// Get runtime mode : fluid, trt_fp16, trt_fp32
// Get runtime mode : paddle, trt_fp16, trt_fp32
if (config["mode"].IsDefined()) {
mode_ = config["mode"].as<std::string>();
} else {
std::cerr << "Please set mode, "
<< "support value : fluid/trt_fp16/trt_fp32." << std::endl;
<< "support value : paddle/trt_fp16/trt_fp32." << std::endl;
return false;
}
......
......@@ -39,7 +39,7 @@ class JDEPredictor {
explicit JDEPredictor(const std::string& device = "CPU",
const std::string& model_dir = "",
const double threshold = -1.,
const std::string& run_mode = "fluid",
const std::string& run_mode = "paddle",
const int gpu_id = 0,
const bool use_mkldnn = false,
const int cpu_threads = 1,
......@@ -61,7 +61,7 @@ class JDEPredictor {
// Load Paddle inference model
void LoadModel(const std::string& model_dir,
const std::string& run_mode = "fluid");
const std::string& run_mode = "paddle");
// Run predictor
void Predict(const std::vector<cv::Mat> imgs,
......
......@@ -43,7 +43,7 @@ class Pipeline {
explicit Pipeline(const std::string& device,
const double threshold,
const std::string& output_dir,
const std::string& run_mode = "fluid",
const std::string& run_mode = "paddle",
const int gpu_id = 0,
const bool use_mkldnn = false,
const int cpu_threads = 1,
......@@ -127,7 +127,7 @@ class Pipeline {
std::string track_model_dir_;
std::string det_model_dir_;
std::string reid_model_dir_;
std::string run_mode_ = "fluid";
std::string run_mode_ = "paddle";
int gpu_id_ = 0;
bool use_mkldnn_ = false;
int cpu_threads_ = 1;
......
......@@ -42,7 +42,7 @@ class Predictor {
const std::string& det_model_dir = "",
const std::string& reid_model_dir = "",
const double threshold = -1.,
const std::string& run_mode = "fluid",
const std::string& run_mode = "paddle",
const int gpu_id = 0,
const bool use_mkldnn = false,
const int cpu_threads = 1,
......
......@@ -40,7 +40,7 @@ class SDEPredictor {
const std::string& det_model_dir = "",
const std::string& reid_model_dir = "",
const double threshold = -1.,
const std::string& run_mode = "fluid",
const std::string& run_mode = "paddle",
const int gpu_id = 0,
const bool use_mkldnn = false,
const int cpu_threads = 1,
......@@ -67,7 +67,7 @@ class SDEPredictor {
// Load Paddle inference model
void LoadModel(const std::string& det_model_dir,
const std::string& reid_model_dir,
const std::string& run_mode = "fluid");
const std::string& run_mode = "paddle");
// Run predictor
void Predict(const std::vector<cv::Mat> imgs,
......
......@@ -44,8 +44,8 @@ DEFINE_string(device,
DEFINE_double(threshold, 0.5, "Threshold of score.");
DEFINE_string(output_dir, "output", "Directory of output visualization files.");
DEFINE_string(run_mode,
"fluid",
"Mode of running(fluid/trt_fp32/trt_fp16/trt_int8)");
"paddle",
"Mode of running(paddle/trt_fp32/trt_fp16/trt_int8)");
DEFINE_int32(gpu_id, 0, "Device id of GPU to execute");
DEFINE_bool(use_mkldnn, false, "Whether use mkldnn with CPU");
DEFINE_int32(cpu_threads, 1, "Num of threads with CPU");
......@@ -125,10 +125,10 @@ int main(int argc, char** argv) {
return -1;
}
if (!(FLAGS_run_mode == "fluid" || FLAGS_run_mode == "trt_fp32" ||
if (!(FLAGS_run_mode == "paddle" || FLAGS_run_mode == "trt_fp32" ||
FLAGS_run_mode == "trt_fp16" || FLAGS_run_mode == "trt_int8")) {
LOG(ERROR)
<< "run_mode should be 'fluid', 'trt_fp32', 'trt_fp16' or 'trt_int8'.";
<< "run_mode should be 'paddle', 'trt_fp32', 'trt_fp16' or 'trt_int8'.";
return -1;
}
transform(FLAGS_device.begin(),
......
......@@ -206,7 +206,7 @@ void Pipeline::PredictMOT(const std::string& video_path) {
times = total_time / frame_id;
LOG(INFO) << "frame_id: " << frame_id
<< " predict time(s): " << total_time / 1000;
<< " predict time(s): " << times / 1000;
cv::Mat out_img = PaddleDetection::VisualizeTrackResult(
frame, result, 1000. / times, frame_id);
......@@ -301,8 +301,7 @@ void Pipeline::RunMOTStream(const cv::Mat img,
total_time = std::accumulate(det_times.begin(), det_times.end(), 0.);
times = total_time / frame_id;
LOG(INFO) << "frame_id: " << frame_id
<< " predict time(s): " << total_time / 1000;
LOG(INFO) << "frame_id: " << frame_id << " predict time(s): " << times / 1000;
out_img = PaddleDetection::VisualizeTrackResult(
img, result, 1000. / times, frame_id);
......
......@@ -232,7 +232,7 @@ mot_sde_infer.predict_naive(model_dir,
| --video_file | Option | 需要预测的视频 |
| --camera_id | Option | 用来预测的摄像头ID,默认为-1(表示不使用摄像头预测,可设置为:0 - (摄像头数目-1) ),预测过程中在可视化界面按`q`退出输出预测结果到:output/output.mp4|
| --device | Option | 运行时的设备,可选择`CPU/GPU/XPU`,默认为`CPU`|
| --run_mode | Option |使用GPU时,默认为fluid, 可选(fluid/trt_fp32/trt_fp16/trt_int8)|
| --run_mode | Option |使用GPU时,默认为paddle, 可选(paddle/trt_fp32/trt_fp16/trt_int8)|
| --batch_size | Option |预测时的batch size,在指定`image_dir`时有效,默认为1 |
| --threshold | Option|预测得分的阈值,默认为0.5|
| --output_dir | Option|可视化结果保存的根目录,默认为output/|
......@@ -248,6 +248,6 @@ mot_sde_infer.predict_naive(model_dir,
说明:
- 参数优先级顺序:`camera_id` > `video_file` > `image_dir` > `image_file`
- run_mode:fluid代表使用AnalysisPredictor,精度float32来推理,其他参数指用AnalysisPredictor,TensorRT不同精度来推理。
- run_mode:paddle代表使用AnalysisPredictor,精度float32来推理,其他参数指用AnalysisPredictor,TensorRT不同精度来推理。
- 如果安装的PaddlePaddle不支持基于TensorRT进行预测,需要自行编译,详细可参考[预测库编译教程](https://paddleinference.paddlepaddle.org.cn/user_guides/source_compile.html)
- --run_benchmark如果设置为True,则需要安装依赖`pip install pynvml psutil GPUtil`
......@@ -47,7 +47,7 @@ class Detector(object):
pred_config (object): config of model, defined by `Config(model_dir)`
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(fluid/trt_fp32/trt_fp16)
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
batch_size (int): size of pre batch in inference
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
......@@ -62,7 +62,7 @@ class Detector(object):
pred_config,
model_dir,
device='CPU',
run_mode='fluid',
run_mode='paddle',
batch_size=1,
trt_min_shape=1,
trt_max_shape=1280,
......@@ -180,7 +180,7 @@ class DetectorPicoDet(Detector):
config (object): config of model, defined by `Config(model_dir)`
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(fluid/trt_fp32/trt_fp16)
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
batch_size (int): size of pre batch in inference
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
......@@ -195,7 +195,7 @@ class DetectorPicoDet(Detector):
pred_config,
model_dir,
device='CPU',
run_mode='fluid',
run_mode='paddle',
batch_size=1,
trt_min_shape=1,
trt_max_shape=1280,
......@@ -370,7 +370,7 @@ class PredictConfig():
def load_predictor(model_dir,
run_mode='fluid',
run_mode='paddle',
batch_size=1,
device='CPU',
min_subgraph_size=3,
......@@ -385,7 +385,7 @@ def load_predictor(model_dir,
Args:
model_dir (str): root path of __model__ and __params__
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(fluid/trt_fp32/trt_fp16/trt_int8)
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16/trt_int8)
use_dynamic_shape (bool): use dynamic shape or not
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
......@@ -397,7 +397,7 @@ def load_predictor(model_dir,
Raises:
ValueError: predict by TensorRT need device == 'GPU'.
"""
if device != 'GPU' and run_mode != 'fluid':
if device != 'GPU' and run_mode != 'paddle':
raise ValueError(
"Predict by TensorRT mode: {}, expect device=='GPU', but device == {}"
.format(run_mode, device))
......@@ -570,7 +570,7 @@ def predict_video(detector, camera_id):
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_out_name)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
index = 1
while (1):
......
......@@ -44,7 +44,7 @@ class JDE_Detector(Detector):
pred_config (object): config of model, defined by `Config(model_dir)`
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(fluid/trt_fp32/trt_fp16)
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
batch_size (int): size of per batch in inference, default is 1 in tracking models
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
......@@ -59,7 +59,7 @@ class JDE_Detector(Detector):
pred_config,
model_dir,
device='CPU',
run_mode='fluid',
run_mode='paddle',
batch_size=1,
trt_min_shape=1,
trt_max_shape=1088,
......
......@@ -67,7 +67,7 @@ class SDE_Detector(Detector):
pred_config (object): config of model, defined by `Config(model_dir)`
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(fluid/trt_fp32/trt_fp16)
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
batch_size (int): size of per batch in inference, default is 1 in tracking models
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
......@@ -82,7 +82,7 @@ class SDE_Detector(Detector):
pred_config,
model_dir,
device='CPU',
run_mode='fluid',
run_mode='paddle',
batch_size=1,
trt_min_shape=1,
trt_max_shape=1088,
......@@ -216,7 +216,7 @@ class SDE_DetectorPicoDet(DetectorPicoDet):
pred_config (object): config of model, defined by `Config(model_dir)`
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(fluid/trt_fp32/trt_fp16)
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
batch_size (int): size of per batch in inference, default is 1 in tracking models
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
......@@ -231,7 +231,7 @@ class SDE_DetectorPicoDet(DetectorPicoDet):
pred_config,
model_dir,
device='CPU',
run_mode='fluid',
run_mode='paddle',
batch_size=1,
trt_min_shape=1,
trt_max_shape=1088,
......@@ -367,7 +367,7 @@ class SDE_ReID(object):
pred_config (object): config of model, defined by `Config(model_dir)`
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(fluid/trt_fp32/trt_fp16)
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
batch_size (int): size of per batch in inference, default 50 means at most
50 sub images can be made a batch and send into ReID model
trt_min_shape (int): min shape for dynamic shape in trt
......@@ -383,7 +383,7 @@ class SDE_ReID(object):
pred_config,
model_dir,
device='CPU',
run_mode='fluid',
run_mode='paddle',
batch_size=50,
trt_min_shape=1,
trt_max_shape=1088,
......
......@@ -58,8 +58,8 @@ def argsparser():
parser.add_argument(
"--run_mode",
type=str,
default='fluid',
help="mode of running(fluid/trt_fp32/trt_fp16/trt_int8)")
default='paddle',
help="mode of running(paddle/trt_fp32/trt_fp16/trt_int8)")
parser.add_argument(
"--device",
type=str,
......
......@@ -34,7 +34,7 @@ python deploy/python/infer.py --model_dir=./output_inference/yolov3_mobilenet_v1
| --video_file | Option | 需要预测的视频 |
| --camera_id | Option | 用来预测的摄像头ID,默认为-1(表示不使用摄像头预测,可设置为:0 - (摄像头数目-1) ),预测过程中在可视化界面按`q`退出输出预测结果到:output/output.mp4|
| --device | Option | 运行时的设备,可选择`CPU/GPU/XPU`,默认为`CPU`|
| --run_mode | Option |使用GPU时,默认为fluid, 可选(fluid/trt_fp32/trt_fp16/trt_int8)|
| --run_mode | Option |使用GPU时,默认为paddle, 可选(paddle/trt_fp32/trt_fp16/trt_int8)|
| --batch_size | Option |预测时的batch size,在指定`image_dir`时有效,默认为1 |
| --threshold | Option|预测得分的阈值,默认为0.5|
| --output_dir | Option|可视化结果保存的根目录,默认为output/|
......@@ -46,6 +46,6 @@ python deploy/python/infer.py --model_dir=./output_inference/yolov3_mobilenet_v1
说明:
- 参数优先级顺序:`camera_id` > `video_file` > `image_dir` > `image_file`
- run_mode:fluid代表使用AnalysisPredictor,精度float32来推理,其他参数指用AnalysisPredictor,TensorRT不同精度来推理。
- run_mode:paddle代表使用AnalysisPredictor,精度float32来推理,其他参数指用AnalysisPredictor,TensorRT不同精度来推理。
- 如果安装的PaddlePaddle不支持基于TensorRT进行预测,需要自行编译,详细可参考[预测库编译教程](https://paddleinference.paddlepaddle.org.cn/user_guides/source_compile.html)
- --run_benchmark如果设置为True,则需要安装依赖`pip install pynvml psutil GPUtil`
......@@ -72,8 +72,8 @@ def argsparser():
parser.add_argument(
"--run_mode",
type=str,
default='fluid',
help="mode of running(fluid/trt_fp32/trt_fp16/trt_int8)")
default='paddle',
help="mode of running(paddle/trt_fp32/trt_fp16/trt_int8)")
parser.add_argument(
"--device",
type=str,
......
......@@ -56,7 +56,7 @@ class Detector(object):
pred_config (object): config of model, defined by `Config(model_dir)`
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(fluid/trt_fp32/trt_fp16)
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
batch_size (int): size of pre batch in inference
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
......@@ -71,7 +71,7 @@ class Detector(object):
pred_config,
model_dir,
device='CPU',
run_mode='fluid',
run_mode='paddle',
batch_size=1,
trt_min_shape=1,
trt_max_shape=1280,
......@@ -191,7 +191,7 @@ class DetectorSOLOv2(Detector):
config (object): config of model, defined by `Config(model_dir)`
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(fluid/trt_fp32/trt_fp16)
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
batch_size (int): size of pre batch in inference
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
......@@ -206,7 +206,7 @@ class DetectorSOLOv2(Detector):
pred_config,
model_dir,
device='CPU',
run_mode='fluid',
run_mode='paddle',
batch_size=1,
trt_min_shape=1,
trt_max_shape=1280,
......@@ -283,7 +283,7 @@ class DetectorPicoDet(Detector):
config (object): config of model, defined by `Config(model_dir)`
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(fluid/trt_fp32/trt_fp16)
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
batch_size (int): size of pre batch in inference
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
......@@ -298,7 +298,7 @@ class DetectorPicoDet(Detector):
pred_config,
model_dir,
device='CPU',
run_mode='fluid',
run_mode='paddle',
batch_size=1,
trt_min_shape=1,
trt_max_shape=1280,
......@@ -471,7 +471,7 @@ class PredictConfig():
def load_predictor(model_dir,
run_mode='fluid',
run_mode='paddle',
batch_size=1,
device='CPU',
min_subgraph_size=3,
......@@ -486,7 +486,7 @@ def load_predictor(model_dir,
Args:
model_dir (str): root path of __model__ and __params__
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(fluid/trt_fp32/trt_fp16/trt_int8)
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16/trt_int8)
use_dynamic_shape (bool): use dynamic shape or not
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
......@@ -498,7 +498,7 @@ def load_predictor(model_dir,
Raises:
ValueError: predict by TensorRT need device == 'GPU'.
"""
if device != 'GPU' and run_mode != 'fluid':
if device != 'GPU' and run_mode != 'paddle':
raise ValueError(
"Predict by TensorRT mode: {}, expect device=='GPU', but device == {}"
.format(run_mode, device))
......
......@@ -46,7 +46,7 @@ class KeyPoint_Detector(Detector):
config (object): config of model, defined by `Config(model_dir)`
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(fluid/trt_fp32/trt_fp16)
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
trt_opt_shape (int): opt shape for dynamic shape in trt
......@@ -61,7 +61,7 @@ class KeyPoint_Detector(Detector):
pred_config,
model_dir,
device='CPU',
run_mode='fluid',
run_mode='paddle',
batch_size=1,
trt_min_shape=1,
trt_max_shape=1280,
......
......@@ -44,7 +44,7 @@ class JDE_Detector(Detector):
pred_config (object): config of model, defined by `Config(model_dir)`
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(fluid/trt_fp32/trt_fp16)
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
batch_size (int): size of pre batch in inference
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
......@@ -59,7 +59,7 @@ class JDE_Detector(Detector):
pred_config,
model_dir,
device='CPU',
run_mode='fluid',
run_mode='paddle',
batch_size=1,
trt_min_shape=1,
trt_max_shape=1088,
......
......@@ -72,8 +72,8 @@ def argsparser():
parser.add_argument(
"--run_mode",
type=str,
default='fluid',
help="mode of running(fluid/trt_fp32/trt_fp16/trt_int8)")
default='paddle',
help="mode of running(paddle/trt_fp32/trt_fp16/trt_int8)")
parser.add_argument(
"--device",
type=str,
......
......@@ -104,7 +104,7 @@ class SDE_Detector(Detector):
pred_config (object): config of model, defined by `Config(model_dir)`
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(fluid/trt_fp32/trt_fp16)
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
trt_opt_shape (int): opt shape for dynamic shape in trt
......@@ -118,7 +118,7 @@ class SDE_Detector(Detector):
pred_config,
model_dir,
device='CPU',
run_mode='fluid',
run_mode='paddle',
batch_size=1,
trt_min_shape=1,
trt_max_shape=1088,
......@@ -238,7 +238,7 @@ class SDE_DetectorPicoDet(DetectorPicoDet):
pred_config (object): config of model, defined by `Config(model_dir)`
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(fluid/trt_fp32/trt_fp16)
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
trt_opt_shape (int): opt shape for dynamic shape in trt
......@@ -252,7 +252,7 @@ class SDE_DetectorPicoDet(DetectorPicoDet):
pred_config,
model_dir,
device='CPU',
run_mode='fluid',
run_mode='paddle',
batch_size=1,
trt_min_shape=1,
trt_max_shape=1088,
......@@ -380,7 +380,7 @@ class SDE_ReID(object):
pred_config,
model_dir,
device='CPU',
run_mode='fluid',
run_mode='paddle',
batch_size=50,
trt_min_shape=1,
trt_max_shape=1088,
......
......@@ -57,8 +57,8 @@ def argsparser():
parser.add_argument(
"--run_mode",
type=str,
default='fluid',
help="mode of running(fluid/trt_fp32/trt_fp16/trt_int8)")
default='paddle',
help="mode of running(paddle/trt_fp32/trt_fp16/trt_int8)")
parser.add_argument(
"--device",
type=str,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册