# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import yaml import os from argparse import ArgumentParser, RawDescriptionHelpFormatter def override(dl, ks, v): """ Recursively replace dict of list Args: dl(dict or list): dict or list to be replaced ks(list): list of keys v(str): value to be replaced """ def str2num(v): try: return eval(v) except Exception: return v assert isinstance(dl, (list, dict)), ("{} should be a list or a dict") assert len(ks) > 0, ('lenght of keys should larger than 0') if isinstance(dl, list): k = str2num(ks[0]) if len(ks) == 1: assert k < len(dl), ('index({}) out of range({})'.format(k, dl)) dl[k] = str2num(v) else: override(dl[k], ks[1:], v) else: if len(ks) == 1: #assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl)) if not ks[0] in dl: logger.warning('A new filed ({}) detected!'.format(ks[0], dl)) dl[ks[0]] = str2num(v) else: assert ks[0] in dl, ( '({}) doesn\'t exist in {}, a new dict field is invalid'. format(ks[0], dl)) override(dl[ks[0]], ks[1:], v) def override_config(config, options=None): """ Recursively override the config Args: config(dict): dict to be replaced options(list): list of pairs(key0.key1.idx.key2=value) such as: [ 'topk=2', 'VALID.transforms.1.ResizeImage.resize_short=300' ] Returns: config(dict): replaced config """ if options is not None: for opt in options: assert isinstance(opt, str), ( "option({}) should be a str".format(opt)) assert "=" in opt, ( "option({}) should contain a =" "to distinguish between key and value".format(opt)) pair = opt.split('=') assert len(pair) == 2, ("there can be only a = in the option") key, value = pair keys = key.split('.') override(config, keys, value) return config class ArgsParser(ArgumentParser): def __init__(self): super(ArgsParser, self).__init__( formatter_class=RawDescriptionHelpFormatter) self.add_argument("-c", "--config", help="configuration file to use") self.add_argument( "-t", "--tag", default="0", help="tag for marking worker") self.add_argument( '-o', '--override', action='append', default=[], help='config options to be overridden') def parse_args(self, argv=None): args = super(ArgsParser, self).parse_args(argv) assert args.config is not None, \ "Please specify --config=configure_file_path." return args def load_config(file_path): """ Load config from yml/yaml file. Args: file_path (str): Path of the config file to be loaded. Returns: config """ ext = os.path.splitext(file_path)[1] assert ext in ['.yml', '.yaml'], "only support yaml files for now" with open(file_path, 'rb') as f: config = yaml.load(f, Loader=yaml.Loader) return config def gen_config(): base_config = { "Global": { "algorithm": "SRNet", "use_gpu": True, "start_epoch": 1, "stage1_epoch_num": 100, "stage2_epoch_num": 100, "log_smooth_window": 20, "print_batch_step": 2, "save_model_dir": "./output/SRNet", "use_visualdl": False, "save_epoch_step": 10, "vgg_pretrain": "./pretrained/VGG19_pretrained", "vgg_load_static_pretrain": True }, "Architecture": { "model_type": "data_aug", "algorithm": "SRNet", "net_g": { "name": "srnet_net_g", "encode_dim": 64, "norm": "batch", "use_dropout": False, "init_type": "xavier", "init_gain": 0.02, "use_dilation": 1 }, # input_nc, ndf, netD, # n_layers_D=3, norm='instance', use_sigmoid=False, init_type='normal', init_gain=0.02, gpu_id='cuda:0' "bg_discriminator": { "name": "srnet_bg_discriminator", "input_nc": 6, "ndf": 64, "netD": "basic", "norm": "none", "init_type": "xavier", }, "fusion_discriminator": { "name": "srnet_fusion_discriminator", "input_nc": 6, "ndf": 64, "netD": "basic", "norm": "none", "init_type": "xavier", } }, "Loss": { "lamb": 10, "perceptual_lamb": 1, "muvar_lamb": 50, "style_lamb": 500 }, "Optimizer": { "name": "Adam", "learning_rate": { "name": "lambda", "lr": 0.0002, "lr_decay_iters": 50 }, "beta1": 0.5, "beta2": 0.999, }, "Train": { "batch_size_per_card": 8, "num_workers_per_card": 4, "dataset": { "delimiter": "\t", "data_dir": "/", "label_file": "tmp/label.txt", "transforms": [{ "DecodeImage": { "to_rgb": True, "to_np": False, "channel_first": False } }, { "NormalizeImage": { "scale": 1. / 255., "mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225], "order": None } }, { "ToCHWImage": None }] } } } with open("config.yml", "w") as f: yaml.dump(base_config, f) if __name__ == '__main__': gen_config()