# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # #Licensed under the Apache License, Version 2.0 (the "License"); #you may not use this file except in compliance with the License. #You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # #Unless required by applicable law or agreed to in writing, software #distributed under the License is distributed on an "AS IS" BASIS, #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #See the License for the specific language governing permissions and #limitations under the License. import yaml import logging logger = logging.getLogger(__name__) CONFIG_SECS = [ 'train', 'valid', 'test', 'infer', ] class AttrDict(dict): def __getattr__(self, key): return self[key] def __setattr__(self, key, value): if key in self.__dict__: self.__dict__[key] = value else: self[key] = value def parse_config(cfg_file): """Load a config file into AttrDict""" with open(cfg_file, 'r') as fopen: yaml_config = AttrDict(yaml.load(fopen, Loader=yaml.Loader)) create_attr_dict(yaml_config) return yaml_config def create_attr_dict(yaml_config): from ast import literal_eval for key, value in yaml_config.items(): if type(value) is dict: yaml_config[key] = value = AttrDict(value) if isinstance(value, str): try: value = literal_eval(value) except BaseException: pass if isinstance(value, AttrDict): create_attr_dict(yaml_config[key]) else: yaml_config[key] = value return def merge_configs(cfg, sec, args_dict): assert sec in CONFIG_SECS, "invalid config section {}".format(sec) sec_dict = getattr(cfg, sec.upper()) for k, v in args_dict.items(): if v is None: continue try: if hasattr(sec_dict, k): setattr(sec_dict, k, v) except: pass return cfg def print_configs(cfg, mode): logger.info("---------------- {:>5} Arguments ----------------".format( mode)) for sec, sec_items in cfg.items(): logger.info("{}:".format(sec)) for k, v in sec_items.items(): logger.info(" {}:{}".format(k, v)) logger.info("-------------------------------------------------")