train.py 29.6 KB
Newer Older
1 2
import argparse
import ast
G
guosheng 已提交
3 4
import copy
import logging
5
import multiprocessing
Y
Yu Yang 已提交
6
import os
G
guosheng 已提交
7
import six
G
guosheng 已提交
8
import sys
Y
Yibing Liu 已提交
9
sys.path.append("../../models/neural_machine_translation/transformer/")
Y
Yu Yang 已提交
10
import time
Y
ying 已提交
11

Y
Yu Yang 已提交
12
import numpy as np
L
Luo Tao 已提交
13
import paddle.fluid as fluid
Y
ying 已提交
14

Y
Yu Yang 已提交
15 16
import reader
from config import *
Y
Yibing Liu 已提交
17
from desc import *
18
from model import transformer, position_encoding_init
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50


def parse_args():
    parser = argparse.ArgumentParser("Training for Transformer.")
    parser.add_argument(
        "--src_vocab_fpath",
        type=str,
        required=True,
        help="The path of vocabulary file of source language.")
    parser.add_argument(
        "--trg_vocab_fpath",
        type=str,
        required=True,
        help="The path of vocabulary file of target language.")
    parser.add_argument(
        "--train_file_pattern",
        type=str,
        required=True,
        help="The pattern to match training data files.")
    parser.add_argument(
        "--val_file_pattern",
        type=str,
        help="The pattern to match validation data files.")
    parser.add_argument(
        "--use_token_batch",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to "
        "produce batch data according to token number.")
    parser.add_argument(
        "--batch_size",
        type=int,
51
        default=4096,
52
        help="The number of sequences contained in a mini-batch, or the maximum "
53 54 55
        "number of tokens (include paddings) contained in a mini-batch. Note "
        "that this represents the number on single device and the actual batch "
        "size for multi-devices will multiply the device number.")
