未验证 提交 25d6c326 编写于 作者: W wangguanzhong 提交者: GitHub

fix unittest (#7807)

上级 e5e20169
......@@ -3,19 +3,19 @@ map_type: 11point
num_classes: 20
TrainDataset:
!VOCDataSet
dataset_dir: dataset/voc
anno_path: trainval.txt
label_list: label_list.txt
data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult']
name: VOCDataSet
dataset_dir: dataset/voc
anno_path: trainval.txt
label_list: label_list.txt
data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult']
EvalDataset:
!VOCDataSet
dataset_dir: dataset/voc
anno_path: test.txt
label_list: label_list.txt
data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult']
name: VOCDataSet
dataset_dir: dataset/voc
anno_path: test.txt
label_list: label_list.txt
data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult']
TestDataset:
!ImageFolder
anno_path: dataset/voc/label_list.txt
name: ImageFolder
anno_path: dataset/voc/label_list.txt
use_gpu: true
use_xpu: false
use_mlu: false
use_npu: false
log_iter: 20
save_dir: output
snapshot_epoch: 1
......
......@@ -67,6 +67,15 @@ class AttrDict(dict):
return self[key]
raise AttributeError("object has no attribute '{}'".format(key))
def __setattr__(self, key, value):
self[key] = value
def copy(self):
new_dict = AttrDict()
for k, v in self.items():
new_dict.update({k: v})
return new_dict
global_config = AttrDict()
......@@ -280,4 +289,4 @@ def create(cls_or_name, **kwargs):
# prevent modification of global config values of reference types
# (e.g., list, dict) from within the created module instances
#kwargs = copy.deepcopy(kwargs)
return cls(**cls_kwargs)
\ No newline at end of file
return cls(**cls_kwargs)
......@@ -62,7 +62,7 @@ MOT_ARCH = ['JDE', 'FairMOT', 'DeepSORT', 'ByteTrack', 'CenterTrack']
class Trainer(object):
def __init__(self, cfg, mode='train'):
self.cfg = cfg
self.cfg = cfg.copy()
assert mode.lower() in ['train', 'eval', 'test'], \
"mode should be 'train', 'eval' or 'test'"
self.mode = mode.lower()
......@@ -99,12 +99,12 @@ class Trainer(object):
self.dataset, cfg.worker_num)
if cfg.architecture == 'JDE' and self.mode == 'train':
cfg['JDEEmbeddingHead'][
self.cfg['JDEEmbeddingHead'][
'num_identities'] = self.dataset.num_identities_dict[0]
# JDE only support single class MOT now.
if cfg.architecture == 'FairMOT' and self.mode == 'train':
cfg['FairMOTEmbeddingHead'][
self.cfg['FairMOTEmbeddingHead'][
'num_identities_dict'] = self.dataset.num_identities_dict
# FairMOT support single class and multi-class MOT now.
......@@ -149,7 +149,7 @@ class Trainer(object):
reader_name = '{}Reader'.format(self.mode.capitalize())
# If metric is VOC, need to be set collate_batch=False.
if cfg.metric == 'VOC':
cfg[reader_name]['collate_batch'] = False
self.cfg[reader_name]['collate_batch'] = False
self.loader = create(reader_name)(self.dataset, cfg.worker_num,
self._eval_batch_sampler)
# TestDataset build after user set images, skip loader creation here
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册