未验证 提交 1a8c009c 编写于 作者: F Feng Ni 提交者: GitHub

add warnings (#5512)

上级 b9e096ef
...@@ -17,3 +17,4 @@ EvalDataset: ...@@ -17,3 +17,4 @@ EvalDataset:
TestDataset: TestDataset:
!ImageFolder !ImageFolder
anno_path: annotations/instances_val2017.json anno_path: annotations/instances_val2017.json
dataset_dir: dataset/coco
...@@ -17,3 +17,4 @@ EvalDataset: ...@@ -17,3 +17,4 @@ EvalDataset:
TestDataset: TestDataset:
!ImageFolder !ImageFolder
anno_path: annotations/instances_val2017.json anno_path: annotations/instances_val2017.json
dataset_dir: dataset/coco
...@@ -17,3 +17,4 @@ EvalDataset: ...@@ -17,3 +17,4 @@ EvalDataset:
TestDataset: TestDataset:
!ImageFolder !ImageFolder
anno_path: trainval_split/s2anet_trainval_paddle_coco.json anno_path: trainval_split/s2anet_trainval_paddle_coco.json
dataset_dir: dataset/DOTA_1024_s2anet/
...@@ -17,6 +17,7 @@ EvalDataset: ...@@ -17,6 +17,7 @@ EvalDataset:
TestDataset: TestDataset:
!ImageFolder !ImageFolder
dataset_dir: dataset/mot/MOT17
anno_path: annotations/val_half.json anno_path: annotations/val_half.json
......
...@@ -17,6 +17,7 @@ EvalDataset: ...@@ -17,6 +17,7 @@ EvalDataset:
TestDataset: TestDataset:
!ImageFolder !ImageFolder
dataset_dir: dataset/mot/MOT17
anno_path: annotations/val_half.json anno_path: annotations/val_half.json
......
...@@ -39,6 +39,11 @@ def get_categories(metric_type, anno_file=None, arch=None): ...@@ -39,6 +39,11 @@ def get_categories(metric_type, anno_file=None, arch=None):
if arch == 'keypoint_arch': if arch == 'keypoint_arch':
return (None, {'id': 'keypoint'}) return (None, {'id': 'keypoint'})
if anno_file == None or (not os.path.isfile(anno_file)):
logger.warning("anno_file '{}' is None or not set or not exist, "
"please recheck TrainDataset/EvalDataset/TestDataset.anno_path, "
"otherwise the default categories will be used by metric_type.".format(anno_file))
if metric_type.lower() == 'coco' or metric_type.lower( if metric_type.lower() == 'coco' or metric_type.lower(
) == 'rbox' or metric_type.lower() == 'snipercoco': ) == 'rbox' or metric_type.lower() == 'snipercoco':
if anno_file and os.path.isfile(anno_file): if anno_file and os.path.isfile(anno_file):
...@@ -55,8 +60,9 @@ def get_categories(metric_type, anno_file=None, arch=None): ...@@ -55,8 +60,9 @@ def get_categories(metric_type, anno_file=None, arch=None):
# anno file not exist, load default categories of COCO17 # anno file not exist, load default categories of COCO17
else: else:
if metric_type.lower() == 'rbox': if metric_type.lower() == 'rbox':
logger.warning("metric_type: {}, load default categories of DOTA.".format(metric_type))
return _dota_category() return _dota_category()
logger.warning("metric_type: {}, load default categories of COCO.".format(metric_type))
return _coco17_category() return _coco17_category()
elif metric_type.lower() == 'voc': elif metric_type.lower() == 'voc':
...@@ -77,6 +83,7 @@ def get_categories(metric_type, anno_file=None, arch=None): ...@@ -77,6 +83,7 @@ def get_categories(metric_type, anno_file=None, arch=None):
# anno file not exist, load default categories of # anno file not exist, load default categories of
# VOC all 20 categories # VOC all 20 categories
else: else:
logger.warning("metric_type: {}, load default categories of VOC.".format(metric_type))
return _vocall_category() return _vocall_category()
elif metric_type.lower() == 'oid': elif metric_type.lower() == 'oid':
...@@ -104,6 +111,7 @@ def get_categories(metric_type, anno_file=None, arch=None): ...@@ -104,6 +111,7 @@ def get_categories(metric_type, anno_file=None, arch=None):
return clsid2catid, catid2name return clsid2catid, catid2name
# anno file not exist, load default category 'pedestrian'. # anno file not exist, load default category 'pedestrian'.
else: else:
logger.warning("metric_type: {}, load default categories of pedestrian MOT.".format(metric_type))
return _mot_category(category='pedestrian') return _mot_category(category='pedestrian')
elif metric_type.lower() in ['kitti', 'bdd100kmot']: elif metric_type.lower() in ['kitti', 'bdd100kmot']:
...@@ -122,6 +130,7 @@ def get_categories(metric_type, anno_file=None, arch=None): ...@@ -122,6 +130,7 @@ def get_categories(metric_type, anno_file=None, arch=None):
return clsid2catid, catid2name return clsid2catid, catid2name
# anno file not exist, load default categories of visdrone all 10 categories # anno file not exist, load default categories of visdrone all 10 categories
else: else:
logger.warning("metric_type: {}, load default categories of VisDrone.".format(metric_type))
return _visdrone_category() return _visdrone_category()
else: else:
......
...@@ -26,8 +26,6 @@ from motmetrics.math_util import quiet_divide ...@@ -26,8 +26,6 @@ from motmetrics.math_util import quiet_divide
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import paddle
import paddle.nn.functional as F
from .metrics import Metric from .metrics import Metric
import motmetrics as mm import motmetrics as mm
import openpyxl import openpyxl
...@@ -311,7 +309,9 @@ class MCMOTEvaluator(object): ...@@ -311,7 +309,9 @@ class MCMOTEvaluator(object):
self.gt_filename = os.path.join(self.data_root, '../', self.gt_filename = os.path.join(self.data_root, '../',
'sequences', 'sequences',
'{}.txt'.format(self.seq_name)) '{}.txt'.format(self.seq_name))
if not os.path.exists(self.gt_filename):
logger.warning("gt_filename '{}' of MCMOTEvaluator is not exist, so the MOTA will be -inf.")
def reset_accumulator(self): def reset_accumulator(self):
import motmetrics as mm import motmetrics as mm
mm.lap.default_solver = 'lap' mm.lap.default_solver = 'lap'
......
...@@ -22,8 +22,7 @@ import sys ...@@ -22,8 +22,7 @@ import sys
import math import math
from collections import defaultdict from collections import defaultdict
import numpy as np import numpy as np
import paddle
import paddle.nn.functional as F
from ppdet.modeling.bbox_utils import bbox_iou_np_expand from ppdet.modeling.bbox_utils import bbox_iou_np_expand
from .map_utils import ap_per_class from .map_utils import ap_per_class
from .metrics import Metric from .metrics import Metric
...@@ -36,8 +35,10 @@ __all__ = ['MOTEvaluator', 'MOTMetric', 'JDEDetMetric', 'KITTIMOTMetric'] ...@@ -36,8 +35,10 @@ __all__ = ['MOTEvaluator', 'MOTMetric', 'JDEDetMetric', 'KITTIMOTMetric']
def read_mot_results(filename, is_gt=False, is_ignore=False): def read_mot_results(filename, is_gt=False, is_ignore=False):
valid_labels = {1} valid_label = [1]
ignore_labels = {2, 7, 8, 12} # only in motchallenge datasets like 'MOT16' ignore_labels = [2, 7, 8, 12] # only in motchallenge datasets like 'MOT16'
logger.info("In MOT16/17 dataset the valid_label of ground truth is '{}', "
"in other dataset it should be '0' for single classs MOT.".format(valid_label[0]))
results_dict = dict() results_dict = dict()
if os.path.isfile(filename): if os.path.isfile(filename):
with open(filename, 'r') as f: with open(filename, 'r') as f:
...@@ -50,12 +51,10 @@ def read_mot_results(filename, is_gt=False, is_ignore=False): ...@@ -50,12 +51,10 @@ def read_mot_results(filename, is_gt=False, is_ignore=False):
continue continue
results_dict.setdefault(fid, list()) results_dict.setdefault(fid, list())
box_size = float(linelist[4]) * float(linelist[5])
if is_gt: if is_gt:
label = int(float(linelist[7])) label = int(float(linelist[7]))
mark = int(float(linelist[6])) mark = int(float(linelist[6]))
if mark == 0 or label not in valid_labels: if mark == 0 or label not in valid_label:
continue continue
score = 1 score = 1
elif is_ignore: elif is_ignore:
...@@ -118,6 +117,8 @@ class MOTEvaluator(object): ...@@ -118,6 +117,8 @@ class MOTEvaluator(object):
assert self.data_type == 'mot' assert self.data_type == 'mot'
gt_filename = os.path.join(self.data_root, self.seq_name, 'gt', gt_filename = os.path.join(self.data_root, self.seq_name, 'gt',
'gt.txt') 'gt.txt')
if not os.path.exists(gt_filename):
logger.warning("gt_filename '{}' of MOTEvaluator is not exist, so the MOTA will be -inf.")
self.gt_frame_dict = read_mot_results(gt_filename, is_gt=True) self.gt_frame_dict = read_mot_results(gt_filename, is_gt=True)
self.gt_ignore_frame_dict = read_mot_results( self.gt_ignore_frame_dict = read_mot_results(
gt_filename, is_ignore=True) gt_filename, is_ignore=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册