提交 2a594b9d 编写于 作者: W WuHaobo

polish config

上级 2811d342
...@@ -14,13 +14,12 @@ ...@@ -14,13 +14,12 @@
import os import os
import yaml import yaml
from ppcls.utils import check from ppcls.utils import check
from ppcls.utils import logger from ppcls.utils import logger
__all__ = ['get_config'] __all__ = ['get_config']
CONFIG_SECS = ['ARCHITECTURE', 'TRAIN', 'VALID', 'OPTIMIZER', 'LEARNING_RATE']
class AttrDict(dict): class AttrDict(dict):
def __getattr__(self, key): def __getattr__(self, key):
...@@ -63,9 +62,8 @@ def print_dict(d, delimiter=0): ...@@ -63,9 +62,8 @@ def print_dict(d, delimiter=0):
Recursively visualize a dict and Recursively visualize a dict and
indenting acrrording by the relationship of keys. indenting acrrording by the relationship of keys.
""" """
for k, v in d.items(): for k, v in sorted(d.items()):
if k in CONFIG_SECS: if k.istitle(): logger.info("-" * 60)
logger.info("-" * 60)
if isinstance(v, dict): if isinstance(v, dict):
logger.info("{}{} : ".format(delimiter * " ", k)) logger.info("{}{} : ".format(delimiter * " ", k))
...@@ -77,27 +75,23 @@ def print_dict(d, delimiter=0): ...@@ -77,27 +75,23 @@ def print_dict(d, delimiter=0):
else: else:
logger.info("{}{} : {}".format(delimiter * " ", k, v)) logger.info("{}{} : {}".format(delimiter * " ", k, v))
if k in CONFIG_SECS: if k.istitle(): logger.info("-" * 60)
logger.info("-" * 60)
def print_config(config): def print_config(config, show=True):
""" """
visualize configs visualize configs
Arguments: Arguments:
config: configs config: configs
""" """
if not show: return
copyright = "PaddleClas is powered by PaddlePaddle" copyright = "PaddleClas is powered by PaddlePaddle"
ad = "https://github.com/PaddlePaddle/PaddleClas" ad = "https://github.com/PaddlePaddle/PaddleClas"
logger.info("\n" * 2) logger.info("\n\n{}\n{}".format(copyright, ad))
logger.info(copyright)
logger.info(ad)
print_dict(config) print_dict(config)
logger.info("-" * 60) logger.info("-" * 60)
...@@ -193,11 +187,9 @@ def get_config(fname, overrides=[], show=True): ...@@ -193,11 +187,9 @@ def get_config(fname, overrides=[], show=True):
assert os.path.exists(fname), \ assert os.path.exists(fname), \
('config file({}) is not exist'.format(fname)) ('config file({}) is not exist'.format(fname))
config = parse_config(fname) config = parse_config(fname)
if show: print_config(config, show)
print_config(config)
if len(overrides) > 0: if len(overrides) > 0:
override_config(config, overrides) override_config(config, overrides)
if show: print_config(config, show)
print_config(config)
check_config(config) check_config(config)
return config return config
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册