56 57 58
    parser.add_argument(
        "--pool_size",
        type=int,
59
        default=200000,
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
        help="The buffer size to pool data.")
    parser.add_argument(
        "--sort_type",
        default="pool",
        choices=("global", "pool", "none"),
        help="The grain to sort by length: global for all instances; pool for "
        "instances in pool; none for no sort.")
    parser.add_argument(
        "--shuffle",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to shuffle instances in each pass.")
    parser.add_argument(
        "--shuffle_batch",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to shuffle the data batches.")
    parser.add_argument(
        "--special_token",
        type=str,
        default=["<s>", "<e>", "<unk>"],
        nargs=3,
        help="The <bos>, <eos> and <unk> tokens in the dictionary.")
83 84
    parser.add_argument(
        "--token_delimiter",
G
guosheng 已提交
85
        type=lambda x: str(x.encode().decode("unicode-escape")),
86 87
        default=" ",
        help="The delimiter used to split tokens in source or target sentences. "
88
        "For EN-DE BPE data we provided, use spaces as token delimiter. ")
89 90 91 92 93
    parser.add_argument(
        'opts',
        help='See config.py for all options',
        default=None,
        nargs=argparse.REMAINDER)
94 95 96 97 98 99 100 101 102 103 104
    parser.add_argument(
        '--local',
        type=ast.literal_eval,
        default=True,
        help='Whether to run as local mode.')
    parser.add_argument(
        '--device',
        type=str,
        default='GPU',
        choices=['CPU', 'GPU'],
        help="The device type.")
G
fix  
gongweibao 已提交
105 106 107 108 109
    parser.add_argument(
        '--update_method',
        choices=("pserver", "nccl2"),
        default="pserver",
        help='Update method.')
Q
Qiao Longfei 已提交
110 111
    parser.add_argument(
        '--sync', type=ast.literal_eval, default=True, help="sync mode.")
G
guosheng 已提交
112 113 114
    parser.add_argument(
        "--enable_ce",
        type=ast.literal_eval,
115
        default=False,
G
guosheng 已提交
116 117
        help="The flag indicating whether to run the task "
        "for continuous evaluation.")
118 119 120
    parser.add_argument(
        "--use_mem_opt",
        type=ast.literal_eval,
G
guosheng 已提交
121
        default=True,
122 123 124 125 126 127
        help="The flag indicating whether to use memory optimization.")
    parser.add_argument(
        "--use_py_reader",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to use py_reader.")
G
fix  
gongweibao 已提交
128
    parser.add_argument(
G
guosheng 已提交
129 130 131 132
        "--fetch_steps",
        type=int,
        default=100,
        help="The frequency to fetch and print output.")
G
fix  
gongweibao 已提交
133

134
    args = parser.parse_args()
135 136 137 138 139 140 141 142 143 144 145
    # Append args related to dict
    src_dict = reader.DataReader.load_dict(args.src_vocab_fpath)
    trg_dict = reader.DataReader.load_dict(args.trg_vocab_fpath)
    dict_args = [
        "src_vocab_size", str(len(src_dict)), "trg_vocab_size",
        str(len(trg_dict)), "bos_idx", str(src_dict[args.special_token[0]]),
        "eos_idx", str(src_dict[args.special_token[1]]), "unk_idx",
        str(src_dict[args.special_token[2]])
    ]
    merge_cfg_from_list(args.opts + dict_args,
                        [TrainTaskConfig, ModelHyperParams])
146
    return args
147 148


G
guosheng 已提交
149 150
def append_nccl2_prepare(startup_prog, trainer_id, worker_endpoints,
                         current_endpoint):
151 152
    assert (trainer_id >= 0 and len(worker_endpoints) > 1 and
            current_endpoint in worker_endpoints)
G
fix  
gongweibao 已提交
153 154
    eps = copy.deepcopy(worker_endpoints)
    eps.remove(current_endpoint)
G
guosheng 已提交
155
    nccl_id_var = startup_prog.global_block().create_var(
156
        name="NCCLID", persistable=True, type=fluid.core.VarDesc.VarType.RAW)
G
guosheng 已提交
157
    startup_prog.global_block().append_op(
G
fix  
gongweibao 已提交
158 159 160 161 162 163 164 165 166
        type="gen_nccl_id",
        inputs={},
        outputs={"NCCLID": nccl_id_var},
        attrs={
            "endpoint": current_endpoint,
            "endpoint_list": eps,
            "trainer_id": trainer_id
        })
    return nccl_id_var
167

168

169 170 171 172
def pad_batch_data(insts,
                   pad_idx,
                   n_head,
                   is_target=False,
173
                   is_label=False,
174
                   return_attn_bias=True,
175 176
                   return_max_len=True,
                   return_num_token=False):
177 178
    """
    Pad the instances to the max sequence length in batch, and generate the
179 180 181 182
    corresponding position data and attention bias.
    """
    return_list = []
    max_len = max(len(inst) for inst in insts)
G
guosheng 已提交
183 184 185 186
    # Any token included in dict can be used to pad, since the paddings' loss
    # will be masked out by weights and make no effect on parameter gradients.
    inst_data = np.array(
        [inst + [pad_idx] * (max_len - len(inst)) for inst in insts])
187
    return_list += [inst_data.astype("int64").reshape([-1, 1])]
188 189 190 191 192 193
    if is_label:  # label weight
        inst_weight = np.array(
            [[1.] * len(inst) + [0.] * (max_len - len(inst)) for inst in insts])
        return_list += [inst_weight.astype("float32").reshape([-1, 1])]
    else:  # position data
        inst_pos = np.array([
194
            list(range(0, len(inst))) + [0] * (max_len - len(inst))
195 196
            for inst in insts
        ])
197 198 199 200 201 202
        return_list += [inst_pos.astype("int64").reshape([-1, 1])]
    if return_attn_bias:
        if is_target:
            # This is used to avoid attention on paddings and subsequent
            # words.
            slf_attn_bias_data = np.ones((inst_data.shape[0], max_len, max_len))
203 204
            slf_attn_bias_data = np.triu(slf_attn_bias_data,
                                         1).reshape([-1, 1, max_len, max_len])
205 206 207 208 209 210 211 212 213 214 215 216 217
            slf_attn_bias_data = np.tile(slf_attn_bias_data,
                                         [1, n_head, 1, 1]) * [-1e9]
        else:
            # This is used to avoid attention on paddings.
            slf_attn_bias_data = np.array([[0] * len(inst) + [-1e9] *
                                           (max_len - len(inst))
                                           for inst in insts])
            slf_attn_bias_data = np.tile(
                slf_attn_bias_data.reshape([-1, 1, 1, max_len]),
                [1, n_head, max_len, 1])
        return_list += [slf_attn_bias_data.astype("float32")]
    if return_max_len:
        return_list += [max_len]
218
    if return_num_token:
G
guosheng 已提交
219 220 221
        num_token = 0
        for inst in insts:
            num_token += len(inst)
222
        return_list += [num_token]
223 224 225
    return return_list if len(return_list) > 1 else return_list[0]


226 227
def prepare_batch_input(insts, data_input_names, src_pad_idx, trg_pad_idx,
                        n_head, d_model):
228 229
    """
    Put all padded data needed by training into a dict.
230
    """
231
    src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
G
guosheng 已提交
232
        [inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
233 234
    src_word = src_word.reshape(-1, src_max_len, 1)
    src_pos = src_pos.reshape(-1, src_max_len, 1)
235
    trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
G
guosheng 已提交
236
        [inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True)
237 238 239
    trg_word = trg_word.reshape(-1, trg_max_len, 1)
    trg_pos = trg_pos.reshape(-1, trg_max_len, 1)

240 241
    trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
                                [1, 1, trg_max_len, 1]).astype("float32")
242

243
    lbl_word, lbl_weight, num_token = pad_batch_data(
244 245 246 247 248 249
        [inst[2] for inst in insts],
        trg_pad_idx,
        n_head,
        is_target=False,
        is_label=True,
        return_attn_bias=False,
250 251 252 253 254 255 256
        return_max_len=False,
        return_num_token=True)

    data_input_dict = dict(
        zip(data_input_names, [
            src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
            trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight
257
        ]))
258

259
    return data_input_dict, np.asarray([num_token], dtype="float32")
260 261


262 263 264 265 266 267
def prepare_data_generator(args,
                           is_test,
                           count,
                           pyreader,
                           py_reader_provider_wrapper,
                           place=None):
Q
Qiao Longfei 已提交
268
    """
269 270
    Data generator wrapper for DataReader. If use py_reader, set the data
    provider for py_reader
Q
Qiao Longfei 已提交
271
    """
272 273
    data_reader = reader.DataReader(
        fpattern=args.val_file_pattern if is_test else args.train_file_pattern,
Q
Qiao Longfei 已提交
274 275
        src_vocab_fpath=args.src_vocab_fpath,
        trg_vocab_fpath=args.trg_vocab_fpath,
276
        token_delimiter=args.token_delimiter,
Q
Qiao Longfei 已提交
277
        use_token_batch=args.use_token_batch,
278
        batch_size=args.batch_size * (1 if args.use_token_batch else count),
Q
Qiao Longfei 已提交
279 280
        pool_size=args.pool_size,
        sort_type=args.sort_type,
281 282
        shuffle=args.shuffle,
        shuffle_batch=args.shuffle_batch,
Q
Qiao Longfei 已提交
283 284 285 286 287
        start_mark=args.special_token[0],
        end_mark=args.special_token[1],
        unk_mark=args.special_token[2],
        # count start and end tokens out
        max_length=ModelHyperParams.max_length - 2,
288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327
        clip_last_batch=False).batch_generator

    def stack(data_reader, count, clip_last=True):
        def __impl__():
            res = []
            for item in data_reader():
                res.append(item)
                if len(res) == count:
                    yield res
                    res = []
            if len(res) == count:
                yield res
            elif not clip_last:
                data = []
                for item in res:
                    data += item
                if len(data) > count:
                    inst_num_per_part = len(data) // count
                    yield [
                        data[inst_num_per_part * i:inst_num_per_part * (i + 1)]
                        for i in range(count)
                    ]

        return __impl__

    def split(data_reader, count):
        def __impl__():
            for item in data_reader():
                inst_num_per_part = len(item) // count
                for i in range(count):
                    yield item[inst_num_per_part * i:inst_num_per_part * (i + 1
                                                                          )]

        return __impl__

    if not args.use_token_batch:
        # to make data on each device have similar token number
        data_reader = split(data_reader, count)
    if args.use_py_reader:
        pyreader.decorate_tensor_provider(
328
            py_reader_provider_wrapper(data_reader, place))
329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365
        data_reader = None
    else:  # Data generator for multi-devices
        data_reader = stack(data_reader, count)
    return data_reader


def prepare_feed_dict_list(data_generator, init_flag, count):
    """
    Prepare the list of feed dict for multi-devices.
    """
    feed_dict_list = []
    if data_generator is not None:  # use_py_reader == False
        data_input_names = encoder_data_input_fields + \
                    decoder_data_input_fields[:-1] + label_data_input_fields
        data = next(data_generator)
        for idx, data_buffer in enumerate(data):
            data_input_dict, num_token = prepare_batch_input(
                data_buffer, data_input_names, ModelHyperParams.eos_idx,
                ModelHyperParams.eos_idx, ModelHyperParams.n_head,
                ModelHyperParams.d_model)
            feed_dict_list.append(data_input_dict)
    if init_flag:
        for idx in range(count):
            pos_enc_tables = dict()
            for pos_enc_param_name in pos_enc_param_names:
                pos_enc_tables[pos_enc_param_name] = position_encoding_init(
                    ModelHyperParams.max_length + 1, ModelHyperParams.d_model)
            if len(feed_dict_list) <= idx:
                feed_dict_list.append(pos_enc_tables)
            else:
                feed_dict_list[idx] = dict(
                    list(pos_enc_tables.items()) + list(feed_dict_list[idx]
                                                        .items()))

    return feed_dict_list if len(feed_dict_list) == count else None


366
def py_reader_provider_wrapper(data_reader, place):
367 368 369
    """
    Data provider needed by fluid.layers.py_reader.
    """
Q
Qiao Longfei 已提交
370

371 372 373 374 375 376 377 378
    def py_reader_provider():
        data_input_names = encoder_data_input_fields + \
                    decoder_data_input_fields[:-1] + label_data_input_fields
        for batch_id, data in enumerate(data_reader()):
            data_input_dict, num_token = prepare_batch_input(
                data, data_input_names, ModelHyperParams.eos_idx,
                ModelHyperParams.eos_idx, ModelHyperParams.n_head,
                ModelHyperParams.d_model)
C
chengduo 已提交
379 380
            total_dict = dict(data_input_dict.items())
            yield [total_dict[item] for item in data_input_names]
381 382 383 384 385 386 387

    return py_reader_provider


def test_context(exe, train_exe, dev_count):
    # Context to do validation.
    test_prog = fluid.Program()
G
guosheng 已提交
388 389 390 391
    startup_prog = fluid.Program()
    if args.enable_ce:
        test_prog.random_seed = 1000
        startup_prog.random_seed = 1000
392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412
    with fluid.program_guard(test_prog, startup_prog):
        with fluid.unique_name.guard():
            sum_cost, avg_cost, predict, token_num, pyreader = transformer(
                ModelHyperParams.src_vocab_size,
                ModelHyperParams.trg_vocab_size,
                ModelHyperParams.max_length + 1,
                ModelHyperParams.n_layer,
                ModelHyperParams.n_head,
                ModelHyperParams.d_key,
                ModelHyperParams.d_value,
                ModelHyperParams.d_model,
                ModelHyperParams.d_inner_hid,
                ModelHyperParams.prepostprocess_dropout,
                ModelHyperParams.attention_dropout,
                ModelHyperParams.relu_dropout,
                ModelHyperParams.preprocess_cmd,
                ModelHyperParams.postprocess_cmd,
                ModelHyperParams.weight_sharing,
                TrainTaskConfig.label_smooth_eps,
                use_py_reader=args.use_py_reader,
                is_test=True)
G
guosheng 已提交
413
    test_prog = test_prog.clone(for_test=True)
414
    test_data = prepare_data_generator(
415 416 417 418 419
        args,
        is_test=True,
        count=dev_count,
        pyreader=pyreader,
        py_reader_provider_wrapper=py_reader_provider_wrapper)
420

421 422 423 424 425 426
    exe.run(startup_prog)  # to init pyreader for testing
    if TrainTaskConfig.ckpt_path:
        fluid.io.load_persistables(
            exe, TrainTaskConfig.ckpt_path, main_program=test_prog)

    build_strategy = fluid.BuildStrategy()
Q
Qiao Longfei 已提交
427 428
    test_exe = fluid.ParallelExecutor(
        use_cuda=TrainTaskConfig.use_gpu,
429
        main_program=test_prog,
430
        build_strategy=build_strategy,
Q
Qiao Longfei 已提交
431 432
        share_vars_from=train_exe)

433
    def test(exe=test_exe, pyreader=pyreader):
Q
Qiao Longfei 已提交
434 435
        test_total_cost = 0
        test_total_token = 0
436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452

        if args.use_py_reader:
            pyreader.start()
            data_generator = None
        else:
            data_generator = test_data()
        while True:
            try:
                feed_dict_list = prepare_feed_dict_list(data_generator, False,
                                                        dev_count)
                outs = test_exe.run(fetch_list=[sum_cost.name, token_num.name],
                                    feed=feed_dict_list)
            except (StopIteration, fluid.core.EOFException):
                # The current pass is over.
                if args.use_py_reader:
                    pyreader.reset()
                break
Q
Qiao Longfei 已提交
453 454 455 456 457 458 459 460 461 462
            sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[1])
            test_total_cost += sum_cost_val.sum()
            test_total_token += token_num_val.sum()
        test_avg_cost = test_total_cost / test_total_token
        test_ppl = np.exp([min(test_avg_cost, 100)])
        return test_avg_cost, test_ppl

    return test


463 464 465 466 467 468 469 470 471 472 473
def train_loop(exe,
               train_prog,
               startup_prog,
               dev_count,
               sum_cost,
               avg_cost,
               token_num,
               predict,
               pyreader,
               nccl2_num_trainers=1,
               nccl2_trainer_id=0):
Q
Qiao Longfei 已提交
474 475
    # Initialize the parameters.
    if TrainTaskConfig.ckpt_path:
476 477 478 479 480
        exe.run(startup_prog)  # to init pyreader for training
        logging.info("load checkpoint from {}".format(
            TrainTaskConfig.ckpt_path))
        fluid.io.load_persistables(
            exe, TrainTaskConfig.ckpt_path, main_program=train_prog)
Q
Qiao Longfei 已提交
481
    else:
G
fix  
gongweibao 已提交
482
        logging.info("init fluid.framework.default_startup_program")
483
        exe.run(startup_prog)
Q
Qiao Longfei 已提交
484

G
fix  
gongweibao 已提交
485
    logging.info("begin reader")
486
    train_data = prepare_data_generator(
487 488 489 490 491
        args,
        is_test=False,
        count=dev_count,
        pyreader=pyreader,
        py_reader_provider_wrapper=py_reader_provider_wrapper)
Q
Qiao Longfei 已提交
492

493 494
    # For faster executor
    exec_strategy = fluid.ExecutionStrategy()
495
    exec_strategy.num_iteration_per_drop_scope = int(args.fetch_steps)
Q
Qiao Longfei 已提交
496
    build_strategy = fluid.BuildStrategy()
C
chengduo 已提交
497 498 499 500 501
    build_strategy.memory_optimize = False
    build_strategy.enable_inplace = True

    sum_cost.persistable = True
    token_num.persistable = True
Q
Qiao Longfei 已提交
502 503 504
    # Since the token number differs among devices, customize gradient scale to
    # use token average cost among multi-devices. and the gradient scale is
    # `1 / token_number` for average cost.
G
guosheng 已提交
505
    # build_strategy.gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.Customized
C
chengduo 已提交
506
    build_strategy.fuse_all_optimizer_ops = True
G
fix  
gongweibao 已提交
507

G
fix  
gongweibao 已提交
508
    logging.info("begin executor")
Q
Qiao Longfei 已提交
509 510
    train_exe = fluid.ParallelExecutor(
        use_cuda=TrainTaskConfig.use_gpu,
511 512 513
        loss_name=avg_cost.name,
        main_program=train_prog,
        build_strategy=build_strategy,
G
fix  
gongweibao 已提交
514
        exec_strategy=exec_strategy,
515 516
        num_trainers=nccl2_num_trainers,
        trainer_id=nccl2_trainer_id)
Q
Qiao Longfei 已提交
517 518

    if args.val_file_pattern is not None:
519
        test = test_context(exe, train_exe, dev_count)
Q
Qiao Longfei 已提交
520

G
guosheng 已提交
521 522 523 524 525 526
    # the best cross-entropy value with label smoothing
    loss_normalizer = -((1. - TrainTaskConfig.label_smooth_eps) * np.log(
        (1. - TrainTaskConfig.label_smooth_eps
         )) + TrainTaskConfig.label_smooth_eps *
                        np.log(TrainTaskConfig.label_smooth_eps / (
                            ModelHyperParams.trg_vocab_size - 1) + 1e-20))
G
guosheng 已提交
527

M
minqiyang 已提交
528
    step_idx = 0
529
    init_flag = True
G
fix  
gongweibao 已提交
530 531

    logging.info("begin train")
G
guosheng 已提交
532
    for pass_id in six.moves.xrange(TrainTaskConfig.pass_num):
Q
Qiao Longfei 已提交
533
        pass_start_time = time.time()
534 535 536 537 538 539 540 541 542 543 544 545 546

        if args.use_py_reader:
            pyreader.start()
            data_generator = None
        else:
            data_generator = train_data()

        batch_id = 0
        while True:
            try:
                feed_dict_list = prepare_feed_dict_list(data_generator,
                                                        init_flag, dev_count)
                outs = train_exe.run(
547
                    fetch_list=[sum_cost.name, token_num.name]
G
guosheng 已提交
548
                    if step_idx % args.fetch_steps == 0 else [],
549
                    feed=feed_dict_list)
550

G
guosheng 已提交
551
                if step_idx % args.fetch_steps == 0:
552 553
                    sum_cost_val, token_num_val = np.array(outs[0]), np.array(
                        outs[1])
G
fix  
gongweibao 已提交
554 555 556 557 558
                    # sum the cost from multi-devices
                    total_sum_cost = sum_cost_val.sum()
                    total_token_num = token_num_val.sum()
                    total_avg_cost = total_sum_cost / total_token_num

G
guosheng 已提交
559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577
                    if step_idx == 0:
                        logging.info(
                            "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
                            "normalized loss: %f, ppl: %f" %
                            (step_idx, pass_id, batch_id, total_avg_cost,
                             total_avg_cost - loss_normalizer,
                             np.exp([min(total_avg_cost, 100)])))
                        avg_batch_time = time.time()
                    else:
                        logging.info(
                            "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
                            "normalized loss: %f, ppl: %f, speed: %.2f step/s" %
                            (step_idx, pass_id, batch_id, total_avg_cost,
                             total_avg_cost - loss_normalizer,
                             np.exp([min(total_avg_cost, 100)]),
                             args.fetch_steps / (time.time() - avg_batch_time)))
                        avg_batch_time = time.time()

                if step_idx % TrainTaskConfig.save_freq == 0 and step_idx > 0:
578 579 580 581 582 583 584 585 586
                    fluid.io.save_persistables(
                        exe,
                        os.path.join(TrainTaskConfig.ckpt_dir,
                                     "latest.checkpoint"), train_prog)
                    fluid.io.save_params(
                        exe,
                        os.path.join(TrainTaskConfig.model_dir,
                                     "iter_" + str(step_idx) + ".infer.model"),
                        train_prog)
G
guosheng 已提交
587

588 589 590 591 592 593 594 595
                init_flag = False
                batch_id += 1
                step_idx += 1
            except (StopIteration, fluid.core.EOFException):
                # The current pass is over.
                if args.use_py_reader:
                    pyreader.reset()
                break
G
guosheng 已提交
596 597

        time_consumed = time.time() - pass_start_time
598
        # Validate and save the persistable.
G
guosheng 已提交
599 600
        if args.val_file_pattern is not None:
            val_avg_cost, val_ppl = test()
G
fix  
gongweibao 已提交
601
            logging.info(
G
guosheng 已提交
602 603 604 605 606
                "epoch: %d, val avg loss: %f, val normalized loss: %f, val ppl: %f,"
                " consumed %fs" % (pass_id, val_avg_cost,
                                   val_avg_cost - loss_normalizer, val_ppl,
                                   time_consumed))
        else:
G
fix  
gongweibao 已提交
607
            logging.info("epoch: %d, consumed %fs" % (pass_id, time_consumed))
G
guosheng 已提交
608 609 610 611 612 613
        if not args.enable_ce:
            fluid.io.save_persistables(
                exe,
                os.path.join(TrainTaskConfig.ckpt_dir,
                             "pass_" + str(pass_id) + ".checkpoint"),
                train_prog)
614

G
guosheng 已提交
615
    if args.enable_ce:  # For CE
616
        print("kpis\ttrain_cost_card%d\t%f" % (dev_count, total_avg_cost))
617 618
        if args.val_file_pattern is not None:
            print("kpis\ttest_cost_card%d\t%f" % (dev_count, val_avg_cost))
619
        print("kpis\ttrain_duration_card%d\t%f" % (dev_count, time_consumed))
Q
Qiao Longfei 已提交
620 621


622 623 624 625 626
def train(args):
    # priority: ENV > args > config
    is_local = os.getenv("PADDLE_IS_LOCAL", "1")
    if is_local == '0':
        args.local = False
G
fix  
gongweibao 已提交
627
    logging.info(args)
628

629 630
    if args.device == 'CPU':
        TrainTaskConfig.use_gpu = False
G
guosheng 已提交
631

632
    training_role = os.getenv("TRAINING_ROLE", "TRAINER")
G
guosheng 已提交
633

634 635 636 637
    if training_role == "PSERVER" or (not TrainTaskConfig.use_gpu):
        place = fluid.CPUPlace()
        dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
    else:
C
chengduo 已提交
638 639
        gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
        place = fluid.CUDAPlace(gpu_id)
640 641 642
        dev_count = fluid.core.get_cuda_device_count()

    exe = fluid.Executor(place)
643

644 645
    train_prog = fluid.Program()
    startup_prog = fluid.Program()
G
guosheng 已提交
646

647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669
    if args.enable_ce:
        train_prog.random_seed = 1000
        startup_prog.random_seed = 1000

    with fluid.program_guard(train_prog, startup_prog):
        with fluid.unique_name.guard():
            sum_cost, avg_cost, predict, token_num, pyreader = transformer(
                ModelHyperParams.src_vocab_size,
                ModelHyperParams.trg_vocab_size,
                ModelHyperParams.max_length + 1,
                ModelHyperParams.n_layer,
                ModelHyperParams.n_head,
                ModelHyperParams.d_key,
                ModelHyperParams.d_value,
                ModelHyperParams.d_model,
                ModelHyperParams.d_inner_hid,
                ModelHyperParams.prepostprocess_dropout,
                ModelHyperParams.attention_dropout,
                ModelHyperParams.relu_dropout,
                ModelHyperParams.preprocess_cmd,
                ModelHyperParams.postprocess_cmd,
                ModelHyperParams.weight_sharing,
                TrainTaskConfig.label_smooth_eps,
Y
Yibing Liu 已提交
670
                ModelHyperParams.bos_idx,
671 672
                use_py_reader=args.use_py_reader,
                is_test=False)
673

674
            optimizer = None
G
fix bug  
gongweibao 已提交
675
            if args.sync:
676 677
                lr_decay = fluid.layers.learning_rate_scheduler.noam_decay(
                    ModelHyperParams.d_model, TrainTaskConfig.warmup_steps)
678
                logging.info("before adam")
G
fix  
gongweibao 已提交
679 680 681 682

                with fluid.default_main_program()._lr_schedule_guard():
                    learning_rate = lr_decay * TrainTaskConfig.learning_rate

683
                optimizer = fluid.optimizer.Adam(
G
fix  
gongweibao 已提交
684
                    learning_rate=learning_rate,
685 686 687
                    beta1=TrainTaskConfig.beta1,
                    beta2=TrainTaskConfig.beta2,
                    epsilon=TrainTaskConfig.eps)
G
fix bug  
gongweibao 已提交
688
            else:
689 690 691 692 693
                optimizer = fluid.optimizer.SGD(0.003)
            optimizer.minimize(avg_cost)

    if args.use_mem_opt:
        fluid.memory_optimize(train_prog)
694 695

    if args.local:
696
        logging.info("local start_up:")
697 698
        train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, avg_cost,
                   token_num, predict, pyreader)
699
    else:
G
fix  
gongweibao 已提交
700 701 702 703 704 705 706 707 708 709 710 711
        if args.update_method == "nccl2":
            trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
            port = os.getenv("PADDLE_PORT")
            worker_ips = os.getenv("PADDLE_TRAINERS")
            worker_endpoints = []
            for ip in worker_ips.split(","):
                worker_endpoints.append(':'.join([ip, port]))
            trainers_num = len(worker_endpoints)
            current_endpoint = os.getenv("POD_IP") + ":" + port
            if trainer_id == 0:
                logging.info("train_id == 0, sleep 60s")
                time.sleep(60)
712 713 714
            logging.info("trainers_num:{}".format(trainers_num))
            logging.info("worker_endpoints:{}".format(worker_endpoints))
            logging.info("current_endpoint:{}".format(current_endpoint))
G
guosheng 已提交
715 716 717 718 719
            append_nccl2_prepare(startup_prog, trainer_id, worker_endpoints,
                                 current_endpoint)
            train_loop(exe, train_prog, startup_prog, dev_count, sum_cost,
                       avg_cost, token_num, predict, pyreader, trainers_num,
                       trainer_id)
G
fix  
gongweibao 已提交
720 721
            return

722 723 724 725 726 727 728 729 730
        port = os.getenv("PADDLE_PORT", "6174")
        pserver_ips = os.getenv("PADDLE_PSERVERS")  # ip,ip...
        eplist = []
        for ip in pserver_ips.split(","):
            eplist.append(':'.join([ip, port]))
        pserver_endpoints = ",".join(eplist)  # ip:port,ip:port...
        trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "0"))
        current_endpoint = os.getenv("POD_IP") + ":" + port
        trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
