train.py 30.7 KB
Newer Older
1 2
import argparse
import ast
G
guosheng 已提交
3 4
import copy
import logging
5
import multiprocessing
Y
Yu Yang 已提交
6
import os
G
guosheng 已提交
7
import six
G
guosheng 已提交
8
import sys
Y
Yibing Liu 已提交
9
sys.path.append("../../models/neural_machine_translation/transformer/")
Y
Yu Yang 已提交
10
import time
Y
ying 已提交
11

Y
Yu Yang 已提交
12
import numpy as np
L
Luo Tao 已提交
13
import paddle.fluid as fluid
Y
ying 已提交
14

Y
Yu Yang 已提交
15 16
import reader
from config import *
Y
Yibing Liu 已提交
17
from desc import *
18
from model import transformer, position_encoding_init
19 20 21
import dist_utils

num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53


def parse_args():
    parser = argparse.ArgumentParser("Training for Transformer.")
    parser.add_argument(
        "--src_vocab_fpath",
        type=str,
        required=True,
        help="The path of vocabulary file of source language.")
    parser.add_argument(
        "--trg_vocab_fpath",
        type=str,
        required=True,
        help="The path of vocabulary file of target language.")
    parser.add_argument(
        "--train_file_pattern",
        type=str,
        required=True,
        help="The pattern to match training data files.")
    parser.add_argument(
        "--val_file_pattern",
        type=str,
        help="The pattern to match validation data files.")
    parser.add_argument(
        "--use_token_batch",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to "
        "produce batch data according to token number.")
    parser.add_argument(
        "--batch_size",
        type=int,
54
        default=4096,
55
        help="The number of sequences contained in a mini-batch, or the maximum "
56 57 58
        "number of tokens (include paddings) contained in a mini-batch. Note "
        "that this represents the number on single device and the actual batch "
        "size for multi-devices will multiply the device number.")
