train.py 10.9 KB
Newer Older
1
import os
2
import time
3
import numpy as np
Y
ying 已提交
4

5
import paddle
L
Luo Tao 已提交
6
import paddle.fluid as fluid
Y
ying 已提交
7

8
from model import transformer, position_encoding_init
9
from optim import LearningRateScheduler
G
guosheng 已提交
10
from config import *
11 12


13 14 15 16
def pad_batch_data(insts,
                   pad_idx,
                   n_head,
                   is_target=False,
17
                   is_label=False,
18 19
                   return_attn_bias=True,
                   return_max_len=True):
20 21
    """
    Pad the instances to the max sequence length in batch, and generate the
22 23 24 25
    corresponding position data and attention bias.
    """
    return_list = []
    max_len = max(len(inst) for inst in insts)
G
guosheng 已提交
26 27 28 29
    # Any token included in dict can be used to pad, since the paddings' loss
    # will be masked out by weights and make no effect on parameter gradients.
    inst_data = np.array(
        [inst + [pad_idx] * (max_len - len(inst)) for inst in insts])
30
    return_list += [inst_data.astype("int64").reshape([-1, 1])]
31 32 33 34 35 36 37 38 39
    if is_label:  # label weight
        inst_weight = np.array(
            [[1.] * len(inst) + [0.] * (max_len - len(inst)) for inst in insts])
        return_list += [inst_weight.astype("float32").reshape([-1, 1])]
    else:  # position data
        inst_pos = np.array([
            range(1, len(inst) + 1) + [0] * (max_len - len(inst))
            for inst in insts
        ])
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
        return_list += [inst_pos.astype("int64").reshape([-1, 1])]
    if return_attn_bias:
        if is_target:
            # This is used to avoid attention on paddings and subsequent
            # words.
            slf_attn_bias_data = np.ones((inst_data.shape[0], max_len, max_len))
            slf_attn_bias_data = np.triu(slf_attn_bias_data, 1).reshape(
                [-1, 1, max_len, max_len])
            slf_attn_bias_data = np.tile(slf_attn_bias_data,
                                         [1, n_head, 1, 1]) * [-1e9]
        else:
            # This is used to avoid attention on paddings.
            slf_attn_bias_data = np.array([[0] * len(inst) + [-1e9] *
                                           (max_len - len(inst))
                                           for inst in insts])
            slf_attn_bias_data = np.tile(
                slf_attn_bias_data.reshape([-1, 1, 1, max_len]),
                [1, n_head, max_len, 1])
        return_list += [slf_attn_bias_data.astype("float32")]
    if return_max_len:
        return_list += [max_len]
    return return_list if len(return_list) > 1 else return_list[0]


G
guosheng 已提交
64 65
def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
                        trg_pad_idx, n_head, d_model):
66 67
    """
    Put all padded data needed by training into a dict.
68
    """
69
    src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
G
guosheng 已提交
70
        [inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
71
    trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
G
guosheng 已提交
72
        [inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True)
73 74
    trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
                                [1, 1, trg_max_len, 1]).astype("float32")
75 76

    # These shape tensors are used in reshape_op.
G
guosheng 已提交
77 78
    src_data_shape = np.array([-1, src_max_len, d_model], dtype="int32")
    trg_data_shape = np.array([-1, trg_max_len, d_model], dtype="int32")
G
guosheng 已提交
79 80 81
    src_slf_attn_pre_softmax_shape = np.array(
        [-1, src_slf_attn_bias.shape[-1]], dtype="int32")
    src_slf_attn_post_softmax_shape = np.array(
G
guosheng 已提交
82
        [-1] + list(src_slf_attn_bias.shape[1:]), dtype="int32")
G
guosheng 已提交
83 84 85
    trg_slf_attn_pre_softmax_shape = np.array(
        [-1, trg_slf_attn_bias.shape[-1]], dtype="int32")
    trg_slf_attn_post_softmax_shape = np.array(
G
guosheng 已提交
86
        [-1] + list(trg_slf_attn_bias.shape[1:]), dtype="int32")
G
guosheng 已提交
87 88 89
    trg_src_attn_pre_softmax_shape = np.array(
        [-1, trg_src_attn_bias.shape[-1]], dtype="int32")
    trg_src_attn_post_softmax_shape = np.array(
G
guosheng 已提交
90
        [-1] + list(trg_src_attn_bias.shape[1:]), dtype="int32")
