train.py 31.0 KB
Newer Older
1 2
import argparse
import ast
G
guosheng 已提交
3 4
import copy
import logging
5
import multiprocessing
Y
Yu Yang 已提交
6
import os
7
import subprocess
8 9 10 11

if os.environ.get('FLAGS_eager_delete_tensor_gb', None) is None:
    os.environ['FLAGS_eager_delete_tensor_gb'] = '0'

G
guosheng 已提交
12
import six
G
guosheng 已提交
13
import sys
14 15 16
if sys.version[0] == '2':
    reload(sys)
    sys.setdefaultencoding("utf-8")
17
sys.path.append("../../")
Y
Yibing Liu 已提交
18
sys.path.append("../../models/neural_machine_translation/transformer/")
Y
Yu Yang 已提交
19
import time
Y
ying 已提交
20

Y
Yu Yang 已提交
21
import numpy as np
L
Luo Tao 已提交
22
import paddle.fluid as fluid
Y
ying 已提交
23

24
from models.model_check import check_cuda
Y
Yu Yang 已提交
25 26
import reader
from config import *
Y
Yibing Liu 已提交
27
from desc import *
28
from model import transformer, position_encoding_init
29 30 31
import dist_utils

num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63


def parse_args():
    parser = argparse.ArgumentParser("Training for Transformer.")
    parser.add_argument(
        "--src_vocab_fpath",
        type=str,
        required=True,
        help="The path of vocabulary file of source language.")
    parser.add_argument(
        "--trg_vocab_fpath",
        type=str,
        required=True,
        help="The path of vocabulary file of target language.")
    parser.add_argument(
        "--train_file_pattern",
        type=str,
        required=True,
        help="The pattern to match training data files.")
    parser.add_argument(
        "--val_file_pattern",
        type=str,
        help="The pattern to match validation data files.")
    parser.add_argument(
        "--use_token_batch",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to "
        "produce batch data according to token number.")
    parser.add_argument(
        "--batch_size",
        type=int,
64
        default=4096,
65
        help="The number of sequences contained in a mini-batch, or the maximum "
66 67 68
        "number of tokens (include paddings) contained in a mini-batch. Note "
        "that this represents the number on single device and the actual batch "
        "size for multi-devices will multiply the device number.")
