train.py 28.4 KB
Newer Older
1 2
import argparse
import ast
G
guosheng 已提交
3 4
import copy
import logging
5
import multiprocessing
Y
Yu Yang 已提交
6
import os
G
guosheng 已提交
7
import six
G
guosheng 已提交
8
import sys
Y
Yu Yang 已提交
9
import time
Y
ying 已提交
10

Y
Yu Yang 已提交
11
import numpy as np
L
Luo Tao 已提交
12
import paddle.fluid as fluid
G
guosheng 已提交
13
from paddle.fluid.transpiler.details import program_to_code
Y
ying 已提交
14

Y
Yu Yang 已提交
15 16
import reader
from config import *
17
from model import transformer, position_encoding_init
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49


def parse_args():
    parser = argparse.ArgumentParser("Training for Transformer.")
    parser.add_argument(
        "--src_vocab_fpath",
        type=str,
        required=True,
        help="The path of vocabulary file of source language.")
    parser.add_argument(
        "--trg_vocab_fpath",
        type=str,
        required=True,
        help="The path of vocabulary file of target language.")
    parser.add_argument(
        "--train_file_pattern",
        type=str,
        required=True,
        help="The pattern to match training data files.")
    parser.add_argument(
        "--val_file_pattern",
        type=str,
        help="The pattern to match validation data files.")
    parser.add_argument(
        "--use_token_batch",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to "
        "produce batch data according to token number.")
    parser.add_argument(
        "--batch_size",
        type=int,
50
        default=4096,
51
        help="The number of sequences contained in a mini-batch, or the maximum "
52 53 54
        "number of tokens (include paddings) contained in a mini-batch. Note "
        "that this represents the number on single device and the actual batch "
        "size for multi-devices will multiply the device number.")
