train.py 30.9 KB
Newer Older
1 2
import argparse
import ast
G
guosheng 已提交
3 4
import copy
import logging
5
import multiprocessing
Y
Yu Yang 已提交
6
import os
7
import subprocess
8 9 10 11

if os.environ.get('FLAGS_eager_delete_tensor_gb', None) is None:
    os.environ['FLAGS_eager_delete_tensor_gb'] = '0'

G
guosheng 已提交
12
import six
G
guosheng 已提交
13
import sys
14
sys.path.append("../../")
Y
Yibing Liu 已提交
15
sys.path.append("../../models/neural_machine_translation/transformer/")
Y
Yu Yang 已提交
16
import time
Y
ying 已提交
17

Y
Yu Yang 已提交
18
import numpy as np
L
Luo Tao 已提交
19
import paddle.fluid as fluid
Y
ying 已提交
20

21
from models.model_check import check_cuda
Y
Yu Yang 已提交
22 23
import reader
from config import *
Y
Yibing Liu 已提交
24
from desc import *
25
from model import transformer, position_encoding_init
26 27 28
import dist_utils

num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60


def parse_args():
    parser = argparse.ArgumentParser("Training for Transformer.")
    parser.add_argument(
        "--src_vocab_fpath",
        type=str,
        required=True,
        help="The path of vocabulary file of source language.")
    parser.add_argument(
        "--trg_vocab_fpath",
        type=str,
        required=True,
        help="The path of vocabulary file of target language.")
    parser.add_argument(
        "--train_file_pattern",
        type=str,
        required=True,
        help="The pattern to match training data files.")
    parser.add_argument(
        "--val_file_pattern",
        type=str,
        help="The pattern to match validation data files.")
    parser.add_argument(
        "--use_token_batch",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to "
        "produce batch data according to token number.")
    parser.add_argument(
        "--batch_size",
        type=int,
61
        default=4096,
62
        help="The number of sequences contained in a mini-batch, or the maximum "
63 64 65
        "number of tokens (include paddings) contained in a mini-batch. Note "
        "that this represents the number on single device and the actual batch "
        "size for multi-devices will multiply the device number.")
