未验证 提交 56d22694 编写于 作者: F Feng Ni 提交者: GitHub

[MOT] unify mot and det output format (#5320)

上级 d8508359
...@@ -90,13 +90,13 @@ class DeepSORTTracker(object): ...@@ -90,13 +90,13 @@ class DeepSORTTracker(object):
Perform measurement update and track management. Perform measurement update and track management.
Args: Args:
pred_dets (np.array): Detection results of the image, the shape is pred_dets (np.array): Detection results of the image, the shape is
[N, 6], means 'x0, y0, x1, y1, score, cls_id'. [N, 6], means 'cls_id, score, x0, y0, x1, y1'.
pred_embs (np.array): Embedding results of the image, the shape is pred_embs (np.array): Embedding results of the image, the shape is
[N, 128], usually pred_embs.shape[1] is a multiple of 128. [N, 128], usually pred_embs.shape[1] is a multiple of 128.
""" """
pred_tlwhs = pred_dets[:, :4] pred_cls_ids = pred_dets[:, 0:1]
pred_scores = pred_dets[:, 4:5] pred_scores = pred_dets[:, 1:2]
pred_cls_ids = pred_dets[:, 5:] pred_tlwhs = pred_dets[:, 2:6]
detections = [ detections = [
Detection(tlwh, score, feat, cls_id) Detection(tlwh, score, feat, cls_id)
......
...@@ -100,7 +100,7 @@ class JDETracker(object): ...@@ -100,7 +100,7 @@ class JDETracker(object):
Args: Args:
pred_dets (np.array): Detection results of the image, the shape is pred_dets (np.array): Detection results of the image, the shape is
[N, 6], means 'x0, y0, x1, y1, score, cls_id'. [N, 6], means 'cls_id, score, x0, y0, x1, y1'.
pred_embs (np.array): Embedding results of the image, the shape is pred_embs (np.array): Embedding results of the image, the shape is
[N, 128] or [N, 512]. [N, 128] or [N, 512].
...@@ -122,7 +122,7 @@ class JDETracker(object): ...@@ -122,7 +122,7 @@ class JDETracker(object):
# unify single and multi classes detection and embedding results # unify single and multi classes detection and embedding results
for cls_id in range(self.num_classes): for cls_id in range(self.num_classes):
cls_idx = (pred_dets[:, 5:] == cls_id).squeeze(-1) cls_idx = (pred_dets[:, 0:1] == cls_id).squeeze(-1)
pred_dets_dict[cls_id] = pred_dets[cls_idx] pred_dets_dict[cls_id] = pred_dets[cls_idx]
if pred_embs is not None: if pred_embs is not None:
pred_embs_dict[cls_id] = pred_embs[cls_idx] pred_embs_dict[cls_id] = pred_embs[cls_idx]
...@@ -133,21 +133,26 @@ class JDETracker(object): ...@@ -133,21 +133,26 @@ class JDETracker(object):
""" Step 1: Get detections by class""" """ Step 1: Get detections by class"""
pred_dets_cls = pred_dets_dict[cls_id] pred_dets_cls = pred_dets_dict[cls_id]
pred_embs_cls = pred_embs_dict[cls_id] pred_embs_cls = pred_embs_dict[cls_id]
remain_inds = (pred_dets_cls[:, 4:5] > self.conf_thres).squeeze(-1) remain_inds = (pred_dets_cls[:, 1:2] > self.conf_thres).squeeze(-1)
if remain_inds.sum() > 0: if remain_inds.sum() > 0:
pred_dets_cls = pred_dets_cls[remain_inds] pred_dets_cls = pred_dets_cls[remain_inds]
if self.use_byte: if self.use_byte:
detections = [ detections = [
STrack( STrack(
STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], cls_id, 30, temp_feat=None) STrack.tlbr_to_tlwh(tlbrs[2:6]),
for tlbrs in pred_dets_cls tlbrs[1],
cls_id,
30,
temp_feat=None) for tlbrs in pred_dets_cls
] ]
else: else:
pred_embs_cls = pred_embs_cls[remain_inds] pred_embs_cls = pred_embs_cls[remain_inds]
detections = [ detections = [
STrack( STrack(
STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], cls_id, 30, temp_feat) STrack.tlbr_to_tlwh(tlbrs[2:6]), tlbrs[1], cls_id,
for (tlbrs, temp_feat) in zip(pred_dets_cls, pred_embs_cls) 30, temp_feat)
for (tlbrs, temp_feat
) in zip(pred_dets_cls, pred_embs_cls)
] ]
else: else:
detections = [] detections = []
...@@ -171,14 +176,17 @@ class JDETracker(object): ...@@ -171,14 +176,17 @@ class JDETracker(object):
STrack.multi_predict(track_pool_dict[cls_id], self.motion) STrack.multi_predict(track_pool_dict[cls_id], self.motion)
if self.use_byte: if self.use_byte:
dists = matching.iou_distance(track_pool_dict[cls_id], detections) dists = matching.iou_distance(track_pool_dict[cls_id],
detections)
matches, u_track, u_detection = matching.linear_assignment( matches, u_track, u_detection = matching.linear_assignment(
dists, thresh=self.match_thres) # not self.tracked_thresh dists, thresh=self.match_thres) # not self.tracked_thresh
else: else:
dists = matching.embedding_distance( dists = matching.embedding_distance(
track_pool_dict[cls_id], detections, metric=self.metric_type) track_pool_dict[cls_id],
dists = matching.fuse_motion(self.motion, dists, detections,
track_pool_dict[cls_id], detections) metric=self.metric_type)
dists = matching.fuse_motion(
self.motion, dists, track_pool_dict[cls_id], detections)
matches, u_track, u_detection = matching.linear_assignment( matches, u_track, u_detection = matching.linear_assignment(
dists, thresh=self.tracked_thresh) dists, thresh=self.tracked_thresh)
...@@ -199,15 +207,20 @@ class JDETracker(object): ...@@ -199,15 +207,20 @@ class JDETracker(object):
# None of the steps below happen if there are no undetected tracks. # None of the steps below happen if there are no undetected tracks.
""" Step 3: Second association, with IOU""" """ Step 3: Second association, with IOU"""
if self.use_byte: if self.use_byte:
inds_low = pred_dets_dict[cls_id][:, 4:5] > self.low_conf_thres inds_low = pred_dets_dict[cls_id][:, 1:2] > self.low_conf_thres
inds_high = pred_dets_dict[cls_id][:, 4:5] < self.conf_thres inds_high = pred_dets_dict[cls_id][:, 1:2] < self.conf_thres
inds_second = np.logical_and(inds_low, inds_high).squeeze(-1) inds_second = np.logical_and(inds_low, inds_high).squeeze(-1)
pred_dets_cls_second = pred_dets_dict[cls_id][inds_second] pred_dets_cls_second = pred_dets_dict[cls_id][inds_second]
# association the untrack to the low score detections # association the untrack to the low score detections
if len(pred_dets_cls_second) > 0: if len(pred_dets_cls_second) > 0:
detections_second = [ detections_second = [
STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], cls_id, 30, temp_feat=None) STrack(
STrack.tlbr_to_tlwh(tlbrs[:4]),
tlbrs[4],
cls_id,
30,
temp_feat=None)
for tlbrs in pred_dets_cls_second[:, :5] for tlbrs in pred_dets_cls_second[:, :5]
] ]
else: else:
...@@ -216,9 +229,10 @@ class JDETracker(object): ...@@ -216,9 +229,10 @@ class JDETracker(object):
track_pool_dict[cls_id][i] for i in u_track track_pool_dict[cls_id][i] for i in u_track
if track_pool_dict[cls_id][i].state == TrackState.Tracked if track_pool_dict[cls_id][i].state == TrackState.Tracked
] ]
dists = matching.iou_distance(r_tracked_stracks, detections_second) dists = matching.iou_distance(r_tracked_stracks,
detections_second)
matches, u_track, u_detection_second = matching.linear_assignment( matches, u_track, u_detection_second = matching.linear_assignment(
dists, thresh=0.4) # not r_tracked_thresh dists, thresh=0.4) # not r_tracked_thresh
else: else:
detections = [detections[i] for i in u_detection] detections = [detections[i] for i in u_detection]
r_tracked_stracks = [] r_tracked_stracks = []
...@@ -232,7 +246,8 @@ class JDETracker(object): ...@@ -232,7 +246,8 @@ class JDETracker(object):
for i_tracked, idet in matches: for i_tracked, idet in matches:
track = r_tracked_stracks[i_tracked] track = r_tracked_stracks[i_tracked]
det = detections[idet] if not self.use_byte else detections_second[idet] det = detections[
idet] if not self.use_byte else detections_second[idet]
if track.state == TrackState.Tracked: if track.state == TrackState.Tracked:
track.update(det, self.frame_id) track.update(det, self.frame_id)
activated_tracks_dict[cls_id].append(track) activated_tracks_dict[cls_id].append(track)
......
...@@ -115,7 +115,7 @@ class JDE_Detector(Detector): ...@@ -115,7 +115,7 @@ class JDE_Detector(Detector):
return result return result
def tracking(self, det_results): def tracking(self, det_results):
pred_dets = det_results['pred_dets'] pred_dets = det_results['pred_dets'] # 'cls_id, score, x0, y0, x1, y1'
pred_embs = det_results['pred_embs'] pred_embs = det_results['pred_embs']
online_targets_dict = self.tracker.update(pred_dets, pred_embs) online_targets_dict = self.tracker.update(pred_dets, pred_embs)
...@@ -143,7 +143,7 @@ class JDE_Detector(Detector): ...@@ -143,7 +143,7 @@ class JDE_Detector(Detector):
repeats (int): repeats number for prediction repeats (int): repeats number for prediction
Returns: Returns:
result (dict): include 'pred_dets': np.ndarray: shape:[N,6], N: number of box, result (dict): include 'pred_dets': np.ndarray: shape:[N,6], N: number of box,
matix element:[x_min, y_min, x_max, y_max, score, class] matix element:[class, score, x_min, y_min, x_max, y_max]
FairMOT(JDE)'s result include 'pred_embs': np.ndarray: FairMOT(JDE)'s result include 'pred_embs': np.ndarray:
shape: [N, 128] shape: [N, 128]
''' '''
......
...@@ -111,11 +111,8 @@ class SDE_Detector(Detector): ...@@ -111,11 +111,8 @@ class SDE_Detector(Detector):
low_conf_thres=low_conf_thres) low_conf_thres=low_conf_thres)
def tracking(self, det_results): def tracking(self, det_results):
pred_dets = det_results['boxes'] pred_dets = det_results['boxes'] # 'cls_id, score, x0, y0, x1, y1'
pred_embs = None pred_embs = None
pred_dets = np.concatenate(
(pred_dets[:, 2:], pred_dets[:, 1:2], pred_dets[:, 0:1]), 1)
# pred_dets should be 'x0, y0, x1, y1, score, cls_id'
online_targets_dict = self.tracker.update(pred_dets, pred_embs) online_targets_dict = self.tracker.update(pred_dets, pred_embs)
online_tlwhs = defaultdict(list) online_tlwhs = defaultdict(list)
......
...@@ -282,14 +282,14 @@ class Tracker(object): ...@@ -282,14 +282,14 @@ class Tracker(object):
# thus will not inference reid model # thus will not inference reid model
continue continue
pred_scores = pred_scores[keep_idx[0]]
pred_cls_ids = pred_cls_ids[keep_idx[0]] pred_cls_ids = pred_cls_ids[keep_idx[0]]
pred_scores = pred_scores[keep_idx[0]]
pred_tlwhs = np.concatenate( pred_tlwhs = np.concatenate(
(pred_xyxys[:, 0:2], (pred_xyxys[:, 0:2],
pred_xyxys[:, 2:4] - pred_xyxys[:, 0:2] + 1), pred_xyxys[:, 2:4] - pred_xyxys[:, 0:2] + 1),
axis=1) axis=1)
pred_dets = np.concatenate( pred_dets = np.concatenate(
(pred_tlwhs, pred_scores, pred_cls_ids), axis=1) (pred_cls_ids, pred_scores, pred_tlwhs), axis=1)
tracker = self.model.tracker tracker = self.model.tracker
crops = get_crops( crops = get_crops(
......
...@@ -96,13 +96,13 @@ class DeepSORTTracker(object): ...@@ -96,13 +96,13 @@ class DeepSORTTracker(object):
Perform measurement update and track management. Perform measurement update and track management.
Args: Args:
pred_dets (np.array): Detection results of the image, the shape is pred_dets (np.array): Detection results of the image, the shape is
[N, 6], means 'x0, y0, x1, y1, score, cls_id'. [N, 6], means 'cls_id, score, x0, y0, x1, y1'.
pred_embs (np.array): Embedding results of the image, the shape is pred_embs (np.array): Embedding results of the image, the shape is
[N, 128], usually pred_embs.shape[1] is a multiple of 128. [N, 128], usually pred_embs.shape[1] is a multiple of 128.
""" """
pred_tlwhs = pred_dets[:, :4] pred_cls_ids = pred_dets[:, 0:1]
pred_scores = pred_dets[:, 4:5] pred_scores = pred_dets[:, 1:2]
pred_cls_ids = pred_dets[:, 5:] pred_tlwhs = pred_dets[:, 2:6]
detections = [ detections = [
Detection(tlwh, score, feat, cls_id) Detection(tlwh, score, feat, cls_id)
......
...@@ -106,7 +106,7 @@ class JDETracker(object): ...@@ -106,7 +106,7 @@ class JDETracker(object):
Args: Args:
pred_dets (np.array): Detection results of the image, the shape is pred_dets (np.array): Detection results of the image, the shape is
[N, 6], means 'x0, y0, x1, y1, score, cls_id'. [N, 6], means 'cls_id, score, x0, y0, x1, y1'.
pred_embs (np.array): Embedding results of the image, the shape is pred_embs (np.array): Embedding results of the image, the shape is
[N, 128] or [N, 512]. [N, 128] or [N, 512].
...@@ -128,7 +128,7 @@ class JDETracker(object): ...@@ -128,7 +128,7 @@ class JDETracker(object):
# unify single and multi classes detection and embedding results # unify single and multi classes detection and embedding results
for cls_id in range(self.num_classes): for cls_id in range(self.num_classes):
cls_idx = (pred_dets[:, 5:] == cls_id).squeeze(-1) cls_idx = (pred_dets[:, 0:1] == cls_id).squeeze(-1)
pred_dets_dict[cls_id] = pred_dets[cls_idx] pred_dets_dict[cls_id] = pred_dets[cls_idx]
if pred_embs is not None: if pred_embs is not None:
pred_embs_dict[cls_id] = pred_embs[cls_idx] pred_embs_dict[cls_id] = pred_embs[cls_idx]
...@@ -139,21 +139,26 @@ class JDETracker(object): ...@@ -139,21 +139,26 @@ class JDETracker(object):
""" Step 1: Get detections by class""" """ Step 1: Get detections by class"""
pred_dets_cls = pred_dets_dict[cls_id] pred_dets_cls = pred_dets_dict[cls_id]
pred_embs_cls = pred_embs_dict[cls_id] pred_embs_cls = pred_embs_dict[cls_id]
remain_inds = (pred_dets_cls[:, 4:5] > self.conf_thres).squeeze(-1) remain_inds = (pred_dets_cls[:, 1:2] > self.conf_thres).squeeze(-1)
if remain_inds.sum() > 0: if remain_inds.sum() > 0:
pred_dets_cls = pred_dets_cls[remain_inds] pred_dets_cls = pred_dets_cls[remain_inds]
if self.use_byte: if self.use_byte:
detections = [ detections = [
STrack( STrack(
STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], cls_id, 30, temp_feat=None) STrack.tlbr_to_tlwh(tlbrs[2:6]),
for tlbrs in pred_dets_cls tlbrs[1],
cls_id,
30,
temp_feat=None) for tlbrs in pred_dets_cls
] ]
else: else:
pred_embs_cls = pred_embs_cls[remain_inds] pred_embs_cls = pred_embs_cls[remain_inds]
detections = [ detections = [
STrack( STrack(
STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], cls_id, 30, temp_feat) STrack.tlbr_to_tlwh(tlbrs[2:6]), tlbrs[1], cls_id,
for (tlbrs, temp_feat) in zip(pred_dets_cls, pred_embs_cls) 30, temp_feat)
for (tlbrs, temp_feat
) in zip(pred_dets_cls, pred_embs_cls)
] ]
else: else:
detections = [] detections = []
...@@ -177,14 +182,17 @@ class JDETracker(object): ...@@ -177,14 +182,17 @@ class JDETracker(object):
STrack.multi_predict(track_pool_dict[cls_id], self.motion) STrack.multi_predict(track_pool_dict[cls_id], self.motion)
if self.use_byte: if self.use_byte:
dists = matching.iou_distance(track_pool_dict[cls_id], detections) dists = matching.iou_distance(track_pool_dict[cls_id],
detections)
matches, u_track, u_detection = matching.linear_assignment( matches, u_track, u_detection = matching.linear_assignment(
dists, thresh=self.match_thres) # dists, thresh=self.match_thres) # not self.tracked_thresh
else: else:
dists = matching.embedding_distance( dists = matching.embedding_distance(
track_pool_dict[cls_id], detections, metric=self.metric_type) track_pool_dict[cls_id],
dists = matching.fuse_motion(self.motion, dists, detections,
track_pool_dict[cls_id], detections) metric=self.metric_type)
dists = matching.fuse_motion(
self.motion, dists, track_pool_dict[cls_id], detections)
matches, u_track, u_detection = matching.linear_assignment( matches, u_track, u_detection = matching.linear_assignment(
dists, thresh=self.tracked_thresh) dists, thresh=self.tracked_thresh)
...@@ -205,15 +213,20 @@ class JDETracker(object): ...@@ -205,15 +213,20 @@ class JDETracker(object):
# None of the steps below happen if there are no undetected tracks. # None of the steps below happen if there are no undetected tracks.
""" Step 3: Second association, with IOU""" """ Step 3: Second association, with IOU"""
if self.use_byte: if self.use_byte:
inds_low = pred_dets_dict[cls_id][:, 4:5] > self.low_conf_thres inds_low = pred_dets_dict[cls_id][:, 1:2] > self.low_conf_thres
inds_high = pred_dets_dict[cls_id][:, 4:5] < self.conf_thres inds_high = pred_dets_dict[cls_id][:, 1:2] < self.conf_thres
inds_second = np.logical_and(inds_low, inds_high).squeeze(-1) inds_second = np.logical_and(inds_low, inds_high).squeeze(-1)
pred_dets_cls_second = pred_dets_dict[cls_id][inds_second] pred_dets_cls_second = pred_dets_dict[cls_id][inds_second]
# association the untrack to the low score detections # association the untrack to the low score detections
if len(pred_dets_cls_second) > 0: if len(pred_dets_cls_second) > 0:
detections_second = [ detections_second = [
STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], cls_id, 30, temp_feat=None) STrack(
STrack.tlbr_to_tlwh(tlbrs[:4]),
tlbrs[4],
cls_id,
30,
temp_feat=None)
for tlbrs in pred_dets_cls_second[:, :5] for tlbrs in pred_dets_cls_second[:, :5]
] ]
else: else:
...@@ -222,9 +235,10 @@ class JDETracker(object): ...@@ -222,9 +235,10 @@ class JDETracker(object):
track_pool_dict[cls_id][i] for i in u_track track_pool_dict[cls_id][i] for i in u_track
if track_pool_dict[cls_id][i].state == TrackState.Tracked if track_pool_dict[cls_id][i].state == TrackState.Tracked
] ]
dists = matching.iou_distance(r_tracked_stracks, detections_second) dists = matching.iou_distance(r_tracked_stracks,
detections_second)
matches, u_track, u_detection_second = matching.linear_assignment( matches, u_track, u_detection_second = matching.linear_assignment(
dists, thresh=0.4) # not r_tracked_thresh dists, thresh=0.4) # not r_tracked_thresh
else: else:
detections = [detections[i] for i in u_detection] detections = [detections[i] for i in u_detection]
r_tracked_stracks = [] r_tracked_stracks = []
...@@ -238,7 +252,8 @@ class JDETracker(object): ...@@ -238,7 +252,8 @@ class JDETracker(object):
for i_tracked, idet in matches: for i_tracked, idet in matches:
track = r_tracked_stracks[i_tracked] track = r_tracked_stracks[i_tracked]
det = detections[idet] if not self.use_byte else detections_second[idet] det = detections[
idet] if not self.use_byte else detections_second[idet]
if track.state == TrackState.Tracked: if track.state == TrackState.Tracked:
track.update(det, self.frame_id) track.update(det, self.frame_id)
activated_tracks_dict[cls_id].append(track) activated_tracks_dict[cls_id].append(track)
......
...@@ -504,11 +504,10 @@ class CenterNetPostProcess(TTFBox): ...@@ -504,11 +504,10 @@ class CenterNetPostProcess(TTFBox):
boxes_shape = bboxes.shape[:] boxes_shape = bboxes.shape[:]
scale_expand = paddle.expand(scale_expand, shape=boxes_shape) scale_expand = paddle.expand(scale_expand, shape=boxes_shape)
bboxes = paddle.divide(bboxes, scale_expand) bboxes = paddle.divide(bboxes, scale_expand)
results = paddle.concat([clses, scores, bboxes], axis=1)
if self.for_mot: if self.for_mot:
results = paddle.concat([bboxes, scores, clses], axis=1)
return results, inds, topk_clses return results, inds, topk_clses
else: else:
results = paddle.concat([clses, scores, bboxes], axis=1)
return results, paddle.shape(results)[0:1], topk_clses return results, paddle.shape(results)[0:1], topk_clses
......
...@@ -152,9 +152,8 @@ class JDEEmbeddingHead(nn.Layer): ...@@ -152,9 +152,8 @@ class JDEEmbeddingHead(nn.Layer):
scale_factor = targets['scale_factor'][0].numpy() scale_factor = targets['scale_factor'][0].numpy()
bboxes[:, 2:] = self.scale_coords(bboxes[:, 2:], input_shape, bboxes[:, 2:] = self.scale_coords(bboxes[:, 2:], input_shape,
im_shape, scale_factor) im_shape, scale_factor)
# tlwhs, scores, cls_ids # cls_ids, scores, tlwhs
pred_dets = paddle.concat( pred_dets = bboxes
(bboxes[:, 2:], bboxes[:, 1:2], bboxes[:, 0:1]), axis=1)
return pred_dets, pred_embs return pred_dets, pred_embs
def scale_coords(self, coords, input_shape, im_shape, scale_factor): def scale_coords(self, coords, input_shape, im_shape, scale_factor):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册