diff --git a/configs/datasets/roadsign_voc.yml b/configs/datasets/roadsign_voc.yml index ddbfc7889e0027d85971c6ab11f3f33adfe8be71..9a081611aa8dafef5d5c6f1af1476cc038db5702 100644 --- a/configs/datasets/roadsign_voc.yml +++ b/configs/datasets/roadsign_voc.yml @@ -3,19 +3,19 @@ map_type: integral num_classes: 4 TrainDataset: - !VOCDataSet - dataset_dir: dataset/roadsign_voc - anno_path: train.txt - label_list: label_list.txt - data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult'] + name: VOCDataSet + dataset_dir: dataset/roadsign_voc + anno_path: train.txt + label_list: label_list.txt + data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult'] EvalDataset: - !VOCDataSet - dataset_dir: dataset/roadsign_voc - anno_path: valid.txt - label_list: label_list.txt - data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult'] + name: VOCDataSet + dataset_dir: dataset/roadsign_voc + anno_path: valid.txt + label_list: label_list.txt + data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult'] TestDataset: - !ImageFolder - anno_path: dataset/roadsign_voc/label_list.txt + name: ImageFolder + anno_path: dataset/roadsign_voc/label_list.txt diff --git a/configs/yolov3/_base_/optimizer_40e.yml b/configs/yolov3/_base_/optimizer_40e.yml index 0f858df59921e20398e34d019277e39c10abd583..7cf676d7119162d55dc0a2566c0590457344cfd3 100644 --- a/configs/yolov3/_base_/optimizer_40e.yml +++ b/configs/yolov3/_base_/optimizer_40e.yml @@ -3,12 +3,12 @@ epoch: 40 LearningRate: base_lr: 0.0001 schedulers: - - !PiecewiseDecay + - name: PiecewiseDecay gamma: 0.1 milestones: - 32 - 36 - - !LinearWarmup + - name: LinearWarmup start_factor: 0.3333333333333333 steps: 100 diff --git a/ppdet/core/workspace.py b/ppdet/core/workspace.py index e633746ed804e29ca9cc53c9b6cf39c1a8a168a6..e56feb31be4d02e81abcdfb6a33fbfc111abb1cc 100644 --- a/ppdet/core/workspace.py +++ b/ppdet/core/workspace.py @@ -210,9 +210,17 @@ def create(cls_or_name, **kwargs): assert type(cls_or_name) in [type, str ], "should be a class or name of a class" name = type(cls_or_name) == str and cls_or_name or cls_or_name.__name__ - assert name in global_config and \ - isinstance(global_config[name], SchemaDict), \ - "the module {} is not registered".format(name) + if name in global_config: + if isinstance(global_config[name], SchemaDict): + pass + elif hasattr(global_config[name], "__dict__"): + # support instance return directly + return global_config[name] + else: + raise ValueError("The module {} is not registered".format(name)) + else: + raise ValueError("The module {} is not registered".format(name)) + config = global_config[name] cls = getattr(config.pymodule, name) cls_kwargs = {} diff --git a/ppdet/data/source/dataset.py b/ppdet/data/source/dataset.py index b8193c1c92eb47f26db27e9fd601c1c657e8dd63..51de675ef09e0ac7485a68a674bb997fa0c7696c 100644 --- a/ppdet/data/source/dataset.py +++ b/ppdet/data/source/dataset.py @@ -23,6 +23,7 @@ from paddle.io import Dataset from ppdet.core.workspace import register, serializable from ppdet.utils.download import get_dataset_path import copy +import ppdet.data.source as source @serializable @@ -60,6 +61,9 @@ class DetDataset(Dataset): def __len__(self, ): return len(self.roidbs) + def __call__(self, *args, **kwargs): + return self + def __getitem__(self, idx): # data batch roidb = copy.deepcopy(self.roidbs[idx]) @@ -198,3 +202,40 @@ class ImageFolder(DetDataset): def set_images(self, images): self.image_dir = images self.roidbs = self._load_images() + + +@register +class CommonDataset(object): + def __init__(self, **dataset_args): + super(CommonDataset, self).__init__() + dataset_args = copy.deepcopy(dataset_args) + type = dataset_args.pop("name") + self.dataset = getattr(source, type)(**dataset_args) + + def __call__(self): + return self.dataset + + +@register +class TrainDataset(CommonDataset): + pass + + +@register +class EvalMOTDataset(CommonDataset): + pass + + +@register +class TestMOTDataset(CommonDataset): + pass + + +@register +class EvalDataset(CommonDataset): + pass + + +@register +class TestDataset(CommonDataset): + pass diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index 6697d3c4111d1d6a370a198791aac60c8cf17231..b08eb50de4bfcbd6cd0545cf8c3c7b9a41678341 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -68,9 +68,10 @@ class Trainer(object): # build data loader if cfg.architecture in MOT_ARCH and self.mode in ['eval', 'test']: - self.dataset = cfg['{}MOTDataset'.format(self.mode.capitalize())] + self.dataset = create('{}MOTDataset'.format(self.mode.capitalize( + )))() else: - self.dataset = cfg['{}Dataset'.format(self.mode.capitalize())] + self.dataset = create('{}Dataset'.format(self.mode.capitalize()))() if cfg.architecture == 'DeepSORT' and self.mode == 'train': logger.error('DeepSORT has no need of training on mot dataset.') diff --git a/ppdet/optimizer.py b/ppdet/optimizer.py index 98a5e5f9caa4bbf436009b9f3860222648bf716b..b591062b3e51ddd0ca77fc832957eb8ba8ea2b1f 100644 --- a/ppdet/optimizer.py +++ b/ppdet/optimizer.py @@ -16,6 +16,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import sys import math import weakref import paddle @@ -25,6 +26,7 @@ import paddle.optimizer as optimizer import paddle.regularizer as regularizer from ppdet.core.workspace import register, serializable +import copy __all__ = ['LearningRate', 'OptimizerBuilder'] @@ -252,7 +254,18 @@ class LearningRate(object): schedulers=[PiecewiseDecay(), LinearWarmup()]): super(LearningRate, self).__init__() self.base_lr = base_lr - self.schedulers = schedulers + self.schedulers = [] + + schedulers = copy.deepcopy(schedulers) + for sched in schedulers: + if isinstance(sched, dict): + # support dict sched instantiate + module = sys.modules[__name__] + type = sched.pop("name") + scheduler = getattr(module, type)(**sched) + self.schedulers.append(scheduler) + else: + self.schedulers.append(sched) def __call__(self, step_per_epoch): assert len(self.schedulers) >= 1