66 67 68
    parser.add_argument(
        "--pool_size",
        type=int,
69
        default=200000,
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
        help="The buffer size to pool data.")
    parser.add_argument(
        "--sort_type",
        default="pool",
        choices=("global", "pool", "none"),
        help="The grain to sort by length: global for all instances; pool for "
        "instances in pool; none for no sort.")
    parser.add_argument(
        "--shuffle",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to shuffle instances in each pass.")
    parser.add_argument(
        "--shuffle_batch",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to shuffle the data batches.")
    parser.add_argument(
        "--special_token",
89 90
        type=lambda x: x.encode(),
        default=[b"<s>", b"<e>", b"<unk>"],
91 92
        nargs=3,
        help="The <bos>, <eos> and <unk> tokens in the dictionary.")
93 94
    parser.add_argument(
        "--token_delimiter",
95 96
        type=lambda x: x.encode(),
        default=b" ",
97
        help="The delimiter used to split tokens in source or target sentences. "
98
        "For EN-DE BPE data we provided, use spaces as token delimiter. ")
99 100 101 102 103
    parser.add_argument(
        'opts',
        help='See config.py for all options',
        default=None,
        nargs=argparse.REMAINDER)
104 105 106 107 108 109 110 111 112 113 114
    parser.add_argument(
        '--local',
        type=ast.literal_eval,
        default=True,
        help='Whether to run as local mode.')
    parser.add_argument(
        '--device',
        type=str,
        default='GPU',
        choices=['CPU', 'GPU'],
        help="The device type.")
G
fix  
gongweibao 已提交
115 116 117 118 119
    parser.add_argument(
        '--update_method',
        choices=("pserver", "nccl2"),
        default="pserver",
        help='Update method.')
Q
Qiao Longfei 已提交
120 121
    parser.add_argument(
        '--sync', type=ast.literal_eval, default=True, help="sync mode.")
G
guosheng 已提交
122 123 124
    parser.add_argument(
        "--enable_ce",
        type=ast.literal_eval,
125
        default=False,
G
guosheng 已提交
126 127
        help="The flag indicating whether to run the task "
        "for continuous evaluation.")
128 129 130
    parser.add_argument(
        "--use_mem_opt",
        type=ast.literal_eval,
G
guosheng 已提交
131
        default=True,
132 133 134 135 136 137
        help="The flag indicating whether to use memory optimization.")
    parser.add_argument(
        "--use_py_reader",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to use py_reader.")
G
fix  
gongweibao 已提交
138
    parser.add_argument(
G
guosheng 已提交
139 140 141 142
        "--fetch_steps",
        type=int,
        default=100,
        help="The frequency to fetch and print output.")
G
fix  
gongweibao 已提交
143

144
    args = parser.parse_args()
145 146 147 148 149 150 151 152 153 154 155
    # Append args related to dict
    src_dict = reader.DataReader.load_dict(args.src_vocab_fpath)
    trg_dict = reader.DataReader.load_dict(args.trg_vocab_fpath)
    dict_args = [
        "src_vocab_size", str(len(src_dict)), "trg_vocab_size",
        str(len(trg_dict)), "bos_idx", str(src_dict[args.special_token[0]]),
        "eos_idx", str(src_dict[args.special_token[1]]), "unk_idx",
        str(src_dict[args.special_token[2]])
    ]
    merge_cfg_from_list(args.opts + dict_args,
                        [TrainTaskConfig, ModelHyperParams])
156
    return args
157 158


159 160 161 162 163 164 165 166 167 168 169 170
def get_device_num():
    # NOTE(zcd): for multi-processe training, each process use one GPU card.
    if num_trainers > 1: return 1
    visible_device = os.environ.get('CUDA_VISIBLE_DEVICES', None)
    if visible_device:
        device_num = len(visible_device.split(','))
    else:
        device_num = subprocess.check_output(
            ['nvidia-smi', '-L']).decode().count('\n')
    return device_num


G
guosheng 已提交
171 172
def append_nccl2_prepare(startup_prog, trainer_id, worker_endpoints,
                         current_endpoint):
173 174
    assert (trainer_id >= 0 and len(worker_endpoints) > 1 and
            current_endpoint in worker_endpoints)
G
fix  
gongweibao 已提交
175 176
    eps = copy.deepcopy(worker_endpoints)
    eps.remove(current_endpoint)
G
guosheng 已提交
177
    nccl_id_var = startup_prog.global_block().create_var(
178
        name="NCCLID", persistable=True, type=fluid.core.VarDesc.VarType.RAW)
G
guosheng 已提交
179
    startup_prog.global_block().append_op(
G
fix  
gongweibao 已提交
180 181 182 183 184 185 186 187 188
        type="gen_nccl_id",
        inputs={},
        outputs={"NCCLID": nccl_id_var},
        attrs={
            "endpoint": current_endpoint,
            "endpoint_list": eps,
            "trainer_id": trainer_id
        })
    return nccl_id_var
189

190

191 192 193 194
def pad_batch_data(insts,
                   pad_idx,
                   n_head,
                   is_target=False,
195
                   is_label=False,
196
                   return_attn_bias=True,
197 198
                   return_max_len=True,
                   return_num_token=False):
199 200
    """
    Pad the instances to the max sequence length in batch, and generate the
201 202 203 204
    corresponding position data and attention bias.
    """
    return_list = []
    max_len = max(len(inst) for inst in insts)
G
guosheng 已提交
205 206 207 208
    # Any token included in dict can be used to pad, since the paddings' loss
    # will be masked out by weights and make no effect on parameter gradients.
    inst_data = np.array(
        [inst + [pad_idx] * (max_len - len(inst)) for inst in insts])
209
    return_list += [inst_data.astype("int64").reshape([-1, 1])]
210 211 212 213 214 215
    if is_label:  # label weight
        inst_weight = np.array(
            [[1.] * len(inst) + [0.] * (max_len - len(inst)) for inst in insts])
        return_list += [inst_weight.astype("float32").reshape([-1, 1])]
    else:  # position data
        inst_pos = np.array([
216
            list(range(0, len(inst))) + [0] * (max_len - len(inst))
217 218
            for inst in insts
        ])
219 220 221 222 223 224
        return_list += [inst_pos.astype("int64").reshape([-1, 1])]
    if return_attn_bias:
        if is_target:
            # This is used to avoid attention on paddings and subsequent
            # words.
            slf_attn_bias_data = np.ones((inst_data.shape[0], max_len, max_len))
225 226
            slf_attn_bias_data = np.triu(slf_attn_bias_data,
                                         1).reshape([-1, 1, max_len, max_len])
227 228 229 230 231 232 233 234 235 236 237 238 239
            slf_attn_bias_data = np.tile(slf_attn_bias_data,
                                         [1, n_head, 1, 1]) * [-1e9]
        else:
            # This is used to avoid attention on paddings.
            slf_attn_bias_data = np.array([[0] * len(inst) + [-1e9] *
                                           (max_len - len(inst))
                                           for inst in insts])
            slf_attn_bias_data = np.tile(
                slf_attn_bias_data.reshape([-1, 1, 1, max_len]),
                [1, n_head, max_len, 1])
        return_list += [slf_attn_bias_data.astype("float32")]
    if return_max_len:
        return_list += [max_len]
240
    if return_num_token:
G
guosheng 已提交
241 242 243
        num_token = 0
        for inst in insts:
            num_token += len(inst)
244
        return_list += [num_token]
245 246 247
    return return_list if len(return_list) > 1 else return_list[0]


248 249
def prepare_batch_input(insts, data_input_names, src_pad_idx, trg_pad_idx,
                        n_head, d_model):
250 251
    """
    Put all padded data needed by training into a dict.
252
    """
253
    src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
G
guosheng 已提交
254
        [inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
255 256
    src_word = src_word.reshape(-1, src_max_len, 1)
    src_pos = src_pos.reshape(-1, src_max_len, 1)
257
    trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
G
guosheng 已提交
258
        [inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True)
259 260 261
    trg_word = trg_word.reshape(-1, trg_max_len, 1)
    trg_pos = trg_pos.reshape(-1, trg_max_len, 1)

262 263
    trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
                                [1, 1, trg_max_len, 1]).astype("float32")
264

265
    lbl_word, lbl_weight, num_token = pad_batch_data(
266 267 268 269 270 271
        [inst[2] for inst in insts],
        trg_pad_idx,
        n_head,
        is_target=False,
        is_label=True,
        return_attn_bias=False,
272 273 274 275 276 277 278
        return_max_len=False,
        return_num_token=True)

    data_input_dict = dict(
        zip(data_input_names, [
            src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
            trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight
279
        ]))
280

281
    return data_input_dict, np.asarray([num_token], dtype="float32")
282 283


284 285 286 287 288 289
def prepare_data_generator(args,
                           is_test,
                           count,
                           pyreader,
                           py_reader_provider_wrapper,
                           place=None):
Q
Qiao Longfei 已提交
290
    """
291 292
    Data generator wrapper for DataReader. If use py_reader, set the data
    provider for py_reader
Q
Qiao Longfei 已提交
293
    """
294 295 296 297
    # NOTE: If num_trainers > 1, the shuffle_seed must be set, because
    # the order of batch data generated by reader
    # must be the same in the respective processes.
    shuffle_seed = 1 if num_trainers > 1 else None
298 299
    data_reader = reader.DataReader(
        fpattern=args.val_file_pattern if is_test else args.train_file_pattern,
Q
Qiao Longfei 已提交
300 301
        src_vocab_fpath=args.src_vocab_fpath,
        trg_vocab_fpath=args.trg_vocab_fpath,
302
        token_delimiter=args.token_delimiter,
Q
Qiao Longfei 已提交
303
        use_token_batch=args.use_token_batch,
304
        batch_size=args.batch_size * (1 if args.use_token_batch else count),
Q
Qiao Longfei 已提交
305 306
        pool_size=args.pool_size,
        sort_type=args.sort_type,
307
        shuffle=args.shuffle,
308
        shuffle_seed=shuffle_seed,
309
        shuffle_batch=args.shuffle_batch,
Q
Qiao Longfei 已提交
310 311 312 313 314
        start_mark=args.special_token[0],
        end_mark=args.special_token[1],
        unk_mark=args.special_token[2],
        # count start and end tokens out
        max_length=ModelHyperParams.max_length - 2,
315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353
        clip_last_batch=False).batch_generator

    def stack(data_reader, count, clip_last=True):
        def __impl__():
            res = []
            for item in data_reader():
                res.append(item)
                if len(res) == count:
                    yield res
                    res = []
            if len(res) == count:
                yield res
            elif not clip_last:
                data = []
                for item in res:
                    data += item
                if len(data) > count:
                    inst_num_per_part = len(data) // count
                    yield [
                        data[inst_num_per_part * i:inst_num_per_part * (i + 1)]
                        for i in range(count)
                    ]

        return __impl__

    def split(data_reader, count):
        def __impl__():
            for item in data_reader():
                inst_num_per_part = len(item) // count
                for i in range(count):
                    yield item[inst_num_per_part * i:inst_num_per_part * (i + 1
                                                                          )]

        return __impl__

    if not args.use_token_batch:
        # to make data on each device have similar token number
        data_reader = split(data_reader, count)
    if args.use_py_reader:
354 355 356 357 358 359
        train_reader = py_reader_provider_wrapper(data_reader, place)
        if num_trainers > 1:
            assert shuffle_seed is not None
            train_reader = fluid.contrib.reader.distributed_batch_reader(
                train_reader)
        pyreader.decorate_tensor_provider(train_reader)
360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396
        data_reader = None
    else:  # Data generator for multi-devices
        data_reader = stack(data_reader, count)
    return data_reader


def prepare_feed_dict_list(data_generator, init_flag, count):
    """
    Prepare the list of feed dict for multi-devices.
    """
    feed_dict_list = []
    if data_generator is not None:  # use_py_reader == False
        data_input_names = encoder_data_input_fields + \
                    decoder_data_input_fields[:-1] + label_data_input_fields
        data = next(data_generator)
        for idx, data_buffer in enumerate(data):
            data_input_dict, num_token = prepare_batch_input(
                data_buffer, data_input_names, ModelHyperParams.eos_idx,
                ModelHyperParams.eos_idx, ModelHyperParams.n_head,
                ModelHyperParams.d_model)
            feed_dict_list.append(data_input_dict)
    if init_flag:
        for idx in range(count):
            pos_enc_tables = dict()
            for pos_enc_param_name in pos_enc_param_names:
                pos_enc_tables[pos_enc_param_name] = position_encoding_init(
                    ModelHyperParams.max_length + 1, ModelHyperParams.d_model)
            if len(feed_dict_list) <= idx:
                feed_dict_list.append(pos_enc_tables)
            else:
                feed_dict_list[idx] = dict(
                    list(pos_enc_tables.items()) + list(feed_dict_list[idx]
                                                        .items()))

    return feed_dict_list if len(feed_dict_list) == count else None


397
def py_reader_provider_wrapper(data_reader, place):
398 399 400
    """
    Data provider needed by fluid.layers.py_reader.
    """
Q
Qiao Longfei 已提交
401

402 403 404 405 406 407 408 409
    def py_reader_provider():
        data_input_names = encoder_data_input_fields + \
                    decoder_data_input_fields[:-1] + label_data_input_fields
        for batch_id, data in enumerate(data_reader()):
            data_input_dict, num_token = prepare_batch_input(
                data, data_input_names, ModelHyperParams.eos_idx,
                ModelHyperParams.eos_idx, ModelHyperParams.n_head,
                ModelHyperParams.d_model)
C
chengduo 已提交
410 411
            total_dict = dict(data_input_dict.items())
            yield [total_dict[item] for item in data_input_names]
412 413 414 415 416 417 418

    return py_reader_provider


def test_context(exe, train_exe, dev_count):
    # Context to do validation.
    test_prog = fluid.Program()
G
guosheng 已提交
419 420 421 422
    startup_prog = fluid.Program()
    if args.enable_ce:
        test_prog.random_seed = 1000
        startup_prog.random_seed = 1000
423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443
    with fluid.program_guard(test_prog, startup_prog):
        with fluid.unique_name.guard():
            sum_cost, avg_cost, predict, token_num, pyreader = transformer(
                ModelHyperParams.src_vocab_size,
                ModelHyperParams.trg_vocab_size,
                ModelHyperParams.max_length + 1,
                ModelHyperParams.n_layer,
                ModelHyperParams.n_head,
                ModelHyperParams.d_key,
                ModelHyperParams.d_value,
                ModelHyperParams.d_model,
                ModelHyperParams.d_inner_hid,
                ModelHyperParams.prepostprocess_dropout,
                ModelHyperParams.attention_dropout,
                ModelHyperParams.relu_dropout,
                ModelHyperParams.preprocess_cmd,
                ModelHyperParams.postprocess_cmd,
                ModelHyperParams.weight_sharing,
                TrainTaskConfig.label_smooth_eps,
                use_py_reader=args.use_py_reader,
                is_test=True)
G
guosheng 已提交
444
    test_prog = test_prog.clone(for_test=True)
445
    test_data = prepare_data_generator(
446 447 448 449 450
        args,
        is_test=True,
        count=dev_count,
        pyreader=pyreader,
        py_reader_provider_wrapper=py_reader_provider_wrapper)
451

452 453 454 455 456 457
    exe.run(startup_prog)  # to init pyreader for testing
    if TrainTaskConfig.ckpt_path:
        fluid.io.load_persistables(
            exe, TrainTaskConfig.ckpt_path, main_program=test_prog)

    build_strategy = fluid.BuildStrategy()
Q
Qiao Longfei 已提交
458 459
    test_exe = fluid.ParallelExecutor(
        use_cuda=TrainTaskConfig.use_gpu,
460
        main_program=test_prog,
461
        build_strategy=build_strategy,
Q
Qiao Longfei 已提交
462 463
        share_vars_from=train_exe)

464
    def test(exe=test_exe, pyreader=pyreader):
Q
Qiao Longfei 已提交
465 466
        test_total_cost = 0
        test_total_token = 0
467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483

        if args.use_py_reader:
            pyreader.start()
            data_generator = None
        else:
            data_generator = test_data()
        while True:
            try:
                feed_dict_list = prepare_feed_dict_list(data_generator, False,
                                                        dev_count)
                outs = test_exe.run(fetch_list=[sum_cost.name, token_num.name],
                                    feed=feed_dict_list)
            except (StopIteration, fluid.core.EOFException):
                # The current pass is over.
                if args.use_py_reader:
                    pyreader.reset()
                break
Q
Qiao Longfei 已提交
484 485 486 487 488 489 490 491 492 493
            sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[1])
            test_total_cost += sum_cost_val.sum()
            test_total_token += token_num_val.sum()
        test_avg_cost = test_total_cost / test_total_token
        test_ppl = np.exp([min(test_avg_cost, 100)])
        return test_avg_cost, test_ppl

    return test


494 495 496 497 498 499 500 501 502 503 504
def train_loop(exe,
               train_prog,
               startup_prog,
               dev_count,
               sum_cost,
               avg_cost,
               token_num,
               predict,
               pyreader,
               nccl2_num_trainers=1,
               nccl2_trainer_id=0):
Q
Qiao Longfei 已提交
505 506
    # Initialize the parameters.
    if TrainTaskConfig.ckpt_path:
507 508 509 510 511
        exe.run(startup_prog)  # to init pyreader for training
        logging.info("load checkpoint from {}".format(
            TrainTaskConfig.ckpt_path))
        fluid.io.load_persistables(
            exe, TrainTaskConfig.ckpt_path, main_program=train_prog)
Q
Qiao Longfei 已提交
512
    else:
G
fix  
gongweibao 已提交
513
        logging.info("init fluid.framework.default_startup_program")
514
        exe.run(startup_prog)
Q
Qiao Longfei 已提交
515

G
fix  
gongweibao 已提交
516
    logging.info("begin reader")
517
    train_data = prepare_data_generator(
518 519 520 521 522
        args,
        is_test=False,
        count=dev_count,
        pyreader=pyreader,
        py_reader_provider_wrapper=py_reader_provider_wrapper)
Q
Qiao Longfei 已提交
523

524 525
    # For faster executor
    exec_strategy = fluid.ExecutionStrategy()
526
    exec_strategy.num_iteration_per_drop_scope = int(args.fetch_steps)
Q
Qiao Longfei 已提交
527
    build_strategy = fluid.BuildStrategy()
C
chengduo 已提交
528 529 530 531 532
    build_strategy.memory_optimize = False
    build_strategy.enable_inplace = True

    sum_cost.persistable = True
    token_num.persistable = True
Q
Qiao Longfei 已提交
533 534 535
    # Since the token number differs among devices, customize gradient scale to
    # use token average cost among multi-devices. and the gradient scale is
    # `1 / token_number` for average cost.
G
guosheng 已提交
536
    # build_strategy.gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.Customized
C
chengduo 已提交
537
    build_strategy.fuse_all_optimizer_ops = True
G
fix  
gongweibao 已提交
538

539 540 541 542
    if num_trainers > 1 and args.use_py_reader and TrainTaskConfig.use_gpu:
        dist_utils.prepare_for_multi_process(exe, build_strategy, train_prog)
        exec_strategy.num_threads = 1

G
fix  
gongweibao 已提交
543
    logging.info("begin executor")
Q
Qiao Longfei 已提交
544 545
    train_exe = fluid.ParallelExecutor(
        use_cuda=TrainTaskConfig.use_gpu,
546 547 548
        loss_name=avg_cost.name,
        main_program=train_prog,
        build_strategy=build_strategy,
G
fix  
gongweibao 已提交
549
        exec_strategy=exec_strategy,
550 551
        num_trainers=nccl2_num_trainers,
        trainer_id=nccl2_trainer_id)
Q
Qiao Longfei 已提交
552 553

    if args.val_file_pattern is not None:
554
        test = test_context(exe, train_exe, dev_count)
Q
Qiao Longfei 已提交
555

G
guosheng 已提交
556 557 558 559 560 561
    # the best cross-entropy value with label smoothing
    loss_normalizer = -((1. - TrainTaskConfig.label_smooth_eps) * np.log(
        (1. - TrainTaskConfig.label_smooth_eps
         )) + TrainTaskConfig.label_smooth_eps *
                        np.log(TrainTaskConfig.label_smooth_eps / (
                            ModelHyperParams.trg_vocab_size - 1) + 1e-20))
G
guosheng 已提交
562

M
minqiyang 已提交
563
    step_idx = 0
564
    init_flag = True
G
fix  
gongweibao 已提交
565
    logging.info("begin train")
G
guosheng 已提交
566
    for pass_id in six.moves.xrange(TrainTaskConfig.pass_num):
Q
Qiao Longfei 已提交
567
        pass_start_time = time.time()
568 569 570 571 572 573 574 575 576 577 578 579 580

        if args.use_py_reader:
            pyreader.start()
            data_generator = None
        else:
            data_generator = train_data()

        batch_id = 0
        while True:
            try:
                feed_dict_list = prepare_feed_dict_list(data_generator,
                                                        init_flag, dev_count)
                outs = train_exe.run(
581
                    fetch_list=[sum_cost.name, token_num.name]
G
guosheng 已提交
582
                    if step_idx % args.fetch_steps == 0 else [],
583
                    feed=feed_dict_list)
584

G
guosheng 已提交
585
                if step_idx % args.fetch_steps == 0:
586 587
                    sum_cost_val, token_num_val = np.array(outs[0]), np.array(
                        outs[1])
G
fix  
gongweibao 已提交
588 589 590 591 592
                    # sum the cost from multi-devices
                    total_sum_cost = sum_cost_val.sum()
                    total_token_num = token_num_val.sum()
                    total_avg_cost = total_sum_cost / total_token_num

G
guosheng 已提交
593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611
                    if step_idx == 0:
                        logging.info(
                            "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
                            "normalized loss: %f, ppl: %f" %
                            (step_idx, pass_id, batch_id, total_avg_cost,
                             total_avg_cost - loss_normalizer,
                             np.exp([min(total_avg_cost, 100)])))
                        avg_batch_time = time.time()
                    else:
                        logging.info(
                            "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
                            "normalized loss: %f, ppl: %f, speed: %.2f step/s" %
                            (step_idx, pass_id, batch_id, total_avg_cost,
                             total_avg_cost - loss_normalizer,
                             np.exp([min(total_avg_cost, 100)]),
                             args.fetch_steps / (time.time() - avg_batch_time)))
                        avg_batch_time = time.time()

                if step_idx % TrainTaskConfig.save_freq == 0 and step_idx > 0:
612 613 614 615 616 617 618 619 620
                    fluid.io.save_persistables(
                        exe,
                        os.path.join(TrainTaskConfig.ckpt_dir,
                                     "latest.checkpoint"), train_prog)
                    fluid.io.save_params(
                        exe,
                        os.path.join(TrainTaskConfig.model_dir,
                                     "iter_" + str(step_idx) + ".infer.model"),
                        train_prog)
G
guosheng 已提交
621

622 623 624 625 626 627 628 629
                init_flag = False
                batch_id += 1
                step_idx += 1
            except (StopIteration, fluid.core.EOFException):
                # The current pass is over.
                if args.use_py_reader:
                    pyreader.reset()
                break
G
guosheng 已提交
630 631

        time_consumed = time.time() - pass_start_time
632
        # Validate and save the persistable.
G
guosheng 已提交
633 634
        if args.val_file_pattern is not None:
            val_avg_cost, val_ppl = test()
G
fix  
gongweibao 已提交
635
            logging.info(
G
guosheng 已提交
636 637 638 639 640
                "epoch: %d, val avg loss: %f, val normalized loss: %f, val ppl: %f,"
                " consumed %fs" % (pass_id, val_avg_cost,
                                   val_avg_cost - loss_normalizer, val_ppl,
                                   time_consumed))
        else:
G
fix  
gongweibao 已提交
641
            logging.info("epoch: %d, consumed %fs" % (pass_id, time_consumed))
642

G
guosheng 已提交
643 644 645 646 647 648
        if not args.enable_ce:
            fluid.io.save_persistables(
                exe,
                os.path.join(TrainTaskConfig.ckpt_dir,
                             "pass_" + str(pass_id) + ".checkpoint"),
                train_prog)
649

G
guosheng 已提交
650
    if args.enable_ce:  # For CE
651
        print("kpis\ttrain_cost_card%d\t%f" % (dev_count, total_avg_cost))
652 653
        if args.val_file_pattern is not None:
            print("kpis\ttest_cost_card%d\t%f" % (dev_count, val_avg_cost))
654
        print("kpis\ttrain_duration_card%d\t%f" % (dev_count, time_consumed))
Q
Qiao Longfei 已提交
655 656


657 658 659 660 661
def train(args):
    # priority: ENV > args > config
    is_local = os.getenv("PADDLE_IS_LOCAL", "1")
    if is_local == '0':
        args.local = False
G
fix  
gongweibao 已提交
662
    logging.info(args)
663

664 665
    if args.device == 'CPU':
        TrainTaskConfig.use_gpu = False
G
guosheng 已提交
666

667
    training_role = os.getenv("TRAINING_ROLE", "TRAINER")
G
guosheng 已提交
668

669 670
    if training_role == "PSERVER" or (not TrainTaskConfig.use_gpu):
        place = fluid.CPUPlace()
671 672
        # the default setting of CPU_NUM in paddle framework is 1
        dev_count = int(os.environ.get('CPU_NUM', 1))
673
    else:
674
        check_cuda(TrainTaskConfig.use_gpu)
C
chengduo 已提交
675 676
        gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
        place = fluid.CUDAPlace(gpu_id)
677
        dev_count = get_device_num()
678 679

    exe = fluid.Executor(place)
680

681 682
    train_prog = fluid.Program()
    startup_prog = fluid.Program()
G
guosheng 已提交
683

684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706
    if args.enable_ce:
        train_prog.random_seed = 1000
        startup_prog.random_seed = 1000

    with fluid.program_guard(train_prog, startup_prog):
        with fluid.unique_name.guard():
            sum_cost, avg_cost, predict, token_num, pyreader = transformer(
                ModelHyperParams.src_vocab_size,
                ModelHyperParams.trg_vocab_size,
                ModelHyperParams.max_length + 1,
                ModelHyperParams.n_layer,
                ModelHyperParams.n_head,
                ModelHyperParams.d_key,
                ModelHyperParams.d_value,
                ModelHyperParams.d_model,
                ModelHyperParams.d_inner_hid,
                ModelHyperParams.prepostprocess_dropout,
                ModelHyperParams.attention_dropout,
                ModelHyperParams.relu_dropout,
                ModelHyperParams.preprocess_cmd,
                ModelHyperParams.postprocess_cmd,
                ModelHyperParams.weight_sharing,
                TrainTaskConfig.label_smooth_eps,
Y
Yibing Liu 已提交
707
                ModelHyperParams.bos_idx,
708 709
                use_py_reader=args.use_py_reader,
                is_test=False)
710

711
            optimizer = None
G
fix bug  
gongweibao 已提交
712
            if args.sync:
713 714
                lr_decay = fluid.layers.learning_rate_scheduler.noam_decay(
                    ModelHyperParams.d_model, TrainTaskConfig.warmup_steps)
715
                logging.info("before adam")
G
fix  
gongweibao 已提交
716 717 718 719

                with fluid.default_main_program()._lr_schedule_guard():
                    learning_rate = lr_decay * TrainTaskConfig.learning_rate

720
                optimizer = fluid.optimizer.Adam(
G
fix  
gongweibao 已提交
721
                    learning_rate=learning_rate,
722 723 724
                    beta1=TrainTaskConfig.beta1,
                    beta2=TrainTaskConfig.beta2,
                    epsilon=TrainTaskConfig.eps)
G
fix bug  
gongweibao 已提交
725
            else:
726 727 728
                optimizer = fluid.optimizer.SGD(0.003)
            optimizer.minimize(avg_cost)

729
    if args.local:
730
        logging.info("local start_up:")
731 732
        train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, avg_cost,
                   token_num, predict, pyreader)
733
    else:
G
fix  
gongweibao 已提交
734 735 736 737 738 739 740 741 742 743 744 745
        if args.update_method == "nccl2":
            trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
            port = os.getenv("PADDLE_PORT")
            worker_ips = os.getenv("PADDLE_TRAINERS")
            worker_endpoints = []
            for ip in worker_ips.split(","):
                worker_endpoints.append(':'.join([ip, port]))
            trainers_num = len(worker_endpoints)
            current_endpoint = os.getenv("POD_IP") + ":" + port
            if trainer_id == 0:
                logging.info("train_id == 0, sleep 60s")
                time.sleep(60)
746 747 748
            logging.info("trainers_num:{}".format(trainers_num))
            logging.info("worker_endpoints:{}".format(worker_endpoints))
            logging.info("current_endpoint:{}".format(current_endpoint))
G
guosheng 已提交
749 750 751 752 753
            append_nccl2_prepare(startup_prog, trainer_id, worker_endpoints,
                                 current_endpoint)
            train_loop(exe, train_prog, startup_prog, dev_count, sum_cost,
                       avg_cost, token_num, predict, pyreader, trainers_num,
                       trainer_id)
G
fix  
gongweibao 已提交
754 755
            return

756 757 758 759 760 761 762 763 764
        port = os.getenv("PADDLE_PORT", "6174")
        pserver_ips = os.getenv("PADDLE_PSERVERS")  # ip,ip...
        eplist = []
        for ip in pserver_ips.split(","):
            eplist.append(':'.join([ip, port]))
        pserver_endpoints = ",".join(eplist)  # ip:port,ip:port...
        trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "0"))
        current_endpoint = os.getenv("POD_IP") + ":" + port
        trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
