未验证 提交 9535ca87 编写于 作者: W wangguanzhong 提交者: GitHub

yaml support no constructor (#5847)

上级 ea4cacde
...@@ -3,19 +3,19 @@ map_type: integral ...@@ -3,19 +3,19 @@ map_type: integral
num_classes: 4 num_classes: 4
TrainDataset: TrainDataset:
!VOCDataSet name: VOCDataSet
dataset_dir: dataset/roadsign_voc dataset_dir: dataset/roadsign_voc
anno_path: train.txt anno_path: train.txt
label_list: label_list.txt label_list: label_list.txt
data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult'] data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult']
EvalDataset: EvalDataset:
!VOCDataSet name: VOCDataSet
dataset_dir: dataset/roadsign_voc dataset_dir: dataset/roadsign_voc
anno_path: valid.txt anno_path: valid.txt
label_list: label_list.txt label_list: label_list.txt
data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult'] data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult']
TestDataset: TestDataset:
!ImageFolder name: ImageFolder
anno_path: dataset/roadsign_voc/label_list.txt anno_path: dataset/roadsign_voc/label_list.txt
...@@ -3,12 +3,12 @@ epoch: 40 ...@@ -3,12 +3,12 @@ epoch: 40
LearningRate: LearningRate:
base_lr: 0.0001 base_lr: 0.0001
schedulers: schedulers:
- !PiecewiseDecay - name: PiecewiseDecay
gamma: 0.1 gamma: 0.1
milestones: milestones:
- 32 - 32
- 36 - 36
- !LinearWarmup - name: LinearWarmup
start_factor: 0.3333333333333333 start_factor: 0.3333333333333333
steps: 100 steps: 100
......
...@@ -210,9 +210,17 @@ def create(cls_or_name, **kwargs): ...@@ -210,9 +210,17 @@ def create(cls_or_name, **kwargs):
assert type(cls_or_name) in [type, str assert type(cls_or_name) in [type, str
], "should be a class or name of a class" ], "should be a class or name of a class"
name = type(cls_or_name) == str and cls_or_name or cls_or_name.__name__ name = type(cls_or_name) == str and cls_or_name or cls_or_name.__name__
assert name in global_config and \ if name in global_config:
isinstance(global_config[name], SchemaDict), \ if isinstance(global_config[name], SchemaDict):
"the module {} is not registered".format(name) pass
elif hasattr(global_config[name], "__dict__"):
# support instance return directly
return global_config[name]
else:
raise ValueError("The module {} is not registered".format(name))
else:
raise ValueError("The module {} is not registered".format(name))
config = global_config[name] config = global_config[name]
cls = getattr(config.pymodule, name) cls = getattr(config.pymodule, name)
cls_kwargs = {} cls_kwargs = {}
......
...@@ -23,6 +23,7 @@ from paddle.io import Dataset ...@@ -23,6 +23,7 @@ from paddle.io import Dataset
from ppdet.core.workspace import register, serializable from ppdet.core.workspace import register, serializable
from ppdet.utils.download import get_dataset_path from ppdet.utils.download import get_dataset_path
import copy import copy
import ppdet.data.source as source
@serializable @serializable
...@@ -60,6 +61,9 @@ class DetDataset(Dataset): ...@@ -60,6 +61,9 @@ class DetDataset(Dataset):
def __len__(self, ): def __len__(self, ):
return len(self.roidbs) return len(self.roidbs)
def __call__(self, *args, **kwargs):
return self
def __getitem__(self, idx): def __getitem__(self, idx):
# data batch # data batch
roidb = copy.deepcopy(self.roidbs[idx]) roidb = copy.deepcopy(self.roidbs[idx])
...@@ -198,3 +202,40 @@ class ImageFolder(DetDataset): ...@@ -198,3 +202,40 @@ class ImageFolder(DetDataset):
def set_images(self, images): def set_images(self, images):
self.image_dir = images self.image_dir = images
self.roidbs = self._load_images() self.roidbs = self._load_images()
@register
class CommonDataset(object):
def __init__(self, **dataset_args):
super(CommonDataset, self).__init__()
dataset_args = copy.deepcopy(dataset_args)
type = dataset_args.pop("name")
self.dataset = getattr(source, type)(**dataset_args)
def __call__(self):
return self.dataset
@register
class TrainDataset(CommonDataset):
pass
@register
class EvalMOTDataset(CommonDataset):
pass
@register
class TestMOTDataset(CommonDataset):
pass
@register
class EvalDataset(CommonDataset):
pass
@register
class TestDataset(CommonDataset):
pass
...@@ -68,9 +68,10 @@ class Trainer(object): ...@@ -68,9 +68,10 @@ class Trainer(object):
# build data loader # build data loader
if cfg.architecture in MOT_ARCH and self.mode in ['eval', 'test']: if cfg.architecture in MOT_ARCH and self.mode in ['eval', 'test']:
self.dataset = cfg['{}MOTDataset'.format(self.mode.capitalize())] self.dataset = create('{}MOTDataset'.format(self.mode.capitalize(
)))()
else: else:
self.dataset = cfg['{}Dataset'.format(self.mode.capitalize())] self.dataset = create('{}Dataset'.format(self.mode.capitalize()))()
if cfg.architecture == 'DeepSORT' and self.mode == 'train': if cfg.architecture == 'DeepSORT' and self.mode == 'train':
logger.error('DeepSORT has no need of training on mot dataset.') logger.error('DeepSORT has no need of training on mot dataset.')
......
...@@ -16,6 +16,7 @@ from __future__ import absolute_import ...@@ -16,6 +16,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import sys
import math import math
import weakref import weakref
import paddle import paddle
...@@ -25,6 +26,7 @@ import paddle.optimizer as optimizer ...@@ -25,6 +26,7 @@ import paddle.optimizer as optimizer
import paddle.regularizer as regularizer import paddle.regularizer as regularizer
from ppdet.core.workspace import register, serializable from ppdet.core.workspace import register, serializable
import copy
__all__ = ['LearningRate', 'OptimizerBuilder'] __all__ = ['LearningRate', 'OptimizerBuilder']
...@@ -252,7 +254,18 @@ class LearningRate(object): ...@@ -252,7 +254,18 @@ class LearningRate(object):
schedulers=[PiecewiseDecay(), LinearWarmup()]): schedulers=[PiecewiseDecay(), LinearWarmup()]):
super(LearningRate, self).__init__() super(LearningRate, self).__init__()
self.base_lr = base_lr self.base_lr = base_lr
self.schedulers = schedulers self.schedulers = []
schedulers = copy.deepcopy(schedulers)
for sched in schedulers:
if isinstance(sched, dict):
# support dict sched instantiate
module = sys.modules[__name__]
type = sched.pop("name")
scheduler = getattr(module, type)(**sched)
self.schedulers.append(scheduler)
else:
self.schedulers.append(sched)
def __call__(self, step_per_epoch): def __call__(self, step_per_epoch):
assert len(self.schedulers) >= 1 assert len(self.schedulers) >= 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册