train.py 17.0 KB
Newer Older
1
import os
2
import time
3 4
import argparse
import ast
5
import numpy as np
Y
ying 已提交
6

7
import paddle
L
Luo Tao 已提交
8
import paddle.fluid as fluid
Y
ying 已提交
9

10
from model import transformer, position_encoding_init
11
from optim import LearningRateScheduler
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
from config import *
import reader


def parse_args():
    parser = argparse.ArgumentParser("Training for Transformer.")
    parser.add_argument(
        "--src_vocab_fpath",
        type=str,
        required=True,
        help="The path of vocabulary file of source language.")
    parser.add_argument(
        "--trg_vocab_fpath",
        type=str,
        required=True,
        help="The path of vocabulary file of target language.")
    parser.add_argument(
        "--train_file_pattern",
        type=str,
        required=True,
        help="The pattern to match training data files.")
    parser.add_argument(
        "--val_file_pattern",
        type=str,
        help="The pattern to match validation data files.")
    parser.add_argument(
        "--use_token_batch",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to "
        "produce batch data according to token number.")
    parser.add_argument(
        "--batch_size",
        type=int,
46
        default=2048,
47
        help="The number of sequences contained in a mini-batch, or the maximum "
48 49 50
        "number of tokens (include paddings) contained in a mini-batch. Note "
        "that this represents the number on single device and the actual batch "
        "size for multi-devices will multiply the device number.")
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
    parser.add_argument(
        "--pool_size",
        type=int,
        default=10000,
        help="The buffer size to pool data.")
    parser.add_argument(
        "--sort_type",
        default="pool",
        choices=("global", "pool", "none"),
        help="The grain to sort by length: global for all instances; pool for "
        "instances in pool; none for no sort.")
    parser.add_argument(
        "--shuffle",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to shuffle instances in each pass.")
    parser.add_argument(
        "--shuffle_batch",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to shuffle the data batches.")
    parser.add_argument(
        "--special_token",
        type=str,
        default=["<s>", "<e>", "<unk>"],
        nargs=3,
        help="The <bos>, <eos> and <unk> tokens in the dictionary.")
    parser.add_argument(
        'opts',
        help='See config.py for all options',
        default=None,
        nargs=argparse.REMAINDER)
    args = parser.parse_args()
84 85 86 87 88 89 90 91 92 93 94
    # Append args related to dict
    src_dict = reader.DataReader.load_dict(args.src_vocab_fpath)
    trg_dict = reader.DataReader.load_dict(args.trg_vocab_fpath)
    dict_args = [
        "src_vocab_size", str(len(src_dict)), "trg_vocab_size",
        str(len(trg_dict)), "bos_idx", str(src_dict[args.special_token[0]]),
        "eos_idx", str(src_dict[args.special_token[1]]), "unk_idx",
        str(src_dict[args.special_token[2]])
    ]
    merge_cfg_from_list(args.opts + dict_args,
                        [TrainTaskConfig, ModelHyperParams])
95
    return args
96 97


98 99 100 101
def pad_batch_data(insts,
                   pad_idx,
                   n_head,
                   is_target=False,
102
                   is_label=False,
103
                   return_attn_bias=True,
104 105
                   return_max_len=True,
                   return_num_token=False):
106 107
    """
    Pad the instances to the max sequence length in batch, and generate the
108 109 110 111
    corresponding position data and attention bias.
    """
    return_list = []
    max_len = max(len(inst) for inst in insts)
112 113
    num_token = reduce(lambda x, y: x + y,
                       [len(inst) for inst in insts]) if return_num_token else 0
G
guosheng 已提交
114 115 116 117
    # Any token included in dict can be used to pad, since the paddings' loss
    # will be masked out by weights and make no effect on parameter gradients.
    inst_data = np.array(
        [inst + [pad_idx] * (max_len - len(inst)) for inst in insts])
118
    return_list += [inst_data.astype("int64").reshape([-1, 1])]
119 120 121 122 123 124 125 126 127
    if is_label:  # label weight
        inst_weight = np.array(
            [[1.] * len(inst) + [0.] * (max_len - len(inst)) for inst in insts])
        return_list += [inst_weight.astype("float32").reshape([-1, 1])]
    else:  # position data
        inst_pos = np.array([
            range(1, len(inst) + 1) + [0] * (max_len - len(inst))
            for inst in insts
        ])
128 129 130 131 132 133
        return_list += [inst_pos.astype("int64").reshape([-1, 1])]
    if return_attn_bias:
        if is_target:
            # This is used to avoid attention on paddings and subsequent
            # words.
            slf_attn_bias_data = np.ones((inst_data.shape[0], max_len, max_len))
134 135
            slf_attn_bias_data = np.triu(slf_attn_bias_data,
                                         1).reshape([-1, 1, max_len, max_len])
136 137 138 139 140 141 142 143 144 145 146 147 148
            slf_attn_bias_data = np.tile(slf_attn_bias_data,
                                         [1, n_head, 1, 1]) * [-1e9]
        else:
            # This is used to avoid attention on paddings.
            slf_attn_bias_data = np.array([[0] * len(inst) + [-1e9] *
                                           (max_len - len(inst))
                                           for inst in insts])
            slf_attn_bias_data = np.tile(
                slf_attn_bias_data.reshape([-1, 1, 1, max_len]),
                [1, n_head, max_len, 1])
        return_list += [slf_attn_bias_data.astype("float32")]
    if return_max_len:
        return_list += [max_len]
149 150
    if return_num_token:
        return_list += [num_token]
151 152 153
    return return_list if len(return_list) > 1 else return_list[0]


154 155
def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
                        trg_pad_idx, n_head, d_model):
156 157
    """
    Put all padded data needed by training into a dict.
158
    """
159
    src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
G
guosheng 已提交
160
        [inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
161
    trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
G
guosheng 已提交
162
        [inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True)
163 164
    trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
                                [1, 1, trg_max_len, 1]).astype("float32")
165 166

    # These shape tensors are used in reshape_op.
167 168
    src_data_shape = np.array([-1, src_max_len, d_model], dtype="int32")
    trg_data_shape = np.array([-1, trg_max_len, d_model], dtype="int32")
G
guosheng 已提交
169 170 171
    src_slf_attn_pre_softmax_shape = np.array(
        [-1, src_slf_attn_bias.shape[-1]], dtype="int32")
    src_slf_attn_post_softmax_shape = np.array(
172
        [-1] + list(src_slf_attn_bias.shape[1:]), dtype="int32")
G
guosheng 已提交
173 174 175
    trg_slf_attn_pre_softmax_shape = np.array(
        [-1, trg_slf_attn_bias.shape[-1]], dtype="int32")
    trg_slf_attn_post_softmax_shape = np.array(
176
        [-1] + list(trg_slf_attn_bias.shape[1:]), dtype="int32")
G
guosheng 已提交
177 178 179
    trg_src_attn_pre_softmax_shape = np.array(
        [-1, trg_src_attn_bias.shape[-1]], dtype="int32")
    trg_src_attn_post_softmax_shape = np.array(
180
        [-1] + list(trg_src_attn_bias.shape[1:]), dtype="int32")
181

182
    lbl_word, lbl_weight, num_token = pad_batch_data(
183 184 185 186 187 188
        [inst[2] for inst in insts],
        trg_pad_idx,
        n_head,
        is_target=False,
        is_label=True,
        return_attn_bias=False,
189 190 191 192 193 194 195
        return_max_len=False,
        return_num_token=True)

    data_input_dict = dict(
        zip(data_input_names, [
            src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
            trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight
196
        ]))
197 198 199 200 201 202 203 204 205
    util_input_dict = dict(
        zip(util_input_names, [
            src_data_shape, src_slf_attn_pre_softmax_shape,
            src_slf_attn_post_softmax_shape, trg_data_shape,
            trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape,
            trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape
        ]))
    return data_input_dict, util_input_dict, np.asarray(
        [num_token], dtype="float32")
206 207


G
guosheng 已提交
208 209 210 211
def read_multiple(reader, count, clip_last=True):
    """
    Stack data from reader for multi-devices.
    """
212

G
guosheng 已提交
213 214 215 216
    def __impl__():
        res = []
        for item in reader():
            res.append(item)
217 218
            if len(res) == count:
                yield res
G
guosheng 已提交
219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251
                res = []
        if len(res) == count:
            yield res
        elif not clip_last:
            data = []
            for item in res:
                data += item
            if len(data) > count:
                inst_num_per_part = len(data) // count
                yield [
                    data[inst_num_per_part * i:inst_num_per_part * (i + 1)]
                    for i in range(count)
                ]

    return __impl__


def split_data(data, num_part):
    """
    Split data for each device.
    """
    if len(data) == num_part:
        return data
    data = data[0]
    inst_num_per_part = len(data) // num_part
    return [
        data[inst_num_per_part * i:inst_num_per_part * (i + 1)]
        for i in range(num_part)
    ]


def train(args):
    dev_count = fluid.core.get_cuda_device_count()
252

G
guosheng 已提交
253
    sum_cost, avg_cost, predict, token_num = transformer(
G
guosheng 已提交
254
        ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size,
G
guosheng 已提交
255
        ModelHyperParams.max_length + 1, ModelHyperParams.n_layer,
G
guosheng 已提交
256 257
        ModelHyperParams.n_head, ModelHyperParams.d_key,
        ModelHyperParams.d_value, ModelHyperParams.d_model,
258
        ModelHyperParams.d_inner_hid, ModelHyperParams.dropout,
G
guosheng 已提交
259
        ModelHyperParams.weight_sharing, TrainTaskConfig.label_smooth_eps)
260

261
    lr_scheduler = LearningRateScheduler(ModelHyperParams.d_model,
262
                                         TrainTaskConfig.warmup_steps,
263
                                         TrainTaskConfig.learning_rate)
264
    optimizer = fluid.optimizer.Adam(
265
        learning_rate=lr_scheduler.learning_rate,
Y
ying 已提交
266 267 268
        beta1=TrainTaskConfig.beta1,
        beta2=TrainTaskConfig.beta2,
        epsilon=TrainTaskConfig.eps)
269
    optimizer.minimize(sum_cost)
270

271 272
    place = fluid.CUDAPlace(0) if TrainTaskConfig.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)
