From 0a78d4d201ab83afa43a641c27daff34300d584a Mon Sep 17 00:00:00 2001 From: Feng Ni Date: Mon, 15 Nov 2021 19:54:54 +0800 Subject: [PATCH] [cherry-pick][MOT] fix mot bug (#4590) --- ppdet/data/source/mot.py | 4 ++-- ppdet/modeling/mot/utils.py | 11 ++++++----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/ppdet/data/source/mot.py b/ppdet/data/source/mot.py index 7464dbfeb..5662b2c4b 100644 --- a/ppdet/data/source/mot.py +++ b/ppdet/data/source/mot.py @@ -367,10 +367,10 @@ class MCMOTDataSet(DetDataset): logger.info('Image start index: {}'.format(self.img_start_index)) logger.info('Total identities of each category: ') - self.num_identities_dict = sorted( + num_identities_dict = sorted( self.num_identities_dict.items(), key=lambda x: x[0]) total_IDs_all_cats = 0 - for (k, v) in self.num_identities_dict: + for (k, v) in num_identities_dict: logger.info('Category {} [{}] has {} IDs.'.format(k, cid2cname[k], v)) total_IDs_all_cats += v diff --git a/ppdet/modeling/mot/utils.py b/ppdet/modeling/mot/utils.py index 2b00f9d5f..1a39713fa 100644 --- a/ppdet/modeling/mot/utils.py +++ b/ppdet/modeling/mot/utils.py @@ -121,21 +121,22 @@ def write_mot_results(filename, results, data_type='mot', num_classes=1): f = open(filename, 'w') for cls_id in range(num_classes): for frame_id, tlwhs, tscores, track_ids in results[cls_id]: + if data_type == 'kitti': + frame_id -= 1 for tlwh, score, track_id in zip(tlwhs, tscores, track_ids): if track_id < 0: continue - if data_type == 'kitti': - frame_id -= 1 - elif data_type == 'mot': + if data_type == 'mot': cls_id = -1 - elif data_type == 'mcmot': - cls_id = cls_id x1, y1, w, h = tlwh + x2, y2 = x1 + w, y1 + h line = save_format.format( frame=frame_id, id=track_id, x1=x1, y1=y1, + x2=x2, + y2=y2, w=w, h=h, score=score, -- GitLab