未验证 提交 9535ca87 编写于 作者: W wangguanzhong 提交者: GitHub

yaml support no constructor (#5847)

上级 ea4cacde
......@@ -3,19 +3,19 @@ map_type: integral
num_classes: 4
TrainDataset:
!VOCDataSet
dataset_dir: dataset/roadsign_voc
anno_path: train.txt
label_list: label_list.txt
data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult']
name: VOCDataSet
dataset_dir: dataset/roadsign_voc
anno_path: train.txt
label_list: label_list.txt
data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult']
EvalDataset:
!VOCDataSet
dataset_dir: dataset/roadsign_voc
anno_path: valid.txt
label_list: label_list.txt
data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult']
name: VOCDataSet
dataset_dir: dataset/roadsign_voc
anno_path: valid.txt
label_list: label_list.txt
data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult']
TestDataset:
!ImageFolder
anno_path: dataset/roadsign_voc/label_list.txt
name: ImageFolder
anno_path: dataset/roadsign_voc/label_list.txt
......@@ -3,12 +3,12 @@ epoch: 40
LearningRate:
base_lr: 0.0001
schedulers:
- !PiecewiseDecay
- name: PiecewiseDecay
gamma: 0.1
milestones:
- 32
- 36
- !LinearWarmup
- name: LinearWarmup
start_factor: 0.3333333333333333
steps: 100
......
......@@ -210,9 +210,17 @@ def create(cls_or_name, **kwargs):
assert type(cls_or_name) in [type, str
], "should be a class or name of a class"
name = type(cls_or_name) == str and cls_or_name or cls_or_name.__name__
assert name in global_config and \
isinstance(global_config[name], SchemaDict), \
"the module {} is not registered".format(name)
if name in global_config:
if isinstance(global_config[name], SchemaDict):
pass
elif hasattr(global_config[name], "__dict__"):
# support instance return directly
return global_config[name]
else:
raise ValueError("The module {} is not registered".format(name))
else:
raise ValueError("The module {} is not registered".format(name))
config = global_config[name]
cls = getattr(config.pymodule, name)
cls_kwargs = {}
......
......@@ -23,6 +23,7 @@ from paddle.io import Dataset
from ppdet.core.workspace import register, serializable
from ppdet.utils.download import get_dataset_path
import copy
import ppdet.data.source as source
@serializable
......@@ -60,6 +61,9 @@ class DetDataset(Dataset):
def __len__(self, ):
return len(self.roidbs)
def __call__(self, *args, **kwargs):
return self
def __getitem__(self, idx):
# data batch
roidb = copy.deepcopy(self.roidbs[idx])
......@@ -198,3 +202,40 @@ class ImageFolder(DetDataset):
def set_images(self, images):
self.image_dir = images
self.roidbs = self._load_images()
@register
class CommonDataset(object):
def __init__(self, **dataset_args):
super(CommonDataset, self).__init__()
dataset_args = copy.deepcopy(dataset_args)
type = dataset_args.pop("name")
self.dataset = getattr(source, type)(**dataset_args)
def __call__(self):
return self.dataset
@register
class TrainDataset(CommonDataset):
pass
@register
class EvalMOTDataset(CommonDataset):
pass
@register
class TestMOTDataset(CommonDataset):
pass
@register
class EvalDataset(CommonDataset):
pass
@register
class TestDataset(CommonDataset):
pass
......@@ -68,9 +68,10 @@ class Trainer(object):
# build data loader
if cfg.architecture in MOT_ARCH and self.mode in ['eval', 'test']:
self.dataset = cfg['{}MOTDataset'.format(self.mode.capitalize())]
self.dataset = create('{}MOTDataset'.format(self.mode.capitalize(
)))()
else:
self.dataset = cfg['{}Dataset'.format(self.mode.capitalize())]
self.dataset = create('{}Dataset'.format(self.mode.capitalize()))()
if cfg.architecture == 'DeepSORT' and self.mode == 'train':
logger.error('DeepSORT has no need of training on mot dataset.')
......
......@@ -16,6 +16,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import math
import weakref
import paddle
......@@ -25,6 +26,7 @@ import paddle.optimizer as optimizer
import paddle.regularizer as regularizer
from ppdet.core.workspace import register, serializable
import copy
__all__ = ['LearningRate', 'OptimizerBuilder']
......@@ -252,7 +254,18 @@ class LearningRate(object):
schedulers=[PiecewiseDecay(), LinearWarmup()]):
super(LearningRate, self).__init__()
self.base_lr = base_lr
self.schedulers = schedulers
self.schedulers = []
schedulers = copy.deepcopy(schedulers)
for sched in schedulers:
if isinstance(sched, dict):
# support dict sched instantiate
module = sys.modules[__name__]
type = sched.pop("name")
scheduler = getattr(module, type)(**sched)
self.schedulers.append(scheduler)
else:
self.schedulers.append(sched)
def __call__(self, step_per_epoch):
assert len(self.schedulers) >= 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册