base_sde_tracker.py 5.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is borrow from https://github.com/nwojke/deep_sort/blob/master/deep_sort/track.py
"""

import numpy as np
from ppdet.core.workspace import register, serializable

__all__ = ['TrackState', 'Track']


class TrackState(object):
    """
    Enumeration type for the single target track state. Newly created tracks are
    classified as `tentative` until enough evidence has been collected. Then,
    the track state is changed to `confirmed`. Tracks that are no longer alive
    are classified as `deleted` to mark them for removal from the set of active
    tracks.
    """
    Tentative = 1
    Confirmed = 2
    Deleted = 3


@register
@serializable
class Track(object):
    """
    A single target track with state space `(x, y, a, h)` and associated
    velocities, where `(x, y)` is the center of the bounding box, `a` is the
    aspect ratio and `h` is the height.

    Args:
        mean (ndarray): Mean vector of the initial state distribution.
        covariance (ndarray): Covariance matrix of the initial state distribution.
        track_id (int): A unique track identifier.
        n_init (int): Number of consecutive detections before the track is confirmed.
            The track state is set to `Deleted` if a miss occurs within the first
            `n_init` frames.
        max_age (int): The maximum number of consecutive misses before the track
            state is set to `Deleted`.
        feature (Optional[ndarray]): Feature vector of the detection this track
            originates from. If not None, this feature is added to the `features` cache.

    Attributes:
        hits (int): Total number of measurement updates.
        age (int): Total number of frames since first occurance.
        time_since_update (int): Total number of frames since last measurement
            update.
        state (TrackState): The current track state.
        features (List[ndarray]): A cache of features. On each measurement update,
            the associated feature vector is added to this list.
    """

    def __init__(self,
                 mean,
                 covariance,
                 track_id,
                 n_init,
                 max_age,
                 feature=None):
        self.mean = mean
        self.covariance = covariance
        self.track_id = track_id
        self.hits = 1
        self.age = 1
        self.time_since_update = 0

        self.state = TrackState.Tentative
        self.features = []
        if feature is not None:
            self.features.append(feature)

        self._n_init = n_init
        self._max_age = max_age

    def to_tlwh(self):
        """Get position in format `(top left x, top left y, width, height)`."""
        ret = self.mean[:4].copy()
        ret[2] *= ret[3]
        ret[:2] -= ret[2:] / 2
        return ret

    def to_tlbr(self):
        """Get position in bounding box format `(min x, miny, max x, max y)`."""
        ret = self.to_tlwh()
        ret[2:] = ret[:2] + ret[2:]
        return ret

    def predict(self, kalman_filter):
        """
        Propagate the state distribution to the current time step using a Kalman
        filter prediction step.
        """
        self.mean, self.covariance = kalman_filter.predict(self.mean,
                                                           self.covariance)
        self.age += 1
        self.time_since_update += 1

    def update(self, kalman_filter, detection):
        """
        Perform Kalman filter measurement update step and update the associated
        detection feature cache.
        """
        self.mean, self.covariance = kalman_filter.update(self.mean,
                                                          self.covariance,
                                                          detection.to_xyah())
        self.features.append(detection.feature)

        self.hits += 1
        self.time_since_update = 0
        if self.state == TrackState.Tentative and self.hits >= self._n_init:
            self.state = TrackState.Confirmed

    def mark_missed(self):
        """Mark this track as missed (no association at the current time step).
        """
        if self.state == TrackState.Tentative:
            self.state = TrackState.Deleted
        elif self.time_since_update > self._max_age:
            self.state = TrackState.Deleted

    def is_tentative(self):
        """Returns True if this track is tentative (unconfirmed)."""
        return self.state == TrackState.Tentative

    def is_confirmed(self):
        """Returns True if this track is confirmed."""
        return self.state == TrackState.Confirmed

    def is_deleted(self):
        """Returns True if this track is dead and should be deleted."""
        return self.state == TrackState.Deleted