未验证 提交 c902676e 编写于 作者: W wangguanzhong 提交者: GitHub

[MOT]add model_dir as input (#4543)

* add model_dir as input

* unify log output
上级 d3927dde
......@@ -53,7 +53,10 @@ class Pipeline {
const std::string& scene = "pedestrian",
const bool tiny_obj = false,
const bool is_mtmct = false,
const int secs_interval = 10) {
const int secs_interval = 10,
const std::string track_model_dir = "",
const std::string det_model_dir = "",
const std::string reid_model_dir = "") {
std::vector<std::string> input;
this->input_ = input;
this->device_ = device;
......@@ -67,7 +70,12 @@ class Pipeline {
this->do_entrance_counting_ = do_entrance_counting;
this->secs_interval_ = secs_interval_;
this->save_result_ = save_result;
SelectModel(scene, tiny_obj, is_mtmct);
SelectModel(scene,
tiny_obj,
is_mtmct,
track_model_dir,
det_model_dir,
reid_model_dir);
InitPredictor();
}
......@@ -102,7 +110,10 @@ class Pipeline {
// Select model according to scenes, it must execute before Run()
void SelectModel(const std::string& scene = "pedestrian",
const bool tiny_obj = false,
const bool is_mtmct = false);
const bool is_mtmct = false,
const std::string track_model_dir = "",
const std::string det_model_dir = "",
const std::string reid_model_dir = "");
void InitPredictor();
std::shared_ptr<PaddleDetection::JDEPredictor> jde_sct_;
......
......@@ -54,7 +54,8 @@ DEFINE_bool(trt_calib_mode,
"If the model is produced by TRT offline quantitative calibration, "
"trt_calib_mode need to set True");
DEFINE_bool(tiny_obj, false, "Whether tracking tiny object");
DEFINE_bool(do_entrance_counting, false,
DEFINE_bool(do_entrance_counting,
false,
"Whether counting the numbers of identifiers entering "
"or getting out from the entrance.");
DEFINE_int32(secs_interval, 10, "The seconds interval to count after tracking");
......@@ -64,6 +65,9 @@ DEFINE_string(
"",
"scene of tracking system, it can be : pedestrian/vehicle/multiclass");
DEFINE_bool(is_mtmct, false, "Whether use multi-target multi-camera tracking");
DEFINE_string(track_model_dir, "", "Path of tracking model");
DEFINE_string(det_model_dir, "", "Path of detection model");
DEFINE_string(reid_model_dir, "", "Path of reid model");
static std::string DirName(const std::string& filepath) {
auto pos = filepath.rfind(OS_PATH_SEP);
......@@ -109,15 +113,21 @@ static void MkDirs(const std::string& path) {
int main(int argc, char** argv) {
// Parsing command-line
google::ParseCommandLineFlags(&argc, &argv, true);
if (FLAGS_video_file.empty() || FLAGS_scene.empty()) {
std::cout << "Usage: ./main "
<< "-video_file=/PATH/TO/INPUT/IMAGE/ "
<< "-scene=pedestrian/vehicle/multiclass" << std::endl;
bool has_model_dir =
!(FLAGS_track_model_dir.empty() && FLAGS_det_model_dir.empty() &&
FLAGS_reid_model_dir.empty());
if (FLAGS_video_file.empty() || (FLAGS_scene.empty() && !has_model_dir)) {
LOG(ERROR) << "Usage: \n"
<< "1. ./main -video_file=/PATH/TO/INPUT/IMAGE/ "
<< "-scene=pedestrian/vehicle/multiclass\n"
<< "2. ./main -video_file=/PATH/TO/INPUT/IMAGE/ "
<< "-track_model_dir=/PATH/TO/MODEL_DIR" << std::endl;
return -1;
}
if (!(FLAGS_run_mode == "fluid" || FLAGS_run_mode == "trt_fp32" ||
FLAGS_run_mode == "trt_fp16" || FLAGS_run_mode == "trt_int8")) {
std::cout
LOG(ERROR)
<< "run_mode should be 'fluid', 'trt_fp32', 'trt_fp16' or 'trt_int8'.";
return -1;
}
......@@ -127,7 +137,7 @@ int main(int argc, char** argv) {
::toupper);
if (!(FLAGS_device == "CPU" || FLAGS_device == "GPU" ||
FLAGS_device == "XPU")) {
std::cout << "device should be 'CPU', 'GPU' or 'XPU'.";
LOG(ERROR) << "device should be 'CPU', 'GPU' or 'XPU'.";
return -1;
}
......@@ -148,7 +158,10 @@ int main(int argc, char** argv) {
FLAGS_scene,
FLAGS_tiny_obj,
FLAGS_is_mtmct,
FLAGS_secs_interval);
FLAGS_secs_interval,
FLAGS_track_model_dir,
FLAGS_det_model_dir,
FLAGS_reid_model_dir);
pipeline.SetInput(FLAGS_video_file);
if (!FLAGS_video_other_file.empty()) {
......
......@@ -36,7 +36,21 @@ void Pipeline::ClearInput() {
void Pipeline::SelectModel(const std::string& scene,
const bool tiny_obj,
const bool is_mtmct) {
const bool is_mtmct,
const std::string track_model_dir,
const std::string det_model_dir,
const std::string reid_model_dir) {
// model_dir has higher priority
if (!track_model_dir.empty()) {
track_model_dir_ = track_model_dir;
return;
}
if (!det_model_dir.empty() && !reid_model_dir.empty()) {
det_model_dir_ = det_model_dir;
reid_model_dir_ = reid_model_dir;
return;
}
// Single camera model, based on FairMot
if (scene == "pedestrian") {
if (tiny_obj) {
......@@ -100,11 +114,11 @@ void Pipeline::InitPredictor() {
void Pipeline::Run() {
if (track_model_dir_.empty() && det_model_dir_.empty()) {
std::cout << "Pipeline must use SelectModel before Run";
LOG(ERROR) << "Pipeline must use SelectModel before Run";
return;
}
if (input_.size() == 0) {
std::cout << "Pipeline must use SetInput before Run";
LOG(ERROR) << "Pipeline must use SetInput before Run";
return;
}
......@@ -165,7 +179,8 @@ void Pipeline::PredictMOT(const std::string& video_path) {
std::vector<int> in_id_list;
std::vector<int> out_id_list;
std::map<int, std::vector<float>> prev_center;
Rect entrance = {0, static_cast<float>(video_height) / 2,
Rect entrance = {0,
static_cast<float>(video_height) / 2,
static_cast<float>(video_width),
static_cast<float>(video_height) / 2};
double times;
......@@ -195,12 +210,20 @@ void Pipeline::PredictMOT(const std::string& video_path) {
cv::Mat out_img = PaddleDetection::VisualizeTrackResult(
frame, result, 1000. / times, frame_id);
// TODO: the entrance line can be set by users
PaddleDetection::FlowStatistic(
result, frame_id, secs_interval_, do_entrance_counting_, video_fps, entrance,
&id_set, &interval_id_set, &in_id_list, &out_id_list,
&prev_center, &flow_records);
// TODO(qianhui): the entrance line can be set by users
PaddleDetection::FlowStatistic(result,
frame_id,
secs_interval_,
do_entrance_counting_,
video_fps,
entrance,
&id_set,
&interval_id_set,
&in_id_list,
&out_id_list,
&prev_center,
&flow_records);
if (save_result_) {
PaddleDetection::SaveMOTResult(result, frame_id, &records);
......@@ -228,7 +251,7 @@ void Pipeline::PredictMOT(const std::string& video_path) {
fclose(fp);
LOG(INFO) << "txt result output saved as " << result_output_path.c_str();
result_output_path = output_dir_ + OS_PATH_SEP + "flow_statistic.txt";
if ((fp = fopen(result_output_path.c_str(), "w+")) == NULL) {
printf("Open %s error.\n", result_output_path);
......@@ -273,15 +296,23 @@ void Pipeline::RunMOTStream(const cv::Mat img,
LOG(INFO) << "frame_id: " << frame_id
<< " predict time(s): " << total_time / 1000;
out_img = PaddleDetection::VisualizeTrackResult(img, result, 1000. / times,
frame_id);
out_img = PaddleDetection::VisualizeTrackResult(
img, result, 1000. / times, frame_id);
// Count total number
// Count in & out number
PaddleDetection::FlowStatistic(result, frame_id, secs_interval_, do_entrance_counting_,
video_fps, entrance, id_set,
interval_id_set, in_id_list,
out_id_list, prev_center, flow_records);
PaddleDetection::FlowStatistic(result,
frame_id,
secs_interval_,
do_entrance_counting_,
video_fps,
entrance,
id_set,
interval_id_set,
in_id_list,
out_id_list,
prev_center,
flow_records);
PrintBenchmarkLog(det_times, frame_id);
if (save_result_) {
......@@ -326,4 +357,4 @@ void Pipeline::PrintBenchmarkLog(const std::vector<double> det_time,
<< ", postprocess_time(ms): " << det_time[2] / num;
}
} // namespace PaddleDetection
} // namespace PaddleDetection
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册