From b78d756e9b07b5770586b7f79c6ba3f001404b6e Mon Sep 17 00:00:00 2001 From: Feng Ni Date: Thu, 11 Nov 2021 10:28:26 +0800 Subject: [PATCH] [MOT] fix kitti results, fix mcmot ids2names (#4536) --- deploy/python/mot_jde_infer.py | 2 +- ppdet/data/source/mot.py | 4 ++-- ppdet/modeling/mot/utils.py | 11 ++++++----- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/deploy/python/mot_jde_infer.py b/deploy/python/mot_jde_infer.py index 1070c0f41..c7006a7cd 100644 --- a/deploy/python/mot_jde_infer.py +++ b/deploy/python/mot_jde_infer.py @@ -183,7 +183,7 @@ def predict_image(detector, image_list): [frame], FLAGS.threshold) online_im = plot_tracking_dict(frame, num_classes, online_tlwhs, online_ids, online_scores, frame_id, - ids2names) + ids2names=ids2names) if FLAGS.save_images: if not os.path.exists(FLAGS.output_dir): os.makedirs(FLAGS.output_dir) diff --git a/ppdet/data/source/mot.py b/ppdet/data/source/mot.py index 7464dbfeb..5662b2c4b 100644 --- a/ppdet/data/source/mot.py +++ b/ppdet/data/source/mot.py @@ -367,10 +367,10 @@ class MCMOTDataSet(DetDataset): logger.info('Image start index: {}'.format(self.img_start_index)) logger.info('Total identities of each category: ') - self.num_identities_dict = sorted( + num_identities_dict = sorted( self.num_identities_dict.items(), key=lambda x: x[0]) total_IDs_all_cats = 0 - for (k, v) in self.num_identities_dict: + for (k, v) in num_identities_dict: logger.info('Category {} [{}] has {} IDs.'.format(k, cid2cname[k], v)) total_IDs_all_cats += v diff --git a/ppdet/modeling/mot/utils.py b/ppdet/modeling/mot/utils.py index 2b00f9d5f..1a39713fa 100644 --- a/ppdet/modeling/mot/utils.py +++ b/ppdet/modeling/mot/utils.py @@ -121,21 +121,22 @@ def write_mot_results(filename, results, data_type='mot', num_classes=1): f = open(filename, 'w') for cls_id in range(num_classes): for frame_id, tlwhs, tscores, track_ids in results[cls_id]: + if data_type == 'kitti': + frame_id -= 1 for tlwh, score, track_id in zip(tlwhs, tscores, track_ids): if track_id < 0: continue - if data_type == 'kitti': - frame_id -= 1 - elif data_type == 'mot': + if data_type == 'mot': cls_id = -1 - elif data_type == 'mcmot': - cls_id = cls_id x1, y1, w, h = tlwh + x2, y2 = x1 + w, y1 + h line = save_format.format( frame=frame_id, id=track_id, x1=x1, y1=y1, + x2=x2, + y2=y2, w=w, h=h, score=score, -- GitLab