tracker.h 2.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

W
wangguanzhong 已提交
15 16 17
// The code is based on:
// https://github.com/CnybTseng/JDE/blob/master/platforms/common/jdetracker.h
// Ths copyright of CnybTseng/JDE is as follows:
18
// MIT License
W
wangguanzhong 已提交
19

20 21 22 23 24
#pragma once

#include <map>
#include <vector>

25 26 27 28
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include "include/trajectory.h"
29 30 31 32 33 34

namespace PaddleDetection {

typedef std::map<int, int> Match;
typedef std::map<int, int>::iterator MatchIterator;

35 36 37 38
struct Track {
  int id;
  float score;
  cv::Vec4f ltrb;
39 40
};

41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
class JDETracker {
 public:
  static JDETracker *instance(void);
  virtual bool update(const cv::Mat &dets,
                      const cv::Mat &emb,
                      std::vector<Track> *tracks);

 private:
  JDETracker(void);
  virtual ~JDETracker(void) {}
  cv::Mat motion_distance(const TrajectoryPtrPool &a, const TrajectoryPool &b);
  void linear_assignment(const cv::Mat &cost,
                         float cost_limit,
                         Match *matches,
                         std::vector<int> *mismatch_row,
                         std::vector<int> *mismatch_col);
  void remove_duplicate_trajectory(TrajectoryPool *a,
                                   TrajectoryPool *b,
                                   float iou_thresh = 0.15f);

 private:
  static JDETracker *me;
  int timestamp;
  TrajectoryPool tracked_trajectories;
  TrajectoryPool lost_trajectories;
  TrajectoryPool removed_trajectories;
  int max_lost_time;
  float lambda;
  float det_thresh;
70 71
};

72
}  // namespace PaddleDetection