273
    # Initialize the parameters.
274 275 276 277 278
    if TrainTaskConfig.ckpt_path:
        fluid.io.load_persistables(exe, TrainTaskConfig.ckpt_path)
        lr_scheduler.current_steps = TrainTaskConfig.start_step
    else:
        exe.run(fluid.framework.default_startup_program())
Y
ying 已提交
279

280 281 282 283 284 285 286 287 288 289 290 291 292
    train_data = reader.DataReader(
        src_vocab_fpath=args.src_vocab_fpath,
        trg_vocab_fpath=args.trg_vocab_fpath,
        fpattern=args.train_file_pattern,
        use_token_batch=args.use_token_batch,
        batch_size=args.batch_size * (1 if args.use_token_batch else dev_count),
        pool_size=args.pool_size,
        sort_type=args.sort_type,
        shuffle=args.shuffle,
        shuffle_batch=args.shuffle_batch,
        start_mark=args.special_token[0],
        end_mark=args.special_token[1],
        unk_mark=args.special_token[2],
G
guosheng 已提交
293
        max_length=ModelHyperParams.max_length,
294
        clip_last_batch=False)
G
guosheng 已提交
295 296 297
    train_data = read_multiple(
        reader=train_data.batch_generator,
        count=dev_count if args.use_token_batch else 1)