59 60 61
    parser.add_argument(
        "--pool_size",
        type=int,
62
        default=200000,
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
        help="The buffer size to pool data.")
    parser.add_argument(
        "--sort_type",
        default="pool",
        choices=("global", "pool", "none"),
        help="The grain to sort by length: global for all instances; pool for "
        "instances in pool; none for no sort.")
    parser.add_argument(
        "--shuffle",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to shuffle instances in each pass.")
    parser.add_argument(
        "--shuffle_batch",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to shuffle the data batches.")
    parser.add_argument(
        "--special_token",
        type=str,
        default=["<s>", "<e>", "<unk>"],
        nargs=3,
        help="The <bos>, <eos> and <unk> tokens in the dictionary.")
86 87
    parser.add_argument(
        "--token_delimiter",
G
guosheng 已提交
88
        type=lambda x: str(x.encode().decode("unicode-escape")),
89 90
        default=" ",
        help="The delimiter used to split tokens in source or target sentences. "
91
        "For EN-DE BPE data we provided, use spaces as token delimiter. ")
92 93 94 95 96
    parser.add_argument(
        'opts',
        help='See config.py for all options',
        default=None,
        nargs=argparse.REMAINDER)
97 98 99 100 101 102 103 104 105 106 107
    parser.add_argument(
        '--local',
        type=ast.literal_eval,
        default=True,
        help='Whether to run as local mode.')
    parser.add_argument(
        '--device',
        type=str,
        default='GPU',
        choices=['CPU', 'GPU'],
        help="The device type.")
G
fix  
gongweibao 已提交
108 109 110 111 112
    parser.add_argument(
        '--update_method',
        choices=("pserver", "nccl2"),
        default="pserver",
        help='Update method.')
Q
Qiao Longfei 已提交
113 114
    parser.add_argument(
        '--sync', type=ast.literal_eval, default=True, help="sync mode.")
G
guosheng 已提交
115 116 117
    parser.add_argument(
        "--enable_ce",
        type=ast.literal_eval,
118
        default=False,
G
guosheng 已提交
119 120
        help="The flag indicating whether to run the task "
        "for continuous evaluation.")
121 122 123
    parser.add_argument(
        "--use_mem_opt",
        type=ast.literal_eval,
G
guosheng 已提交
124
        default=True,
125 126 127 128 129 130
        help="The flag indicating whether to use memory optimization.")
    parser.add_argument(
        "--use_py_reader",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to use py_reader.")
G
fix  
gongweibao 已提交
131
    parser.add_argument(
G
guosheng 已提交
132 133 134 135
        "--fetch_steps",
        type=int,
        default=100,
        help="The frequency to fetch and print output.")
G
fix  
gongweibao 已提交
136

137
    args = parser.parse_args()
138 139 140 141 142 143 144 145 146 147 148
    # Append args related to dict
    src_dict = reader.DataReader.load_dict(args.src_vocab_fpath)
    trg_dict = reader.DataReader.load_dict(args.trg_vocab_fpath)
    dict_args = [
        "src_vocab_size", str(len(src_dict)), "trg_vocab_size",
        str(len(trg_dict)), "bos_idx", str(src_dict[args.special_token[0]]),
        "eos_idx", str(src_dict[args.special_token[1]]), "unk_idx",
        str(src_dict[args.special_token[2]])
    ]
    merge_cfg_from_list(args.opts + dict_args,
                        [TrainTaskConfig, ModelHyperParams])
149
    return args
150 151


152 153 154 155 156 157 158 159 160 161 162 163
def get_device_num():
    # NOTE(zcd): for multi-processe training, each process use one GPU card.
    if num_trainers > 1: return 1
    visible_device = os.environ.get('CUDA_VISIBLE_DEVICES', None)
    if visible_device:
        device_num = len(visible_device.split(','))
    else:
        device_num = subprocess.check_output(
            ['nvidia-smi', '-L']).decode().count('\n')
    return device_num


G
guosheng 已提交
164 165
def append_nccl2_prepare(startup_prog, trainer_id, worker_endpoints,
                         current_endpoint):
166 167
    assert (trainer_id >= 0 and len(worker_endpoints) > 1 and
            current_endpoint in worker_endpoints)
G
fix  
gongweibao 已提交
168 169
    eps = copy.deepcopy(worker_endpoints)
    eps.remove(current_endpoint)
G
guosheng 已提交
170
    nccl_id_var = startup_prog.global_block().create_var(
171
        name="NCCLID", persistable=True, type=fluid.core.VarDesc.VarType.RAW)
G
guosheng 已提交
172
    startup_prog.global_block().append_op(
G
fix  
gongweibao 已提交
173 174 175 176 177 178 179 180 181
        type="gen_nccl_id",
        inputs={},
        outputs={"NCCLID": nccl_id_var},
        attrs={
            "endpoint": current_endpoint,
            "endpoint_list": eps,
            "trainer_id": trainer_id
        })
    return nccl_id_var
182

183

184 185 186 187
def pad_batch_data(insts,
                   pad_idx,
                   n_head,
                   is_target=False,
188
                   is_label=False,
189
                   return_attn_bias=True,
190 191
                   return_max_len=True,
                   return_num_token=False):
192 193
    """
    Pad the instances to the max sequence length in batch, and generate the
194 195 196 197
    corresponding position data and attention bias.
    """
    return_list = []
    max_len = max(len(inst) for inst in insts)
G
guosheng 已提交
198 199 200 201
    # Any token included in dict can be used to pad, since the paddings' loss
    # will be masked out by weights and make no effect on parameter gradients.
    inst_data = np.array(
        [inst + [pad_idx] * (max_len - len(inst)) for inst in insts])
202
    return_list += [inst_data.astype("int64").reshape([-1, 1])]
203 204 205 206 207 208
    if is_label:  # label weight
        inst_weight = np.array(
            [[1.] * len(inst) + [0.] * (max_len - len(inst)) for inst in insts])
        return_list += [inst_weight.astype("float32").reshape([-1, 1])]
    else:  # position data
        inst_pos = np.array([
209
            list(range(0, len(inst))) + [0] * (max_len - len(inst))
210 211
            for inst in insts
        ])
212 213 214 215 216 217
        return_list += [inst_pos.astype("int64").reshape([-1, 1])]
    if return_attn_bias:
        if is_target:
            # This is used to avoid attention on paddings and subsequent
            # words.
            slf_attn_bias_data = np.ones((inst_data.shape[0], max_len, max_len))
218 219
            slf_attn_bias_data = np.triu(slf_attn_bias_data,
                                         1).reshape([-1, 1, max_len, max_len])
220 221 222 223 224 225 226 227 228 229 230 231 232
            slf_attn_bias_data = np.tile(slf_attn_bias_data,
                                         [1, n_head, 1, 1]) * [-1e9]
        else:
            # This is used to avoid attention on paddings.
            slf_attn_bias_data = np.array([[0] * len(inst) + [-1e9] *
                                           (max_len - len(inst))
                                           for inst in insts])
            slf_attn_bias_data = np.tile(
                slf_attn_bias_data.reshape([-1, 1, 1, max_len]),
                [1, n_head, max_len, 1])
        return_list += [slf_attn_bias_data.astype("float32")]
    if return_max_len:
        return_list += [max_len]
233
    if return_num_token:
G
guosheng 已提交
234 235 236
        num_token = 0
        for inst in insts:
            num_token += len(inst)
237
        return_list += [num_token]
238 239 240
    return return_list if len(return_list) > 1 else return_list[0]


241 242
def prepare_batch_input(insts, data_input_names, src_pad_idx, trg_pad_idx,
                        n_head, d_model):
243 244
    """
    Put all padded data needed by training into a dict.
245
    """
246
    src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
G
guosheng 已提交
247
        [inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
248 249
    src_word = src_word.reshape(-1, src_max_len, 1)
    src_pos = src_pos.reshape(-1, src_max_len, 1)
250
    trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
G
guosheng 已提交
251
        [inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True)
252 253 254
    trg_word = trg_word.reshape(-1, trg_max_len, 1)
    trg_pos = trg_pos.reshape(-1, trg_max_len, 1)

255 256
    trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
                                [1, 1, trg_max_len, 1]).astype("float32")
257

258
    lbl_word, lbl_weight, num_token = pad_batch_data(
259 260 261 262 263 264
        [inst[2] for inst in insts],
        trg_pad_idx,
        n_head,
        is_target=False,
        is_label=True,
        return_attn_bias=False,
265 266 267 268 269 270 271
        return_max_len=False,
        return_num_token=True)

    data_input_dict = dict(
        zip(data_input_names, [
            src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
            trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight
272
        ]))
273

274
    return data_input_dict, np.asarray([num_token], dtype="float32")
275 276


277 278 279 280 281 282
def prepare_data_generator(args,
                           is_test,
                           count,
                           pyreader,
                           py_reader_provider_wrapper,
                           place=None):
Q
Qiao Longfei 已提交
283
    """
284 285
    Data generator wrapper for DataReader. If use py_reader, set the data
    provider for py_reader
Q
Qiao Longfei 已提交
286
    """
287 288 289 290
    # NOTE: If num_trainers > 1, the shuffle_seed must be set, because
    # the order of batch data generated by reader
    # must be the same in the respective processes.
    shuffle_seed = 1 if num_trainers > 1 else None
291 292
    data_reader = reader.DataReader(
        fpattern=args.val_file_pattern if is_test else args.train_file_pattern,
Q
Qiao Longfei 已提交
293 294
        src_vocab_fpath=args.src_vocab_fpath,
        trg_vocab_fpath=args.trg_vocab_fpath,
295
        token_delimiter=args.token_delimiter,
Q
Qiao Longfei 已提交
296
        use_token_batch=args.use_token_batch,
297
        batch_size=args.batch_size * (1 if args.use_token_batch else count),
Q
Qiao Longfei 已提交
298 299
        pool_size=args.pool_size,
        sort_type=args.sort_type,
300
        shuffle=args.shuffle,
301
        shuffle_seed=shuffle_seed,
302
        shuffle_batch=args.shuffle_batch,
Q
Qiao Longfei 已提交
303 304 305 306 307
        start_mark=args.special_token[0],
        end_mark=args.special_token[1],
        unk_mark=args.special_token[2],
        # count start and end tokens out
        max_length=ModelHyperParams.max_length - 2,
308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346
        clip_last_batch=False).batch_generator

    def stack(data_reader, count, clip_last=True):
        def __impl__():
            res = []
            for item in data_reader():
                res.append(item)
                if len(res) == count:
                    yield res
                    res = []
            if len(res) == count:
                yield res
            elif not clip_last:
                data = []
                for item in res:
                    data += item
                if len(data) > count:
                    inst_num_per_part = len(data) // count
                    yield [
                        data[inst_num_per_part * i:inst_num_per_part * (i + 1)]
                        for i in range(count)
                    ]

        return __impl__

    def split(data_reader, count):
        def __impl__():
            for item in data_reader():
                inst_num_per_part = len(item) // count
                for i in range(count):
                    yield item[inst_num_per_part * i:inst_num_per_part * (i + 1
                                                                          )]

        return __impl__

    if not args.use_token_batch:
        # to make data on each device have similar token number
        data_reader = split(data_reader, count)
    if args.use_py_reader:
347 348 349 350 351 352
        train_reader = py_reader_provider_wrapper(data_reader, place)
        if num_trainers > 1:
            assert shuffle_seed is not None
            train_reader = fluid.contrib.reader.distributed_batch_reader(
                train_reader)
        pyreader.decorate_tensor_provider(train_reader)
353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389
        data_reader = None
    else:  # Data generator for multi-devices
        data_reader = stack(data_reader, count)
    return data_reader


def prepare_feed_dict_list(data_generator, init_flag, count):
    """
    Prepare the list of feed dict for multi-devices.
    """
    feed_dict_list = []
    if data_generator is not None:  # use_py_reader == False
        data_input_names = encoder_data_input_fields + \
                    decoder_data_input_fields[:-1] + label_data_input_fields
        data = next(data_generator)
        for idx, data_buffer in enumerate(data):
            data_input_dict, num_token = prepare_batch_input(
                data_buffer, data_input_names, ModelHyperParams.eos_idx,
                ModelHyperParams.eos_idx, ModelHyperParams.n_head,
                ModelHyperParams.d_model)
            feed_dict_list.append(data_input_dict)
    if init_flag:
        for idx in range(count):
            pos_enc_tables = dict()
            for pos_enc_param_name in pos_enc_param_names:
                pos_enc_tables[pos_enc_param_name] = position_encoding_init(
                    ModelHyperParams.max_length + 1, ModelHyperParams.d_model)
            if len(feed_dict_list) <= idx:
                feed_dict_list.append(pos_enc_tables)
            else:
                feed_dict_list[idx] = dict(
                    list(pos_enc_tables.items()) + list(feed_dict_list[idx]
                                                        .items()))

    return feed_dict_list if len(feed_dict_list) == count else None


390
def py_reader_provider_wrapper(data_reader, place):
391 392 393
    """
    Data provider needed by fluid.layers.py_reader.
    """
Q
Qiao Longfei 已提交
394

395 396 397 398 399 400 401 402
    def py_reader_provider():
        data_input_names = encoder_data_input_fields + \
                    decoder_data_input_fields[:-1] + label_data_input_fields
        for batch_id, data in enumerate(data_reader()):
            data_input_dict, num_token = prepare_batch_input(
                data, data_input_names, ModelHyperParams.eos_idx,
                ModelHyperParams.eos_idx, ModelHyperParams.n_head,
                ModelHyperParams.d_model)
C
chengduo 已提交
403 404
            total_dict = dict(data_input_dict.items())
            yield [total_dict[item] for item in data_input_names]
405 406 407 408 409 410 411

    return py_reader_provider


def test_context(exe, train_exe, dev_count):
    # Context to do validation.
    test_prog = fluid.Program()
G
guosheng 已提交
412 413 414 415
    startup_prog = fluid.Program()
    if args.enable_ce:
        test_prog.random_seed = 1000
        startup_prog.random_seed = 1000
416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436
    with fluid.program_guard(test_prog, startup_prog):
        with fluid.unique_name.guard():
            sum_cost, avg_cost, predict, token_num, pyreader = transformer(
                ModelHyperParams.src_vocab_size,
                ModelHyperParams.trg_vocab_size,
                ModelHyperParams.max_length + 1,
                ModelHyperParams.n_layer,
                ModelHyperParams.n_head,
                ModelHyperParams.d_key,
                ModelHyperParams.d_value,
                ModelHyperParams.d_model,
                ModelHyperParams.d_inner_hid,
                ModelHyperParams.prepostprocess_dropout,
                ModelHyperParams.attention_dropout,
                ModelHyperParams.relu_dropout,
                ModelHyperParams.preprocess_cmd,
                ModelHyperParams.postprocess_cmd,
                ModelHyperParams.weight_sharing,
                TrainTaskConfig.label_smooth_eps,
                use_py_reader=args.use_py_reader,
                is_test=True)
G
guosheng 已提交
437
    test_prog = test_prog.clone(for_test=True)
438
    test_data = prepare_data_generator(
439 440 441 442 443
        args,
        is_test=True,
        count=dev_count,
        pyreader=pyreader,
        py_reader_provider_wrapper=py_reader_provider_wrapper)
444

445 446 447 448 449 450
    exe.run(startup_prog)  # to init pyreader for testing
    if TrainTaskConfig.ckpt_path:
        fluid.io.load_persistables(
            exe, TrainTaskConfig.ckpt_path, main_program=test_prog)

    build_strategy = fluid.BuildStrategy()
Q
Qiao Longfei 已提交
451 452
    test_exe = fluid.ParallelExecutor(
        use_cuda=TrainTaskConfig.use_gpu,
453
        main_program=test_prog,
454
        build_strategy=build_strategy,
Q
Qiao Longfei 已提交
455 456
        share_vars_from=train_exe)

457
    def test(exe=test_exe, pyreader=pyreader):
Q
Qiao Longfei 已提交
458 459
        test_total_cost = 0
        test_total_token = 0
460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476

        if args.use_py_reader:
            pyreader.start()
            data_generator = None
        else:
            data_generator = test_data()
        while True:
            try:
                feed_dict_list = prepare_feed_dict_list(data_generator, False,
                                                        dev_count)
                outs = test_exe.run(fetch_list=[sum_cost.name, token_num.name],
                                    feed=feed_dict_list)
            except (StopIteration, fluid.core.EOFException):
                # The current pass is over.
                if args.use_py_reader:
                    pyreader.reset()
                break
Q
Qiao Longfei 已提交
477 478 479 480 481 482 483 484 485 486
            sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[1])
            test_total_cost += sum_cost_val.sum()
            test_total_token += token_num_val.sum()
        test_avg_cost = test_total_cost / test_total_token
        test_ppl = np.exp([min(test_avg_cost, 100)])
        return test_avg_cost, test_ppl

    return test


487 488 489 490 491 492 493 494 495 496 497
def train_loop(exe,
               train_prog,
               startup_prog,
               dev_count,
               sum_cost,
               avg_cost,
               token_num,
               predict,
               pyreader,
               nccl2_num_trainers=1,
               nccl2_trainer_id=0):
Q
Qiao Longfei 已提交
498 499
    # Initialize the parameters.
    if TrainTaskConfig.ckpt_path:
500 501 502 503 504
        exe.run(startup_prog)  # to init pyreader for training
        logging.info("load checkpoint from {}".format(
            TrainTaskConfig.ckpt_path))
        fluid.io.load_persistables(
            exe, TrainTaskConfig.ckpt_path, main_program=train_prog)
Q
Qiao Longfei 已提交
505
    else:
G
fix  
gongweibao 已提交
506
        logging.info("init fluid.framework.default_startup_program")
507
        exe.run(startup_prog)
Q
Qiao Longfei 已提交
508

G
fix  
gongweibao 已提交
509
    logging.info("begin reader")
510
    train_data = prepare_data_generator(
511 512 513 514 515
        args,
        is_test=False,
        count=dev_count,
        pyreader=pyreader,
        py_reader_provider_wrapper=py_reader_provider_wrapper)
Q
Qiao Longfei 已提交
516

517 518
    # For faster executor
    exec_strategy = fluid.ExecutionStrategy()
519
    exec_strategy.num_iteration_per_drop_scope = int(args.fetch_steps)
Q
Qiao Longfei 已提交
520
    build_strategy = fluid.BuildStrategy()
C
chengduo 已提交
521 522 523 524 525
    build_strategy.memory_optimize = False
    build_strategy.enable_inplace = True

    sum_cost.persistable = True
    token_num.persistable = True
Q
Qiao Longfei 已提交
526 527 528
    # Since the token number differs among devices, customize gradient scale to
    # use token average cost among multi-devices. and the gradient scale is
    # `1 / token_number` for average cost.
G
guosheng 已提交
529
    # build_strategy.gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.Customized
C
chengduo 已提交
530
    build_strategy.fuse_all_optimizer_ops = True
G
fix  
gongweibao 已提交
531

532 533 534 535
    if num_trainers > 1 and args.use_py_reader and TrainTaskConfig.use_gpu:
        dist_utils.prepare_for_multi_process(exe, build_strategy, train_prog)
        exec_strategy.num_threads = 1

G
fix  
gongweibao 已提交
536
    logging.info("begin executor")
Q
Qiao Longfei 已提交
537 538
    train_exe = fluid.ParallelExecutor(
        use_cuda=TrainTaskConfig.use_gpu,
539 540 541
        loss_name=avg_cost.name,
        main_program=train_prog,
        build_strategy=build_strategy,
G
fix  
gongweibao 已提交
542
        exec_strategy=exec_strategy,
543 544
        num_trainers=nccl2_num_trainers,
        trainer_id=nccl2_trainer_id)
Q
Qiao Longfei 已提交
545 546

    if args.val_file_pattern is not None:
547
        test = test_context(exe, train_exe, dev_count)
Q
Qiao Longfei 已提交
548

G
guosheng 已提交
549 550 551 552 553 554
    # the best cross-entropy value with label smoothing
    loss_normalizer = -((1. - TrainTaskConfig.label_smooth_eps) * np.log(
        (1. - TrainTaskConfig.label_smooth_eps
         )) + TrainTaskConfig.label_smooth_eps *
                        np.log(TrainTaskConfig.label_smooth_eps / (
                            ModelHyperParams.trg_vocab_size - 1) + 1e-20))
G
guosheng 已提交
555

M
minqiyang 已提交
556
    step_idx = 0
557
    init_flag = True
G
fix  
gongweibao 已提交
558
    logging.info("begin train")
G
guosheng 已提交
559
    for pass_id in six.moves.xrange(TrainTaskConfig.pass_num):
Q
Qiao Longfei 已提交
560
        pass_start_time = time.time()
561 562 563 564 565 566 567 568 569 570 571 572 573

        if args.use_py_reader:
            pyreader.start()
            data_generator = None
        else:
            data_generator = train_data()

        batch_id = 0
        while True:
            try:
                feed_dict_list = prepare_feed_dict_list(data_generator,
                                                        init_flag, dev_count)
                outs = train_exe.run(
574
                    fetch_list=[sum_cost.name, token_num.name]
G
guosheng 已提交
575
                    if step_idx % args.fetch_steps == 0 else [],
576
                    feed=feed_dict_list)
577

G
guosheng 已提交
578
                if step_idx % args.fetch_steps == 0:
579 580
                    sum_cost_val, token_num_val = np.array(outs[0]), np.array(
                        outs[1])
G
fix  
gongweibao 已提交
581 582 583 584 585
                    # sum the cost from multi-devices
                    total_sum_cost = sum_cost_val.sum()
                    total_token_num = token_num_val.sum()
                    total_avg_cost = total_sum_cost / total_token_num

G
guosheng 已提交
586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604
                    if step_idx == 0:
                        logging.info(
                            "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
                            "normalized loss: %f, ppl: %f" %
                            (step_idx, pass_id, batch_id, total_avg_cost,
                             total_avg_cost - loss_normalizer,
                             np.exp([min(total_avg_cost, 100)])))
                        avg_batch_time = time.time()
                    else:
                        logging.info(
                            "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
                            "normalized loss: %f, ppl: %f, speed: %.2f step/s" %
                            (step_idx, pass_id, batch_id, total_avg_cost,
                             total_avg_cost - loss_normalizer,
                             np.exp([min(total_avg_cost, 100)]),
                             args.fetch_steps / (time.time() - avg_batch_time)))
                        avg_batch_time = time.time()

                if step_idx % TrainTaskConfig.save_freq == 0 and step_idx > 0:
605 606 607 608 609 610 611 612 613
                    fluid.io.save_persistables(
                        exe,
                        os.path.join(TrainTaskConfig.ckpt_dir,
                                     "latest.checkpoint"), train_prog)
                    fluid.io.save_params(
                        exe,
                        os.path.join(TrainTaskConfig.model_dir,
                                     "iter_" + str(step_idx) + ".infer.model"),
                        train_prog)
G
guosheng 已提交
614

615 616 617 618 619 620 621 622
                init_flag = False
                batch_id += 1
                step_idx += 1
            except (StopIteration, fluid.core.EOFException):
                # The current pass is over.
                if args.use_py_reader:
                    pyreader.reset()
                break
G
guosheng 已提交
623 624

        time_consumed = time.time() - pass_start_time
625
        # Validate and save the persistable.
G
guosheng 已提交
626 627
        if args.val_file_pattern is not None:
            val_avg_cost, val_ppl = test()
G
fix  
gongweibao 已提交
628
            logging.info(
G
guosheng 已提交
629 630 631 632 633
                "epoch: %d, val avg loss: %f, val normalized loss: %f, val ppl: %f,"
                " consumed %fs" % (pass_id, val_avg_cost,
                                   val_avg_cost - loss_normalizer, val_ppl,
                                   time_consumed))
        else:
G
fix  
gongweibao 已提交
634
            logging.info("epoch: %d, consumed %fs" % (pass_id, time_consumed))
635

G
guosheng 已提交
636 637 638 639 640 641
        if not args.enable_ce:
            fluid.io.save_persistables(
                exe,
                os.path.join(TrainTaskConfig.ckpt_dir,
                             "pass_" + str(pass_id) + ".checkpoint"),
                train_prog)
642

G
guosheng 已提交
643
    if args.enable_ce:  # For CE
644
        print("kpis\ttrain_cost_card%d\t%f" % (dev_count, total_avg_cost))
645 646
        if args.val_file_pattern is not None:
            print("kpis\ttest_cost_card%d\t%f" % (dev_count, val_avg_cost))
647
        print("kpis\ttrain_duration_card%d\t%f" % (dev_count, time_consumed))
Q
Qiao Longfei 已提交
648 649


650 651 652 653 654
def train(args):
    # priority: ENV > args > config
    is_local = os.getenv("PADDLE_IS_LOCAL", "1")
    if is_local == '0':
        args.local = False
G
fix  
gongweibao 已提交
655
    logging.info(args)
656

657 658
    if args.device == 'CPU':
        TrainTaskConfig.use_gpu = False
G
guosheng 已提交
659

660
    training_role = os.getenv("TRAINING_ROLE", "TRAINER")
G
guosheng 已提交
661

662 663 664 665
    if training_role == "PSERVER" or (not TrainTaskConfig.use_gpu):
        place = fluid.CPUPlace()
        dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
    else:
C
chengduo 已提交
666 667
        gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
        place = fluid.CUDAPlace(gpu_id)
668
        dev_count = get_device_num()
669 670

    exe = fluid.Executor(place)
671

672 673
    train_prog = fluid.Program()
    startup_prog = fluid.Program()
G
guosheng 已提交
674

675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697
    if args.enable_ce:
        train_prog.random_seed = 1000
        startup_prog.random_seed = 1000

    with fluid.program_guard(train_prog, startup_prog):
        with fluid.unique_name.guard():
            sum_cost, avg_cost, predict, token_num, pyreader = transformer(
                ModelHyperParams.src_vocab_size,
                ModelHyperParams.trg_vocab_size,
                ModelHyperParams.max_length + 1,
                ModelHyperParams.n_layer,
                ModelHyperParams.n_head,
                ModelHyperParams.d_key,
                ModelHyperParams.d_value,
                ModelHyperParams.d_model,
                ModelHyperParams.d_inner_hid,
                ModelHyperParams.prepostprocess_dropout,
                ModelHyperParams.attention_dropout,
                ModelHyperParams.relu_dropout,
                ModelHyperParams.preprocess_cmd,
                ModelHyperParams.postprocess_cmd,
                ModelHyperParams.weight_sharing,
                TrainTaskConfig.label_smooth_eps,
Y
Yibing Liu 已提交
698
                ModelHyperParams.bos_idx,
699 700
                use_py_reader=args.use_py_reader,
                is_test=False)
701

702
            optimizer = None
G
fix bug  
gongweibao 已提交
703
            if args.sync:
704 705
                lr_decay = fluid.layers.learning_rate_scheduler.noam_decay(
                    ModelHyperParams.d_model, TrainTaskConfig.warmup_steps)
706
                logging.info("before adam")
G
fix  
gongweibao 已提交
707 708 709 710

                with fluid.default_main_program()._lr_schedule_guard():
                    learning_rate = lr_decay * TrainTaskConfig.learning_rate

711
                optimizer = fluid.optimizer.Adam(
G
fix  
gongweibao 已提交
712
                    learning_rate=learning_rate,
713 714 715
                    beta1=TrainTaskConfig.beta1,
                    beta2=TrainTaskConfig.beta2,
                    epsilon=TrainTaskConfig.eps)
G
fix bug  
gongweibao 已提交
716
            else:
717 718 719 720 721
                optimizer = fluid.optimizer.SGD(0.003)
            optimizer.minimize(avg_cost)

    if args.use_mem_opt:
        fluid.memory_optimize(train_prog)
722 723

    if args.local:
724
        logging.info("local start_up:")
725 726
        train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, avg_cost,
                   token_num, predict, pyreader)
727
    else:
G
fix  
gongweibao 已提交
728 729 730 731 732 733 734 735 736 737 738 739
        if args.update_method == "nccl2":
            trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
            port = os.getenv("PADDLE_PORT")
            worker_ips = os.getenv("PADDLE_TRAINERS")
            worker_endpoints = []
            for ip in worker_ips.split(","):
                worker_endpoints.append(':'.join([ip, port]))
            trainers_num = len(worker_endpoints)
            current_endpoint = os.getenv("POD_IP") + ":" + port
            if trainer_id == 0:
                logging.info("train_id == 0, sleep 60s")
                time.sleep(60)
740 741 742
            logging.info("trainers_num:{}".format(trainers_num))
            logging.info("worker_endpoints:{}".format(worker_endpoints))
            logging.info("current_endpoint:{}".format(current_endpoint))
G
guosheng 已提交
743 744 745 746 747
            append_nccl2_prepare(startup_prog, trainer_id, worker_endpoints,
                                 current_endpoint)
            train_loop(exe, train_prog, startup_prog, dev_count, sum_cost,
                       avg_cost, token_num, predict, pyreader, trainers_num,
                       trainer_id)
G
fix  
gongweibao 已提交
748 749
            return

750 751 752 753 754 755 756 757 758
        port = os.getenv("PADDLE_PORT", "6174")
        pserver_ips = os.getenv("PADDLE_PSERVERS")  # ip,ip...
        eplist = []
        for ip in pserver_ips.split(","):
            eplist.append(':'.join([ip, port]))
        pserver_endpoints = ",".join(eplist)  # ip:port,ip:port...
        trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "0"))
        current_endpoint = os.getenv("POD_IP") + ":" + port
        trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