91

92 93 94 95 96 97 98 99
    lbl_word, lbl_weight = pad_batch_data(
        [inst[2] for inst in insts],
        trg_pad_idx,
        n_head,
        is_target=False,
        is_label=True,
        return_attn_bias=False,
        return_max_len=False)
100

G
guosheng 已提交
101 102 103 104
    data_input_dict = dict(
        zip(data_input_names, [
            src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
            trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight
105
        ]))
G
guosheng 已提交
106 107 108 109 110 111 112 113
    util_input_dict = dict(
        zip(util_input_names, [
            src_data_shape, src_slf_attn_pre_softmax_shape,
            src_slf_attn_post_softmax_shape, trg_data_shape,
            trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape,
            trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape
        ]))
    return data_input_dict, util_input_dict
114 115


Y
Yu Yang 已提交
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
def read_multiple(reader, count):
    def __impl__():
        res = []
        for item in reader():
            res.append(item)
            if len(res) == count:
                yield res
                res = []

        if len(res) == count:
            yield res

    return __impl__


131
def main():
132 133 134
    place = fluid.CUDAPlace(0) if TrainTaskConfig.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)

G
guosheng 已提交
135
    sum_cost, avg_cost, predict, token_num = transformer(
G
guosheng 已提交
136
        ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size,
G
guosheng 已提交
137
        ModelHyperParams.max_length + 1, ModelHyperParams.n_layer,
G
guosheng 已提交
138 139 140
        ModelHyperParams.n_head, ModelHyperParams.d_key,
        ModelHyperParams.d_value, ModelHyperParams.d_model,
        ModelHyperParams.d_inner_hid, ModelHyperParams.dropout)
141

142
    lr_scheduler = LearningRateScheduler(ModelHyperParams.d_model,
G
guosheng 已提交
143
                                         TrainTaskConfig.warmup_steps,
144
                                         TrainTaskConfig.learning_rate)
145
    optimizer = fluid.optimizer.Adam(
146
        learning_rate=lr_scheduler.learning_rate,
Y
ying 已提交
147 148 149
        beta1=TrainTaskConfig.beta1,
        beta2=TrainTaskConfig.beta2,
        epsilon=TrainTaskConfig.eps)
150
    optimizer.minimize(avg_cost if TrainTaskConfig.use_avg_cost else sum_cost)
151 152 153

    train_data = paddle.batch(
        paddle.reader.shuffle(
Y
ying 已提交
154 155
            paddle.dataset.wmt16.train(ModelHyperParams.src_vocab_size,
                                       ModelHyperParams.trg_vocab_size),
G
guosheng 已提交
156
            buf_size=100000),
Y
ying 已提交
157
        batch_size=TrainTaskConfig.batch_size)
158

159 160 161
    # Program to do validation.
    test_program = fluid.default_main_program().clone()
    with fluid.program_guard(test_program):
G
guosheng 已提交
162
        test_program = fluid.io.get_inference_program([avg_cost])
163 164 165 166 167 168
    val_data = paddle.batch(
        paddle.dataset.wmt16.validation(ModelHyperParams.src_vocab_size,
                                        ModelHyperParams.trg_vocab_size),
        batch_size=TrainTaskConfig.batch_size)

    def test(exe):
G
guosheng 已提交
169 170
        test_total_cost = 0
        test_total_token = 0
171
        for batch_id, data in enumerate(val_data()):
G
guosheng 已提交
172 173 174 175
            data_input_dict, util_input_dict = prepare_batch_input(
                data, data_input_names, util_input_names,
                ModelHyperParams.eos_idx, ModelHyperParams.eos_idx,
                ModelHyperParams.n_head, ModelHyperParams.d_model)
G
guosheng 已提交
176 177
            test_sum_cost, test_token_num = exe.run(
                test_program,
G
guosheng 已提交
178
                feed=dict(data_input_dict.items() + util_input_dict.items()),
G
guosheng 已提交
179 180 181 182 183 184 185
                fetch_list=[sum_cost, token_num],
                use_program_cache=True)
            test_total_cost += test_sum_cost
            test_total_token += test_token_num
        test_avg_cost = test_total_cost / test_total_token
        test_ppl = np.exp([min(test_avg_cost, 100)])
        return test_avg_cost, test_ppl
