train.py 9.9 KB
Newer Older
1
import os
2
import time
3
import numpy as np
Y
ying 已提交
4

5
import paddle
L
Luo Tao 已提交
6
import paddle.fluid as fluid
Y
ying 已提交
7

8
from model import transformer, position_encoding_init
9
from optim import LearningRateScheduler
G
guosheng 已提交
10
from config import *
11 12


13 14 15 16
def pad_batch_data(insts,
                   pad_idx,
                   n_head,
                   is_target=False,
17
                   is_label=False,
18 19
                   return_attn_bias=True,
                   return_max_len=True):
20 21
    """
    Pad the instances to the max sequence length in batch, and generate the
22 23 24 25
    corresponding position data and attention bias.
    """
    return_list = []
    max_len = max(len(inst) for inst in insts)
G
guosheng 已提交
26 27 28 29
    # Any token included in dict can be used to pad, since the paddings' loss
    # will be masked out by weights and make no effect on parameter gradients.
    inst_data = np.array(
        [inst + [pad_idx] * (max_len - len(inst)) for inst in insts])
30
    return_list += [inst_data.astype("int64").reshape([-1, 1])]
31 32 33 34 35 36 37 38 39
    if is_label:  # label weight
        inst_weight = np.array(
            [[1.] * len(inst) + [0.] * (max_len - len(inst)) for inst in insts])
        return_list += [inst_weight.astype("float32").reshape([-1, 1])]
    else:  # position data
        inst_pos = np.array([
            range(1, len(inst) + 1) + [0] * (max_len - len(inst))
            for inst in insts
        ])
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
        return_list += [inst_pos.astype("int64").reshape([-1, 1])]
    if return_attn_bias:
        if is_target:
            # This is used to avoid attention on paddings and subsequent
            # words.
            slf_attn_bias_data = np.ones((inst_data.shape[0], max_len, max_len))
            slf_attn_bias_data = np.triu(slf_attn_bias_data, 1).reshape(
                [-1, 1, max_len, max_len])
            slf_attn_bias_data = np.tile(slf_attn_bias_data,
                                         [1, n_head, 1, 1]) * [-1e9]
        else:
            # This is used to avoid attention on paddings.
            slf_attn_bias_data = np.array([[0] * len(inst) + [-1e9] *
                                           (max_len - len(inst))
                                           for inst in insts])
            slf_attn_bias_data = np.tile(
                slf_attn_bias_data.reshape([-1, 1, 1, max_len]),
                [1, n_head, max_len, 1])
        return_list += [slf_attn_bias_data.astype("float32")]
    if return_max_len:
        return_list += [max_len]
    return return_list if len(return_list) > 1 else return_list[0]


G
guosheng 已提交
64 65
def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
                        trg_pad_idx, n_head, d_model):
66 67
    """
    Put all padded data needed by training into a dict.
68
    """
69
    src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
G
guosheng 已提交
70
        [inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
71
    trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
G
guosheng 已提交
72
        [inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True)
73 74
    trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
                                [1, 1, trg_max_len, 1]).astype("float32")
75 76

    # These shape tensors are used in reshape_op.
G
guosheng 已提交
77 78
    src_data_shape = np.array([-1, src_max_len, d_model], dtype="int32")
    trg_data_shape = np.array([-1, trg_max_len, d_model], dtype="int32")
G
guosheng 已提交
79 80 81
    src_slf_attn_pre_softmax_shape = np.array(
        [-1, src_slf_attn_bias.shape[-1]], dtype="int32")
    src_slf_attn_post_softmax_shape = np.array(
G
guosheng 已提交
82
        [-1] + list(src_slf_attn_bias.shape[1:]), dtype="int32")
G
guosheng 已提交
83 84 85
    trg_slf_attn_pre_softmax_shape = np.array(
        [-1, trg_slf_attn_bias.shape[-1]], dtype="int32")
    trg_slf_attn_post_softmax_shape = np.array(
G
guosheng 已提交
86
        [-1] + list(trg_slf_attn_bias.shape[1:]), dtype="int32")
G
guosheng 已提交
87 88 89
    trg_src_attn_pre_softmax_shape = np.array(
        [-1, trg_src_attn_bias.shape[-1]], dtype="int32")
    trg_src_attn_post_softmax_shape = np.array(
G
guosheng 已提交
90
        [-1] + list(trg_src_attn_bias.shape[1:]), dtype="int32")
91

92 93 94 95 96 97 98 99
    lbl_word, lbl_weight = pad_batch_data(
        [inst[2] for inst in insts],
        trg_pad_idx,
        n_head,
        is_target=False,
        is_label=True,
        return_attn_bias=False,
        return_max_len=False)
100

G
guosheng 已提交
101 102 103 104
    data_input_dict = dict(
        zip(data_input_names, [
            src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
            trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight
105
        ]))
G
guosheng 已提交
106 107 108 109 110 111 112 113
    util_input_dict = dict(
        zip(util_input_names, [
            src_data_shape, src_slf_attn_pre_softmax_shape,
            src_slf_attn_post_softmax_shape, trg_data_shape,
            trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape,
            trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape
        ]))
    return data_input_dict, util_input_dict
114 115 116


def main():
117 118 119
    place = fluid.CUDAPlace(0) if TrainTaskConfig.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)

G
guosheng 已提交
120
    sum_cost, avg_cost, predict, token_num = transformer(
G
guosheng 已提交
121
        ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size,
G
guosheng 已提交
122
        ModelHyperParams.max_length + 1, ModelHyperParams.n_layer,
G
guosheng 已提交
123 124 125
        ModelHyperParams.n_head, ModelHyperParams.d_key,
        ModelHyperParams.d_value, ModelHyperParams.d_model,
        ModelHyperParams.d_inner_hid, ModelHyperParams.dropout)
