config.py 9.0 KB
Newer Older
Y
ying 已提交
1
class TrainTaskConfig(object):
G
guosheng 已提交
2
    # only support GPU currently
3
    use_gpu = True
Y
ying 已提交
4
    # the epoch number to train.
5
    pass_num = 30
6
    # the number of sequences contained in a mini-batch.
G
guosheng 已提交
7
    # deprecated, set batch_size in args.
8
    batch_size = 32
9
    # the hyper parameters for Adam optimizer.
10
    # This static learning_rate will be multiplied to the LearningRateScheduler
11 12
    # derived learning rate the to get the final learning rate.
    learning_rate = 1
Y
ying 已提交
13 14 15
    beta1 = 0.9
    beta2 = 0.98
    eps = 1e-9
16
    # the parameters for learning rate scheduling.
17
    warmup_steps = 4000
18 19 20 21
    # the weight used to mix up the ground-truth distribution and the fixed
    # uniform distribution in label smoothing when training.
    # Set this as zero if label smoothing is not wanted.
    label_smooth_eps = 0.1
22 23
    # the directory for saving trained models.
    model_dir = "trained_models"
24 25 26 27 28 29 30 31 32
    # the directory for saving checkpoints.
    ckpt_dir = "trained_ckpts"
    # the directory for loading checkpoint.
    # If provided, continue training from the checkpoint.
    ckpt_path = None
    # the parameter to initialize the learning rate scheduler.
    # It should be provided if use checkpoints, since the checkpoint doesn't
    # include the training step counter currently.
    start_step = 0
33 34 35


class InferTaskConfig(object):
36
    use_gpu = True
37
    # the number of examples in one run for sequence generation.
38
    batch_size = 10
39
    # the parameters for beam search.
40
    beam_size = 5
41
    max_out_len = 256
42
    # the number of decoded sentences to output.
43
    n_best = 1
44 45 46
    # the flags indicating whether to output the special tokens.
    output_bos = False
    output_eos = False
G
guosheng 已提交
47
    output_unk = True
48 49
    # the directory for loading the trained model.
    model_path = "trained_models/pass_1.infer.model"
50

Y
ying 已提交
51 52

class ModelHyperParams(object):
G
guosheng 已提交
53 54
    # These following five vocabularies related configurations will be set
    # automatically according to the passed vocabulary path and special tokens.
Y
ying 已提交
55 56 57 58
    # size of source word dictionary.
    src_vocab_size = 10000
    # size of target word dictionay
    trg_vocab_size = 10000
59 60 61 62
    # index for <bos> token
    bos_idx = 0
    # index for <eos> token
    eos_idx = 1
63 64
    # index for <unk> token
    unk_idx = 2
65 66
    # max length of sequences deciding the size of position encoding table.
    # Start from 1 and count start and end tokens in.
G
guosheng 已提交
67
    max_length = 256
Y
ying 已提交
68 69 70 71 72
    # the dimension for word embeddings, which is also the last dimension of
    # the input and output of multi-head attention, position-wise feed-forward
    # networks, encoder and decoder.
    d_model = 512
    # size of the hidden layer in position-wise feed-forward networks.
G
guosheng 已提交
73
    d_inner_hid = 2048
Y
ying 已提交
74 75 76 77 78 79 80 81 82 83
    # the dimension that keys are projected to for dot-product attention.
    d_key = 64
    # the dimension that values are projected to for dot-product attention.
    d_value = 64
    # number of head used in multi-head attention.
    n_head = 8
    # number of sub-layers to be stacked in the encoder and decoder.
    n_layer = 6
    # dropout rate used by all dropout layers.
    dropout = 0.1
G
guosheng 已提交
84 85 86
    # the flag indicating whether to share embedding and softmax weights.
    # vocabularies in source and target should be same for weight sharing.
    weight_sharing = True
Y
ying 已提交
87 88


89 90 91 92 93 94 95 96 97 98
def merge_cfg_from_list(cfg_list, g_cfgs):
    """
    Set the above global configurations using the cfg_list. 
    """
    assert len(cfg_list) % 2 == 0
    for key, value in zip(cfg_list[0::2], cfg_list[1::2]):
        for g_cfg in g_cfgs:
            if hasattr(g_cfg, key):
                try:
                    value = eval(value)
99
                except Exception:  # for file path
100 101 102 103 104
                    pass
                setattr(g_cfg, key, value)
                break


105 106 107
# The placeholder for batch_size in compile time. Must be -1 currently to be
# consistent with some ops' infer-shape output in compile time, such as the
# sequence_expand op used in beamsearch decoder.
108
batch_size = -1
109
# The placeholder for squence length in compile time.
110
seq_len = ModelHyperParams.max_length
111 112 113 114 115 116
# Here list the data shapes and data types of all inputs.
# The shapes here act as placeholder and are set to pass the infer-shape in
# compile time.
input_descs = {
    # The actual data shape of src_word is:
    # [batch_size * max_src_len_in_batch, 1]
117
    "src_word": [(batch_size * seq_len, 1L), "int64", 2],
118 119
    # The actual data shape of src_pos is:
    # [batch_size * max_src_len_in_batch, 1]
120
    "src_pos": [(batch_size * seq_len, 1L), "int64"],
121 122 123 124
    # This input is used to remove attention weights on paddings in the
    # encoder.
    # The actual data shape of src_slf_attn_bias is:
    # [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch]
125 126
    "src_slf_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
                           seq_len), "float32"],
