config.py 7.7 KB
Newer Older
Y
ying 已提交
1
class TrainTaskConfig(object):
G
guosheng 已提交
2
    # only support GPU currently
3
    use_gpu = True
Y
ying 已提交
4
    # the epoch number to train.
5
    pass_num = 30
6
    # the number of sequences contained in a mini-batch.
G
guosheng 已提交
7
    # deprecated, set batch_size in args.
8
    batch_size = 32
9
    # the hyper parameters for Adam optimizer.
10
    # This static learning_rate will be multiplied to the LearningRateScheduler
11 12
    # derived learning rate the to get the final learning rate.
    learning_rate = 1
Y
ying 已提交
13 14 15
    beta1 = 0.9
    beta2 = 0.98
    eps = 1e-9
16
    # the parameters for learning rate scheduling.
17
    warmup_steps = 4000
18 19 20 21
    # the weight used to mix up the ground-truth distribution and the fixed
    # uniform distribution in label smoothing when training.
    # Set this as zero if label smoothing is not wanted.
    label_smooth_eps = 0.1
22 23
    # the directory for saving trained models.
    model_dir = "trained_models"
24 25 26 27 28 29 30 31 32
    # the directory for saving checkpoints.
    ckpt_dir = "trained_ckpts"
    # the directory for loading checkpoint.
    # If provided, continue training from the checkpoint.
    ckpt_path = None
    # the parameter to initialize the learning rate scheduler.
    # It should be provided if use checkpoints, since the checkpoint doesn't
    # include the training step counter currently.
    start_step = 0
33 34 35


class InferTaskConfig(object):
36
    use_gpu = True
37
    # the number of examples in one run for sequence generation.
38
    batch_size = 10
39
    # the parameters for beam search.
40
    beam_size = 5
41
    max_out_len = 256
42
    # the number of decoded sentences to output.
43
    n_best = 1
44 45 46
    # the flags indicating whether to output the special tokens.
    output_bos = False
    output_eos = False
G
guosheng 已提交
47
    output_unk = True
48 49
    # the directory for loading the trained model.
    model_path = "trained_models/pass_1.infer.model"
50

Y
ying 已提交
51 52

class ModelHyperParams(object):
G
guosheng 已提交
53 54
    # These following five vocabularies related configurations will be set
    # automatically according to the passed vocabulary path and special tokens.
Y
ying 已提交
55 56 57 58
    # size of source word dictionary.
    src_vocab_size = 10000
    # size of target word dictionay
    trg_vocab_size = 10000
59 60 61 62
    # index for <bos> token
    bos_idx = 0
    # index for <eos> token
    eos_idx = 1
63 64
    # index for <unk> token
    unk_idx = 2
65 66
    # max length of sequences deciding the size of position encoding table.
    # Start from 1 and count start and end tokens in.
G
guosheng 已提交
67
    max_length = 256
Y
ying 已提交
68 69 70 71 72
    # the dimension for word embeddings, which is also the last dimension of
    # the input and output of multi-head attention, position-wise feed-forward
    # networks, encoder and decoder.
    d_model = 512
    # size of the hidden layer in position-wise feed-forward networks.
G
guosheng 已提交
73
    d_inner_hid = 2048
Y
ying 已提交
74 75 76 77 78 79 80 81 82 83
    # the dimension that keys are projected to for dot-product attention.
    d_key = 64
    # the dimension that values are projected to for dot-product attention.
    d_value = 64
    # number of head used in multi-head attention.
    n_head = 8
    # number of sub-layers to be stacked in the encoder and decoder.
    n_layer = 6
    # dropout rate used by all dropout layers.
    dropout = 0.1
G
guosheng 已提交
84 85
    # random seed used in dropout for CE.
    dropout_seed = None
G
guosheng 已提交
86 87 88
    # the flag indicating whether to share embedding and softmax weights.
    # vocabularies in source and target should be same for weight sharing.
    weight_sharing = True
Y
ying 已提交
89 90


91 92 93 94 95 96 97 98 99 100
def merge_cfg_from_list(cfg_list, g_cfgs):
    """
    Set the above global configurations using the cfg_list. 
    """
    assert len(cfg_list) % 2 == 0
    for key, value in zip(cfg_list[0::2], cfg_list[1::2]):
        for g_cfg in g_cfgs:
            if hasattr(g_cfg, key):
                try:
                    value = eval(value)
101
                except Exception:  # for file path
102 103 104 105 106
                    pass
                setattr(g_cfg, key, value)
                break


