未验证 提交 25d6c326 编写于 作者: W wangguanzhong 提交者: GitHub

fix unittest (#7807)

上级 e5e20169
...@@ -3,19 +3,19 @@ map_type: 11point ...@@ -3,19 +3,19 @@ map_type: 11point
num_classes: 20 num_classes: 20
TrainDataset: TrainDataset:
!VOCDataSet name: VOCDataSet
dataset_dir: dataset/voc dataset_dir: dataset/voc
anno_path: trainval.txt anno_path: trainval.txt
label_list: label_list.txt label_list: label_list.txt
data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult'] data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult']
EvalDataset: EvalDataset:
!VOCDataSet name: VOCDataSet
dataset_dir: dataset/voc dataset_dir: dataset/voc
anno_path: test.txt anno_path: test.txt
label_list: label_list.txt label_list: label_list.txt
data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult'] data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult']
TestDataset: TestDataset:
!ImageFolder name: ImageFolder
anno_path: dataset/voc/label_list.txt anno_path: dataset/voc/label_list.txt
use_gpu: true use_gpu: true
use_xpu: false use_xpu: false
use_mlu: false use_mlu: false
use_npu: false
log_iter: 20 log_iter: 20
save_dir: output save_dir: output
snapshot_epoch: 1 snapshot_epoch: 1
......
...@@ -67,6 +67,15 @@ class AttrDict(dict): ...@@ -67,6 +67,15 @@ class AttrDict(dict):
return self[key] return self[key]
raise AttributeError("object has no attribute '{}'".format(key)) raise AttributeError("object has no attribute '{}'".format(key))
def __setattr__(self, key, value):
self[key] = value
def copy(self):
new_dict = AttrDict()
for k, v in self.items():
new_dict.update({k: v})
return new_dict
global_config = AttrDict() global_config = AttrDict()
...@@ -280,4 +289,4 @@ def create(cls_or_name, **kwargs): ...@@ -280,4 +289,4 @@ def create(cls_or_name, **kwargs):
# prevent modification of global config values of reference types # prevent modification of global config values of reference types
# (e.g., list, dict) from within the created module instances # (e.g., list, dict) from within the created module instances
#kwargs = copy.deepcopy(kwargs) #kwargs = copy.deepcopy(kwargs)
return cls(**cls_kwargs) return cls(**cls_kwargs)
\ No newline at end of file
...@@ -62,7 +62,7 @@ MOT_ARCH = ['JDE', 'FairMOT', 'DeepSORT', 'ByteTrack', 'CenterTrack'] ...@@ -62,7 +62,7 @@ MOT_ARCH = ['JDE', 'FairMOT', 'DeepSORT', 'ByteTrack', 'CenterTrack']
class Trainer(object): class Trainer(object):
def __init__(self, cfg, mode='train'): def __init__(self, cfg, mode='train'):
self.cfg = cfg self.cfg = cfg.copy()
assert mode.lower() in ['train', 'eval', 'test'], \ assert mode.lower() in ['train', 'eval', 'test'], \
"mode should be 'train', 'eval' or 'test'" "mode should be 'train', 'eval' or 'test'"
self.mode = mode.lower() self.mode = mode.lower()
...@@ -99,12 +99,12 @@ class Trainer(object): ...@@ -99,12 +99,12 @@ class Trainer(object):
self.dataset, cfg.worker_num) self.dataset, cfg.worker_num)
if cfg.architecture == 'JDE' and self.mode == 'train': if cfg.architecture == 'JDE' and self.mode == 'train':
cfg['JDEEmbeddingHead'][ self.cfg['JDEEmbeddingHead'][
'num_identities'] = self.dataset.num_identities_dict[0] 'num_identities'] = self.dataset.num_identities_dict[0]
# JDE only support single class MOT now. # JDE only support single class MOT now.
if cfg.architecture == 'FairMOT' and self.mode == 'train': if cfg.architecture == 'FairMOT' and self.mode == 'train':
cfg['FairMOTEmbeddingHead'][ self.cfg['FairMOTEmbeddingHead'][
'num_identities_dict'] = self.dataset.num_identities_dict 'num_identities_dict'] = self.dataset.num_identities_dict
# FairMOT support single class and multi-class MOT now. # FairMOT support single class and multi-class MOT now.
...@@ -149,7 +149,7 @@ class Trainer(object): ...@@ -149,7 +149,7 @@ class Trainer(object):
reader_name = '{}Reader'.format(self.mode.capitalize()) reader_name = '{}Reader'.format(self.mode.capitalize())
# If metric is VOC, need to be set collate_batch=False. # If metric is VOC, need to be set collate_batch=False.
if cfg.metric == 'VOC': if cfg.metric == 'VOC':
cfg[reader_name]['collate_batch'] = False self.cfg[reader_name]['collate_batch'] = False
self.loader = create(reader_name)(self.dataset, cfg.worker_num, self.loader = create(reader_name)(self.dataset, cfg.worker_num,
self._eval_batch_sampler) self._eval_batch_sampler)
# TestDataset build after user set images, skip loader creation here # TestDataset build after user set images, skip loader creation here
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册