train.py 6.5 KB
Newer Older
1
import numpy as np
Y
ying 已提交
2

3
import paddle.v2 as paddle
L
Luo Tao 已提交
4
import paddle.fluid as fluid
Y
ying 已提交
5

6
from model import transformer, position_encoding_init
7
from optim import LearningRateScheduler
Y
ying 已提交
8 9
from config import TrainTaskConfig, ModelHyperParams, \
        pos_enc_param_names, input_data_names
10 11 12 13 14 15 16 17 18 19 20


def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
                        max_length, n_head, place):
    """
    Pad the instances to the max sequence length in batch, and generate the
    corresponding position data and attention bias. Then, convert the numpy
    data to tensors and return a dict mapping names to tensors.
    """
    input_dict = {}

Y
ying 已提交
21 22 23 24 25 26
    def __pad_batch_data(insts,
                         pad_idx,
                         is_target=False,
                         return_pos=True,
                         return_attn_bias=True,
                         return_max_len=True):
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
        """
        Pad the instances to the max sequence length in batch, and generate the
        corresponding position data and attention bias.
        """
        return_list = []
        max_len = max(len(inst) for inst in insts)
        inst_data = np.array(
            [inst + [pad_idx] * (max_len - len(inst)) for inst in insts])
        return_list += [inst_data.astype("int64").reshape([-1, 1])]
        if return_pos:
            inst_pos = np.array([[
                pos_i + 1 if w_i != pad_idx else 0
                for pos_i, w_i in enumerate(inst)
            ] for inst in inst_data])

            return_list += [inst_pos.astype("int64").reshape([-1, 1])]
        if return_attn_bias:
            if is_target:
                # This is used to avoid attention on paddings and subsequent
                # words.
                slf_attn_bias_data = np.ones((inst_data.shape[0], max_len,
                                              max_len))
                slf_attn_bias_data = np.triu(slf_attn_bias_data, 1).reshape(
                    [-1, 1, max_len, max_len])
                slf_attn_bias_data = np.tile(slf_attn_bias_data,
                                             [1, n_head, 1, 1]) * [-1e9]
            else:
                # This is used to avoid attention on paddings.
                slf_attn_bias_data = np.array([[0] * len(inst) + [-1e9] *
                                               (max_len - len(inst))
                                               for inst in insts])
                slf_attn_bias_data = np.tile(
                    slf_attn_bias_data.reshape([-1, 1, 1, max_len]),
                    [1, n_head, max_len, 1])
            return_list += [slf_attn_bias_data.astype("float32")]
        if return_max_len:
            return_list += [max_len]
        return return_list if len(return_list) > 1 else return_list[0]

    def data_to_tensor(data_list, name_list, input_dict, place):
        assert len(data_list) == len(name_list)
        for i in range(len(name_list)):
            tensor = fluid.LoDTensor()
            tensor.set(data_list[i], place)
            input_dict[name_list[i]] = tensor

Y
ying 已提交
73
    src_word, src_pos, src_slf_attn_bias, src_max_len = __pad_batch_data(
74
        [inst[0] for inst in insts], src_pad_idx, is_target=False)
Y
ying 已提交
75
    trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = __pad_batch_data(
76 77 78
        [inst[1] for inst in insts], trg_pad_idx, is_target=True)
    trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
                                [1, 1, trg_max_len, 1]).astype("float32")
Y
ying 已提交
79 80
    lbl_word = __pad_batch_data([inst[2] for inst in insts], trg_pad_idx, False,
                                False, False, False)
81
    lbl_weight = (lbl_word != trg_pad_idx).astype("float32").reshape([-1, 1])
82 83 84

    data_to_tensor([
        src_word, src_pos, trg_word, trg_pos, src_slf_attn_bias,
85
        trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight
86 87 88 89 90 91
    ], input_data_names, input_dict, place)

    return input_dict


def main():
92 93 94
    place = fluid.CUDAPlace(0) if TrainTaskConfig.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)

95
    cost = transformer(
Y
ying 已提交
96 97 98 99 100 101 102
        ModelHyperParams.src_vocab_size + 1,
        ModelHyperParams.trg_vocab_size + 1, ModelHyperParams.max_length + 1,
        ModelHyperParams.n_layer, ModelHyperParams.n_head,
        ModelHyperParams.d_key, ModelHyperParams.d_value,
        ModelHyperParams.d_model, ModelHyperParams.d_inner_hid,
        ModelHyperParams.dropout, ModelHyperParams.src_pad_idx,
        ModelHyperParams.trg_pad_idx, ModelHyperParams.pos_pad_idx)
103

104 105 106
    lr_scheduler = LearningRateScheduler(ModelHyperParams.d_model,
                                         TrainTaskConfig.warmup_steps, place,
                                         TrainTaskConfig.learning_rate)
107
    optimizer = fluid.optimizer.Adam(
108
        learning_rate=lr_scheduler.learning_rate,
Y
ying 已提交
109 110 111
        beta1=TrainTaskConfig.beta1,
        beta2=TrainTaskConfig.beta2,
        epsilon=TrainTaskConfig.eps)
112
    optimizer.minimize(cost)
113 114 115

    train_data = paddle.batch(
        paddle.reader.shuffle(
Y
ying 已提交
116 117 118 119
            paddle.dataset.wmt16.train(ModelHyperParams.src_vocab_size,
                                       ModelHyperParams.trg_vocab_size),
            buf_size=51200),
        batch_size=TrainTaskConfig.batch_size)
120 121 122 123 124 125 126

    # Initialize the parameters.
    exe.run(fluid.framework.default_startup_program())
    for pos_enc_param_name in pos_enc_param_names:
        pos_enc_param = fluid.global_scope().find_var(
            pos_enc_param_name).get_tensor()
        pos_enc_param.set(
Y
ying 已提交
127 128 129 130 131
            position_encoding_init(ModelHyperParams.max_length + 1,
                                   ModelHyperParams.d_model), place)

    for pass_id in xrange(TrainTaskConfig.pass_num):
        for batch_id, data in enumerate(train_data()):
132 133 134 135
            # The current program desc is coupled with batch_size, thus all
            # mini-batches must have the same number of instances currently.
            if len(data) != TrainTaskConfig.batch_size:
                continue
Y
ying 已提交
136 137 138 139
            data_input = prepare_batch_input(
                data, input_data_names, ModelHyperParams.src_pad_idx,
                ModelHyperParams.trg_pad_idx, ModelHyperParams.max_length,
                ModelHyperParams.n_head, place)
140
            lr_scheduler.update_learning_rate(data_input)
141 142
            outs = exe.run(fluid.framework.default_main_program(),
                           feed=data_input,
143 144
                           fetch_list=[cost])
            cost_val = np.array(outs[0])
Y
ying 已提交
145
            print("pass_id = " + str(pass_id) + " batch = " + str(batch_id) +
146
                  " avg_cost = " + str(cost_val))
147 148 149 150


if __name__ == "__main__":
    main()