train.py 9.3 KB
Newer Older
1
import os
2
import time
3
import numpy as np
Y
ying 已提交
4

5
import paddle
L
Luo Tao 已提交
6
import paddle.fluid as fluid
Y
ying 已提交
7

8
from model import transformer, position_encoding_init
9
from optim import LearningRateScheduler
10 11
from config import TrainTaskConfig, ModelHyperParams, pos_enc_param_names, \
        encoder_input_data_names, decoder_input_data_names, label_data_names
12 13


14 15 16 17
def pad_batch_data(insts,
                   pad_idx,
                   n_head,
                   is_target=False,
18
                   is_label=False,
19 20
                   return_attn_bias=True,
                   return_max_len=True):
21 22
    """
    Pad the instances to the max sequence length in batch, and generate the
23 24 25 26
    corresponding position data and attention bias.
    """
    return_list = []
    max_len = max(len(inst) for inst in insts)
G
guosheng 已提交
27 28 29 30
    # Any token included in dict can be used to pad, since the paddings' loss
    # will be masked out by weights and make no effect on parameter gradients.
    inst_data = np.array(
        [inst + [pad_idx] * (max_len - len(inst)) for inst in insts])
31
    return_list += [inst_data.astype("int64").reshape([-1, 1])]
32 33 34 35 36 37 38 39 40
    if is_label:  # label weight
        inst_weight = np.array(
            [[1.] * len(inst) + [0.] * (max_len - len(inst)) for inst in insts])
        return_list += [inst_weight.astype("float32").reshape([-1, 1])]
    else:  # position data
        inst_pos = np.array([
            range(1, len(inst) + 1) + [0] * (max_len - len(inst))
            for inst in insts
        ])
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
        return_list += [inst_pos.astype("int64").reshape([-1, 1])]
    if return_attn_bias:
        if is_target:
            # This is used to avoid attention on paddings and subsequent
            # words.
            slf_attn_bias_data = np.ones((inst_data.shape[0], max_len, max_len))
            slf_attn_bias_data = np.triu(slf_attn_bias_data, 1).reshape(
                [-1, 1, max_len, max_len])
            slf_attn_bias_data = np.tile(slf_attn_bias_data,
                                         [1, n_head, 1, 1]) * [-1e9]
        else:
            # This is used to avoid attention on paddings.
            slf_attn_bias_data = np.array([[0] * len(inst) + [-1e9] *
                                           (max_len - len(inst))
                                           for inst in insts])
            slf_attn_bias_data = np.tile(
                slf_attn_bias_data.reshape([-1, 1, 1, max_len]),
                [1, n_head, max_len, 1])
        return_list += [slf_attn_bias_data.astype("float32")]
    if return_max_len:
        return_list += [max_len]
    return return_list if len(return_list) > 1 else return_list[0]


def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
G
guosheng 已提交
66
                        n_head, d_model):
67 68
    """
    Put all padded data needed by training into a dict.
69
    """
70
    src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
G
guosheng 已提交
71
        [inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
72
    trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
G
guosheng 已提交
73
        [inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True)
74 75
    trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
                                [1, 1, trg_max_len, 1]).astype("float32")
76 77 78 79

    # These shape tensors are used in reshape_op.
    src_data_shape = np.array([len(insts), src_max_len, d_model], dtype="int32")
    trg_data_shape = np.array([len(insts), trg_max_len, d_model], dtype="int32")
G
guosheng 已提交
80 81 82 83 84 85 86 87 88 89 90 91
    src_slf_attn_pre_softmax_shape = np.array(
        [-1, src_slf_attn_bias.shape[-1]], dtype="int32")
    src_slf_attn_post_softmax_shape = np.array(
        src_slf_attn_bias.shape, dtype="int32")
    trg_slf_attn_pre_softmax_shape = np.array(
        [-1, trg_slf_attn_bias.shape[-1]], dtype="int32")
    trg_slf_attn_post_softmax_shape = np.array(
        trg_slf_attn_bias.shape, dtype="int32")
    trg_src_attn_pre_softmax_shape = np.array(
        [-1, trg_src_attn_bias.shape[-1]], dtype="int32")
    trg_src_attn_post_softmax_shape = np.array(
        trg_src_attn_bias.shape, dtype="int32")
92

93 94 95 96 97 98 99 100
    lbl_word, lbl_weight = pad_batch_data(
        [inst[2] for inst in insts],
        trg_pad_idx,
        n_head,
        is_target=False,
        is_label=True,
        return_attn_bias=False,
        return_max_len=False)
101

102 103
    input_dict = dict(
        zip(input_data_names, [
104
            src_word, src_pos, src_slf_attn_bias, src_data_shape,
G
guosheng 已提交
105 106
            src_slf_attn_pre_softmax_shape, src_slf_attn_post_softmax_shape,
            trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias,
107 108 109
            trg_data_shape, trg_slf_attn_pre_softmax_shape,
            trg_slf_attn_post_softmax_shape, trg_src_attn_pre_softmax_shape,
            trg_src_attn_post_softmax_shape, lbl_word, lbl_weight
110
        ]))
111 112 113 114
    return input_dict


def main():
115 116 117
    place = fluid.CUDAPlace(0) if TrainTaskConfig.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)

