提交 2e0bc41c 编写于 作者: W wuzewu

Print config

上级 40ed988d
......@@ -15,7 +15,6 @@
import codecs
import os
from typing import Any, Callable
import pprint
import yaml
import paddle
......@@ -37,10 +36,10 @@ class Config(object):
if not os.path.exists(path):
raise FileNotFoundError('File {} does not exist'.format(path))
self._model = None
self._losses = None
if path.endswith('yml') or path.endswith('yaml'):
dic = self._parse_from_yaml(path)
logger.info('\n' + pprint.pformat(dic))
self._build(dic)
self.dic = self._parse_from_yaml(path)
else:
raise RuntimeError('Config file should in yaml format!')
......@@ -61,6 +60,7 @@ class Config(object):
'''Parse a yaml file and build config'''
with codecs.open(path, 'r', 'utf-8') as file:
dic = yaml.load(file, Loader=yaml.FullLoader)
if '_base_' in dic:
cfg_dir = os.path.dirname(path)
base_path = dic.pop('_base_')
......@@ -69,111 +69,85 @@ class Config(object):
dic = self._update_dic(dic, base_dic)
return dic
def _build(self, dic: dict):
'''Build config from dictionary'''
dic = dic.copy()
self._batch_size = dic.get('batch_size', 1)
self._iters = dic.get('iters')
if 'model' not in dic:
raise RuntimeError()
self._model_cfg = dic['model']
self._model = None
self._train_dataset = dic.get('train_dataset')
self._val_dataset = dic.get('val_dataset')
self._learning_rate_cfg = dic.get('learning_rate', {})
self._learning_rate = self._learning_rate_cfg.get('value')
self._decay = self._learning_rate_cfg.get('decay', {
'type': 'poly',
'power': 0.9
})
self._loss_cfg = dic.get('loss', {})
self._losses = None
self._optimizer_cfg = dic.get('optimizer', {})
def update(self,
learning_rate: float = None,
batch_size: int = None,
iters: int = None):
'''Update config'''
if learning_rate:
self._learning_rate = learning_rate
self.dic['learning_rate']['value'] = learning_rate
if batch_size:
self._batch_size = batch_size
self.dic['batch_size'] = batch_size
if iters:
self._iters = iters
self.dic['iters'] = iters
@property
def batch_size(self) -> int:
return self._batch_size
return self.dic.get('batch_size', 1)
@property
def iters(self) -> int:
if not self._iters:
iters = self.dic.get('iters')
if not iters:
raise RuntimeError('No iters specified in the configuration file.')
return self._iters
return iters
@property
def learning_rate(self) -> float:
if not self._learning_rate:
_learning_rate = self.dic.get('learning_rate', {}).get('value')
if not _learning_rate:
raise RuntimeError(
'No learning rate specified in the configuration file.')
if self.decay_type == 'poly':
lr = self._learning_rate
args = self.decay_args
args.setdefault('decay_steps', self.iters)
args.setdefault('end_lr', 0)
args = self.decay_args
decay_type = args.pop('type')
if decay_type == 'poly':
lr = _learning_rate
return paddle.optimizer.PolynomialLR(lr, **args)
else:
raise RuntimeError('Only poly decay support.')
@property
def optimizer(self) -> paddle.optimizer.Optimizer:
if self.optimizer_type == 'sgd':
args = self.optimizer_args
optimizer_type = args.pop('type')
if optimizer_type == 'sgd':
lr = self.learning_rate
args = self.optimizer_args
args.setdefault('momentum', 0.9)
return paddle.optimizer.Momentum(
lr, parameters=self.model.parameters(), **args)
else:
raise RuntimeError('Only sgd optimizer support.')
@property
def optimizer_type(self) -> str:
otype = self._optimizer_cfg.get('type')
if not otype:
raise RuntimeError(
'No optimizer type specified in the configuration file.')
return otype
@property
def optimizer_args(self) -> dict:
args = self._optimizer_cfg.copy()
args.pop('type')
return args
args = self.dic.get('optimizer', {}).copy()
if args['type'] == 'sgd':
args.setdefault('momentum', 0.9)
@property
def decay_type(self) -> str:
return self._decay['type']
return args
@property
def decay_args(self) -> dict:
args = self._decay.copy()
args.pop('type')
args = self.dic.get('learning_rate', {}).get('decay', {
'type': 'poly',
'power': 0.9
}).copy()
if args['type'] == 'poly':
args.setdefault('decay_steps', self.iters)
args.setdefault('end_lr', 0)
return args
@property
def loss(self) -> list:
args = self.dic.get('loss', {}).copy()
if not self._losses:
args = self._loss_cfg.copy()
self._losses = dict()
for key, val in args.items():
if key == 'types':
......@@ -191,21 +165,26 @@ class Config(object):
@property
def model(self) -> Callable:
model_cfg = self.dic.get('model').copy()
if not model_cfg:
raise RuntimeError('No model specified in the configuration file.')
if not self._model:
self._model = self._load_object(self._model_cfg)
self._model = self._load_object(model_cfg)
return self._model
@property
def train_dataset(self) -> Any:
if not self._train_dataset:
_train_dataset = self.dic.get('train_dataset').copy()
if not _train_dataset:
return None
return self._load_object(self._train_dataset)
return self._load_object(_train_dataset)
@property
def val_dataset(self) -> Any:
if not self._val_dataset:
_val_dataset = self.dic.get('val_dataset').copy()
if not _val_dataset:
return None
return self._load_object(self._val_dataset)
return self._load_object(_val_dataset)
def _load_component(self, com_name: str) -> Any:
com_list = [
......@@ -243,3 +222,6 @@ class Config(object):
def _is_meta_type(self, item: Any) -> bool:
return isinstance(item, dict) and 'type' in item
def __str__(self) -> str:
return yaml.dump(self.dic)
......@@ -100,15 +100,22 @@ def main(args):
raise RuntimeError('No configuration file specified.')
cfg = Config(args.cfg)
cfg.update(
learning_rate=args.learning_rate,
iters=args.iters,
batch_size=args.batch_size)
train_dataset = cfg.train_dataset
if not train_dataset:
raise RuntimeError(
'The training dataset is not specified in the configuration file.')
val_dataset = cfg.val_dataset if args.do_eval else None
losses = cfg.loss
print('---------------Config Information---------------')
print(cfg)
print('------------------------------------------------')
train(
cfg.model,
train_dataset,
......
......@@ -55,6 +55,11 @@ def main(args):
raise RuntimeError(
'The verification dataset is not specified in the configuration file.'
)
print('---------------Config Information---------------')
print(cfg)
print('------------------------------------------------')
evaluate(
cfg.model,
val_dataset,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册