126

127
    lr_scheduler = LearningRateScheduler(ModelHyperParams.d_model,
G
guosheng 已提交
128
                                         TrainTaskConfig.warmup_steps,
129
                                         TrainTaskConfig.learning_rate)
130
    optimizer = fluid.optimizer.Adam(
131
        learning_rate=lr_scheduler.learning_rate,
Y
ying 已提交
132 133 134
        beta1=TrainTaskConfig.beta1,
        beta2=TrainTaskConfig.beta2,
        epsilon=TrainTaskConfig.eps)
135
    optimizer.minimize(avg_cost if TrainTaskConfig.use_avg_cost else sum_cost)
136 137 138

    train_data = paddle.batch(
        paddle.reader.shuffle(
Y
ying 已提交
139 140
            paddle.dataset.wmt16.train(ModelHyperParams.src_vocab_size,
                                       ModelHyperParams.trg_vocab_size),
G
guosheng 已提交
141
            buf_size=100000),
Y
ying 已提交
142
        batch_size=TrainTaskConfig.batch_size)
143

144 145 146
    # Program to do validation.
    test_program = fluid.default_main_program().clone()
    with fluid.program_guard(test_program):
G
guosheng 已提交
147
        test_program = fluid.io.get_inference_program([avg_cost])
148 149 150 151 152 153
    val_data = paddle.batch(
        paddle.dataset.wmt16.validation(ModelHyperParams.src_vocab_size,
                                        ModelHyperParams.trg_vocab_size),
        batch_size=TrainTaskConfig.batch_size)

    def test(exe):
G
guosheng 已提交
154 155
        test_total_cost = 0
        test_total_token = 0
156
        for batch_id, data in enumerate(val_data()):
G
guosheng 已提交
157 158 159 160
            data_input_dict, util_input_dict = prepare_batch_input(
                data, data_input_names, util_input_names,
                ModelHyperParams.eos_idx, ModelHyperParams.eos_idx,
                ModelHyperParams.n_head, ModelHyperParams.d_model)
G
guosheng 已提交
161 162
            test_sum_cost, test_token_num = exe.run(
                test_program,
G
guosheng 已提交
163
                feed=dict(data_input_dict.items() + util_input_dict.items()),
G
guosheng 已提交
164 165 166 167 168 169 170
                fetch_list=[sum_cost, token_num],
                use_program_cache=True)
            test_total_cost += test_sum_cost
            test_total_token += test_token_num
        test_avg_cost = test_total_cost / test_total_token
        test_ppl = np.exp([min(test_avg_cost, 100)])
        return test_avg_cost, test_ppl
171

G
guosheng 已提交
172 173 174 175
    def set_util_input(input_name_value):
        tensor = fluid.global_scope().find_var(input_name_value[0]).get_tensor()
        tensor.set(input_name_value[1], place)

176 177 178
    # Initialize the parameters.
    exe.run(fluid.framework.default_startup_program())
    for pos_enc_param_name in pos_enc_param_names:
G
guosheng 已提交
179 180 181 182 183 184 185 186 187 188 189
        set_util_input((pos_enc_param_name, position_encoding_init(
            ModelHyperParams.max_length + 1, ModelHyperParams.d_model)))

    data_input_names = encoder_data_input_fields + decoder_data_input_fields[:
                                                                             -1] + label_data_names
    util_input_names = encoder_util_input_fields + decoder_util_input_fields

    train_exe = fluid.ParallelExecutor(
        use_cuda=TrainTaskConfig.use_gpu,
        loss_name=avg_cost.name
        if TrainTaskConfig.use_avg_cost else sum_cost.name)
Y
ying 已提交
190 191

    for pass_id in xrange(TrainTaskConfig.pass_num):
192
        pass_start_time = time.time()
Y
ying 已提交
193
        for batch_id, data in enumerate(train_data()):
G
guosheng 已提交
194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
            data_input_dict, util_input_dict = prepare_batch_input(
                data, data_input_names, util_input_names,
                ModelHyperParams.eos_idx, ModelHyperParams.eos_idx,
                ModelHyperParams.n_head, ModelHyperParams.d_model)
            map(set_util_input,
                zip(util_input_dict.keys() + [lr_scheduler.learning_rate.name],
                    util_input_dict.values() +
                    [lr_scheduler.update_learning_rate()]))
            outs = train_exe.run(feed_dict=data_input_dict,
                                 fetch_list=[sum_cost.name, token_num.name])
            sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[1])
            total_sum_cost = sum_cost_val.sum(
            )  # sum the cost from multi devices
            total_token_num = token_num_val.sum()
            total_avg_cost = total_sum_cost / total_token_num
209
            print("epoch: %d, batch: %d, sum loss: %f, avg loss: %f, ppl: %f" %
G
guosheng 已提交
210 211
                  (pass_id, batch_id, total_sum_cost, total_avg_cost,
                   np.exp([min(total_avg_cost, 100)])))
212
        # Validate and save the model for inference.
G
guosheng 已提交
213
        val_avg_cost, val_ppl = test(exe)
214 215
        pass_end_time = time.time()
        time_consumed = pass_end_time - pass_start_time
G
guosheng 已提交
216 217
        print("epoch: %d, val avg loss: %f, val ppl: %f, "
              "consumed %fs" % (pass_id, val_avg_cost, val_ppl, time_consumed))
218 219 220 221 222
        fluid.io.save_inference_model(
            os.path.join(TrainTaskConfig.model_dir,
                         "pass_" + str(pass_id) + ".infer.model"),
            encoder_input_data_names + decoder_input_data_names[:-1],
            [predict], exe)
223 224 225 226


if __name__ == "__main__":
    main()