未验证 提交 b78d756e 编写于 作者: F Feng Ni 提交者: GitHub

[MOT] fix kitti results, fix mcmot ids2names (#4536)

上级 c44543ae
...@@ -183,7 +183,7 @@ def predict_image(detector, image_list): ...@@ -183,7 +183,7 @@ def predict_image(detector, image_list):
[frame], FLAGS.threshold) [frame], FLAGS.threshold)
online_im = plot_tracking_dict(frame, num_classes, online_tlwhs, online_im = plot_tracking_dict(frame, num_classes, online_tlwhs,
online_ids, online_scores, frame_id, online_ids, online_scores, frame_id,
ids2names) ids2names=ids2names)
if FLAGS.save_images: if FLAGS.save_images:
if not os.path.exists(FLAGS.output_dir): if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir) os.makedirs(FLAGS.output_dir)
......
...@@ -367,10 +367,10 @@ class MCMOTDataSet(DetDataset): ...@@ -367,10 +367,10 @@ class MCMOTDataSet(DetDataset):
logger.info('Image start index: {}'.format(self.img_start_index)) logger.info('Image start index: {}'.format(self.img_start_index))
logger.info('Total identities of each category: ') logger.info('Total identities of each category: ')
self.num_identities_dict = sorted( num_identities_dict = sorted(
self.num_identities_dict.items(), key=lambda x: x[0]) self.num_identities_dict.items(), key=lambda x: x[0])
total_IDs_all_cats = 0 total_IDs_all_cats = 0
for (k, v) in self.num_identities_dict: for (k, v) in num_identities_dict:
logger.info('Category {} [{}] has {} IDs.'.format(k, cid2cname[k], logger.info('Category {} [{}] has {} IDs.'.format(k, cid2cname[k],
v)) v))
total_IDs_all_cats += v total_IDs_all_cats += v
......
...@@ -121,21 +121,22 @@ def write_mot_results(filename, results, data_type='mot', num_classes=1): ...@@ -121,21 +121,22 @@ def write_mot_results(filename, results, data_type='mot', num_classes=1):
f = open(filename, 'w') f = open(filename, 'w')
for cls_id in range(num_classes): for cls_id in range(num_classes):
for frame_id, tlwhs, tscores, track_ids in results[cls_id]: for frame_id, tlwhs, tscores, track_ids in results[cls_id]:
if data_type == 'kitti':
frame_id -= 1
for tlwh, score, track_id in zip(tlwhs, tscores, track_ids): for tlwh, score, track_id in zip(tlwhs, tscores, track_ids):
if track_id < 0: continue if track_id < 0: continue
if data_type == 'kitti': if data_type == 'mot':
frame_id -= 1
elif data_type == 'mot':
cls_id = -1 cls_id = -1
elif data_type == 'mcmot':
cls_id = cls_id
x1, y1, w, h = tlwh x1, y1, w, h = tlwh
x2, y2 = x1 + w, y1 + h
line = save_format.format( line = save_format.format(
frame=frame_id, frame=frame_id,
id=track_id, id=track_id,
x1=x1, x1=x1,
y1=y1, y1=y1,
x2=x2,
y2=y2,
w=w, w=w,
h=h, h=h,
score=score, score=score,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册