diff --git a/ppdet/data/source/mot.py b/ppdet/data/source/mot.py index 7464dbfeb3f07f9e9ebd53655dd6c0b27a8e3886..5662b2c4bef04929958d7d345e02973491e680bd 100644 --- a/ppdet/data/source/mot.py +++ b/ppdet/data/source/mot.py @@ -367,10 +367,10 @@ class MCMOTDataSet(DetDataset): logger.info('Image start index: {}'.format(self.img_start_index)) logger.info('Total identities of each category: ') - self.num_identities_dict = sorted( + num_identities_dict = sorted( self.num_identities_dict.items(), key=lambda x: x[0]) total_IDs_all_cats = 0 - for (k, v) in self.num_identities_dict: + for (k, v) in num_identities_dict: logger.info('Category {} [{}] has {} IDs.'.format(k, cid2cname[k], v)) total_IDs_all_cats += v diff --git a/ppdet/modeling/mot/utils.py b/ppdet/modeling/mot/utils.py index 2b00f9d5f9b70d013de18db55f55e9250bced6b0..1a39713fa250688c19f474c26c67b01dc36f80e7 100644 --- a/ppdet/modeling/mot/utils.py +++ b/ppdet/modeling/mot/utils.py @@ -121,21 +121,22 @@ def write_mot_results(filename, results, data_type='mot', num_classes=1): f = open(filename, 'w') for cls_id in range(num_classes): for frame_id, tlwhs, tscores, track_ids in results[cls_id]: + if data_type == 'kitti': + frame_id -= 1 for tlwh, score, track_id in zip(tlwhs, tscores, track_ids): if track_id < 0: continue - if data_type == 'kitti': - frame_id -= 1 - elif data_type == 'mot': + if data_type == 'mot': cls_id = -1 - elif data_type == 'mcmot': - cls_id = cls_id x1, y1, w, h = tlwh + x2, y2 = x1 + w, y1 + h line = save_format.format( frame=frame_id, id=track_id, x1=x1, y1=y1, + x2=x2, + y2=y2, w=w, h=h, score=score,