config.py 4.0 KB
Newer Older
Y
ying 已提交
1
class TrainTaskConfig(object):
L
luotao1 已提交
2
    # support both CPU and GPU now.
3
    use_gpu = True
Y
ying 已提交
4
    # the epoch number to train.
5
    pass_num = 30
6
    # the number of sequences contained in a mini-batch.
G
guosheng 已提交
7
    # deprecated, set batch_size in args.
8
    batch_size = 32
9
    # the hyper parameters for Adam optimizer.
10
    # This static learning_rate will be multiplied to the LearningRateScheduler
11
    # derived learning rate the to get the final learning rate.
12
    learning_rate = 2.0
Y
ying 已提交
13
    beta1 = 0.9
14
    beta2 = 0.997
Y
ying 已提交
15
    eps = 1e-9
16
    # the parameters for learning rate scheduling.
17
    warmup_steps = 8000
18 19 20 21
    # the weight used to mix up the ground-truth distribution and the fixed
    # uniform distribution in label smoothing when training.
    # Set this as zero if label smoothing is not wanted.
    label_smooth_eps = 0.1
22 23
    # the directory for saving trained models.
    model_dir = "trained_models"
24 25 26 27 28 29 30 31 32
    # the directory for saving checkpoints.
    ckpt_dir = "trained_ckpts"
    # the directory for loading checkpoint.
    # If provided, continue training from the checkpoint.
    ckpt_path = None
    # the parameter to initialize the learning rate scheduler.
    # It should be provided if use checkpoints, since the checkpoint doesn't
    # include the training step counter currently.
    start_step = 0
33 34
    # the frequency to save trained models.
    save_freq = 10000
35 36 37


class InferTaskConfig(object):
38
    use_gpu = True
39
    # the number of examples in one run for sequence generation.
40
    batch_size = 10
41
    # the parameters for beam search.
42
    beam_size = 5
43
    max_out_len = 256
44
    # the number of decoded sentences to output.
45
    n_best = 1
46 47 48
    # the flags indicating whether to output the special tokens.
    output_bos = False
    output_eos = False
G
guosheng 已提交
49
    output_unk = True
50 51
    # the directory for loading the trained model.
    model_path = "trained_models/pass_1.infer.model"
52

Y
ying 已提交
53 54

class ModelHyperParams(object):
G
guosheng 已提交
55 56
    # These following five vocabularies related configurations will be set
    # automatically according to the passed vocabulary path and special tokens.
Y
ying 已提交
57 58 59 60
    # size of source word dictionary.
    src_vocab_size = 10000
    # size of target word dictionay
    trg_vocab_size = 10000
61 62 63 64
    # index for <bos> token
    bos_idx = 0
    # index for <eos> token
    eos_idx = 1
65 66
    # index for <unk> token
    unk_idx = 2
67
    # max length of sequences deciding the size of position encoding table.
G
guosheng 已提交
68
    max_length = 256
Y
ying 已提交
69 70 71 72 73
    # the dimension for word embeddings, which is also the last dimension of
    # the input and output of multi-head attention, position-wise feed-forward
    # networks, encoder and decoder.
    d_model = 512
    # size of the hidden layer in position-wise feed-forward networks.
G
guosheng 已提交
74
    d_inner_hid = 2048
Y
ying 已提交
75 76 77 78 79 80 81 82
    # the dimension that keys are projected to for dot-product attention.
    d_key = 64
    # the dimension that values are projected to for dot-product attention.
    d_value = 64
    # number of head used in multi-head attention.
    n_head = 8
    # number of sub-layers to be stacked in the encoder and decoder.
    n_layer = 6
83 84 85 86 87 88 89 90
    # dropout rates of different modules.
    prepostprocess_dropout = 0.1
    attention_dropout = 0.1
    relu_dropout = 0.1
    # to process before each sub-layer
    preprocess_cmd = "n"  # layer normalization
    # to process after each sub-layer
    postprocess_cmd = "da"  # dropout + residual connection
G
guosheng 已提交
91 92
    # random seed used in dropout for CE.
    dropout_seed = None
G
guosheng 已提交
93 94 95
    # the flag indicating whether to share embedding and softmax weights.
    # vocabularies in source and target should be same for weight sharing.
    weight_sharing = True
Y
ying 已提交
96 97


98 99 100 101 102 103 104 105 106 107
def merge_cfg_from_list(cfg_list, g_cfgs):
    """
    Set the above global configurations using the cfg_list. 
    """
    assert len(cfg_list) % 2 == 0
    for key, value in zip(cfg_list[0::2], cfg_list[1::2]):
        for g_cfg in g_cfgs:
            if hasattr(g_cfg, key):
                try:
                    value = eval(value)
108
                except Exception:  # for file path
109 110 111
                    pass
                setattr(g_cfg, key, value)
                break