未验证 提交 d53f6412 编写于 作者: C chenjian 提交者: GitHub

Fix mot version imcompatible (#1709)

上级 9c8d8959
......@@ -5,7 +5,7 @@ FairMOT:
detector: CenterNet
reid: FairMOTEmbeddingHead
loss: FairMOTLoss
tracker: JDETracker
tracker: FrozenJDETracker
CenterNet:
backbone: DLA
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import matching
from . import tracker
from . import motion
from . import visualization
from . import utils
from .matching import *
from .tracker import *
from .motion import *
from .visualization import *
from .utils import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import jde_matching
from . import deepsort_matching
from .jde_matching import *
from .deepsort_matching import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is borrow from https://github.com/nwojke/deep_sort/tree/master/deep_sort
"""
import numpy as np
from scipy.optimize import linear_sum_assignment
from ..motion import kalman_filter
INFTY_COST = 1e+5
__all__ = [
'iou_1toN',
'iou_cost',
'_nn_euclidean_distance',
'_nn_cosine_distance',
'NearestNeighborDistanceMetric',
'min_cost_matching',
'matching_cascade',
'gate_cost_matrix',
]
def iou_1toN(bbox, candidates):
"""
Computer intersection over union (IoU) by one box to N candidates.
Args:
bbox (ndarray): A bounding box in format `(top left x, top left y, width, height)`.
candidates (ndarray): A matrix of candidate bounding boxes (one per row) in the
same format as `bbox`.
Returns:
ious (ndarray): The intersection over union in [0, 1] between the `bbox`
and each candidate. A higher score means a larger fraction of the
`bbox` is occluded by the candidate.
"""
bbox_tl = bbox[:2]
bbox_br = bbox[:2] + bbox[2:]
candidates_tl = candidates[:, :2]
candidates_br = candidates[:, :2] + candidates[:, 2:]
tl = np.c_[np.maximum(bbox_tl[0], candidates_tl[:, 0])[:, np.newaxis],
np.maximum(bbox_tl[1], candidates_tl[:, 1])[:, np.newaxis]]
br = np.c_[np.minimum(bbox_br[0], candidates_br[:, 0])[:, np.newaxis],
np.minimum(bbox_br[1], candidates_br[:, 1])[:, np.newaxis]]
wh = np.maximum(0., br - tl)
area_intersection = wh.prod(axis=1)
area_bbox = bbox[2:].prod()
area_candidates = candidates[:, 2:].prod(axis=1)
ious = area_intersection / (area_bbox + area_candidates - area_intersection)
return ious
def iou_cost(tracks, detections, track_indices=None, detection_indices=None):
"""
IoU distance metric.
Args:
tracks (list[Track]): A list of tracks.
detections (list[Detection]): A list of detections.
track_indices (Optional[list[int]]): A list of indices to tracks that
should be matched. Defaults to all `tracks`.
detection_indices (Optional[list[int]]): A list of indices to detections
that should be matched. Defaults to all `detections`.
Returns:
cost_matrix (ndarray): A cost matrix of shape len(track_indices),
len(detection_indices) where entry (i, j) is
`1 - iou(tracks[track_indices[i]], detections[detection_indices[j]])`.
"""
if track_indices is None:
track_indices = np.arange(len(tracks))
if detection_indices is None:
detection_indices = np.arange(len(detections))
cost_matrix = np.zeros((len(track_indices), len(detection_indices)))
for row, track_idx in enumerate(track_indices):
if tracks[track_idx].time_since_update > 1:
cost_matrix[row, :] = 1e+5
continue
bbox = tracks[track_idx].to_tlwh()
candidates = np.asarray([detections[i].tlwh for i in detection_indices])
cost_matrix[row, :] = 1. - iou_1toN(bbox, candidates)
return cost_matrix
def _nn_euclidean_distance(s, q):
"""
Compute pair-wise squared (Euclidean) distance between points in `s` and `q`.
Args:
s (ndarray): Sample points: an NxM matrix of N samples of dimensionality M.
q (ndarray): Query points: an LxM matrix of L samples of dimensionality M.
Returns:
distances (ndarray): A vector of length M that contains for each entry in `q` the
smallest Euclidean distance to a sample in `s`.
"""
s, q = np.asarray(s), np.asarray(q)
if len(s) == 0 or len(q) == 0:
return np.zeros((len(s), len(q)))
s2, q2 = np.square(s).sum(axis=1), np.square(q).sum(axis=1)
distances = -2. * np.dot(s, q.T) + s2[:, None] + q2[None, :]
distances = np.clip(distances, 0., float(np.inf))
return np.maximum(0.0, distances.min(axis=0))
def _nn_cosine_distance(s, q):
"""
Compute pair-wise cosine distance between points in `s` and `q`.
Args:
s (ndarray): Sample points: an NxM matrix of N samples of dimensionality M.
q (ndarray): Query points: an LxM matrix of L samples of dimensionality M.
Returns:
distances (ndarray): A vector of length M that contains for each entry in `q` the
smallest Euclidean distance to a sample in `s`.
"""
s = np.asarray(s) / np.linalg.norm(s, axis=1, keepdims=True)
q = np.asarray(q) / np.linalg.norm(q, axis=1, keepdims=True)
distances = 1. - np.dot(s, q.T)
return distances.min(axis=0)
class NearestNeighborDistanceMetric(object):
"""
A nearest neighbor distance metric that, for each target, returns
the closest distance to any sample that has been observed so far.
Args:
metric (str): Either "euclidean" or "cosine".
matching_threshold (float): The matching threshold. Samples with larger
distance are considered an invalid match.
budget (Optional[int]): If not None, fix samples per class to at most
this number. Removes the oldest samples when the budget is reached.
Attributes:
samples (Dict[int -> List[ndarray]]): A dictionary that maps from target
identities to the list of samples that have been observed so far.
"""
def __init__(self, metric, matching_threshold, budget=None):
if metric == "euclidean":
self._metric = _nn_euclidean_distance
elif metric == "cosine":
self._metric = _nn_cosine_distance
else:
raise ValueError("Invalid metric; must be either 'euclidean' or 'cosine'")
self.matching_threshold = matching_threshold
self.budget = budget
self.samples = {}
def partial_fit(self, features, targets, active_targets):
"""
Update the distance metric with new data.
Args:
features (ndarray): An NxM matrix of N features of dimensionality M.
targets (ndarray): An integer array of associated target identities.
active_targets (List[int]): A list of targets that are currently
present in the scene.
"""
for feature, target in zip(features, targets):
self.samples.setdefault(target, []).append(feature)
if self.budget is not None:
self.samples[target] = self.samples[target][-self.budget:]
self.samples = {k: self.samples[k] for k in active_targets}
def distance(self, features, targets):
"""
Compute distance between features and targets.
Args:
features (ndarray): An NxM matrix of N features of dimensionality M.
targets (list[int]): A list of targets to match the given `features` against.
Returns:
cost_matrix (ndarray): a cost matrix of shape len(targets), len(features),
where element (i, j) contains the closest squared distance between
`targets[i]` and `features[j]`.
"""
cost_matrix = np.zeros((len(targets), len(features)))
for i, target in enumerate(targets):
cost_matrix[i, :] = self._metric(self.samples[target], features)
return cost_matrix
def min_cost_matching(distance_metric, max_distance, tracks, detections, track_indices=None, detection_indices=None):
"""
Solve linear assignment problem.
Args:
distance_metric :
Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray
The distance metric is given a list of tracks and detections as
well as a list of N track indices and M detection indices. The
metric should return the NxM dimensional cost matrix, where element
(i, j) is the association cost between the i-th track in the given
track indices and the j-th detection in the given detection_indices.
max_distance (float): Gating threshold. Associations with cost larger
than this value are disregarded.
tracks (list[Track]): A list of predicted tracks at the current time
step.
detections (list[Detection]): A list of detections at the current time
step.
track_indices (list[int]): List of track indices that maps rows in
`cost_matrix` to tracks in `tracks`.
detection_indices (List[int]): List of detection indices that maps
columns in `cost_matrix` to detections in `detections`.
Returns:
A tuple (List[(int, int)], List[int], List[int]) with the following
three entries:
* A list of matched track and detection indices.
* A list of unmatched track indices.
* A list of unmatched detection indices.
"""
if track_indices is None:
track_indices = np.arange(len(tracks))
if detection_indices is None:
detection_indices = np.arange(len(detections))
if len(detection_indices) == 0 or len(track_indices) == 0:
return [], track_indices, detection_indices # Nothing to match.
cost_matrix = distance_metric(tracks, detections, track_indices, detection_indices)
cost_matrix[cost_matrix > max_distance] = max_distance + 1e-5
indices = linear_sum_assignment(cost_matrix)
matches, unmatched_tracks, unmatched_detections = [], [], []
for col, detection_idx in enumerate(detection_indices):
if col not in indices[1]:
unmatched_detections.append(detection_idx)
for row, track_idx in enumerate(track_indices):
if row not in indices[0]:
unmatched_tracks.append(track_idx)
for row, col in zip(indices[0], indices[1]):
track_idx = track_indices[row]
detection_idx = detection_indices[col]
if cost_matrix[row, col] > max_distance:
unmatched_tracks.append(track_idx)
unmatched_detections.append(detection_idx)
else:
matches.append((track_idx, detection_idx))
return matches, unmatched_tracks, unmatched_detections
def matching_cascade(distance_metric,
max_distance,
cascade_depth,
tracks,
detections,
track_indices=None,
detection_indices=None):
"""
Run matching cascade.
Args:
distance_metric :
Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray
The distance metric is given a list of tracks and detections as
well as a list of N track indices and M detection indices. The
metric should return the NxM dimensional cost matrix, where element
(i, j) is the association cost between the i-th track in the given
track indices and the j-th detection in the given detection_indices.
max_distance (float): Gating threshold. Associations with cost larger
than this value are disregarded.
cascade_depth (int): The cascade depth, should be se to the maximum
track age.
tracks (list[Track]): A list of predicted tracks at the current time
step.
detections (list[Detection]): A list of detections at the current time
step.
track_indices (list[int]): List of track indices that maps rows in
`cost_matrix` to tracks in `tracks`.
detection_indices (List[int]): List of detection indices that maps
columns in `cost_matrix` to detections in `detections`.
Returns:
A tuple (List[(int, int)], List[int], List[int]) with the following
three entries:
* A list of matched track and detection indices.
* A list of unmatched track indices.
* A list of unmatched detection indices.
"""
if track_indices is None:
track_indices = list(range(len(tracks)))
if detection_indices is None:
detection_indices = list(range(len(detections)))
unmatched_detections = detection_indices
matches = []
for level in range(cascade_depth):
if len(unmatched_detections) == 0: # No detections left
break
track_indices_l = [k for k in track_indices if tracks[k].time_since_update == 1 + level]
if len(track_indices_l) == 0: # Nothing to match at this level
continue
matches_l, _, unmatched_detections = \
min_cost_matching(
distance_metric, max_distance, tracks, detections,
track_indices_l, unmatched_detections)
matches += matches_l
unmatched_tracks = list(set(track_indices) - set(k for k, _ in matches))
return matches, unmatched_tracks, unmatched_detections
def gate_cost_matrix(kf,
cost_matrix,
tracks,
detections,
track_indices,
detection_indices,
gated_cost=INFTY_COST,
only_position=False):
"""
Invalidate infeasible entries in cost matrix based on the state
distributions obtained by Kalman filtering.
Args:
kf (object): The Kalman filter.
cost_matrix (ndarray): The NxM dimensional cost matrix, where N is the
number of track indices and M is the number of detection indices,
such that entry (i, j) is the association cost between
`tracks[track_indices[i]]` and `detections[detection_indices[j]]`.
tracks (list[Track]): A list of predicted tracks at the current time
step.
detections (list[Detection]): A list of detections at the current time
step.
track_indices (List[int]): List of track indices that maps rows in
`cost_matrix` to tracks in `tracks`.
detection_indices (List[int]): List of detection indices that maps
columns in `cost_matrix` to detections in `detections`.
gated_cost (Optional[float]): Entries in the cost matrix corresponding
to infeasible associations are set this value. Defaults to a very
large value.
only_position (Optional[bool]): If True, only the x, y position of the
state distribution is considered during gating. Default False.
"""
gating_dim = 2 if only_position else 4
gating_threshold = kalman_filter.chi2inv95[gating_dim]
measurements = np.asarray([detections[i].to_xyah() for i in detection_indices])
for row, track_idx in enumerate(track_indices):
track = tracks[track_idx]
gating_distance = kf.gating_distance(track.mean, track.covariance, measurements, only_position)
cost_matrix[row, gating_distance > gating_threshold] = gated_cost
return cost_matrix
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is borrow from https://github.com/Zhongdao/Towards-Realtime-MOT/blob/master/tracker/matching.py
"""
import lap
import scipy
import numpy as np
from scipy.spatial.distance import cdist
from ..motion import kalman_filter
from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
__all__ = [
'merge_matches',
'linear_assignment',
'cython_bbox_ious',
'iou_distance',
'embedding_distance',
'fuse_motion',
]
def merge_matches(m1, m2, shape):
O, P, Q = shape
m1 = np.asarray(m1)
m2 = np.asarray(m2)
M1 = scipy.sparse.coo_matrix((np.ones(len(m1)), (m1[:, 0], m1[:, 1])), shape=(O, P))
M2 = scipy.sparse.coo_matrix((np.ones(len(m2)), (m2[:, 0], m2[:, 1])), shape=(P, Q))
mask = M1 * M2
match = mask.nonzero()
match = list(zip(match[0], match[1]))
unmatched_O = tuple(set(range(O)) - set([i for i, j in match]))
unmatched_Q = tuple(set(range(Q)) - set([j for i, j in match]))
return match, unmatched_O, unmatched_Q
def linear_assignment(cost_matrix, thresh):
if cost_matrix.size == 0:
return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1]))
matches, unmatched_a, unmatched_b = [], [], []
cost, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)
for ix, mx in enumerate(x):
if mx >= 0:
matches.append([ix, mx])
unmatched_a = np.where(x < 0)[0]
unmatched_b = np.where(y < 0)[0]
matches = np.asarray(matches)
return matches, unmatched_a, unmatched_b
def cython_bbox_ious(atlbrs, btlbrs):
ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float)
if ious.size == 0:
return ious
try:
import cython_bbox
except Exception as e:
logger.error('cython_bbox not found, please install cython_bbox.' 'for example: `pip install cython_bbox`.')
raise e
ious = cython_bbox.bbox_overlaps(
np.ascontiguousarray(atlbrs, dtype=np.float), np.ascontiguousarray(btlbrs, dtype=np.float))
return ious
def iou_distance(atracks, btracks):
"""
Compute cost based on IoU between two list[STrack].
"""
if (len(atracks) > 0 and isinstance(atracks[0], np.ndarray)) or (len(btracks) > 0
and isinstance(btracks[0], np.ndarray)):
atlbrs = atracks
btlbrs = btracks
else:
atlbrs = [track.tlbr for track in atracks]
btlbrs = [track.tlbr for track in btracks]
_ious = cython_bbox_ious(atlbrs, btlbrs)
cost_matrix = 1 - _ious
return cost_matrix
def embedding_distance(tracks, detections, metric='euclidean'):
"""
Compute cost based on features between two list[STrack].
"""
cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float)
if cost_matrix.size == 0:
return cost_matrix
det_features = np.asarray([track.curr_feat for track in detections], dtype=np.float)
track_features = np.asarray([track.smooth_feat for track in tracks], dtype=np.float)
cost_matrix = np.maximum(0.0, cdist(track_features, det_features, metric)) # Nomalized features
return cost_matrix
def fuse_motion(kf, cost_matrix, tracks, detections, only_position=False, lambda_=0.98):
if cost_matrix.size == 0:
return cost_matrix
gating_dim = 2 if only_position else 4
gating_threshold = kalman_filter.chi2inv95[gating_dim]
measurements = np.asarray([det.to_xyah() for det in detections])
for row, track in enumerate(tracks):
gating_distance = kf.gating_distance(track.mean, track.covariance, measurements, only_position, metric='maha')
cost_matrix[row, gating_distance > gating_threshold] = np.inf
cost_matrix[row] = lambda_ * cost_matrix[row] + (1 - lambda_) * gating_distance
return cost_matrix
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import kalman_filter
from .kalman_filter import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is borrow from https://github.com/nwojke/deep_sort/blob/master/deep_sort/kalman_filter.py
"""
import numpy as np
import scipy.linalg
__all__ = ['KalmanFilter']
"""
Table for the 0.95 quantile of the chi-square distribution with N degrees of
freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv
function and used as Mahalanobis gating threshold.
"""
chi2inv95 = {1: 3.8415, 2: 5.9915, 3: 7.8147, 4: 9.4877, 5: 11.070, 6: 12.592, 7: 14.067, 8: 15.507, 9: 16.919}
class KalmanFilter(object):
"""
A simple Kalman filter for tracking bounding boxes in image space.
The 8-dimensional state space
x, y, a, h, vx, vy, va, vh
contains the bounding box center position (x, y), aspect ratio a, height h,
and their respective velocities.
Object motion follows a constant velocity model. The bounding box location
(x, y, a, h) is taken as direct observation of the state space (linear
observation model).
"""
def __init__(self):
ndim, dt = 4, 1.
# Create Kalman filter model matrices.
self._motion_mat = np.eye(2 * ndim, 2 * ndim)
for i in range(ndim):
self._motion_mat[i, ndim + i] = dt
self._update_mat = np.eye(ndim, 2 * ndim)
# Motion and observation uncertainty are chosen relative to the current
# state estimate. These weights control the amount of uncertainty in
# the model. This is a bit hacky.
self._std_weight_position = 1. / 20
self._std_weight_velocity = 1. / 160
def initiate(self, measurement):
"""
Create track from unassociated measurement.
Args:
measurement (ndarray): Bounding box coordinates (x, y, a, h) with
center position (x, y), aspect ratio a, and height h.
Returns:
The mean vector (8 dimensional) and covariance matrix (8x8
dimensional) of the new track. Unobserved velocities are
initialized to 0 mean.
"""
mean_pos = measurement
mean_vel = np.zeros_like(mean_pos)
mean = np.r_[mean_pos, mean_vel]
std = [
2 * self._std_weight_position * measurement[3], 2 * self._std_weight_position * measurement[3], 1e-2,
2 * self._std_weight_position * measurement[3], 10 * self._std_weight_velocity * measurement[3],
10 * self._std_weight_velocity * measurement[3], 1e-5, 10 * self._std_weight_velocity * measurement[3]
]
covariance = np.diag(np.square(std))
return mean, covariance
def predict(self, mean, covariance):
"""
Run Kalman filter prediction step.
Args:
mean (ndarray): The 8 dimensional mean vector of the object state
at the previous time step.
covariance (ndarray): The 8x8 dimensional covariance matrix of the
object state at the previous time step.
Returns:
The mean vector and covariance matrix of the predicted state.
Unobserved velocities are initialized to 0 mean.
"""
std_pos = [
self._std_weight_position * mean[3], self._std_weight_position * mean[3], 1e-2,
self._std_weight_position * mean[3]
]
std_vel = [
self._std_weight_velocity * mean[3], self._std_weight_velocity * mean[3], 1e-5,
self._std_weight_velocity * mean[3]
]
motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
#mean = np.dot(self._motion_mat, mean)
mean = np.dot(mean, self._motion_mat.T)
covariance = np.linalg.multi_dot((self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
return mean, covariance
def project(self, mean, covariance):
"""
Project state distribution to measurement space.
Args
mean (ndarray): The state's mean vector (8 dimensional array).
covariance (ndarray): The state's covariance matrix (8x8 dimensional).
Returns:
The projected mean and covariance matrix of the given state estimate.
"""
std = [
self._std_weight_position * mean[3], self._std_weight_position * mean[3], 1e-1,
self._std_weight_position * mean[3]
]
innovation_cov = np.diag(np.square(std))
mean = np.dot(self._update_mat, mean)
covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T))
return mean, covariance + innovation_cov
def multi_predict(self, mean, covariance):
"""
Run Kalman filter prediction step (Vectorized version).
Args:
mean (ndarray): The Nx8 dimensional mean matrix of the object states
at the previous time step.
covariance (ndarray): The Nx8x8 dimensional covariance matrics of the
object states at the previous time step.
Returns:
The mean vector and covariance matrix of the predicted state.
Unobserved velocities are initialized to 0 mean.
"""
std_pos = [
self._std_weight_position * mean[:, 3], self._std_weight_position * mean[:, 3],
1e-2 * np.ones_like(mean[:, 3]), self._std_weight_position * mean[:, 3]
]
std_vel = [
self._std_weight_velocity * mean[:, 3], self._std_weight_velocity * mean[:, 3],
1e-5 * np.ones_like(mean[:, 3]), self._std_weight_velocity * mean[:, 3]
]
sqr = np.square(np.r_[std_pos, std_vel]).T
motion_cov = []
for i in range(len(mean)):
motion_cov.append(np.diag(sqr[i]))
motion_cov = np.asarray(motion_cov)
mean = np.dot(mean, self._motion_mat.T)
left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2))
covariance = np.dot(left, self._motion_mat.T) + motion_cov
return mean, covariance
def update(self, mean, covariance, measurement):
"""
Run Kalman filter correction step.
Args:
mean (ndarray): The predicted state's mean vector (8 dimensional).
covariance (ndarray): The state's covariance matrix (8x8 dimensional).
measurement (ndarray): The 4 dimensional measurement vector
(x, y, a, h), where (x, y) is the center position, a the aspect
ratio, and h the height of the bounding box.
Returns:
The measurement-corrected state distribution.
"""
projected_mean, projected_cov = self.project(mean, covariance)
chol_factor, lower = scipy.linalg.cho_factor(projected_cov, lower=True, check_finite=False)
kalman_gain = scipy.linalg.cho_solve((chol_factor, lower),
np.dot(covariance, self._update_mat.T).T,
check_finite=False).T
innovation = measurement - projected_mean
new_mean = mean + np.dot(innovation, kalman_gain.T)
new_covariance = covariance - np.linalg.multi_dot((kalman_gain, projected_cov, kalman_gain.T))
return new_mean, new_covariance
def gating_distance(self, mean, covariance, measurements, only_position=False, metric='maha'):
"""
Compute gating distance between state distribution and measurements.
A suitable distance threshold can be obtained from `chi2inv95`. If
`only_position` is False, the chi-square distribution has 4 degrees of
freedom, otherwise 2.
Args:
mean (ndarray): Mean vector over the state distribution (8
dimensional).
covariance (ndarray): Covariance of the state distribution (8x8
dimensional).
measurements (ndarray): An Nx4 dimensional matrix of N measurements,
each in format (x, y, a, h) where (x, y) is the bounding box center
position, a the aspect ratio, and h the height.
only_position (Optional[bool]): If True, distance computation is
done with respect to the bounding box center position only.
metric (str): Metric type, 'gaussian' or 'maha'.
Returns
An array of length N, where the i-th element contains the squared
Mahalanobis distance between (mean, covariance) and `measurements[i]`.
"""
mean, covariance = self.project(mean, covariance)
if only_position:
mean, covariance = mean[:2], covariance[:2, :2]
measurements = measurements[:, :2]
d = measurements - mean
if metric == 'gaussian':
return np.sum(d * d, axis=1)
elif metric == 'maha':
cholesky_factor = np.linalg.cholesky(covariance)
z = scipy.linalg.solve_triangular(cholesky_factor, d.T, lower=True, check_finite=False, overwrite_b=True)
squared_maha = np.sum(z * z, axis=0)
return squared_maha
else:
raise ValueError('invalid distance metric')
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import base_jde_tracker
from . import base_sde_tracker
from . import jde_tracker
from .base_jde_tracker import *
from .base_sde_tracker import *
from .jde_tracker import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is borrow from https://github.com/Zhongdao/Towards-Realtime-MOT/blob/master/tracker/multitracker.py
"""
import numpy as np
from collections import deque, OrderedDict
from ..matching import jde_matching as matching
from ppdet.core.workspace import register, serializable
__all__ = [
'TrackState',
'BaseTrack',
'STrack',
'joint_stracks',
'sub_stracks',
'remove_duplicate_stracks',
]
class TrackState(object):
New = 0
Tracked = 1
Lost = 2
Removed = 3
class BaseTrack(object):
_count = 0
track_id = 0
is_activated = False
state = TrackState.New
history = OrderedDict()
features = []
curr_feature = None
score = 0
start_frame = 0
frame_id = 0
time_since_update = 0
# multi-camera
location = (np.inf, np.inf)
@property
def end_frame(self):
return self.frame_id
@staticmethod
def next_id():
BaseTrack._count += 1
return BaseTrack._count
def activate(self, *args):
raise NotImplementedError
def predict(self):
raise NotImplementedError
def update(self, *args, **kwargs):
raise NotImplementedError
def mark_lost(self):
self.state = TrackState.Lost
def mark_removed(self):
self.state = TrackState.Removed
class STrack(BaseTrack):
def __init__(self, tlwh, score, temp_feat, buffer_size=30):
# wait activate
self._tlwh = np.asarray(tlwh, dtype=np.float)
self.kalman_filter = None
self.mean, self.covariance = None, None
self.is_activated = False
self.score = score
self.tracklet_len = 0
self.smooth_feat = None
self.update_features(temp_feat)
self.features = deque([], maxlen=buffer_size)
self.alpha = 0.9
def update_features(self, feat):
feat /= np.linalg.norm(feat)
self.curr_feat = feat
if self.smooth_feat is None:
self.smooth_feat = feat
else:
self.smooth_feat = self.alpha * self.smooth_feat + (1 - self.alpha) * feat
self.features.append(feat)
self.smooth_feat /= np.linalg.norm(self.smooth_feat)
def predict(self):
mean_state = self.mean.copy()
if self.state != TrackState.Tracked:
mean_state[7] = 0
self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
@staticmethod
def multi_predict(stracks, kalman_filter):
if len(stracks) > 0:
multi_mean = np.asarray([st.mean.copy() for st in stracks])
multi_covariance = np.asarray([st.covariance for st in stracks])
for i, st in enumerate(stracks):
if st.state != TrackState.Tracked:
multi_mean[i][7] = 0
multi_mean, multi_covariance = kalman_filter.multi_predict(multi_mean, multi_covariance)
for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
stracks[i].mean = mean
stracks[i].covariance = cov
def activate(self, kalman_filter, frame_id):
"""Start a new tracklet"""
self.kalman_filter = kalman_filter
self.track_id = self.next_id()
self.mean, self.covariance = self.kalman_filter.initiate(self.tlwh_to_xyah(self._tlwh))
self.tracklet_len = 0
self.state = TrackState.Tracked
if frame_id == 1:
self.is_activated = True
self.frame_id = frame_id
self.start_frame = frame_id
def re_activate(self, new_track, frame_id, new_id=False):
self.mean, self.covariance = self.kalman_filter.update(self.mean, self.covariance,
self.tlwh_to_xyah(new_track.tlwh))
self.update_features(new_track.curr_feat)
self.tracklet_len = 0
self.state = TrackState.Tracked
self.is_activated = True
self.frame_id = frame_id
if new_id:
self.track_id = self.next_id()
def update(self, new_track, frame_id, update_feature=True):
self.frame_id = frame_id
self.tracklet_len += 1
new_tlwh = new_track.tlwh
self.mean, self.covariance = self.kalman_filter.update(self.mean, self.covariance, self.tlwh_to_xyah(new_tlwh))
self.state = TrackState.Tracked
self.is_activated = True
self.score = new_track.score
if update_feature:
self.update_features(new_track.curr_feat)
@property
def tlwh(self):
"""
Get current position in bounding box format `(top left x, top left y,
width, height)`.
"""
if self.mean is None:
return self._tlwh.copy()
ret = self.mean[:4].copy()
ret[2] *= ret[3]
ret[:2] -= ret[2:] / 2
return ret
@property
def tlbr(self):
"""
Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
`(top left, bottom right)`.
"""
ret = self.tlwh.copy()
ret[2:] += ret[:2]
return ret
@staticmethod
def tlwh_to_xyah(tlwh):
"""
Convert bounding box to format `(center x, center y, aspect ratio,
height)`, where the aspect ratio is `width / height`.
"""
ret = np.asarray(tlwh).copy()
ret[:2] += ret[2:] / 2
ret[2] /= ret[3]
return ret
def to_xyah(self):
return self.tlwh_to_xyah(self.tlwh)
@staticmethod
def tlbr_to_tlwh(tlbr):
ret = np.asarray(tlbr).copy()
ret[2:] -= ret[:2]
return ret
@staticmethod
def tlwh_to_tlbr(tlwh):
ret = np.asarray(tlwh).copy()
ret[2:] += ret[:2]
return ret
def __repr__(self):
return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame)
def joint_stracks(tlista, tlistb):
exists = {}
res = []
for t in tlista:
exists[t.track_id] = 1
res.append(t)
for t in tlistb:
tid = t.track_id
if not exists.get(tid, 0):
exists[tid] = 1
res.append(t)
return res
def sub_stracks(tlista, tlistb):
stracks = {}
for t in tlista:
stracks[t.track_id] = t
for t in tlistb:
tid = t.track_id
if stracks.get(tid, 0):
del stracks[tid]
return list(stracks.values())
def remove_duplicate_stracks(stracksa, stracksb):
pdist = matching.iou_distance(stracksa, stracksb)
pairs = np.where(pdist < 0.15)
dupa, dupb = list(), list()
for p, q in zip(*pairs):
timep = stracksa[p].frame_id - stracksa[p].start_frame
timeq = stracksb[q].frame_id - stracksb[q].start_frame
if timep > timeq:
dupb.append(q)
else:
dupa.append(p)
resa = [t for i, t in enumerate(stracksa) if not i in dupa]
resb = [t for i, t in enumerate(stracksb) if not i in dupb]
return resa, resb
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is borrow from https://github.com/nwojke/deep_sort/blob/master/deep_sort/track.py
"""
from ppdet.core.workspace import register, serializable
__all__ = ['TrackState', 'Track']
class TrackState(object):
"""
Enumeration type for the single target track state. Newly created tracks are
classified as `tentative` until enough evidence has been collected. Then,
the track state is changed to `confirmed`. Tracks that are no longer alive
are classified as `deleted` to mark them for removal from the set of active
tracks.
"""
Tentative = 1
Confirmed = 2
Deleted = 3
class Track(object):
"""
A single target track with state space `(x, y, a, h)` and associated
velocities, where `(x, y)` is the center of the bounding box, `a` is the
aspect ratio and `h` is the height.
Args:
mean (ndarray): Mean vector of the initial state distribution.
covariance (ndarray): Covariance matrix of the initial state distribution.
track_id (int): A unique track identifier.
n_init (int): Number of consecutive detections before the track is confirmed.
The track state is set to `Deleted` if a miss occurs within the first
`n_init` frames.
max_age (int): The maximum number of consecutive misses before the track
state is set to `Deleted`.
feature (Optional[ndarray]): Feature vector of the detection this track
originates from. If not None, this feature is added to the `features` cache.
Attributes:
hits (int): Total number of measurement updates.
age (int): Total number of frames since first occurance.
time_since_update (int): Total number of frames since last measurement
update.
state (TrackState): The current track state.
features (List[ndarray]): A cache of features. On each measurement update,
the associated feature vector is added to this list.
"""
def __init__(self, mean, covariance, track_id, n_init, max_age, feature=None):
self.mean = mean
self.covariance = covariance
self.track_id = track_id
self.hits = 1
self.age = 1
self.time_since_update = 0
self.state = TrackState.Tentative
self.features = []
if feature is not None:
self.features.append(feature)
self._n_init = n_init
self._max_age = max_age
def to_tlwh(self):
"""Get position in format `(top left x, top left y, width, height)`."""
ret = self.mean[:4].copy()
ret[2] *= ret[3]
ret[:2] -= ret[2:] / 2
return ret
def to_tlbr(self):
"""Get position in bounding box format `(min x, miny, max x, max y)`."""
ret = self.to_tlwh()
ret[2:] = ret[:2] + ret[2:]
return ret
def predict(self, kalman_filter):
"""
Propagate the state distribution to the current time step using a Kalman
filter prediction step.
"""
self.mean, self.covariance = kalman_filter.predict(self.mean, self.covariance)
self.age += 1
self.time_since_update += 1
def update(self, kalman_filter, detection):
"""
Perform Kalman filter measurement update step and update the associated
detection feature cache.
"""
self.mean, self.covariance = kalman_filter.update(self.mean, self.covariance, detection.to_xyah())
self.features.append(detection.feature)
self.hits += 1
self.time_since_update = 0
if self.state == TrackState.Tentative and self.hits >= self._n_init:
self.state = TrackState.Confirmed
def mark_missed(self):
"""Mark this track as missed (no association at the current time step).
"""
if self.state == TrackState.Tentative:
self.state = TrackState.Deleted
elif self.time_since_update > self._max_age:
self.state = TrackState.Deleted
def is_tentative(self):
"""Returns True if this track is tentative (unconfirmed)."""
return self.state == TrackState.Tentative
def is_confirmed(self):
"""Returns True if this track is confirmed."""
return self.state == TrackState.Confirmed
def is_deleted(self):
"""Returns True if this track is dead and should be deleted."""
return self.state == TrackState.Deleted
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is borrow from https://github.com/Zhongdao/Towards-Realtime-MOT/blob/master/tracker/multitracker.py
"""
import paddle
from ..matching import jde_matching as matching
from .base_jde_tracker import TrackState, BaseTrack, STrack
from .base_jde_tracker import joint_stracks, sub_stracks, remove_duplicate_stracks
from ppdet.core.workspace import register, serializable
from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
__all__ = ['FrozenJDETracker']
@register
@serializable
class FrozenJDETracker(object):
__inject__ = ['motion']
"""
JDE tracker
Args:
det_thresh (float): threshold of detection score
track_buffer (int): buffer for tracker
min_box_area (int): min box area to filter out low quality boxes
vertical_ratio (float): w/h, the vertical ratio of the bbox to filter
bad results, set 1.6 default for pedestrian tracking. If set -1
means no need to filter bboxes.
tracked_thresh (float): linear assignment threshold of tracked
stracks and detections
r_tracked_thresh (float): linear assignment threshold of
tracked stracks and unmatched detections
unconfirmed_thresh (float): linear assignment threshold of
unconfirmed stracks and unmatched detections
motion (object): KalmanFilter instance
conf_thres (float): confidence threshold for tracking
metric_type (str): either "euclidean" or "cosine", the distance metric
used for measurement to track association.
"""
def __init__(self,
det_thresh=0.3,
track_buffer=30,
min_box_area=200,
vertical_ratio=1.6,
tracked_thresh=0.7,
r_tracked_thresh=0.5,
unconfirmed_thresh=0.7,
motion='KalmanFilter',
conf_thres=0,
metric_type='euclidean'):
self.det_thresh = det_thresh
self.track_buffer = track_buffer
self.min_box_area = min_box_area
self.vertical_ratio = vertical_ratio
self.tracked_thresh = tracked_thresh
self.r_tracked_thresh = r_tracked_thresh
self.unconfirmed_thresh = unconfirmed_thresh
self.motion = motion
self.conf_thres = conf_thres
self.metric_type = metric_type
self.frame_id = 0
self.tracked_stracks = []
self.lost_stracks = []
self.removed_stracks = []
self.max_time_lost = 0
# max_time_lost will be calculated: int(frame_rate / 30.0 * track_buffer)
def update(self, pred_dets, pred_embs):
"""
Processes the image frame and finds bounding box(detections).
Associates the detection with corresponding tracklets and also handles
lost, removed, refound and active tracklets.
Args:
pred_dets (Tensor): Detection results of the image, shape is [N, 5].
pred_embs (Tensor): Embedding results of the image, shape is [N, 512].
Return:
output_stracks (list): The list contains information regarding the
online_tracklets for the recieved image tensor.
"""
self.frame_id += 1
activated_starcks = []
# for storing active tracks, for the current frame
refind_stracks = []
# Lost Tracks whose detections are obtained in the current frame
lost_stracks = []
# The tracks which are not obtained in the current frame but are not
# removed. (Lost for some time lesser than the threshold for removing)
removed_stracks = []
remain_inds = paddle.nonzero(pred_dets[:, 4] > self.conf_thres)
if remain_inds.shape[0] == 0:
pred_dets = paddle.zeros([0, 1])
pred_embs = paddle.zeros([0, 1])
else:
pred_dets = paddle.gather(pred_dets, remain_inds)
pred_embs = paddle.gather(pred_embs, remain_inds)
# Filter out the image with box_num = 0. pred_dets = [[0.0, 0.0, 0.0 ,0.0]]
empty_pred = True if len(pred_dets) == 1 and paddle.sum(pred_dets) == 0.0 else False
""" Step 1: Network forward, get detections & embeddings"""
if len(pred_dets) > 0 and not empty_pred:
pred_dets = pred_dets.numpy()
pred_embs = pred_embs.numpy()
detections = [
STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f, 30) for (tlbrs, f) in zip(pred_dets, pred_embs)
]
else:
detections = []
''' Add newly detected tracklets to tracked_stracks'''
unconfirmed = []
tracked_stracks = [] # type: list[STrack]
for track in self.tracked_stracks:
if not track.is_activated:
# previous tracks which are not active in the current frame are added in unconfirmed list
unconfirmed.append(track)
else:
# Active tracks are added to the local list 'tracked_stracks'
tracked_stracks.append(track)
""" Step 2: First association, with embedding"""
# Combining currently tracked_stracks and lost_stracks
strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
# Predict the current location with KF
STrack.multi_predict(strack_pool, self.motion)
dists = matching.embedding_distance(strack_pool, detections, metric=self.metric_type)
dists = matching.fuse_motion(self.motion, dists, strack_pool, detections)
# The dists is the list of distances of the detection with the tracks in strack_pool
matches, u_track, u_detection = matching.linear_assignment(dists, thresh=self.tracked_thresh)
# The matches is the array for corresponding matches of the detection with the corresponding strack_pool
for itracked, idet in matches:
# itracked is the id of the track and idet is the detection
track = strack_pool[itracked]
det = detections[idet]
if track.state == TrackState.Tracked:
# If the track is active, add the detection to the track
track.update(detections[idet], self.frame_id)
activated_starcks.append(track)
else:
# We have obtained a detection from a track which is not active,
# hence put the track in refind_stracks list
track.re_activate(det, self.frame_id, new_id=False)
refind_stracks.append(track)
# None of the steps below happen if there are no undetected tracks.
""" Step 3: Second association, with IOU"""
detections = [detections[i] for i in u_detection]
# detections is now a list of the unmatched detections
r_tracked_stracks = []
# This is container for stracks which were tracked till the previous
# frame but no detection was found for it in the current frame.
for i in u_track:
if strack_pool[i].state == TrackState.Tracked:
r_tracked_stracks.append(strack_pool[i])
dists = matching.iou_distance(r_tracked_stracks, detections)
matches, u_track, u_detection = matching.linear_assignment(dists, thresh=self.r_tracked_thresh)
# matches is the list of detections which matched with corresponding
# tracks by IOU distance method.
for itracked, idet in matches:
track = r_tracked_stracks[itracked]
det = detections[idet]
if track.state == TrackState.Tracked:
track.update(det, self.frame_id)
activated_starcks.append(track)
else:
track.re_activate(det, self.frame_id, new_id=False)
refind_stracks.append(track)
# Same process done for some unmatched detections, but now considering IOU_distance as measure
for it in u_track:
track = r_tracked_stracks[it]
if not track.state == TrackState.Lost:
track.mark_lost()
lost_stracks.append(track)
# If no detections are obtained for tracks (u_track), the tracks are added to lost_tracks list and are marked lost
'''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
detections = [detections[i] for i in u_detection]
dists = matching.iou_distance(unconfirmed, detections)
matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=self.unconfirmed_thresh)
for itracked, idet in matches:
unconfirmed[itracked].update(detections[idet], self.frame_id)
activated_starcks.append(unconfirmed[itracked])
# The tracks which are yet not matched
for it in u_unconfirmed:
track = unconfirmed[it]
track.mark_removed()
removed_stracks.append(track)
# after all these confirmation steps, if a new detection is found, it is initialized for a new track
""" Step 4: Init new stracks"""
for inew in u_detection:
track = detections[inew]
if track.score < self.det_thresh:
continue
track.activate(self.motion, self.frame_id)
activated_starcks.append(track)
""" Step 5: Update state"""
# If the tracks are lost for more frames than the threshold number, the tracks are removed.
for track in self.lost_stracks:
if self.frame_id - track.end_frame > self.max_time_lost:
track.mark_removed()
removed_stracks.append(track)
# Update the self.tracked_stracks and self.lost_stracks using the updates in this step.
self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks)
self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks)
self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
self.lost_stracks.extend(lost_stracks)
self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
self.removed_stracks.extend(removed_stracks)
self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks)
# get scores of lost tracks
output_stracks = [track for track in self.tracked_stracks if track.is_activated]
logger.debug('===========Frame {}=========='.format(self.frame_id))
logger.debug('Activated: {}'.format([track.track_id for track in activated_starcks]))
logger.debug('Refind: {}'.format([track.track_id for track in refind_stracks]))
logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks]))
logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks]))
return output_stracks
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import time
import paddle
import numpy as np
__all__ = [
'Timer',
'Detection',
'load_det_results',
'preprocess_reid',
'get_crops',
'clip_box',
'scale_coords',
]
class Timer(object):
"""
This class used to compute and print the current FPS while evaling.
"""
def __init__(self):
self.total_time = 0.
self.calls = 0
self.start_time = 0.
self.diff = 0.
self.average_time = 0.
self.duration = 0.
def tic(self):
# using time.time instead of time.clock because time time.clock
# does not normalize for multithreading
self.start_time = time.time()
def toc(self, average=True):
self.diff = time.time() - self.start_time
self.total_time += self.diff
self.calls += 1
self.average_time = self.total_time / self.calls
if average:
self.duration = self.average_time
else:
self.duration = self.diff
return self.duration
def clear(self):
self.total_time = 0.
self.calls = 0
self.start_time = 0.
self.diff = 0.
self.average_time = 0.
self.duration = 0.
class Detection(object):
"""
This class represents a bounding box detection in a single image.
Args:
tlwh (ndarray): Bounding box in format `(top left x, top left y,
width, height)`.
confidence (ndarray): Detector confidence score.
feature (Tensor): A feature vector that describes the object
contained in this image.
"""
def __init__(self, tlwh, confidence, feature):
self.tlwh = np.asarray(tlwh, dtype=np.float32)
self.confidence = np.asarray(confidence, dtype=np.float32)
self.feature = feature
def to_tlbr(self):
"""
Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
`(top left, bottom right)`.
"""
ret = self.tlwh.copy()
ret[2:] += ret[:2]
return ret
def to_xyah(self):
"""
Convert bounding box to format `(center x, center y, aspect ratio,
height)`, where the aspect ratio is `width / height`.
"""
ret = self.tlwh.copy()
ret[:2] += ret[2:] / 2
ret[2] /= ret[3]
return ret
def load_det_results(det_file, num_frames):
assert os.path.exists(det_file) and os.path.isfile(det_file), \
'Error: det_file: {} not exist or not a file.'.format(det_file)
labels = np.loadtxt(det_file, dtype='float32', delimiter=',')
results_list = []
for frame_i in range(0, num_frames):
results = {'bbox': [], 'score': []}
lables_with_frame = labels[labels[:, 0] == frame_i + 1]
for l in lables_with_frame:
results['bbox'].append(l[1:5])
results['score'].append(l[5])
results_list.append(results)
return results_list
def scale_coords(coords, input_shape, im_shape, scale_factor):
im_shape = im_shape.numpy()[0]
ratio = scale_factor[0][0]
pad_w = (input_shape[1] - int(im_shape[1])) / 2
pad_h = (input_shape[0] - int(im_shape[0])) / 2
coords = paddle.cast(coords, 'float32')
coords[:, 0::2] -= pad_w
coords[:, 1::2] -= pad_h
coords[:, 0:4] /= ratio
coords[:, :4] = paddle.clip(coords[:, :4], min=0, max=coords[:, :4].max())
return coords.round()
def clip_box(xyxy, input_shape, im_shape, scale_factor):
im_shape = im_shape.numpy()[0]
ratio = scale_factor.numpy()[0][0]
img0_shape = [int(im_shape[0] / ratio), int(im_shape[1] / ratio)]
xyxy[:, 0::2] = paddle.clip(xyxy[:, 0::2], min=0, max=img0_shape[1])
xyxy[:, 1::2] = paddle.clip(xyxy[:, 1::2], min=0, max=img0_shape[0])
return xyxy
def get_crops(xyxy, ori_img, pred_scores, w, h):
crops = []
keep_scores = []
xyxy = xyxy.numpy().astype(np.int64)
ori_img = ori_img.numpy()
ori_img = np.squeeze(ori_img, axis=0).transpose(1, 0, 2)
pred_scores = pred_scores.numpy()
for i, bbox in enumerate(xyxy):
if bbox[2] <= bbox[0] or bbox[3] <= bbox[1]:
continue
crop = ori_img[bbox[0]:bbox[2], bbox[1]:bbox[3], :]
crops.append(crop)
keep_scores.append(pred_scores[i])
if len(crops) == 0:
return [], []
crops = preprocess_reid(crops, w, h)
return crops, keep_scores
def preprocess_reid(imgs, w=64, h=192, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
im_batch = []
for img in imgs:
img = cv2.resize(img, (w, h))
img = img[:, :, ::-1].astype('float32').transpose((2, 0, 1)) / 255
img_mean = np.array(mean).reshape((3, 1, 1))
img_std = np.array(std).reshape((3, 1, 1))
img -= img_mean
img /= img_std
img = np.expand_dims(img, axis=0)
im_batch.append(img)
im_batch = np.concatenate(im_batch, 0)
return im_batch
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2
import numpy as np
def tlwhs_to_tlbrs(tlwhs):
tlbrs = np.copy(tlwhs)
if len(tlbrs) == 0:
return tlbrs
tlbrs[:, 2] += tlwhs[:, 0]
tlbrs[:, 3] += tlwhs[:, 1]
return tlbrs
def get_color(idx):
idx = idx * 3
color = ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255)
return color
def resize_image(image, max_size=800):
if max(image.shape[:2]) > max_size:
scale = float(max_size) / max(image.shape[:2])
image = cv2.resize(image, None, fx=scale, fy=scale)
return image
def plot_tracking(image, tlwhs, obj_ids, scores=None, frame_id=0, fps=0., ids2=None):
im = np.ascontiguousarray(np.copy(image))
im_h, im_w = im.shape[:2]
top_view = np.zeros([im_w, im_w, 3], dtype=np.uint8) + 255
text_scale = max(1, image.shape[1] / 1600.)
text_thickness = 2
line_thickness = max(1, int(image.shape[1] / 500.))
radius = max(5, int(im_w / 140.))
cv2.putText(
im,
'frame: %d fps: %.2f num: %d' % (frame_id, fps, len(tlwhs)), (0, int(15 * text_scale)),
cv2.FONT_HERSHEY_PLAIN,
text_scale, (0, 0, 255),
thickness=2)
for i, tlwh in enumerate(tlwhs):
x1, y1, w, h = tlwh
intbox = tuple(map(int, (x1, y1, x1 + w, y1 + h)))
obj_id = int(obj_ids[i])
id_text = '{}'.format(int(obj_id))
if ids2 is not None:
id_text = id_text + ', {}'.format(int(ids2[i]))
_line_thickness = 1 if obj_id <= 0 else line_thickness
color = get_color(abs(obj_id))
cv2.rectangle(im, intbox[0:2], intbox[2:4], color=color, thickness=line_thickness)
cv2.putText(
im,
id_text, (intbox[0], intbox[1] + 10),
cv2.FONT_HERSHEY_PLAIN,
text_scale, (0, 0, 255),
thickness=text_thickness)
if scores is not None:
text = '{:.2f}'.format(float(scores[i]))
cv2.putText(
im,
text, (intbox[0], intbox[1] - 10),
cv2.FONT_HERSHEY_PLAIN,
text_scale, (0, 255, 255),
thickness=text_thickness)
return im
def plot_trajectory(image, tlwhs, track_ids):
image = image.copy()
for one_tlwhs, track_id in zip(tlwhs, track_ids):
color = get_color(int(track_id))
for tlwh in one_tlwhs:
x1, y1, w, h = tuple(map(int, tlwh))
cv2.circle(image, (int(x1 + 0.5 * w), int(y1 + h)), 2, color, thickness=2)
return image
def plot_detections(image, tlbrs, scores=None, color=(255, 0, 0), ids=None):
im = np.copy(image)
text_scale = max(1, image.shape[1] / 800.)
thickness = 2 if text_scale > 1.3 else 1
for i, det in enumerate(tlbrs):
x1, y1, x2, y2 = np.asarray(det[:4], dtype=np.int)
if len(det) >= 7:
label = 'det' if det[5] > 0 else 'trk'
if ids is not None:
text = '{}# {:.2f}: {:d}'.format(label, det[6], ids[i])
cv2.putText(
im, text, (x1, y1 + 30), cv2.FONT_HERSHEY_PLAIN, text_scale, (0, 255, 255), thickness=thickness)
else:
text = '{}# {:.2f}'.format(label, det[6])
if scores is not None:
text = '{:.2f}'.format(scores[i])
cv2.putText(im, text, (x1, y1 + 30), cv2.FONT_HERSHEY_PLAIN, text_scale, (0, 255, 255), thickness=thickness)
cv2.rectangle(im, (x1, y1), (x2, y2), color, 2)
return im
......@@ -16,18 +16,19 @@ import cv2
import glob
import paddle
import numpy as np
import collections
from ppdet.core.workspace import create
from ppdet.utils.checkpoint import load_weight, load_pretrain_weight
from ppdet.modeling.mot.utils import Detection, get_crops, scale_coords, clip_box
from ppdet.modeling.mot.utils import Timer, load_det_results
from ppdet.modeling.mot import visualization as mot_vis
from ppdet.metrics import Metric, MOTMetric, KITTIMOTMetric
import ppdet.utils.stats as stats
from ppdet.engine.callbacks import Callback, ComposeCallback
from ppdet.utils.logger import setup_logger
from .dataset import MOTVideoStream, MOTImageStream
from .utils import Timer
from .modeling.mot.utils import Detection, get_crops, scale_coords, clip_box
from .modeling.mot import visualization as mot_vis
logger = setup_logger(__name__)
......@@ -71,7 +72,6 @@ class StreamTracker(object):
timer.tic()
pred_dets, pred_embs = self.model(data)
online_targets = self.model.tracker.update(pred_dets, pred_embs)
online_tlwhs, online_ids = [], []
online_scores = []
for t in online_targets:
......@@ -109,7 +109,6 @@ class StreamTracker(object):
timer.tic()
pred_dets, pred_embs = self.model(data)
online_targets = self.model.tracker.update(pred_dets, pred_embs)
online_tlwhs, online_ids = [], []
online_scores = []
for t in online_targets:
......
import time
class Timer(object):
"""
This class used to compute and print the current FPS while evaling.
"""
def __init__(self):
self.total_time = 0.
self.calls = 0
self.start_time = 0.
self.diff = 0.
self.average_time = 0.
self.duration = 0.
def tic(self):
# using time.time instead of time.clock because time time.clock
# does not normalize for multithreading
self.start_time = time.time()
def toc(self, average=True):
self.diff = time.time() - self.start_time
self.total_time += self.diff
self.calls += 1
self.average_time = self.total_time / self.calls
if average:
self.duration = self.average_time
else:
self.duration = self.diff
return self.duration
def clear(self):
self.total_time = 0.
self.calls = 0
self.start_time = 0.
self.diff = 0.
self.average_time = 0.
self.duration = 0.
......@@ -5,7 +5,7 @@ find_unused_parameters: True
JDE:
detector: YOLOv3
reid: JDEEmbeddingHead
tracker: JDETracker
tracker: FrozenJDETracker
YOLOv3:
backbone: DarkNet
......
......@@ -9,7 +9,7 @@ _BASE_: [
JDE:
detector: YOLOv3
reid: JDEEmbeddingHead
tracker: JDETracker
tracker: FrozenJDETracker
YOLOv3:
backbone: DarkNet
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import matching
from . import tracker
from . import motion
from . import visualization
from . import utils
from .matching import *
from .tracker import *
from .motion import *
from .visualization import *
from .utils import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import jde_matching
from . import deepsort_matching
from .jde_matching import *
from .deepsort_matching import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is borrow from https://github.com/nwojke/deep_sort/tree/master/deep_sort
"""
import numpy as np
from scipy.optimize import linear_sum_assignment
from ..motion import kalman_filter
INFTY_COST = 1e+5
__all__ = [
'iou_1toN',
'iou_cost',
'_nn_euclidean_distance',
'_nn_cosine_distance',
'NearestNeighborDistanceMetric',
'min_cost_matching',
'matching_cascade',
'gate_cost_matrix',
]
def iou_1toN(bbox, candidates):
"""
Computer intersection over union (IoU) by one box to N candidates.
Args:
bbox (ndarray): A bounding box in format `(top left x, top left y, width, height)`.
candidates (ndarray): A matrix of candidate bounding boxes (one per row) in the
same format as `bbox`.
Returns:
ious (ndarray): The intersection over union in [0, 1] between the `bbox`
and each candidate. A higher score means a larger fraction of the
`bbox` is occluded by the candidate.
"""
bbox_tl = bbox[:2]
bbox_br = bbox[:2] + bbox[2:]
candidates_tl = candidates[:, :2]
candidates_br = candidates[:, :2] + candidates[:, 2:]
tl = np.c_[np.maximum(bbox_tl[0], candidates_tl[:, 0])[:, np.newaxis],
np.maximum(bbox_tl[1], candidates_tl[:, 1])[:, np.newaxis]]
br = np.c_[np.minimum(bbox_br[0], candidates_br[:, 0])[:, np.newaxis],
np.minimum(bbox_br[1], candidates_br[:, 1])[:, np.newaxis]]
wh = np.maximum(0., br - tl)
area_intersection = wh.prod(axis=1)
area_bbox = bbox[2:].prod()
area_candidates = candidates[:, 2:].prod(axis=1)
ious = area_intersection / (area_bbox + area_candidates - area_intersection)
return ious
def iou_cost(tracks, detections, track_indices=None, detection_indices=None):
"""
IoU distance metric.
Args:
tracks (list[Track]): A list of tracks.
detections (list[Detection]): A list of detections.
track_indices (Optional[list[int]]): A list of indices to tracks that
should be matched. Defaults to all `tracks`.
detection_indices (Optional[list[int]]): A list of indices to detections
that should be matched. Defaults to all `detections`.
Returns:
cost_matrix (ndarray): A cost matrix of shape len(track_indices),
len(detection_indices) where entry (i, j) is
`1 - iou(tracks[track_indices[i]], detections[detection_indices[j]])`.
"""
if track_indices is None:
track_indices = np.arange(len(tracks))
if detection_indices is None:
detection_indices = np.arange(len(detections))
cost_matrix = np.zeros((len(track_indices), len(detection_indices)))
for row, track_idx in enumerate(track_indices):
if tracks[track_idx].time_since_update > 1:
cost_matrix[row, :] = 1e+5
continue
bbox = tracks[track_idx].to_tlwh()
candidates = np.asarray([detections[i].tlwh for i in detection_indices])
cost_matrix[row, :] = 1. - iou_1toN(bbox, candidates)
return cost_matrix
def _nn_euclidean_distance(s, q):
"""
Compute pair-wise squared (Euclidean) distance between points in `s` and `q`.
Args:
s (ndarray): Sample points: an NxM matrix of N samples of dimensionality M.
q (ndarray): Query points: an LxM matrix of L samples of dimensionality M.
Returns:
distances (ndarray): A vector of length M that contains for each entry in `q` the
smallest Euclidean distance to a sample in `s`.
"""
s, q = np.asarray(s), np.asarray(q)
if len(s) == 0 or len(q) == 0:
return np.zeros((len(s), len(q)))
s2, q2 = np.square(s).sum(axis=1), np.square(q).sum(axis=1)
distances = -2. * np.dot(s, q.T) + s2[:, None] + q2[None, :]
distances = np.clip(distances, 0., float(np.inf))
return np.maximum(0.0, distances.min(axis=0))
def _nn_cosine_distance(s, q):
"""
Compute pair-wise cosine distance between points in `s` and `q`.
Args:
s (ndarray): Sample points: an NxM matrix of N samples of dimensionality M.
q (ndarray): Query points: an LxM matrix of L samples of dimensionality M.
Returns:
distances (ndarray): A vector of length M that contains for each entry in `q` the
smallest Euclidean distance to a sample in `s`.
"""
s = np.asarray(s) / np.linalg.norm(s, axis=1, keepdims=True)
q = np.asarray(q) / np.linalg.norm(q, axis=1, keepdims=True)
distances = 1. - np.dot(s, q.T)
return distances.min(axis=0)
class NearestNeighborDistanceMetric(object):
"""
A nearest neighbor distance metric that, for each target, returns
the closest distance to any sample that has been observed so far.
Args:
metric (str): Either "euclidean" or "cosine".
matching_threshold (float): The matching threshold. Samples with larger
distance are considered an invalid match.
budget (Optional[int]): If not None, fix samples per class to at most
this number. Removes the oldest samples when the budget is reached.
Attributes:
samples (Dict[int -> List[ndarray]]): A dictionary that maps from target
identities to the list of samples that have been observed so far.
"""
def __init__(self, metric, matching_threshold, budget=None):
if metric == "euclidean":
self._metric = _nn_euclidean_distance
elif metric == "cosine":
self._metric = _nn_cosine_distance
else:
raise ValueError("Invalid metric; must be either 'euclidean' or 'cosine'")
self.matching_threshold = matching_threshold
self.budget = budget
self.samples = {}
def partial_fit(self, features, targets, active_targets):
"""
Update the distance metric with new data.
Args:
features (ndarray): An NxM matrix of N features of dimensionality M.
targets (ndarray): An integer array of associated target identities.
active_targets (List[int]): A list of targets that are currently
present in the scene.
"""
for feature, target in zip(features, targets):
self.samples.setdefault(target, []).append(feature)
if self.budget is not None:
self.samples[target] = self.samples[target][-self.budget:]
self.samples = {k: self.samples[k] for k in active_targets}
def distance(self, features, targets):
"""
Compute distance between features and targets.
Args:
features (ndarray): An NxM matrix of N features of dimensionality M.
targets (list[int]): A list of targets to match the given `features` against.
Returns:
cost_matrix (ndarray): a cost matrix of shape len(targets), len(features),
where element (i, j) contains the closest squared distance between
`targets[i]` and `features[j]`.
"""
cost_matrix = np.zeros((len(targets), len(features)))
for i, target in enumerate(targets):
cost_matrix[i, :] = self._metric(self.samples[target], features)
return cost_matrix
def min_cost_matching(distance_metric, max_distance, tracks, detections, track_indices=None, detection_indices=None):
"""
Solve linear assignment problem.
Args:
distance_metric :
Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray
The distance metric is given a list of tracks and detections as
well as a list of N track indices and M detection indices. The
metric should return the NxM dimensional cost matrix, where element
(i, j) is the association cost between the i-th track in the given
track indices and the j-th detection in the given detection_indices.
max_distance (float): Gating threshold. Associations with cost larger
than this value are disregarded.
tracks (list[Track]): A list of predicted tracks at the current time
step.
detections (list[Detection]): A list of detections at the current time
step.
track_indices (list[int]): List of track indices that maps rows in
`cost_matrix` to tracks in `tracks`.
detection_indices (List[int]): List of detection indices that maps
columns in `cost_matrix` to detections in `detections`.
Returns:
A tuple (List[(int, int)], List[int], List[int]) with the following
three entries:
* A list of matched track and detection indices.
* A list of unmatched track indices.
* A list of unmatched detection indices.
"""
if track_indices is None:
track_indices = np.arange(len(tracks))
if detection_indices is None:
detection_indices = np.arange(len(detections))
if len(detection_indices) == 0 or len(track_indices) == 0:
return [], track_indices, detection_indices # Nothing to match.
cost_matrix = distance_metric(tracks, detections, track_indices, detection_indices)
cost_matrix[cost_matrix > max_distance] = max_distance + 1e-5
indices = linear_sum_assignment(cost_matrix)
matches, unmatched_tracks, unmatched_detections = [], [], []
for col, detection_idx in enumerate(detection_indices):
if col not in indices[1]:
unmatched_detections.append(detection_idx)
for row, track_idx in enumerate(track_indices):
if row not in indices[0]:
unmatched_tracks.append(track_idx)
for row, col in zip(indices[0], indices[1]):
track_idx = track_indices[row]
detection_idx = detection_indices[col]
if cost_matrix[row, col] > max_distance:
unmatched_tracks.append(track_idx)
unmatched_detections.append(detection_idx)
else:
matches.append((track_idx, detection_idx))
return matches, unmatched_tracks, unmatched_detections
def matching_cascade(distance_metric,
max_distance,
cascade_depth,
tracks,
detections,
track_indices=None,
detection_indices=None):
"""
Run matching cascade.
Args:
distance_metric :
Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray
The distance metric is given a list of tracks and detections as
well as a list of N track indices and M detection indices. The
metric should return the NxM dimensional cost matrix, where element
(i, j) is the association cost between the i-th track in the given
track indices and the j-th detection in the given detection_indices.
max_distance (float): Gating threshold. Associations with cost larger
than this value are disregarded.
cascade_depth (int): The cascade depth, should be se to the maximum
track age.
tracks (list[Track]): A list of predicted tracks at the current time
step.
detections (list[Detection]): A list of detections at the current time
step.
track_indices (list[int]): List of track indices that maps rows in
`cost_matrix` to tracks in `tracks`.
detection_indices (List[int]): List of detection indices that maps
columns in `cost_matrix` to detections in `detections`.
Returns:
A tuple (List[(int, int)], List[int], List[int]) with the following
three entries:
* A list of matched track and detection indices.
* A list of unmatched track indices.
* A list of unmatched detection indices.
"""
if track_indices is None:
track_indices = list(range(len(tracks)))
if detection_indices is None:
detection_indices = list(range(len(detections)))
unmatched_detections = detection_indices
matches = []
for level in range(cascade_depth):
if len(unmatched_detections) == 0: # No detections left
break
track_indices_l = [k for k in track_indices if tracks[k].time_since_update == 1 + level]
if len(track_indices_l) == 0: # Nothing to match at this level
continue
matches_l, _, unmatched_detections = \
min_cost_matching(
distance_metric, max_distance, tracks, detections,
track_indices_l, unmatched_detections)
matches += matches_l
unmatched_tracks = list(set(track_indices) - set(k for k, _ in matches))
return matches, unmatched_tracks, unmatched_detections
def gate_cost_matrix(kf,
cost_matrix,
tracks,
detections,
track_indices,
detection_indices,
gated_cost=INFTY_COST,
only_position=False):
"""
Invalidate infeasible entries in cost matrix based on the state
distributions obtained by Kalman filtering.
Args:
kf (object): The Kalman filter.
cost_matrix (ndarray): The NxM dimensional cost matrix, where N is the
number of track indices and M is the number of detection indices,
such that entry (i, j) is the association cost between
`tracks[track_indices[i]]` and `detections[detection_indices[j]]`.
tracks (list[Track]): A list of predicted tracks at the current time
step.
detections (list[Detection]): A list of detections at the current time
step.
track_indices (List[int]): List of track indices that maps rows in
`cost_matrix` to tracks in `tracks`.
detection_indices (List[int]): List of detection indices that maps
columns in `cost_matrix` to detections in `detections`.
gated_cost (Optional[float]): Entries in the cost matrix corresponding
to infeasible associations are set this value. Defaults to a very
large value.
only_position (Optional[bool]): If True, only the x, y position of the
state distribution is considered during gating. Default False.
"""
gating_dim = 2 if only_position else 4
gating_threshold = kalman_filter.chi2inv95[gating_dim]
measurements = np.asarray([detections[i].to_xyah() for i in detection_indices])
for row, track_idx in enumerate(track_indices):
track = tracks[track_idx]
gating_distance = kf.gating_distance(track.mean, track.covariance, measurements, only_position)
cost_matrix[row, gating_distance > gating_threshold] = gated_cost
return cost_matrix
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is borrow from https://github.com/Zhongdao/Towards-Realtime-MOT/blob/master/tracker/matching.py
"""
import lap
import scipy
import numpy as np
from scipy.spatial.distance import cdist
from ..motion import kalman_filter
from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
__all__ = [
'merge_matches',
'linear_assignment',
'cython_bbox_ious',
'iou_distance',
'embedding_distance',
'fuse_motion',
]
def merge_matches(m1, m2, shape):
O, P, Q = shape
m1 = np.asarray(m1)
m2 = np.asarray(m2)
M1 = scipy.sparse.coo_matrix((np.ones(len(m1)), (m1[:, 0], m1[:, 1])), shape=(O, P))
M2 = scipy.sparse.coo_matrix((np.ones(len(m2)), (m2[:, 0], m2[:, 1])), shape=(P, Q))
mask = M1 * M2
match = mask.nonzero()
match = list(zip(match[0], match[1]))
unmatched_O = tuple(set(range(O)) - set([i for i, j in match]))
unmatched_Q = tuple(set(range(Q)) - set([j for i, j in match]))
return match, unmatched_O, unmatched_Q
def linear_assignment(cost_matrix, thresh):
if cost_matrix.size == 0:
return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1]))
matches, unmatched_a, unmatched_b = [], [], []
cost, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)
for ix, mx in enumerate(x):
if mx >= 0:
matches.append([ix, mx])
unmatched_a = np.where(x < 0)[0]
unmatched_b = np.where(y < 0)[0]
matches = np.asarray(matches)
return matches, unmatched_a, unmatched_b
def cython_bbox_ious(atlbrs, btlbrs):
ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float)
if ious.size == 0:
return ious
try:
import cython_bbox
except Exception as e:
logger.error('cython_bbox not found, please install cython_bbox.' 'for example: `pip install cython_bbox`.')
raise e
ious = cython_bbox.bbox_overlaps(
np.ascontiguousarray(atlbrs, dtype=np.float), np.ascontiguousarray(btlbrs, dtype=np.float))
return ious
def iou_distance(atracks, btracks):
"""
Compute cost based on IoU between two list[STrack].
"""
if (len(atracks) > 0 and isinstance(atracks[0], np.ndarray)) or (len(btracks) > 0
and isinstance(btracks[0], np.ndarray)):
atlbrs = atracks
btlbrs = btracks
else:
atlbrs = [track.tlbr for track in atracks]
btlbrs = [track.tlbr for track in btracks]
_ious = cython_bbox_ious(atlbrs, btlbrs)
cost_matrix = 1 - _ious
return cost_matrix
def embedding_distance(tracks, detections, metric='euclidean'):
"""
Compute cost based on features between two list[STrack].
"""
cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float)
if cost_matrix.size == 0:
return cost_matrix
det_features = np.asarray([track.curr_feat for track in detections], dtype=np.float)
track_features = np.asarray([track.smooth_feat for track in tracks], dtype=np.float)
cost_matrix = np.maximum(0.0, cdist(track_features, det_features, metric)) # Nomalized features
return cost_matrix
def fuse_motion(kf, cost_matrix, tracks, detections, only_position=False, lambda_=0.98):
if cost_matrix.size == 0:
return cost_matrix
gating_dim = 2 if only_position else 4
gating_threshold = kalman_filter.chi2inv95[gating_dim]
measurements = np.asarray([det.to_xyah() for det in detections])
for row, track in enumerate(tracks):
gating_distance = kf.gating_distance(track.mean, track.covariance, measurements, only_position, metric='maha')
cost_matrix[row, gating_distance > gating_threshold] = np.inf
cost_matrix[row] = lambda_ * cost_matrix[row] + (1 - lambda_) * gating_distance
return cost_matrix
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import kalman_filter
from .kalman_filter import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is borrow from https://github.com/nwojke/deep_sort/blob/master/deep_sort/kalman_filter.py
"""
import numpy as np
import scipy.linalg
__all__ = ['KalmanFilter']
"""
Table for the 0.95 quantile of the chi-square distribution with N degrees of
freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv
function and used as Mahalanobis gating threshold.
"""
chi2inv95 = {1: 3.8415, 2: 5.9915, 3: 7.8147, 4: 9.4877, 5: 11.070, 6: 12.592, 7: 14.067, 8: 15.507, 9: 16.919}
class KalmanFilter(object):
"""
A simple Kalman filter for tracking bounding boxes in image space.
The 8-dimensional state space
x, y, a, h, vx, vy, va, vh
contains the bounding box center position (x, y), aspect ratio a, height h,
and their respective velocities.
Object motion follows a constant velocity model. The bounding box location
(x, y, a, h) is taken as direct observation of the state space (linear
observation model).
"""
def __init__(self):
ndim, dt = 4, 1.
# Create Kalman filter model matrices.
self._motion_mat = np.eye(2 * ndim, 2 * ndim)
for i in range(ndim):
self._motion_mat[i, ndim + i] = dt
self._update_mat = np.eye(ndim, 2 * ndim)
# Motion and observation uncertainty are chosen relative to the current
# state estimate. These weights control the amount of uncertainty in
# the model. This is a bit hacky.
self._std_weight_position = 1. / 20
self._std_weight_velocity = 1. / 160
def initiate(self, measurement):
"""
Create track from unassociated measurement.
Args:
measurement (ndarray): Bounding box coordinates (x, y, a, h) with
center position (x, y), aspect ratio a, and height h.
Returns:
The mean vector (8 dimensional) and covariance matrix (8x8
dimensional) of the new track. Unobserved velocities are
initialized to 0 mean.
"""
mean_pos = measurement
mean_vel = np.zeros_like(mean_pos)
mean = np.r_[mean_pos, mean_vel]
std = [
2 * self._std_weight_position * measurement[3], 2 * self._std_weight_position * measurement[3], 1e-2,
2 * self._std_weight_position * measurement[3], 10 * self._std_weight_velocity * measurement[3],
10 * self._std_weight_velocity * measurement[3], 1e-5, 10 * self._std_weight_velocity * measurement[3]
]
covariance = np.diag(np.square(std))
return mean, covariance
def predict(self, mean, covariance):
"""
Run Kalman filter prediction step.
Args:
mean (ndarray): The 8 dimensional mean vector of the object state
at the previous time step.
covariance (ndarray): The 8x8 dimensional covariance matrix of the
object state at the previous time step.
Returns:
The mean vector and covariance matrix of the predicted state.
Unobserved velocities are initialized to 0 mean.
"""
std_pos = [
self._std_weight_position * mean[3], self._std_weight_position * mean[3], 1e-2,
self._std_weight_position * mean[3]
]
std_vel = [
self._std_weight_velocity * mean[3], self._std_weight_velocity * mean[3], 1e-5,
self._std_weight_velocity * mean[3]
]
motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
#mean = np.dot(self._motion_mat, mean)
mean = np.dot(mean, self._motion_mat.T)
covariance = np.linalg.multi_dot((self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
return mean, covariance
def project(self, mean, covariance):
"""
Project state distribution to measurement space.
Args
mean (ndarray): The state's mean vector (8 dimensional array).
covariance (ndarray): The state's covariance matrix (8x8 dimensional).
Returns:
The projected mean and covariance matrix of the given state estimate.
"""
std = [
self._std_weight_position * mean[3], self._std_weight_position * mean[3], 1e-1,
self._std_weight_position * mean[3]
]
innovation_cov = np.diag(np.square(std))
mean = np.dot(self._update_mat, mean)
covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T))
return mean, covariance + innovation_cov
def multi_predict(self, mean, covariance):
"""
Run Kalman filter prediction step (Vectorized version).
Args:
mean (ndarray): The Nx8 dimensional mean matrix of the object states
at the previous time step.
covariance (ndarray): The Nx8x8 dimensional covariance matrics of the
object states at the previous time step.
Returns:
The mean vector and covariance matrix of the predicted state.
Unobserved velocities are initialized to 0 mean.
"""
std_pos = [
self._std_weight_position * mean[:, 3], self._std_weight_position * mean[:, 3],
1e-2 * np.ones_like(mean[:, 3]), self._std_weight_position * mean[:, 3]
]
std_vel = [
self._std_weight_velocity * mean[:, 3], self._std_weight_velocity * mean[:, 3],
1e-5 * np.ones_like(mean[:, 3]), self._std_weight_velocity * mean[:, 3]
]
sqr = np.square(np.r_[std_pos, std_vel]).T
motion_cov = []
for i in range(len(mean)):
motion_cov.append(np.diag(sqr[i]))
motion_cov = np.asarray(motion_cov)
mean = np.dot(mean, self._motion_mat.T)
left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2))
covariance = np.dot(left, self._motion_mat.T) + motion_cov
return mean, covariance
def update(self, mean, covariance, measurement):
"""
Run Kalman filter correction step.
Args:
mean (ndarray): The predicted state's mean vector (8 dimensional).
covariance (ndarray): The state's covariance matrix (8x8 dimensional).
measurement (ndarray): The 4 dimensional measurement vector
(x, y, a, h), where (x, y) is the center position, a the aspect
ratio, and h the height of the bounding box.
Returns:
The measurement-corrected state distribution.
"""
projected_mean, projected_cov = self.project(mean, covariance)
chol_factor, lower = scipy.linalg.cho_factor(projected_cov, lower=True, check_finite=False)
kalman_gain = scipy.linalg.cho_solve((chol_factor, lower),
np.dot(covariance, self._update_mat.T).T,
check_finite=False).T
innovation = measurement - projected_mean
new_mean = mean + np.dot(innovation, kalman_gain.T)
new_covariance = covariance - np.linalg.multi_dot((kalman_gain, projected_cov, kalman_gain.T))
return new_mean, new_covariance
def gating_distance(self, mean, covariance, measurements, only_position=False, metric='maha'):
"""
Compute gating distance between state distribution and measurements.
A suitable distance threshold can be obtained from `chi2inv95`. If
`only_position` is False, the chi-square distribution has 4 degrees of
freedom, otherwise 2.
Args:
mean (ndarray): Mean vector over the state distribution (8
dimensional).
covariance (ndarray): Covariance of the state distribution (8x8
dimensional).
measurements (ndarray): An Nx4 dimensional matrix of N measurements,
each in format (x, y, a, h) where (x, y) is the bounding box center
position, a the aspect ratio, and h the height.
only_position (Optional[bool]): If True, distance computation is
done with respect to the bounding box center position only.
metric (str): Metric type, 'gaussian' or 'maha'.
Returns
An array of length N, where the i-th element contains the squared
Mahalanobis distance between (mean, covariance) and `measurements[i]`.
"""
mean, covariance = self.project(mean, covariance)
if only_position:
mean, covariance = mean[:2], covariance[:2, :2]
measurements = measurements[:, :2]
d = measurements - mean
if metric == 'gaussian':
return np.sum(d * d, axis=1)
elif metric == 'maha':
cholesky_factor = np.linalg.cholesky(covariance)
z = scipy.linalg.solve_triangular(cholesky_factor, d.T, lower=True, check_finite=False, overwrite_b=True)
squared_maha = np.sum(z * z, axis=0)
return squared_maha
else:
raise ValueError('invalid distance metric')
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import base_jde_tracker
from . import base_sde_tracker
from . import jde_tracker
from .base_jde_tracker import *
from .base_sde_tracker import *
from .jde_tracker import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is borrow from https://github.com/Zhongdao/Towards-Realtime-MOT/blob/master/tracker/multitracker.py
"""
import numpy as np
from collections import deque, OrderedDict
from ..matching import jde_matching as matching
from ppdet.core.workspace import register, serializable
__all__ = [
'TrackState',
'BaseTrack',
'STrack',
'joint_stracks',
'sub_stracks',
'remove_duplicate_stracks',
]
class TrackState(object):
New = 0
Tracked = 1
Lost = 2
Removed = 3
class BaseTrack(object):
_count = 0
track_id = 0
is_activated = False
state = TrackState.New
history = OrderedDict()
features = []
curr_feature = None
score = 0
start_frame = 0
frame_id = 0
time_since_update = 0
# multi-camera
location = (np.inf, np.inf)
@property
def end_frame(self):
return self.frame_id
@staticmethod
def next_id():
BaseTrack._count += 1
return BaseTrack._count
def activate(self, *args):
raise NotImplementedError
def predict(self):
raise NotImplementedError
def update(self, *args, **kwargs):
raise NotImplementedError
def mark_lost(self):
self.state = TrackState.Lost
def mark_removed(self):
self.state = TrackState.Removed
class STrack(BaseTrack):
def __init__(self, tlwh, score, temp_feat, buffer_size=30):
# wait activate
self._tlwh = np.asarray(tlwh, dtype=np.float)
self.kalman_filter = None
self.mean, self.covariance = None, None
self.is_activated = False
self.score = score
self.tracklet_len = 0
self.smooth_feat = None
self.update_features(temp_feat)
self.features = deque([], maxlen=buffer_size)
self.alpha = 0.9
def update_features(self, feat):
feat /= np.linalg.norm(feat)
self.curr_feat = feat
if self.smooth_feat is None:
self.smooth_feat = feat
else:
self.smooth_feat = self.alpha * self.smooth_feat + (1 - self.alpha) * feat
self.features.append(feat)
self.smooth_feat /= np.linalg.norm(self.smooth_feat)
def predict(self):
mean_state = self.mean.copy()
if self.state != TrackState.Tracked:
mean_state[7] = 0
self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
@staticmethod
def multi_predict(stracks, kalman_filter):
if len(stracks) > 0:
multi_mean = np.asarray([st.mean.copy() for st in stracks])
multi_covariance = np.asarray([st.covariance for st in stracks])
for i, st in enumerate(stracks):
if st.state != TrackState.Tracked:
multi_mean[i][7] = 0
multi_mean, multi_covariance = kalman_filter.multi_predict(multi_mean, multi_covariance)
for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
stracks[i].mean = mean
stracks[i].covariance = cov
def activate(self, kalman_filter, frame_id):
"""Start a new tracklet"""
self.kalman_filter = kalman_filter
self.track_id = self.next_id()
self.mean, self.covariance = self.kalman_filter.initiate(self.tlwh_to_xyah(self._tlwh))
self.tracklet_len = 0
self.state = TrackState.Tracked
if frame_id == 1:
self.is_activated = True
self.frame_id = frame_id
self.start_frame = frame_id
def re_activate(self, new_track, frame_id, new_id=False):
self.mean, self.covariance = self.kalman_filter.update(self.mean, self.covariance,
self.tlwh_to_xyah(new_track.tlwh))
self.update_features(new_track.curr_feat)
self.tracklet_len = 0
self.state = TrackState.Tracked
self.is_activated = True
self.frame_id = frame_id
if new_id:
self.track_id = self.next_id()
def update(self, new_track, frame_id, update_feature=True):
self.frame_id = frame_id
self.tracklet_len += 1
new_tlwh = new_track.tlwh
self.mean, self.covariance = self.kalman_filter.update(self.mean, self.covariance, self.tlwh_to_xyah(new_tlwh))
self.state = TrackState.Tracked
self.is_activated = True
self.score = new_track.score
if update_feature:
self.update_features(new_track.curr_feat)
@property
def tlwh(self):
"""
Get current position in bounding box format `(top left x, top left y,
width, height)`.
"""
if self.mean is None:
return self._tlwh.copy()
ret = self.mean[:4].copy()
ret[2] *= ret[3]
ret[:2] -= ret[2:] / 2
return ret
@property
def tlbr(self):
"""
Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
`(top left, bottom right)`.
"""
ret = self.tlwh.copy()
ret[2:] += ret[:2]
return ret
@staticmethod
def tlwh_to_xyah(tlwh):
"""
Convert bounding box to format `(center x, center y, aspect ratio,
height)`, where the aspect ratio is `width / height`.
"""
ret = np.asarray(tlwh).copy()
ret[:2] += ret[2:] / 2
ret[2] /= ret[3]
return ret
def to_xyah(self):
return self.tlwh_to_xyah(self.tlwh)
@staticmethod
def tlbr_to_tlwh(tlbr):
ret = np.asarray(tlbr).copy()
ret[2:] -= ret[:2]
return ret
@staticmethod
def tlwh_to_tlbr(tlwh):
ret = np.asarray(tlwh).copy()
ret[2:] += ret[:2]
return ret
def __repr__(self):
return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame)
def joint_stracks(tlista, tlistb):
exists = {}
res = []
for t in tlista:
exists[t.track_id] = 1
res.append(t)
for t in tlistb:
tid = t.track_id
if not exists.get(tid, 0):
exists[tid] = 1
res.append(t)
return res
def sub_stracks(tlista, tlistb):
stracks = {}
for t in tlista:
stracks[t.track_id] = t
for t in tlistb:
tid = t.track_id
if stracks.get(tid, 0):
del stracks[tid]
return list(stracks.values())
def remove_duplicate_stracks(stracksa, stracksb):
pdist = matching.iou_distance(stracksa, stracksb)
pairs = np.where(pdist < 0.15)
dupa, dupb = list(), list()
for p, q in zip(*pairs):
timep = stracksa[p].frame_id - stracksa[p].start_frame
timeq = stracksb[q].frame_id - stracksb[q].start_frame
if timep > timeq:
dupb.append(q)
else:
dupa.append(p)
resa = [t for i, t in enumerate(stracksa) if not i in dupa]
resb = [t for i, t in enumerate(stracksb) if not i in dupb]
return resa, resb
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is borrow from https://github.com/nwojke/deep_sort/blob/master/deep_sort/track.py
"""
from ppdet.core.workspace import register, serializable
__all__ = ['TrackState', 'Track']
class TrackState(object):
"""
Enumeration type for the single target track state. Newly created tracks are
classified as `tentative` until enough evidence has been collected. Then,
the track state is changed to `confirmed`. Tracks that are no longer alive
are classified as `deleted` to mark them for removal from the set of active
tracks.
"""
Tentative = 1
Confirmed = 2
Deleted = 3
class Track(object):
"""
A single target track with state space `(x, y, a, h)` and associated
velocities, where `(x, y)` is the center of the bounding box, `a` is the
aspect ratio and `h` is the height.
Args:
mean (ndarray): Mean vector of the initial state distribution.
covariance (ndarray): Covariance matrix of the initial state distribution.
track_id (int): A unique track identifier.
n_init (int): Number of consecutive detections before the track is confirmed.
The track state is set to `Deleted` if a miss occurs within the first
`n_init` frames.
max_age (int): The maximum number of consecutive misses before the track
state is set to `Deleted`.
feature (Optional[ndarray]): Feature vector of the detection this track
originates from. If not None, this feature is added to the `features` cache.
Attributes:
hits (int): Total number of measurement updates.
age (int): Total number of frames since first occurance.
time_since_update (int): Total number of frames since last measurement
update.
state (TrackState): The current track state.
features (List[ndarray]): A cache of features. On each measurement update,
the associated feature vector is added to this list.
"""
def __init__(self, mean, covariance, track_id, n_init, max_age, feature=None):
self.mean = mean
self.covariance = covariance
self.track_id = track_id
self.hits = 1
self.age = 1
self.time_since_update = 0
self.state = TrackState.Tentative
self.features = []
if feature is not None:
self.features.append(feature)
self._n_init = n_init
self._max_age = max_age
def to_tlwh(self):
"""Get position in format `(top left x, top left y, width, height)`."""
ret = self.mean[:4].copy()
ret[2] *= ret[3]
ret[:2] -= ret[2:] / 2
return ret
def to_tlbr(self):
"""Get position in bounding box format `(min x, miny, max x, max y)`."""
ret = self.to_tlwh()
ret[2:] = ret[:2] + ret[2:]
return ret
def predict(self, kalman_filter):
"""
Propagate the state distribution to the current time step using a Kalman
filter prediction step.
"""
self.mean, self.covariance = kalman_filter.predict(self.mean, self.covariance)
self.age += 1
self.time_since_update += 1
def update(self, kalman_filter, detection):
"""
Perform Kalman filter measurement update step and update the associated
detection feature cache.
"""
self.mean, self.covariance = kalman_filter.update(self.mean, self.covariance, detection.to_xyah())
self.features.append(detection.feature)
self.hits += 1
self.time_since_update = 0
if self.state == TrackState.Tentative and self.hits >= self._n_init:
self.state = TrackState.Confirmed
def mark_missed(self):
"""Mark this track as missed (no association at the current time step).
"""
if self.state == TrackState.Tentative:
self.state = TrackState.Deleted
elif self.time_since_update > self._max_age:
self.state = TrackState.Deleted
def is_tentative(self):
"""Returns True if this track is tentative (unconfirmed)."""
return self.state == TrackState.Tentative
def is_confirmed(self):
"""Returns True if this track is confirmed."""
return self.state == TrackState.Confirmed
def is_deleted(self):
"""Returns True if this track is dead and should be deleted."""
return self.state == TrackState.Deleted
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is borrow from https://github.com/Zhongdao/Towards-Realtime-MOT/blob/master/tracker/multitracker.py
"""
import paddle
from ..matching import jde_matching as matching
from .base_jde_tracker import TrackState, BaseTrack, STrack
from .base_jde_tracker import joint_stracks, sub_stracks, remove_duplicate_stracks
from ppdet.core.workspace import register, serializable
from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
__all__ = ['FrozenJDETracker']
@register
@serializable
class FrozenJDETracker(object):
__inject__ = ['motion']
"""
JDE tracker
Args:
det_thresh (float): threshold of detection score
track_buffer (int): buffer for tracker
min_box_area (int): min box area to filter out low quality boxes
vertical_ratio (float): w/h, the vertical ratio of the bbox to filter
bad results, set 1.6 default for pedestrian tracking. If set -1
means no need to filter bboxes.
tracked_thresh (float): linear assignment threshold of tracked
stracks and detections
r_tracked_thresh (float): linear assignment threshold of
tracked stracks and unmatched detections
unconfirmed_thresh (float): linear assignment threshold of
unconfirmed stracks and unmatched detections
motion (object): KalmanFilter instance
conf_thres (float): confidence threshold for tracking
metric_type (str): either "euclidean" or "cosine", the distance metric
used for measurement to track association.
"""
def __init__(self,
det_thresh=0.3,
track_buffer=30,
min_box_area=200,
vertical_ratio=1.6,
tracked_thresh=0.7,
r_tracked_thresh=0.5,
unconfirmed_thresh=0.7,
motion='KalmanFilter',
conf_thres=0,
metric_type='euclidean'):
self.det_thresh = det_thresh
self.track_buffer = track_buffer
self.min_box_area = min_box_area
self.vertical_ratio = vertical_ratio
self.tracked_thresh = tracked_thresh
self.r_tracked_thresh = r_tracked_thresh
self.unconfirmed_thresh = unconfirmed_thresh
self.motion = motion
self.conf_thres = conf_thres
self.metric_type = metric_type
self.frame_id = 0
self.tracked_stracks = []
self.lost_stracks = []
self.removed_stracks = []
self.max_time_lost = 0
# max_time_lost will be calculated: int(frame_rate / 30.0 * track_buffer)
def update(self, pred_dets, pred_embs):
"""
Processes the image frame and finds bounding box(detections).
Associates the detection with corresponding tracklets and also handles
lost, removed, refound and active tracklets.
Args:
pred_dets (Tensor): Detection results of the image, shape is [N, 5].
pred_embs (Tensor): Embedding results of the image, shape is [N, 512].
Return:
output_stracks (list): The list contains information regarding the
online_tracklets for the recieved image tensor.
"""
self.frame_id += 1
activated_starcks = []
# for storing active tracks, for the current frame
refind_stracks = []
# Lost Tracks whose detections are obtained in the current frame
lost_stracks = []
# The tracks which are not obtained in the current frame but are not
# removed. (Lost for some time lesser than the threshold for removing)
removed_stracks = []
remain_inds = paddle.nonzero(pred_dets[:, 4] > self.conf_thres)
if remain_inds.shape[0] == 0:
pred_dets = paddle.zeros([0, 1])
pred_embs = paddle.zeros([0, 1])
else:
pred_dets = paddle.gather(pred_dets, remain_inds)
pred_embs = paddle.gather(pred_embs, remain_inds)
# Filter out the image with box_num = 0. pred_dets = [[0.0, 0.0, 0.0 ,0.0]]
empty_pred = True if len(pred_dets) == 1 and paddle.sum(pred_dets) == 0.0 else False
""" Step 1: Network forward, get detections & embeddings"""
if len(pred_dets) > 0 and not empty_pred:
pred_dets = pred_dets.numpy()
pred_embs = pred_embs.numpy()
detections = [
STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f, 30) for (tlbrs, f) in zip(pred_dets, pred_embs)
]
else:
detections = []
''' Add newly detected tracklets to tracked_stracks'''
unconfirmed = []
tracked_stracks = [] # type: list[STrack]
for track in self.tracked_stracks:
if not track.is_activated:
# previous tracks which are not active in the current frame are added in unconfirmed list
unconfirmed.append(track)
else:
# Active tracks are added to the local list 'tracked_stracks'
tracked_stracks.append(track)
""" Step 2: First association, with embedding"""
# Combining currently tracked_stracks and lost_stracks
strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
# Predict the current location with KF
STrack.multi_predict(strack_pool, self.motion)
dists = matching.embedding_distance(strack_pool, detections, metric=self.metric_type)
dists = matching.fuse_motion(self.motion, dists, strack_pool, detections)
# The dists is the list of distances of the detection with the tracks in strack_pool
matches, u_track, u_detection = matching.linear_assignment(dists, thresh=self.tracked_thresh)
# The matches is the array for corresponding matches of the detection with the corresponding strack_pool
for itracked, idet in matches:
# itracked is the id of the track and idet is the detection
track = strack_pool[itracked]
det = detections[idet]
if track.state == TrackState.Tracked:
# If the track is active, add the detection to the track
track.update(detections[idet], self.frame_id)
activated_starcks.append(track)
else:
# We have obtained a detection from a track which is not active,
# hence put the track in refind_stracks list
track.re_activate(det, self.frame_id, new_id=False)
refind_stracks.append(track)
# None of the steps below happen if there are no undetected tracks.
""" Step 3: Second association, with IOU"""
detections = [detections[i] for i in u_detection]
# detections is now a list of the unmatched detections
r_tracked_stracks = []
# This is container for stracks which were tracked till the previous
# frame but no detection was found for it in the current frame.
for i in u_track:
if strack_pool[i].state == TrackState.Tracked:
r_tracked_stracks.append(strack_pool[i])
dists = matching.iou_distance(r_tracked_stracks, detections)
matches, u_track, u_detection = matching.linear_assignment(dists, thresh=self.r_tracked_thresh)
# matches is the list of detections which matched with corresponding
# tracks by IOU distance method.
for itracked, idet in matches:
track = r_tracked_stracks[itracked]
det = detections[idet]
if track.state == TrackState.Tracked:
track.update(det, self.frame_id)
activated_starcks.append(track)
else:
track.re_activate(det, self.frame_id, new_id=False)
refind_stracks.append(track)
# Same process done for some unmatched detections, but now considering IOU_distance as measure
for it in u_track:
track = r_tracked_stracks[it]
if not track.state == TrackState.Lost:
track.mark_lost()
lost_stracks.append(track)
# If no detections are obtained for tracks (u_track), the tracks are added to lost_tracks list and are marked lost
'''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
detections = [detections[i] for i in u_detection]
dists = matching.iou_distance(unconfirmed, detections)
matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=self.unconfirmed_thresh)
for itracked, idet in matches:
unconfirmed[itracked].update(detections[idet], self.frame_id)
activated_starcks.append(unconfirmed[itracked])
# The tracks which are yet not matched
for it in u_unconfirmed:
track = unconfirmed[it]
track.mark_removed()
removed_stracks.append(track)
# after all these confirmation steps, if a new detection is found, it is initialized for a new track
""" Step 4: Init new stracks"""
for inew in u_detection:
track = detections[inew]
if track.score < self.det_thresh:
continue
track.activate(self.motion, self.frame_id)
activated_starcks.append(track)
""" Step 5: Update state"""
# If the tracks are lost for more frames than the threshold number, the tracks are removed.
for track in self.lost_stracks:
if self.frame_id - track.end_frame > self.max_time_lost:
track.mark_removed()
removed_stracks.append(track)
# Update the self.tracked_stracks and self.lost_stracks using the updates in this step.
self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks)
self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks)
self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
self.lost_stracks.extend(lost_stracks)
self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
self.removed_stracks.extend(removed_stracks)
self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks)
# get scores of lost tracks
output_stracks = [track for track in self.tracked_stracks if track.is_activated]
logger.debug('===========Frame {}=========='.format(self.frame_id))
logger.debug('Activated: {}'.format([track.track_id for track in activated_starcks]))
logger.debug('Refind: {}'.format([track.track_id for track in refind_stracks]))
logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks]))
logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks]))
return output_stracks
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import time
import paddle
import numpy as np
__all__ = [
'Timer',
'Detection',
'load_det_results',
'preprocess_reid',
'get_crops',
'clip_box',
'scale_coords',
]
class Timer(object):
"""
This class used to compute and print the current FPS while evaling.
"""
def __init__(self):
self.total_time = 0.
self.calls = 0
self.start_time = 0.
self.diff = 0.
self.average_time = 0.
self.duration = 0.
def tic(self):
# using time.time instead of time.clock because time time.clock
# does not normalize for multithreading
self.start_time = time.time()
def toc(self, average=True):
self.diff = time.time() - self.start_time
self.total_time += self.diff
self.calls += 1
self.average_time = self.total_time / self.calls
if average:
self.duration = self.average_time
else:
self.duration = self.diff
return self.duration
def clear(self):
self.total_time = 0.
self.calls = 0
self.start_time = 0.
self.diff = 0.
self.average_time = 0.
self.duration = 0.
class Detection(object):
"""
This class represents a bounding box detection in a single image.
Args:
tlwh (ndarray): Bounding box in format `(top left x, top left y,
width, height)`.
confidence (ndarray): Detector confidence score.
feature (Tensor): A feature vector that describes the object
contained in this image.
"""
def __init__(self, tlwh, confidence, feature):
self.tlwh = np.asarray(tlwh, dtype=np.float32)
self.confidence = np.asarray(confidence, dtype=np.float32)
self.feature = feature
def to_tlbr(self):
"""
Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
`(top left, bottom right)`.
"""
ret = self.tlwh.copy()
ret[2:] += ret[:2]
return ret
def to_xyah(self):
"""
Convert bounding box to format `(center x, center y, aspect ratio,
height)`, where the aspect ratio is `width / height`.
"""
ret = self.tlwh.copy()
ret[:2] += ret[2:] / 2
ret[2] /= ret[3]
return ret
def load_det_results(det_file, num_frames):
assert os.path.exists(det_file) and os.path.isfile(det_file), \
'Error: det_file: {} not exist or not a file.'.format(det_file)
labels = np.loadtxt(det_file, dtype='float32', delimiter=',')
results_list = []
for frame_i in range(0, num_frames):
results = {'bbox': [], 'score': []}
lables_with_frame = labels[labels[:, 0] == frame_i + 1]
for l in lables_with_frame:
results['bbox'].append(l[1:5])
results['score'].append(l[5])
results_list.append(results)
return results_list
def scale_coords(coords, input_shape, im_shape, scale_factor):
im_shape = im_shape.numpy()[0]
ratio = scale_factor[0][0]
pad_w = (input_shape[1] - int(im_shape[1])) / 2
pad_h = (input_shape[0] - int(im_shape[0])) / 2
coords = paddle.cast(coords, 'float32')
coords[:, 0::2] -= pad_w
coords[:, 1::2] -= pad_h
coords[:, 0:4] /= ratio
coords[:, :4] = paddle.clip(coords[:, :4], min=0, max=coords[:, :4].max())
return coords.round()
def clip_box(xyxy, input_shape, im_shape, scale_factor):
im_shape = im_shape.numpy()[0]
ratio = scale_factor.numpy()[0][0]
img0_shape = [int(im_shape[0] / ratio), int(im_shape[1] / ratio)]
xyxy[:, 0::2] = paddle.clip(xyxy[:, 0::2], min=0, max=img0_shape[1])
xyxy[:, 1::2] = paddle.clip(xyxy[:, 1::2], min=0, max=img0_shape[0])
return xyxy
def get_crops(xyxy, ori_img, pred_scores, w, h):
crops = []
keep_scores = []
xyxy = xyxy.numpy().astype(np.int64)
ori_img = ori_img.numpy()
ori_img = np.squeeze(ori_img, axis=0).transpose(1, 0, 2)
pred_scores = pred_scores.numpy()
for i, bbox in enumerate(xyxy):
if bbox[2] <= bbox[0] or bbox[3] <= bbox[1]:
continue
crop = ori_img[bbox[0]:bbox[2], bbox[1]:bbox[3], :]
crops.append(crop)
keep_scores.append(pred_scores[i])
if len(crops) == 0:
return [], []
crops = preprocess_reid(crops, w, h)
return crops, keep_scores
def preprocess_reid(imgs, w=64, h=192, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
im_batch = []
for img in imgs:
img = cv2.resize(img, (w, h))
img = img[:, :, ::-1].astype('float32').transpose((2, 0, 1)) / 255
img_mean = np.array(mean).reshape((3, 1, 1))
img_std = np.array(std).reshape((3, 1, 1))
img -= img_mean
img /= img_std
img = np.expand_dims(img, axis=0)
im_batch.append(img)
im_batch = np.concatenate(im_batch, 0)
return im_batch
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2
import numpy as np
def tlwhs_to_tlbrs(tlwhs):
tlbrs = np.copy(tlwhs)
if len(tlbrs) == 0:
return tlbrs
tlbrs[:, 2] += tlwhs[:, 0]
tlbrs[:, 3] += tlwhs[:, 1]
return tlbrs
def get_color(idx):
idx = idx * 3
color = ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255)
return color
def resize_image(image, max_size=800):
if max(image.shape[:2]) > max_size:
scale = float(max_size) / max(image.shape[:2])
image = cv2.resize(image, None, fx=scale, fy=scale)
return image
def plot_tracking(image, tlwhs, obj_ids, scores=None, frame_id=0, fps=0., ids2=None):
im = np.ascontiguousarray(np.copy(image))
im_h, im_w = im.shape[:2]
top_view = np.zeros([im_w, im_w, 3], dtype=np.uint8) + 255
text_scale = max(1, image.shape[1] / 1600.)
text_thickness = 2
line_thickness = max(1, int(image.shape[1] / 500.))
radius = max(5, int(im_w / 140.))
cv2.putText(
im,
'frame: %d fps: %.2f num: %d' % (frame_id, fps, len(tlwhs)), (0, int(15 * text_scale)),
cv2.FONT_HERSHEY_PLAIN,
text_scale, (0, 0, 255),
thickness=2)
for i, tlwh in enumerate(tlwhs):
x1, y1, w, h = tlwh
intbox = tuple(map(int, (x1, y1, x1 + w, y1 + h)))
obj_id = int(obj_ids[i])
id_text = '{}'.format(int(obj_id))
if ids2 is not None:
id_text = id_text + ', {}'.format(int(ids2[i]))
_line_thickness = 1 if obj_id <= 0 else line_thickness
color = get_color(abs(obj_id))
cv2.rectangle(im, intbox[0:2], intbox[2:4], color=color, thickness=line_thickness)
cv2.putText(
im,
id_text, (intbox[0], intbox[1] + 10),
cv2.FONT_HERSHEY_PLAIN,
text_scale, (0, 0, 255),
thickness=text_thickness)
if scores is not None:
text = '{:.2f}'.format(float(scores[i]))
cv2.putText(
im,
text, (intbox[0], intbox[1] - 10),
cv2.FONT_HERSHEY_PLAIN,
text_scale, (0, 255, 255),
thickness=text_thickness)
return im
def plot_trajectory(image, tlwhs, track_ids):
image = image.copy()
for one_tlwhs, track_id in zip(tlwhs, track_ids):
color = get_color(int(track_id))
for tlwh in one_tlwhs:
x1, y1, w, h = tuple(map(int, tlwh))
cv2.circle(image, (int(x1 + 0.5 * w), int(y1 + h)), 2, color, thickness=2)
return image
def plot_detections(image, tlbrs, scores=None, color=(255, 0, 0), ids=None):
im = np.copy(image)
text_scale = max(1, image.shape[1] / 800.)
thickness = 2 if text_scale > 1.3 else 1
for i, det in enumerate(tlbrs):
x1, y1, x2, y2 = np.asarray(det[:4], dtype=np.int)
if len(det) >= 7:
label = 'det' if det[5] > 0 else 'trk'
if ids is not None:
text = '{}# {:.2f}: {:d}'.format(label, det[6], ids[i])
cv2.putText(
im, text, (x1, y1 + 30), cv2.FONT_HERSHEY_PLAIN, text_scale, (0, 255, 255), thickness=thickness)
else:
text = '{}# {:.2f}'.format(label, det[6])
if scores is not None:
text = '{:.2f}'.format(scores[i])
cv2.putText(im, text, (x1, y1 + 30), cv2.FONT_HERSHEY_PLAIN, text_scale, (0, 255, 255), thickness=thickness)
cv2.rectangle(im, (x1, y1), (x2, y2), color, 2)
return im
......@@ -16,18 +16,19 @@ import cv2
import glob
import paddle
import numpy as np
import collections
from ppdet.core.workspace import create
from ppdet.utils.checkpoint import load_weight, load_pretrain_weight
from ppdet.modeling.mot.utils import Detection, get_crops, scale_coords, clip_box
from ppdet.modeling.mot.utils import Timer, load_det_results
from ppdet.modeling.mot import visualization as mot_vis
from ppdet.metrics import Metric, MOTMetric, KITTIMOTMetric
import ppdet.utils.stats as stats
from ppdet.engine.callbacks import Callback, ComposeCallback
from ppdet.core.workspace import create
from ppdet.utils.logger import setup_logger
from .dataset import MOTVideoStream, MOTImageStream
from .modeling.mot.utils import Detection, get_crops, scale_coords, clip_box
from .modeling.mot import visualization as mot_vis
from .utils import Timer
logger = setup_logger(__name__)
......@@ -70,7 +71,6 @@ class StreamTracker(object):
timer.tic()
pred_dets, pred_embs = self.model(data)
online_targets = self.model.tracker.update(pred_dets, pred_embs)
online_tlwhs, online_ids = [], []
online_scores = []
for t in online_targets:
......@@ -109,7 +109,6 @@ class StreamTracker(object):
with paddle.no_grad():
pred_dets, pred_embs = self.model(data)
online_targets = self.model.tracker.update(pred_dets, pred_embs)
online_tlwhs, online_ids = [], []
online_scores = []
for t in online_targets:
......
import time
class Timer(object):
"""
This class used to compute and print the current FPS while evaling.
"""
def __init__(self):
self.total_time = 0.
self.calls = 0
self.start_time = 0.
self.diff = 0.
self.average_time = 0.
self.duration = 0.
def tic(self):
# using time.time instead of time.clock because time time.clock
# does not normalize for multithreading
self.start_time = time.time()
def toc(self, average=True):
self.diff = time.time() - self.start_time
self.total_time += self.diff
self.calls += 1
self.average_time = self.total_time / self.calls
if average:
self.duration = self.average_time
else:
self.duration = self.diff
return self.duration
def clear(self):
self.total_time = 0.
self.calls = 0
self.start_time = 0.
self.diff = 0.
self.average_time = 0.
self.duration = 0.
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册