186

187 188
    # Initialize the parameters.
    exe.run(fluid.framework.default_startup_program())
G
guosheng 已提交
189 190 191 192 193 194 195 196 197

    data_input_names = encoder_data_input_fields + decoder_data_input_fields[:
                                                                             -1] + label_data_names
    util_input_names = encoder_util_input_fields + decoder_util_input_fields

    train_exe = fluid.ParallelExecutor(
        use_cuda=TrainTaskConfig.use_gpu,
        loss_name=avg_cost.name
        if TrainTaskConfig.use_avg_cost else sum_cost.name)
Y
ying 已提交
198

Y
Yu Yang 已提交
199 200 201 202
    dev_count = fluid.core.get_cuda_device_count()

    for pos_enc_param_name in pos_enc_param_names:
        tensor = position_encoding_init(ModelHyperParams.max_length + 1, ModelHyperParams.d_model)
Y
Update  
Yu Yang 已提交
203 204
        for place_id in xrange(dev_count):
            local_scope = train_exe.executor.local_scope(place_id)
Y
Yu Yang 已提交
205 206 207
            local_scope.find_var(pos_enc_param_name).get_tensor().set(tensor, fluid.CUDAPlace(place_id))

    train_data = read_multiple(reader=train_data, count=dev_count)
Y
ying 已提交
208
    for pass_id in xrange(TrainTaskConfig.pass_num):
209
        pass_start_time = time.time()
Y
ying 已提交
210
        for batch_id, data in enumerate(train_data()):
Y
Update  
Yu Yang 已提交
211
            for place_id, data_buffer in enumerate(data):
Y
Yu Yang 已提交
212 213 214 215 216
                data_input_dict, util_input_dict = prepare_batch_input(
                    data_buffer, data_input_names, util_input_names,
                    ModelHyperParams.eos_idx, ModelHyperParams.eos_idx,
                    ModelHyperParams.n_head, ModelHyperParams.d_model)

Y
Update  
Yu Yang 已提交
217 218
                local_scope = train_exe.executor.local_scope(place_id)

Y
Yu Yang 已提交
219 220 221 222 223 224 225 226 227
                local_scope.find_var(lr_scheduler.learning_rate.name).get_tensor().set(
                    lr_scheduler.update_learning_rate(),
                    fluid.CUDAPlace(place_id))

                for var_name in data_input_dict:
                    local_scope.find_var(var_name).get_tensor().set(data_input_dict[var_name],
                                                                    fluid.CUDAPlace(place_id))

                for var_name in util_input_dict:
Y
Stash  
Yu Yang 已提交
228
                    print var_name, local_scope.find_var(var_name)
Y
Yu Yang 已提交
229 230 231 232
                    local_scope.find_var(var_name).get_tensor().set(util_input_dict[var_name],
                                                                    fluid.CUDAPlace(place_id))

            outs = train_exe.run(fetch_list=[sum_cost.name, token_num.name])
G
guosheng 已提交
233 234 235 236 237
            sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[1])
            total_sum_cost = sum_cost_val.sum(
            )  # sum the cost from multi devices
            total_token_num = token_num_val.sum()
            total_avg_cost = total_sum_cost / total_token_num
238
            print("epoch: %d, batch: %d, sum loss: %f, avg loss: %f, ppl: %f" %
G
guosheng 已提交
239 240
                  (pass_id, batch_id, total_sum_cost, total_avg_cost,
                   np.exp([min(total_avg_cost, 100)])))
241
        # Validate and save the model for inference.
G
guosheng 已提交
242
        val_avg_cost, val_ppl = test(exe)
243 244
        pass_end_time = time.time()
        time_consumed = pass_end_time - pass_start_time
G
guosheng 已提交
245 246
        print("epoch: %d, val avg loss: %f, val ppl: %f, "
              "consumed %fs" % (pass_id, val_avg_cost, val_ppl, time_consumed))
247 248 249 250 251
        fluid.io.save_inference_model(
            os.path.join(TrainTaskConfig.model_dir,
                         "pass_" + str(pass_id) + ".infer.model"),
            encoder_input_data_names + decoder_input_data_names[:-1],
            [predict], exe)
252 253 254 255


if __name__ == "__main__":
    main()