From 25d6c326938cc1c64c143464fba7cfa2a8849979 Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Wed, 22 Feb 2023 16:00:25 +0800 Subject: [PATCH] fix unittest (#7807) --- configs/datasets/voc.yml | 24 ++++++++++++------------ configs/runtime.yml | 1 + ppdet/core/workspace.py | 11 ++++++++++- ppdet/engine/trainer.py | 8 ++++---- 4 files changed, 27 insertions(+), 17 deletions(-) diff --git a/configs/datasets/voc.yml b/configs/datasets/voc.yml index 9fb492f03..72182bed9 100644 --- a/configs/datasets/voc.yml +++ b/configs/datasets/voc.yml @@ -3,19 +3,19 @@ map_type: 11point num_classes: 20 TrainDataset: - !VOCDataSet - dataset_dir: dataset/voc - anno_path: trainval.txt - label_list: label_list.txt - data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult'] + name: VOCDataSet + dataset_dir: dataset/voc + anno_path: trainval.txt + label_list: label_list.txt + data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult'] EvalDataset: - !VOCDataSet - dataset_dir: dataset/voc - anno_path: test.txt - label_list: label_list.txt - data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult'] + name: VOCDataSet + dataset_dir: dataset/voc + anno_path: test.txt + label_list: label_list.txt + data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult'] TestDataset: - !ImageFolder - anno_path: dataset/voc/label_list.txt + name: ImageFolder + anno_path: dataset/voc/label_list.txt diff --git a/configs/runtime.yml b/configs/runtime.yml index c0920da55..a58b171ce 100644 --- a/configs/runtime.yml +++ b/configs/runtime.yml @@ -1,6 +1,7 @@ use_gpu: true use_xpu: false use_mlu: false +use_npu: false log_iter: 20 save_dir: output snapshot_epoch: 1 diff --git a/ppdet/core/workspace.py b/ppdet/core/workspace.py index b3c932c0a..6735bcfc2 100644 --- a/ppdet/core/workspace.py +++ b/ppdet/core/workspace.py @@ -67,6 +67,15 @@ class AttrDict(dict): return self[key] raise AttributeError("object has no attribute '{}'".format(key)) + def __setattr__(self, key, value): + self[key] = value + + def copy(self): + new_dict = AttrDict() + for k, v in self.items(): + new_dict.update({k: v}) + return new_dict + global_config = AttrDict() @@ -280,4 +289,4 @@ def create(cls_or_name, **kwargs): # prevent modification of global config values of reference types # (e.g., list, dict) from within the created module instances #kwargs = copy.deepcopy(kwargs) - return cls(**cls_kwargs) \ No newline at end of file + return cls(**cls_kwargs) diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index ae0e21d8e..0378e00ec 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -62,7 +62,7 @@ MOT_ARCH = ['JDE', 'FairMOT', 'DeepSORT', 'ByteTrack', 'CenterTrack'] class Trainer(object): def __init__(self, cfg, mode='train'): - self.cfg = cfg + self.cfg = cfg.copy() assert mode.lower() in ['train', 'eval', 'test'], \ "mode should be 'train', 'eval' or 'test'" self.mode = mode.lower() @@ -99,12 +99,12 @@ class Trainer(object): self.dataset, cfg.worker_num) if cfg.architecture == 'JDE' and self.mode == 'train': - cfg['JDEEmbeddingHead'][ + self.cfg['JDEEmbeddingHead'][ 'num_identities'] = self.dataset.num_identities_dict[0] # JDE only support single class MOT now. if cfg.architecture == 'FairMOT' and self.mode == 'train': - cfg['FairMOTEmbeddingHead'][ + self.cfg['FairMOTEmbeddingHead'][ 'num_identities_dict'] = self.dataset.num_identities_dict # FairMOT support single class and multi-class MOT now. @@ -149,7 +149,7 @@ class Trainer(object): reader_name = '{}Reader'.format(self.mode.capitalize()) # If metric is VOC, need to be set collate_batch=False. if cfg.metric == 'VOC': - cfg[reader_name]['collate_batch'] = False + self.cfg[reader_name]['collate_batch'] = False self.loader = create(reader_name)(self.dataset, cfg.worker_num, self._eval_batch_sampler) # TestDataset build after user set images, skip loader creation here -- GitLab