69 70 71
    parser.add_argument(
        "--pool_size",
        type=int,
72
        default=200000,
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
        help="The buffer size to pool data.")
    parser.add_argument(
        "--sort_type",
        default="pool",
        choices=("global", "pool", "none"),
        help="The grain to sort by length: global for all instances; pool for "
        "instances in pool; none for no sort.")
    parser.add_argument(
        "--shuffle",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to shuffle instances in each pass.")
    parser.add_argument(
        "--shuffle_batch",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to shuffle the data batches.")
    parser.add_argument(
        "--special_token",
92 93
        type=lambda x: x.encode(),
        default=[b"<s>", b"<e>", b"<unk>"],
94 95
        nargs=3,
        help="The <bos>, <eos> and <unk> tokens in the dictionary.")
96 97
    parser.add_argument(
        "--token_delimiter",
98 99
        type=lambda x: x.encode(),
        default=b" ",
100
        help="The delimiter used to split tokens in source or target sentences. "
101
        "For EN-DE BPE data we provided, use spaces as token delimiter. ")
102 103 104 105 106
    parser.add_argument(
        'opts',
        help='See config.py for all options',
        default=None,
        nargs=argparse.REMAINDER)
107 108 109 110 111 112 113 114 115 116 117
    parser.add_argument(
        '--local',
        type=ast.literal_eval,
        default=True,
        help='Whether to run as local mode.')
    parser.add_argument(
        '--device',
        type=str,
        default='GPU',
        choices=['CPU', 'GPU'],
        help="The device type.")
G
fix  
gongweibao 已提交
118 119 120 121 122
    parser.add_argument(
        '--update_method',
        choices=("pserver", "nccl2"),
        default="pserver",
        help='Update method.')
Q
Qiao Longfei 已提交
123 124
    parser.add_argument(
        '--sync', type=ast.literal_eval, default=True, help="sync mode.")
G
guosheng 已提交
125 126 127
    parser.add_argument(
        "--enable_ce",
        type=ast.literal_eval,
128
        default=False,
G
guosheng 已提交
129 130
        help="The flag indicating whether to run the task "
        "for continuous evaluation.")
131 132 133
    parser.add_argument(
        "--use_mem_opt",
        type=ast.literal_eval,
G
guosheng 已提交
134
        default=True,
135 136 137 138 139 140
        help="The flag indicating whether to use memory optimization.")
    parser.add_argument(
        "--use_py_reader",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to use py_reader.")
G
fix  
gongweibao 已提交
141
    parser.add_argument(
G
guosheng 已提交
142 143 144 145
        "--fetch_steps",
        type=int,
        default=100,
        help="The frequency to fetch and print output.")
G
fix  
gongweibao 已提交
146

147
    args = parser.parse_args()
148 149 150 151 152 153 154 155 156 157 158
    # Append args related to dict
    src_dict = reader.DataReader.load_dict(args.src_vocab_fpath)
    trg_dict = reader.DataReader.load_dict(args.trg_vocab_fpath)
    dict_args = [
        "src_vocab_size", str(len(src_dict)), "trg_vocab_size",
        str(len(trg_dict)), "bos_idx", str(src_dict[args.special_token[0]]),
        "eos_idx", str(src_dict[args.special_token[1]]), "unk_idx",
        str(src_dict[args.special_token[2]])
    ]
    merge_cfg_from_list(args.opts + dict_args,
                        [TrainTaskConfig, ModelHyperParams])
159
    return args
160 161


162 163 164 165 166 167 168 169 170 171 172 173
def get_device_num():
    # NOTE(zcd): for multi-processe training, each process use one GPU card.
    if num_trainers > 1: return 1
    visible_device = os.environ.get('CUDA_VISIBLE_DEVICES', None)
    if visible_device:
        device_num = len(visible_device.split(','))
    else:
        device_num = subprocess.check_output(
            ['nvidia-smi', '-L']).decode().count('\n')
    return device_num


G
guosheng 已提交
174 175
def append_nccl2_prepare(startup_prog, trainer_id, worker_endpoints,
                         current_endpoint):
176 177
    assert (trainer_id >= 0 and len(worker_endpoints) > 1 and
            current_endpoint in worker_endpoints)
G
fix  
gongweibao 已提交
178 179
    eps = copy.deepcopy(worker_endpoints)
    eps.remove(current_endpoint)
G
guosheng 已提交
180
    nccl_id_var = startup_prog.global_block().create_var(
181
        name="NCCLID", persistable=True, type=fluid.core.VarDesc.VarType.RAW)
G
guosheng 已提交
182
    startup_prog.global_block().append_op(
G
fix  
gongweibao 已提交
183 184 185 186 187 188 189 190 191
        type="gen_nccl_id",
        inputs={},
        outputs={"NCCLID": nccl_id_var},
        attrs={
            "endpoint": current_endpoint,
            "endpoint_list": eps,
            "trainer_id": trainer_id
        })
    return nccl_id_var
192

193

194 195 196 197
def pad_batch_data(insts,
                   pad_idx,
                   n_head,
                   is_target=False,
198
                   is_label=False,
199
                   return_attn_bias=True,
200 201
                   return_max_len=True,
                   return_num_token=False):
202 203
    """
    Pad the instances to the max sequence length in batch, and generate the
204 205 206 207
    corresponding position data and attention bias.
    """
    return_list = []
    max_len = max(len(inst) for inst in insts)
G
guosheng 已提交
208 209 210 211
    # Any token included in dict can be used to pad, since the paddings' loss
    # will be masked out by weights and make no effect on parameter gradients.
    inst_data = np.array(
        [inst + [pad_idx] * (max_len - len(inst)) for inst in insts])
212
    return_list += [inst_data.astype("int64").reshape([-1, 1])]
213 214 215 216 217 218
    if is_label:  # label weight
        inst_weight = np.array(
            [[1.] * len(inst) + [0.] * (max_len - len(inst)) for inst in insts])
        return_list += [inst_weight.astype("float32").reshape([-1, 1])]
    else:  # position data
        inst_pos = np.array([
219
            list(range(0, len(inst))) + [0] * (max_len - len(inst))
220 221
            for inst in insts
        ])
222 223 224 225 226 227
        return_list += [inst_pos.astype("int64").reshape([-1, 1])]
    if return_attn_bias:
        if is_target:
            # This is used to avoid attention on paddings and subsequent
            # words.
            slf_attn_bias_data = np.ones((inst_data.shape[0], max_len, max_len))
228 229
            slf_attn_bias_data = np.triu(slf_attn_bias_data,
                                         1).reshape([-1, 1, max_len, max_len])
230 231 232 233 234 235 236 237 238 239 240 241 242
            slf_attn_bias_data = np.tile(slf_attn_bias_data,
                                         [1, n_head, 1, 1]) * [-1e9]
        else:
            # This is used to avoid attention on paddings.
            slf_attn_bias_data = np.array([[0] * len(inst) + [-1e9] *
                                           (max_len - len(inst))
                                           for inst in insts])
            slf_attn_bias_data = np.tile(
                slf_attn_bias_data.reshape([-1, 1, 1, max_len]),
                [1, n_head, max_len, 1])
        return_list += [slf_attn_bias_data.astype("float32")]
    if return_max_len:
        return_list += [max_len]
243
    if return_num_token:
G
guosheng 已提交
244 245 246
        num_token = 0
        for inst in insts:
            num_token += len(inst)
247
        return_list += [num_token]
248 249 250
    return return_list if len(return_list) > 1 else return_list[0]


251 252
def prepare_batch_input(insts, data_input_names, src_pad_idx, trg_pad_idx,
                        n_head, d_model):
253 254
    """
    Put all padded data needed by training into a dict.
255
    """
256
    src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
G
guosheng 已提交
257
        [inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
258 259
    src_word = src_word.reshape(-1, src_max_len, 1)
    src_pos = src_pos.reshape(-1, src_max_len, 1)
260
    trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
G
guosheng 已提交
261
        [inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True)
262 263 264
    trg_word = trg_word.reshape(-1, trg_max_len, 1)
    trg_pos = trg_pos.reshape(-1, trg_max_len, 1)

265 266
    trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
                                [1, 1, trg_max_len, 1]).astype("float32")
267

268
    lbl_word, lbl_weight, num_token = pad_batch_data(
269 270 271 272 273 274
        [inst[2] for inst in insts],
        trg_pad_idx,
        n_head,
        is_target=False,
        is_label=True,
        return_attn_bias=False,
275 276 277 278 279 280 281
        return_max_len=False,
        return_num_token=True)

    data_input_dict = dict(
        zip(data_input_names, [
            src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
            trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight
282
        ]))
283

284
    return data_input_dict, np.asarray([num_token], dtype="float32")
285 286


287 288 289 290 291 292
def prepare_data_generator(args,
                           is_test,
                           count,
                           pyreader,
                           py_reader_provider_wrapper,
                           place=None):
Q
Qiao Longfei 已提交
293
    """
294 295
    Data generator wrapper for DataReader. If use py_reader, set the data
    provider for py_reader
Q
Qiao Longfei 已提交
296
    """
297 298 299 300
    # NOTE: If num_trainers > 1, the shuffle_seed must be set, because
    # the order of batch data generated by reader
    # must be the same in the respective processes.
    shuffle_seed = 1 if num_trainers > 1 else None
301 302
    data_reader = reader.DataReader(
        fpattern=args.val_file_pattern if is_test else args.train_file_pattern,
Q
Qiao Longfei 已提交
303 304
        src_vocab_fpath=args.src_vocab_fpath,
        trg_vocab_fpath=args.trg_vocab_fpath,
305
        token_delimiter=args.token_delimiter,
Q
Qiao Longfei 已提交
306
        use_token_batch=args.use_token_batch,
307
        batch_size=args.batch_size * (1 if args.use_token_batch else count),
Q
Qiao Longfei 已提交
308 309
        pool_size=args.pool_size,
        sort_type=args.sort_type,
310
        shuffle=args.shuffle,
311
        shuffle_seed=shuffle_seed,
312
        shuffle_batch=args.shuffle_batch,
Q
Qiao Longfei 已提交
313 314 315 316 317
        start_mark=args.special_token[0],
        end_mark=args.special_token[1],
        unk_mark=args.special_token[2],
        # count start and end tokens out
        max_length=ModelHyperParams.max_length - 2,
318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356
        clip_last_batch=False).batch_generator

    def stack(data_reader, count, clip_last=True):
        def __impl__():
            res = []
            for item in data_reader():
                res.append(item)
                if len(res) == count:
                    yield res
                    res = []
            if len(res) == count:
                yield res
            elif not clip_last:
                data = []
                for item in res:
                    data += item
                if len(data) > count:
                    inst_num_per_part = len(data) // count
                    yield [
                        data[inst_num_per_part * i:inst_num_per_part * (i + 1)]
                        for i in range(count)
                    ]

        return __impl__

    def split(data_reader, count):
        def __impl__():
            for item in data_reader():
                inst_num_per_part = len(item) // count
                for i in range(count):
                    yield item[inst_num_per_part * i:inst_num_per_part * (i + 1
                                                                          )]

        return __impl__

    if not args.use_token_batch:
        # to make data on each device have similar token number
        data_reader = split(data_reader, count)
    if args.use_py_reader:
357 358 359 360 361 362
        train_reader = py_reader_provider_wrapper(data_reader, place)
        if num_trainers > 1:
            assert shuffle_seed is not None
            train_reader = fluid.contrib.reader.distributed_batch_reader(
                train_reader)
        pyreader.decorate_tensor_provider(train_reader)
363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399
        data_reader = None
    else:  # Data generator for multi-devices
        data_reader = stack(data_reader, count)
    return data_reader


def prepare_feed_dict_list(data_generator, init_flag, count):
    """
    Prepare the list of feed dict for multi-devices.
    """
    feed_dict_list = []
    if data_generator is not None:  # use_py_reader == False
        data_input_names = encoder_data_input_fields + \
                    decoder_data_input_fields[:-1] + label_data_input_fields
        data = next(data_generator)
        for idx, data_buffer in enumerate(data):
            data_input_dict, num_token = prepare_batch_input(
                data_buffer, data_input_names, ModelHyperParams.eos_idx,
                ModelHyperParams.eos_idx, ModelHyperParams.n_head,
                ModelHyperParams.d_model)
            feed_dict_list.append(data_input_dict)
    if init_flag:
        for idx in range(count):
            pos_enc_tables = dict()
            for pos_enc_param_name in pos_enc_param_names:
                pos_enc_tables[pos_enc_param_name] = position_encoding_init(
                    ModelHyperParams.max_length + 1, ModelHyperParams.d_model)
            if len(feed_dict_list) <= idx:
                feed_dict_list.append(pos_enc_tables)
            else:
                feed_dict_list[idx] = dict(
                    list(pos_enc_tables.items()) + list(feed_dict_list[idx]
                                                        .items()))

    return feed_dict_list if len(feed_dict_list) == count else None


400
def py_reader_provider_wrapper(data_reader, place):
401 402 403
    """
    Data provider needed by fluid.layers.py_reader.
    """
Q
Qiao Longfei 已提交
404

405 406 407 408 409 410 411 412
    def py_reader_provider():
        data_input_names = encoder_data_input_fields + \
                    decoder_data_input_fields[:-1] + label_data_input_fields
        for batch_id, data in enumerate(data_reader()):
            data_input_dict, num_token = prepare_batch_input(
                data, data_input_names, ModelHyperParams.eos_idx,
                ModelHyperParams.eos_idx, ModelHyperParams.n_head,
                ModelHyperParams.d_model)
C
chengduo 已提交
413 414
            total_dict = dict(data_input_dict.items())
            yield [total_dict[item] for item in data_input_names]
415 416 417 418 419 420 421

    return py_reader_provider


def test_context(exe, train_exe, dev_count):
    # Context to do validation.
    test_prog = fluid.Program()
G
guosheng 已提交
422 423 424 425
    startup_prog = fluid.Program()
    if args.enable_ce:
        test_prog.random_seed = 1000
        startup_prog.random_seed = 1000
426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446
    with fluid.program_guard(test_prog, startup_prog):
        with fluid.unique_name.guard():
            sum_cost, avg_cost, predict, token_num, pyreader = transformer(
                ModelHyperParams.src_vocab_size,
                ModelHyperParams.trg_vocab_size,
                ModelHyperParams.max_length + 1,
                ModelHyperParams.n_layer,
                ModelHyperParams.n_head,
                ModelHyperParams.d_key,
                ModelHyperParams.d_value,
                ModelHyperParams.d_model,
                ModelHyperParams.d_inner_hid,
                ModelHyperParams.prepostprocess_dropout,
                ModelHyperParams.attention_dropout,
                ModelHyperParams.relu_dropout,
                ModelHyperParams.preprocess_cmd,
                ModelHyperParams.postprocess_cmd,
                ModelHyperParams.weight_sharing,
                TrainTaskConfig.label_smooth_eps,
                use_py_reader=args.use_py_reader,
                is_test=True)
G
guosheng 已提交
447
    test_prog = test_prog.clone(for_test=True)
448
    test_data = prepare_data_generator(
449 450 451 452 453
        args,
        is_test=True,
        count=dev_count,
        pyreader=pyreader,
        py_reader_provider_wrapper=py_reader_provider_wrapper)
454

455 456 457 458 459 460
    exe.run(startup_prog)  # to init pyreader for testing
    if TrainTaskConfig.ckpt_path:
        fluid.io.load_persistables(
            exe, TrainTaskConfig.ckpt_path, main_program=test_prog)

    build_strategy = fluid.BuildStrategy()
Q
Qiao Longfei 已提交
461 462
    test_exe = fluid.ParallelExecutor(
        use_cuda=TrainTaskConfig.use_gpu,
463
        main_program=test_prog,
464
        build_strategy=build_strategy,
Q
Qiao Longfei 已提交
465 466
        share_vars_from=train_exe)

467
    def test(exe=test_exe, pyreader=pyreader):
Q
Qiao Longfei 已提交
468 469
        test_total_cost = 0
        test_total_token = 0
470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486

        if args.use_py_reader:
            pyreader.start()
            data_generator = None
        else:
            data_generator = test_data()
        while True:
            try:
                feed_dict_list = prepare_feed_dict_list(data_generator, False,
                                                        dev_count)
                outs = test_exe.run(fetch_list=[sum_cost.name, token_num.name],
                                    feed=feed_dict_list)
            except (StopIteration, fluid.core.EOFException):
                # The current pass is over.
                if args.use_py_reader:
                    pyreader.reset()
                break
Q
Qiao Longfei 已提交
487 488 489 490 491 492 493 494 495 496
            sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[1])
            test_total_cost += sum_cost_val.sum()
            test_total_token += token_num_val.sum()
        test_avg_cost = test_total_cost / test_total_token
        test_ppl = np.exp([min(test_avg_cost, 100)])
        return test_avg_cost, test_ppl

    return test


497 498 499 500 501 502 503 504 505 506 507
def train_loop(exe,
               train_prog,
               startup_prog,
               dev_count,
               sum_cost,
               avg_cost,
               token_num,
               predict,
               pyreader,
               nccl2_num_trainers=1,
               nccl2_trainer_id=0):
Q
Qiao Longfei 已提交
508 509
    # Initialize the parameters.
    if TrainTaskConfig.ckpt_path:
510 511 512 513 514
        exe.run(startup_prog)  # to init pyreader for training
        logging.info("load checkpoint from {}".format(
            TrainTaskConfig.ckpt_path))
        fluid.io.load_persistables(
            exe, TrainTaskConfig.ckpt_path, main_program=train_prog)
Q
Qiao Longfei 已提交
515
    else:
G
fix  
gongweibao 已提交
516
        logging.info("init fluid.framework.default_startup_program")
517
        exe.run(startup_prog)
Q
Qiao Longfei 已提交
518

G
fix  
gongweibao 已提交
519
    logging.info("begin reader")
520
    train_data = prepare_data_generator(
521 522 523 524 525
        args,
        is_test=False,
        count=dev_count,
        pyreader=pyreader,
        py_reader_provider_wrapper=py_reader_provider_wrapper)
Q
Qiao Longfei 已提交
526

527 528
    # For faster executor
    exec_strategy = fluid.ExecutionStrategy()
529
    exec_strategy.num_iteration_per_drop_scope = int(args.fetch_steps)
Q
Qiao Longfei 已提交
530
    build_strategy = fluid.BuildStrategy()
C
chengduo 已提交
531 532 533 534 535
    build_strategy.memory_optimize = False
    build_strategy.enable_inplace = True

    sum_cost.persistable = True
    token_num.persistable = True
Q
Qiao Longfei 已提交
536 537 538
    # Since the token number differs among devices, customize gradient scale to
    # use token average cost among multi-devices. and the gradient scale is
    # `1 / token_number` for average cost.
G
guosheng 已提交
539
    # build_strategy.gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.Customized
C
chengduo 已提交
540
    build_strategy.fuse_all_optimizer_ops = True
G
fix  
gongweibao 已提交
541

542 543 544 545
    if num_trainers > 1 and args.use_py_reader and TrainTaskConfig.use_gpu:
        dist_utils.prepare_for_multi_process(exe, build_strategy, train_prog)
        exec_strategy.num_threads = 1

G
fix  
gongweibao 已提交
546
    logging.info("begin executor")
Q
Qiao Longfei 已提交
547 548
    train_exe = fluid.ParallelExecutor(
        use_cuda=TrainTaskConfig.use_gpu,
549 550 551
        loss_name=avg_cost.name,
        main_program=train_prog,
        build_strategy=build_strategy,
G
fix  
gongweibao 已提交
552
        exec_strategy=exec_strategy,
553 554
        num_trainers=nccl2_num_trainers,
        trainer_id=nccl2_trainer_id)
Q
Qiao Longfei 已提交
555 556

    if args.val_file_pattern is not None:
557
        test = test_context(exe, train_exe, dev_count)
Q
Qiao Longfei 已提交
558

G
guosheng 已提交
559 560 561 562 563 564
    # the best cross-entropy value with label smoothing
    loss_normalizer = -((1. - TrainTaskConfig.label_smooth_eps) * np.log(
        (1. - TrainTaskConfig.label_smooth_eps
         )) + TrainTaskConfig.label_smooth_eps *
                        np.log(TrainTaskConfig.label_smooth_eps / (
                            ModelHyperParams.trg_vocab_size - 1) + 1e-20))
G
guosheng 已提交
565

M
minqiyang 已提交
566
    step_idx = 0
567
    init_flag = True
G
fix  
gongweibao 已提交
568
    logging.info("begin train")
G
guosheng 已提交
569
    for pass_id in six.moves.xrange(TrainTaskConfig.pass_num):
Q
Qiao Longfei 已提交
570
        pass_start_time = time.time()
571 572 573 574 575 576 577 578 579 580 581 582 583

        if args.use_py_reader:
            pyreader.start()
            data_generator = None
        else:
            data_generator = train_data()

        batch_id = 0
        while True:
            try:
                feed_dict_list = prepare_feed_dict_list(data_generator,
                                                        init_flag, dev_count)
                outs = train_exe.run(
584
                    fetch_list=[sum_cost.name, token_num.name]
G
guosheng 已提交
585
                    if step_idx % args.fetch_steps == 0 else [],
586
                    feed=feed_dict_list)
587

G
guosheng 已提交
588
                if step_idx % args.fetch_steps == 0:
589 590
                    sum_cost_val, token_num_val = np.array(outs[0]), np.array(
                        outs[1])
G
fix  
gongweibao 已提交
591 592 593 594 595
                    # sum the cost from multi-devices
                    total_sum_cost = sum_cost_val.sum()
                    total_token_num = token_num_val.sum()
                    total_avg_cost = total_sum_cost / total_token_num

G
guosheng 已提交
596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614
                    if step_idx == 0:
                        logging.info(
                            "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
                            "normalized loss: %f, ppl: %f" %
                            (step_idx, pass_id, batch_id, total_avg_cost,
                             total_avg_cost - loss_normalizer,
                             np.exp([min(total_avg_cost, 100)])))
                        avg_batch_time = time.time()
                    else:
                        logging.info(
                            "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
                            "normalized loss: %f, ppl: %f, speed: %.2f step/s" %
                            (step_idx, pass_id, batch_id, total_avg_cost,
                             total_avg_cost - loss_normalizer,
                             np.exp([min(total_avg_cost, 100)]),
                             args.fetch_steps / (time.time() - avg_batch_time)))
                        avg_batch_time = time.time()

                if step_idx % TrainTaskConfig.save_freq == 0 and step_idx > 0:
615 616 617 618 619 620 621 622 623
                    fluid.io.save_persistables(
                        exe,
                        os.path.join(TrainTaskConfig.ckpt_dir,
                                     "latest.checkpoint"), train_prog)
                    fluid.io.save_params(
                        exe,
                        os.path.join(TrainTaskConfig.model_dir,
                                     "iter_" + str(step_idx) + ".infer.model"),
                        train_prog)
G
guosheng 已提交
624

625 626 627 628 629 630 631 632
                init_flag = False
                batch_id += 1
                step_idx += 1
            except (StopIteration, fluid.core.EOFException):
                # The current pass is over.
                if args.use_py_reader:
                    pyreader.reset()
                break
G
guosheng 已提交
633 634

        time_consumed = time.time() - pass_start_time
635
        # Validate and save the persistable.
G
guosheng 已提交
636 637
        if args.val_file_pattern is not None:
            val_avg_cost, val_ppl = test()
G
fix  
gongweibao 已提交
638
            logging.info(
G
guosheng 已提交
639 640 641 642 643
                "epoch: %d, val avg loss: %f, val normalized loss: %f, val ppl: %f,"
                " consumed %fs" % (pass_id, val_avg_cost,
                                   val_avg_cost - loss_normalizer, val_ppl,
                                   time_consumed))
        else:
G
fix  
gongweibao 已提交
644
            logging.info("epoch: %d, consumed %fs" % (pass_id, time_consumed))
645

G
guosheng 已提交
646 647 648 649 650 651
        if not args.enable_ce:
            fluid.io.save_persistables(
                exe,
                os.path.join(TrainTaskConfig.ckpt_dir,
                             "pass_" + str(pass_id) + ".checkpoint"),
                train_prog)
652

G
guosheng 已提交
653
    if args.enable_ce:  # For CE
654
        print("kpis\ttrain_cost_card%d\t%f" % (dev_count, total_avg_cost))
655 656
        if args.val_file_pattern is not None:
            print("kpis\ttest_cost_card%d\t%f" % (dev_count, val_avg_cost))
657
        print("kpis\ttrain_duration_card%d\t%f" % (dev_count, time_consumed))
Q
Qiao Longfei 已提交
658 659


660 661 662 663 664
def train(args):
    # priority: ENV > args > config
    is_local = os.getenv("PADDLE_IS_LOCAL", "1")
    if is_local == '0':
        args.local = False
G
fix  
gongweibao 已提交
665
    logging.info(args)
666

667 668
    if args.device == 'CPU':
        TrainTaskConfig.use_gpu = False
G
guosheng 已提交
669

670
    training_role = os.getenv("TRAINING_ROLE", "TRAINER")
G
guosheng 已提交
671

672 673
    if training_role == "PSERVER" or (not TrainTaskConfig.use_gpu):
        place = fluid.CPUPlace()
674 675
        # the default setting of CPU_NUM in paddle framework is 1
        dev_count = int(os.environ.get('CPU_NUM', 1))
676
    else:
677
        check_cuda(TrainTaskConfig.use_gpu)
C
chengduo 已提交
678 679
        gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
        place = fluid.CUDAPlace(gpu_id)
680
        dev_count = get_device_num()
681 682

    exe = fluid.Executor(place)
683

684 685
    train_prog = fluid.Program()
    startup_prog = fluid.Program()
G
guosheng 已提交
686

687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709
    if args.enable_ce:
        train_prog.random_seed = 1000
        startup_prog.random_seed = 1000

    with fluid.program_guard(train_prog, startup_prog):
        with fluid.unique_name.guard():
            sum_cost, avg_cost, predict, token_num, pyreader = transformer(
                ModelHyperParams.src_vocab_size,
                ModelHyperParams.trg_vocab_size,
                ModelHyperParams.max_length + 1,
                ModelHyperParams.n_layer,
                ModelHyperParams.n_head,
                ModelHyperParams.d_key,
                ModelHyperParams.d_value,
                ModelHyperParams.d_model,
                ModelHyperParams.d_inner_hid,
                ModelHyperParams.prepostprocess_dropout,
                ModelHyperParams.attention_dropout,
                ModelHyperParams.relu_dropout,
                ModelHyperParams.preprocess_cmd,
                ModelHyperParams.postprocess_cmd,
                ModelHyperParams.weight_sharing,
                TrainTaskConfig.label_smooth_eps,
Y
Yibing Liu 已提交
710
                ModelHyperParams.bos_idx,
711 712
                use_py_reader=args.use_py_reader,
                is_test=False)
713

714
            optimizer = None
G
fix bug  
gongweibao 已提交
715
            if args.sync:
716 717
                lr_decay = fluid.layers.learning_rate_scheduler.noam_decay(
                    ModelHyperParams.d_model, TrainTaskConfig.warmup_steps)
718
                logging.info("before adam")
G
fix  
gongweibao 已提交
719 720 721 722

                with fluid.default_main_program()._lr_schedule_guard():
                    learning_rate = lr_decay * TrainTaskConfig.learning_rate

723
                optimizer = fluid.optimizer.Adam(
G
fix  
gongweibao 已提交
724
                    learning_rate=learning_rate,
725 726 727
                    beta1=TrainTaskConfig.beta1,
                    beta2=TrainTaskConfig.beta2,
                    epsilon=TrainTaskConfig.eps)
G
fix bug  
gongweibao 已提交
728
            else:
729 730 731
                optimizer = fluid.optimizer.SGD(0.003)
            optimizer.minimize(avg_cost)

732
    if args.local:
733
        logging.info("local start_up:")
734 735
        train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, avg_cost,
                   token_num, predict, pyreader)
736
    else:
G
fix  
gongweibao 已提交
737 738 739 740 741 742 743 744 745 746 747 748
        if args.update_method == "nccl2":
            trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
            port = os.getenv("PADDLE_PORT")
            worker_ips = os.getenv("PADDLE_TRAINERS")
            worker_endpoints = []
            for ip in worker_ips.split(","):
                worker_endpoints.append(':'.join([ip, port]))
            trainers_num = len(worker_endpoints)
            current_endpoint = os.getenv("POD_IP") + ":" + port
            if trainer_id == 0:
                logging.info("train_id == 0, sleep 60s")
                time.sleep(60)
749 750 751
            logging.info("trainers_num:{}".format(trainers_num))
            logging.info("worker_endpoints:{}".format(worker_endpoints))
            logging.info("current_endpoint:{}".format(current_endpoint))
G
guosheng 已提交
752 753 754 755 756
            append_nccl2_prepare(startup_prog, trainer_id, worker_endpoints,
                                 current_endpoint)
            train_loop(exe, train_prog, startup_prog, dev_count, sum_cost,
                       avg_cost, token_num, predict, pyreader, trainers_num,
                       trainer_id)
G
fix  
gongweibao 已提交
757 758
            return

759 760 761 762 763 764 765 766 767
        port = os.getenv("PADDLE_PORT", "6174")
        pserver_ips = os.getenv("PADDLE_PSERVERS")  # ip,ip...
        eplist = []
        for ip in pserver_ips.split(","):
            eplist.append(':'.join([ip, port]))
        pserver_endpoints = ",".join(eplist)  # ip:port,ip:port...
        trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "0"))
        current_endpoint = os.getenv("POD_IP") + ":" + port
        trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