G
fix  
gongweibao 已提交
731

732 733 734 735 736
        logging.info("pserver_endpoints:{}".format(pserver_endpoints))
        logging.info("current_endpoint:{}".format(current_endpoint))
        logging.info("trainer_id:{}".format(trainer_id))
        logging.info("pserver_ips:{}".format(pserver_ips))
        logging.info("port:{}".format(port))
G
fix  
gongweibao 已提交
737

738
        t = fluid.DistributeTranspiler()
739 740 741 742 743 744
        t.transpile(
            trainer_id,
            pservers=pserver_endpoints,
            trainers=trainers,
            program=train_prog,
            startup_program=startup_prog)
745 746

        if training_role == "PSERVER":
G
fix bug  
gongweibao 已提交
747
            logging.info("distributed: pserver started")
748 749 750
            current_endpoint = os.getenv("POD_IP") + ":" + os.getenv(
                "PADDLE_PORT")
            if not current_endpoint:
751
                logging.critical("need env SERVER_ENDPOINT")
752 753 754 755 756 757 758 759
                exit(1)
            pserver_prog = t.get_pserver_program(current_endpoint)
            pserver_startup = t.get_startup_program(current_endpoint,
                                                    pserver_prog)

            exe.run(pserver_startup)
            exe.run(pserver_prog)
        elif training_role == "TRAINER":
G
fix bug  
gongweibao 已提交
760
            logging.info("distributed: trainer started")
761
            trainer_prog = t.get_trainer_program()
G
fix  
gongweibao 已提交
762

763 764
            train_loop(exe, train_prog, startup_prog, dev_count, sum_cost,
                       avg_cost, token_num, predict, pyreader)
765
        else:
766 767
            logging.critical(
                "environment var TRAINER_ROLE should be TRAINER os PSERVER")
G
fix  
gongweibao 已提交
768
            exit(1)
769 770 771


if __name__ == "__main__":
G
fix  
gongweibao 已提交
772
    LOG_FORMAT = "[%(asctime)s %(levelname)s %(filename)s:%(lineno)d] %(message)s"
773 774
    logging.basicConfig(
        stream=sys.stdout, level=logging.DEBUG, format=LOG_FORMAT)
775
    logging.getLogger().setLevel(logging.INFO)
G
fix  
gongweibao 已提交
776

777 778
    args = parse_args()
    train(args)