config.py 7.8 KB
Newer Older
Y
ying 已提交
1
class TrainTaskConfig(object):
L
luotao1 已提交
2
    # support both CPU and GPU now.
3
    use_gpu = True
Y
ying 已提交
4
    # the epoch number to train.
5
    pass_num = 30
6
    # the number of sequences contained in a mini-batch.
G
guosheng 已提交
7
    # deprecated, set batch_size in args.
8
    batch_size = 32
9
    # the hyper parameters for Adam optimizer.
10
    # This static learning_rate will be multiplied to the LearningRateScheduler
11
    # derived learning rate the to get the final learning rate.
12
    learning_rate = 2.0
Y
ying 已提交
13
    beta1 = 0.9
14
    beta2 = 0.997
Y
ying 已提交
15
    eps = 1e-9
16
    # the parameters for learning rate scheduling.
17
    warmup_steps = 8000
18 19 20 21
    # the weight used to mix up the ground-truth distribution and the fixed
    # uniform distribution in label smoothing when training.
    # Set this as zero if label smoothing is not wanted.
    label_smooth_eps = 0.1
22 23
    # the directory for saving trained models.
    model_dir = "trained_models"
24 25 26 27 28 29 30 31 32
    # the directory for saving checkpoints.
    ckpt_dir = "trained_ckpts"
    # the directory for loading checkpoint.
    # If provided, continue training from the checkpoint.
    ckpt_path = None
    # the parameter to initialize the learning rate scheduler.
    # It should be provided if use checkpoints, since the checkpoint doesn't
    # include the training step counter currently.
    start_step = 0
33 34
    # the frequency to save trained models.
    save_freq = 10000
35 36 37


class InferTaskConfig(object):
38
    use_gpu = True
39
    # the number of examples in one run for sequence generation.
40
    batch_size = 10
41
    # the parameters for beam search.
42
    beam_size = 5
43
    max_out_len = 256
44
    # the number of decoded sentences to output.
45
    n_best = 1
46 47 48
    # the flags indicating whether to output the special tokens.
    output_bos = False
    output_eos = False
G
guosheng 已提交
49
    output_unk = True
50 51
    # the directory for loading the trained model.
    model_path = "trained_models/pass_1.infer.model"
52

Y
ying 已提交
53 54

class ModelHyperParams(object):
G
guosheng 已提交
55 56
    # These following five vocabularies related configurations will be set
    # automatically according to the passed vocabulary path and special tokens.
Y
ying 已提交
57 58 59 60
    # size of source word dictionary.
    src_vocab_size = 10000
    # size of target word dictionay
    trg_vocab_size = 10000
61 62 63 64
    # index for <bos> token
    bos_idx = 0
    # index for <eos> token
    eos_idx = 1
65 66
    # index for <unk> token
    unk_idx = 2
67
    # max length of sequences deciding the size of position encoding table.
G
guosheng 已提交
68
    max_length = 256
Y
ying 已提交
69 70 71 72 73
    # the dimension for word embeddings, which is also the last dimension of
    # the input and output of multi-head attention, position-wise feed-forward
    # networks, encoder and decoder.
    d_model = 512
    # size of the hidden layer in position-wise feed-forward networks.
G
guosheng 已提交
74
    d_inner_hid = 2048
Y
ying 已提交
75 76 77 78 79 80 81 82
    # the dimension that keys are projected to for dot-product attention.
    d_key = 64
    # the dimension that values are projected to for dot-product attention.
    d_value = 64
    # number of head used in multi-head attention.
    n_head = 8
    # number of sub-layers to be stacked in the encoder and decoder.
    n_layer = 6
83 84 85 86 87 88 89 90
    # dropout rates of different modules.
    prepostprocess_dropout = 0.1
    attention_dropout = 0.1
    relu_dropout = 0.1
    # to process before each sub-layer
    preprocess_cmd = "n"  # layer normalization
    # to process after each sub-layer
    postprocess_cmd = "da"  # dropout + residual connection
G
guosheng 已提交
91 92
    # random seed used in dropout for CE.
    dropout_seed = None
G
guosheng 已提交
93 94 95
    # the flag indicating whether to share embedding and softmax weights.
    # vocabularies in source and target should be same for weight sharing.
    weight_sharing = True
Y
ying 已提交
96 97


98 99 100 101 102 103 104 105 106 107
def merge_cfg_from_list(cfg_list, g_cfgs):
    """
    Set the above global configurations using the cfg_list. 
    """
    assert len(cfg_list) % 2 == 0
    for key, value in zip(cfg_list[0::2], cfg_list[1::2]):
        for g_cfg in g_cfgs:
            if hasattr(g_cfg, key):
                try:
                    value = eval(value)
