未验证 提交 4e33a499 编写于 作者: W wangguanzhong 提交者: GitHub

Cherry pick yaml constructor (#5917)

* yaml support no constructor

* fix yaml constructor in eval & export

* fix typo

* revert config
上级 6ef164f6
......@@ -210,9 +210,17 @@ def create(cls_or_name, **kwargs):
assert type(cls_or_name) in [type, str
], "should be a class or name of a class"
name = type(cls_or_name) == str and cls_or_name or cls_or_name.__name__
assert name in global_config and \
isinstance(global_config[name], SchemaDict), \
"the module {} is not registered".format(name)
if name in global_config:
if isinstance(global_config[name], SchemaDict):
pass
elif hasattr(global_config[name], "__dict__"):
# support instance return directly
return global_config[name]
else:
raise ValueError("The module {} is not registered".format(name))
else:
raise ValueError("The module {} is not registered".format(name))
config = global_config[name]
cls = getattr(config.pymodule, name)
cls_kwargs = {}
......
......@@ -27,3 +27,4 @@ from .category import *
from .keypoint_coco import *
from .mot import *
from .sniper_coco import SniperCOCODataSet
from .dataset import ImageFolder
......@@ -23,6 +23,7 @@ from paddle.io import Dataset
from ppdet.core.workspace import register, serializable
from ppdet.utils.download import get_dataset_path
import copy
import ppdet.data.source as source
@serializable
......@@ -60,6 +61,9 @@ class DetDataset(Dataset):
def __len__(self, ):
return len(self.roidbs)
def __call__(self, *args, **kwargs):
return self
def __getitem__(self, idx):
# data batch
roidb = copy.deepcopy(self.roidbs[idx])
......@@ -198,3 +202,40 @@ class ImageFolder(DetDataset):
def set_images(self, images):
self.image_dir = images
self.roidbs = self._load_images()
@register
class CommonDataset(object):
def __init__(self, **dataset_args):
super(CommonDataset, self).__init__()
dataset_args = copy.deepcopy(dataset_args)
type = dataset_args.pop("name")
self.dataset = getattr(source, type)(**dataset_args)
def __call__(self):
return self.dataset
@register
class TrainDataset(CommonDataset):
pass
@register
class EvalMOTDataset(CommonDataset):
pass
@register
class TestMOTDataset(CommonDataset):
pass
@register
class EvalDataset(CommonDataset):
pass
@register
class TestDataset(CommonDataset):
pass
......@@ -67,10 +67,13 @@ class Trainer(object):
self.is_loaded_weights = False
# build data loader
capital_mode = self.mode.capitalize()
if cfg.architecture in MOT_ARCH and self.mode in ['eval', 'test']:
self.dataset = cfg['{}MOTDataset'.format(self.mode.capitalize())]
self.dataset = self.cfg['{}MOTDataset'.format(
capital_mode)] = create('{}MOTDataset'.format(capital_mode))()
else:
self.dataset = cfg['{}Dataset'.format(self.mode.capitalize())]
self.dataset = self.cfg['{}Dataset'.format(capital_mode)] = create(
'{}Dataset'.format(capital_mode))()
if cfg.architecture == 'DeepSORT' and self.mode == 'train':
logger.error('DeepSORT has no need of training on mot dataset.')
......@@ -81,7 +84,7 @@ class Trainer(object):
self.dataset.set_images(images)
if self.mode == 'train':
self.loader = create('{}Reader'.format(self.mode.capitalize()))(
self.loader = create('{}Reader'.format(capital_mode))(
self.dataset, cfg.worker_num)
if cfg.architecture == 'JDE' and self.mode == 'train':
......@@ -370,6 +373,8 @@ class Trainer(object):
def train(self, validate=False):
assert self.mode == 'train', "Model not in 'train' mode"
Init_mark = False
if validate:
self.cfg.EvalDataset = create("EvalDataset")()
sync_bn = (getattr(self.cfg, 'norm_type', None) == 'sync_bn' and
self.cfg.use_gpu and self._nranks > 1)
......
......@@ -16,6 +16,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import math
import weakref
import paddle
......@@ -25,6 +26,7 @@ import paddle.optimizer as optimizer
import paddle.regularizer as regularizer
from ppdet.core.workspace import register, serializable
import copy
__all__ = ['LearningRate', 'OptimizerBuilder']
......@@ -252,7 +254,18 @@ class LearningRate(object):
schedulers=[PiecewiseDecay(), LinearWarmup()]):
super(LearningRate, self).__init__()
self.base_lr = base_lr
self.schedulers = schedulers
self.schedulers = []
schedulers = copy.deepcopy(schedulers)
for sched in schedulers:
if isinstance(sched, dict):
# support dict sched instantiate
module = sys.modules[__name__]
type = sched.pop("name")
scheduler = getattr(module, type)(**sched)
self.schedulers.append(scheduler)
else:
self.schedulers.append(sched)
def __call__(self, step_per_epoch):
assert len(self.schedulers) >= 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册