diff --git a/ppcls/utils/config.py b/ppcls/utils/config.py index 17f7aab30a71b80cd05cabbcb55dcba54d5a8041..524d9a6e3b5b72ebbd8f3068c0de1ea4b306c689 100644 --- a/ppcls/utils/config.py +++ b/ppcls/utils/config.py @@ -14,13 +14,12 @@ import os import yaml + from ppcls.utils import check from ppcls.utils import logger __all__ = ['get_config'] -CONFIG_SECS = ['ARCHITECTURE', 'TRAIN', 'VALID', 'OPTIMIZER', 'LEARNING_RATE'] - class AttrDict(dict): def __getattr__(self, key): @@ -47,13 +46,12 @@ def create_attr_dict(yaml_config): create_attr_dict(yaml_config[key]) else: yaml_config[key] = value - return def parse_config(cfg_file): """Load a config file into AttrDict""" with open(cfg_file, 'r') as fopen: - yaml_config = AttrDict(yaml.load(fopen, Loader=yaml.FullLoader)) + yaml_config = AttrDict(yaml.load(fopen, Loader=yaml.SafeLoader)) create_attr_dict(yaml_config) return yaml_config @@ -63,14 +61,8 @@ def print_dict(d, delimiter=0): Recursively visualize a dict and indenting acrrording by the relationship of keys. """ - - dk = [] - dv = [] - for k, v in d.items(): - - if k in CONFIG_SECS: - logger.info("-" * 60) - + placeholder = "-" * 60 + for k, v in sorted(d.items()): if isinstance(v, dict): logger.info("{}{} : ".format(delimiter * " ", k)) print_dict(v, delimiter + 4) @@ -79,16 +71,11 @@ def print_dict(d, delimiter=0): for value in v: print_dict(value, delimiter + 4) else: - dk.append(k) - dv.append(v) - if k in CONFIG_SECS: - logger.info("-" * 60) - - for ki,vi in zip(dk,dv): + logger.info("{}{} : {}".format(delimiter * " ", k, v)) - logger.info("{}{} : {}".format(delimiter * " ", ki, vi)) + if k.isupper(): + logger.info(placeholder) - def print_config(config): """ @@ -97,18 +84,22 @@ def print_config(config): Arguments: config: configs """ - - copyright = "PaddleClas is powered by PaddlePaddle" - ad = "https://github.com/PaddlePaddle/PaddleClas" - - logger.info("\n" * 2) - logger.info(copyright) - logger.info(ad) - + copyright = "PaddleClas is powered by PaddlePaddle !" + info = "For more info please go to the following website." + website = "https://github.com/PaddlePaddle/PaddleClas" + AD_LEN = 55 + + logger.info("\n{0}\n{1}\n{2}\n{3}\n{4}\n{5}\n{6}\n{7}\n".format( + "=" * (AD_LEN + 4), + "=={}==".format(copyright.center(AD_LEN)), + "=" * (AD_LEN + 4), + "=={}==".format(' ' * AD_LEN), + "=={}==".format(info.center(AD_LEN)), + "=={}==".format(' ' * AD_LEN), + "=={}==".format(website.center(AD_LEN)), + "=" * (AD_LEN + 4), )) print_dict(config) - logger.info("-" * 60) - def check_config(config): """ @@ -166,7 +157,7 @@ def override(dl, ks, v): override(dl[ks[0]], ks[1:], v) -def override_config(config, options=[]): +def override_config(config, options=None): """ Recursively override the config @@ -181,32 +172,31 @@ def override_config(config, options=[]): Returns: config(dict): replaced config """ - for opt in options: - assert isinstance(opt, str), \ - ("option({}) should be a str".format(opt)) - assert "=" in opt, ("option({}) should contain " \ - "a = to distinguish between key and value".format(opt)) - pair = opt.split('=') - assert len(pair) == 2, ("there can be only a = in the option") - key, value = pair - keys = key.split('.') - override(config, keys, value) + if options is not None: + for opt in options: + assert isinstance(opt, str), ( + "option({}) should be a str".format(opt)) + assert "=" in opt, ( + "option({}) should contain a =" + "to distinguish between key and value".format(opt)) + pair = opt.split('=') + assert len(pair) == 2, ("there can be only a = in the option") + key, value = pair + keys = key.split('.') + override(config, keys, value) return config -def get_config(fname, overrides=[], show=True): +def get_config(fname, overrides=None, show=True): """ Read config from file """ - assert os.path.exists(fname), \ - ('config file({}) is not exist'.format(fname)) + assert os.path.exists(fname), ( + 'config file({}) is not exist'.format(fname)) config = parse_config(fname) + override_config(config, overrides) if show: print_config(config) - if len(overrides) > 0: - override_config(config, overrides) - if show: - print_config(config) check_config(config) return config