298

299 300 301 302 303
    build_strategy = fluid.BuildStrategy()
    # Since the token number differs among devices, customize gradient scale to
    # use token average cost among multi-devices. and the gradient scale is
    # `1 / token_number` for average cost.
    build_strategy.gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.Customized
304 305 306
    train_exe = fluid.ParallelExecutor(
        use_cuda=TrainTaskConfig.use_gpu,
        loss_name=sum_cost.name,
307
        build_strategy=build_strategy)
308 309 310

    def test_context():
        # Context to do validation.
G
guosheng 已提交
311 312 313 314 315
        test_program = fluid.default_main_program().clone(for_test=True)
        test_exe = fluid.ParallelExecutor(
            use_cuda=TrainTaskConfig.use_gpu,
            main_program=test_program,
            share_vars_from=train_exe)
316 317 318 319 320 321 322 323 324 325 326 327 328

        val_data = reader.DataReader(
            src_vocab_fpath=args.src_vocab_fpath,
            trg_vocab_fpath=args.trg_vocab_fpath,
            fpattern=args.val_file_pattern,
            use_token_batch=args.use_token_batch,
            batch_size=args.batch_size *
            (1 if args.use_token_batch else dev_count),
            pool_size=args.pool_size,
            sort_type=args.sort_type,
            start_mark=args.special_token[0],
            end_mark=args.special_token[1],
            unk_mark=args.special_token[2],
G
guosheng 已提交
329
            max_length=ModelHyperParams.max_length,
330 331 332 333 334 335 336
            clip_last_batch=False,
            shuffle=False,
            shuffle_batch=False)

        def test(exe=test_exe):
            test_total_cost = 0
            test_total_token = 0