55 56 57
    parser.add_argument(
        "--pool_size",
        type=int,
58
        default=200000,
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
        help="The buffer size to pool data.")
    parser.add_argument(
        "--sort_type",
        default="pool",
        choices=("global", "pool", "none"),
        help="The grain to sort by length: global for all instances; pool for "
        "instances in pool; none for no sort.")
    parser.add_argument(
        "--shuffle",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to shuffle instances in each pass.")
    parser.add_argument(
        "--shuffle_batch",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to shuffle the data batches.")
    parser.add_argument(
        "--special_token",
        type=str,
        default=["<s>", "<e>", "<unk>"],
        nargs=3,
        help="The <bos>, <eos> and <unk> tokens in the dictionary.")
82 83
    parser.add_argument(
        "--token_delimiter",
G
guosheng 已提交
84
        type=lambda x: str(x.encode().decode("unicode-escape")),
85 86
        default=" ",
        help="The delimiter used to split tokens in source or target sentences. "
87
        "For EN-DE BPE data we provided, use spaces as token delimiter. ")
88 89 90 91 92
    parser.add_argument(
        'opts',
        help='See config.py for all options',
        default=None,
        nargs=argparse.REMAINDER)
93 94 95 96 97 98 99 100 101 102 103
    parser.add_argument(
        '--local',
        type=ast.literal_eval,
        default=True,
        help='Whether to run as local mode.')
    parser.add_argument(
        '--device',
        type=str,
        default='GPU',
        choices=['CPU', 'GPU'],
        help="The device type.")
G
fix  
gongweibao 已提交
104 105 106 107 108
    parser.add_argument(
        '--update_method',
        choices=("pserver", "nccl2"),
        default="pserver",
        help='Update method.')
Q
Qiao Longfei 已提交
109 110
    parser.add_argument(
        '--sync', type=ast.literal_eval, default=True, help="sync mode.")
G
guosheng 已提交
111 112 113
    parser.add_argument(
        "--enable_ce",
        type=ast.literal_eval,
114
        default=False,
G
guosheng 已提交
115 116
        help="The flag indicating whether to run the task "
        "for continuous evaluation.")
117 118 119
    parser.add_argument(
        "--use_mem_opt",
        type=ast.literal_eval,
G
guosheng 已提交
120
        default=True,
121 122 123 124 125 126
        help="The flag indicating whether to use memory optimization.")
    parser.add_argument(
        "--use_py_reader",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to use py_reader.")
G
fix  
gongweibao 已提交
127
    parser.add_argument(
G
guosheng 已提交
128 129 130 131
        "--fetch_steps",
        type=int,
        default=100,
        help="The frequency to fetch and print output.")
G
fix  
gongweibao 已提交
132

133
    args = parser.parse_args()
134 135 136 137 138 139 140 141 142 143 144
    # Append args related to dict
    src_dict = reader.DataReader.load_dict(args.src_vocab_fpath)
    trg_dict = reader.DataReader.load_dict(args.trg_vocab_fpath)
    dict_args = [
        "src_vocab_size", str(len(src_dict)), "trg_vocab_size",
        str(len(trg_dict)), "bos_idx", str(src_dict[args.special_token[0]]),
        "eos_idx", str(src_dict[args.special_token[1]]), "unk_idx",
        str(src_dict[args.special_token[2]])
    ]
    merge_cfg_from_list(args.opts + dict_args,
                        [TrainTaskConfig, ModelHyperParams])
145
    return args
146 147


G
fix  
gongweibao 已提交
148
def append_nccl2_prepare(trainer_id, worker_endpoints, current_endpoint):
149 150
    assert (trainer_id >= 0 and len(worker_endpoints) > 1 and
            current_endpoint in worker_endpoints)
G
fix  
gongweibao 已提交
151 152 153
    eps = copy.deepcopy(worker_endpoints)
    eps.remove(current_endpoint)
    nccl_id_var = fluid.default_startup_program().global_block().create_var(
154
        name="NCCLID", persistable=True, type=fluid.core.VarDesc.VarType.RAW)
G
fix  
gongweibao 已提交
155 156 157 158 159 160 161 162 163 164
    fluid.default_startup_program().global_block().append_op(
        type="gen_nccl_id",
        inputs={},
        outputs={"NCCLID": nccl_id_var},
        attrs={
            "endpoint": current_endpoint,
            "endpoint_list": eps,
            "trainer_id": trainer_id
        })
    return nccl_id_var
165

166

167 168 169 170
def pad_batch_data(insts,
                   pad_idx,
                   n_head,
                   is_target=False,
171
                   is_label=False,
172
                   return_attn_bias=True,
173 174
                   return_max_len=True,
                   return_num_token=False):
175 176
    """
    Pad the instances to the max sequence length in batch, and generate the
177 178 179 180
    corresponding position data and attention bias.
    """
    return_list = []
    max_len = max(len(inst) for inst in insts)
G
guosheng 已提交
181 182 183 184
    # Any token included in dict can be used to pad, since the paddings' loss
    # will be masked out by weights and make no effect on parameter gradients.
    inst_data = np.array(
        [inst + [pad_idx] * (max_len - len(inst)) for inst in insts])
185
    return_list += [inst_data.astype("int64").reshape([-1, 1])]
186 187 188 189 190 191
    if is_label:  # label weight
        inst_weight = np.array(
            [[1.] * len(inst) + [0.] * (max_len - len(inst)) for inst in insts])
        return_list += [inst_weight.astype("float32").reshape([-1, 1])]
    else:  # position data
        inst_pos = np.array([
192
            list(range(0, len(inst))) + [0] * (max_len - len(inst))
193 194
            for inst in insts
        ])
195 196 197 198 199 200
        return_list += [inst_pos.astype("int64").reshape([-1, 1])]
    if return_attn_bias:
        if is_target:
            # This is used to avoid attention on paddings and subsequent
            # words.
            slf_attn_bias_data = np.ones((inst_data.shape[0], max_len, max_len))
201 202
            slf_attn_bias_data = np.triu(slf_attn_bias_data,
                                         1).reshape([-1, 1, max_len, max_len])
203 204 205 206 207 208 209 210 211 212 213 214 215
            slf_attn_bias_data = np.tile(slf_attn_bias_data,
                                         [1, n_head, 1, 1]) * [-1e9]
        else:
            # This is used to avoid attention on paddings.
            slf_attn_bias_data = np.array([[0] * len(inst) + [-1e9] *
                                           (max_len - len(inst))
                                           for inst in insts])
            slf_attn_bias_data = np.tile(
                slf_attn_bias_data.reshape([-1, 1, 1, max_len]),
                [1, n_head, max_len, 1])
        return_list += [slf_attn_bias_data.astype("float32")]
    if return_max_len:
        return_list += [max_len]
216
    if return_num_token:
G
guosheng 已提交
217 218 219
        num_token = 0
        for inst in insts:
            num_token += len(inst)
220
        return_list += [num_token]
221 222 223
    return return_list if len(return_list) > 1 else return_list[0]


224 225
def prepare_batch_input(insts, data_input_names, src_pad_idx, trg_pad_idx,
                        n_head, d_model):
226 227
    """
    Put all padded data needed by training into a dict.
228
    """
229
    src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
G
guosheng 已提交
230
        [inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
231 232
    src_word = src_word.reshape(-1, src_max_len, 1)
    src_pos = src_pos.reshape(-1, src_max_len, 1)
233
    trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
G
guosheng 已提交
234
        [inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True)
235 236 237
    trg_word = trg_word.reshape(-1, trg_max_len, 1)
    trg_pos = trg_pos.reshape(-1, trg_max_len, 1)

238 239
    trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
                                [1, 1, trg_max_len, 1]).astype("float32")
240

241
    lbl_word, lbl_weight, num_token = pad_batch_data(
242 243 244 245 246 247
        [inst[2] for inst in insts],
        trg_pad_idx,
        n_head,
        is_target=False,
        is_label=True,
        return_attn_bias=False,
248 249 250 251 252 253 254
        return_max_len=False,
        return_num_token=True)

    data_input_dict = dict(
        zip(data_input_names, [
            src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
            trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight
255
        ]))
256

257
    return data_input_dict, np.asarray([num_token], dtype="float32")
258 259


260
def prepare_data_generator(args, is_test, count, pyreader):
Q
Qiao Longfei 已提交
261
    """
262 263
    Data generator wrapper for DataReader. If use py_reader, set the data
    provider for py_reader
Q
Qiao Longfei 已提交
264
    """
265 266
    data_reader = reader.DataReader(
        fpattern=args.val_file_pattern if is_test else args.train_file_pattern,
Q
Qiao Longfei 已提交
267 268
        src_vocab_fpath=args.src_vocab_fpath,
        trg_vocab_fpath=args.trg_vocab_fpath,
269
        token_delimiter=args.token_delimiter,
Q
Qiao Longfei 已提交
270
        use_token_batch=args.use_token_batch,
271
        batch_size=args.batch_size * (1 if args.use_token_batch else count),
Q
Qiao Longfei 已提交
272 273
        pool_size=args.pool_size,
        sort_type=args.sort_type,
274 275
        shuffle=args.shuffle,
        shuffle_batch=args.shuffle_batch,
Q
Qiao Longfei 已提交
276 277 278 279 280
        start_mark=args.special_token[0],
        end_mark=args.special_token[1],
        unk_mark=args.special_token[2],
        # count start and end tokens out
        max_length=ModelHyperParams.max_length - 2,
281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362
        clip_last_batch=False).batch_generator

    def stack(data_reader, count, clip_last=True):
        def __impl__():
            res = []
            for item in data_reader():
                res.append(item)
                if len(res) == count:
                    yield res
                    res = []
            if len(res) == count:
                yield res
            elif not clip_last:
                data = []
                for item in res:
                    data += item
                if len(data) > count:
                    inst_num_per_part = len(data) // count
                    yield [
                        data[inst_num_per_part * i:inst_num_per_part * (i + 1)]
                        for i in range(count)
                    ]

        return __impl__

    def split(data_reader, count):
        def __impl__():
            for item in data_reader():
                inst_num_per_part = len(item) // count
                for i in range(count):
                    yield item[inst_num_per_part * i:inst_num_per_part * (i + 1
                                                                          )]

        return __impl__

    if not args.use_token_batch:
        # to make data on each device have similar token number
        data_reader = split(data_reader, count)
    if args.use_py_reader:
        pyreader.decorate_tensor_provider(
            py_reader_provider_wrapper(data_reader))
        data_reader = None
    else:  # Data generator for multi-devices
        data_reader = stack(data_reader, count)
    return data_reader


def prepare_feed_dict_list(data_generator, init_flag, count):
    """
    Prepare the list of feed dict for multi-devices.
    """
    feed_dict_list = []
    if data_generator is not None:  # use_py_reader == False
        data_input_names = encoder_data_input_fields + \
                    decoder_data_input_fields[:-1] + label_data_input_fields
        data = next(data_generator)
        for idx, data_buffer in enumerate(data):
            data_input_dict, num_token = prepare_batch_input(
                data_buffer, data_input_names, ModelHyperParams.eos_idx,
                ModelHyperParams.eos_idx, ModelHyperParams.n_head,
                ModelHyperParams.d_model)
            feed_dict_list.append(data_input_dict)
    if init_flag:
        for idx in range(count):
            pos_enc_tables = dict()
            for pos_enc_param_name in pos_enc_param_names:
                pos_enc_tables[pos_enc_param_name] = position_encoding_init(
                    ModelHyperParams.max_length + 1, ModelHyperParams.d_model)
            if len(feed_dict_list) <= idx:
                feed_dict_list.append(pos_enc_tables)
            else:
                feed_dict_list[idx] = dict(
                    list(pos_enc_tables.items()) + list(feed_dict_list[idx]
                                                        .items()))

    return feed_dict_list if len(feed_dict_list) == count else None


def py_reader_provider_wrapper(data_reader):
    """
    Data provider needed by fluid.layers.py_reader.
    """
Q
Qiao Longfei 已提交
363

364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380
    def py_reader_provider():
        data_input_names = encoder_data_input_fields + \
                    decoder_data_input_fields[:-1] + label_data_input_fields
        for batch_id, data in enumerate(data_reader()):
            data_input_dict, num_token = prepare_batch_input(
                data, data_input_names, ModelHyperParams.eos_idx,
                ModelHyperParams.eos_idx, ModelHyperParams.n_head,
                ModelHyperParams.d_model)
            total_dict = dict(data_input_dict.items())
            yield [total_dict[item] for item in data_input_names]

    return py_reader_provider


def test_context(exe, train_exe, dev_count):
    # Context to do validation.
    test_prog = fluid.Program()
G
guosheng 已提交
381 382 383 384
    startup_prog = fluid.Program()
    if args.enable_ce:
        test_prog.random_seed = 1000
        startup_prog.random_seed = 1000
385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410
    with fluid.program_guard(test_prog, startup_prog):
        with fluid.unique_name.guard():
            sum_cost, avg_cost, predict, token_num, pyreader = transformer(
                ModelHyperParams.src_vocab_size,
                ModelHyperParams.trg_vocab_size,
                ModelHyperParams.max_length + 1,
                ModelHyperParams.n_layer,
                ModelHyperParams.n_head,
                ModelHyperParams.d_key,
                ModelHyperParams.d_value,
                ModelHyperParams.d_model,
                ModelHyperParams.d_inner_hid,
                ModelHyperParams.prepostprocess_dropout,
                ModelHyperParams.attention_dropout,
                ModelHyperParams.relu_dropout,
                ModelHyperParams.preprocess_cmd,
                ModelHyperParams.postprocess_cmd,
                ModelHyperParams.weight_sharing,
                TrainTaskConfig.label_smooth_eps,
                use_py_reader=args.use_py_reader,
                is_test=True)

    test_data = prepare_data_generator(
        args, is_test=True, count=dev_count, pyreader=pyreader)

    exe.run(startup_prog)
Q
Qiao Longfei 已提交
411 412
    test_exe = fluid.ParallelExecutor(
        use_cuda=TrainTaskConfig.use_gpu,
413
        main_program=test_prog,
Q
Qiao Longfei 已提交
414 415
        share_vars_from=train_exe)

416
    def test(exe=test_exe, pyreader=pyreader):
Q
Qiao Longfei 已提交
417 418
        test_total_cost = 0
        test_total_token = 0
419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435

        if args.use_py_reader:
            pyreader.start()
            data_generator = None
        else:
            data_generator = test_data()
        while True:
            try:
                feed_dict_list = prepare_feed_dict_list(data_generator, False,
                                                        dev_count)
                outs = test_exe.run(fetch_list=[sum_cost.name, token_num.name],
                                    feed=feed_dict_list)
            except (StopIteration, fluid.core.EOFException):
                # The current pass is over.
                if args.use_py_reader:
                    pyreader.reset()
                break
Q
Qiao Longfei 已提交
436 437 438 439 440 441 442 443 444 445
            sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[1])
            test_total_cost += sum_cost_val.sum()
            test_total_token += token_num_val.sum()
        test_avg_cost = test_total_cost / test_total_token
        test_ppl = np.exp([min(test_avg_cost, 100)])
        return test_avg_cost, test_ppl

    return test


446 447 448 449 450 451 452 453 454 455 456
def train_loop(exe,
               train_prog,
               startup_prog,
               dev_count,
               sum_cost,
               avg_cost,
               token_num,
               predict,
               pyreader,
               nccl2_num_trainers=1,
               nccl2_trainer_id=0):
Q
Qiao Longfei 已提交
457 458 459 460
    # Initialize the parameters.
    if TrainTaskConfig.ckpt_path:
        fluid.io.load_persistables(exe, TrainTaskConfig.ckpt_path)
    else:
G
fix  
gongweibao 已提交
461
        logging.info("init fluid.framework.default_startup_program")
462
        exe.run(startup_prog)
Q
Qiao Longfei 已提交
463

G
fix  
gongweibao 已提交
464
    logging.info("begin reader")
465 466
    train_data = prepare_data_generator(
        args, is_test=False, count=dev_count, pyreader=pyreader)
Q
Qiao Longfei 已提交
467

468 469 470 471
    # For faster executor
    exec_strategy = fluid.ExecutionStrategy()
    exec_strategy.use_experimental_executor = True
    # exec_strategy.num_iteration_per_drop_scope = 5
Q
Qiao Longfei 已提交
472 473 474 475
    build_strategy = fluid.BuildStrategy()
    # Since the token number differs among devices, customize gradient scale to
    # use token average cost among multi-devices. and the gradient scale is
    # `1 / token_number` for average cost.
G
guosheng 已提交
476
    # build_strategy.gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.Customized
G
fix  
gongweibao 已提交
477

G
fix  
gongweibao 已提交
478
    logging.info("begin executor")
Q
Qiao Longfei 已提交
479 480
    train_exe = fluid.ParallelExecutor(
        use_cuda=TrainTaskConfig.use_gpu,
481 482 483
        loss_name=avg_cost.name,
        main_program=train_prog,
        build_strategy=build_strategy,
G
fix  
gongweibao 已提交
484
        exec_strategy=exec_strategy,
485 486
        num_trainers=nccl2_num_trainers,
        trainer_id=nccl2_trainer_id)
Q
Qiao Longfei 已提交
487 488

    if args.val_file_pattern is not None:
489
        test = test_context(exe, train_exe, dev_count)
Q
Qiao Longfei 已提交
490

G
guosheng 已提交
491 492 493 494 495 496
    # the best cross-entropy value with label smoothing
    loss_normalizer = -((1. - TrainTaskConfig.label_smooth_eps) * np.log(
        (1. - TrainTaskConfig.label_smooth_eps
         )) + TrainTaskConfig.label_smooth_eps *
                        np.log(TrainTaskConfig.label_smooth_eps / (
                            ModelHyperParams.trg_vocab_size - 1) + 1e-20))
G
guosheng 已提交
497 498

    step_idx = 0
499
    init_flag = True
G
fix  
gongweibao 已提交
500 501

    logging.info("begin train")
G
guosheng 已提交
502
    for pass_id in six.moves.xrange(TrainTaskConfig.pass_num):
Q
Qiao Longfei 已提交
503
        pass_start_time = time.time()
504 505 506 507 508 509 510 511 512 513 514 515 516

        if args.use_py_reader:
            pyreader.start()
            data_generator = None
        else:
            data_generator = train_data()

        batch_id = 0
        while True:
            try:
                feed_dict_list = prepare_feed_dict_list(data_generator,
                                                        init_flag, dev_count)
                outs = train_exe.run(
517
                    fetch_list=[sum_cost.name, token_num.name]
G
guosheng 已提交
518
                    if step_idx % args.fetch_steps == 0 else [],
519
                    feed=feed_dict_list)
520

G
guosheng 已提交
521
                if step_idx % args.fetch_steps == 0:
522 523
                    sum_cost_val, token_num_val = np.array(outs[0]), np.array(
                        outs[1])
G
fix  
gongweibao 已提交
524 525 526 527 528
                    # sum the cost from multi-devices
                    total_sum_cost = sum_cost_val.sum()
                    total_token_num = token_num_val.sum()
                    total_avg_cost = total_sum_cost / total_token_num

G
guosheng 已提交
529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547
                    if step_idx == 0:
                        logging.info(
                            "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
                            "normalized loss: %f, ppl: %f" %
                            (step_idx, pass_id, batch_id, total_avg_cost,
                             total_avg_cost - loss_normalizer,
                             np.exp([min(total_avg_cost, 100)])))
                        avg_batch_time = time.time()
                    else:
                        logging.info(
                            "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
                            "normalized loss: %f, ppl: %f, speed: %.2f step/s" %
                            (step_idx, pass_id, batch_id, total_avg_cost,
                             total_avg_cost - loss_normalizer,
                             np.exp([min(total_avg_cost, 100)]),
                             args.fetch_steps / (time.time() - avg_batch_time)))
                        avg_batch_time = time.time()

                if step_idx % TrainTaskConfig.save_freq == 0 and step_idx > 0:
548 549 550 551 552 553 554 555 556
                    fluid.io.save_persistables(
                        exe,
                        os.path.join(TrainTaskConfig.ckpt_dir,
                                     "latest.checkpoint"), train_prog)
                    fluid.io.save_params(
                        exe,
                        os.path.join(TrainTaskConfig.model_dir,
                                     "iter_" + str(step_idx) + ".infer.model"),
                        train_prog)
G
guosheng 已提交
557

558 559 560 561 562 563 564 565
                init_flag = False
                batch_id += 1
                step_idx += 1
            except (StopIteration, fluid.core.EOFException):
                # The current pass is over.
                if args.use_py_reader:
                    pyreader.reset()
                break
G
guosheng 已提交
566 567

        time_consumed = time.time() - pass_start_time
568
        # Validate and save the persistable.
G
guosheng 已提交
569 570
        if args.val_file_pattern is not None:
            val_avg_cost, val_ppl = test()
G
fix  
gongweibao 已提交
571
            logging.info(
G
guosheng 已提交
572 573 574 575 576
                "epoch: %d, val avg loss: %f, val normalized loss: %f, val ppl: %f,"
                " consumed %fs" % (pass_id, val_avg_cost,
                                   val_avg_cost - loss_normalizer, val_ppl,
                                   time_consumed))
        else:
G
fix  
gongweibao 已提交
577
            logging.info("epoch: %d, consumed %fs" % (pass_id, time_consumed))
G
guosheng 已提交
578 579 580 581 582 583
        if not args.enable_ce:
            fluid.io.save_persistables(
                exe,
                os.path.join(TrainTaskConfig.ckpt_dir,
                             "pass_" + str(pass_id) + ".checkpoint"),
                train_prog)
584

G
guosheng 已提交
585
    if args.enable_ce:  # For CE
586
        print("kpis\ttrain_cost_card%d\t%f" % (dev_count, total_avg_cost))
587 588
        if args.val_file_pattern is not None:
            print("kpis\ttest_cost_card%d\t%f" % (dev_count, val_avg_cost))
589
        print("kpis\ttrain_duration_card%d\t%f" % (dev_count, time_consumed))
Q
Qiao Longfei 已提交
590 591


592 593 594 595 596
def train(args):
    # priority: ENV > args > config
    is_local = os.getenv("PADDLE_IS_LOCAL", "1")
    if is_local == '0':
        args.local = False
G
fix  
gongweibao 已提交
597
    logging.info(args)
598

599 600
    if args.device == 'CPU':
        TrainTaskConfig.use_gpu = False
G
guosheng 已提交
601

602
    training_role = os.getenv("TRAINING_ROLE", "TRAINER")
G
guosheng 已提交
603

604 605 606 607 608 609 610 611
    if training_role == "PSERVER" or (not TrainTaskConfig.use_gpu):
        place = fluid.CPUPlace()
        dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
    else:
        place = fluid.CUDAPlace(0)
        dev_count = fluid.core.get_cuda_device_count()

    exe = fluid.Executor(place)
612

613 614
    train_prog = fluid.Program()
    startup_prog = fluid.Program()
G
guosheng 已提交
615

616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640
    if args.enable_ce:
        train_prog.random_seed = 1000
        startup_prog.random_seed = 1000

    with fluid.program_guard(train_prog, startup_prog):
        with fluid.unique_name.guard():
            sum_cost, avg_cost, predict, token_num, pyreader = transformer(
                ModelHyperParams.src_vocab_size,
                ModelHyperParams.trg_vocab_size,
                ModelHyperParams.max_length + 1,
                ModelHyperParams.n_layer,
                ModelHyperParams.n_head,
                ModelHyperParams.d_key,
                ModelHyperParams.d_value,
                ModelHyperParams.d_model,
                ModelHyperParams.d_inner_hid,
                ModelHyperParams.prepostprocess_dropout,
                ModelHyperParams.attention_dropout,
                ModelHyperParams.relu_dropout,
                ModelHyperParams.preprocess_cmd,
                ModelHyperParams.postprocess_cmd,
                ModelHyperParams.weight_sharing,
                TrainTaskConfig.label_smooth_eps,
                use_py_reader=args.use_py_reader,
                is_test=False)
641

642
            optimizer = None
G
fix bug  
gongweibao 已提交
643
            if args.sync:
644 645
                lr_decay = fluid.layers.learning_rate_scheduler.noam_decay(
                    ModelHyperParams.d_model, TrainTaskConfig.warmup_steps)
646
                logging.info("before adam")
G
fix  
gongweibao 已提交
647 648 649 650

                with fluid.default_main_program()._lr_schedule_guard():
                    learning_rate = lr_decay * TrainTaskConfig.learning_rate

651
                optimizer = fluid.optimizer.Adam(
G
fix  
gongweibao 已提交
652
                    learning_rate=learning_rate,
653 654 655
                    beta1=TrainTaskConfig.beta1,
                    beta2=TrainTaskConfig.beta2,
                    epsilon=TrainTaskConfig.eps)
G
fix bug  
gongweibao 已提交
656
            else:
657 658 659 660 661
                optimizer = fluid.optimizer.SGD(0.003)
            optimizer.minimize(avg_cost)

    if args.use_mem_opt:
        fluid.memory_optimize(train_prog)
662 663

    if args.local:
664
        logging.info("local start_up:")
665 666
        train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, avg_cost,
                   token_num, predict, pyreader)
667
    else:
G
fix  
gongweibao 已提交
668 669 670 671 672 673 674 675 676 677 678 679
        if args.update_method == "nccl2":
            trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
            port = os.getenv("PADDLE_PORT")
            worker_ips = os.getenv("PADDLE_TRAINERS")
            worker_endpoints = []
            for ip in worker_ips.split(","):
                worker_endpoints.append(':'.join([ip, port]))
            trainers_num = len(worker_endpoints)
            current_endpoint = os.getenv("POD_IP") + ":" + port
            if trainer_id == 0:
                logging.info("train_id == 0, sleep 60s")
                time.sleep(60)
680 681 682
            logging.info("trainers_num:{}".format(trainers_num))
            logging.info("worker_endpoints:{}".format(worker_endpoints))
            logging.info("current_endpoint:{}".format(current_endpoint))
G
fix  
gongweibao 已提交
683
            append_nccl2_prepare(trainer_id, worker_endpoints, current_endpoint)
684 685
            train_loop(exe,
                       fluid.default_main_program(), dev_count, sum_cost,
G
guosheng 已提交
686
                       avg_cost, token_num, predict, trainers_num, trainer_id)
G
fix  
gongweibao 已提交
687 688
            return

689 690 691 692 693 694 695 696 697
        port = os.getenv("PADDLE_PORT", "6174")
        pserver_ips = os.getenv("PADDLE_PSERVERS")  # ip,ip...
        eplist = []
        for ip in pserver_ips.split(","):
            eplist.append(':'.join([ip, port]))
        pserver_endpoints = ",".join(eplist)  # ip:port,ip:port...
        trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "0"))
        current_endpoint = os.getenv("POD_IP") + ":" + port
        trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
