diff --git a/ppdet/data/source/category.py b/ppdet/data/source/category.py index 56ca0f64d7eacaf41f86580f1ee289d3811c5417..4f85f5260ea0b836339009ef2e1e55841630e1cb 100644 --- a/ppdet/data/source/category.py +++ b/ppdet/data/source/category.py @@ -90,16 +90,19 @@ def get_categories(metric_type, anno_file=None, arch=None): elif metric_type.lower() in ['mot', 'motdet', 'reid']: return _mot_category() + elif metric_type.lower() in ['kitti', 'bdd100k']: + return _mot_category(category='car') + else: raise ValueError("unknown metric type {}".format(metric_type)) -def _mot_category(): +def _mot_category(category='person'): """ Get class id to category id map and category id to category name map of mot dataset """ - label_map = {'person': 0} + label_map = {category: 0} label_map = sorted(label_map.items(), key=lambda x: x[1]) cats = [l[0] for l in label_map]