107 108 109
# The placeholder for batch_size in compile time. Must be -1 currently to be
# consistent with some ops' infer-shape output in compile time, such as the
# sequence_expand op used in beamsearch decoder.
110
batch_size = -1
111
# The placeholder for squence length in compile time.
112
seq_len = ModelHyperParams.max_length
113 114 115 116 117 118
# Here list the data shapes and data types of all inputs.
# The shapes here act as placeholder and are set to pass the infer-shape in
# compile time.
input_descs = {
    # The actual data shape of src_word is:
    # [batch_size * max_src_len_in_batch, 1]
119
    "src_word": [(batch_size, seq_len, 1L), "int64", 2],
120 121
    # The actual data shape of src_pos is:
    # [batch_size * max_src_len_in_batch, 1]
122
    "src_pos": [(batch_size, seq_len, 1L), "int64"],
123 124 125 126
    # This input is used to remove attention weights on paddings in the
    # encoder.
    # The actual data shape of src_slf_attn_bias is:
    # [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch]
127 128
    "src_slf_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
                           seq_len), "float32"],
129 130
    # The actual data shape of trg_word is:
    # [batch_size * max_trg_len_in_batch, 1]
131
    "trg_word": [(batch_size, seq_len, 1L), "int64",
132
                 2],  # lod_level is only used in fast decoder.
133 134
    # The actual data shape of trg_pos is:
    # [batch_size * max_trg_len_in_batch, 1]
135
    "trg_pos": [(batch_size, seq_len, 1L), "int64"],
136 137 138 139
    # This input is used to remove attention weights on paddings and
    # subsequent words in the decoder.
    # The actual data shape of trg_slf_attn_bias is:
    # [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch]
140 141
    "trg_slf_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
                           seq_len), "float32"],
142 143 144 145
    # This input is used to remove attention weights on paddings of the source
    # input in the encoder-decoder attention.
    # The actual data shape of trg_src_attn_bias is:
    # [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch]
146 147
    "trg_src_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
                           seq_len), "float32"],
148 149 150
    # This input is used in independent decoder program for inference.
    # The actual data shape of enc_output is:
    # [batch_size, max_src_len_in_batch, d_model]
151
    "enc_output": [(batch_size, seq_len, ModelHyperParams.d_model), "float32"],
152 153
    # The actual data shape of label_word is:
    # [batch_size * max_trg_len_in_batch, 1]
154
    "lbl_word": [(batch_size * seq_len, 1L), "int64"],
155 156 157
    # This input is used to mask out the loss of paddding tokens.
    # The actual data shape of label_weight is:
    # [batch_size * max_trg_len_in_batch, 1]
158 159 160 161 162
    "lbl_weight": [(batch_size * seq_len, 1L), "float32"],
    # These inputs are used to change the shape tensor in beam-search decoder.
    "trg_slf_attn_pre_softmax_shape_delta": [(2L, ), "int32"],
    "trg_slf_attn_post_softmax_shape_delta": [(4L, ), "int32"],
    "init_score": [(batch_size, 1L), "float32"],
163 164
}

G
guosheng 已提交
165 166 167 168
# Names of word embedding table which might be reused for weight sharing.
word_emb_param_names = (
    "src_word_emb_table",
    "trg_word_emb_table", )
Y
ying 已提交
169 170 171 172
# Names of position encoding table which will be initialized externally.
pos_enc_param_names = (
    "src_pos_enc_table",
    "trg_pos_enc_table", )
173 174
# separated inputs for different usages.
encoder_data_input_fields = (
Y
ying 已提交
175 176
    "src_word",
    "src_pos",
177 178
    "src_slf_attn_bias", )
decoder_data_input_fields = (
Y
ying 已提交
179 180 181 182
    "trg_word",
    "trg_pos",
    "trg_slf_attn_bias",
    "trg_src_attn_bias",
183 184
    "enc_output", )
label_data_input_fields = (
185 186
    "lbl_word",
    "lbl_weight", )
187 188 189
# In fast decoder, trg_pos (only containing the current time step) is generated
# by ops and trg_slf_attn_bias is not needed.
fast_decoder_data_input_fields = (
190
    "trg_word",
191
    "init_score",
192
    "trg_src_attn_bias", )
Y
Yu Yang 已提交
193 194 195
# fast_decoder_util_input_fields = (
#     "trg_slf_attn_pre_softmax_shape_delta",
#     "trg_slf_attn_post_softmax_shape_delta", )