G
fix  
gongweibao 已提交
768

769 770 771 772 773
        logging.info("pserver_endpoints:{}".format(pserver_endpoints))
        logging.info("current_endpoint:{}".format(current_endpoint))
        logging.info("trainer_id:{}".format(trainer_id))
        logging.info("pserver_ips:{}".format(pserver_ips))
        logging.info("port:{}".format(port))
G
fix  
gongweibao 已提交
774

775
        t = fluid.DistributeTranspiler()
776 777 778 779 780 781
        t.transpile(
            trainer_id,
            pservers=pserver_endpoints,
            trainers=trainers,
            program=train_prog,
            startup_program=startup_prog)
782 783

        if training_role == "PSERVER":
G
fix bug  
gongweibao 已提交
784
            logging.info("distributed: pserver started")
785 786 787
            current_endpoint = os.getenv("POD_IP") + ":" + os.getenv(
                "PADDLE_PORT")
            if not current_endpoint:
788
                logging.critical("need env SERVER_ENDPOINT")
789 790 791 792 793 794 795 796
                exit(1)
            pserver_prog = t.get_pserver_program(current_endpoint)
            pserver_startup = t.get_startup_program(current_endpoint,
                                                    pserver_prog)

            exe.run(pserver_startup)
            exe.run(pserver_prog)
        elif training_role == "TRAINER":
G
fix bug  
gongweibao 已提交
797
            logging.info("distributed: trainer started")
798
            trainer_prog = t.get_trainer_program()
G
fix  
gongweibao 已提交
799

800 801
            train_loop(exe, train_prog, startup_prog, dev_count, sum_cost,
                       avg_cost, token_num, predict, pyreader)
802
        else:
803 804
            logging.critical(
                "environment var TRAINER_ROLE should be TRAINER os PSERVER")
G
fix  
gongweibao 已提交
805
            exit(1)
806 807 808


if __name__ == "__main__":
G
fix  
gongweibao 已提交
809
    LOG_FORMAT = "[%(asctime)s %(levelname)s %(filename)s:%(lineno)d] %(message)s"
810 811
    logging.basicConfig(
        stream=sys.stdout, level=logging.DEBUG, format=LOG_FORMAT)
812
    logging.getLogger().setLevel(logging.INFO)
G
fix  
gongweibao 已提交
813

814
    args = parse_args()
815
    train(args)