G
fix  
gongweibao 已提交
765

766 767 768 769 770
        logging.info("pserver_endpoints:{}".format(pserver_endpoints))
        logging.info("current_endpoint:{}".format(current_endpoint))
        logging.info("trainer_id:{}".format(trainer_id))
        logging.info("pserver_ips:{}".format(pserver_ips))
        logging.info("port:{}".format(port))
G
fix  
gongweibao 已提交
771

772
        t = fluid.DistributeTranspiler()
773 774 775 776 777 778
        t.transpile(
            trainer_id,
            pservers=pserver_endpoints,
            trainers=trainers,
            program=train_prog,
            startup_program=startup_prog)
779 780

        if training_role == "PSERVER":
G
fix bug  
gongweibao 已提交
781
            logging.info("distributed: pserver started")
782 783 784
            current_endpoint = os.getenv("POD_IP") + ":" + os.getenv(
                "PADDLE_PORT")
            if not current_endpoint:
785
                logging.critical("need env SERVER_ENDPOINT")
786 787 788 789 790 791 792 793
                exit(1)
            pserver_prog = t.get_pserver_program(current_endpoint)
            pserver_startup = t.get_startup_program(current_endpoint,
                                                    pserver_prog)

            exe.run(pserver_startup)
            exe.run(pserver_prog)
        elif training_role == "TRAINER":
G
fix bug  
gongweibao 已提交
794
            logging.info("distributed: trainer started")
795
            trainer_prog = t.get_trainer_program()
G
fix  
gongweibao 已提交
796

797 798
            train_loop(exe, train_prog, startup_prog, dev_count, sum_cost,
                       avg_cost, token_num, predict, pyreader)
799
        else:
800 801
            logging.critical(
                "environment var TRAINER_ROLE should be TRAINER os PSERVER")
G
fix  
gongweibao 已提交
802
            exit(1)
803 804 805


if __name__ == "__main__":
G
fix  
gongweibao 已提交
806
    LOG_FORMAT = "[%(asctime)s %(levelname)s %(filename)s:%(lineno)d] %(message)s"
807 808
    logging.basicConfig(
        stream=sys.stdout, level=logging.DEBUG, format=LOG_FORMAT)
809
    logging.getLogger().setLevel(logging.INFO)
G
fix  
gongweibao 已提交
810

811
    args = parse_args()
812
    train(args)