108
                except Exception:  # for file path
109 110 111 112 113
                    pass
                setattr(g_cfg, key, value)
                break


114 115 116
# The placeholder for batch_size in compile time. Must be -1 currently to be
# consistent with some ops' infer-shape output in compile time, such as the
# sequence_expand op used in beamsearch decoder.
117
batch_size = -1
118
# The placeholder for squence length in compile time.
119
seq_len = ModelHyperParams.max_length
120 121 122 123 124
# Here list the data shapes and data types of all inputs.
# The shapes here act as placeholder and are set to pass the infer-shape in
# compile time.
input_descs = {
    # The actual data shape of src_word is:
G
guosheng 已提交
125 126
    # [batch_size, max_src_len_in_batch, 1]
    "src_word": [(batch_size, seq_len, 1), "int64", 2],
127
    # The actual data shape of src_pos is:
G
guosheng 已提交
128 129
    # [batch_size, max_src_len_in_batch, 1]
    "src_pos": [(batch_size, seq_len, 1), "int64"],
130 131 132 133
    # This input is used to remove attention weights on paddings in the
    # encoder.
    # The actual data shape of src_slf_attn_bias is:
    # [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch]
134 135
    "src_slf_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
                           seq_len), "float32"],
136
    # The actual data shape of trg_word is:
G
guosheng 已提交
137 138
    # [batch_size, max_trg_len_in_batch, 1]
    "trg_word": [(batch_size, seq_len, 1), "int64",
139
                 2],  # lod_level is only used in fast decoder.
140
    # The actual data shape of trg_pos is:
G
guosheng 已提交
141 142
    # [batch_size, max_trg_len_in_batch, 1]
    "trg_pos": [(batch_size, seq_len, 1), "int64"],
143 144 145 146
    # This input is used to remove attention weights on paddings and
    # subsequent words in the decoder.
    # The actual data shape of trg_slf_attn_bias is:
    # [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch]
147 148
    "trg_slf_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
                           seq_len), "float32"],
149 150 151 152
    # This input is used to remove attention weights on paddings of the source
    # input in the encoder-decoder attention.
    # The actual data shape of trg_src_attn_bias is:
    # [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch]
153 154
    "trg_src_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
                           seq_len), "float32"],
155 156 157
    # This input is used in independent decoder program for inference.
    # The actual data shape of enc_output is:
    # [batch_size, max_src_len_in_batch, d_model]
158
    "enc_output": [(batch_size, seq_len, ModelHyperParams.d_model), "float32"],
159 160
    # The actual data shape of label_word is:
    # [batch_size * max_trg_len_in_batch, 1]
G
guosheng 已提交
161
    "lbl_word": [(batch_size * seq_len, 1), "int64"],
162 163 164
    # This input is used to mask out the loss of paddding tokens.
    # The actual data shape of label_weight is:
    # [batch_size * max_trg_len_in_batch, 1]
G
guosheng 已提交
165 166
    "lbl_weight": [(batch_size * seq_len, 1), "float32"],
    # This input is used in beam-search decoder.
167 168 169 170
    "init_score": [(batch_size, 1), "float32", 2],
    # This input is used in beam-search decoder for the first gather
    # (cell states updation)
    "init_idx": [(batch_size, ), "int32"],
171 172
}

G
guosheng 已提交
173 174 175 176
# Names of word embedding table which might be reused for weight sharing.
word_emb_param_names = (
    "src_word_emb_table",
    "trg_word_emb_table", )
Y
ying 已提交
177 178 179 180
# Names of position encoding table which will be initialized externally.
pos_enc_param_names = (
    "src_pos_enc_table",
    "trg_pos_enc_table", )
181 182
# separated inputs for different usages.
encoder_data_input_fields = (
Y
ying 已提交
183 184
    "src_word",
    "src_pos",
185 186
    "src_slf_attn_bias", )
decoder_data_input_fields = (
Y
ying 已提交
187 188 189 190
    "trg_word",
    "trg_pos",
    "trg_slf_attn_bias",
    "trg_src_attn_bias",
191 192
    "enc_output", )
label_data_input_fields = (
193 194
    "lbl_word",
    "lbl_weight", )
195 196 197
# In fast decoder, trg_pos (only containing the current time step) is generated
# by ops and trg_slf_attn_bias is not needed.
fast_decoder_data_input_fields = (
198
    "trg_word",
199
    "init_score",
200
    "init_idx",
201
    "trg_src_attn_bias", )