G
guosheng 已提交
337 338 339
            test_data = read_multiple(
                reader=val_data.batch_generator,
                count=dev_count if args.use_token_batch else 1)
340 341
            for batch_id, data in enumerate(test_data()):
                feed_list = []
G
guosheng 已提交
342 343 344
                for place_id, data_buffer in enumerate(
                        split_data(
                            data, num_part=dev_count)):
345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370
                    data_input_dict, util_input_dict, _ = prepare_batch_input(
                        data_buffer, data_input_names, util_input_names,
                        ModelHyperParams.eos_idx, ModelHyperParams.eos_idx,
                        ModelHyperParams.n_head, ModelHyperParams.d_model)
                    feed_list.append(
                        dict(data_input_dict.items() + util_input_dict.items()))

                outs = exe.run(feed=feed_list,
                               fetch_list=[sum_cost.name, token_num.name])
                sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[
                    1])
                test_total_cost += sum_cost_val.sum()
                test_total_token += token_num_val.sum()
            test_avg_cost = test_total_cost / test_total_token
            test_ppl = np.exp([min(test_avg_cost, 100)])
            return test_avg_cost, test_ppl

        return test

    if args.val_file_pattern is not None:
        test = test_context()

    data_input_names = encoder_data_input_fields + decoder_data_input_fields[:
                                                                             -1] + label_data_input_fields
    util_input_names = encoder_util_input_fields + decoder_util_input_fields
    init = False
Y
ying 已提交
371
    for pass_id in xrange(TrainTaskConfig.pass_num):
372
        pass_start_time = time.time()
Y
ying 已提交
373
        for batch_id, data in enumerate(train_data()):
374 375 376
            feed_list = []
            total_num_token = 0
            lr_rate = lr_scheduler.update_learning_rate()
G
guosheng 已提交
377 378 379
            for place_id, data_buffer in enumerate(
                    split_data(
                        data, num_part=dev_count)):
380 381 382 383 384 385 386 387 388
                data_input_dict, util_input_dict, num_token = prepare_batch_input(
                    data_buffer, data_input_names, util_input_names,
                    ModelHyperParams.eos_idx, ModelHyperParams.eos_idx,
                    ModelHyperParams.n_head, ModelHyperParams.d_model)
                total_num_token += num_token
                feed_list.append(
                    dict(data_input_dict.items() + util_input_dict.items() +
                         {lr_scheduler.learning_rate.name: lr_rate}.items()))

G
guosheng 已提交
389
                if not init:  # init the position encoding table
390 391 392 393 394 395
                    for pos_enc_param_name in pos_enc_param_names:
                        pos_enc = position_encoding_init(
                            ModelHyperParams.max_length + 1,
                            ModelHyperParams.d_model)
                        feed_list[place_id][pos_enc_param_name] = pos_enc
            for feed_dict in feed_list:
G
guosheng 已提交
396
                feed_dict[sum_cost.name + "@GRAD"] = 1. / total_num_token
397 398 399 400 401 402 403
            outs = train_exe.run(fetch_list=[sum_cost.name, token_num.name],
                                 feed=feed_list)
            sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[1])
            total_sum_cost = sum_cost_val.sum(
            )  # sum the cost from multi-devices
            total_token_num = token_num_val.sum()
            total_avg_cost = total_sum_cost / total_token_num
404
            print("epoch: %d, batch: %d, sum loss: %f, avg loss: %f, ppl: %f" %
405 406 407
                  (pass_id, batch_id, total_sum_cost, total_avg_cost,
                   np.exp([min(total_avg_cost, 100)])))
            init = True
408
        # Validate and save the model for inference.
409 410 411 412 413 414 415 416
        print("epoch: %d, " % pass_id + (
            "val avg loss: %f, val ppl: %f, " % test()
            if args.val_file_pattern is not None else "") + "consumed %fs" % (
                time.time() - pass_start_time))
        fluid.io.save_persistables(
            exe,
            os.path.join(TrainTaskConfig.ckpt_dir,
                         "pass_" + str(pass_id) + ".checkpoint"))
417 418 419
        fluid.io.save_inference_model(
            os.path.join(TrainTaskConfig.model_dir,
                         "pass_" + str(pass_id) + ".infer.model"),
420
            data_input_names[:-2] + util_input_names, [predict], exe)
421 422 423


if __name__ == "__main__":
424 425
    args = parse_args()
    train(args)