未验证 提交 08370fcc 编写于 作者: F Feng Ni 提交者: GitHub

[MOT] add bytetracker (#4910)

上级 b3e0bd3a
...@@ -81,6 +81,24 @@ PP-tracking provides an AI studio public project tutorial. Please refer to this ...@@ -81,6 +81,24 @@ PP-tracking provides an AI studio public project tutorial. Please refer to this
**Notes:** **Notes:**
- FairMOT HRNetV2-W18 used 8 GPUs for training and mini-batch size as 4 on each GPU, and trained for 30 epoches. Only ImageNet pre-train model is used, and the optimizer adopts Momentum. The crowdhuman dataset is added to the train-set during training. - FairMOT HRNetV2-W18 used 8 GPUs for training and mini-batch size as 4 on each GPU, and trained for 30 epoches. Only ImageNet pre-train model is used, and the optimizer adopts Momentum. The crowdhuman dataset is added to the train-set during training.
### FairMOT + BYTETracker
### Results on MOT-17 Half Set
| backbone | input shape | MOTA | IDF1 | IDS | FP | FN | FPS | download | config |
| :--------------| :------- | :----: | :----: | :----: | :----: | :----: | :------: | :----: |:-----: |
| DLA-34 | 1088x608 | 69.1 | 72.8 | 299 | 1957 | 14412 | - |[model](https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_bytetracker.pdparams) | [config](./fairmot_dla34_30e_1088x608.yml) |
| DLA-34 + BYTETracker| 1088x608 | 70.3 | 73.2 | 234 | 2176 | 13598 | - |[model](https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_bytetracker.pdparams) | [config](./fairmot_dla34_30e_1088x608_bytetracker.yml) |
**Notes:**
- FairMOT here is for ablation study, the training dataset is the 5 datasets of MIX(Caltech,CUHKSYSU,PRW,Cityscapes,ETHZ) and the first half of MOT17 Train, and the pretrain weights is CenterNet COCO model, the evaluation is on the second half of MOT17 Train.
- BYTETracker adapt to other FairMOT models of PaddleDetection, you can modify the tracker of the config like this:
```
JDETracker:
use_byte: True
match_thres: 0.8
conf_thres: 0.4
low_conf_thres: 0.2
```
## Getting Start ## Getting Start
......
...@@ -77,6 +77,25 @@ PP-Tracking 提供了AI Studio公开项目案例,教程请参考[PP-Tracking ...@@ -77,6 +77,25 @@ PP-Tracking 提供了AI Studio公开项目案例,教程请参考[PP-Tracking
**注意:** **注意:**
- FairMOT HRNetV2-W18均使用8个GPU进行训练,每个GPU上batch size为4,训练30个epoch,使用的ImageNet预训练,优化器策略采用的是Momentum,并且训练集中加入了crowdhuman数据集一起参与训练。 - FairMOT HRNetV2-W18均使用8个GPU进行训练,每个GPU上batch size为4,训练30个epoch,使用的ImageNet预训练,优化器策略采用的是Momentum,并且训练集中加入了crowdhuman数据集一起参与训练。
### FairMOT + BYTETracker
### 在MOT-17 Half上结果
| 骨干网络 | 输入尺寸 | MOTA | IDF1 | IDS | FP | FN | FPS | 下载链接 | 配置文件 |
| :--------------| :------- | :----: | :----: | :----: | :----: | :----: | :------: | :----: |:-----: |
| DLA-34 | 1088x608 | 69.1 | 72.8 | 299 | 1957 | 14412 | - |[下载链接](https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_bytetracker.pdparams) | [配置文件](./fairmot_dla34_30e_1088x608.yml) |
| DLA-34 + BYTETracker| 1088x608 | 70.3 | 73.2 | 234 | 2176 | 13598 | - |[下载链接](https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_bytetracker.pdparams) | [配置文件](./fairmot_dla34_30e_1088x608_bytetracker.yml) |
**注意:**
- FairMOT模型此处是ablation study的配置,使用的训练集是原先MIX的5个数据集(Caltech,CUHKSYSU,PRW,Cityscapes,ETHZ)加上MOT17 Train的前一半,且使用是预训练权重是CenterNet的COCO预训练权重,验证是在MOT17 Train的后一半上测的。
- BYTETracker应用到PaddleDetection的其他FairMOT模型,只需要更改对应的config文件里的tracker部分为如下所示:
```
JDETracker:
use_byte: True
match_thres: 0.8
conf_thres: 0.4
low_conf_thres: 0.2
```
## 快速开始 ## 快速开始
......
_BASE_: [
'../../datasets/mot.yml',
'../../runtime.yml',
'_base_/optimizer_30e.yml',
'_base_/fairmot_dla34.yml',
'_base_/fairmot_reader_1088x608.yml',
]
weights: output/fairmot_dla34_30e_1088x608_bytetracker/model_final
# for ablation study, MIX + MOT17-half
TrainDataset:
!MOTDataSet
dataset_dir: dataset/mot
image_lists: ['mot17.half', 'caltech.all', 'cuhksysu.train', 'prw.train', 'citypersons.train', 'eth.train']
data_fields: ['image', 'gt_bbox', 'gt_class', 'gt_ide']
JDETracker:
use_byte: True
match_thres: 0.8
conf_thres: 0.4
low_conf_thres: 0.2
...@@ -98,28 +98,28 @@ class STrack(BaseTrack): ...@@ -98,28 +98,28 @@ class STrack(BaseTrack):
def __init__(self, def __init__(self,
tlwh, tlwh,
score, score,
temp_feat,
num_classes,
cls_id, cls_id,
buff_size=30): buff_size=30,
# object class id temp_feat=None):
self.cls_id = cls_id
# wait activate # wait activate
self._tlwh = np.asarray(tlwh, dtype=np.float) self._tlwh = np.asarray(tlwh, dtype=np.float)
self.score = score
self.cls_id = cls_id
self.track_len = 0
self.kalman_filter = None self.kalman_filter = None
self.mean, self.covariance = None, None self.mean, self.covariance = None, None
self.is_activated = False self.is_activated = False
self.score = score self.use_reid = True if temp_feat is not None else False
self.track_len = 0 if self.use_reid:
self.smooth_feat = None
self.smooth_feat = None self.update_features(temp_feat)
self.update_features(temp_feat) self.features = deque([], maxlen=buff_size)
self.features = deque([], maxlen=buff_size) self.alpha = 0.9
self.alpha = 0.9
def update_features(self, feat): def update_features(self, feat):
# L2 normalizing # L2 normalizing, this function has no use for BYTETracker
feat /= np.linalg.norm(feat) feat /= np.linalg.norm(feat)
self.curr_feat = feat self.curr_feat = feat
if self.smooth_feat is None: if self.smooth_feat is None:
...@@ -175,7 +175,8 @@ class STrack(BaseTrack): ...@@ -175,7 +175,8 @@ class STrack(BaseTrack):
def re_activate(self, new_track, frame_id, new_id=False): def re_activate(self, new_track, frame_id, new_id=False):
self.mean, self.covariance = self.kalman_filter.update( self.mean, self.covariance = self.kalman_filter.update(
self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh)) self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh))
self.update_features(new_track.curr_feat) if self.use_reid:
self.update_features(new_track.curr_feat)
self.track_len = 0 self.track_len = 0
self.state = TrackState.Tracked self.state = TrackState.Tracked
self.is_activated = True self.is_activated = True
...@@ -194,7 +195,7 @@ class STrack(BaseTrack): ...@@ -194,7 +195,7 @@ class STrack(BaseTrack):
self.is_activated = True # set flag 'activated' self.is_activated = True # set flag 'activated'
self.score = new_track.score self.score = new_track.score
if update_feature: if update_feature and self.use_reid:
self.update_features(new_track.curr_feat) self.update_features(new_track.curr_feat)
@property @property
......
...@@ -52,6 +52,7 @@ class JDETracker(object): ...@@ -52,6 +52,7 @@ class JDETracker(object):
""" """
def __init__(self, def __init__(self,
use_byte=False,
num_classes=1, num_classes=1,
det_thresh=0.3, det_thresh=0.3,
track_buffer=30, track_buffer=30,
...@@ -60,11 +61,14 @@ class JDETracker(object): ...@@ -60,11 +61,14 @@ class JDETracker(object):
tracked_thresh=0.7, tracked_thresh=0.7,
r_tracked_thresh=0.5, r_tracked_thresh=0.5,
unconfirmed_thresh=0.7, unconfirmed_thresh=0.7,
motion='KalmanFilter',
conf_thres=0, conf_thres=0,
match_thres=0.8,
low_conf_thres=0.2,
motion='KalmanFilter',
metric_type='euclidean'): metric_type='euclidean'):
self.use_byte = use_byte
self.num_classes = num_classes self.num_classes = num_classes
self.det_thresh = det_thresh self.det_thresh = det_thresh if not use_byte else conf_thres + 0.1
self.track_buffer = track_buffer self.track_buffer = track_buffer
self.min_box_area = min_box_area self.min_box_area = min_box_area
self.vertical_ratio = vertical_ratio self.vertical_ratio = vertical_ratio
...@@ -72,9 +76,12 @@ class JDETracker(object): ...@@ -72,9 +76,12 @@ class JDETracker(object):
self.tracked_thresh = tracked_thresh self.tracked_thresh = tracked_thresh
self.r_tracked_thresh = r_tracked_thresh self.r_tracked_thresh = r_tracked_thresh
self.unconfirmed_thresh = unconfirmed_thresh self.unconfirmed_thresh = unconfirmed_thresh
self.conf_thres = conf_thres
self.match_thres = match_thres
self.low_conf_thres = low_conf_thres
if motion == 'KalmanFilter': if motion == 'KalmanFilter':
self.motion = KalmanFilter() self.motion = KalmanFilter()
self.conf_thres = conf_thres
self.metric_type = metric_type self.metric_type = metric_type
self.frame_id = 0 self.frame_id = 0
...@@ -85,7 +92,7 @@ class JDETracker(object): ...@@ -85,7 +92,7 @@ class JDETracker(object):
self.max_time_lost = 0 self.max_time_lost = 0
# max_time_lost will be calculated: int(frame_rate / 30.0 * track_buffer) # max_time_lost will be calculated: int(frame_rate / 30.0 * track_buffer)
def update(self, pred_dets, pred_embs): def update(self, pred_dets, pred_embs=None):
""" """
Processes the image frame and finds bounding box(detections). Processes the image frame and finds bounding box(detections).
Associates the detection with corresponding tracklets and also handles Associates the detection with corresponding tracklets and also handles
...@@ -117,7 +124,10 @@ class JDETracker(object): ...@@ -117,7 +124,10 @@ class JDETracker(object):
for cls_id in range(self.num_classes): for cls_id in range(self.num_classes):
cls_idx = (pred_dets[:, 5:] == cls_id).squeeze(-1) cls_idx = (pred_dets[:, 5:] == cls_id).squeeze(-1)
pred_dets_dict[cls_id] = pred_dets[cls_idx] pred_dets_dict[cls_id] = pred_dets[cls_idx]
pred_embs_dict[cls_id] = pred_embs[cls_idx] if pred_embs is not None:
pred_embs_dict[cls_id] = pred_embs[cls_idx]
else:
pred_embs_dict[cls_id] = None
for cls_id in range(self.num_classes): for cls_id in range(self.num_classes):
""" Step 1: Get detections by class""" """ Step 1: Get detections by class"""
...@@ -126,13 +136,19 @@ class JDETracker(object): ...@@ -126,13 +136,19 @@ class JDETracker(object):
remain_inds = (pred_dets_cls[:, 4:5] > self.conf_thres).squeeze(-1) remain_inds = (pred_dets_cls[:, 4:5] > self.conf_thres).squeeze(-1)
if remain_inds.sum() > 0: if remain_inds.sum() > 0:
pred_dets_cls = pred_dets_cls[remain_inds] pred_dets_cls = pred_dets_cls[remain_inds]
pred_embs_cls = pred_embs_cls[remain_inds] if self.use_byte:
detections = [ detections = [
STrack( STrack(
STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f, STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], cls_id, 30, temp_feat=None)
self.num_classes, cls_id, 30) for tlbrs in pred_dets_cls
for (tlbrs, f) in zip(pred_dets_cls, pred_embs_cls) ]
] else:
pred_embs_cls = pred_embs_cls[remain_inds]
detections = [
STrack(
STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], cls_id, 30, temp_feat)
for (tlbrs, temp_feat) in zip(pred_dets_cls, pred_embs_cls)
]
else: else:
detections = [] detections = []
''' Add newly detected tracklets to tracked_stracks''' ''' Add newly detected tracklets to tracked_stracks'''
...@@ -154,12 +170,17 @@ class JDETracker(object): ...@@ -154,12 +170,17 @@ class JDETracker(object):
# Predict the current location with KalmanFilter # Predict the current location with KalmanFilter
STrack.multi_predict(track_pool_dict[cls_id], self.motion) STrack.multi_predict(track_pool_dict[cls_id], self.motion)
dists = matching.embedding_distance( if self.use_byte:
track_pool_dict[cls_id], detections, metric=self.metric_type) dists = matching.iou_distance(track_pool_dict[cls_id], detections)
dists = matching.fuse_motion(self.motion, dists, matches, u_track, u_detection = matching.linear_assignment(
track_pool_dict[cls_id], detections) dists, thresh=self.match_thres) # not self.tracked_thresh
matches, u_track, u_detection = matching.linear_assignment( else:
dists, thresh=self.tracked_thresh) dists = matching.embedding_distance(
track_pool_dict[cls_id], detections, metric=self.metric_type)
dists = matching.fuse_motion(self.motion, dists,
track_pool_dict[cls_id], detections)
matches, u_track, u_detection = matching.linear_assignment(
dists, thresh=self.tracked_thresh)
for i_tracked, idet in matches: for i_tracked, idet in matches:
# i_tracked is the id of the track and idet is the detection # i_tracked is the id of the track and idet is the detection
...@@ -177,19 +198,41 @@ class JDETracker(object): ...@@ -177,19 +198,41 @@ class JDETracker(object):
# None of the steps below happen if there are no undetected tracks. # None of the steps below happen if there are no undetected tracks.
""" Step 3: Second association, with IOU""" """ Step 3: Second association, with IOU"""
detections = [detections[i] for i in u_detection] if self.use_byte:
r_tracked_stracks = [] inds_low = pred_dets_dict[cls_id][:, 4:5] > self.low_conf_thres
for i in u_track: inds_high = pred_dets_dict[cls_id][:, 4:5] < self.conf_thres
if track_pool_dict[cls_id][i].state == TrackState.Tracked: inds_second = np.logical_and(inds_low, inds_high).squeeze(-1)
r_tracked_stracks.append(track_pool_dict[cls_id][i]) pred_dets_cls_second = pred_dets_dict[cls_id][inds_second]
dists = matching.iou_distance(r_tracked_stracks, detections) # association the untrack to the low score detections
matches, u_track, u_detection = matching.linear_assignment( if len(pred_dets_cls_second) > 0:
dists, thresh=self.r_tracked_thresh) detections_second = [
STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], cls_id, 30, temp_feat=None)
for tlbrs in pred_dets_cls_second[:, :5]
]
else:
detections_second = []
r_tracked_stracks = [
track_pool_dict[cls_id][i] for i in u_track
if track_pool_dict[cls_id][i].state == TrackState.Tracked
]
dists = matching.iou_distance(r_tracked_stracks, detections_second)
matches, u_track, u_detection_second = matching.linear_assignment(
dists, thresh=0.4) # not r_tracked_thresh
else:
detections = [detections[i] for i in u_detection]
r_tracked_stracks = []
for i in u_track:
if track_pool_dict[cls_id][i].state == TrackState.Tracked:
r_tracked_stracks.append(track_pool_dict[cls_id][i])
dists = matching.iou_distance(r_tracked_stracks, detections)
matches, u_track, u_detection = matching.linear_assignment(
dists, thresh=self.r_tracked_thresh)
for i_tracked, idet in matches: for i_tracked, idet in matches:
track = r_tracked_stracks[i_tracked] track = r_tracked_stracks[i_tracked]
det = detections[idet] det = detections[idet] if not self.use_byte else detections_second[idet]
if track.state == TrackState.Tracked: if track.state == TrackState.Tracked:
track.update(det, self.frame_id) track.update(det, self.frame_id)
activated_tracks_dict[cls_id].append(track) activated_tracks_dict[cls_id].append(track)
......
...@@ -105,28 +105,28 @@ class STrack(BaseTrack): ...@@ -105,28 +105,28 @@ class STrack(BaseTrack):
def __init__(self, def __init__(self,
tlwh, tlwh,
score, score,
temp_feat,
num_classes,
cls_id, cls_id,
buff_size=30): buff_size=30,
# object class id temp_feat=None):
self.cls_id = cls_id
# wait activate # wait activate
self._tlwh = np.asarray(tlwh, dtype=np.float) self._tlwh = np.asarray(tlwh, dtype=np.float)
self.score = score
self.cls_id = cls_id
self.track_len = 0
self.kalman_filter = None self.kalman_filter = None
self.mean, self.covariance = None, None self.mean, self.covariance = None, None
self.is_activated = False self.is_activated = False
self.score = score self.use_reid = True if temp_feat is not None else False
self.track_len = 0 if self.use_reid:
self.smooth_feat = None
self.smooth_feat = None self.update_features(temp_feat)
self.update_features(temp_feat) self.features = deque([], maxlen=buff_size)
self.features = deque([], maxlen=buff_size) self.alpha = 0.9
self.alpha = 0.9
def update_features(self, feat): def update_features(self, feat):
# L2 normalizing # L2 normalizing, this function has no use for BYTETracker
feat /= np.linalg.norm(feat) feat /= np.linalg.norm(feat)
self.curr_feat = feat self.curr_feat = feat
if self.smooth_feat is None: if self.smooth_feat is None:
...@@ -182,7 +182,8 @@ class STrack(BaseTrack): ...@@ -182,7 +182,8 @@ class STrack(BaseTrack):
def re_activate(self, new_track, frame_id, new_id=False): def re_activate(self, new_track, frame_id, new_id=False):
self.mean, self.covariance = self.kalman_filter.update( self.mean, self.covariance = self.kalman_filter.update(
self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh)) self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh))
self.update_features(new_track.curr_feat) if self.use_reid:
self.update_features(new_track.curr_feat)
self.track_len = 0 self.track_len = 0
self.state = TrackState.Tracked self.state = TrackState.Tracked
self.is_activated = True self.is_activated = True
...@@ -201,7 +202,7 @@ class STrack(BaseTrack): ...@@ -201,7 +202,7 @@ class STrack(BaseTrack):
self.is_activated = True # set flag 'activated' self.is_activated = True # set flag 'activated'
self.score = new_track.score self.score = new_track.score
if update_feature: if update_feature and self.use_reid:
self.update_features(new_track.curr_feat) self.update_features(new_track.curr_feat)
@property @property
......
...@@ -58,6 +58,7 @@ class JDETracker(object): ...@@ -58,6 +58,7 @@ class JDETracker(object):
""" """
def __init__(self, def __init__(self,
use_byte=False,
num_classes=1, num_classes=1,
det_thresh=0.3, det_thresh=0.3,
track_buffer=30, track_buffer=30,
...@@ -66,11 +67,14 @@ class JDETracker(object): ...@@ -66,11 +67,14 @@ class JDETracker(object):
tracked_thresh=0.7, tracked_thresh=0.7,
r_tracked_thresh=0.5, r_tracked_thresh=0.5,
unconfirmed_thresh=0.7, unconfirmed_thresh=0.7,
motion='KalmanFilter',
conf_thres=0, conf_thres=0,
match_thres=0.8,
low_conf_thres=0.2,
motion='KalmanFilter',
metric_type='euclidean'): metric_type='euclidean'):
self.use_byte = use_byte
self.num_classes = num_classes self.num_classes = num_classes
self.det_thresh = det_thresh self.det_thresh = det_thresh if not use_byte else conf_thres + 0.1
self.track_buffer = track_buffer self.track_buffer = track_buffer
self.min_box_area = min_box_area self.min_box_area = min_box_area
self.vertical_ratio = vertical_ratio self.vertical_ratio = vertical_ratio
...@@ -78,9 +82,12 @@ class JDETracker(object): ...@@ -78,9 +82,12 @@ class JDETracker(object):
self.tracked_thresh = tracked_thresh self.tracked_thresh = tracked_thresh
self.r_tracked_thresh = r_tracked_thresh self.r_tracked_thresh = r_tracked_thresh
self.unconfirmed_thresh = unconfirmed_thresh self.unconfirmed_thresh = unconfirmed_thresh
self.conf_thres = conf_thres
self.match_thres = match_thres
self.low_conf_thres = low_conf_thres
if motion == 'KalmanFilter': if motion == 'KalmanFilter':
self.motion = KalmanFilter() self.motion = KalmanFilter()
self.conf_thres = conf_thres
self.metric_type = metric_type self.metric_type = metric_type
self.frame_id = 0 self.frame_id = 0
...@@ -91,7 +98,7 @@ class JDETracker(object): ...@@ -91,7 +98,7 @@ class JDETracker(object):
self.max_time_lost = 0 self.max_time_lost = 0
# max_time_lost will be calculated: int(frame_rate / 30.0 * track_buffer) # max_time_lost will be calculated: int(frame_rate / 30.0 * track_buffer)
def update(self, pred_dets, pred_embs): def update(self, pred_dets, pred_embs=None):
""" """
Processes the image frame and finds bounding box(detections). Processes the image frame and finds bounding box(detections).
Associates the detection with corresponding tracklets and also handles Associates the detection with corresponding tracklets and also handles
...@@ -123,7 +130,10 @@ class JDETracker(object): ...@@ -123,7 +130,10 @@ class JDETracker(object):
for cls_id in range(self.num_classes): for cls_id in range(self.num_classes):
cls_idx = (pred_dets[:, 5:] == cls_id).squeeze(-1) cls_idx = (pred_dets[:, 5:] == cls_id).squeeze(-1)
pred_dets_dict[cls_id] = pred_dets[cls_idx] pred_dets_dict[cls_id] = pred_dets[cls_idx]
pred_embs_dict[cls_id] = pred_embs[cls_idx] if pred_embs is not None:
pred_embs_dict[cls_id] = pred_embs[cls_idx]
else:
pred_embs_dict[cls_id] = None
for cls_id in range(self.num_classes): for cls_id in range(self.num_classes):
""" Step 1: Get detections by class""" """ Step 1: Get detections by class"""
...@@ -132,13 +142,19 @@ class JDETracker(object): ...@@ -132,13 +142,19 @@ class JDETracker(object):
remain_inds = (pred_dets_cls[:, 4:5] > self.conf_thres).squeeze(-1) remain_inds = (pred_dets_cls[:, 4:5] > self.conf_thres).squeeze(-1)
if remain_inds.sum() > 0: if remain_inds.sum() > 0:
pred_dets_cls = pred_dets_cls[remain_inds] pred_dets_cls = pred_dets_cls[remain_inds]
pred_embs_cls = pred_embs_cls[remain_inds] if self.use_byte:
detections = [ detections = [
STrack( STrack(
STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f, STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], cls_id, 30, temp_feat=None)
self.num_classes, cls_id, 30) for tlbrs in pred_dets_cls
for (tlbrs, f) in zip(pred_dets_cls, pred_embs_cls) ]
] else:
pred_embs_cls = pred_embs_cls[remain_inds]
detections = [
STrack(
STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], cls_id, 30, temp_feat)
for (tlbrs, temp_feat) in zip(pred_dets_cls, pred_embs_cls)
]
else: else:
detections = [] detections = []
''' Add newly detected tracklets to tracked_stracks''' ''' Add newly detected tracklets to tracked_stracks'''
...@@ -160,12 +176,17 @@ class JDETracker(object): ...@@ -160,12 +176,17 @@ class JDETracker(object):
# Predict the current location with KalmanFilter # Predict the current location with KalmanFilter
STrack.multi_predict(track_pool_dict[cls_id], self.motion) STrack.multi_predict(track_pool_dict[cls_id], self.motion)
dists = matching.embedding_distance( if self.use_byte:
track_pool_dict[cls_id], detections, metric=self.metric_type) dists = matching.iou_distance(track_pool_dict[cls_id], detections)
dists = matching.fuse_motion(self.motion, dists, matches, u_track, u_detection = matching.linear_assignment(
track_pool_dict[cls_id], detections) dists, thresh=self.match_thres) #
matches, u_track, u_detection = matching.linear_assignment( else:
dists, thresh=self.tracked_thresh) dists = matching.embedding_distance(
track_pool_dict[cls_id], detections, metric=self.metric_type)
dists = matching.fuse_motion(self.motion, dists,
track_pool_dict[cls_id], detections)
matches, u_track, u_detection = matching.linear_assignment(
dists, thresh=self.tracked_thresh)
for i_tracked, idet in matches: for i_tracked, idet in matches:
# i_tracked is the id of the track and idet is the detection # i_tracked is the id of the track and idet is the detection
...@@ -183,19 +204,41 @@ class JDETracker(object): ...@@ -183,19 +204,41 @@ class JDETracker(object):
# None of the steps below happen if there are no undetected tracks. # None of the steps below happen if there are no undetected tracks.
""" Step 3: Second association, with IOU""" """ Step 3: Second association, with IOU"""
detections = [detections[i] for i in u_detection] if self.use_byte:
r_tracked_stracks = [] inds_low = pred_dets_dict[cls_id][:, 4:5] > self.low_conf_thres
for i in u_track: inds_high = pred_dets_dict[cls_id][:, 4:5] < self.conf_thres
if track_pool_dict[cls_id][i].state == TrackState.Tracked: inds_second = np.logical_and(inds_low, inds_high).squeeze(-1)
r_tracked_stracks.append(track_pool_dict[cls_id][i]) pred_dets_cls_second = pred_dets_dict[cls_id][inds_second]
dists = matching.iou_distance(r_tracked_stracks, detections) # association the untrack to the low score detections
matches, u_track, u_detection = matching.linear_assignment( if len(pred_dets_cls_second) > 0:
dists, thresh=self.r_tracked_thresh) detections_second = [
STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], cls_id, 30, temp_feat=None)
for tlbrs in pred_dets_cls_second[:, :5]
]
else:
detections_second = []
r_tracked_stracks = [
track_pool_dict[cls_id][i] for i in u_track
if track_pool_dict[cls_id][i].state == TrackState.Tracked
]
dists = matching.iou_distance(r_tracked_stracks, detections_second)
matches, u_track, u_detection_second = matching.linear_assignment(
dists, thresh=0.4) # not r_tracked_thresh
else:
detections = [detections[i] for i in u_detection]
r_tracked_stracks = []
for i in u_track:
if track_pool_dict[cls_id][i].state == TrackState.Tracked:
r_tracked_stracks.append(track_pool_dict[cls_id][i])
dists = matching.iou_distance(r_tracked_stracks, detections)
matches, u_track, u_detection = matching.linear_assignment(
dists, thresh=self.r_tracked_thresh)
for i_tracked, idet in matches: for i_tracked, idet in matches:
track = r_tracked_stracks[i_tracked] track = r_tracked_stracks[i_tracked]
det = detections[idet] det = detections[idet] if not self.use_byte else detections_second[idet]
if track.state == TrackState.Tracked: if track.state == TrackState.Tracked:
track.update(det, self.frame_id) track.update(det, self.frame_id)
activated_tracks_dict[cls_id].append(track) activated_tracks_dict[cls_id].append(track)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册