diff --git a/ppdet/core/workspace.py b/ppdet/core/workspace.py index e633746ed804e29ca9cc53c9b6cf39c1a8a168a6..e56feb31be4d02e81abcdfb6a33fbfc111abb1cc 100644 --- a/ppdet/core/workspace.py +++ b/ppdet/core/workspace.py @@ -210,9 +210,17 @@ def create(cls_or_name, **kwargs): assert type(cls_or_name) in [type, str ], "should be a class or name of a class" name = type(cls_or_name) == str and cls_or_name or cls_or_name.__name__ - assert name in global_config and \ - isinstance(global_config[name], SchemaDict), \ - "the module {} is not registered".format(name) + if name in global_config: + if isinstance(global_config[name], SchemaDict): + pass + elif hasattr(global_config[name], "__dict__"): + # support instance return directly + return global_config[name] + else: + raise ValueError("The module {} is not registered".format(name)) + else: + raise ValueError("The module {} is not registered".format(name)) + config = global_config[name] cls = getattr(config.pymodule, name) cls_kwargs = {} diff --git a/ppdet/data/source/__init__.py b/ppdet/data/source/__init__.py index 3854d3d2530b032b3c84d1ab5f2e01ea963c5c70..e3abb16b606de5501886f1a615fd25a7cd114e61 100644 --- a/ppdet/data/source/__init__.py +++ b/ppdet/data/source/__init__.py @@ -27,3 +27,4 @@ from .category import * from .keypoint_coco import * from .mot import * from .sniper_coco import SniperCOCODataSet +from .dataset import ImageFolder diff --git a/ppdet/data/source/dataset.py b/ppdet/data/source/dataset.py index 1bef548e696764964608ade67b373a1c19c84a96..2ac94e2a45338c1c94280e679a0b1d6c15b9a5b1 100644 --- a/ppdet/data/source/dataset.py +++ b/ppdet/data/source/dataset.py @@ -23,6 +23,7 @@ from paddle.io import Dataset from ppdet.core.workspace import register, serializable from ppdet.utils.download import get_dataset_path import copy +import ppdet.data.source as source @serializable @@ -60,6 +61,9 @@ class DetDataset(Dataset): def __len__(self, ): return len(self.roidbs) + def __call__(self, *args, **kwargs): + return self + def __getitem__(self, idx): # data batch roidb = copy.deepcopy(self.roidbs[idx]) @@ -195,3 +199,40 @@ class ImageFolder(DetDataset): def set_images(self, images): self.image_dir = images self.roidbs = self._load_images() + + +@register +class CommonDataset(object): + def __init__(self, **dataset_args): + super(CommonDataset, self).__init__() + dataset_args = copy.deepcopy(dataset_args) + type = dataset_args.pop("name") + self.dataset = getattr(source, type)(**dataset_args) + + def __call__(self): + return self.dataset + + +@register +class TrainDataset(CommonDataset): + pass + + +@register +class EvalMOTDataset(CommonDataset): + pass + + +@register +class TestMOTDataset(CommonDataset): + pass + + +@register +class EvalDataset(CommonDataset): + pass + + +@register +class TestDataset(CommonDataset): + pass diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index 671d373bb13c33d7aa4d4734ad606d62035d7dd8..5aaabcacb4683ad529760d2f60aabaf01133c90a 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -63,17 +63,20 @@ class Trainer(object): self.is_loaded_weights = False # build data loader + capital_mode = self.mode.capitalize() if cfg.architecture in MOT_ARCH and self.mode in ['eval', 'test']: - self.dataset = cfg['{}MOTDataset'.format(self.mode.capitalize())] + self.dataset = self.cfg['{}MOTDataset'.format( + capital_mode)] = create('{}MOTDataset'.format(capital_mode))() else: - self.dataset = cfg['{}Dataset'.format(self.mode.capitalize())] + self.dataset = self.cfg['{}Dataset'.format(capital_mode)] = create( + '{}Dataset'.format(capital_mode))() if cfg.architecture == 'DeepSORT' and self.mode == 'train': logger.error('DeepSORT has no need of training on mot dataset.') sys.exit(1) if self.mode == 'train': - self.loader = create('{}Reader'.format(self.mode.capitalize()))( + self.loader = create('{}Reader'.format(capital_mode))( self.dataset, cfg.worker_num) if cfg.architecture == 'JDE' and self.mode == 'train': @@ -335,6 +338,8 @@ class Trainer(object): def train(self, validate=False): assert self.mode == 'train', "Model not in 'train' mode" Init_mark = False + if validate: + self.cfg.EvalDataset = create("EvalDataset")() model = self.model if self.cfg.get('fleet', False): diff --git a/ppdet/optimizer.py b/ppdet/optimizer.py index fcdcbd8d64475949e211fa09f3c4baffc16abf9a..360daa660157d10197d5b844d042fb262d693df6 100644 --- a/ppdet/optimizer.py +++ b/ppdet/optimizer.py @@ -16,6 +16,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import sys import math import paddle import paddle.nn as nn @@ -24,6 +25,7 @@ import paddle.optimizer as optimizer import paddle.regularizer as regularizer from ppdet.core.workspace import register, serializable +import copy __all__ = ['LearningRate', 'OptimizerBuilder'] @@ -188,7 +190,18 @@ class LearningRate(object): schedulers=[PiecewiseDecay(), LinearWarmup()]): super(LearningRate, self).__init__() self.base_lr = base_lr - self.schedulers = schedulers + self.schedulers = [] + + schedulers = copy.deepcopy(schedulers) + for sched in schedulers: + if isinstance(sched, dict): + # support dict sched instantiate + module = sys.modules[__name__] + type = sched.pop("name") + scheduler = getattr(module, type)(**sched) + self.schedulers.append(scheduler) + else: + self.schedulers.append(sched) def __call__(self, step_per_epoch): assert len(self.schedulers) >= 1