提交 9c61fdd9 编写于 作者: S shippingwang

fix

上级 8a85f447
...@@ -14,13 +14,12 @@ ...@@ -14,13 +14,12 @@
import os import os
import yaml import yaml
from ppcls.utils import check from ppcls.utils import check
from ppcls.utils import logger from ppcls.utils import logger
__all__ = ['get_config'] __all__ = ['get_config']
CONFIG_SECS = ['ARCHITECTURE', 'TRAIN', 'VALID', 'OPTIMIZER', 'LEARNING_RATE']
class AttrDict(dict): class AttrDict(dict):
def __getattr__(self, key): def __getattr__(self, key):
...@@ -47,13 +46,12 @@ def create_attr_dict(yaml_config): ...@@ -47,13 +46,12 @@ def create_attr_dict(yaml_config):
create_attr_dict(yaml_config[key]) create_attr_dict(yaml_config[key])
else: else:
yaml_config[key] = value yaml_config[key] = value
return
def parse_config(cfg_file): def parse_config(cfg_file):
"""Load a config file into AttrDict""" """Load a config file into AttrDict"""
with open(cfg_file, 'r') as fopen: with open(cfg_file, 'r') as fopen:
yaml_config = AttrDict(yaml.load(fopen, Loader=yaml.FullLoader)) yaml_config = AttrDict(yaml.load(fopen, Loader=yaml.SafeLoader))
create_attr_dict(yaml_config) create_attr_dict(yaml_config)
return yaml_config return yaml_config
...@@ -63,14 +61,8 @@ def print_dict(d, delimiter=0): ...@@ -63,14 +61,8 @@ def print_dict(d, delimiter=0):
Recursively visualize a dict and Recursively visualize a dict and
indenting acrrording by the relationship of keys. indenting acrrording by the relationship of keys.
""" """
placeholder = "-" * 60
dk = [] for k, v in sorted(d.items()):
dv = []
for k, v in d.items():
if k in CONFIG_SECS:
logger.info("-" * 60)
if isinstance(v, dict): if isinstance(v, dict):
logger.info("{}{} : ".format(delimiter * " ", k)) logger.info("{}{} : ".format(delimiter * " ", k))
print_dict(v, delimiter + 4) print_dict(v, delimiter + 4)
...@@ -79,16 +71,11 @@ def print_dict(d, delimiter=0): ...@@ -79,16 +71,11 @@ def print_dict(d, delimiter=0):
for value in v: for value in v:
print_dict(value, delimiter + 4) print_dict(value, delimiter + 4)
else: else:
dk.append(k) logger.info("{}{} : {}".format(delimiter * " ", k, v))
dv.append(v)
if k in CONFIG_SECS:
logger.info("-" * 60)
for ki,vi in zip(dk,dv):
logger.info("{}{} : {}".format(delimiter * " ", ki, vi)) if k.isupper():
logger.info(placeholder)
def print_config(config): def print_config(config):
""" """
...@@ -97,18 +84,22 @@ def print_config(config): ...@@ -97,18 +84,22 @@ def print_config(config):
Arguments: Arguments:
config: configs config: configs
""" """
copyright = "PaddleClas is powered by PaddlePaddle !"
copyright = "PaddleClas is powered by PaddlePaddle" info = "For more info please go to the following website."
ad = "https://github.com/PaddlePaddle/PaddleClas" website = "https://github.com/PaddlePaddle/PaddleClas"
AD_LEN = 55
logger.info("\n" * 2)
logger.info(copyright) logger.info("\n{0}\n{1}\n{2}\n{3}\n{4}\n{5}\n{6}\n{7}\n".format(
logger.info(ad) "=" * (AD_LEN + 4),
"=={}==".format(copyright.center(AD_LEN)),
"=" * (AD_LEN + 4),
"=={}==".format(' ' * AD_LEN),
"=={}==".format(info.center(AD_LEN)),
"=={}==".format(' ' * AD_LEN),
"=={}==".format(website.center(AD_LEN)),
"=" * (AD_LEN + 4), ))
print_dict(config) print_dict(config)
logger.info("-" * 60)
def check_config(config): def check_config(config):
""" """
...@@ -166,7 +157,7 @@ def override(dl, ks, v): ...@@ -166,7 +157,7 @@ def override(dl, ks, v):
override(dl[ks[0]], ks[1:], v) override(dl[ks[0]], ks[1:], v)
def override_config(config, options=[]): def override_config(config, options=None):
""" """
Recursively override the config Recursively override the config
...@@ -181,32 +172,31 @@ def override_config(config, options=[]): ...@@ -181,32 +172,31 @@ def override_config(config, options=[]):
Returns: Returns:
config(dict): replaced config config(dict): replaced config
""" """
for opt in options: if options is not None:
assert isinstance(opt, str), \ for opt in options:
("option({}) should be a str".format(opt)) assert isinstance(opt, str), (
assert "=" in opt, ("option({}) should contain " \ "option({}) should be a str".format(opt))
"a = to distinguish between key and value".format(opt)) assert "=" in opt, (
pair = opt.split('=') "option({}) should contain a ="
assert len(pair) == 2, ("there can be only a = in the option") "to distinguish between key and value".format(opt))
key, value = pair pair = opt.split('=')
keys = key.split('.') assert len(pair) == 2, ("there can be only a = in the option")
override(config, keys, value) key, value = pair
keys = key.split('.')
override(config, keys, value)
return config return config
def get_config(fname, overrides=[], show=True): def get_config(fname, overrides=None, show=True):
""" """
Read config from file Read config from file
""" """
assert os.path.exists(fname), \ assert os.path.exists(fname), (
('config file({}) is not exist'.format(fname)) 'config file({}) is not exist'.format(fname))
config = parse_config(fname) config = parse_config(fname)
override_config(config, overrides)
if show: if show:
print_config(config) print_config(config)
if len(overrides) > 0:
override_config(config, overrides)
if show:
print_config(config)
check_config(config) check_config(config)
return config return config
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册