G
guosheng 已提交
118
    sum_cost, avg_cost, predict, token_num = transformer(
G
guosheng 已提交
119
        ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size,
G
guosheng 已提交
120
        ModelHyperParams.max_length + 1, ModelHyperParams.n_layer,
G
guosheng 已提交
121 122 123
        ModelHyperParams.n_head, ModelHyperParams.d_key,
        ModelHyperParams.d_value, ModelHyperParams.d_model,
        ModelHyperParams.d_inner_hid, ModelHyperParams.dropout)
124

125 126 127
    lr_scheduler = LearningRateScheduler(ModelHyperParams.d_model,
                                         TrainTaskConfig.warmup_steps, place,
                                         TrainTaskConfig.learning_rate)
128
    optimizer = fluid.optimizer.Adam(
129
        learning_rate=lr_scheduler.learning_rate,
Y
ying 已提交
130 131 132
        beta1=TrainTaskConfig.beta1,
        beta2=TrainTaskConfig.beta2,
        epsilon=TrainTaskConfig.eps)
133
    optimizer.minimize(avg_cost if TrainTaskConfig.use_avg_cost else sum_cost)
134 135 136

    train_data = paddle.batch(
        paddle.reader.shuffle(
Y
ying 已提交
137 138
            paddle.dataset.wmt16.train(ModelHyperParams.src_vocab_size,
                                       ModelHyperParams.trg_vocab_size),
G
guosheng 已提交
139
            buf_size=100000),
Y
ying 已提交
140
        batch_size=TrainTaskConfig.batch_size)
141

142 143 144
    # Program to do validation.
    test_program = fluid.default_main_program().clone()
    with fluid.program_guard(test_program):
G
guosheng 已提交
145
        test_program = fluid.io.get_inference_program([avg_cost])
146 147 148 149 150 151
    val_data = paddle.batch(
        paddle.dataset.wmt16.validation(ModelHyperParams.src_vocab_size,
                                        ModelHyperParams.trg_vocab_size),
        batch_size=TrainTaskConfig.batch_size)

    def test(exe):
G
guosheng 已提交
152 153
        test_total_cost = 0
        test_total_token = 0
154 155 156
        for batch_id, data in enumerate(val_data()):
            data_input = prepare_batch_input(
                data, encoder_input_data_names + decoder_input_data_names[:-1] +
G
guosheng 已提交
157 158 159
                label_data_names, ModelHyperParams.eos_idx,
                ModelHyperParams.eos_idx, ModelHyperParams.n_head,
                ModelHyperParams.d_model)
G
guosheng 已提交
160 161 162 163 164 165 166 167 168 169
            test_sum_cost, test_token_num = exe.run(
                test_program,
                feed=data_input,
                fetch_list=[sum_cost, token_num],
                use_program_cache=True)
            test_total_cost += test_sum_cost
            test_total_token += test_token_num
        test_avg_cost = test_total_cost / test_total_token
        test_ppl = np.exp([min(test_avg_cost, 100)])
        return test_avg_cost, test_ppl
170

171 172 173 174 175 176
    # Initialize the parameters.
    exe.run(fluid.framework.default_startup_program())
    for pos_enc_param_name in pos_enc_param_names:
        pos_enc_param = fluid.global_scope().find_var(
            pos_enc_param_name).get_tensor()
        pos_enc_param.set(
G
guosheng 已提交
177
            position_encoding_init(ModelHyperParams.max_length + 1,
Y
ying 已提交
178 179 180
                                   ModelHyperParams.d_model), place)

    for pass_id in xrange(TrainTaskConfig.pass_num):
181
        pass_start_time = time.time()
Y
ying 已提交
182
        for batch_id, data in enumerate(train_data()):
183 184
            if len(data) != TrainTaskConfig.batch_size:
                continue
Y
ying 已提交
185
            data_input = prepare_batch_input(
186
                data, encoder_input_data_names + decoder_input_data_names[:-1] +
G
guosheng 已提交
187 188 189
                label_data_names, ModelHyperParams.eos_idx,
                ModelHyperParams.eos_idx, ModelHyperParams.n_head,
                ModelHyperParams.d_model)
190
            lr_scheduler.update_learning_rate(data_input)
191 192
            outs = exe.run(fluid.framework.default_main_program(),
                           feed=data_input,
G
guosheng 已提交
193
                           fetch_list=[sum_cost, avg_cost],
X
Xin Pan 已提交
194
                           use_program_cache=True)
G
guosheng 已提交
195
            sum_cost_val, avg_cost_val = np.array(outs[0]), np.array(outs[1])
196
            print("epoch: %d, batch: %d, sum loss: %f, avg loss: %f, ppl: %f" %
197 198
                  (pass_id, batch_id, sum_cost_val, avg_cost_val,
                   np.exp([min(avg_cost_val[0], 100)])))
199
        # Validate and save the model for inference.
G
guosheng 已提交
200
        val_avg_cost, val_ppl = test(exe)
201 202
        pass_end_time = time.time()
        time_consumed = pass_end_time - pass_start_time
G
guosheng 已提交
203 204
        print("epoch: %d, val avg loss: %f, val ppl: %f, "
              "consumed %fs" % (pass_id, val_avg_cost, val_ppl, time_consumed))
205 206 207 208 209
        fluid.io.save_inference_model(
            os.path.join(TrainTaskConfig.model_dir,
                         "pass_" + str(pass_id) + ".infer.model"),
            encoder_input_data_names + decoder_input_data_names[:-1],
            [predict], exe)
210 211 212 213


if __name__ == "__main__":
    main()