deepsort_tracker.py 7.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
15
This code is based on https://github.com/nwojke/deep_sort/blob/master/deep_sort/tracker.py
16 17 18 19
"""

import numpy as np

F
Feng Ni 已提交
20
from ..motion import KalmanFilter
21 22 23
from ..matching.deepsort_matching import NearestNeighborDistanceMetric
from ..matching.deepsort_matching import iou_cost, min_cost_matching, matching_cascade, gate_cost_matrix
from .base_sde_tracker import Track
24
from ..utils import Detection
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39

from ppdet.core.workspace import register, serializable
from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)

__all__ = ['DeepSORTTracker']


@register
@serializable
class DeepSORTTracker(object):
    """
    DeepSORT tracker

    Args:
40 41 42 43 44 45
        input_size (list): input feature map size to reid model, [h, w] format,
            [64, 192] as default.
        min_box_area (int): min box area to filter out low quality boxes
        vertical_ratio (float): w/h, the vertical ratio of the bbox to filter
            bad results, set 1.6 default for pedestrian tracking. If set <=0
            means no need to filter bboxes.
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
        budget (int): If not None, fix samples per class to at most this number.
            Removes the oldest samples when the budget is reached.
        max_age (int): maximum number of missed misses before a track is deleted
        n_init (float): Number of frames that a track remains in initialization
            phase. Number of consecutive detections before the track is confirmed. 
            The track state is set to `Deleted` if a miss occurs within the first 
            `n_init` frames.
        metric_type (str): either "euclidean" or "cosine", the distance metric 
            used for measurement to track association.
        matching_threshold (float): samples with larger distance are 
            considered an invalid match.
        max_iou_distance (float): max iou distance threshold
        motion (object): KalmanFilter instance
    """

    def __init__(self,
62 63 64
                 input_size=[64, 192],
                 min_box_area=0,
                 vertical_ratio=-1,
65
                 budget=100,
66
                 max_age=70,
67 68 69
                 n_init=3,
                 metric_type='cosine',
                 matching_threshold=0.2,
70
                 max_iou_distance=0.9,
71
                 motion='KalmanFilter'):
72 73 74
        self.input_size = input_size
        self.min_box_area = min_box_area
        self.vertical_ratio = vertical_ratio
75 76 77 78 79
        self.max_age = max_age
        self.n_init = n_init
        self.metric = NearestNeighborDistanceMetric(metric_type,
                                                    matching_threshold, budget)
        self.max_iou_distance = max_iou_distance
F
Feng Ni 已提交
80 81
        if motion == 'KalmanFilter':
            self.motion = KalmanFilter()
82 83 84 85 86 87 88 89 90 91 92 93

        self.tracks = []
        self._next_id = 1

    def predict(self):
        """
        Propagate track state distributions one time step forward.
        This function should be called once every time step, before `update`.
        """
        for track in self.tracks:
            track.predict(self.motion)

94
    def update(self, pred_dets, pred_embs):
95 96 97
        """
        Perform measurement update and track management.
        Args:
F
Feng Ni 已提交
98
            pred_dets (np.array): Detection results of the image, the shape is
99
                [N, 6], means 'cls_id, score, x0, y0, x1, y1'.
F
Feng Ni 已提交
100 101
            pred_embs (np.array): Embedding results of the image, the shape is
                [N, 128], usually pred_embs.shape[1] is a multiple of 128.
102
        """
103 104 105
        pred_cls_ids = pred_dets[:, 0:1]
        pred_scores = pred_dets[:, 1:2]
        pred_tlwhs = pred_dets[:, 2:6]
106 107 108 109 110 111 112

        detections = [
            Detection(tlwh, score, feat, cls_id)
            for tlwh, score, feat, cls_id in zip(pred_tlwhs, pred_scores,
                                                 pred_embs, pred_cls_ids)
        ]

113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186
        # Run matching cascade.
        matches, unmatched_tracks, unmatched_detections = \
            self._match(detections)

        # Update track set.
        for track_idx, detection_idx in matches:
            self.tracks[track_idx].update(self.motion,
                                          detections[detection_idx])
        for track_idx in unmatched_tracks:
            self.tracks[track_idx].mark_missed()
        for detection_idx in unmatched_detections:
            self._initiate_track(detections[detection_idx])
        self.tracks = [t for t in self.tracks if not t.is_deleted()]

        # Update distance metric.
        active_targets = [t.track_id for t in self.tracks if t.is_confirmed()]
        features, targets = [], []
        for track in self.tracks:
            if not track.is_confirmed():
                continue
            features += track.features
            targets += [track.track_id for _ in track.features]
            track.features = []
        self.metric.partial_fit(
            np.asarray(features), np.asarray(targets), active_targets)
        output_stracks = self.tracks
        return output_stracks

    def _match(self, detections):
        def gated_metric(tracks, dets, track_indices, detection_indices):
            features = np.array([dets[i].feature for i in detection_indices])
            targets = np.array([tracks[i].track_id for i in track_indices])
            cost_matrix = self.metric.distance(features, targets)
            cost_matrix = gate_cost_matrix(self.motion, cost_matrix, tracks,
                                           dets, track_indices,
                                           detection_indices)
            return cost_matrix

        # Split track set into confirmed and unconfirmed tracks.
        confirmed_tracks = [
            i for i, t in enumerate(self.tracks) if t.is_confirmed()
        ]
        unconfirmed_tracks = [
            i for i, t in enumerate(self.tracks) if not t.is_confirmed()
        ]

        # Associate confirmed tracks using appearance features.
        matches_a, unmatched_tracks_a, unmatched_detections = \
            matching_cascade(
                gated_metric, self.metric.matching_threshold, self.max_age,
                self.tracks, detections, confirmed_tracks)

        # Associate remaining tracks together with unconfirmed tracks using IOU.
        iou_track_candidates = unconfirmed_tracks + [
            k for k in unmatched_tracks_a
            if self.tracks[k].time_since_update == 1
        ]
        unmatched_tracks_a = [
            k for k in unmatched_tracks_a
            if self.tracks[k].time_since_update != 1
        ]
        matches_b, unmatched_tracks_b, unmatched_detections = \
            min_cost_matching(
                iou_cost, self.max_iou_distance, self.tracks,
                detections, iou_track_candidates, unmatched_detections)

        matches = matches_a + matches_b
        unmatched_tracks = list(set(unmatched_tracks_a + unmatched_tracks_b))
        return matches, unmatched_tracks, unmatched_detections

    def _initiate_track(self, detection):
        mean, covariance = self.motion.initiate(detection.to_xyah())
        self.tracks.append(
            Track(mean, covariance, self._next_id, self.n_init, self.max_age,
187
                  detection.cls_id, detection.score, detection.feature))
188
        self._next_id += 1