G
fix  
gongweibao 已提交
698

699 700 701 702 703
        logging.info("pserver_endpoints:{}".format(pserver_endpoints))
        logging.info("current_endpoint:{}".format(current_endpoint))
        logging.info("trainer_id:{}".format(trainer_id))
        logging.info("pserver_ips:{}".format(pserver_ips))
        logging.info("port:{}".format(port))
G
fix  
gongweibao 已提交
704

705
        t = fluid.DistributeTranspiler()
706 707 708 709 710 711
        t.transpile(
            trainer_id,
            pservers=pserver_endpoints,
            trainers=trainers,
            program=train_prog,
            startup_program=startup_prog)
712 713

        if training_role == "PSERVER":
G
fix bug  
gongweibao 已提交
714
            logging.info("distributed: pserver started")
715 716 717
            current_endpoint = os.getenv("POD_IP") + ":" + os.getenv(
                "PADDLE_PORT")
            if not current_endpoint:
718
                logging.critical("need env SERVER_ENDPOINT")
719 720 721 722 723 724 725 726
                exit(1)
            pserver_prog = t.get_pserver_program(current_endpoint)
            pserver_startup = t.get_startup_program(current_endpoint,
                                                    pserver_prog)

            exe.run(pserver_startup)
            exe.run(pserver_prog)
        elif training_role == "TRAINER":
G
fix bug  
gongweibao 已提交
727
            logging.info("distributed: trainer started")
728
            trainer_prog = t.get_trainer_program()
G
fix  
gongweibao 已提交
729

730 731
            train_loop(exe, train_prog, startup_prog, dev_count, sum_cost,
                       avg_cost, token_num, predict, pyreader)
732
        else:
733 734
            logging.critical(
                "environment var TRAINER_ROLE should be TRAINER os PSERVER")
G
fix  
gongweibao 已提交
735
            exit(1)
736 737 738


if __name__ == "__main__":
G
fix  
gongweibao 已提交
739
    LOG_FORMAT = "[%(asctime)s %(levelname)s %(filename)s:%(lineno)d] %(message)s"
740 741
    logging.basicConfig(
        stream=sys.stdout, level=logging.DEBUG, format=LOG_FORMAT)
G
fix  
gongweibao 已提交
742

743 744
    args = parse_args()
    train(args)