127 128 129 130 131 132 133 134
    # This shape input is used to reshape the output of embedding layer.
    "src_data_shape": [(3L, ), "int32"],
    # This shape input is used to reshape before softmax in self attention.
    "src_slf_attn_pre_softmax_shape": [(2L, ), "int32"],
    # This shape input is used to reshape after softmax in self attention.
    "src_slf_attn_post_softmax_shape": [(4L, ), "int32"],
    # The actual data shape of trg_word is:
    # [batch_size * max_trg_len_in_batch, 1]
135 136
    "trg_word": [(batch_size * seq_len, 1L), "int64",
                 2],  # lod_level is only used in fast decoder.
137 138
    # The actual data shape of trg_pos is:
    # [batch_size * max_trg_len_in_batch, 1]
139
    "trg_pos": [(batch_size * seq_len, 1L), "int64"],
140 141 142 143
    # This input is used to remove attention weights on paddings and
    # subsequent words in the decoder.
    # The actual data shape of trg_slf_attn_bias is:
    # [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch]
144 145
    "trg_slf_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
                           seq_len), "float32"],
146 147 148 149
    # This input is used to remove attention weights on paddings of the source
    # input in the encoder-decoder attention.
    # The actual data shape of trg_src_attn_bias is:
    # [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch]
150 151
    "trg_src_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
                           seq_len), "float32"],
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
    # This shape input is used to reshape the output of embedding layer.
    "trg_data_shape": [(3L, ), "int32"],
    # This shape input is used to reshape before softmax in self attention.
    "trg_slf_attn_pre_softmax_shape": [(2L, ), "int32"],
    # This shape input is used to reshape after softmax in self attention.
    "trg_slf_attn_post_softmax_shape": [(4L, ), "int32"],
    # This shape input is used to reshape before softmax in encoder-decoder
    # attention.
    "trg_src_attn_pre_softmax_shape": [(2L, ), "int32"],
    # This shape input is used to reshape after softmax in encoder-decoder
    # attention.
    "trg_src_attn_post_softmax_shape": [(4L, ), "int32"],
    # This input is used in independent decoder program for inference.
    # The actual data shape of enc_output is:
    # [batch_size, max_src_len_in_batch, d_model]
167
    "enc_output": [(batch_size, seq_len, ModelHyperParams.d_model), "float32"],
168 169
    # The actual data shape of label_word is:
    # [batch_size * max_trg_len_in_batch, 1]
170
    "lbl_word": [(batch_size * seq_len, 1L), "int64"],
171 172 173
    # This input is used to mask out the loss of paddding tokens.
    # The actual data shape of label_weight is:
    # [batch_size * max_trg_len_in_batch, 1]
174 175 176 177 178
    "lbl_weight": [(batch_size * seq_len, 1L), "float32"],
    # These inputs are used to change the shape tensor in beam-search decoder.
    "trg_slf_attn_pre_softmax_shape_delta": [(2L, ), "int32"],
    "trg_slf_attn_post_softmax_shape_delta": [(4L, ), "int32"],
    "init_score": [(batch_size, 1L), "float32"],
179 180
}

G
guosheng 已提交
181 182 183 184
# Names of word embedding table which might be reused for weight sharing.
word_emb_param_names = (
    "src_word_emb_table",
    "trg_word_emb_table", )
Y
ying 已提交
185 186 187 188
# Names of position encoding table which will be initialized externally.
pos_enc_param_names = (
    "src_pos_enc_table",
    "trg_pos_enc_table", )
189 190
# separated inputs for different usages.
encoder_data_input_fields = (
Y
ying 已提交
191 192
    "src_word",
    "src_pos",
193 194
    "src_slf_attn_bias", )
encoder_util_input_fields = (
195
    "src_data_shape",
G
guosheng 已提交
196 197
    "src_slf_attn_pre_softmax_shape",
    "src_slf_attn_post_softmax_shape", )
198
decoder_data_input_fields = (
Y
ying 已提交
199 200 201 202
    "trg_word",
    "trg_pos",
    "trg_slf_attn_bias",
    "trg_src_attn_bias",
203 204
    "enc_output", )
decoder_util_input_fields = (
205
    "trg_data_shape",
G
guosheng 已提交
206 207 208
    "trg_slf_attn_pre_softmax_shape",
    "trg_slf_attn_post_softmax_shape",
    "trg_src_attn_pre_softmax_shape",
209 210
    "trg_src_attn_post_softmax_shape", )
label_data_input_fields = (
211 212
    "lbl_word",
    "lbl_weight", )
213 214 215
# In fast decoder, trg_pos (only containing the current time step) is generated
# by ops and trg_slf_attn_bias is not needed.
fast_decoder_data_input_fields = (
216
    "trg_word",
217
    "init_score",
218
    "trg_src_attn_bias", )
219 220 221
fast_decoder_util_input_fields = decoder_util_input_fields + (
    "trg_slf_attn_pre_softmax_shape_delta",
    "trg_slf_attn_post_softmax_shape_delta", )