G
fix  
gongweibao 已提交
759

760 761 762 763 764
        logging.info("pserver_endpoints:{}".format(pserver_endpoints))
        logging.info("current_endpoint:{}".format(current_endpoint))
        logging.info("trainer_id:{}".format(trainer_id))
        logging.info("pserver_ips:{}".format(pserver_ips))
        logging.info("port:{}".format(port))
G
fix  
gongweibao 已提交
765

766
        t = fluid.DistributeTranspiler()
767 768 769 770 771 772
        t.transpile(
            trainer_id,
            pservers=pserver_endpoints,
            trainers=trainers,
            program=train_prog,
            startup_program=startup_prog)
773 774

        if training_role == "PSERVER":
G
fix bug  
gongweibao 已提交
775
            logging.info("distributed: pserver started")
776 777 778
            current_endpoint = os.getenv("POD_IP") + ":" + os.getenv(
                "PADDLE_PORT")
            if not current_endpoint:
779
                logging.critical("need env SERVER_ENDPOINT")
780 781 782 783 784 785 786 787
                exit(1)
            pserver_prog = t.get_pserver_program(current_endpoint)
            pserver_startup = t.get_startup_program(current_endpoint,
                                                    pserver_prog)

            exe.run(pserver_startup)
            exe.run(pserver_prog)
        elif training_role == "TRAINER":
G
fix bug  
gongweibao 已提交
788
            logging.info("distributed: trainer started")
789
            trainer_prog = t.get_trainer_program()
G
fix  
gongweibao 已提交
790

791 792
            train_loop(exe, train_prog, startup_prog, dev_count, sum_cost,
                       avg_cost, token_num, predict, pyreader)
793
        else:
794 795
            logging.critical(
                "environment var TRAINER_ROLE should be TRAINER os PSERVER")
G
fix  
gongweibao 已提交
796
            exit(1)
797 798 799


if __name__ == "__main__":
G
fix  
gongweibao 已提交
800
    LOG_FORMAT = "[%(asctime)s %(levelname)s %(filename)s:%(lineno)d] %(message)s"
801 802
    logging.basicConfig(
        stream=sys.stdout, level=logging.DEBUG, format=LOG_FORMAT)
803
    logging.getLogger().setLevel(logging.INFO)
G
fix  
gongweibao 已提交
804

805
    args = parse_args()
806
    train(args)