diff --git a/modules/video/multiple_object_tracking/fairmot_dla34/config/_base_/fairmot_dla34.yml b/modules/video/multiple_object_tracking/fairmot_dla34/config/_base_/fairmot_dla34.yml index c5f07de702fbeb594c9eeda60d709c0c40af8b1b..e2ca32a2b6c31d66a1b8f5fa42d278d0609dbdca 100644 --- a/modules/video/multiple_object_tracking/fairmot_dla34/config/_base_/fairmot_dla34.yml +++ b/modules/video/multiple_object_tracking/fairmot_dla34/config/_base_/fairmot_dla34.yml @@ -5,7 +5,7 @@ FairMOT: detector: CenterNet reid: FairMOTEmbeddingHead loss: FairMOTLoss - tracker: JDETracker + tracker: FrozenJDETracker CenterNet: backbone: DLA diff --git a/modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/__init__.py b/modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..258e4c9010832936f098e6febe777ac556f0668f --- /dev/null +++ b/modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import matching +from . import tracker +from . import motion +from . import visualization +from . import utils + +from .matching import * +from .tracker import * +from .motion import * +from .visualization import * +from .utils import * diff --git a/modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/matching/__init__.py b/modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/matching/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..54c6680f79f16247c562a9da1024dd3e1de4c57f --- /dev/null +++ b/modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/matching/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import jde_matching +from . import deepsort_matching + +from .jde_matching import * +from .deepsort_matching import * diff --git a/modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/matching/deepsort_matching.py b/modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/matching/deepsort_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..c55aa8876cc128f512aa4e2e4e48a935a3f8dd77 --- /dev/null +++ b/modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/matching/deepsort_matching.py @@ -0,0 +1,368 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is borrow from https://github.com/nwojke/deep_sort/tree/master/deep_sort +""" + +import numpy as np +from scipy.optimize import linear_sum_assignment +from ..motion import kalman_filter + +INFTY_COST = 1e+5 + +__all__ = [ + 'iou_1toN', + 'iou_cost', + '_nn_euclidean_distance', + '_nn_cosine_distance', + 'NearestNeighborDistanceMetric', + 'min_cost_matching', + 'matching_cascade', + 'gate_cost_matrix', +] + + +def iou_1toN(bbox, candidates): + """ + Computer intersection over union (IoU) by one box to N candidates. + + Args: + bbox (ndarray): A bounding box in format `(top left x, top left y, width, height)`. + candidates (ndarray): A matrix of candidate bounding boxes (one per row) in the + same format as `bbox`. + + Returns: + ious (ndarray): The intersection over union in [0, 1] between the `bbox` + and each candidate. A higher score means a larger fraction of the + `bbox` is occluded by the candidate. + """ + bbox_tl = bbox[:2] + bbox_br = bbox[:2] + bbox[2:] + candidates_tl = candidates[:, :2] + candidates_br = candidates[:, :2] + candidates[:, 2:] + + tl = np.c_[np.maximum(bbox_tl[0], candidates_tl[:, 0])[:, np.newaxis], + np.maximum(bbox_tl[1], candidates_tl[:, 1])[:, np.newaxis]] + br = np.c_[np.minimum(bbox_br[0], candidates_br[:, 0])[:, np.newaxis], + np.minimum(bbox_br[1], candidates_br[:, 1])[:, np.newaxis]] + wh = np.maximum(0., br - tl) + + area_intersection = wh.prod(axis=1) + area_bbox = bbox[2:].prod() + area_candidates = candidates[:, 2:].prod(axis=1) + ious = area_intersection / (area_bbox + area_candidates - area_intersection) + return ious + + +def iou_cost(tracks, detections, track_indices=None, detection_indices=None): + """ + IoU distance metric. + + Args: + tracks (list[Track]): A list of tracks. + detections (list[Detection]): A list of detections. + track_indices (Optional[list[int]]): A list of indices to tracks that + should be matched. Defaults to all `tracks`. + detection_indices (Optional[list[int]]): A list of indices to detections + that should be matched. Defaults to all `detections`. + + Returns: + cost_matrix (ndarray): A cost matrix of shape len(track_indices), + len(detection_indices) where entry (i, j) is + `1 - iou(tracks[track_indices[i]], detections[detection_indices[j]])`. + """ + if track_indices is None: + track_indices = np.arange(len(tracks)) + if detection_indices is None: + detection_indices = np.arange(len(detections)) + + cost_matrix = np.zeros((len(track_indices), len(detection_indices))) + for row, track_idx in enumerate(track_indices): + if tracks[track_idx].time_since_update > 1: + cost_matrix[row, :] = 1e+5 + continue + + bbox = tracks[track_idx].to_tlwh() + candidates = np.asarray([detections[i].tlwh for i in detection_indices]) + cost_matrix[row, :] = 1. - iou_1toN(bbox, candidates) + return cost_matrix + + +def _nn_euclidean_distance(s, q): + """ + Compute pair-wise squared (Euclidean) distance between points in `s` and `q`. + + Args: + s (ndarray): Sample points: an NxM matrix of N samples of dimensionality M. + q (ndarray): Query points: an LxM matrix of L samples of dimensionality M. + + Returns: + distances (ndarray): A vector of length M that contains for each entry in `q` the + smallest Euclidean distance to a sample in `s`. + """ + s, q = np.asarray(s), np.asarray(q) + if len(s) == 0 or len(q) == 0: + return np.zeros((len(s), len(q))) + s2, q2 = np.square(s).sum(axis=1), np.square(q).sum(axis=1) + distances = -2. * np.dot(s, q.T) + s2[:, None] + q2[None, :] + distances = np.clip(distances, 0., float(np.inf)) + + return np.maximum(0.0, distances.min(axis=0)) + + +def _nn_cosine_distance(s, q): + """ + Compute pair-wise cosine distance between points in `s` and `q`. + + Args: + s (ndarray): Sample points: an NxM matrix of N samples of dimensionality M. + q (ndarray): Query points: an LxM matrix of L samples of dimensionality M. + + Returns: + distances (ndarray): A vector of length M that contains for each entry in `q` the + smallest Euclidean distance to a sample in `s`. + """ + s = np.asarray(s) / np.linalg.norm(s, axis=1, keepdims=True) + q = np.asarray(q) / np.linalg.norm(q, axis=1, keepdims=True) + distances = 1. - np.dot(s, q.T) + + return distances.min(axis=0) + + +class NearestNeighborDistanceMetric(object): + """ + A nearest neighbor distance metric that, for each target, returns + the closest distance to any sample that has been observed so far. + + Args: + metric (str): Either "euclidean" or "cosine". + matching_threshold (float): The matching threshold. Samples with larger + distance are considered an invalid match. + budget (Optional[int]): If not None, fix samples per class to at most + this number. Removes the oldest samples when the budget is reached. + + Attributes: + samples (Dict[int -> List[ndarray]]): A dictionary that maps from target + identities to the list of samples that have been observed so far. + """ + + def __init__(self, metric, matching_threshold, budget=None): + if metric == "euclidean": + self._metric = _nn_euclidean_distance + elif metric == "cosine": + self._metric = _nn_cosine_distance + else: + raise ValueError("Invalid metric; must be either 'euclidean' or 'cosine'") + self.matching_threshold = matching_threshold + self.budget = budget + self.samples = {} + + def partial_fit(self, features, targets, active_targets): + """ + Update the distance metric with new data. + + Args: + features (ndarray): An NxM matrix of N features of dimensionality M. + targets (ndarray): An integer array of associated target identities. + active_targets (List[int]): A list of targets that are currently + present in the scene. + """ + for feature, target in zip(features, targets): + self.samples.setdefault(target, []).append(feature) + if self.budget is not None: + self.samples[target] = self.samples[target][-self.budget:] + self.samples = {k: self.samples[k] for k in active_targets} + + def distance(self, features, targets): + """ + Compute distance between features and targets. + + Args: + features (ndarray): An NxM matrix of N features of dimensionality M. + targets (list[int]): A list of targets to match the given `features` against. + + Returns: + cost_matrix (ndarray): a cost matrix of shape len(targets), len(features), + where element (i, j) contains the closest squared distance between + `targets[i]` and `features[j]`. + """ + cost_matrix = np.zeros((len(targets), len(features))) + for i, target in enumerate(targets): + cost_matrix[i, :] = self._metric(self.samples[target], features) + return cost_matrix + + +def min_cost_matching(distance_metric, max_distance, tracks, detections, track_indices=None, detection_indices=None): + """ + Solve linear assignment problem. + + Args: + distance_metric : + Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray + The distance metric is given a list of tracks and detections as + well as a list of N track indices and M detection indices. The + metric should return the NxM dimensional cost matrix, where element + (i, j) is the association cost between the i-th track in the given + track indices and the j-th detection in the given detection_indices. + max_distance (float): Gating threshold. Associations with cost larger + than this value are disregarded. + tracks (list[Track]): A list of predicted tracks at the current time + step. + detections (list[Detection]): A list of detections at the current time + step. + track_indices (list[int]): List of track indices that maps rows in + `cost_matrix` to tracks in `tracks`. + detection_indices (List[int]): List of detection indices that maps + columns in `cost_matrix` to detections in `detections`. + + Returns: + A tuple (List[(int, int)], List[int], List[int]) with the following + three entries: + * A list of matched track and detection indices. + * A list of unmatched track indices. + * A list of unmatched detection indices. + """ + if track_indices is None: + track_indices = np.arange(len(tracks)) + if detection_indices is None: + detection_indices = np.arange(len(detections)) + + if len(detection_indices) == 0 or len(track_indices) == 0: + return [], track_indices, detection_indices # Nothing to match. + + cost_matrix = distance_metric(tracks, detections, track_indices, detection_indices) + + cost_matrix[cost_matrix > max_distance] = max_distance + 1e-5 + indices = linear_sum_assignment(cost_matrix) + + matches, unmatched_tracks, unmatched_detections = [], [], [] + for col, detection_idx in enumerate(detection_indices): + if col not in indices[1]: + unmatched_detections.append(detection_idx) + for row, track_idx in enumerate(track_indices): + if row not in indices[0]: + unmatched_tracks.append(track_idx) + for row, col in zip(indices[0], indices[1]): + track_idx = track_indices[row] + detection_idx = detection_indices[col] + if cost_matrix[row, col] > max_distance: + unmatched_tracks.append(track_idx) + unmatched_detections.append(detection_idx) + else: + matches.append((track_idx, detection_idx)) + return matches, unmatched_tracks, unmatched_detections + + +def matching_cascade(distance_metric, + max_distance, + cascade_depth, + tracks, + detections, + track_indices=None, + detection_indices=None): + """ + Run matching cascade. + + Args: + distance_metric : + Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray + The distance metric is given a list of tracks and detections as + well as a list of N track indices and M detection indices. The + metric should return the NxM dimensional cost matrix, where element + (i, j) is the association cost between the i-th track in the given + track indices and the j-th detection in the given detection_indices. + max_distance (float): Gating threshold. Associations with cost larger + than this value are disregarded. + cascade_depth (int): The cascade depth, should be se to the maximum + track age. + tracks (list[Track]): A list of predicted tracks at the current time + step. + detections (list[Detection]): A list of detections at the current time + step. + track_indices (list[int]): List of track indices that maps rows in + `cost_matrix` to tracks in `tracks`. + detection_indices (List[int]): List of detection indices that maps + columns in `cost_matrix` to detections in `detections`. + + Returns: + A tuple (List[(int, int)], List[int], List[int]) with the following + three entries: + * A list of matched track and detection indices. + * A list of unmatched track indices. + * A list of unmatched detection indices. + """ + if track_indices is None: + track_indices = list(range(len(tracks))) + if detection_indices is None: + detection_indices = list(range(len(detections))) + + unmatched_detections = detection_indices + matches = [] + for level in range(cascade_depth): + if len(unmatched_detections) == 0: # No detections left + break + + track_indices_l = [k for k in track_indices if tracks[k].time_since_update == 1 + level] + if len(track_indices_l) == 0: # Nothing to match at this level + continue + + matches_l, _, unmatched_detections = \ + min_cost_matching( + distance_metric, max_distance, tracks, detections, + track_indices_l, unmatched_detections) + matches += matches_l + unmatched_tracks = list(set(track_indices) - set(k for k, _ in matches)) + return matches, unmatched_tracks, unmatched_detections + + +def gate_cost_matrix(kf, + cost_matrix, + tracks, + detections, + track_indices, + detection_indices, + gated_cost=INFTY_COST, + only_position=False): + """ + Invalidate infeasible entries in cost matrix based on the state + distributions obtained by Kalman filtering. + + Args: + kf (object): The Kalman filter. + cost_matrix (ndarray): The NxM dimensional cost matrix, where N is the + number of track indices and M is the number of detection indices, + such that entry (i, j) is the association cost between + `tracks[track_indices[i]]` and `detections[detection_indices[j]]`. + tracks (list[Track]): A list of predicted tracks at the current time + step. + detections (list[Detection]): A list of detections at the current time + step. + track_indices (List[int]): List of track indices that maps rows in + `cost_matrix` to tracks in `tracks`. + detection_indices (List[int]): List of detection indices that maps + columns in `cost_matrix` to detections in `detections`. + gated_cost (Optional[float]): Entries in the cost matrix corresponding + to infeasible associations are set this value. Defaults to a very + large value. + only_position (Optional[bool]): If True, only the x, y position of the + state distribution is considered during gating. Default False. + """ + gating_dim = 2 if only_position else 4 + gating_threshold = kalman_filter.chi2inv95[gating_dim] + measurements = np.asarray([detections[i].to_xyah() for i in detection_indices]) + for row, track_idx in enumerate(track_indices): + track = tracks[track_idx] + gating_distance = kf.gating_distance(track.mean, track.covariance, measurements, only_position) + cost_matrix[row, gating_distance > gating_threshold] = gated_cost + return cost_matrix diff --git a/modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/matching/jde_matching.py b/modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/matching/jde_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..bf2e891c391c98ed8944f88377f62c9722fa5155 --- /dev/null +++ b/modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/matching/jde_matching.py @@ -0,0 +1,123 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is borrow from https://github.com/Zhongdao/Towards-Realtime-MOT/blob/master/tracker/matching.py +""" + +import lap +import scipy +import numpy as np +from scipy.spatial.distance import cdist +from ..motion import kalman_filter + +from ppdet.utils.logger import setup_logger +logger = setup_logger(__name__) + +__all__ = [ + 'merge_matches', + 'linear_assignment', + 'cython_bbox_ious', + 'iou_distance', + 'embedding_distance', + 'fuse_motion', +] + + +def merge_matches(m1, m2, shape): + O, P, Q = shape + m1 = np.asarray(m1) + m2 = np.asarray(m2) + + M1 = scipy.sparse.coo_matrix((np.ones(len(m1)), (m1[:, 0], m1[:, 1])), shape=(O, P)) + M2 = scipy.sparse.coo_matrix((np.ones(len(m2)), (m2[:, 0], m2[:, 1])), shape=(P, Q)) + + mask = M1 * M2 + match = mask.nonzero() + match = list(zip(match[0], match[1])) + unmatched_O = tuple(set(range(O)) - set([i for i, j in match])) + unmatched_Q = tuple(set(range(Q)) - set([j for i, j in match])) + + return match, unmatched_O, unmatched_Q + + +def linear_assignment(cost_matrix, thresh): + if cost_matrix.size == 0: + return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1])) + matches, unmatched_a, unmatched_b = [], [], [] + cost, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh) + for ix, mx in enumerate(x): + if mx >= 0: + matches.append([ix, mx]) + unmatched_a = np.where(x < 0)[0] + unmatched_b = np.where(y < 0)[0] + matches = np.asarray(matches) + return matches, unmatched_a, unmatched_b + + +def cython_bbox_ious(atlbrs, btlbrs): + ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float) + if ious.size == 0: + return ious + try: + import cython_bbox + except Exception as e: + logger.error('cython_bbox not found, please install cython_bbox.' 'for example: `pip install cython_bbox`.') + raise e + + ious = cython_bbox.bbox_overlaps( + np.ascontiguousarray(atlbrs, dtype=np.float), np.ascontiguousarray(btlbrs, dtype=np.float)) + return ious + + +def iou_distance(atracks, btracks): + """ + Compute cost based on IoU between two list[STrack]. + """ + if (len(atracks) > 0 and isinstance(atracks[0], np.ndarray)) or (len(btracks) > 0 + and isinstance(btracks[0], np.ndarray)): + atlbrs = atracks + btlbrs = btracks + else: + atlbrs = [track.tlbr for track in atracks] + btlbrs = [track.tlbr for track in btracks] + _ious = cython_bbox_ious(atlbrs, btlbrs) + cost_matrix = 1 - _ious + + return cost_matrix + + +def embedding_distance(tracks, detections, metric='euclidean'): + """ + Compute cost based on features between two list[STrack]. + """ + cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float) + if cost_matrix.size == 0: + return cost_matrix + det_features = np.asarray([track.curr_feat for track in detections], dtype=np.float) + track_features = np.asarray([track.smooth_feat for track in tracks], dtype=np.float) + cost_matrix = np.maximum(0.0, cdist(track_features, det_features, metric)) # Nomalized features + return cost_matrix + + +def fuse_motion(kf, cost_matrix, tracks, detections, only_position=False, lambda_=0.98): + if cost_matrix.size == 0: + return cost_matrix + gating_dim = 2 if only_position else 4 + gating_threshold = kalman_filter.chi2inv95[gating_dim] + measurements = np.asarray([det.to_xyah() for det in detections]) + for row, track in enumerate(tracks): + gating_distance = kf.gating_distance(track.mean, track.covariance, measurements, only_position, metric='maha') + cost_matrix[row, gating_distance > gating_threshold] = np.inf + cost_matrix[row] = lambda_ * cost_matrix[row] + (1 - lambda_) * gating_distance + return cost_matrix diff --git a/modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/motion/__init__.py b/modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/motion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e42dd0b019d66d6ea07bec1ad90cf9a8d53d8172 --- /dev/null +++ b/modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/motion/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import kalman_filter + +from .kalman_filter import * diff --git a/modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/motion/kalman_filter.py b/modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/motion/kalman_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..7cc182e4c5e76e0688688c883b2a24fa30df9c74 --- /dev/null +++ b/modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/motion/kalman_filter.py @@ -0,0 +1,237 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is borrow from https://github.com/nwojke/deep_sort/blob/master/deep_sort/kalman_filter.py +""" + +import numpy as np +import scipy.linalg + +__all__ = ['KalmanFilter'] +""" +Table for the 0.95 quantile of the chi-square distribution with N degrees of +freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv +function and used as Mahalanobis gating threshold. +""" + +chi2inv95 = {1: 3.8415, 2: 5.9915, 3: 7.8147, 4: 9.4877, 5: 11.070, 6: 12.592, 7: 14.067, 8: 15.507, 9: 16.919} + + +class KalmanFilter(object): + """ + A simple Kalman filter for tracking bounding boxes in image space. + + The 8-dimensional state space + + x, y, a, h, vx, vy, va, vh + + contains the bounding box center position (x, y), aspect ratio a, height h, + and their respective velocities. + + Object motion follows a constant velocity model. The bounding box location + (x, y, a, h) is taken as direct observation of the state space (linear + observation model). + + """ + + def __init__(self): + ndim, dt = 4, 1. + + # Create Kalman filter model matrices. + self._motion_mat = np.eye(2 * ndim, 2 * ndim) + for i in range(ndim): + self._motion_mat[i, ndim + i] = dt + self._update_mat = np.eye(ndim, 2 * ndim) + + # Motion and observation uncertainty are chosen relative to the current + # state estimate. These weights control the amount of uncertainty in + # the model. This is a bit hacky. + self._std_weight_position = 1. / 20 + self._std_weight_velocity = 1. / 160 + + def initiate(self, measurement): + """ + Create track from unassociated measurement. + + Args: + measurement (ndarray): Bounding box coordinates (x, y, a, h) with + center position (x, y), aspect ratio a, and height h. + + Returns: + The mean vector (8 dimensional) and covariance matrix (8x8 + dimensional) of the new track. Unobserved velocities are + initialized to 0 mean. + """ + mean_pos = measurement + mean_vel = np.zeros_like(mean_pos) + mean = np.r_[mean_pos, mean_vel] + + std = [ + 2 * self._std_weight_position * measurement[3], 2 * self._std_weight_position * measurement[3], 1e-2, + 2 * self._std_weight_position * measurement[3], 10 * self._std_weight_velocity * measurement[3], + 10 * self._std_weight_velocity * measurement[3], 1e-5, 10 * self._std_weight_velocity * measurement[3] + ] + covariance = np.diag(np.square(std)) + return mean, covariance + + def predict(self, mean, covariance): + """ + Run Kalman filter prediction step. + + Args: + mean (ndarray): The 8 dimensional mean vector of the object state + at the previous time step. + covariance (ndarray): The 8x8 dimensional covariance matrix of the + object state at the previous time step. + + Returns: + The mean vector and covariance matrix of the predicted state. + Unobserved velocities are initialized to 0 mean. + """ + std_pos = [ + self._std_weight_position * mean[3], self._std_weight_position * mean[3], 1e-2, + self._std_weight_position * mean[3] + ] + std_vel = [ + self._std_weight_velocity * mean[3], self._std_weight_velocity * mean[3], 1e-5, + self._std_weight_velocity * mean[3] + ] + motion_cov = np.diag(np.square(np.r_[std_pos, std_vel])) + + #mean = np.dot(self._motion_mat, mean) + mean = np.dot(mean, self._motion_mat.T) + covariance = np.linalg.multi_dot((self._motion_mat, covariance, self._motion_mat.T)) + motion_cov + + return mean, covariance + + def project(self, mean, covariance): + """ + Project state distribution to measurement space. + + Args + mean (ndarray): The state's mean vector (8 dimensional array). + covariance (ndarray): The state's covariance matrix (8x8 dimensional). + + Returns: + The projected mean and covariance matrix of the given state estimate. + """ + std = [ + self._std_weight_position * mean[3], self._std_weight_position * mean[3], 1e-1, + self._std_weight_position * mean[3] + ] + innovation_cov = np.diag(np.square(std)) + + mean = np.dot(self._update_mat, mean) + covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T)) + return mean, covariance + innovation_cov + + def multi_predict(self, mean, covariance): + """ + Run Kalman filter prediction step (Vectorized version). + + Args: + mean (ndarray): The Nx8 dimensional mean matrix of the object states + at the previous time step. + covariance (ndarray): The Nx8x8 dimensional covariance matrics of the + object states at the previous time step. + + Returns: + The mean vector and covariance matrix of the predicted state. + Unobserved velocities are initialized to 0 mean. + """ + std_pos = [ + self._std_weight_position * mean[:, 3], self._std_weight_position * mean[:, 3], + 1e-2 * np.ones_like(mean[:, 3]), self._std_weight_position * mean[:, 3] + ] + std_vel = [ + self._std_weight_velocity * mean[:, 3], self._std_weight_velocity * mean[:, 3], + 1e-5 * np.ones_like(mean[:, 3]), self._std_weight_velocity * mean[:, 3] + ] + sqr = np.square(np.r_[std_pos, std_vel]).T + + motion_cov = [] + for i in range(len(mean)): + motion_cov.append(np.diag(sqr[i])) + motion_cov = np.asarray(motion_cov) + + mean = np.dot(mean, self._motion_mat.T) + left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2)) + covariance = np.dot(left, self._motion_mat.T) + motion_cov + + return mean, covariance + + def update(self, mean, covariance, measurement): + """ + Run Kalman filter correction step. + + Args: + mean (ndarray): The predicted state's mean vector (8 dimensional). + covariance (ndarray): The state's covariance matrix (8x8 dimensional). + measurement (ndarray): The 4 dimensional measurement vector + (x, y, a, h), where (x, y) is the center position, a the aspect + ratio, and h the height of the bounding box. + + Returns: + The measurement-corrected state distribution. + """ + projected_mean, projected_cov = self.project(mean, covariance) + + chol_factor, lower = scipy.linalg.cho_factor(projected_cov, lower=True, check_finite=False) + kalman_gain = scipy.linalg.cho_solve((chol_factor, lower), + np.dot(covariance, self._update_mat.T).T, + check_finite=False).T + innovation = measurement - projected_mean + + new_mean = mean + np.dot(innovation, kalman_gain.T) + new_covariance = covariance - np.linalg.multi_dot((kalman_gain, projected_cov, kalman_gain.T)) + return new_mean, new_covariance + + def gating_distance(self, mean, covariance, measurements, only_position=False, metric='maha'): + """ + Compute gating distance between state distribution and measurements. + A suitable distance threshold can be obtained from `chi2inv95`. If + `only_position` is False, the chi-square distribution has 4 degrees of + freedom, otherwise 2. + + Args: + mean (ndarray): Mean vector over the state distribution (8 + dimensional). + covariance (ndarray): Covariance of the state distribution (8x8 + dimensional). + measurements (ndarray): An Nx4 dimensional matrix of N measurements, + each in format (x, y, a, h) where (x, y) is the bounding box center + position, a the aspect ratio, and h the height. + only_position (Optional[bool]): If True, distance computation is + done with respect to the bounding box center position only. + metric (str): Metric type, 'gaussian' or 'maha'. + + Returns + An array of length N, where the i-th element contains the squared + Mahalanobis distance between (mean, covariance) and `measurements[i]`. + """ + mean, covariance = self.project(mean, covariance) + if only_position: + mean, covariance = mean[:2], covariance[:2, :2] + measurements = measurements[:, :2] + + d = measurements - mean + if metric == 'gaussian': + return np.sum(d * d, axis=1) + elif metric == 'maha': + cholesky_factor = np.linalg.cholesky(covariance) + z = scipy.linalg.solve_triangular(cholesky_factor, d.T, lower=True, check_finite=False, overwrite_b=True) + squared_maha = np.sum(z * z, axis=0) + return squared_maha + else: + raise ValueError('invalid distance metric') diff --git a/modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/tracker/__init__.py b/modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/tracker/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..904822119661be61141715c638388db9d045fee1 --- /dev/null +++ b/modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/tracker/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import base_jde_tracker +from . import base_sde_tracker +from . import jde_tracker + +from .base_jde_tracker import * +from .base_sde_tracker import * +from .jde_tracker import * diff --git a/modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/tracker/base_jde_tracker.py b/modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/tracker/base_jde_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..9505a709ee573acecf4b5dd7e02a06cee9d44284 --- /dev/null +++ b/modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/tracker/base_jde_tracker.py @@ -0,0 +1,257 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is borrow from https://github.com/Zhongdao/Towards-Realtime-MOT/blob/master/tracker/multitracker.py +""" + +import numpy as np +from collections import deque, OrderedDict +from ..matching import jde_matching as matching +from ppdet.core.workspace import register, serializable + +__all__ = [ + 'TrackState', + 'BaseTrack', + 'STrack', + 'joint_stracks', + 'sub_stracks', + 'remove_duplicate_stracks', +] + + +class TrackState(object): + New = 0 + Tracked = 1 + Lost = 2 + Removed = 3 + + +class BaseTrack(object): + _count = 0 + + track_id = 0 + is_activated = False + state = TrackState.New + + history = OrderedDict() + features = [] + curr_feature = None + score = 0 + start_frame = 0 + frame_id = 0 + time_since_update = 0 + + # multi-camera + location = (np.inf, np.inf) + + @property + def end_frame(self): + return self.frame_id + + @staticmethod + def next_id(): + BaseTrack._count += 1 + return BaseTrack._count + + def activate(self, *args): + raise NotImplementedError + + def predict(self): + raise NotImplementedError + + def update(self, *args, **kwargs): + raise NotImplementedError + + def mark_lost(self): + self.state = TrackState.Lost + + def mark_removed(self): + self.state = TrackState.Removed + + +class STrack(BaseTrack): + def __init__(self, tlwh, score, temp_feat, buffer_size=30): + # wait activate + self._tlwh = np.asarray(tlwh, dtype=np.float) + self.kalman_filter = None + self.mean, self.covariance = None, None + self.is_activated = False + + self.score = score + self.tracklet_len = 0 + + self.smooth_feat = None + self.update_features(temp_feat) + self.features = deque([], maxlen=buffer_size) + self.alpha = 0.9 + + def update_features(self, feat): + feat /= np.linalg.norm(feat) + self.curr_feat = feat + if self.smooth_feat is None: + self.smooth_feat = feat + else: + self.smooth_feat = self.alpha * self.smooth_feat + (1 - self.alpha) * feat + self.features.append(feat) + self.smooth_feat /= np.linalg.norm(self.smooth_feat) + + def predict(self): + mean_state = self.mean.copy() + if self.state != TrackState.Tracked: + mean_state[7] = 0 + self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance) + + @staticmethod + def multi_predict(stracks, kalman_filter): + if len(stracks) > 0: + multi_mean = np.asarray([st.mean.copy() for st in stracks]) + multi_covariance = np.asarray([st.covariance for st in stracks]) + for i, st in enumerate(stracks): + if st.state != TrackState.Tracked: + multi_mean[i][7] = 0 + multi_mean, multi_covariance = kalman_filter.multi_predict(multi_mean, multi_covariance) + for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)): + stracks[i].mean = mean + stracks[i].covariance = cov + + def activate(self, kalman_filter, frame_id): + """Start a new tracklet""" + self.kalman_filter = kalman_filter + self.track_id = self.next_id() + self.mean, self.covariance = self.kalman_filter.initiate(self.tlwh_to_xyah(self._tlwh)) + + self.tracklet_len = 0 + self.state = TrackState.Tracked + if frame_id == 1: + self.is_activated = True + self.frame_id = frame_id + self.start_frame = frame_id + + def re_activate(self, new_track, frame_id, new_id=False): + self.mean, self.covariance = self.kalman_filter.update(self.mean, self.covariance, + self.tlwh_to_xyah(new_track.tlwh)) + + self.update_features(new_track.curr_feat) + self.tracklet_len = 0 + self.state = TrackState.Tracked + self.is_activated = True + self.frame_id = frame_id + if new_id: + self.track_id = self.next_id() + + def update(self, new_track, frame_id, update_feature=True): + self.frame_id = frame_id + self.tracklet_len += 1 + + new_tlwh = new_track.tlwh + self.mean, self.covariance = self.kalman_filter.update(self.mean, self.covariance, self.tlwh_to_xyah(new_tlwh)) + self.state = TrackState.Tracked + self.is_activated = True + + self.score = new_track.score + if update_feature: + self.update_features(new_track.curr_feat) + + @property + def tlwh(self): + """ + Get current position in bounding box format `(top left x, top left y, + width, height)`. + """ + if self.mean is None: + return self._tlwh.copy() + ret = self.mean[:4].copy() + ret[2] *= ret[3] + ret[:2] -= ret[2:] / 2 + return ret + + @property + def tlbr(self): + """ + Convert bounding box to format `(min x, min y, max x, max y)`, i.e., + `(top left, bottom right)`. + """ + ret = self.tlwh.copy() + ret[2:] += ret[:2] + return ret + + @staticmethod + def tlwh_to_xyah(tlwh): + """ + Convert bounding box to format `(center x, center y, aspect ratio, + height)`, where the aspect ratio is `width / height`. + """ + ret = np.asarray(tlwh).copy() + ret[:2] += ret[2:] / 2 + ret[2] /= ret[3] + return ret + + def to_xyah(self): + return self.tlwh_to_xyah(self.tlwh) + + @staticmethod + def tlbr_to_tlwh(tlbr): + ret = np.asarray(tlbr).copy() + ret[2:] -= ret[:2] + return ret + + @staticmethod + def tlwh_to_tlbr(tlwh): + ret = np.asarray(tlwh).copy() + ret[2:] += ret[:2] + return ret + + def __repr__(self): + return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame) + + +def joint_stracks(tlista, tlistb): + exists = {} + res = [] + for t in tlista: + exists[t.track_id] = 1 + res.append(t) + for t in tlistb: + tid = t.track_id + if not exists.get(tid, 0): + exists[tid] = 1 + res.append(t) + return res + + +def sub_stracks(tlista, tlistb): + stracks = {} + for t in tlista: + stracks[t.track_id] = t + for t in tlistb: + tid = t.track_id + if stracks.get(tid, 0): + del stracks[tid] + return list(stracks.values()) + + +def remove_duplicate_stracks(stracksa, stracksb): + pdist = matching.iou_distance(stracksa, stracksb) + pairs = np.where(pdist < 0.15) + dupa, dupb = list(), list() + for p, q in zip(*pairs): + timep = stracksa[p].frame_id - stracksa[p].start_frame + timeq = stracksb[q].frame_id - stracksb[q].start_frame + if timep > timeq: + dupb.append(q) + else: + dupa.append(p) + resa = [t for i, t in enumerate(stracksa) if not i in dupa] + resb = [t for i, t in enumerate(stracksb) if not i in dupb] + return resa, resb diff --git a/modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/tracker/base_sde_tracker.py b/modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/tracker/base_sde_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..2e811e536a42ff781f60872b448b251de0301f61 --- /dev/null +++ b/modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/tracker/base_sde_tracker.py @@ -0,0 +1,133 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is borrow from https://github.com/nwojke/deep_sort/blob/master/deep_sort/track.py +""" + +from ppdet.core.workspace import register, serializable + +__all__ = ['TrackState', 'Track'] + + +class TrackState(object): + """ + Enumeration type for the single target track state. Newly created tracks are + classified as `tentative` until enough evidence has been collected. Then, + the track state is changed to `confirmed`. Tracks that are no longer alive + are classified as `deleted` to mark them for removal from the set of active + tracks. + """ + Tentative = 1 + Confirmed = 2 + Deleted = 3 + + +class Track(object): + """ + A single target track with state space `(x, y, a, h)` and associated + velocities, where `(x, y)` is the center of the bounding box, `a` is the + aspect ratio and `h` is the height. + + Args: + mean (ndarray): Mean vector of the initial state distribution. + covariance (ndarray): Covariance matrix of the initial state distribution. + track_id (int): A unique track identifier. + n_init (int): Number of consecutive detections before the track is confirmed. + The track state is set to `Deleted` if a miss occurs within the first + `n_init` frames. + max_age (int): The maximum number of consecutive misses before the track + state is set to `Deleted`. + feature (Optional[ndarray]): Feature vector of the detection this track + originates from. If not None, this feature is added to the `features` cache. + + Attributes: + hits (int): Total number of measurement updates. + age (int): Total number of frames since first occurance. + time_since_update (int): Total number of frames since last measurement + update. + state (TrackState): The current track state. + features (List[ndarray]): A cache of features. On each measurement update, + the associated feature vector is added to this list. + """ + + def __init__(self, mean, covariance, track_id, n_init, max_age, feature=None): + self.mean = mean + self.covariance = covariance + self.track_id = track_id + self.hits = 1 + self.age = 1 + self.time_since_update = 0 + + self.state = TrackState.Tentative + self.features = [] + if feature is not None: + self.features.append(feature) + + self._n_init = n_init + self._max_age = max_age + + def to_tlwh(self): + """Get position in format `(top left x, top left y, width, height)`.""" + ret = self.mean[:4].copy() + ret[2] *= ret[3] + ret[:2] -= ret[2:] / 2 + return ret + + def to_tlbr(self): + """Get position in bounding box format `(min x, miny, max x, max y)`.""" + ret = self.to_tlwh() + ret[2:] = ret[:2] + ret[2:] + return ret + + def predict(self, kalman_filter): + """ + Propagate the state distribution to the current time step using a Kalman + filter prediction step. + """ + self.mean, self.covariance = kalman_filter.predict(self.mean, self.covariance) + self.age += 1 + self.time_since_update += 1 + + def update(self, kalman_filter, detection): + """ + Perform Kalman filter measurement update step and update the associated + detection feature cache. + """ + self.mean, self.covariance = kalman_filter.update(self.mean, self.covariance, detection.to_xyah()) + self.features.append(detection.feature) + + self.hits += 1 + self.time_since_update = 0 + if self.state == TrackState.Tentative and self.hits >= self._n_init: + self.state = TrackState.Confirmed + + def mark_missed(self): + """Mark this track as missed (no association at the current time step). + """ + if self.state == TrackState.Tentative: + self.state = TrackState.Deleted + elif self.time_since_update > self._max_age: + self.state = TrackState.Deleted + + def is_tentative(self): + """Returns True if this track is tentative (unconfirmed).""" + return self.state == TrackState.Tentative + + def is_confirmed(self): + """Returns True if this track is confirmed.""" + return self.state == TrackState.Confirmed + + def is_deleted(self): + """Returns True if this track is dead and should be deleted.""" + return self.state == TrackState.Deleted diff --git a/modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/tracker/jde_tracker.py b/modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/tracker/jde_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..2e1cafb345b7687e563fc6d9c2c1769cb39d690c --- /dev/null +++ b/modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/tracker/jde_tracker.py @@ -0,0 +1,248 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is borrow from https://github.com/Zhongdao/Towards-Realtime-MOT/blob/master/tracker/multitracker.py +""" + +import paddle + +from ..matching import jde_matching as matching +from .base_jde_tracker import TrackState, BaseTrack, STrack +from .base_jde_tracker import joint_stracks, sub_stracks, remove_duplicate_stracks + +from ppdet.core.workspace import register, serializable +from ppdet.utils.logger import setup_logger +logger = setup_logger(__name__) + +__all__ = ['FrozenJDETracker'] + + +@register +@serializable +class FrozenJDETracker(object): + __inject__ = ['motion'] + """ + JDE tracker + + Args: + det_thresh (float): threshold of detection score + track_buffer (int): buffer for tracker + min_box_area (int): min box area to filter out low quality boxes + vertical_ratio (float): w/h, the vertical ratio of the bbox to filter + bad results, set 1.6 default for pedestrian tracking. If set -1 + means no need to filter bboxes. + tracked_thresh (float): linear assignment threshold of tracked + stracks and detections + r_tracked_thresh (float): linear assignment threshold of + tracked stracks and unmatched detections + unconfirmed_thresh (float): linear assignment threshold of + unconfirmed stracks and unmatched detections + motion (object): KalmanFilter instance + conf_thres (float): confidence threshold for tracking + metric_type (str): either "euclidean" or "cosine", the distance metric + used for measurement to track association. + """ + + def __init__(self, + det_thresh=0.3, + track_buffer=30, + min_box_area=200, + vertical_ratio=1.6, + tracked_thresh=0.7, + r_tracked_thresh=0.5, + unconfirmed_thresh=0.7, + motion='KalmanFilter', + conf_thres=0, + metric_type='euclidean'): + self.det_thresh = det_thresh + self.track_buffer = track_buffer + self.min_box_area = min_box_area + self.vertical_ratio = vertical_ratio + + self.tracked_thresh = tracked_thresh + self.r_tracked_thresh = r_tracked_thresh + self.unconfirmed_thresh = unconfirmed_thresh + self.motion = motion + self.conf_thres = conf_thres + self.metric_type = metric_type + + self.frame_id = 0 + self.tracked_stracks = [] + self.lost_stracks = [] + self.removed_stracks = [] + + self.max_time_lost = 0 + # max_time_lost will be calculated: int(frame_rate / 30.0 * track_buffer) + + def update(self, pred_dets, pred_embs): + """ + Processes the image frame and finds bounding box(detections). + Associates the detection with corresponding tracklets and also handles + lost, removed, refound and active tracklets. + + Args: + pred_dets (Tensor): Detection results of the image, shape is [N, 5]. + pred_embs (Tensor): Embedding results of the image, shape is [N, 512]. + + Return: + output_stracks (list): The list contains information regarding the + online_tracklets for the recieved image tensor. + """ + self.frame_id += 1 + activated_starcks = [] + # for storing active tracks, for the current frame + refind_stracks = [] + # Lost Tracks whose detections are obtained in the current frame + lost_stracks = [] + # The tracks which are not obtained in the current frame but are not + # removed. (Lost for some time lesser than the threshold for removing) + removed_stracks = [] + + remain_inds = paddle.nonzero(pred_dets[:, 4] > self.conf_thres) + if remain_inds.shape[0] == 0: + pred_dets = paddle.zeros([0, 1]) + pred_embs = paddle.zeros([0, 1]) + else: + pred_dets = paddle.gather(pred_dets, remain_inds) + pred_embs = paddle.gather(pred_embs, remain_inds) + + # Filter out the image with box_num = 0. pred_dets = [[0.0, 0.0, 0.0 ,0.0]] + empty_pred = True if len(pred_dets) == 1 and paddle.sum(pred_dets) == 0.0 else False + """ Step 1: Network forward, get detections & embeddings""" + if len(pred_dets) > 0 and not empty_pred: + pred_dets = pred_dets.numpy() + pred_embs = pred_embs.numpy() + detections = [ + STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f, 30) for (tlbrs, f) in zip(pred_dets, pred_embs) + ] + else: + detections = [] + ''' Add newly detected tracklets to tracked_stracks''' + unconfirmed = [] + tracked_stracks = [] # type: list[STrack] + for track in self.tracked_stracks: + if not track.is_activated: + # previous tracks which are not active in the current frame are added in unconfirmed list + unconfirmed.append(track) + else: + # Active tracks are added to the local list 'tracked_stracks' + tracked_stracks.append(track) + """ Step 2: First association, with embedding""" + # Combining currently tracked_stracks and lost_stracks + strack_pool = joint_stracks(tracked_stracks, self.lost_stracks) + # Predict the current location with KF + STrack.multi_predict(strack_pool, self.motion) + + dists = matching.embedding_distance(strack_pool, detections, metric=self.metric_type) + dists = matching.fuse_motion(self.motion, dists, strack_pool, detections) + # The dists is the list of distances of the detection with the tracks in strack_pool + matches, u_track, u_detection = matching.linear_assignment(dists, thresh=self.tracked_thresh) + # The matches is the array for corresponding matches of the detection with the corresponding strack_pool + + for itracked, idet in matches: + # itracked is the id of the track and idet is the detection + track = strack_pool[itracked] + det = detections[idet] + if track.state == TrackState.Tracked: + # If the track is active, add the detection to the track + track.update(detections[idet], self.frame_id) + activated_starcks.append(track) + else: + # We have obtained a detection from a track which is not active, + # hence put the track in refind_stracks list + track.re_activate(det, self.frame_id, new_id=False) + refind_stracks.append(track) + + # None of the steps below happen if there are no undetected tracks. + """ Step 3: Second association, with IOU""" + detections = [detections[i] for i in u_detection] + # detections is now a list of the unmatched detections + r_tracked_stracks = [] + # This is container for stracks which were tracked till the previous + # frame but no detection was found for it in the current frame. + + for i in u_track: + if strack_pool[i].state == TrackState.Tracked: + r_tracked_stracks.append(strack_pool[i]) + dists = matching.iou_distance(r_tracked_stracks, detections) + matches, u_track, u_detection = matching.linear_assignment(dists, thresh=self.r_tracked_thresh) + # matches is the list of detections which matched with corresponding + # tracks by IOU distance method. + + for itracked, idet in matches: + track = r_tracked_stracks[itracked] + det = detections[idet] + if track.state == TrackState.Tracked: + track.update(det, self.frame_id) + activated_starcks.append(track) + else: + track.re_activate(det, self.frame_id, new_id=False) + refind_stracks.append(track) + # Same process done for some unmatched detections, but now considering IOU_distance as measure + + for it in u_track: + track = r_tracked_stracks[it] + if not track.state == TrackState.Lost: + track.mark_lost() + lost_stracks.append(track) + # If no detections are obtained for tracks (u_track), the tracks are added to lost_tracks list and are marked lost + '''Deal with unconfirmed tracks, usually tracks with only one beginning frame''' + detections = [detections[i] for i in u_detection] + dists = matching.iou_distance(unconfirmed, detections) + matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=self.unconfirmed_thresh) + for itracked, idet in matches: + unconfirmed[itracked].update(detections[idet], self.frame_id) + activated_starcks.append(unconfirmed[itracked]) + + # The tracks which are yet not matched + for it in u_unconfirmed: + track = unconfirmed[it] + track.mark_removed() + removed_stracks.append(track) + + # after all these confirmation steps, if a new detection is found, it is initialized for a new track + """ Step 4: Init new stracks""" + for inew in u_detection: + track = detections[inew] + if track.score < self.det_thresh: + continue + track.activate(self.motion, self.frame_id) + activated_starcks.append(track) + """ Step 5: Update state""" + # If the tracks are lost for more frames than the threshold number, the tracks are removed. + for track in self.lost_stracks: + if self.frame_id - track.end_frame > self.max_time_lost: + track.mark_removed() + removed_stracks.append(track) + + # Update the self.tracked_stracks and self.lost_stracks using the updates in this step. + self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked] + self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks) + self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks) + + self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks) + self.lost_stracks.extend(lost_stracks) + self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks) + self.removed_stracks.extend(removed_stracks) + self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks) + # get scores of lost tracks + output_stracks = [track for track in self.tracked_stracks if track.is_activated] + + logger.debug('===========Frame {}=========='.format(self.frame_id)) + logger.debug('Activated: {}'.format([track.track_id for track in activated_starcks])) + logger.debug('Refind: {}'.format([track.track_id for track in refind_stracks])) + logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks])) + logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks])) + + return output_stracks diff --git a/modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/utils.py b/modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..12c61686a1715a965407822dcf19fd1081f292d7 --- /dev/null +++ b/modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/utils.py @@ -0,0 +1,176 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import cv2 +import time +import paddle +import numpy as np + +__all__ = [ + 'Timer', + 'Detection', + 'load_det_results', + 'preprocess_reid', + 'get_crops', + 'clip_box', + 'scale_coords', +] + + +class Timer(object): + """ + This class used to compute and print the current FPS while evaling. + """ + + def __init__(self): + self.total_time = 0. + self.calls = 0 + self.start_time = 0. + self.diff = 0. + self.average_time = 0. + self.duration = 0. + + def tic(self): + # using time.time instead of time.clock because time time.clock + # does not normalize for multithreading + self.start_time = time.time() + + def toc(self, average=True): + self.diff = time.time() - self.start_time + self.total_time += self.diff + self.calls += 1 + self.average_time = self.total_time / self.calls + if average: + self.duration = self.average_time + else: + self.duration = self.diff + return self.duration + + def clear(self): + self.total_time = 0. + self.calls = 0 + self.start_time = 0. + self.diff = 0. + self.average_time = 0. + self.duration = 0. + + +class Detection(object): + """ + This class represents a bounding box detection in a single image. + + Args: + tlwh (ndarray): Bounding box in format `(top left x, top left y, + width, height)`. + confidence (ndarray): Detector confidence score. + feature (Tensor): A feature vector that describes the object + contained in this image. + """ + + def __init__(self, tlwh, confidence, feature): + self.tlwh = np.asarray(tlwh, dtype=np.float32) + self.confidence = np.asarray(confidence, dtype=np.float32) + self.feature = feature + + def to_tlbr(self): + """ + Convert bounding box to format `(min x, min y, max x, max y)`, i.e., + `(top left, bottom right)`. + """ + ret = self.tlwh.copy() + ret[2:] += ret[:2] + return ret + + def to_xyah(self): + """ + Convert bounding box to format `(center x, center y, aspect ratio, + height)`, where the aspect ratio is `width / height`. + """ + ret = self.tlwh.copy() + ret[:2] += ret[2:] / 2 + ret[2] /= ret[3] + return ret + + +def load_det_results(det_file, num_frames): + assert os.path.exists(det_file) and os.path.isfile(det_file), \ + 'Error: det_file: {} not exist or not a file.'.format(det_file) + labels = np.loadtxt(det_file, dtype='float32', delimiter=',') + results_list = [] + for frame_i in range(0, num_frames): + results = {'bbox': [], 'score': []} + lables_with_frame = labels[labels[:, 0] == frame_i + 1] + for l in lables_with_frame: + results['bbox'].append(l[1:5]) + results['score'].append(l[5]) + results_list.append(results) + return results_list + + +def scale_coords(coords, input_shape, im_shape, scale_factor): + im_shape = im_shape.numpy()[0] + ratio = scale_factor[0][0] + pad_w = (input_shape[1] - int(im_shape[1])) / 2 + pad_h = (input_shape[0] - int(im_shape[0])) / 2 + coords = paddle.cast(coords, 'float32') + coords[:, 0::2] -= pad_w + coords[:, 1::2] -= pad_h + coords[:, 0:4] /= ratio + coords[:, :4] = paddle.clip(coords[:, :4], min=0, max=coords[:, :4].max()) + return coords.round() + + +def clip_box(xyxy, input_shape, im_shape, scale_factor): + im_shape = im_shape.numpy()[0] + ratio = scale_factor.numpy()[0][0] + img0_shape = [int(im_shape[0] / ratio), int(im_shape[1] / ratio)] + + xyxy[:, 0::2] = paddle.clip(xyxy[:, 0::2], min=0, max=img0_shape[1]) + xyxy[:, 1::2] = paddle.clip(xyxy[:, 1::2], min=0, max=img0_shape[0]) + return xyxy + + +def get_crops(xyxy, ori_img, pred_scores, w, h): + crops = [] + keep_scores = [] + xyxy = xyxy.numpy().astype(np.int64) + ori_img = ori_img.numpy() + ori_img = np.squeeze(ori_img, axis=0).transpose(1, 0, 2) + pred_scores = pred_scores.numpy() + for i, bbox in enumerate(xyxy): + if bbox[2] <= bbox[0] or bbox[3] <= bbox[1]: + continue + crop = ori_img[bbox[0]:bbox[2], bbox[1]:bbox[3], :] + crops.append(crop) + keep_scores.append(pred_scores[i]) + if len(crops) == 0: + return [], [] + crops = preprocess_reid(crops, w, h) + return crops, keep_scores + + +def preprocess_reid(imgs, w=64, h=192, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): + im_batch = [] + for img in imgs: + img = cv2.resize(img, (w, h)) + img = img[:, :, ::-1].astype('float32').transpose((2, 0, 1)) / 255 + img_mean = np.array(mean).reshape((3, 1, 1)) + img_std = np.array(std).reshape((3, 1, 1)) + img -= img_mean + img /= img_std + img = np.expand_dims(img, axis=0) + im_batch.append(img) + im_batch = np.concatenate(im_batch, 0) + return im_batch diff --git a/modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/visualization.py b/modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/visualization.py new file mode 100644 index 0000000000000000000000000000000000000000..cd9c5b15e15f677b7955dd4eba40798e985315a1 --- /dev/null +++ b/modules/video/multiple_object_tracking/fairmot_dla34/modeling/mot/visualization.py @@ -0,0 +1,117 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import numpy as np + + +def tlwhs_to_tlbrs(tlwhs): + tlbrs = np.copy(tlwhs) + if len(tlbrs) == 0: + return tlbrs + tlbrs[:, 2] += tlwhs[:, 0] + tlbrs[:, 3] += tlwhs[:, 1] + return tlbrs + + +def get_color(idx): + idx = idx * 3 + color = ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255) + return color + + +def resize_image(image, max_size=800): + if max(image.shape[:2]) > max_size: + scale = float(max_size) / max(image.shape[:2]) + image = cv2.resize(image, None, fx=scale, fy=scale) + return image + + +def plot_tracking(image, tlwhs, obj_ids, scores=None, frame_id=0, fps=0., ids2=None): + im = np.ascontiguousarray(np.copy(image)) + im_h, im_w = im.shape[:2] + + top_view = np.zeros([im_w, im_w, 3], dtype=np.uint8) + 255 + + text_scale = max(1, image.shape[1] / 1600.) + text_thickness = 2 + line_thickness = max(1, int(image.shape[1] / 500.)) + + radius = max(5, int(im_w / 140.)) + cv2.putText( + im, + 'frame: %d fps: %.2f num: %d' % (frame_id, fps, len(tlwhs)), (0, int(15 * text_scale)), + cv2.FONT_HERSHEY_PLAIN, + text_scale, (0, 0, 255), + thickness=2) + + for i, tlwh in enumerate(tlwhs): + x1, y1, w, h = tlwh + intbox = tuple(map(int, (x1, y1, x1 + w, y1 + h))) + obj_id = int(obj_ids[i]) + id_text = '{}'.format(int(obj_id)) + if ids2 is not None: + id_text = id_text + ', {}'.format(int(ids2[i])) + _line_thickness = 1 if obj_id <= 0 else line_thickness + color = get_color(abs(obj_id)) + cv2.rectangle(im, intbox[0:2], intbox[2:4], color=color, thickness=line_thickness) + cv2.putText( + im, + id_text, (intbox[0], intbox[1] + 10), + cv2.FONT_HERSHEY_PLAIN, + text_scale, (0, 0, 255), + thickness=text_thickness) + + if scores is not None: + text = '{:.2f}'.format(float(scores[i])) + cv2.putText( + im, + text, (intbox[0], intbox[1] - 10), + cv2.FONT_HERSHEY_PLAIN, + text_scale, (0, 255, 255), + thickness=text_thickness) + return im + + +def plot_trajectory(image, tlwhs, track_ids): + image = image.copy() + for one_tlwhs, track_id in zip(tlwhs, track_ids): + color = get_color(int(track_id)) + for tlwh in one_tlwhs: + x1, y1, w, h = tuple(map(int, tlwh)) + cv2.circle(image, (int(x1 + 0.5 * w), int(y1 + h)), 2, color, thickness=2) + return image + + +def plot_detections(image, tlbrs, scores=None, color=(255, 0, 0), ids=None): + im = np.copy(image) + text_scale = max(1, image.shape[1] / 800.) + thickness = 2 if text_scale > 1.3 else 1 + for i, det in enumerate(tlbrs): + x1, y1, x2, y2 = np.asarray(det[:4], dtype=np.int) + if len(det) >= 7: + label = 'det' if det[5] > 0 else 'trk' + if ids is not None: + text = '{}# {:.2f}: {:d}'.format(label, det[6], ids[i]) + cv2.putText( + im, text, (x1, y1 + 30), cv2.FONT_HERSHEY_PLAIN, text_scale, (0, 255, 255), thickness=thickness) + else: + text = '{}# {:.2f}'.format(label, det[6]) + + if scores is not None: + text = '{:.2f}'.format(scores[i]) + cv2.putText(im, text, (x1, y1 + 30), cv2.FONT_HERSHEY_PLAIN, text_scale, (0, 255, 255), thickness=thickness) + + cv2.rectangle(im, (x1, y1), (x2, y2), color, 2) + return im diff --git a/modules/video/multiple_object_tracking/fairmot_dla34/tracker.py b/modules/video/multiple_object_tracking/fairmot_dla34/tracker.py index f641527ce94c8014db1afc0c5418bf6a278c352e..016f1e5878b12418ebb29344287bcfc6af830a8e 100644 --- a/modules/video/multiple_object_tracking/fairmot_dla34/tracker.py +++ b/modules/video/multiple_object_tracking/fairmot_dla34/tracker.py @@ -16,18 +16,19 @@ import cv2 import glob import paddle import numpy as np +import collections from ppdet.core.workspace import create from ppdet.utils.checkpoint import load_weight, load_pretrain_weight -from ppdet.modeling.mot.utils import Detection, get_crops, scale_coords, clip_box -from ppdet.modeling.mot.utils import Timer, load_det_results -from ppdet.modeling.mot import visualization as mot_vis from ppdet.metrics import Metric, MOTMetric, KITTIMOTMetric import ppdet.utils.stats as stats from ppdet.engine.callbacks import Callback, ComposeCallback from ppdet.utils.logger import setup_logger from .dataset import MOTVideoStream, MOTImageStream +from .utils import Timer +from .modeling.mot.utils import Detection, get_crops, scale_coords, clip_box +from .modeling.mot import visualization as mot_vis logger = setup_logger(__name__) @@ -71,7 +72,6 @@ class StreamTracker(object): timer.tic() pred_dets, pred_embs = self.model(data) online_targets = self.model.tracker.update(pred_dets, pred_embs) - online_tlwhs, online_ids = [], [] online_scores = [] for t in online_targets: @@ -109,7 +109,6 @@ class StreamTracker(object): timer.tic() pred_dets, pred_embs = self.model(data) online_targets = self.model.tracker.update(pred_dets, pred_embs) - online_tlwhs, online_ids = [], [] online_scores = [] for t in online_targets: diff --git a/modules/video/multiple_object_tracking/fairmot_dla34/utils.py b/modules/video/multiple_object_tracking/fairmot_dla34/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4426f217f9f5fb5c7afa6593c2b83ce4b67236f9 --- /dev/null +++ b/modules/video/multiple_object_tracking/fairmot_dla34/utils.py @@ -0,0 +1,39 @@ +import time + + +class Timer(object): + """ + This class used to compute and print the current FPS while evaling. + """ + + def __init__(self): + self.total_time = 0. + self.calls = 0 + self.start_time = 0. + self.diff = 0. + self.average_time = 0. + self.duration = 0. + + def tic(self): + # using time.time instead of time.clock because time time.clock + # does not normalize for multithreading + self.start_time = time.time() + + def toc(self, average=True): + self.diff = time.time() - self.start_time + self.total_time += self.diff + self.calls += 1 + self.average_time = self.total_time / self.calls + if average: + self.duration = self.average_time + else: + self.duration = self.diff + return self.duration + + def clear(self): + self.total_time = 0. + self.calls = 0 + self.start_time = 0. + self.diff = 0. + self.average_time = 0. + self.duration = 0. diff --git a/modules/video/multiple_object_tracking/jde_darknet53/config/_base_/jde_darknet53.yml b/modules/video/multiple_object_tracking/jde_darknet53/config/_base_/jde_darknet53.yml index 73faa52f662e7db24ef40c25c029561225d1a3b8..dcc67ac4276c3e8a3abd81950d970f3643d05551 100644 --- a/modules/video/multiple_object_tracking/jde_darknet53/config/_base_/jde_darknet53.yml +++ b/modules/video/multiple_object_tracking/jde_darknet53/config/_base_/jde_darknet53.yml @@ -5,7 +5,7 @@ find_unused_parameters: True JDE: detector: YOLOv3 reid: JDEEmbeddingHead - tracker: JDETracker + tracker: FrozenJDETracker YOLOv3: backbone: DarkNet diff --git a/modules/video/multiple_object_tracking/jde_darknet53/config/jde_darknet53_30e_1088x608.yml b/modules/video/multiple_object_tracking/jde_darknet53/config/jde_darknet53_30e_1088x608.yml index d2ac3aee460aaa378dcef11c3a3fce9aa4c29f05..33fa547afe9f95f5dfe7ea321c3e9be1c3634e1d 100644 --- a/modules/video/multiple_object_tracking/jde_darknet53/config/jde_darknet53_30e_1088x608.yml +++ b/modules/video/multiple_object_tracking/jde_darknet53/config/jde_darknet53_30e_1088x608.yml @@ -9,7 +9,7 @@ _BASE_: [ JDE: detector: YOLOv3 reid: JDEEmbeddingHead - tracker: JDETracker + tracker: FrozenJDETracker YOLOv3: backbone: DarkNet diff --git a/modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/__init__.py b/modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..258e4c9010832936f098e6febe777ac556f0668f --- /dev/null +++ b/modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import matching +from . import tracker +from . import motion +from . import visualization +from . import utils + +from .matching import * +from .tracker import * +from .motion import * +from .visualization import * +from .utils import * diff --git a/modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/matching/__init__.py b/modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/matching/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..54c6680f79f16247c562a9da1024dd3e1de4c57f --- /dev/null +++ b/modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/matching/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import jde_matching +from . import deepsort_matching + +from .jde_matching import * +from .deepsort_matching import * diff --git a/modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/matching/deepsort_matching.py b/modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/matching/deepsort_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..c55aa8876cc128f512aa4e2e4e48a935a3f8dd77 --- /dev/null +++ b/modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/matching/deepsort_matching.py @@ -0,0 +1,368 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is borrow from https://github.com/nwojke/deep_sort/tree/master/deep_sort +""" + +import numpy as np +from scipy.optimize import linear_sum_assignment +from ..motion import kalman_filter + +INFTY_COST = 1e+5 + +__all__ = [ + 'iou_1toN', + 'iou_cost', + '_nn_euclidean_distance', + '_nn_cosine_distance', + 'NearestNeighborDistanceMetric', + 'min_cost_matching', + 'matching_cascade', + 'gate_cost_matrix', +] + + +def iou_1toN(bbox, candidates): + """ + Computer intersection over union (IoU) by one box to N candidates. + + Args: + bbox (ndarray): A bounding box in format `(top left x, top left y, width, height)`. + candidates (ndarray): A matrix of candidate bounding boxes (one per row) in the + same format as `bbox`. + + Returns: + ious (ndarray): The intersection over union in [0, 1] between the `bbox` + and each candidate. A higher score means a larger fraction of the + `bbox` is occluded by the candidate. + """ + bbox_tl = bbox[:2] + bbox_br = bbox[:2] + bbox[2:] + candidates_tl = candidates[:, :2] + candidates_br = candidates[:, :2] + candidates[:, 2:] + + tl = np.c_[np.maximum(bbox_tl[0], candidates_tl[:, 0])[:, np.newaxis], + np.maximum(bbox_tl[1], candidates_tl[:, 1])[:, np.newaxis]] + br = np.c_[np.minimum(bbox_br[0], candidates_br[:, 0])[:, np.newaxis], + np.minimum(bbox_br[1], candidates_br[:, 1])[:, np.newaxis]] + wh = np.maximum(0., br - tl) + + area_intersection = wh.prod(axis=1) + area_bbox = bbox[2:].prod() + area_candidates = candidates[:, 2:].prod(axis=1) + ious = area_intersection / (area_bbox + area_candidates - area_intersection) + return ious + + +def iou_cost(tracks, detections, track_indices=None, detection_indices=None): + """ + IoU distance metric. + + Args: + tracks (list[Track]): A list of tracks. + detections (list[Detection]): A list of detections. + track_indices (Optional[list[int]]): A list of indices to tracks that + should be matched. Defaults to all `tracks`. + detection_indices (Optional[list[int]]): A list of indices to detections + that should be matched. Defaults to all `detections`. + + Returns: + cost_matrix (ndarray): A cost matrix of shape len(track_indices), + len(detection_indices) where entry (i, j) is + `1 - iou(tracks[track_indices[i]], detections[detection_indices[j]])`. + """ + if track_indices is None: + track_indices = np.arange(len(tracks)) + if detection_indices is None: + detection_indices = np.arange(len(detections)) + + cost_matrix = np.zeros((len(track_indices), len(detection_indices))) + for row, track_idx in enumerate(track_indices): + if tracks[track_idx].time_since_update > 1: + cost_matrix[row, :] = 1e+5 + continue + + bbox = tracks[track_idx].to_tlwh() + candidates = np.asarray([detections[i].tlwh for i in detection_indices]) + cost_matrix[row, :] = 1. - iou_1toN(bbox, candidates) + return cost_matrix + + +def _nn_euclidean_distance(s, q): + """ + Compute pair-wise squared (Euclidean) distance between points in `s` and `q`. + + Args: + s (ndarray): Sample points: an NxM matrix of N samples of dimensionality M. + q (ndarray): Query points: an LxM matrix of L samples of dimensionality M. + + Returns: + distances (ndarray): A vector of length M that contains for each entry in `q` the + smallest Euclidean distance to a sample in `s`. + """ + s, q = np.asarray(s), np.asarray(q) + if len(s) == 0 or len(q) == 0: + return np.zeros((len(s), len(q))) + s2, q2 = np.square(s).sum(axis=1), np.square(q).sum(axis=1) + distances = -2. * np.dot(s, q.T) + s2[:, None] + q2[None, :] + distances = np.clip(distances, 0., float(np.inf)) + + return np.maximum(0.0, distances.min(axis=0)) + + +def _nn_cosine_distance(s, q): + """ + Compute pair-wise cosine distance between points in `s` and `q`. + + Args: + s (ndarray): Sample points: an NxM matrix of N samples of dimensionality M. + q (ndarray): Query points: an LxM matrix of L samples of dimensionality M. + + Returns: + distances (ndarray): A vector of length M that contains for each entry in `q` the + smallest Euclidean distance to a sample in `s`. + """ + s = np.asarray(s) / np.linalg.norm(s, axis=1, keepdims=True) + q = np.asarray(q) / np.linalg.norm(q, axis=1, keepdims=True) + distances = 1. - np.dot(s, q.T) + + return distances.min(axis=0) + + +class NearestNeighborDistanceMetric(object): + """ + A nearest neighbor distance metric that, for each target, returns + the closest distance to any sample that has been observed so far. + + Args: + metric (str): Either "euclidean" or "cosine". + matching_threshold (float): The matching threshold. Samples with larger + distance are considered an invalid match. + budget (Optional[int]): If not None, fix samples per class to at most + this number. Removes the oldest samples when the budget is reached. + + Attributes: + samples (Dict[int -> List[ndarray]]): A dictionary that maps from target + identities to the list of samples that have been observed so far. + """ + + def __init__(self, metric, matching_threshold, budget=None): + if metric == "euclidean": + self._metric = _nn_euclidean_distance + elif metric == "cosine": + self._metric = _nn_cosine_distance + else: + raise ValueError("Invalid metric; must be either 'euclidean' or 'cosine'") + self.matching_threshold = matching_threshold + self.budget = budget + self.samples = {} + + def partial_fit(self, features, targets, active_targets): + """ + Update the distance metric with new data. + + Args: + features (ndarray): An NxM matrix of N features of dimensionality M. + targets (ndarray): An integer array of associated target identities. + active_targets (List[int]): A list of targets that are currently + present in the scene. + """ + for feature, target in zip(features, targets): + self.samples.setdefault(target, []).append(feature) + if self.budget is not None: + self.samples[target] = self.samples[target][-self.budget:] + self.samples = {k: self.samples[k] for k in active_targets} + + def distance(self, features, targets): + """ + Compute distance between features and targets. + + Args: + features (ndarray): An NxM matrix of N features of dimensionality M. + targets (list[int]): A list of targets to match the given `features` against. + + Returns: + cost_matrix (ndarray): a cost matrix of shape len(targets), len(features), + where element (i, j) contains the closest squared distance between + `targets[i]` and `features[j]`. + """ + cost_matrix = np.zeros((len(targets), len(features))) + for i, target in enumerate(targets): + cost_matrix[i, :] = self._metric(self.samples[target], features) + return cost_matrix + + +def min_cost_matching(distance_metric, max_distance, tracks, detections, track_indices=None, detection_indices=None): + """ + Solve linear assignment problem. + + Args: + distance_metric : + Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray + The distance metric is given a list of tracks and detections as + well as a list of N track indices and M detection indices. The + metric should return the NxM dimensional cost matrix, where element + (i, j) is the association cost between the i-th track in the given + track indices and the j-th detection in the given detection_indices. + max_distance (float): Gating threshold. Associations with cost larger + than this value are disregarded. + tracks (list[Track]): A list of predicted tracks at the current time + step. + detections (list[Detection]): A list of detections at the current time + step. + track_indices (list[int]): List of track indices that maps rows in + `cost_matrix` to tracks in `tracks`. + detection_indices (List[int]): List of detection indices that maps + columns in `cost_matrix` to detections in `detections`. + + Returns: + A tuple (List[(int, int)], List[int], List[int]) with the following + three entries: + * A list of matched track and detection indices. + * A list of unmatched track indices. + * A list of unmatched detection indices. + """ + if track_indices is None: + track_indices = np.arange(len(tracks)) + if detection_indices is None: + detection_indices = np.arange(len(detections)) + + if len(detection_indices) == 0 or len(track_indices) == 0: + return [], track_indices, detection_indices # Nothing to match. + + cost_matrix = distance_metric(tracks, detections, track_indices, detection_indices) + + cost_matrix[cost_matrix > max_distance] = max_distance + 1e-5 + indices = linear_sum_assignment(cost_matrix) + + matches, unmatched_tracks, unmatched_detections = [], [], [] + for col, detection_idx in enumerate(detection_indices): + if col not in indices[1]: + unmatched_detections.append(detection_idx) + for row, track_idx in enumerate(track_indices): + if row not in indices[0]: + unmatched_tracks.append(track_idx) + for row, col in zip(indices[0], indices[1]): + track_idx = track_indices[row] + detection_idx = detection_indices[col] + if cost_matrix[row, col] > max_distance: + unmatched_tracks.append(track_idx) + unmatched_detections.append(detection_idx) + else: + matches.append((track_idx, detection_idx)) + return matches, unmatched_tracks, unmatched_detections + + +def matching_cascade(distance_metric, + max_distance, + cascade_depth, + tracks, + detections, + track_indices=None, + detection_indices=None): + """ + Run matching cascade. + + Args: + distance_metric : + Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray + The distance metric is given a list of tracks and detections as + well as a list of N track indices and M detection indices. The + metric should return the NxM dimensional cost matrix, where element + (i, j) is the association cost between the i-th track in the given + track indices and the j-th detection in the given detection_indices. + max_distance (float): Gating threshold. Associations with cost larger + than this value are disregarded. + cascade_depth (int): The cascade depth, should be se to the maximum + track age. + tracks (list[Track]): A list of predicted tracks at the current time + step. + detections (list[Detection]): A list of detections at the current time + step. + track_indices (list[int]): List of track indices that maps rows in + `cost_matrix` to tracks in `tracks`. + detection_indices (List[int]): List of detection indices that maps + columns in `cost_matrix` to detections in `detections`. + + Returns: + A tuple (List[(int, int)], List[int], List[int]) with the following + three entries: + * A list of matched track and detection indices. + * A list of unmatched track indices. + * A list of unmatched detection indices. + """ + if track_indices is None: + track_indices = list(range(len(tracks))) + if detection_indices is None: + detection_indices = list(range(len(detections))) + + unmatched_detections = detection_indices + matches = [] + for level in range(cascade_depth): + if len(unmatched_detections) == 0: # No detections left + break + + track_indices_l = [k for k in track_indices if tracks[k].time_since_update == 1 + level] + if len(track_indices_l) == 0: # Nothing to match at this level + continue + + matches_l, _, unmatched_detections = \ + min_cost_matching( + distance_metric, max_distance, tracks, detections, + track_indices_l, unmatched_detections) + matches += matches_l + unmatched_tracks = list(set(track_indices) - set(k for k, _ in matches)) + return matches, unmatched_tracks, unmatched_detections + + +def gate_cost_matrix(kf, + cost_matrix, + tracks, + detections, + track_indices, + detection_indices, + gated_cost=INFTY_COST, + only_position=False): + """ + Invalidate infeasible entries in cost matrix based on the state + distributions obtained by Kalman filtering. + + Args: + kf (object): The Kalman filter. + cost_matrix (ndarray): The NxM dimensional cost matrix, where N is the + number of track indices and M is the number of detection indices, + such that entry (i, j) is the association cost between + `tracks[track_indices[i]]` and `detections[detection_indices[j]]`. + tracks (list[Track]): A list of predicted tracks at the current time + step. + detections (list[Detection]): A list of detections at the current time + step. + track_indices (List[int]): List of track indices that maps rows in + `cost_matrix` to tracks in `tracks`. + detection_indices (List[int]): List of detection indices that maps + columns in `cost_matrix` to detections in `detections`. + gated_cost (Optional[float]): Entries in the cost matrix corresponding + to infeasible associations are set this value. Defaults to a very + large value. + only_position (Optional[bool]): If True, only the x, y position of the + state distribution is considered during gating. Default False. + """ + gating_dim = 2 if only_position else 4 + gating_threshold = kalman_filter.chi2inv95[gating_dim] + measurements = np.asarray([detections[i].to_xyah() for i in detection_indices]) + for row, track_idx in enumerate(track_indices): + track = tracks[track_idx] + gating_distance = kf.gating_distance(track.mean, track.covariance, measurements, only_position) + cost_matrix[row, gating_distance > gating_threshold] = gated_cost + return cost_matrix diff --git a/modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/matching/jde_matching.py b/modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/matching/jde_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..bf2e891c391c98ed8944f88377f62c9722fa5155 --- /dev/null +++ b/modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/matching/jde_matching.py @@ -0,0 +1,123 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is borrow from https://github.com/Zhongdao/Towards-Realtime-MOT/blob/master/tracker/matching.py +""" + +import lap +import scipy +import numpy as np +from scipy.spatial.distance import cdist +from ..motion import kalman_filter + +from ppdet.utils.logger import setup_logger +logger = setup_logger(__name__) + +__all__ = [ + 'merge_matches', + 'linear_assignment', + 'cython_bbox_ious', + 'iou_distance', + 'embedding_distance', + 'fuse_motion', +] + + +def merge_matches(m1, m2, shape): + O, P, Q = shape + m1 = np.asarray(m1) + m2 = np.asarray(m2) + + M1 = scipy.sparse.coo_matrix((np.ones(len(m1)), (m1[:, 0], m1[:, 1])), shape=(O, P)) + M2 = scipy.sparse.coo_matrix((np.ones(len(m2)), (m2[:, 0], m2[:, 1])), shape=(P, Q)) + + mask = M1 * M2 + match = mask.nonzero() + match = list(zip(match[0], match[1])) + unmatched_O = tuple(set(range(O)) - set([i for i, j in match])) + unmatched_Q = tuple(set(range(Q)) - set([j for i, j in match])) + + return match, unmatched_O, unmatched_Q + + +def linear_assignment(cost_matrix, thresh): + if cost_matrix.size == 0: + return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1])) + matches, unmatched_a, unmatched_b = [], [], [] + cost, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh) + for ix, mx in enumerate(x): + if mx >= 0: + matches.append([ix, mx]) + unmatched_a = np.where(x < 0)[0] + unmatched_b = np.where(y < 0)[0] + matches = np.asarray(matches) + return matches, unmatched_a, unmatched_b + + +def cython_bbox_ious(atlbrs, btlbrs): + ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float) + if ious.size == 0: + return ious + try: + import cython_bbox + except Exception as e: + logger.error('cython_bbox not found, please install cython_bbox.' 'for example: `pip install cython_bbox`.') + raise e + + ious = cython_bbox.bbox_overlaps( + np.ascontiguousarray(atlbrs, dtype=np.float), np.ascontiguousarray(btlbrs, dtype=np.float)) + return ious + + +def iou_distance(atracks, btracks): + """ + Compute cost based on IoU between two list[STrack]. + """ + if (len(atracks) > 0 and isinstance(atracks[0], np.ndarray)) or (len(btracks) > 0 + and isinstance(btracks[0], np.ndarray)): + atlbrs = atracks + btlbrs = btracks + else: + atlbrs = [track.tlbr for track in atracks] + btlbrs = [track.tlbr for track in btracks] + _ious = cython_bbox_ious(atlbrs, btlbrs) + cost_matrix = 1 - _ious + + return cost_matrix + + +def embedding_distance(tracks, detections, metric='euclidean'): + """ + Compute cost based on features between two list[STrack]. + """ + cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float) + if cost_matrix.size == 0: + return cost_matrix + det_features = np.asarray([track.curr_feat for track in detections], dtype=np.float) + track_features = np.asarray([track.smooth_feat for track in tracks], dtype=np.float) + cost_matrix = np.maximum(0.0, cdist(track_features, det_features, metric)) # Nomalized features + return cost_matrix + + +def fuse_motion(kf, cost_matrix, tracks, detections, only_position=False, lambda_=0.98): + if cost_matrix.size == 0: + return cost_matrix + gating_dim = 2 if only_position else 4 + gating_threshold = kalman_filter.chi2inv95[gating_dim] + measurements = np.asarray([det.to_xyah() for det in detections]) + for row, track in enumerate(tracks): + gating_distance = kf.gating_distance(track.mean, track.covariance, measurements, only_position, metric='maha') + cost_matrix[row, gating_distance > gating_threshold] = np.inf + cost_matrix[row] = lambda_ * cost_matrix[row] + (1 - lambda_) * gating_distance + return cost_matrix diff --git a/modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/motion/__init__.py b/modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/motion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e42dd0b019d66d6ea07bec1ad90cf9a8d53d8172 --- /dev/null +++ b/modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/motion/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import kalman_filter + +from .kalman_filter import * diff --git a/modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/motion/kalman_filter.py b/modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/motion/kalman_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..7cc182e4c5e76e0688688c883b2a24fa30df9c74 --- /dev/null +++ b/modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/motion/kalman_filter.py @@ -0,0 +1,237 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is borrow from https://github.com/nwojke/deep_sort/blob/master/deep_sort/kalman_filter.py +""" + +import numpy as np +import scipy.linalg + +__all__ = ['KalmanFilter'] +""" +Table for the 0.95 quantile of the chi-square distribution with N degrees of +freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv +function and used as Mahalanobis gating threshold. +""" + +chi2inv95 = {1: 3.8415, 2: 5.9915, 3: 7.8147, 4: 9.4877, 5: 11.070, 6: 12.592, 7: 14.067, 8: 15.507, 9: 16.919} + + +class KalmanFilter(object): + """ + A simple Kalman filter for tracking bounding boxes in image space. + + The 8-dimensional state space + + x, y, a, h, vx, vy, va, vh + + contains the bounding box center position (x, y), aspect ratio a, height h, + and their respective velocities. + + Object motion follows a constant velocity model. The bounding box location + (x, y, a, h) is taken as direct observation of the state space (linear + observation model). + + """ + + def __init__(self): + ndim, dt = 4, 1. + + # Create Kalman filter model matrices. + self._motion_mat = np.eye(2 * ndim, 2 * ndim) + for i in range(ndim): + self._motion_mat[i, ndim + i] = dt + self._update_mat = np.eye(ndim, 2 * ndim) + + # Motion and observation uncertainty are chosen relative to the current + # state estimate. These weights control the amount of uncertainty in + # the model. This is a bit hacky. + self._std_weight_position = 1. / 20 + self._std_weight_velocity = 1. / 160 + + def initiate(self, measurement): + """ + Create track from unassociated measurement. + + Args: + measurement (ndarray): Bounding box coordinates (x, y, a, h) with + center position (x, y), aspect ratio a, and height h. + + Returns: + The mean vector (8 dimensional) and covariance matrix (8x8 + dimensional) of the new track. Unobserved velocities are + initialized to 0 mean. + """ + mean_pos = measurement + mean_vel = np.zeros_like(mean_pos) + mean = np.r_[mean_pos, mean_vel] + + std = [ + 2 * self._std_weight_position * measurement[3], 2 * self._std_weight_position * measurement[3], 1e-2, + 2 * self._std_weight_position * measurement[3], 10 * self._std_weight_velocity * measurement[3], + 10 * self._std_weight_velocity * measurement[3], 1e-5, 10 * self._std_weight_velocity * measurement[3] + ] + covariance = np.diag(np.square(std)) + return mean, covariance + + def predict(self, mean, covariance): + """ + Run Kalman filter prediction step. + + Args: + mean (ndarray): The 8 dimensional mean vector of the object state + at the previous time step. + covariance (ndarray): The 8x8 dimensional covariance matrix of the + object state at the previous time step. + + Returns: + The mean vector and covariance matrix of the predicted state. + Unobserved velocities are initialized to 0 mean. + """ + std_pos = [ + self._std_weight_position * mean[3], self._std_weight_position * mean[3], 1e-2, + self._std_weight_position * mean[3] + ] + std_vel = [ + self._std_weight_velocity * mean[3], self._std_weight_velocity * mean[3], 1e-5, + self._std_weight_velocity * mean[3] + ] + motion_cov = np.diag(np.square(np.r_[std_pos, std_vel])) + + #mean = np.dot(self._motion_mat, mean) + mean = np.dot(mean, self._motion_mat.T) + covariance = np.linalg.multi_dot((self._motion_mat, covariance, self._motion_mat.T)) + motion_cov + + return mean, covariance + + def project(self, mean, covariance): + """ + Project state distribution to measurement space. + + Args + mean (ndarray): The state's mean vector (8 dimensional array). + covariance (ndarray): The state's covariance matrix (8x8 dimensional). + + Returns: + The projected mean and covariance matrix of the given state estimate. + """ + std = [ + self._std_weight_position * mean[3], self._std_weight_position * mean[3], 1e-1, + self._std_weight_position * mean[3] + ] + innovation_cov = np.diag(np.square(std)) + + mean = np.dot(self._update_mat, mean) + covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T)) + return mean, covariance + innovation_cov + + def multi_predict(self, mean, covariance): + """ + Run Kalman filter prediction step (Vectorized version). + + Args: + mean (ndarray): The Nx8 dimensional mean matrix of the object states + at the previous time step. + covariance (ndarray): The Nx8x8 dimensional covariance matrics of the + object states at the previous time step. + + Returns: + The mean vector and covariance matrix of the predicted state. + Unobserved velocities are initialized to 0 mean. + """ + std_pos = [ + self._std_weight_position * mean[:, 3], self._std_weight_position * mean[:, 3], + 1e-2 * np.ones_like(mean[:, 3]), self._std_weight_position * mean[:, 3] + ] + std_vel = [ + self._std_weight_velocity * mean[:, 3], self._std_weight_velocity * mean[:, 3], + 1e-5 * np.ones_like(mean[:, 3]), self._std_weight_velocity * mean[:, 3] + ] + sqr = np.square(np.r_[std_pos, std_vel]).T + + motion_cov = [] + for i in range(len(mean)): + motion_cov.append(np.diag(sqr[i])) + motion_cov = np.asarray(motion_cov) + + mean = np.dot(mean, self._motion_mat.T) + left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2)) + covariance = np.dot(left, self._motion_mat.T) + motion_cov + + return mean, covariance + + def update(self, mean, covariance, measurement): + """ + Run Kalman filter correction step. + + Args: + mean (ndarray): The predicted state's mean vector (8 dimensional). + covariance (ndarray): The state's covariance matrix (8x8 dimensional). + measurement (ndarray): The 4 dimensional measurement vector + (x, y, a, h), where (x, y) is the center position, a the aspect + ratio, and h the height of the bounding box. + + Returns: + The measurement-corrected state distribution. + """ + projected_mean, projected_cov = self.project(mean, covariance) + + chol_factor, lower = scipy.linalg.cho_factor(projected_cov, lower=True, check_finite=False) + kalman_gain = scipy.linalg.cho_solve((chol_factor, lower), + np.dot(covariance, self._update_mat.T).T, + check_finite=False).T + innovation = measurement - projected_mean + + new_mean = mean + np.dot(innovation, kalman_gain.T) + new_covariance = covariance - np.linalg.multi_dot((kalman_gain, projected_cov, kalman_gain.T)) + return new_mean, new_covariance + + def gating_distance(self, mean, covariance, measurements, only_position=False, metric='maha'): + """ + Compute gating distance between state distribution and measurements. + A suitable distance threshold can be obtained from `chi2inv95`. If + `only_position` is False, the chi-square distribution has 4 degrees of + freedom, otherwise 2. + + Args: + mean (ndarray): Mean vector over the state distribution (8 + dimensional). + covariance (ndarray): Covariance of the state distribution (8x8 + dimensional). + measurements (ndarray): An Nx4 dimensional matrix of N measurements, + each in format (x, y, a, h) where (x, y) is the bounding box center + position, a the aspect ratio, and h the height. + only_position (Optional[bool]): If True, distance computation is + done with respect to the bounding box center position only. + metric (str): Metric type, 'gaussian' or 'maha'. + + Returns + An array of length N, where the i-th element contains the squared + Mahalanobis distance between (mean, covariance) and `measurements[i]`. + """ + mean, covariance = self.project(mean, covariance) + if only_position: + mean, covariance = mean[:2], covariance[:2, :2] + measurements = measurements[:, :2] + + d = measurements - mean + if metric == 'gaussian': + return np.sum(d * d, axis=1) + elif metric == 'maha': + cholesky_factor = np.linalg.cholesky(covariance) + z = scipy.linalg.solve_triangular(cholesky_factor, d.T, lower=True, check_finite=False, overwrite_b=True) + squared_maha = np.sum(z * z, axis=0) + return squared_maha + else: + raise ValueError('invalid distance metric') diff --git a/modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/tracker/__init__.py b/modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/tracker/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..904822119661be61141715c638388db9d045fee1 --- /dev/null +++ b/modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/tracker/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import base_jde_tracker +from . import base_sde_tracker +from . import jde_tracker + +from .base_jde_tracker import * +from .base_sde_tracker import * +from .jde_tracker import * diff --git a/modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/tracker/base_jde_tracker.py b/modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/tracker/base_jde_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..9505a709ee573acecf4b5dd7e02a06cee9d44284 --- /dev/null +++ b/modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/tracker/base_jde_tracker.py @@ -0,0 +1,257 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is borrow from https://github.com/Zhongdao/Towards-Realtime-MOT/blob/master/tracker/multitracker.py +""" + +import numpy as np +from collections import deque, OrderedDict +from ..matching import jde_matching as matching +from ppdet.core.workspace import register, serializable + +__all__ = [ + 'TrackState', + 'BaseTrack', + 'STrack', + 'joint_stracks', + 'sub_stracks', + 'remove_duplicate_stracks', +] + + +class TrackState(object): + New = 0 + Tracked = 1 + Lost = 2 + Removed = 3 + + +class BaseTrack(object): + _count = 0 + + track_id = 0 + is_activated = False + state = TrackState.New + + history = OrderedDict() + features = [] + curr_feature = None + score = 0 + start_frame = 0 + frame_id = 0 + time_since_update = 0 + + # multi-camera + location = (np.inf, np.inf) + + @property + def end_frame(self): + return self.frame_id + + @staticmethod + def next_id(): + BaseTrack._count += 1 + return BaseTrack._count + + def activate(self, *args): + raise NotImplementedError + + def predict(self): + raise NotImplementedError + + def update(self, *args, **kwargs): + raise NotImplementedError + + def mark_lost(self): + self.state = TrackState.Lost + + def mark_removed(self): + self.state = TrackState.Removed + + +class STrack(BaseTrack): + def __init__(self, tlwh, score, temp_feat, buffer_size=30): + # wait activate + self._tlwh = np.asarray(tlwh, dtype=np.float) + self.kalman_filter = None + self.mean, self.covariance = None, None + self.is_activated = False + + self.score = score + self.tracklet_len = 0 + + self.smooth_feat = None + self.update_features(temp_feat) + self.features = deque([], maxlen=buffer_size) + self.alpha = 0.9 + + def update_features(self, feat): + feat /= np.linalg.norm(feat) + self.curr_feat = feat + if self.smooth_feat is None: + self.smooth_feat = feat + else: + self.smooth_feat = self.alpha * self.smooth_feat + (1 - self.alpha) * feat + self.features.append(feat) + self.smooth_feat /= np.linalg.norm(self.smooth_feat) + + def predict(self): + mean_state = self.mean.copy() + if self.state != TrackState.Tracked: + mean_state[7] = 0 + self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance) + + @staticmethod + def multi_predict(stracks, kalman_filter): + if len(stracks) > 0: + multi_mean = np.asarray([st.mean.copy() for st in stracks]) + multi_covariance = np.asarray([st.covariance for st in stracks]) + for i, st in enumerate(stracks): + if st.state != TrackState.Tracked: + multi_mean[i][7] = 0 + multi_mean, multi_covariance = kalman_filter.multi_predict(multi_mean, multi_covariance) + for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)): + stracks[i].mean = mean + stracks[i].covariance = cov + + def activate(self, kalman_filter, frame_id): + """Start a new tracklet""" + self.kalman_filter = kalman_filter + self.track_id = self.next_id() + self.mean, self.covariance = self.kalman_filter.initiate(self.tlwh_to_xyah(self._tlwh)) + + self.tracklet_len = 0 + self.state = TrackState.Tracked + if frame_id == 1: + self.is_activated = True + self.frame_id = frame_id + self.start_frame = frame_id + + def re_activate(self, new_track, frame_id, new_id=False): + self.mean, self.covariance = self.kalman_filter.update(self.mean, self.covariance, + self.tlwh_to_xyah(new_track.tlwh)) + + self.update_features(new_track.curr_feat) + self.tracklet_len = 0 + self.state = TrackState.Tracked + self.is_activated = True + self.frame_id = frame_id + if new_id: + self.track_id = self.next_id() + + def update(self, new_track, frame_id, update_feature=True): + self.frame_id = frame_id + self.tracklet_len += 1 + + new_tlwh = new_track.tlwh + self.mean, self.covariance = self.kalman_filter.update(self.mean, self.covariance, self.tlwh_to_xyah(new_tlwh)) + self.state = TrackState.Tracked + self.is_activated = True + + self.score = new_track.score + if update_feature: + self.update_features(new_track.curr_feat) + + @property + def tlwh(self): + """ + Get current position in bounding box format `(top left x, top left y, + width, height)`. + """ + if self.mean is None: + return self._tlwh.copy() + ret = self.mean[:4].copy() + ret[2] *= ret[3] + ret[:2] -= ret[2:] / 2 + return ret + + @property + def tlbr(self): + """ + Convert bounding box to format `(min x, min y, max x, max y)`, i.e., + `(top left, bottom right)`. + """ + ret = self.tlwh.copy() + ret[2:] += ret[:2] + return ret + + @staticmethod + def tlwh_to_xyah(tlwh): + """ + Convert bounding box to format `(center x, center y, aspect ratio, + height)`, where the aspect ratio is `width / height`. + """ + ret = np.asarray(tlwh).copy() + ret[:2] += ret[2:] / 2 + ret[2] /= ret[3] + return ret + + def to_xyah(self): + return self.tlwh_to_xyah(self.tlwh) + + @staticmethod + def tlbr_to_tlwh(tlbr): + ret = np.asarray(tlbr).copy() + ret[2:] -= ret[:2] + return ret + + @staticmethod + def tlwh_to_tlbr(tlwh): + ret = np.asarray(tlwh).copy() + ret[2:] += ret[:2] + return ret + + def __repr__(self): + return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame) + + +def joint_stracks(tlista, tlistb): + exists = {} + res = [] + for t in tlista: + exists[t.track_id] = 1 + res.append(t) + for t in tlistb: + tid = t.track_id + if not exists.get(tid, 0): + exists[tid] = 1 + res.append(t) + return res + + +def sub_stracks(tlista, tlistb): + stracks = {} + for t in tlista: + stracks[t.track_id] = t + for t in tlistb: + tid = t.track_id + if stracks.get(tid, 0): + del stracks[tid] + return list(stracks.values()) + + +def remove_duplicate_stracks(stracksa, stracksb): + pdist = matching.iou_distance(stracksa, stracksb) + pairs = np.where(pdist < 0.15) + dupa, dupb = list(), list() + for p, q in zip(*pairs): + timep = stracksa[p].frame_id - stracksa[p].start_frame + timeq = stracksb[q].frame_id - stracksb[q].start_frame + if timep > timeq: + dupb.append(q) + else: + dupa.append(p) + resa = [t for i, t in enumerate(stracksa) if not i in dupa] + resb = [t for i, t in enumerate(stracksb) if not i in dupb] + return resa, resb diff --git a/modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/tracker/base_sde_tracker.py b/modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/tracker/base_sde_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..2e811e536a42ff781f60872b448b251de0301f61 --- /dev/null +++ b/modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/tracker/base_sde_tracker.py @@ -0,0 +1,133 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is borrow from https://github.com/nwojke/deep_sort/blob/master/deep_sort/track.py +""" + +from ppdet.core.workspace import register, serializable + +__all__ = ['TrackState', 'Track'] + + +class TrackState(object): + """ + Enumeration type for the single target track state. Newly created tracks are + classified as `tentative` until enough evidence has been collected. Then, + the track state is changed to `confirmed`. Tracks that are no longer alive + are classified as `deleted` to mark them for removal from the set of active + tracks. + """ + Tentative = 1 + Confirmed = 2 + Deleted = 3 + + +class Track(object): + """ + A single target track with state space `(x, y, a, h)` and associated + velocities, where `(x, y)` is the center of the bounding box, `a` is the + aspect ratio and `h` is the height. + + Args: + mean (ndarray): Mean vector of the initial state distribution. + covariance (ndarray): Covariance matrix of the initial state distribution. + track_id (int): A unique track identifier. + n_init (int): Number of consecutive detections before the track is confirmed. + The track state is set to `Deleted` if a miss occurs within the first + `n_init` frames. + max_age (int): The maximum number of consecutive misses before the track + state is set to `Deleted`. + feature (Optional[ndarray]): Feature vector of the detection this track + originates from. If not None, this feature is added to the `features` cache. + + Attributes: + hits (int): Total number of measurement updates. + age (int): Total number of frames since first occurance. + time_since_update (int): Total number of frames since last measurement + update. + state (TrackState): The current track state. + features (List[ndarray]): A cache of features. On each measurement update, + the associated feature vector is added to this list. + """ + + def __init__(self, mean, covariance, track_id, n_init, max_age, feature=None): + self.mean = mean + self.covariance = covariance + self.track_id = track_id + self.hits = 1 + self.age = 1 + self.time_since_update = 0 + + self.state = TrackState.Tentative + self.features = [] + if feature is not None: + self.features.append(feature) + + self._n_init = n_init + self._max_age = max_age + + def to_tlwh(self): + """Get position in format `(top left x, top left y, width, height)`.""" + ret = self.mean[:4].copy() + ret[2] *= ret[3] + ret[:2] -= ret[2:] / 2 + return ret + + def to_tlbr(self): + """Get position in bounding box format `(min x, miny, max x, max y)`.""" + ret = self.to_tlwh() + ret[2:] = ret[:2] + ret[2:] + return ret + + def predict(self, kalman_filter): + """ + Propagate the state distribution to the current time step using a Kalman + filter prediction step. + """ + self.mean, self.covariance = kalman_filter.predict(self.mean, self.covariance) + self.age += 1 + self.time_since_update += 1 + + def update(self, kalman_filter, detection): + """ + Perform Kalman filter measurement update step and update the associated + detection feature cache. + """ + self.mean, self.covariance = kalman_filter.update(self.mean, self.covariance, detection.to_xyah()) + self.features.append(detection.feature) + + self.hits += 1 + self.time_since_update = 0 + if self.state == TrackState.Tentative and self.hits >= self._n_init: + self.state = TrackState.Confirmed + + def mark_missed(self): + """Mark this track as missed (no association at the current time step). + """ + if self.state == TrackState.Tentative: + self.state = TrackState.Deleted + elif self.time_since_update > self._max_age: + self.state = TrackState.Deleted + + def is_tentative(self): + """Returns True if this track is tentative (unconfirmed).""" + return self.state == TrackState.Tentative + + def is_confirmed(self): + """Returns True if this track is confirmed.""" + return self.state == TrackState.Confirmed + + def is_deleted(self): + """Returns True if this track is dead and should be deleted.""" + return self.state == TrackState.Deleted diff --git a/modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/tracker/jde_tracker.py b/modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/tracker/jde_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..2e1cafb345b7687e563fc6d9c2c1769cb39d690c --- /dev/null +++ b/modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/tracker/jde_tracker.py @@ -0,0 +1,248 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is borrow from https://github.com/Zhongdao/Towards-Realtime-MOT/blob/master/tracker/multitracker.py +""" + +import paddle + +from ..matching import jde_matching as matching +from .base_jde_tracker import TrackState, BaseTrack, STrack +from .base_jde_tracker import joint_stracks, sub_stracks, remove_duplicate_stracks + +from ppdet.core.workspace import register, serializable +from ppdet.utils.logger import setup_logger +logger = setup_logger(__name__) + +__all__ = ['FrozenJDETracker'] + + +@register +@serializable +class FrozenJDETracker(object): + __inject__ = ['motion'] + """ + JDE tracker + + Args: + det_thresh (float): threshold of detection score + track_buffer (int): buffer for tracker + min_box_area (int): min box area to filter out low quality boxes + vertical_ratio (float): w/h, the vertical ratio of the bbox to filter + bad results, set 1.6 default for pedestrian tracking. If set -1 + means no need to filter bboxes. + tracked_thresh (float): linear assignment threshold of tracked + stracks and detections + r_tracked_thresh (float): linear assignment threshold of + tracked stracks and unmatched detections + unconfirmed_thresh (float): linear assignment threshold of + unconfirmed stracks and unmatched detections + motion (object): KalmanFilter instance + conf_thres (float): confidence threshold for tracking + metric_type (str): either "euclidean" or "cosine", the distance metric + used for measurement to track association. + """ + + def __init__(self, + det_thresh=0.3, + track_buffer=30, + min_box_area=200, + vertical_ratio=1.6, + tracked_thresh=0.7, + r_tracked_thresh=0.5, + unconfirmed_thresh=0.7, + motion='KalmanFilter', + conf_thres=0, + metric_type='euclidean'): + self.det_thresh = det_thresh + self.track_buffer = track_buffer + self.min_box_area = min_box_area + self.vertical_ratio = vertical_ratio + + self.tracked_thresh = tracked_thresh + self.r_tracked_thresh = r_tracked_thresh + self.unconfirmed_thresh = unconfirmed_thresh + self.motion = motion + self.conf_thres = conf_thres + self.metric_type = metric_type + + self.frame_id = 0 + self.tracked_stracks = [] + self.lost_stracks = [] + self.removed_stracks = [] + + self.max_time_lost = 0 + # max_time_lost will be calculated: int(frame_rate / 30.0 * track_buffer) + + def update(self, pred_dets, pred_embs): + """ + Processes the image frame and finds bounding box(detections). + Associates the detection with corresponding tracklets and also handles + lost, removed, refound and active tracklets. + + Args: + pred_dets (Tensor): Detection results of the image, shape is [N, 5]. + pred_embs (Tensor): Embedding results of the image, shape is [N, 512]. + + Return: + output_stracks (list): The list contains information regarding the + online_tracklets for the recieved image tensor. + """ + self.frame_id += 1 + activated_starcks = [] + # for storing active tracks, for the current frame + refind_stracks = [] + # Lost Tracks whose detections are obtained in the current frame + lost_stracks = [] + # The tracks which are not obtained in the current frame but are not + # removed. (Lost for some time lesser than the threshold for removing) + removed_stracks = [] + + remain_inds = paddle.nonzero(pred_dets[:, 4] > self.conf_thres) + if remain_inds.shape[0] == 0: + pred_dets = paddle.zeros([0, 1]) + pred_embs = paddle.zeros([0, 1]) + else: + pred_dets = paddle.gather(pred_dets, remain_inds) + pred_embs = paddle.gather(pred_embs, remain_inds) + + # Filter out the image with box_num = 0. pred_dets = [[0.0, 0.0, 0.0 ,0.0]] + empty_pred = True if len(pred_dets) == 1 and paddle.sum(pred_dets) == 0.0 else False + """ Step 1: Network forward, get detections & embeddings""" + if len(pred_dets) > 0 and not empty_pred: + pred_dets = pred_dets.numpy() + pred_embs = pred_embs.numpy() + detections = [ + STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f, 30) for (tlbrs, f) in zip(pred_dets, pred_embs) + ] + else: + detections = [] + ''' Add newly detected tracklets to tracked_stracks''' + unconfirmed = [] + tracked_stracks = [] # type: list[STrack] + for track in self.tracked_stracks: + if not track.is_activated: + # previous tracks which are not active in the current frame are added in unconfirmed list + unconfirmed.append(track) + else: + # Active tracks are added to the local list 'tracked_stracks' + tracked_stracks.append(track) + """ Step 2: First association, with embedding""" + # Combining currently tracked_stracks and lost_stracks + strack_pool = joint_stracks(tracked_stracks, self.lost_stracks) + # Predict the current location with KF + STrack.multi_predict(strack_pool, self.motion) + + dists = matching.embedding_distance(strack_pool, detections, metric=self.metric_type) + dists = matching.fuse_motion(self.motion, dists, strack_pool, detections) + # The dists is the list of distances of the detection with the tracks in strack_pool + matches, u_track, u_detection = matching.linear_assignment(dists, thresh=self.tracked_thresh) + # The matches is the array for corresponding matches of the detection with the corresponding strack_pool + + for itracked, idet in matches: + # itracked is the id of the track and idet is the detection + track = strack_pool[itracked] + det = detections[idet] + if track.state == TrackState.Tracked: + # If the track is active, add the detection to the track + track.update(detections[idet], self.frame_id) + activated_starcks.append(track) + else: + # We have obtained a detection from a track which is not active, + # hence put the track in refind_stracks list + track.re_activate(det, self.frame_id, new_id=False) + refind_stracks.append(track) + + # None of the steps below happen if there are no undetected tracks. + """ Step 3: Second association, with IOU""" + detections = [detections[i] for i in u_detection] + # detections is now a list of the unmatched detections + r_tracked_stracks = [] + # This is container for stracks which were tracked till the previous + # frame but no detection was found for it in the current frame. + + for i in u_track: + if strack_pool[i].state == TrackState.Tracked: + r_tracked_stracks.append(strack_pool[i]) + dists = matching.iou_distance(r_tracked_stracks, detections) + matches, u_track, u_detection = matching.linear_assignment(dists, thresh=self.r_tracked_thresh) + # matches is the list of detections which matched with corresponding + # tracks by IOU distance method. + + for itracked, idet in matches: + track = r_tracked_stracks[itracked] + det = detections[idet] + if track.state == TrackState.Tracked: + track.update(det, self.frame_id) + activated_starcks.append(track) + else: + track.re_activate(det, self.frame_id, new_id=False) + refind_stracks.append(track) + # Same process done for some unmatched detections, but now considering IOU_distance as measure + + for it in u_track: + track = r_tracked_stracks[it] + if not track.state == TrackState.Lost: + track.mark_lost() + lost_stracks.append(track) + # If no detections are obtained for tracks (u_track), the tracks are added to lost_tracks list and are marked lost + '''Deal with unconfirmed tracks, usually tracks with only one beginning frame''' + detections = [detections[i] for i in u_detection] + dists = matching.iou_distance(unconfirmed, detections) + matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=self.unconfirmed_thresh) + for itracked, idet in matches: + unconfirmed[itracked].update(detections[idet], self.frame_id) + activated_starcks.append(unconfirmed[itracked]) + + # The tracks which are yet not matched + for it in u_unconfirmed: + track = unconfirmed[it] + track.mark_removed() + removed_stracks.append(track) + + # after all these confirmation steps, if a new detection is found, it is initialized for a new track + """ Step 4: Init new stracks""" + for inew in u_detection: + track = detections[inew] + if track.score < self.det_thresh: + continue + track.activate(self.motion, self.frame_id) + activated_starcks.append(track) + """ Step 5: Update state""" + # If the tracks are lost for more frames than the threshold number, the tracks are removed. + for track in self.lost_stracks: + if self.frame_id - track.end_frame > self.max_time_lost: + track.mark_removed() + removed_stracks.append(track) + + # Update the self.tracked_stracks and self.lost_stracks using the updates in this step. + self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked] + self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks) + self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks) + + self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks) + self.lost_stracks.extend(lost_stracks) + self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks) + self.removed_stracks.extend(removed_stracks) + self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks) + # get scores of lost tracks + output_stracks = [track for track in self.tracked_stracks if track.is_activated] + + logger.debug('===========Frame {}=========='.format(self.frame_id)) + logger.debug('Activated: {}'.format([track.track_id for track in activated_starcks])) + logger.debug('Refind: {}'.format([track.track_id for track in refind_stracks])) + logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks])) + logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks])) + + return output_stracks diff --git a/modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/utils.py b/modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..12c61686a1715a965407822dcf19fd1081f292d7 --- /dev/null +++ b/modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/utils.py @@ -0,0 +1,176 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import cv2 +import time +import paddle +import numpy as np + +__all__ = [ + 'Timer', + 'Detection', + 'load_det_results', + 'preprocess_reid', + 'get_crops', + 'clip_box', + 'scale_coords', +] + + +class Timer(object): + """ + This class used to compute and print the current FPS while evaling. + """ + + def __init__(self): + self.total_time = 0. + self.calls = 0 + self.start_time = 0. + self.diff = 0. + self.average_time = 0. + self.duration = 0. + + def tic(self): + # using time.time instead of time.clock because time time.clock + # does not normalize for multithreading + self.start_time = time.time() + + def toc(self, average=True): + self.diff = time.time() - self.start_time + self.total_time += self.diff + self.calls += 1 + self.average_time = self.total_time / self.calls + if average: + self.duration = self.average_time + else: + self.duration = self.diff + return self.duration + + def clear(self): + self.total_time = 0. + self.calls = 0 + self.start_time = 0. + self.diff = 0. + self.average_time = 0. + self.duration = 0. + + +class Detection(object): + """ + This class represents a bounding box detection in a single image. + + Args: + tlwh (ndarray): Bounding box in format `(top left x, top left y, + width, height)`. + confidence (ndarray): Detector confidence score. + feature (Tensor): A feature vector that describes the object + contained in this image. + """ + + def __init__(self, tlwh, confidence, feature): + self.tlwh = np.asarray(tlwh, dtype=np.float32) + self.confidence = np.asarray(confidence, dtype=np.float32) + self.feature = feature + + def to_tlbr(self): + """ + Convert bounding box to format `(min x, min y, max x, max y)`, i.e., + `(top left, bottom right)`. + """ + ret = self.tlwh.copy() + ret[2:] += ret[:2] + return ret + + def to_xyah(self): + """ + Convert bounding box to format `(center x, center y, aspect ratio, + height)`, where the aspect ratio is `width / height`. + """ + ret = self.tlwh.copy() + ret[:2] += ret[2:] / 2 + ret[2] /= ret[3] + return ret + + +def load_det_results(det_file, num_frames): + assert os.path.exists(det_file) and os.path.isfile(det_file), \ + 'Error: det_file: {} not exist or not a file.'.format(det_file) + labels = np.loadtxt(det_file, dtype='float32', delimiter=',') + results_list = [] + for frame_i in range(0, num_frames): + results = {'bbox': [], 'score': []} + lables_with_frame = labels[labels[:, 0] == frame_i + 1] + for l in lables_with_frame: + results['bbox'].append(l[1:5]) + results['score'].append(l[5]) + results_list.append(results) + return results_list + + +def scale_coords(coords, input_shape, im_shape, scale_factor): + im_shape = im_shape.numpy()[0] + ratio = scale_factor[0][0] + pad_w = (input_shape[1] - int(im_shape[1])) / 2 + pad_h = (input_shape[0] - int(im_shape[0])) / 2 + coords = paddle.cast(coords, 'float32') + coords[:, 0::2] -= pad_w + coords[:, 1::2] -= pad_h + coords[:, 0:4] /= ratio + coords[:, :4] = paddle.clip(coords[:, :4], min=0, max=coords[:, :4].max()) + return coords.round() + + +def clip_box(xyxy, input_shape, im_shape, scale_factor): + im_shape = im_shape.numpy()[0] + ratio = scale_factor.numpy()[0][0] + img0_shape = [int(im_shape[0] / ratio), int(im_shape[1] / ratio)] + + xyxy[:, 0::2] = paddle.clip(xyxy[:, 0::2], min=0, max=img0_shape[1]) + xyxy[:, 1::2] = paddle.clip(xyxy[:, 1::2], min=0, max=img0_shape[0]) + return xyxy + + +def get_crops(xyxy, ori_img, pred_scores, w, h): + crops = [] + keep_scores = [] + xyxy = xyxy.numpy().astype(np.int64) + ori_img = ori_img.numpy() + ori_img = np.squeeze(ori_img, axis=0).transpose(1, 0, 2) + pred_scores = pred_scores.numpy() + for i, bbox in enumerate(xyxy): + if bbox[2] <= bbox[0] or bbox[3] <= bbox[1]: + continue + crop = ori_img[bbox[0]:bbox[2], bbox[1]:bbox[3], :] + crops.append(crop) + keep_scores.append(pred_scores[i]) + if len(crops) == 0: + return [], [] + crops = preprocess_reid(crops, w, h) + return crops, keep_scores + + +def preprocess_reid(imgs, w=64, h=192, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): + im_batch = [] + for img in imgs: + img = cv2.resize(img, (w, h)) + img = img[:, :, ::-1].astype('float32').transpose((2, 0, 1)) / 255 + img_mean = np.array(mean).reshape((3, 1, 1)) + img_std = np.array(std).reshape((3, 1, 1)) + img -= img_mean + img /= img_std + img = np.expand_dims(img, axis=0) + im_batch.append(img) + im_batch = np.concatenate(im_batch, 0) + return im_batch diff --git a/modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/visualization.py b/modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/visualization.py new file mode 100644 index 0000000000000000000000000000000000000000..cd9c5b15e15f677b7955dd4eba40798e985315a1 --- /dev/null +++ b/modules/video/multiple_object_tracking/jde_darknet53/modeling/mot/visualization.py @@ -0,0 +1,117 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import numpy as np + + +def tlwhs_to_tlbrs(tlwhs): + tlbrs = np.copy(tlwhs) + if len(tlbrs) == 0: + return tlbrs + tlbrs[:, 2] += tlwhs[:, 0] + tlbrs[:, 3] += tlwhs[:, 1] + return tlbrs + + +def get_color(idx): + idx = idx * 3 + color = ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255) + return color + + +def resize_image(image, max_size=800): + if max(image.shape[:2]) > max_size: + scale = float(max_size) / max(image.shape[:2]) + image = cv2.resize(image, None, fx=scale, fy=scale) + return image + + +def plot_tracking(image, tlwhs, obj_ids, scores=None, frame_id=0, fps=0., ids2=None): + im = np.ascontiguousarray(np.copy(image)) + im_h, im_w = im.shape[:2] + + top_view = np.zeros([im_w, im_w, 3], dtype=np.uint8) + 255 + + text_scale = max(1, image.shape[1] / 1600.) + text_thickness = 2 + line_thickness = max(1, int(image.shape[1] / 500.)) + + radius = max(5, int(im_w / 140.)) + cv2.putText( + im, + 'frame: %d fps: %.2f num: %d' % (frame_id, fps, len(tlwhs)), (0, int(15 * text_scale)), + cv2.FONT_HERSHEY_PLAIN, + text_scale, (0, 0, 255), + thickness=2) + + for i, tlwh in enumerate(tlwhs): + x1, y1, w, h = tlwh + intbox = tuple(map(int, (x1, y1, x1 + w, y1 + h))) + obj_id = int(obj_ids[i]) + id_text = '{}'.format(int(obj_id)) + if ids2 is not None: + id_text = id_text + ', {}'.format(int(ids2[i])) + _line_thickness = 1 if obj_id <= 0 else line_thickness + color = get_color(abs(obj_id)) + cv2.rectangle(im, intbox[0:2], intbox[2:4], color=color, thickness=line_thickness) + cv2.putText( + im, + id_text, (intbox[0], intbox[1] + 10), + cv2.FONT_HERSHEY_PLAIN, + text_scale, (0, 0, 255), + thickness=text_thickness) + + if scores is not None: + text = '{:.2f}'.format(float(scores[i])) + cv2.putText( + im, + text, (intbox[0], intbox[1] - 10), + cv2.FONT_HERSHEY_PLAIN, + text_scale, (0, 255, 255), + thickness=text_thickness) + return im + + +def plot_trajectory(image, tlwhs, track_ids): + image = image.copy() + for one_tlwhs, track_id in zip(tlwhs, track_ids): + color = get_color(int(track_id)) + for tlwh in one_tlwhs: + x1, y1, w, h = tuple(map(int, tlwh)) + cv2.circle(image, (int(x1 + 0.5 * w), int(y1 + h)), 2, color, thickness=2) + return image + + +def plot_detections(image, tlbrs, scores=None, color=(255, 0, 0), ids=None): + im = np.copy(image) + text_scale = max(1, image.shape[1] / 800.) + thickness = 2 if text_scale > 1.3 else 1 + for i, det in enumerate(tlbrs): + x1, y1, x2, y2 = np.asarray(det[:4], dtype=np.int) + if len(det) >= 7: + label = 'det' if det[5] > 0 else 'trk' + if ids is not None: + text = '{}# {:.2f}: {:d}'.format(label, det[6], ids[i]) + cv2.putText( + im, text, (x1, y1 + 30), cv2.FONT_HERSHEY_PLAIN, text_scale, (0, 255, 255), thickness=thickness) + else: + text = '{}# {:.2f}'.format(label, det[6]) + + if scores is not None: + text = '{:.2f}'.format(scores[i]) + cv2.putText(im, text, (x1, y1 + 30), cv2.FONT_HERSHEY_PLAIN, text_scale, (0, 255, 255), thickness=thickness) + + cv2.rectangle(im, (x1, y1), (x2, y2), color, 2) + return im diff --git a/modules/video/multiple_object_tracking/jde_darknet53/tracker.py b/modules/video/multiple_object_tracking/jde_darknet53/tracker.py index a4488125b11e09d7fb6e4328252ad61e8e844aac..1e4ab7d0b3a996775407eb1334c6183db26129d7 100644 --- a/modules/video/multiple_object_tracking/jde_darknet53/tracker.py +++ b/modules/video/multiple_object_tracking/jde_darknet53/tracker.py @@ -16,18 +16,19 @@ import cv2 import glob import paddle import numpy as np +import collections -from ppdet.core.workspace import create from ppdet.utils.checkpoint import load_weight, load_pretrain_weight -from ppdet.modeling.mot.utils import Detection, get_crops, scale_coords, clip_box -from ppdet.modeling.mot.utils import Timer, load_det_results -from ppdet.modeling.mot import visualization as mot_vis from ppdet.metrics import Metric, MOTMetric, KITTIMOTMetric import ppdet.utils.stats as stats from ppdet.engine.callbacks import Callback, ComposeCallback +from ppdet.core.workspace import create from ppdet.utils.logger import setup_logger from .dataset import MOTVideoStream, MOTImageStream +from .modeling.mot.utils import Detection, get_crops, scale_coords, clip_box +from .modeling.mot import visualization as mot_vis +from .utils import Timer logger = setup_logger(__name__) @@ -70,7 +71,6 @@ class StreamTracker(object): timer.tic() pred_dets, pred_embs = self.model(data) online_targets = self.model.tracker.update(pred_dets, pred_embs) - online_tlwhs, online_ids = [], [] online_scores = [] for t in online_targets: @@ -109,7 +109,6 @@ class StreamTracker(object): with paddle.no_grad(): pred_dets, pred_embs = self.model(data) online_targets = self.model.tracker.update(pred_dets, pred_embs) - online_tlwhs, online_ids = [], [] online_scores = [] for t in online_targets: diff --git a/modules/video/multiple_object_tracking/jde_darknet53/utils.py b/modules/video/multiple_object_tracking/jde_darknet53/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4426f217f9f5fb5c7afa6593c2b83ce4b67236f9 --- /dev/null +++ b/modules/video/multiple_object_tracking/jde_darknet53/utils.py @@ -0,0 +1,39 @@ +import time + + +class Timer(object): + """ + This class used to compute and print the current FPS while evaling. + """ + + def __init__(self): + self.total_time = 0. + self.calls = 0 + self.start_time = 0. + self.diff = 0. + self.average_time = 0. + self.duration = 0. + + def tic(self): + # using time.time instead of time.clock because time time.clock + # does not normalize for multithreading + self.start_time = time.time() + + def toc(self, average=True): + self.diff = time.time() - self.start_time + self.total_time += self.diff + self.calls += 1 + self.average_time = self.total_time / self.calls + if average: + self.duration = self.average_time + else: + self.duration = self.diff + return self.duration + + def clear(self): + self.total_time = 0. + self.calls = 0 + self.start_time = 0. + self.diff = 0. + self.average_time = 0. + self.duration = 0.