提交 2e0bc41c 编写于 作者: W wuzewu

Print config

上级 40ed988d
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
import codecs import codecs
import os import os
from typing import Any, Callable from typing import Any, Callable
import pprint
import yaml import yaml
import paddle import paddle
...@@ -37,10 +36,10 @@ class Config(object): ...@@ -37,10 +36,10 @@ class Config(object):
if not os.path.exists(path): if not os.path.exists(path):
raise FileNotFoundError('File {} does not exist'.format(path)) raise FileNotFoundError('File {} does not exist'.format(path))
self._model = None
self._losses = None
if path.endswith('yml') or path.endswith('yaml'): if path.endswith('yml') or path.endswith('yaml'):
dic = self._parse_from_yaml(path) self.dic = self._parse_from_yaml(path)
logger.info('\n' + pprint.pformat(dic))
self._build(dic)
else: else:
raise RuntimeError('Config file should in yaml format!') raise RuntimeError('Config file should in yaml format!')
...@@ -61,6 +60,7 @@ class Config(object): ...@@ -61,6 +60,7 @@ class Config(object):
'''Parse a yaml file and build config''' '''Parse a yaml file and build config'''
with codecs.open(path, 'r', 'utf-8') as file: with codecs.open(path, 'r', 'utf-8') as file:
dic = yaml.load(file, Loader=yaml.FullLoader) dic = yaml.load(file, Loader=yaml.FullLoader)
if '_base_' in dic: if '_base_' in dic:
cfg_dir = os.path.dirname(path) cfg_dir = os.path.dirname(path)
base_path = dic.pop('_base_') base_path = dic.pop('_base_')
...@@ -69,111 +69,85 @@ class Config(object): ...@@ -69,111 +69,85 @@ class Config(object):
dic = self._update_dic(dic, base_dic) dic = self._update_dic(dic, base_dic)
return dic return dic
def _build(self, dic: dict):
'''Build config from dictionary'''
dic = dic.copy()
self._batch_size = dic.get('batch_size', 1)
self._iters = dic.get('iters')
if 'model' not in dic:
raise RuntimeError()
self._model_cfg = dic['model']
self._model = None
self._train_dataset = dic.get('train_dataset')
self._val_dataset = dic.get('val_dataset')
self._learning_rate_cfg = dic.get('learning_rate', {})
self._learning_rate = self._learning_rate_cfg.get('value')
self._decay = self._learning_rate_cfg.get('decay', {
'type': 'poly',
'power': 0.9
})
self._loss_cfg = dic.get('loss', {})
self._losses = None
self._optimizer_cfg = dic.get('optimizer', {})
def update(self, def update(self,
learning_rate: float = None, learning_rate: float = None,
batch_size: int = None, batch_size: int = None,
iters: int = None): iters: int = None):
'''Update config''' '''Update config'''
if learning_rate: if learning_rate:
self._learning_rate = learning_rate self.dic['learning_rate']['value'] = learning_rate
if batch_size: if batch_size:
self._batch_size = batch_size self.dic['batch_size'] = batch_size
if iters: if iters:
self._iters = iters self.dic['iters'] = iters
@property @property
def batch_size(self) -> int: def batch_size(self) -> int:
return self._batch_size return self.dic.get('batch_size', 1)
@property @property
def iters(self) -> int: def iters(self) -> int:
if not self._iters: iters = self.dic.get('iters')
if not iters:
raise RuntimeError('No iters specified in the configuration file.') raise RuntimeError('No iters specified in the configuration file.')
return self._iters return iters
@property @property
def learning_rate(self) -> float: def learning_rate(self) -> float:
if not self._learning_rate: _learning_rate = self.dic.get('learning_rate', {}).get('value')
if not _learning_rate:
raise RuntimeError( raise RuntimeError(
'No learning rate specified in the configuration file.') 'No learning rate specified in the configuration file.')
if self.decay_type == 'poly': args = self.decay_args
lr = self._learning_rate decay_type = args.pop('type')
args = self.decay_args
args.setdefault('decay_steps', self.iters) if decay_type == 'poly':
args.setdefault('end_lr', 0) lr = _learning_rate
return paddle.optimizer.PolynomialLR(lr, **args) return paddle.optimizer.PolynomialLR(lr, **args)
else: else:
raise RuntimeError('Only poly decay support.') raise RuntimeError('Only poly decay support.')
@property @property
def optimizer(self) -> paddle.optimizer.Optimizer: def optimizer(self) -> paddle.optimizer.Optimizer:
if self.optimizer_type == 'sgd': args = self.optimizer_args
optimizer_type = args.pop('type')
if optimizer_type == 'sgd':
lr = self.learning_rate lr = self.learning_rate
args = self.optimizer_args
args.setdefault('momentum', 0.9)
return paddle.optimizer.Momentum( return paddle.optimizer.Momentum(
lr, parameters=self.model.parameters(), **args) lr, parameters=self.model.parameters(), **args)
else: else:
raise RuntimeError('Only sgd optimizer support.') raise RuntimeError('Only sgd optimizer support.')
@property
def optimizer_type(self) -> str:
otype = self._optimizer_cfg.get('type')
if not otype:
raise RuntimeError(
'No optimizer type specified in the configuration file.')
return otype
@property @property
def optimizer_args(self) -> dict: def optimizer_args(self) -> dict:
args = self._optimizer_cfg.copy() args = self.dic.get('optimizer', {}).copy()
args.pop('type') if args['type'] == 'sgd':
return args args.setdefault('momentum', 0.9)
@property return args
def decay_type(self) -> str:
return self._decay['type']
@property @property
def decay_args(self) -> dict: def decay_args(self) -> dict:
args = self._decay.copy() args = self.dic.get('learning_rate', {}).get('decay', {
args.pop('type') 'type': 'poly',
'power': 0.9
}).copy()
if args['type'] == 'poly':
args.setdefault('decay_steps', self.iters)
args.setdefault('end_lr', 0)
return args return args
@property @property
def loss(self) -> list: def loss(self) -> list:
args = self.dic.get('loss', {}).copy()
if not self._losses: if not self._losses:
args = self._loss_cfg.copy()
self._losses = dict() self._losses = dict()
for key, val in args.items(): for key, val in args.items():
if key == 'types': if key == 'types':
...@@ -191,21 +165,26 @@ class Config(object): ...@@ -191,21 +165,26 @@ class Config(object):
@property @property
def model(self) -> Callable: def model(self) -> Callable:
model_cfg = self.dic.get('model').copy()
if not model_cfg:
raise RuntimeError('No model specified in the configuration file.')
if not self._model: if not self._model:
self._model = self._load_object(self._model_cfg) self._model = self._load_object(model_cfg)
return self._model return self._model
@property @property
def train_dataset(self) -> Any: def train_dataset(self) -> Any:
if not self._train_dataset: _train_dataset = self.dic.get('train_dataset').copy()
if not _train_dataset:
return None return None
return self._load_object(self._train_dataset) return self._load_object(_train_dataset)
@property @property
def val_dataset(self) -> Any: def val_dataset(self) -> Any:
if not self._val_dataset: _val_dataset = self.dic.get('val_dataset').copy()
if not _val_dataset:
return None return None
return self._load_object(self._val_dataset) return self._load_object(_val_dataset)
def _load_component(self, com_name: str) -> Any: def _load_component(self, com_name: str) -> Any:
com_list = [ com_list = [
...@@ -243,3 +222,6 @@ class Config(object): ...@@ -243,3 +222,6 @@ class Config(object):
def _is_meta_type(self, item: Any) -> bool: def _is_meta_type(self, item: Any) -> bool:
return isinstance(item, dict) and 'type' in item return isinstance(item, dict) and 'type' in item
def __str__(self) -> str:
return yaml.dump(self.dic)
...@@ -100,15 +100,22 @@ def main(args): ...@@ -100,15 +100,22 @@ def main(args):
raise RuntimeError('No configuration file specified.') raise RuntimeError('No configuration file specified.')
cfg = Config(args.cfg) cfg = Config(args.cfg)
cfg.update(
learning_rate=args.learning_rate,
iters=args.iters,
batch_size=args.batch_size)
train_dataset = cfg.train_dataset train_dataset = cfg.train_dataset
if not train_dataset: if not train_dataset:
raise RuntimeError( raise RuntimeError(
'The training dataset is not specified in the configuration file.') 'The training dataset is not specified in the configuration file.')
val_dataset = cfg.val_dataset if args.do_eval else None val_dataset = cfg.val_dataset if args.do_eval else None
losses = cfg.loss losses = cfg.loss
print('---------------Config Information---------------')
print(cfg)
print('------------------------------------------------')
train( train(
cfg.model, cfg.model,
train_dataset, train_dataset,
......
...@@ -55,6 +55,11 @@ def main(args): ...@@ -55,6 +55,11 @@ def main(args):
raise RuntimeError( raise RuntimeError(
'The verification dataset is not specified in the configuration file.' 'The verification dataset is not specified in the configuration file.'
) )
print('---------------Config Information---------------')
print(cfg)
print('------------------------------------------------')
evaluate( evaluate(
cfg.model, cfg.model,
val_dataset, val_dataset,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册