configure.py 5.0 KB
Newer Older
0
0YuanZhang0 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
#encoding=utf8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import sys
import argparse
import six
import logging
import json

logging_only_message = "%(message)s"
logging_details = "%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s"

class JsonConfig(object):
    def __init__(self, config_path):
        self._config_dict = self._parse(config_path)

    def _parse(self, config_path):
        try:
            with open(config_path) as json_file:
                config_dict = json.load(json_file)
        except:
            raise IOError("Error in parsing bert model config file '%s'" %
                config_path)
        else:
            return config_dict

    def __getitem__(self, key):
        return self._config_dict[key]

    def print_config(self):
        for arg, value in sorted(six.iteritems(self._config_dict)):
            print('%s: %s' % (arg, value))
        print('------------------------------------------------')


class ArgumentGroup(object):
    def __init__(self, parser, title, des):
        self._group = parser.add_argument_group(title=title, description=des)

    def add_arg(self, name, type, default, help, **kwargs):
        type = str2bool if type == bool else type
        self._group.add_argument(
            "--" + name,
            default=default,
            type=type,
            help=help + ' Default: %(default)s.',
            **kwargs)

class ArgConfig(object):
    
    def __init__(self):
        parser = argparse.ArgumentParser()

        train_g = ArgumentGroup(parser, "training", "training options.")
        train_g.add_arg("epoch",             int,    3,      "Number of epoches for fine-tuning.")
        train_g.add_arg("learning_rate",     float,  5e-5,   "Learning rate used to train with warmup.")
        train_g.add_arg("lr_scheduler",      str,    "linear_warmup_decay",
                        "scheduler of learning rate.", choices=['linear_warmup_decay', 'noam_decay'])
        train_g.add_arg("weight_decay",      float,  0.01,   "Weight decay rate for L2 regularizer.")
        train_g.add_arg("warmup_proportion", float,  0.1,
                        "Proportion of training steps to perform linear learning rate warmup for.")
        train_g.add_arg("save_steps",        int,    1000,   "The steps interval to save checkpoints.")
        train_g.add_arg("use_fp16",          bool,   False,  "Whether to use fp16 mixed precision training.")
        train_g.add_arg("loss_scaling",      float,  1.0,
                        "Loss scaling factor for mixed precision training, only valid when use_fp16 is enabled.")
        train_g.add_arg("pred_dir",   str,    None,   "Path to save the prediction results")

        log_g = ArgumentGroup(parser, "logging", "logging related.")
        log_g.add_arg("skip_steps",          int,    10,    "The steps interval to print loss.")
        log_g.add_arg("verbose",             bool,   False, "Whether to output verbose log.")

        run_type_g = ArgumentGroup(parser, "run_type", "running type options.")
        run_type_g.add_arg("use_cuda",                     bool,   True,  "If set, use GPU for training.")
        run_type_g.add_arg("use_fast_executor",            bool,   False, "If set, use fast parallel executor (in experiment).")
        run_type_g.add_arg("num_iteration_per_drop_scope", int,    1,     "Ihe iteration intervals to clean up temporary variables.")
        run_type_g.add_arg("do_train",                     bool,   True,  "Whether to perform training.")
        run_type_g.add_arg("do_predict",                   bool,   True,  "Whether to perform prediction.")

        custom_g = ArgumentGroup(parser, "customize", "customized options.")

        self.custom_g = custom_g

        self.parser = parser

    def add_arg(self, name, dtype, default, descrip):
        self.custom_g.add_arg(name, dtype, default, descrip)

    def build_conf(self):
        return self.parser.parse_args()


def str2bool(v):
    # because argparse does not support to parse "true, False" as python
    # boolean directly
    return v.lower() in ("true", "t", "1")


def print_arguments(args, log = None):
    if not log:
        print('-----------  Configuration Arguments -----------')
        for arg, value in sorted(six.iteritems(vars(args))):
            print('%s: %s' % (arg, value))
        print('------------------------------------------------')
    else:
        log.info('-----------  Configuration Arguments -----------')
        for arg, value in sorted(six.iteritems(vars(args))):
            log.info('%s: %s' % (arg, value))
        log.info('------------------------------------------------')


if __name__ == "__main__":

    args = ArgConfig()
    args = args.build_conf()

    # using print()
    print_arguments(args)

    logging.basicConfig(
        level=logging.INFO,
        format=logging_details,
        datefmt='%Y-%m-%d %H:%M:%S')

    # using logging
    print_arguments(args, logging)

    json_conf = JsonConfig("../../data/pretrained_models/uncased_L-12_H-768_A-12/bert_config.json")
    json_conf.print_config()