未验证 提交 d896574f 编写于 作者: W wangguanzhong 提交者: GitHub

[MOT] add FairMot c++ deploy (#4322)

* add fairmot deploy

* separate main_jde

* update copyright & add PrintBenchmark

* refine lap
上级 48db9a8c
...@@ -25,7 +25,7 @@ EvalMOTReader: ...@@ -25,7 +25,7 @@ EvalMOTReader:
sample_transforms: sample_transforms:
- Decode: {} - Decode: {}
- LetterBoxResize: {target_size: [608, 1088]} - LetterBoxResize: {target_size: [608, 1088]}
- NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1]} - NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True}
- Permute: {} - Permute: {}
batch_size: 1 batch_size: 1
...@@ -36,6 +36,6 @@ TestMOTReader: ...@@ -36,6 +36,6 @@ TestMOTReader:
sample_transforms: sample_transforms:
- Decode: {} - Decode: {}
- LetterBoxResize: {target_size: [608, 1088]} - LetterBoxResize: {target_size: [608, 1088]}
- NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1]} - NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True}
- Permute: {} - Permute: {}
batch_size: 1 batch_size: 1
...@@ -36,6 +36,6 @@ TestMOTReader: ...@@ -36,6 +36,6 @@ TestMOTReader:
sample_transforms: sample_transforms:
- Decode: {} - Decode: {}
- LetterBoxResize: {target_size: [320, 576]} - LetterBoxResize: {target_size: [320, 576]}
- NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1]} - NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True}
- Permute: {} - Permute: {}
batch_size: 1 batch_size: 1
...@@ -36,6 +36,6 @@ TestMOTReader: ...@@ -36,6 +36,6 @@ TestMOTReader:
sample_transforms: sample_transforms:
- Decode: {} - Decode: {}
- LetterBoxResize: {target_size: [480, 864]} - LetterBoxResize: {target_size: [480, 864]}
- NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1]} - NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True}
- Permute: {} - Permute: {}
batch_size: 1 batch_size: 1
...@@ -6,6 +6,7 @@ option(WITH_GPU "Compile demo with GPU/CPU, default use CPU." ...@@ -6,6 +6,7 @@ option(WITH_GPU "Compile demo with GPU/CPU, default use CPU."
option(WITH_TENSORRT "Compile demo with TensorRT." OFF) option(WITH_TENSORRT "Compile demo with TensorRT." OFF)
option(WITH_KEYPOINT "Whether to Compile KeyPoint detector" OFF) option(WITH_KEYPOINT "Whether to Compile KeyPoint detector" OFF)
option(WITH_MOT "Whether to Compile MOT detector" OFF)
SET(PADDLE_DIR "" CACHE PATH "Location of libraries") SET(PADDLE_DIR "" CACHE PATH "Location of libraries")
SET(PADDLE_LIB_NAME "" CACHE STRING "libpaddle_inference") SET(PADDLE_LIB_NAME "" CACHE STRING "libpaddle_inference")
...@@ -23,6 +24,8 @@ link_directories("${CMAKE_CURRENT_BINARY_DIR}/ext/yaml-cpp/lib") ...@@ -23,6 +24,8 @@ link_directories("${CMAKE_CURRENT_BINARY_DIR}/ext/yaml-cpp/lib")
if (WITH_KEYPOINT) if (WITH_KEYPOINT)
set(SRCS src/main_keypoint.cc src/preprocess_op.cc src/object_detector.cc src/picodet_postprocess.cc src/utils.cc src/keypoint_detector.cc src/keypoint_postprocess.cc) set(SRCS src/main_keypoint.cc src/preprocess_op.cc src/object_detector.cc src/picodet_postprocess.cc src/utils.cc src/keypoint_detector.cc src/keypoint_postprocess.cc)
elseif (WITH_MOT)
set(SRCS src/main_jde.cc src/preprocess_op.cc src/object_detector.cc src/jde_detector.cc src/tracker.cc src/trajectory.cc src/lapjv.cpp src/picodet_postprocess.cc src/utils.cc)
else () else ()
set(SRCS src/main.cc src/preprocess_op.cc src/object_detector.cc src/picodet_postprocess.cc src/utils.cc) set(SRCS src/main.cc src/preprocess_op.cc src/object_detector.cc src/picodet_postprocess.cc src/utils.cc)
endif() endif()
......
...@@ -99,6 +99,16 @@ class ConfigPaser { ...@@ -99,6 +99,16 @@ class ConfigPaser {
return false; return false;
} }
// Get conf_thresh for tracker
if (config["tracker"].IsDefined()) {
if (config["tracker"]["conf_thres"].IsDefined()) {
conf_thresh_ = config["tracker"]["conf_thres"].as<float>();
} else {
std::cerr << "Please set conf_thres in tracker." << std::endl;
return false;
}
}
// Get NMS for postprocess // Get NMS for postprocess
if (config["NMS"].IsDefined()) { if (config["NMS"].IsDefined()) {
nms_info_ = config["NMS"]; nms_info_ = config["NMS"];
...@@ -122,6 +132,7 @@ class ConfigPaser { ...@@ -122,6 +132,7 @@ class ConfigPaser {
std::vector<std::string> label_list_; std::vector<std::string> label_list_;
std::vector<int> fpn_stride_; std::vector<int> fpn_stride_;
bool use_dynamic_shape_; bool use_dynamic_shape_;
float conf_thresh_;
}; };
} // namespace PaddleDetection } // namespace PaddleDetection
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include <memory>
#include <utility>
#include <ctime>
#include <opencv2/core/core.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <opencv2/highgui/highgui.hpp>
#include "paddle_inference_api.h" // NOLINT
#include "include/preprocess_op.h"
#include "include/config_parser.h"
#include "include/tracker.h"
using namespace paddle_infer;
namespace PaddleDetection {
// JDE Detection Result
struct MOT_Rect
{
float left;
float top;
float right;
float bottom;
};
struct MOT_Track
{
int ids;
float score;
MOT_Rect rects;
};
typedef std::vector<MOT_Track> MOT_Result;
// Generate visualization color
cv::Scalar GetColor(int idx);
// Visualiztion Detection Result
cv::Mat VisualizeTrackResult(const cv::Mat& img,
const MOT_Result& results,
const float fps, const int frame_id);
class JDEDetector {
public:
explicit JDEDetector(const std::string& model_dir,
const std::string& device="CPU",
bool use_mkldnn=false,
int cpu_threads=1,
const std::string& run_mode="fluid",
const int batch_size=1,
const int gpu_id=0,
const int trt_min_shape=1,
const int trt_max_shape=1280,
const int trt_opt_shape=640,
bool trt_calib_mode=false,
const int min_box_area=200) {
this->device_ = device;
this->gpu_id_ = gpu_id;
this->cpu_math_library_num_threads_ = cpu_threads;
this->use_mkldnn_ = use_mkldnn;
this->trt_min_shape_ = trt_min_shape;
this->trt_max_shape_ = trt_max_shape;
this->trt_opt_shape_ = trt_opt_shape;
this->trt_calib_mode_ = trt_calib_mode;
config_.load_config(model_dir);
this->use_dynamic_shape_ = config_.use_dynamic_shape_;
this->min_subgraph_size_ = config_.min_subgraph_size_;
threshold_ = config_.draw_threshold_;
preprocessor_.Init(config_.preprocess_info_);
LoadModel(model_dir, batch_size, run_mode);
this->min_box_area_ = min_box_area;
this->conf_thresh_ = config_.conf_thresh_;
}
// Load Paddle inference model
void LoadModel(
const std::string& model_dir,
const int batch_size = 1,
const std::string& run_mode = "fluid");
// Run predictor
void Predict(const std::vector<cv::Mat> imgs,
const double threshold = 0.5,
const int warmup = 0,
const int repeats = 1,
MOT_Result* result = nullptr,
std::vector<double>* times = nullptr);
private:
std::string device_ = "CPU";
int gpu_id_ = 0;
int cpu_math_library_num_threads_ = 1;
bool use_mkldnn_ = false;
int min_subgraph_size_ = 3;
bool use_dynamic_shape_ = false;
int trt_min_shape_ = 1;
int trt_max_shape_ = 1280;
int trt_opt_shape_ = 640;
bool trt_calib_mode_ = false;
// Preprocess image and copy data to input buffer
void Preprocess(const cv::Mat& image_mat);
// Postprocess result
void Postprocess(
const cv::Mat dets, const cv::Mat emb,
MOT_Result* result);
std::shared_ptr<Predictor> predictor_;
Preprocessor preprocessor_;
ImageBlob inputs_;
std::vector<float> bbox_data_;
std::vector<float> emb_data_;
float threshold_;
ConfigPaser config_;
float min_box_area_;
float conf_thresh_;
};
} // namespace PaddleDetection
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef LAPJV_H
#define LAPJV_H
#define LARGE 1000000
#if !defined TRUE
#define TRUE 1
#endif
#if !defined FALSE
#define FALSE 0
#endif
#define NEW(x, t, n) if ((x = (t *)malloc(sizeof(t) * (n))) == 0) {return -1;}
#define FREE(x) if (x != 0) { free(x); x = 0; }
#define SWAP_INDICES(a, b) { int_t _temp_index = a; a = b; b = _temp_index; }
#include <opencv2/opencv.hpp>
namespace PaddleDetection {
typedef signed int int_t;
typedef unsigned int uint_t;
typedef double cost_t;
typedef char boolean;
typedef enum fp_t { FP_1 = 1, FP_2 = 2, FP_DYNAMIC = 3 } fp_t;
int lapjv_internal(
const cv::Mat &cost, const bool extend_cost, const float cost_limit,
int *x, int *y);
} // namespace PaddleDetection
#endif // LAPJV_H
...@@ -102,6 +102,20 @@ class Resize : public PreprocessOp { ...@@ -102,6 +102,20 @@ class Resize : public PreprocessOp {
std::vector<int> in_net_shape_; std::vector<int> in_net_shape_;
}; };
class LetterBoxResize : public PreprocessOp {
public:
virtual void Init(const YAML::Node& item) {
target_size_ = item["target_size"].as<std::vector<int>>();
}
float GenerateScale(const cv::Mat& im);
virtual void Run(cv::Mat* im, ImageBlob* data);
private:
std::vector<int> target_size_;
std::vector<int> in_net_shape_;
};
// Models with FPN need input shape % stride == 0 // Models with FPN need input shape % stride == 0
class PadStride : public PreprocessOp { class PadStride : public PreprocessOp {
public: public:
...@@ -146,6 +160,8 @@ class Preprocessor { ...@@ -146,6 +160,8 @@ class Preprocessor {
std::shared_ptr<PreprocessOp> CreateOp(const std::string& name) { std::shared_ptr<PreprocessOp> CreateOp(const std::string& name) {
if (name == "Resize") { if (name == "Resize") {
return std::make_shared<Resize>(); return std::make_shared<Resize>();
} else if (name == "LetterBoxResize") {
return std::make_shared<LetterBoxResize>();
} else if (name == "Permute") { } else if (name == "Permute") {
return std::make_shared<Permute>(); return std::make_shared<Permute>();
} else if (name == "NormalizeImage") { } else if (name == "NormalizeImage") {
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <vector>
#include <opencv2/opencv.hpp>
#include "trajectory.h"
namespace PaddleDetection {
typedef std::map<int, int> Match;
typedef std::map<int, int>::iterator MatchIterator;
struct Track
{
int id;
float score;
cv::Vec4f ltrb;
};
class JDETracker
{
public:
static JDETracker *instance(void);
virtual bool update(const cv::Mat &dets, const cv::Mat &emb, std::vector<Track> &tracks);
private:
JDETracker(void);
virtual ~JDETracker(void) {}
cv::Mat motion_distance(const TrajectoryPtrPool &a, const TrajectoryPool &b);
void linear_assignment(const cv::Mat &cost, float cost_limit, Match &matches,
std::vector<int> &mismatch_row, std::vector<int> &mismatch_col);
void remove_duplicate_trajectory(TrajectoryPool &a, TrajectoryPool &b, float iou_thresh=0.15f);
private:
static JDETracker *me;
int timestamp;
TrajectoryPool tracked_trajectories;
TrajectoryPool lost_trajectories;
TrajectoryPool removed_trajectories;
int max_lost_time;
float lambda;
float det_thresh;
};
} // namespace PaddleDetection
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include <opencv2/opencv.hpp>
namespace PaddleDetection {
typedef enum
{
New = 0,
Tracked = 1,
Lost = 2,
Removed = 3
} TrajectoryState;
class Trajectory;
typedef std::vector<Trajectory> TrajectoryPool;
typedef std::vector<Trajectory>::iterator TrajectoryPoolIterator;
typedef std::vector<Trajectory *>TrajectoryPtrPool;
typedef std::vector<Trajectory *>::iterator TrajectoryPtrPoolIterator;
class TKalmanFilter : public cv::KalmanFilter
{
public:
TKalmanFilter(void);
virtual ~TKalmanFilter(void) {}
virtual void init(const cv::Mat &measurement);
virtual const cv::Mat &predict();
virtual const cv::Mat &correct(const cv::Mat &measurement);
virtual void project(cv::Mat &mean, cv::Mat &covariance) const;
private:
float std_weight_position;
float std_weight_velocity;
};
inline TKalmanFilter::TKalmanFilter(void) : cv::KalmanFilter(8, 4)
{
cv::KalmanFilter::transitionMatrix = cv::Mat::eye(8, 8, CV_32F);
for (int i = 0; i < 4; ++i)
cv::KalmanFilter::transitionMatrix.at<float>(i, i + 4) = 1;
cv::KalmanFilter::measurementMatrix = cv::Mat::eye(4, 8, CV_32F);
std_weight_position = 1/20.f;
std_weight_velocity = 1/160.f;
}
class Trajectory : public TKalmanFilter
{
public:
Trajectory();
Trajectory(cv::Vec4f &ltrb, float score, const cv::Mat &embedding);
Trajectory(const Trajectory &other);
Trajectory &operator=(const Trajectory &rhs);
virtual ~Trajectory(void) {};
static int next_id();
virtual const cv::Mat &predict(void);
virtual void update(Trajectory &traj, int timestamp, bool update_embedding=true);
virtual void activate(int timestamp);
virtual void reactivate(Trajectory &traj, int timestamp, bool newid=false);
virtual void mark_lost(void);
virtual void mark_removed(void);
friend TrajectoryPool operator+(const TrajectoryPool &a, const TrajectoryPool &b);
friend TrajectoryPool operator+(const TrajectoryPool &a, const TrajectoryPtrPool &b);
friend TrajectoryPool &operator+=(TrajectoryPool &a, const TrajectoryPtrPool &b);
friend TrajectoryPool operator-(const TrajectoryPool &a, const TrajectoryPool &b);
friend TrajectoryPool &operator-=(TrajectoryPool &a, const TrajectoryPool &b);
friend TrajectoryPtrPool operator+(const TrajectoryPtrPool &a, const TrajectoryPtrPool &b);
friend TrajectoryPtrPool operator+(const TrajectoryPtrPool &a, TrajectoryPool &b);
friend TrajectoryPtrPool operator-(const TrajectoryPtrPool &a, const TrajectoryPtrPool &b);
friend cv::Mat embedding_distance(const TrajectoryPool &a, const TrajectoryPool &b);
friend cv::Mat embedding_distance(const TrajectoryPtrPool &a, const TrajectoryPtrPool &b);
friend cv::Mat embedding_distance(const TrajectoryPtrPool &a, const TrajectoryPool &b);
friend cv::Mat mahalanobis_distance(const TrajectoryPool &a, const TrajectoryPool &b);
friend cv::Mat mahalanobis_distance(const TrajectoryPtrPool &a, const TrajectoryPtrPool &b);
friend cv::Mat mahalanobis_distance(const TrajectoryPtrPool &a, const TrajectoryPool &b);
friend cv::Mat iou_distance(const TrajectoryPool &a, const TrajectoryPool &b);
friend cv::Mat iou_distance(const TrajectoryPtrPool &a, const TrajectoryPtrPool &b);
friend cv::Mat iou_distance(const TrajectoryPtrPool &a, const TrajectoryPool &b);
private:
void update_embedding(const cv::Mat &embedding);
public:
TrajectoryState state;
cv::Vec4f ltrb;
cv::Mat smooth_embedding;
int id;
bool is_activated;
int timestamp;
int starttime;
float score;
private:
static int count;
cv::Vec4f xyah;
cv::Mat current_embedding;
float eta;
int length;
};
inline cv::Vec4f ltrb2xyah(cv::Vec4f &ltrb)
{
cv::Vec4f xyah;
xyah[0] = (ltrb[0] + ltrb[2]) * 0.5f;
xyah[1] = (ltrb[1] + ltrb[3]) * 0.5f;
xyah[3] = ltrb[3] - ltrb[1];
xyah[2] = (ltrb[2] - ltrb[0]) / xyah[3];
return xyah;
}
inline Trajectory::Trajectory() :
state(New), ltrb(cv::Vec4f()), smooth_embedding(cv::Mat()), id(0),
is_activated(false), timestamp(0), starttime(0), score(0), eta(0.9), length(0)
{
}
inline Trajectory::Trajectory(cv::Vec4f &ltrb_, float score_, const cv::Mat &embedding) :
state(New), ltrb(ltrb_), smooth_embedding(cv::Mat()), id(0),
is_activated(false), timestamp(0), starttime(0), score(score_), eta(0.9), length(0)
{
xyah = ltrb2xyah(ltrb);
update_embedding(embedding);
}
inline Trajectory::Trajectory(const Trajectory &other):
state(other.state), ltrb(other.ltrb), id(other.id), is_activated(other.is_activated),
timestamp(other.timestamp), starttime(other.starttime), xyah(other.xyah),
score(other.score), eta(other.eta), length(other.length)
{
other.smooth_embedding.copyTo(smooth_embedding);
other.current_embedding.copyTo(current_embedding);
// copy state in KalmanFilter
other.statePre.copyTo(cv::KalmanFilter::statePre);
other.statePost.copyTo(cv::KalmanFilter::statePost);
other.errorCovPre.copyTo(cv::KalmanFilter::errorCovPre);
other.errorCovPost.copyTo(cv::KalmanFilter::errorCovPost);
}
inline Trajectory &Trajectory::operator=(const Trajectory &rhs)
{
this->state = rhs.state;
this->ltrb = rhs.ltrb;
rhs.smooth_embedding.copyTo(this->smooth_embedding);
this->id = rhs.id;
this->is_activated = rhs.is_activated;
this->timestamp = rhs.timestamp;
this->starttime = rhs.starttime;
this->xyah = rhs.xyah;
this->score = rhs.score;
rhs.current_embedding.copyTo(this->current_embedding);
this->eta = rhs.eta;
this->length = rhs.length;
// copy state in KalmanFilter
rhs.statePre.copyTo(cv::KalmanFilter::statePre);
rhs.statePost.copyTo(cv::KalmanFilter::statePost);
rhs.errorCovPre.copyTo(cv::KalmanFilter::errorCovPre);
rhs.errorCovPost.copyTo(cv::KalmanFilter::errorCovPost);
return *this;
}
inline int Trajectory::next_id()
{
++count;
return count;
}
inline void Trajectory::mark_lost(void)
{
state = Lost;
}
inline void Trajectory::mark_removed(void)
{
state = Removed;
}
} // namespace PaddleDetection
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <sstream>
// for setprecision
#include <iomanip>
#include <chrono>
#include "include/jde_detector.h"
using namespace paddle_infer;
namespace PaddleDetection {
// Load Model and create model predictor
void JDEDetector::LoadModel(const std::string& model_dir,
const int batch_size,
const std::string& run_mode) {
paddle_infer::Config config;
std::string prog_file = model_dir + OS_PATH_SEP + "model.pdmodel";
std::string params_file = model_dir + OS_PATH_SEP + "model.pdiparams";
config.SetModel(prog_file, params_file);
if (this->device_ == "GPU") {
config.EnableUseGpu(200, this->gpu_id_);
config.SwitchIrOptim(true);
// use tensorrt
if (run_mode != "fluid") {
auto precision = paddle_infer::Config::Precision::kFloat32;
if (run_mode == "trt_fp32") {
precision = paddle_infer::Config::Precision::kFloat32;
}
else if (run_mode == "trt_fp16") {
precision = paddle_infer::Config::Precision::kHalf;
}
else if (run_mode == "trt_int8") {
precision = paddle_infer::Config::Precision::kInt8;
} else {
printf("run_mode should be 'fluid', 'trt_fp32', 'trt_fp16' or 'trt_int8'");
}
// set tensorrt
config.EnableTensorRtEngine(
1 << 30,
batch_size,
this->min_subgraph_size_,
precision,
false,
this->trt_calib_mode_);
// set use dynamic shape
if (this->use_dynamic_shape_) {
// set DynamicShsape for image tensor
const std::vector<int> min_input_shape = {1, 3, this->trt_min_shape_, this->trt_min_shape_};
const std::vector<int> max_input_shape = {1, 3, this->trt_max_shape_, this->trt_max_shape_};
const std::vector<int> opt_input_shape = {1, 3, this->trt_opt_shape_, this->trt_opt_shape_};
const std::map<std::string, std::vector<int>> map_min_input_shape = {{"image", min_input_shape}};
const std::map<std::string, std::vector<int>> map_max_input_shape = {{"image", max_input_shape}};
const std::map<std::string, std::vector<int>> map_opt_input_shape = {{"image", opt_input_shape}};
config.SetTRTDynamicShapeInfo(map_min_input_shape,
map_max_input_shape,
map_opt_input_shape);
std::cout << "TensorRT dynamic shape enabled" << std::endl;
}
}
} else if (this->device_ == "XPU"){
config.EnableXpu(10*1024*1024);
} else {
config.DisableGpu();
if (this->use_mkldnn_) {
config.EnableMKLDNN();
// cache 10 different shapes for mkldnn to avoid memory leak
config.SetMkldnnCacheCapacity(10);
}
config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_);
}
config.SwitchUseFeedFetchOps(false);
config.SwitchIrOptim(true);
config.DisableGlogInfo();
// Memory optimization
config.EnableMemoryOptim();
predictor_ = std::move(CreatePredictor(config));
}
// Visualiztion results
cv::Mat VisualizeTrackResult(const cv::Mat& img,
const MOT_Result& results,
const float fps, const int frame_id) {
cv::Mat vis_img = img.clone();
int im_h = img.rows;
int im_w = img.cols;
float text_scale = std::max(1, int(im_w / 1600.));
float text_thickness = 2.;
float line_thickness = std::max(1, int(im_w / 500.));
std::ostringstream oss;
oss << std::setiosflags(std::ios::fixed) << std::setprecision(4);
oss << "frame: " << frame_id<<" ";
oss << "fps: " << fps<<" ";
oss << "num: " << results.size();
std::string text = oss.str();
cv::Point origin;
origin.x = 0;
origin.y = int(15 * text_scale);
cv::putText(
vis_img,
text,
origin,
cv::FONT_HERSHEY_PLAIN,
text_scale, (0, 0, 255), 2);
for (int i = 0; i < results.size(); ++i) {
const int obj_id = results[i].ids;
const float score = results[i].score;
cv::Scalar color = GetColor(obj_id);
cv::Point pt1 = cv::Point(results[i].rects.left, results[i].rects.top);
cv::Point pt2 = cv::Point(results[i].rects.right, results[i].rects.bottom);
cv::Point id_pt = cv::Point(results[i].rects.left, results[i].rects.top + 10);
cv::Point score_pt = cv::Point(results[i].rects.left, results[i].rects.top - 10);
cv::rectangle(vis_img, pt1, pt2, color, line_thickness);
std::ostringstream idoss;
idoss << std::setiosflags(std::ios::fixed) << std::setprecision(4);
idoss << obj_id;
std::string id_text = idoss.str();
cv::putText(vis_img,
id_text,
id_pt,
cv::FONT_HERSHEY_PLAIN,
text_scale,
cv::Scalar(0, 255, 255),
text_thickness);
std::ostringstream soss;
soss << std::setiosflags(std::ios::fixed) << std::setprecision(2);
soss << score;
std::string score_text = soss.str();
cv::putText(vis_img,
score_text,
score_pt,
cv::FONT_HERSHEY_PLAIN,
text_scale,
cv::Scalar(0, 255, 255),
text_thickness);
}
return vis_img;
}
void FilterDets(const float conf_thresh, const cv::Mat dets, std::vector<int>* index) {
for (int i = 0; i < dets.rows; ++i) {
float score = *dets.ptr<float>(i, 4);
if (score > conf_thresh) {
index->push_back(i);
}
}
}
void JDEDetector::Preprocess(const cv::Mat& ori_im) {
// Clone the image : keep the original mat for postprocess
cv::Mat im = ori_im.clone();
preprocessor_.Run(&im, &inputs_);
}
void JDEDetector::Postprocess(
const cv::Mat dets, const cv::Mat emb,
MOT_Result* result) {
result->clear();
std::vector<Track> tracks;
std::vector<int> valid;
FilterDets(conf_thresh_, dets, &valid);
cv::Mat new_dets, new_emb;
for (int i = 0; i < valid.size(); ++i) {
new_dets.push_back(dets.row(valid[i]));
new_emb.push_back(emb.row(valid[i]));
}
JDETracker::instance()->update(new_dets, new_emb, tracks);
if (tracks.size() == 0) {
MOT_Track mot_track;
MOT_Rect ret = {*dets.ptr<float>(0, 0),
*dets.ptr<float>(0, 1),
*dets.ptr<float>(0, 2),
*dets.ptr<float>(0, 3)};
mot_track.ids = 1;
mot_track.score = *dets.ptr<float>(0, 4);
mot_track.rects = ret;
result->push_back(mot_track);
} else {
std::vector<Track>::iterator titer;
for (titer = tracks.begin(); titer != tracks.end(); ++titer) {
if (titer->score < threshold_) {
continue;
} else {
float w = titer->ltrb[2] - titer->ltrb[0];
float h = titer->ltrb[3] - titer->ltrb[1];
bool vertical = w / h > 1.6;
float area = w * h;
if (area > min_box_area_ && !vertical) {
MOT_Track mot_track;
MOT_Rect ret = {titer->ltrb[0],
titer->ltrb[1],
titer->ltrb[2],
titer->ltrb[3]};
mot_track.rects = ret;
mot_track.score = titer->score;
mot_track.ids = titer->id;
result->push_back(mot_track);
}
}
}
}
}
void JDEDetector::Predict(const std::vector<cv::Mat> imgs,
const double threshold,
const int warmup,
const int repeats,
MOT_Result* result,
std::vector<double>* times) {
auto preprocess_start = std::chrono::steady_clock::now();
int batch_size = imgs.size();
// in_data_batch
std::vector<float> in_data_all;
std::vector<float> im_shape_all(batch_size * 2);
std::vector<float> scale_factor_all(batch_size * 2);
// Preprocess image
for (int bs_idx = 0; bs_idx < batch_size; bs_idx++) {
cv::Mat im = imgs.at(bs_idx);
Preprocess(im);
im_shape_all[bs_idx * 2] = inputs_.im_shape_[0];
im_shape_all[bs_idx * 2 + 1] = inputs_.im_shape_[1];
scale_factor_all[bs_idx * 2] = inputs_.scale_factor_[0];
scale_factor_all[bs_idx * 2 + 1] = inputs_.scale_factor_[1];
// TODO: reduce cost time
in_data_all.insert(in_data_all.end(), inputs_.im_data_.begin(), inputs_.im_data_.end());
}
// Prepare input tensor
auto input_names = predictor_->GetInputNames();
for (const auto& tensor_name : input_names) {
auto in_tensor = predictor_->GetInputHandle(tensor_name);
if (tensor_name == "image") {
int rh = inputs_.in_net_shape_[0];
int rw = inputs_.in_net_shape_[1];
in_tensor->Reshape({batch_size, 3, rh, rw});
in_tensor->CopyFromCpu(in_data_all.data());
} else if (tensor_name == "im_shape") {
in_tensor->Reshape({batch_size, 2});
in_tensor->CopyFromCpu(im_shape_all.data());
} else if (tensor_name == "scale_factor") {
in_tensor->Reshape({batch_size, 2});
in_tensor->CopyFromCpu(scale_factor_all.data());
}
}
auto preprocess_end = std::chrono::steady_clock::now();
std::vector<int> bbox_shape;
std::vector<int> emb_shape;
// Run predictor
// warmup
for (int i = 0; i < warmup; i++)
{
predictor_->Run();
// Get output tensor
auto output_names = predictor_->GetOutputNames();
auto bbox_tensor = predictor_->GetOutputHandle(output_names[0]);
bbox_shape = bbox_tensor->shape();
auto emb_tensor = predictor_->GetOutputHandle(output_names[1]);
emb_shape = emb_tensor->shape();
// Calculate bbox length
int bbox_size = 1;
for (int j = 0; j < bbox_shape.size(); ++j) {
bbox_size *= bbox_shape[j];
}
// Calculate emb length
int emb_size = 1;
for (int j = 0; j < emb_shape.size(); ++j) {
emb_size *= emb_shape[j];
}
bbox_data_.resize(bbox_size);
bbox_tensor->CopyToCpu(bbox_data_.data());
emb_data_.resize(emb_size);
emb_tensor->CopyToCpu(emb_data_.data());
}
auto inference_start = std::chrono::steady_clock::now();
for (int i = 0; i < repeats; i++)
{
predictor_->Run();
// Get output tensor
auto output_names = predictor_->GetOutputNames();
auto bbox_tensor = predictor_->GetOutputHandle(output_names[0]);
bbox_shape = bbox_tensor->shape();
auto emb_tensor = predictor_->GetOutputHandle(output_names[1]);
emb_shape = emb_tensor->shape();
// Calculate bbox length
int bbox_size = 1;
for (int j = 0; j < bbox_shape.size(); ++j) {
bbox_size *= bbox_shape[j];
}
// Calculate emb length
int emb_size = 1;
for (int j = 0; j < emb_shape.size(); ++j) {
emb_size *= emb_shape[j];
}
bbox_data_.resize(bbox_size);
bbox_tensor->CopyToCpu(bbox_data_.data());
emb_data_.resize(emb_size);
emb_tensor->CopyToCpu(emb_data_.data());
}
auto inference_end = std::chrono::steady_clock::now();
auto postprocess_start = std::chrono::steady_clock::now();
// Postprocessing result
result->clear();
cv::Mat dets(bbox_shape[0], 6, CV_32FC1, bbox_data_.data());
cv::Mat emb(bbox_shape[0], emb_shape[1], CV_32FC1, emb_data_.data());
Postprocess(dets, emb, result);
auto postprocess_end = std::chrono::steady_clock::now();
std::chrono::duration<float> preprocess_diff = preprocess_end - preprocess_start;
(*times)[0] += double(preprocess_diff.count() * 1000);
std::chrono::duration<float> inference_diff = inference_end - inference_start;
(*times)[1] += double(inference_diff.count() * 1000);
std::chrono::duration<float> postprocess_diff = postprocess_end - postprocess_start;
(*times)[2] += double(postprocess_diff.count() * 1000);
}
cv::Scalar GetColor(int idx) {
idx = idx * 3;
cv::Scalar color = cv::Scalar((37 * idx) % 255,
(17 * idx) % 255,
(29 * idx) % 255);
return color;
}
} // namespace PaddleDetection
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include "include/lapjv.h"
namespace PaddleDetection {
/** Column-reduction and reduction transfer for a dense cost matrix.
*/
int _ccrrt_dense(const int n, float *cost[],
int *free_rows, int *x, int *y, float *v)
{
int n_free_rows;
bool *unique;
for (int i = 0; i < n; i++) {
x[i] = -1;
v[i] = LARGE;
y[i] = 0;
}
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
const float c = cost[i][j];
if (c < v[j]) {
v[j] = c;
y[j] = i;
}
}
}
NEW(unique, bool, n);
memset(unique, TRUE, n);
{
int j = n;
do {
j--;
const int i = y[j];
if (x[i] < 0) {
x[i] = j;
} else {
unique[i] = FALSE;
y[j] = -1;
}
} while (j > 0);
}
n_free_rows = 0;
for (int i = 0; i < n; i++) {
if (x[i] < 0) {
free_rows[n_free_rows++] = i;
} else if (unique[i]) {
const int j = x[i];
float min = LARGE;
for (int j2 = 0; j2 < n; j2++) {
if (j2 == (int)j) {
continue;
}
const float c = cost[i][j2] - v[j2];
if (c < min) {
min = c;
}
}
v[j] -= min;
}
}
FREE(unique);
return n_free_rows;
}
/** Augmenting row reduction for a dense cost matrix.
*/
int _carr_dense(
const int n, float *cost[],
const int n_free_rows,
int *free_rows, int *x, int *y, float *v)
{
int current = 0;
int new_free_rows = 0;
int rr_cnt = 0;
while (current < n_free_rows) {
int i0;
int j1, j2;
float v1, v2, v1_new;
bool v1_lowers;
rr_cnt++;
const int free_i = free_rows[current++];
j1 = 0;
v1 = cost[free_i][0] - v[0];
j2 = -1;
v2 = LARGE;
for (int j = 1; j < n; j++) {
const float c = cost[free_i][j] - v[j];
if (c < v2) {
if (c >= v1) {
v2 = c;
j2 = j;
} else {
v2 = v1;
v1 = c;
j2 = j1;
j1 = j;
}
}
}
i0 = y[j1];
v1_new = v[j1] - (v2 - v1);
v1_lowers = v1_new < v[j1];
if (rr_cnt < current * n) {
if (v1_lowers) {
v[j1] = v1_new;
} else if (i0 >= 0 && j2 >= 0) {
j1 = j2;
i0 = y[j2];
}
if (i0 >= 0) {
if (v1_lowers) {
free_rows[--current] = i0;
} else {
free_rows[new_free_rows++] = i0;
}
}
} else {
if (i0 >= 0) {
free_rows[new_free_rows++] = i0;
}
}
x[free_i] = j1;
y[j1] = free_i;
}
return new_free_rows;
}
/** Find columns with minimum d[j] and put them on the SCAN list.
*/
int _find_dense(const int n, int lo, float *d, int *cols, int *y)
{
int hi = lo + 1;
float mind = d[cols[lo]];
for (int k = hi; k < n; k++) {
int j = cols[k];
if (d[j] <= mind) {
if (d[j] < mind) {
hi = lo;
mind = d[j];
}
cols[k] = cols[hi];
cols[hi++] = j;
}
}
return hi;
}
// Scan all columns in TODO starting from arbitrary column in SCAN
// and try to decrease d of the TODO columns using the SCAN column.
int _scan_dense(const int n, float *cost[],
int *plo, int*phi,
float *d, int *cols, int *pred,
int *y, float *v)
{
int lo = *plo;
int hi = *phi;
float h, cred_ij;
while (lo != hi) {
int j = cols[lo++];
const int i = y[j];
const float mind = d[j];
h = cost[i][j] - v[j] - mind;
// For all columns in TODO
for (int k = hi; k < n; k++) {
j = cols[k];
cred_ij = cost[i][j] - v[j] - h;
if (cred_ij < d[j]) {
d[j] = cred_ij;
pred[j] = i;
if (cred_ij == mind) {
if (y[j] < 0) {
return j;
}
cols[k] = cols[hi];
cols[hi++] = j;
}
}
}
}
*plo = lo;
*phi = hi;
return -1;
}
/** Single iteration of modified Dijkstra shortest path algorithm as explained in the JV paper.
*
* This is a dense matrix version.
*
* \return The closest free column index.
*/
int find_path_dense(
const int n, float *cost[],
const int start_i,
int *y, float *v,
int *pred)
{
int lo = 0, hi = 0;
int final_j = -1;
int n_ready = 0;
int *cols;
float *d;
NEW(cols, int, n);
NEW(d, float, n);
for (int i = 0; i < n; i++) {
cols[i] = i;
pred[i] = start_i;
d[i] = cost[start_i][i] - v[i];
}
while (final_j == -1) {
// No columns left on the SCAN list.
if (lo == hi) {
n_ready = lo;
hi = _find_dense(n, lo, d, cols, y);
for (int k = lo; k < hi; k++) {
const int j = cols[k];
if (y[j] < 0) {
final_j = j;
}
}
}
if (final_j == -1) {
final_j = _scan_dense(
n, cost, &lo, &hi, d, cols, pred, y, v);
}
}
{
const float mind = d[cols[lo]];
for (int k = 0; k < n_ready; k++) {
const int j = cols[k];
v[j] += d[j] - mind;
}
}
FREE(cols);
FREE(d);
return final_j;
}
/** Augment for a dense cost matrix.
*/
int _ca_dense(
const int n, float *cost[],
const int n_free_rows,
int *free_rows, int *x, int *y, float *v)
{
int *pred;
NEW(pred, int, n);
for (int *pfree_i = free_rows; pfree_i < free_rows + n_free_rows; pfree_i++) {
int i = -1, j;
int k = 0;
j = find_path_dense(n, cost, *pfree_i, y, v, pred);
while (i != *pfree_i) {
i = pred[j];
y[j] = i;
SWAP_INDICES(j, x[i]);
k++;
}
}
FREE(pred);
return 0;
}
/** Solve dense sparse LAP.
*/
int lapjv_internal(
const cv::Mat &cost, const bool extend_cost, const float cost_limit,
int *x, int *y ) {
int n_rows = cost.rows;
int n_cols = cost.cols;
int n;
if (n_rows == n_cols) {
n = n_rows;
} else if (!extend_cost) {
throw std::invalid_argument("Square cost array expected. If cost is intentionally non-square, pass extend_cost=True.");
}
// Get extend cost
if (extend_cost || cost_limit < LARGE) {
n = n_rows + n_cols;
}
cv::Mat cost_expand(n, n, CV_32F);
float expand_value;
if (cost_limit < LARGE) {
expand_value = cost_limit / 2;
} else {
double max_v;
minMaxLoc(cost, nullptr, &max_v);
expand_value = (float)max_v + 1;
}
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
cost_expand.at<float>(i, j) = expand_value;
if (i >= n_rows && j >= n_cols) {
cost_expand.at<float>(i, j) = 0;
} else if (i < n_rows && j < n_cols) {
cost_expand.at<float>(i, j) = cost.at<float>(i, j);
}
}
}
// Convert Mat to pointer array
float **cost_ptr;
NEW(cost_ptr, float *, n);
for (int i = 0; i < n; ++i) {
NEW(cost_ptr[i], float, n);
}
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
cost_ptr[i][j] = cost_expand.at<float>(i, j);
}
}
int ret;
int *free_rows;
float *v;
int *x_c;
int *y_c;
NEW(free_rows, int, n);
NEW(v, float, n);
NEW(x_c, int, n);
NEW(y_c, int, n);
ret = _ccrrt_dense(n, cost_ptr, free_rows, x_c, y_c, v);
int i = 0;
while (ret > 0 && i < 2) {
ret = _carr_dense(n, cost_ptr, ret, free_rows, x_c, y_c, v);
i++;
}
if (ret > 0) {
ret = _ca_dense(n, cost_ptr, ret, free_rows, x_c, y_c, v);
}
FREE(v);
FREE(free_rows);
for (int i = 0; i < n; ++i) {
FREE(cost_ptr[i]);
}
FREE(cost_ptr);
if (ret != 0) {
if (ret == -1){
throw "Out of memory.";
}
throw "Unknown error (lapjv_internal)";
}
// Get output of x, y, opt
for (int i = 0; i < n; ++i) {
if (i < n_rows) {
x[i] = x_c[i];
if (x[i] >= n_cols) {
x[i] = -1;
}
}
if (i < n_cols) {
y[i] = y_c[i];
if (y[i] >= n_rows) {
y[i] = -1;
}
}
}
FREE(x_c);
FREE(y_c);
return ret;
}
} // namespace PaddleDetection
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <glog/logging.h>
#include <iostream>
#include <string>
#include <vector>
#include <numeric>
#include <sys/types.h>
#include <sys/stat.h>
#include <math.h>
#include <algorithm>
#ifdef _WIN32
#include <direct.h>
#include <io.h>
#elif LINUX
#include <stdarg.h>
#include <sys/stat.h>
#endif
#include "include/object_detector.h"
#include "include/jde_detector.h"
#include <gflags/gflags.h>
#include <opencv2/opencv.hpp>
DEFINE_string(model_dir, "", "Path of inference model");
DEFINE_int32(batch_size, 1, "batch_size");
DEFINE_string(video_file, "", "Path of input video, `video_file` or `camera_id` has a highest priority.");
DEFINE_int32(camera_id, -1, "Device id of camera to predict");
DEFINE_bool(use_gpu, false, "Deprecated, please use `--device` to set the device you want to run.");
DEFINE_string(device, "CPU", "Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU.");
DEFINE_double(threshold, 0.5, "Threshold of score.");
DEFINE_string(output_dir, "output", "Directory of output visualization files.");
DEFINE_string(run_mode, "fluid", "Mode of running(fluid/trt_fp32/trt_fp16/trt_int8)");
DEFINE_int32(gpu_id, 0, "Device id of GPU to execute");
DEFINE_bool(run_benchmark, false, "Whether to predict a image_file repeatedly for benchmark");
DEFINE_bool(use_mkldnn, false, "Whether use mkldnn with CPU");
DEFINE_int32(cpu_threads, 1, "Num of threads with CPU");
DEFINE_int32(trt_min_shape, 1, "Min shape of TRT DynamicShapeI");
DEFINE_int32(trt_max_shape, 1280, "Max shape of TRT DynamicShapeI");
DEFINE_int32(trt_opt_shape, 640, "Opt shape of TRT DynamicShapeI");
DEFINE_bool(trt_calib_mode, false, "If the model is produced by TRT offline quantitative calibration, trt_calib_mode need to set True");
void PrintBenchmarkLog(std::vector<double> det_time, int img_num){
LOG(INFO) << "----------------------- Config info -----------------------";
LOG(INFO) << "runtime_device: " << FLAGS_device;
LOG(INFO) << "ir_optim: " << "True";
LOG(INFO) << "enable_memory_optim: " << "True";
int has_trt = FLAGS_run_mode.find("trt");
if (has_trt >= 0) {
LOG(INFO) << "enable_tensorrt: " << "True";
std::string precision = FLAGS_run_mode.substr(4, 8);
LOG(INFO) << "precision: " << precision;
} else {
LOG(INFO) << "enable_tensorrt: " << "False";
LOG(INFO) << "precision: " << "fp32";
}
LOG(INFO) << "enable_mkldnn: " << (FLAGS_use_mkldnn ? "True" : "False");
LOG(INFO) << "cpu_math_library_num_threads: " << FLAGS_cpu_threads;
LOG(INFO) << "----------------------- Data info -----------------------";
LOG(INFO) << "batch_size: " << FLAGS_batch_size;
LOG(INFO) << "input_shape: " << "dynamic shape";
LOG(INFO) << "----------------------- Model info -----------------------";
FLAGS_model_dir.erase(FLAGS_model_dir.find_last_not_of("/") + 1);
LOG(INFO) << "model_name: " << FLAGS_model_dir.substr(FLAGS_model_dir.find_last_of('/') + 1);
LOG(INFO) << "----------------------- Perf info ------------------------";
LOG(INFO) << "Total number of predicted data: " << img_num
<< " and total time spent(ms): "
<< std::accumulate(det_time.begin(), det_time.end(), 0);
LOG(INFO) << "preproce_time(ms): " << det_time[0] / img_num
<< ", inference_time(ms): " << det_time[1] / img_num
<< ", postprocess_time(ms): " << det_time[2] / img_num;
}
static std::string DirName(const std::string &filepath) {
auto pos = filepath.rfind(OS_PATH_SEP);
if (pos == std::string::npos) {
return "";
}
return filepath.substr(0, pos);
}
static bool PathExists(const std::string& path){
#ifdef _WIN32
struct _stat buffer;
return (_stat(path.c_str(), &buffer) == 0);
#else
struct stat buffer;
return (stat(path.c_str(), &buffer) == 0);
#endif // !_WIN32
}
static void MkDir(const std::string& path) {
if (PathExists(path)) return;
int ret = 0;
#ifdef _WIN32
ret = _mkdir(path.c_str());
#else
ret = mkdir(path.c_str(), 0755);
#endif // !_WIN32
if (ret != 0) {
std::string path_error(path);
path_error += " mkdir failed!";
throw std::runtime_error(path_error);
}
}
static void MkDirs(const std::string& path) {
if (path.empty()) return;
if (PathExists(path)) return;
MkDirs(DirName(path));
MkDir(path);
}
void PredictVideo(const std::string& video_path,
PaddleDetection::JDEDetector* mot) {
// Open video
cv::VideoCapture capture;
if (FLAGS_camera_id != -1){
capture.open(FLAGS_camera_id);
}else{
capture.open(video_path.c_str());
}
if (!capture.isOpened()) {
printf("can not open video : %s\n", video_path.c_str());
return;
}
// Get Video info : resolution, fps
int video_width = static_cast<int>(capture.get(CV_CAP_PROP_FRAME_WIDTH));
int video_height = static_cast<int>(capture.get(CV_CAP_PROP_FRAME_HEIGHT));
int video_fps = static_cast<int>(capture.get(CV_CAP_PROP_FPS));
// Create VideoWriter for output
cv::VideoWriter video_out;
std::string video_out_path = "mot_output.mp4";
video_out.open(video_out_path.c_str(),
0x00000021,
video_fps,
cv::Size(video_width, video_height),
true);
if (!video_out.isOpened()) {
printf("create video writer failed!\n");
return;
}
PaddleDetection::MOT_Result result;
std::vector<double> det_times(3);
double times;
// Capture all frames and do inference
cv::Mat frame;
int frame_id = 0;
while (capture.read(frame)) {
if (frame.empty()) {
break;
}
std::vector<cv::Mat> imgs;
imgs.push_back(frame);
mot->Predict(imgs, 0.5, 0, 1, &result, &det_times);
frame_id += 1;
times = std::accumulate(det_times.begin(), det_times.end(), 0) / frame_id;
cv::Mat out_im = PaddleDetection::VisualizeTrackResult(
frame, result, 1000./times, frame_id);
video_out.write(out_im);
}
capture.release();
video_out.release();
PrintBenchmarkLog(det_times, frame_id);
printf("Visualized output saved as %s\n", video_out_path.c_str());
}
int main(int argc, char** argv) {
// Parsing command-line
google::ParseCommandLineFlags(&argc, &argv, true);
if (FLAGS_model_dir.empty()
|| FLAGS_video_file.empty()) {
std::cout << "Usage: ./main --model_dir=/PATH/TO/INFERENCE_MODEL/ "
<< "--video_file=/PATH/TO/INPUT/VIDEO/" << std::endl;
return -1;
}
if (!(FLAGS_run_mode == "fluid" || FLAGS_run_mode == "trt_fp32"
|| FLAGS_run_mode == "trt_fp16" || FLAGS_run_mode == "trt_int8")) {
std::cout << "run_mode should be 'fluid', 'trt_fp32', 'trt_fp16' or 'trt_int8'.";
return -1;
}
transform(FLAGS_device.begin(),FLAGS_device.end(),FLAGS_device.begin(),::toupper);
if (!(FLAGS_device == "CPU" || FLAGS_device == "GPU" || FLAGS_device == "XPU")) {
std::cout << "device should be 'CPU', 'GPU' or 'XPU'.";
return -1;
}
if (FLAGS_use_gpu) {
std::cout << "Deprecated, please use `--device` to set the device you want to run.";
return -1;
}
// Do inference on input video or image
PaddleDetection::JDEDetector mot(FLAGS_model_dir, FLAGS_device, FLAGS_use_mkldnn,
FLAGS_cpu_threads, FLAGS_run_mode, FLAGS_batch_size,FLAGS_gpu_id,
FLAGS_trt_min_shape, FLAGS_trt_max_shape, FLAGS_trt_opt_shape,
FLAGS_trt_calib_mode);
PredictVideo(FLAGS_video_file, &mot);
return 0;
}
...@@ -84,6 +84,7 @@ void Resize::Run(cv::Mat* im, ImageBlob* data) { ...@@ -84,6 +84,7 @@ void Resize::Run(cv::Mat* im, ImageBlob* data) {
}; };
} }
std::pair<float, float> Resize::GenerateScale(const cv::Mat& im) { std::pair<float, float> Resize::GenerateScale(const cv::Mat& im) {
std::pair<float, float> resize_scale; std::pair<float, float> resize_scale;
int origin_w = im.cols; int origin_w = im.cols;
...@@ -109,6 +110,65 @@ std::pair<float, float> Resize::GenerateScale(const cv::Mat& im) { ...@@ -109,6 +110,65 @@ std::pair<float, float> Resize::GenerateScale(const cv::Mat& im) {
return resize_scale; return resize_scale;
} }
void LetterBoxResize::Run(cv::Mat* im, ImageBlob* data) {
float resize_scale = GenerateScale(*im);
int new_shape_w = std::round(im->cols * resize_scale);
int new_shape_h = std::round(im->rows * resize_scale);
data->im_shape_ = {
static_cast<float>(new_shape_h),
static_cast<float>(new_shape_w)
};
float padw = (target_size_[1] - new_shape_w) / 2.;
float padh = (target_size_[0] - new_shape_h) / 2.;
int top = std::round(padh - 0.1);
int bottom = std::round(padh + 0.1);
int left = std::round(padw - 0.1);
int right = std::round(padw + 0.1);
cv::resize(
*im, *im, cv::Size(new_shape_w, new_shape_h), 0, 0, cv::INTER_AREA);
data->in_net_shape_ = {
static_cast<float>(im->rows),
static_cast<float>(im->cols),
};
cv::copyMakeBorder(
*im,
*im,
top,
bottom,
left,
right,
cv::BORDER_CONSTANT,
cv::Scalar(127.5));
data->in_net_shape_ = {
static_cast<float>(im->rows),
static_cast<float>(im->cols),
};
data->scale_factor_ = {
resize_scale,
resize_scale,
};
}
float LetterBoxResize::GenerateScale(const cv::Mat& im) {
int origin_w = im.cols;
int origin_h = im.rows;
int target_h = target_size_[0];
int target_w = target_size_[1];
float ratio_h = static_cast<float>(target_h) / static_cast<float>(origin_h);
float ratio_w = static_cast<float>(target_w) / static_cast<float>(origin_w);
float resize_scale = std::min(ratio_h, ratio_w);
return resize_scale;
}
void PadStride::Run(cv::Mat* im, ImageBlob* data) { void PadStride::Run(cv::Mat* im, ImageBlob* data) {
if (stride_ <= 0) { if (stride_ <= 0) {
return; return;
...@@ -145,7 +205,7 @@ void TopDownEvalAffine::Run(cv::Mat* im, ImageBlob* data) { ...@@ -145,7 +205,7 @@ void TopDownEvalAffine::Run(cv::Mat* im, ImageBlob* data) {
// Preprocessor op running order // Preprocessor op running order
const std::vector<std::string> Preprocessor::RUN_ORDER = { const std::vector<std::string> Preprocessor::RUN_ORDER = {
"InitInfo", "TopDownEvalAffine", "Resize", "NormalizeImage", "PadStride", "Permute" "InitInfo", "TopDownEvalAffine", "Resize", "LetterBoxResize", "NormalizeImage", "PadStride", "Permute"
}; };
void Preprocessor::Run(cv::Mat* im, ImageBlob* data) { void Preprocessor::Run(cv::Mat* im, ImageBlob* data) {
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <map>
#include <stdio.h>
#include <limits.h>
#include <algorithm>
#include "include/lapjv.h"
#include "include/tracker.h"
#define mat2vec4f(m) cv::Vec4f(*m.ptr<float>(0,0), *m.ptr<float>(0,1), *m.ptr<float>(0,2), *m.ptr<float>(0,3))
namespace PaddleDetection {
static std::map<int, float> chi2inv95 = {
{1, 3.841459f},
{2, 5.991465f},
{3, 7.814728f},
{4, 9.487729f},
{5, 11.070498f},
{6, 12.591587f},
{7, 14.067140f},
{8, 15.507313f},
{9, 16.918978f}
};
JDETracker *JDETracker::me = new JDETracker;
JDETracker *JDETracker::instance(void)
{
return me;
}
JDETracker::JDETracker(void) : timestamp(0), max_lost_time(30), lambda(0.98f), det_thresh(0.3f)
{
}
bool JDETracker::update(const cv::Mat &dets, const cv::Mat &emb, std::vector<Track> &tracks)
{
++timestamp;
TrajectoryPool candidates(dets.rows);
for (int i = 0; i < dets.rows; ++i)
{
float score = *dets.ptr<float>(i, 4);
const cv::Mat &ltrb_ = dets(cv::Rect(0, i, 4, 1));
cv::Vec4f ltrb = mat2vec4f(ltrb_);
const cv::Mat &embedding = emb(cv::Rect(0, i, emb.cols, 1));
candidates[i] = Trajectory(ltrb, score, embedding);
}
TrajectoryPtrPool tracked_trajectories;
TrajectoryPtrPool unconfirmed_trajectories;
for (size_t i = 0; i < this->tracked_trajectories.size(); ++i)
{
if (this->tracked_trajectories[i].is_activated)
tracked_trajectories.push_back(&this->tracked_trajectories[i]);
else
unconfirmed_trajectories.push_back(&this->tracked_trajectories[i]);
}
TrajectoryPtrPool trajectory_pool = tracked_trajectories + this->lost_trajectories;
for (size_t i = 0; i < trajectory_pool.size(); ++i)
trajectory_pool[i]->predict();
Match matches;
std::vector<int> mismatch_row;
std::vector<int> mismatch_col;
cv::Mat cost = motion_distance(trajectory_pool, candidates);
linear_assignment(cost, 0.7f, matches, mismatch_row, mismatch_col);
MatchIterator miter;
TrajectoryPtrPool activated_trajectories;
TrajectoryPtrPool retrieved_trajectories;
for (miter = matches.begin(); miter != matches.end(); miter++)
{
Trajectory *pt = trajectory_pool[miter->first];
Trajectory &ct = candidates[miter->second];
if (pt->state == Tracked)
{
pt->update(ct, timestamp);
activated_trajectories.push_back(pt);
}
else
{
pt->reactivate(ct, timestamp);
retrieved_trajectories.push_back(pt);
}
}
TrajectoryPtrPool next_candidates(mismatch_col.size());
for (size_t i = 0; i < mismatch_col.size(); ++i)
next_candidates[i] = &candidates[mismatch_col[i]];
TrajectoryPtrPool next_trajectory_pool;
for (size_t i = 0; i < mismatch_row.size(); ++i)
{
int j = mismatch_row[i];
if (trajectory_pool[j]->state == Tracked)
next_trajectory_pool.push_back(trajectory_pool[j]);
}
cost = iou_distance(next_trajectory_pool, next_candidates);
linear_assignment(cost, 0.5f, matches, mismatch_row, mismatch_col);
for (miter = matches.begin(); miter != matches.end(); miter++)
{
Trajectory *pt = next_trajectory_pool[miter->first];
Trajectory *ct = next_candidates[miter->second];
if (pt->state == Tracked)
{
pt->update(*ct, timestamp);
activated_trajectories.push_back(pt);
}
else
{
pt->reactivate(*ct, timestamp);
retrieved_trajectories.push_back(pt);
}
}
TrajectoryPtrPool lost_trajectories;
for (size_t i = 0; i < mismatch_row.size(); ++i)
{
Trajectory *pt = next_trajectory_pool[mismatch_row[i]];
if (pt->state != Lost)
{
pt->mark_lost();
lost_trajectories.push_back(pt);
}
}
TrajectoryPtrPool nnext_candidates(mismatch_col.size());
for (size_t i = 0; i < mismatch_col.size(); ++i)
nnext_candidates[i] = next_candidates[mismatch_col[i]];
cost = iou_distance(unconfirmed_trajectories, nnext_candidates);
linear_assignment(cost, 0.7f, matches, mismatch_row, mismatch_col);
for (miter = matches.begin(); miter != matches.end(); miter++)
{
unconfirmed_trajectories[miter->first]->update(*nnext_candidates[miter->second], timestamp);
activated_trajectories.push_back(unconfirmed_trajectories[miter->first]);
}
TrajectoryPtrPool removed_trajectories;
for (size_t i = 0; i < mismatch_row.size(); ++i)
{
unconfirmed_trajectories[mismatch_row[i]]->mark_removed();
removed_trajectories.push_back(unconfirmed_trajectories[mismatch_row[i]]);
}
for (size_t i = 0; i < mismatch_col.size(); ++i)
{
if (nnext_candidates[mismatch_col[i]]->score < det_thresh) continue;
nnext_candidates[mismatch_col[i]]->activate(timestamp);
activated_trajectories.push_back(nnext_candidates[mismatch_col[i]]);
}
for (size_t i = 0; i < this->lost_trajectories.size(); ++i)
{
Trajectory &lt = this->lost_trajectories[i];
if (timestamp - lt.timestamp > max_lost_time)
{
lt.mark_removed();
removed_trajectories.push_back(&lt);
}
}
TrajectoryPoolIterator piter;
for (piter = this->tracked_trajectories.begin(); piter != this->tracked_trajectories.end(); )
{
if (piter->state != Tracked)
piter = this->tracked_trajectories.erase(piter);
else
++piter;
}
this->tracked_trajectories += activated_trajectories;
this->tracked_trajectories += retrieved_trajectories;
this->lost_trajectories -= this->tracked_trajectories;
this->lost_trajectories += lost_trajectories;
this->lost_trajectories -= this->removed_trajectories;
this->removed_trajectories += removed_trajectories;
remove_duplicate_trajectory(this->tracked_trajectories, this->lost_trajectories);
tracks.clear();
for (size_t i = 0; i < this->tracked_trajectories.size(); ++i)
{
if (this->tracked_trajectories[i].is_activated)
{
Track track = {
.id = this->tracked_trajectories[i].id,
.score = this->tracked_trajectories[i].score,
.ltrb = this->tracked_trajectories[i].ltrb};
tracks.push_back(track);
}
}
return 0;
}
cv::Mat JDETracker::motion_distance(const TrajectoryPtrPool &a, const TrajectoryPool &b)
{
if (0 == a.size() || 0 == b.size())
return cv::Mat(a.size(), b.size(), CV_32F);
cv::Mat edists = embedding_distance(a, b);
cv::Mat mdists = mahalanobis_distance(a, b);
cv::Mat fdists = lambda * edists + (1 - lambda) * mdists;
const float gate_thresh = chi2inv95[4];
for (int i = 0; i < fdists.rows; ++i)
{
for (int j = 0; j < fdists.cols; ++j)
{
if (*mdists.ptr<float>(i, j) > gate_thresh)
*fdists.ptr<float>(i, j) = FLT_MAX;
}
}
return fdists;
}
void JDETracker::linear_assignment(const cv::Mat &cost, float cost_limit, Match &matches,
std::vector<int> &mismatch_row, std::vector<int> &mismatch_col)
{
matches.clear();
mismatch_row.clear();
mismatch_col.clear();
if (cost.empty())
{
for (int i = 0; i < cost.rows; ++i)
mismatch_row.push_back(i);
for (int i = 0; i < cost.cols; ++i)
mismatch_col.push_back(i);
return;
}
float opt = 0;
cv::Mat x(cost.rows, 1, CV_32S);
cv::Mat y(cost.cols, 1, CV_32S);
lapjv_internal(cost, true, cost_limit,
(int *)x.data, (int *)y.data);
for (int i = 0; i < x.rows; ++i)
{
int j = *x.ptr<int>(i);
if (j >= 0)
matches.insert({i, j});
else
mismatch_row.push_back(i);
}
for (int i = 0; i < y.rows; ++i)
{
int j = *y.ptr<int>(i);
if (j < 0)
mismatch_col.push_back(i);
}
return;
}
void JDETracker::remove_duplicate_trajectory(TrajectoryPool &a, TrajectoryPool &b, float iou_thresh)
{
if (0 == a.size() || 0 == b.size())
return;
cv::Mat dist = iou_distance(a, b);
cv::Mat mask = dist < iou_thresh;
std::vector<cv::Point> idx;
cv::findNonZero(mask, idx);
std::vector<int> da;
std::vector<int> db;
for (size_t i = 0; i < idx.size(); ++i)
{
int ta = a[idx[i].y].timestamp - a[idx[i].y].starttime;
int tb = b[idx[i].x].timestamp - b[idx[i].x].starttime;
if (ta > tb)
db.push_back(idx[i].x);
else
da.push_back(idx[i].y);
}
int id = 0;
TrajectoryPoolIterator piter;
for (piter = a.begin(); piter != a.end(); )
{
std::vector<int>::iterator iter = find(da.begin(), da.end(), id++);
if (iter != da.end())
piter = a.erase(piter);
else
++piter;
}
id = 0;
for (piter = b.begin(); piter != b.end(); )
{
std::vector<int>::iterator iter = find(db.begin(), db.end(), id++);
if (iter != db.end())
piter = b.erase(piter);
else
++piter;
}
}
} // namespace PaddleDetection
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include "include/trajectory.h"
namespace PaddleDetection {
void TKalmanFilter::init(const cv::Mat &measurement)
{
measurement.copyTo(statePost(cv::Rect(0, 0, 1, 4)));
statePost(cv::Rect(0, 4, 1, 4)).setTo(0);
statePost.copyTo(statePre);
float varpos = 2 * std_weight_position * (*measurement.ptr<float>(3));
varpos *= varpos;
float varvel = 10 * std_weight_velocity * (*measurement.ptr<float>(3));
varvel *= varvel;
errorCovPost.setTo(0);
*errorCovPost.ptr<float>(0, 0) = varpos;
*errorCovPost.ptr<float>(1, 1) = varpos;
*errorCovPost.ptr<float>(2, 2) = 1e-4f;
*errorCovPost.ptr<float>(3, 3) = varpos;
*errorCovPost.ptr<float>(4, 4) = varvel;
*errorCovPost.ptr<float>(5, 5) = varvel;
*errorCovPost.ptr<float>(6, 6) = 1e-10f;
*errorCovPost.ptr<float>(7, 7) = varvel;
errorCovPost.copyTo(errorCovPre);
}
const cv::Mat &TKalmanFilter::predict()
{
float varpos = std_weight_position * (*statePre.ptr<float>(3));
varpos *= varpos;
float varvel = std_weight_velocity * (*statePre.ptr<float>(3));
varvel *= varvel;
processNoiseCov.setTo(0);
*processNoiseCov.ptr<float>(0, 0) = varpos;
*processNoiseCov.ptr<float>(1, 1) = varpos;
*processNoiseCov.ptr<float>(2, 2) = 1e-4f;
*processNoiseCov.ptr<float>(3, 3) = varpos;
*processNoiseCov.ptr<float>(4, 4) = varvel;
*processNoiseCov.ptr<float>(5, 5) = varvel;
*processNoiseCov.ptr<float>(6, 6) = 1e-10f;
*processNoiseCov.ptr<float>(7, 7) = varvel;
return cv::KalmanFilter::predict();
}
const cv::Mat &TKalmanFilter::correct(const cv::Mat &measurement)
{
float varpos = std_weight_position * (*measurement.ptr<float>(3));
varpos *= varpos;
measurementNoiseCov.setTo(0);
*measurementNoiseCov.ptr<float>(0, 0) = varpos;
*measurementNoiseCov.ptr<float>(1, 1) = varpos;
*measurementNoiseCov.ptr<float>(2, 2) = 1e-2f;
*measurementNoiseCov.ptr<float>(3, 3) = varpos;
return cv::KalmanFilter::correct(measurement);
}
void TKalmanFilter::project(cv::Mat &mean, cv::Mat &covariance) const
{
float varpos = std_weight_position * (*statePost.ptr<float>(3));
varpos *= varpos;
cv::Mat measurementNoiseCov_ = cv::Mat::eye(4, 4, CV_32F);
*measurementNoiseCov_.ptr<float>(0, 0) = varpos;
*measurementNoiseCov_.ptr<float>(1, 1) = varpos;
*measurementNoiseCov_.ptr<float>(2, 2) = 1e-2f;
*measurementNoiseCov_.ptr<float>(3, 3) = varpos;
mean = measurementMatrix * statePost;
cv::Mat temp = measurementMatrix * errorCovPost;
gemm(temp, measurementMatrix, 1, measurementNoiseCov_, 1, covariance, cv::GEMM_2_T);
}
int Trajectory::count = 0;
const cv::Mat &Trajectory::predict(void)
{
if (state != Tracked)
*cv::KalmanFilter::statePost.ptr<float>(7) = 0;
return TKalmanFilter::predict();
}
void Trajectory::update(Trajectory &traj, int timestamp_, bool update_embedding_)
{
timestamp = timestamp_;
++length;
ltrb = traj.ltrb;
xyah = traj.xyah;
TKalmanFilter::correct(cv::Mat(traj.xyah));
state = Tracked;
is_activated = true;
score = traj.score;
if (update_embedding_)
update_embedding(traj.current_embedding);
}
void Trajectory::activate(int timestamp_)
{
id = next_id();
TKalmanFilter::init(cv::Mat(xyah));
length = 0;
state = Tracked;
if (timestamp_ == 1) {
is_activated = true;
}
timestamp = timestamp_;
starttime = timestamp_;
}
void Trajectory::reactivate(Trajectory &traj, int timestamp_, bool newid)
{
TKalmanFilter::correct(cv::Mat(traj.xyah));
update_embedding(traj.current_embedding);
length = 0;
state = Tracked;
is_activated = true;
timestamp = timestamp_;
if (newid)
id = next_id();
}
void Trajectory::update_embedding(const cv::Mat &embedding)
{
current_embedding = embedding / cv::norm(embedding);
if (smooth_embedding.empty())
{
smooth_embedding = current_embedding;
}
else
{
smooth_embedding = eta * smooth_embedding + (1 - eta) * current_embedding;
}
smooth_embedding = smooth_embedding / cv::norm(smooth_embedding);
}
TrajectoryPool operator+(const TrajectoryPool &a, const TrajectoryPool &b)
{
TrajectoryPool sum;
sum.insert(sum.end(), a.begin(), a.end());
std::vector<int> ids(a.size());
for (size_t i = 0; i < a.size(); ++i)
ids[i] = a[i].id;
for (size_t i = 0; i < b.size(); ++i)
{
std::vector<int>::iterator iter = find(ids.begin(), ids.end(), b[i].id);
if (iter == ids.end())
{
sum.push_back(b[i]);
ids.push_back(b[i].id);
}
}
return sum;
}
TrajectoryPool operator+(const TrajectoryPool &a, const TrajectoryPtrPool &b)
{
TrajectoryPool sum;
sum.insert(sum.end(), a.begin(), a.end());
std::vector<int> ids(a.size());
for (size_t i = 0; i < a.size(); ++i)
ids[i] = a[i].id;
for (size_t i = 0; i < b.size(); ++i)
{
std::vector<int>::iterator iter = find(ids.begin(), ids.end(), b[i]->id);
if (iter == ids.end())
{
sum.push_back(*b[i]);
ids.push_back(b[i]->id);
}
}
return sum;
}
TrajectoryPool &operator+=(TrajectoryPool &a, const TrajectoryPtrPool &b)
{
std::vector<int> ids(a.size());
for (size_t i = 0; i < a.size(); ++i)
ids[i] = a[i].id;
for (size_t i = 0; i < b.size(); ++i)
{
if (b[i]->smooth_embedding.empty())
continue;
std::vector<int>::iterator iter = find(ids.begin(), ids.end(), b[i]->id);
if (iter == ids.end())
{
a.push_back(*b[i]);
ids.push_back(b[i]->id);
}
}
return a;
}
TrajectoryPool operator-(const TrajectoryPool &a, const TrajectoryPool &b)
{
TrajectoryPool dif;
std::vector<int> ids(b.size());
for (size_t i = 0; i < b.size(); ++i)
ids[i] = b[i].id;
for (size_t i = 0; i < a.size(); ++i)
{
std::vector<int>::iterator iter = find(ids.begin(), ids.end(), a[i].id);
if (iter == ids.end())
dif.push_back(a[i]);
}
return dif;
}
TrajectoryPool &operator-=(TrajectoryPool &a, const TrajectoryPool &b)
{
std::vector<int> ids(b.size());
for (size_t i = 0; i < b.size(); ++i)
ids[i] = b[i].id;
TrajectoryPoolIterator piter;
for (piter = a.begin(); piter != a.end(); )
{
std::vector<int>::iterator iter = find(ids.begin(), ids.end(), piter->id);
if (iter == ids.end())
++piter;
else
piter = a.erase(piter);
}
return a;
}
TrajectoryPtrPool operator+(const TrajectoryPtrPool &a, const TrajectoryPtrPool &b)
{
TrajectoryPtrPool sum;
sum.insert(sum.end(), a.begin(), a.end());
std::vector<int> ids(a.size());
for (size_t i = 0; i < a.size(); ++i)
ids[i] = a[i]->id;
for (size_t i = 0; i < b.size(); ++i)
{
std::vector<int>::iterator iter = find(ids.begin(), ids.end(), b[i]->id);
if (iter == ids.end())
{
sum.push_back(b[i]);
ids.push_back(b[i]->id);
}
}
return sum;
}
TrajectoryPtrPool operator+(const TrajectoryPtrPool &a, TrajectoryPool &b)
{
TrajectoryPtrPool sum;
sum.insert(sum.end(), a.begin(), a.end());
std::vector<int> ids(a.size());
for (size_t i = 0; i < a.size(); ++i)
ids[i] = a[i]->id;
for (size_t i = 0; i < b.size(); ++i)
{
std::vector<int>::iterator iter = find(ids.begin(), ids.end(), b[i].id);
if (iter == ids.end())
{
sum.push_back(&b[i]);
ids.push_back(b[i].id);
}
}
return sum;
}
TrajectoryPtrPool operator-(const TrajectoryPtrPool &a, const TrajectoryPtrPool &b)
{
TrajectoryPtrPool dif;
std::vector<int> ids(b.size());
for (size_t i = 0; i < b.size(); ++i)
ids[i] = b[i]->id;
for (size_t i = 0; i < a.size(); ++i)
{
std::vector<int>::iterator iter = find(ids.begin(), ids.end(), a[i]->id);
if (iter == ids.end())
dif.push_back(a[i]);
}
return dif;
}
cv::Mat embedding_distance(const TrajectoryPool &a, const TrajectoryPool &b)
{
cv::Mat dists(a.size(), b.size(), CV_32F);
for (size_t i = 0; i < a.size(); ++i)
{
float *distsi = dists.ptr<float>(i);
for (size_t j = 0; j < b.size(); ++j)
{
cv::Mat u = a[i].smooth_embedding;
cv::Mat v = b[j].smooth_embedding;
double uv = u.dot(v);
double uu = u.dot(u);
double vv = v.dot(v);
double dist = std::abs(1. - uv / std::sqrt(uu * vv));
//double dist = cv::norm(a[i].smooth_embedding, b[j].smooth_embedding, cv::NORM_L2);
distsi[j] = static_cast<float>(std::max(std::min(dist, 2.), 0.));
}
}
return dists;
}
cv::Mat embedding_distance(const TrajectoryPtrPool &a, const TrajectoryPtrPool &b)
{
cv::Mat dists(a.size(), b.size(), CV_32F);
for (size_t i = 0; i < a.size(); ++i)
{
float *distsi = dists.ptr<float>(i);
for (size_t j = 0; j < b.size(); ++j)
{
//double dist = cv::norm(a[i]->smooth_embedding, b[j]->smooth_embedding, cv::NORM_L2);
//distsi[j] = static_cast<float>(dist);
cv::Mat u = a[i]->smooth_embedding;
cv::Mat v = b[j]->smooth_embedding;
double uv = u.dot(v);
double uu = u.dot(u);
double vv = v.dot(v);
double dist = std::abs(1. - uv / std::sqrt(uu * vv));
distsi[j] = static_cast<float>(std::max(std::min(dist, 2.), 0.));
}
}
return dists;
}
cv::Mat embedding_distance(const TrajectoryPtrPool &a, const TrajectoryPool &b)
{
cv::Mat dists(a.size(), b.size(), CV_32F);
for (size_t i = 0; i < a.size(); ++i)
{
float *distsi = dists.ptr<float>(i);
for (size_t j = 0; j < b.size(); ++j)
{
//double dist = cv::norm(a[i]->smooth_embedding, b[j].smooth_embedding, cv::NORM_L2);
//distsi[j] = static_cast<float>(dist);
cv::Mat u = a[i]->smooth_embedding;
cv::Mat v = b[j].smooth_embedding;
double uv = u.dot(v);
double uu = u.dot(u);
double vv = v.dot(v);
double dist = std::abs(1. - uv / std::sqrt(uu * vv));
distsi[j] = static_cast<float>(std::max(std::min(dist, 2.), 0.));
}
}
return dists;
}
cv::Mat mahalanobis_distance(const TrajectoryPool &a, const TrajectoryPool &b)
{
std::vector<cv::Mat> means(a.size());
std::vector<cv::Mat> icovariances(a.size());
for (size_t i = 0; i < a.size(); ++i)
{
cv::Mat covariance;
a[i].project(means[i], covariance);
cv::invert(covariance, icovariances[i]);
}
cv::Mat dists(a.size(), b.size(), CV_32F);
for (size_t i = 0; i < a.size(); ++i)
{
float *distsi = dists.ptr<float>(i);
for (size_t j = 0; j < b.size(); ++j)
{
const cv::Mat x(b[j].xyah);
float dist = static_cast<float>(cv::Mahalanobis(x, means[i], icovariances[i]));
distsi[j] = dist * dist;
}
}
return dists;
}
cv::Mat mahalanobis_distance(const TrajectoryPtrPool &a, const TrajectoryPtrPool &b)
{
std::vector<cv::Mat> means(a.size());
std::vector<cv::Mat> icovariances(a.size());
for (size_t i = 0; i < a.size(); ++i)
{
cv::Mat covariance;
a[i]->project(means[i], covariance);
cv::invert(covariance, icovariances[i]);
}
cv::Mat dists(a.size(), b.size(), CV_32F);
for (size_t i = 0; i < a.size(); ++i)
{
float *distsi = dists.ptr<float>(i);
for (size_t j = 0; j < b.size(); ++j)
{
const cv::Mat x(b[j]->xyah);
float dist = static_cast<float>(cv::Mahalanobis(x, means[i], icovariances[i]));
distsi[j] = dist * dist;
}
}
return dists;
}
cv::Mat mahalanobis_distance(const TrajectoryPtrPool &a, const TrajectoryPool &b)
{
std::vector<cv::Mat> means(a.size());
std::vector<cv::Mat> icovariances(a.size());
for (size_t i = 0; i < a.size(); ++i)
{
cv::Mat covariance;
a[i]->project(means[i], covariance);
cv::invert(covariance, icovariances[i]);
}
cv::Mat dists(a.size(), b.size(), CV_32F);
for (size_t i = 0; i < a.size(); ++i)
{
float *distsi = dists.ptr<float>(i);
for (size_t j = 0; j < b.size(); ++j)
{
const cv::Mat x(b[j].xyah);
float dist = static_cast<float>(cv::Mahalanobis(x, means[i], icovariances[i]));
distsi[j] = dist * dist;
}
}
return dists;
}
static inline float calc_inter_area(const cv::Vec4f &a, const cv::Vec4f &b)
{
if (a[2] < b[0] || a[0] > b[2] || a[3] < b[1] || a[1] > b[3])
return 0.f;
float w = std::min(a[2], b[2]) - std::max(a[0], b[0]);
float h = std::min(a[3], b[3]) - std::max(a[1], b[1]);
return w * h;
}
cv::Mat iou_distance(const TrajectoryPool &a, const TrajectoryPool &b)
{
std::vector<float> areaa(a.size());
for (size_t i = 0; i < a.size(); ++i)
{
float w = a[i].ltrb[2] - a[i].ltrb[0];
float h = a[i].ltrb[3] - a[i].ltrb[1];
areaa[i] = w * h;
}
std::vector<float> areab(b.size());
for (size_t j = 0; j < b.size(); ++j)
{
float w = b[j].ltrb[2] - b[j].ltrb[0];
float h = b[j].ltrb[3] - b[j].ltrb[1];
areab[j] = w * h;
}
cv::Mat dists(a.size(), b.size(), CV_32F);
for (size_t i = 0; i < a.size(); ++i)
{
const cv::Vec4f &boxa = a[i].ltrb;
float *distsi = dists.ptr<float>(i);
for (size_t j = 0; j < b.size(); ++j)
{
const cv::Vec4f &boxb = b[j].ltrb;
float inters = calc_inter_area(boxa, boxb);
distsi[j] = 1.f - inters / (areaa[i] + areab[j] - inters);
}
}
return dists;
}
cv::Mat iou_distance(const TrajectoryPtrPool &a, const TrajectoryPtrPool &b)
{
std::vector<float> areaa(a.size());
for (size_t i = 0; i < a.size(); ++i)
{
float w = a[i]->ltrb[2] - a[i]->ltrb[0];
float h = a[i]->ltrb[3] - a[i]->ltrb[1];
areaa[i] = w * h;
}
std::vector<float> areab(b.size());
for (size_t j = 0; j < b.size(); ++j)
{
float w = b[j]->ltrb[2] - b[j]->ltrb[0];
float h = b[j]->ltrb[3] - b[j]->ltrb[1];
areab[j] = w * h;
}
cv::Mat dists(a.size(), b.size(), CV_32F);
for (size_t i = 0; i < a.size(); ++i)
{
const cv::Vec4f &boxa = a[i]->ltrb;
float *distsi = dists.ptr<float>(i);
for (size_t j = 0; j < b.size(); ++j)
{
const cv::Vec4f &boxb = b[j]->ltrb;
float inters = calc_inter_area(boxa, boxb);
distsi[j] = 1.f - inters / (areaa[i] + areab[j] - inters);
}
}
return dists;
}
cv::Mat iou_distance(const TrajectoryPtrPool &a, const TrajectoryPool &b)
{
std::vector<float> areaa(a.size());
for (size_t i = 0; i < a.size(); ++i)
{
float w = a[i]->ltrb[2] - a[i]->ltrb[0];
float h = a[i]->ltrb[3] - a[i]->ltrb[1];
areaa[i] = w * h;
}
std::vector<float> areab(b.size());
for (size_t j = 0; j < b.size(); ++j)
{
float w = b[j].ltrb[2] - b[j].ltrb[0];
float h = b[j].ltrb[3] - b[j].ltrb[1];
areab[j] = w * h;
}
cv::Mat dists(a.size(), b.size(), CV_32F);
for (size_t i = 0; i < a.size(); ++i)
{
const cv::Vec4f &boxa = a[i]->ltrb;
float *distsi = dists.ptr<float>(i);
for (size_t j = 0; j < b.size(); ++j)
{
const cv::Vec4f &boxb = b[j].ltrb;
float inters = calc_inter_area(boxa, boxb);
distsi[j] = 1.f - inters / (areaa[i] + areab[j] - inters);
}
}
return dists;
}
} // namespace PaddleDetection
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册