# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ This code is based on https://github.com/Zhongdao/Towards-Realtime-MOT/blob/master/tracker/multitracker.py """ import numpy as np from collections import defaultdict from collections import deque, OrderedDict from ..matching import jde_matching as matching from ppdet.core.workspace import register, serializable import warnings warnings.filterwarnings("ignore") __all__ = [ 'TrackState', 'BaseTrack', 'STrack', 'joint_stracks', 'sub_stracks', 'remove_duplicate_stracks', ] class TrackState(object): New = 0 Tracked = 1 Lost = 2 Removed = 3 @register @serializable class BaseTrack(object): _count_dict = defaultdict(int) # support single class and multi classes track_id = 0 is_activated = False state = TrackState.New history = OrderedDict() features = [] curr_feat = None score = 0 start_frame = 0 frame_id = 0 time_since_update = 0 # multi-camera location = (np.inf, np.inf) @property def end_frame(self): return self.frame_id @staticmethod def next_id(cls_id): BaseTrack._count_dict[cls_id] += 1 return BaseTrack._count_dict[cls_id] # @even: reset track id @staticmethod def init_count(num_classes): """ Initiate _count for all object classes :param num_classes: """ for cls_id in range(num_classes): BaseTrack._count_dict[cls_id] = 0 @staticmethod def reset_track_count(cls_id): BaseTrack._count_dict[cls_id] = 0 def activate(self, *args): raise NotImplementedError def predict(self): raise NotImplementedError def update(self, *args, **kwargs): raise NotImplementedError def mark_lost(self): self.state = TrackState.Lost def mark_removed(self): self.state = TrackState.Removed @register @serializable class STrack(BaseTrack): def __init__(self, tlwh, score, cls_id, buff_size=30, temp_feat=None): # wait activate self._tlwh = np.asarray(tlwh, dtype=np.float32) self.score = score self.cls_id = cls_id self.track_len = 0 self.kalman_filter = None self.mean, self.covariance = None, None self.is_activated = False self.use_reid = True if temp_feat is not None else False if self.use_reid: self.smooth_feat = None self.update_features(temp_feat) self.features = deque([], maxlen=buff_size) self.alpha = 0.9 def update_features(self, feat): # L2 normalizing, this function has no use for BYTETracker feat /= np.linalg.norm(feat) self.curr_feat = feat if self.smooth_feat is None: self.smooth_feat = feat else: self.smooth_feat = self.alpha * self.smooth_feat + (1.0 - self.alpha ) * feat self.features.append(feat) self.smooth_feat /= np.linalg.norm(self.smooth_feat) def predict(self): mean_state = self.mean.copy() if self.state != TrackState.Tracked: mean_state[7] = 0 self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance) @staticmethod def multi_predict(tracks, kalman_filter): if len(tracks) > 0: multi_mean = np.asarray([track.mean.copy() for track in tracks]) multi_covariance = np.asarray( [track.covariance for track in tracks]) for i, st in enumerate(tracks): if st.state != TrackState.Tracked: multi_mean[i][7] = 0 multi_mean, multi_covariance = kalman_filter.multi_predict( multi_mean, multi_covariance) for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)): tracks[i].mean = mean tracks[i].covariance = cov def reset_track_id(self): self.reset_track_count(self.cls_id) def activate(self, kalman_filter, frame_id): """Start a new track""" self.kalman_filter = kalman_filter # update track id for the object class self.track_id = self.next_id(self.cls_id) self.mean, self.covariance = self.kalman_filter.initiate( self.tlwh_to_xyah(self._tlwh)) self.track_len = 0 self.state = TrackState.Tracked # set flag 'tracked' if frame_id == 1: # to record the first frame's detection result self.is_activated = True self.frame_id = frame_id self.start_frame = frame_id def re_activate(self, new_track, frame_id, new_id=False): self.mean, self.covariance = self.kalman_filter.update( self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh)) if self.use_reid: self.update_features(new_track.curr_feat) self.track_len = 0 self.state = TrackState.Tracked self.is_activated = True self.frame_id = frame_id if new_id: # update track id for the object class self.track_id = self.next_id(self.cls_id) def update(self, new_track, frame_id, update_feature=True): self.frame_id = frame_id self.track_len += 1 new_tlwh = new_track.tlwh self.mean, self.covariance = self.kalman_filter.update( self.mean, self.covariance, self.tlwh_to_xyah(new_tlwh)) self.state = TrackState.Tracked # set flag 'tracked' self.is_activated = True # set flag 'activated' self.score = new_track.score if update_feature and self.use_reid: self.update_features(new_track.curr_feat) @property def tlwh(self): """Get current position in bounding box format `(top left x, top left y, width, height)`. """ if self.mean is None: return self._tlwh.copy() ret = self.mean[:4].copy() ret[2] *= ret[3] ret[:2] -= ret[2:] / 2 return ret @property def tlbr(self): """Convert bounding box to format `(min x, min y, max x, max y)`, i.e., `(top left, bottom right)`. """ ret = self.tlwh.copy() ret[2:] += ret[:2] return ret @staticmethod def tlwh_to_xyah(tlwh): """Convert bounding box to format `(center x, center y, aspect ratio, height)`, where the aspect ratio is `width / height`. """ ret = np.asarray(tlwh).copy() ret[:2] += ret[2:] / 2 ret[2] /= ret[3] return ret def to_xyah(self): return self.tlwh_to_xyah(self.tlwh) @staticmethod def tlbr_to_tlwh(tlbr): ret = np.asarray(tlbr).copy() ret[2:] -= ret[:2] return ret @staticmethod def tlwh_to_tlbr(tlwh): ret = np.asarray(tlwh).copy() ret[2:] += ret[:2] return ret def __repr__(self): return 'OT_({}-{})_({}-{})'.format(self.cls_id, self.track_id, self.start_frame, self.end_frame) def joint_stracks(tlista, tlistb): exists = {} res = [] for t in tlista: exists[t.track_id] = 1 res.append(t) for t in tlistb: tid = t.track_id if not exists.get(tid, 0): exists[tid] = 1 res.append(t) return res def sub_stracks(tlista, tlistb): stracks = {} for t in tlista: stracks[t.track_id] = t for t in tlistb: tid = t.track_id if stracks.get(tid, 0): del stracks[tid] return list(stracks.values()) def remove_duplicate_stracks(stracksa, stracksb): pdist = matching.iou_distance(stracksa, stracksb) pairs = np.where(pdist < 0.15) dupa, dupb = list(), list() for p, q in zip(*pairs): timep = stracksa[p].frame_id - stracksa[p].start_frame timeq = stracksb[q].frame_id - stracksb[q].start_frame if timep > timeq: dupb.append(q) else: dupa.append(p) resa = [t for i, t in enumerate(stracksa) if not i in dupa] resb = [t for i, t in enumerate(stracksb) if not i in dupb] return resa, resb