train.py 27.3 KB
Newer Older
1 2
import argparse
import ast
3
import multiprocessing
Y
Yu Yang 已提交
4
import os
G
guosheng 已提交
5
import six
Y
Yu Yang 已提交
6
import time
Y
ying 已提交
7

Y
Yu Yang 已提交
8
import numpy as np
L
Luo Tao 已提交
9
import paddle.fluid as fluid
Y
ying 已提交
10

Y
Yu Yang 已提交
11 12
import reader
from config import *
13
from model import transformer, position_encoding_init
14

G
fix  
gongweibao 已提交
15 16
import logging
import sys
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47

def parse_args():
    parser = argparse.ArgumentParser("Training for Transformer.")
    parser.add_argument(
        "--src_vocab_fpath",
        type=str,
        required=True,
        help="The path of vocabulary file of source language.")
    parser.add_argument(
        "--trg_vocab_fpath",
        type=str,
        required=True,
        help="The path of vocabulary file of target language.")
    parser.add_argument(
        "--train_file_pattern",
        type=str,
        required=True,
        help="The pattern to match training data files.")
    parser.add_argument(
        "--val_file_pattern",
        type=str,
        help="The pattern to match validation data files.")
    parser.add_argument(
        "--use_token_batch",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to "
        "produce batch data according to token number.")
    parser.add_argument(
        "--batch_size",
        type=int,
48
        default=4096,
49
        help="The number of sequences contained in a mini-batch, or the maximum "
50 51 52
        "number of tokens (include paddings) contained in a mini-batch. Note "
        "that this represents the number on single device and the actual batch "
        "size for multi-devices will multiply the device number.")
53 54 55
    parser.add_argument(
        "--pool_size",
        type=int,
56
        default=200000,
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
        help="The buffer size to pool data.")
    parser.add_argument(
        "--sort_type",
        default="pool",
        choices=("global", "pool", "none"),
        help="The grain to sort by length: global for all instances; pool for "
        "instances in pool; none for no sort.")
    parser.add_argument(
        "--shuffle",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to shuffle instances in each pass.")
    parser.add_argument(
        "--shuffle_batch",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to shuffle the data batches.")
    parser.add_argument(
        "--special_token",
        type=str,
        default=["<s>", "<e>", "<unk>"],
        nargs=3,
        help="The <bos>, <eos> and <unk> tokens in the dictionary.")
80 81
    parser.add_argument(
        "--token_delimiter",
G
guosheng 已提交
82
        type=lambda x: str(x.encode().decode("unicode-escape")),
83 84
        default=" ",
        help="The delimiter used to split tokens in source or target sentences. "
85
        "For EN-DE BPE data we provided, use spaces as token delimiter. "
86
        "For EN-FR wordpiece data we provided, use '\x01' as token delimiter.")
87 88 89 90 91
    parser.add_argument(
        'opts',
        help='See config.py for all options',
        default=None,
        nargs=argparse.REMAINDER)
92 93 94 95 96 97 98 99 100 101 102
    parser.add_argument(
        '--local',
        type=ast.literal_eval,
        default=True,
        help='Whether to run as local mode.')
    parser.add_argument(
        '--device',
        type=str,
        default='GPU',
        choices=['CPU', 'GPU'],
        help="The device type.")
G
fix  
gongweibao 已提交
103 104 105 106 107
    parser.add_argument(
        '--update_method',
        choices=("pserver", "nccl2"),
        default="pserver",
        help='Update method.')
Q
Qiao Longfei 已提交
108 109
    parser.add_argument(
        '--sync', type=ast.literal_eval, default=True, help="sync mode.")
G
guosheng 已提交
110 111 112
    parser.add_argument(
        "--enable_ce",
        type=ast.literal_eval,
113
        default=False,
G
guosheng 已提交
114 115
        help="The flag indicating whether to run the task "
        "for continuous evaluation.")
116 117 118 119 120 121 122 123 124 125
    parser.add_argument(
        "--use_mem_opt",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to use memory optimization.")
    parser.add_argument(
        "--use_py_reader",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to use py_reader.")
G
fix  
gongweibao 已提交
126 127 128 129 130 131
    parser.add_argument(
        "--fetch_steps",
        type=int,
        default=100,
        help="Fetch outputs steps.")

132

133
    args = parser.parse_args()
134 135 136 137 138 139 140 141 142 143 144
    # Append args related to dict
    src_dict = reader.DataReader.load_dict(args.src_vocab_fpath)
    trg_dict = reader.DataReader.load_dict(args.trg_vocab_fpath)
    dict_args = [
        "src_vocab_size", str(len(src_dict)), "trg_vocab_size",
        str(len(trg_dict)), "bos_idx", str(src_dict[args.special_token[0]]),
        "eos_idx", str(src_dict[args.special_token[1]]), "unk_idx",
        str(src_dict[args.special_token[2]])
    ]
    merge_cfg_from_list(args.opts + dict_args,
                        [TrainTaskConfig, ModelHyperParams])
145
    return args
146

G
fix  
gongweibao 已提交
147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
def append_nccl2_prepare(trainer_id, worker_endpoints, current_endpoint):
    assert(trainer_id >= 0 and
           len(worker_endpoints) > 1 and
           current_endpoint in worker_endpoints)
    eps = copy.deepcopy(worker_endpoints)
    eps.remove(current_endpoint)
    nccl_id_var = fluid.default_startup_program().global_block().create_var(
        name="NCCLID",
        persistable=True,
        type=fluid.core.VarDesc.VarType.RAW)
    fluid.default_startup_program().global_block().append_op(
        type="gen_nccl_id",
        inputs={},
        outputs={"NCCLID": nccl_id_var},
        attrs={
            "endpoint": current_endpoint,
            "endpoint_list": eps,
            "trainer_id": trainer_id
        })
    return nccl_id_var
167

168 169 170 171
def pad_batch_data(insts,
                   pad_idx,
                   n_head,
                   is_target=False,
172
                   is_label=False,
173
                   return_attn_bias=True,
174 175
                   return_max_len=True,
                   return_num_token=False):
176 177
    """
    Pad the instances to the max sequence length in batch, and generate the
178 179 180 181
    corresponding position data and attention bias.
    """
    return_list = []
    max_len = max(len(inst) for inst in insts)
G
guosheng 已提交
182 183 184 185
    # Any token included in dict can be used to pad, since the paddings' loss
    # will be masked out by weights and make no effect on parameter gradients.
    inst_data = np.array(
        [inst + [pad_idx] * (max_len - len(inst)) for inst in insts])
186
    return_list += [inst_data.astype("int64").reshape([-1, 1])]
187 188 189 190 191 192
    if is_label:  # label weight
        inst_weight = np.array(
            [[1.] * len(inst) + [0.] * (max_len - len(inst)) for inst in insts])
        return_list += [inst_weight.astype("float32").reshape([-1, 1])]
    else:  # position data
        inst_pos = np.array([
193
            list(range(0, len(inst))) + [0] * (max_len - len(inst))
194 195
            for inst in insts
        ])
196 197 198 199 200 201
        return_list += [inst_pos.astype("int64").reshape([-1, 1])]
    if return_attn_bias:
        if is_target:
            # This is used to avoid attention on paddings and subsequent
            # words.
            slf_attn_bias_data = np.ones((inst_data.shape[0], max_len, max_len))
202 203
            slf_attn_bias_data = np.triu(slf_attn_bias_data,
                                         1).reshape([-1, 1, max_len, max_len])
204 205 206 207 208 209 210 211 212 213 214 215 216
            slf_attn_bias_data = np.tile(slf_attn_bias_data,
                                         [1, n_head, 1, 1]) * [-1e9]
        else:
            # This is used to avoid attention on paddings.
            slf_attn_bias_data = np.array([[0] * len(inst) + [-1e9] *
                                           (max_len - len(inst))
                                           for inst in insts])
            slf_attn_bias_data = np.tile(
                slf_attn_bias_data.reshape([-1, 1, 1, max_len]),
                [1, n_head, max_len, 1])
        return_list += [slf_attn_bias_data.astype("float32")]
    if return_max_len:
        return_list += [max_len]
217
    if return_num_token:
G
guosheng 已提交
218 219 220
        num_token = 0
        for inst in insts:
            num_token += len(inst)
221
        return_list += [num_token]
222 223 224
    return return_list if len(return_list) > 1 else return_list[0]


225 226
def prepare_batch_input(insts, data_input_names, src_pad_idx, trg_pad_idx,
                        n_head, d_model):
227 228
    """
    Put all padded data needed by training into a dict.
229
    """
230
    src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
G
guosheng 已提交
231
        [inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
232 233
    src_word = src_word.reshape(-1, src_max_len, 1)
    src_pos = src_pos.reshape(-1, src_max_len, 1)
234
    trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
G
guosheng 已提交
235
        [inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True)
236 237 238
    trg_word = trg_word.reshape(-1, trg_max_len, 1)
    trg_pos = trg_pos.reshape(-1, trg_max_len, 1)

239 240
    trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
                                [1, 1, trg_max_len, 1]).astype("float32")
241

242
    lbl_word, lbl_weight, num_token = pad_batch_data(
243 244 245 246 247 248
        [inst[2] for inst in insts],
        trg_pad_idx,
        n_head,
        is_target=False,
        is_label=True,
        return_attn_bias=False,
249 250 251 252 253 254 255
        return_max_len=False,
        return_num_token=True)

    data_input_dict = dict(
        zip(data_input_names, [
            src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
            trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight
256
        ]))
257

258
    return data_input_dict, np.asarray([num_token], dtype="float32")
259 260


261
def prepare_data_generator(args, is_test, count, pyreader):
Q
Qiao Longfei 已提交
262
    """
263 264
    Data generator wrapper for DataReader. If use py_reader, set the data
    provider for py_reader
Q
Qiao Longfei 已提交
265
    """
266 267
    data_reader = reader.DataReader(
        fpattern=args.val_file_pattern if is_test else args.train_file_pattern,
Q
Qiao Longfei 已提交
268 269
        src_vocab_fpath=args.src_vocab_fpath,
        trg_vocab_fpath=args.trg_vocab_fpath,
270
        token_delimiter=args.token_delimiter,
Q
Qiao Longfei 已提交
271
        use_token_batch=args.use_token_batch,
272
        batch_size=args.batch_size * (1 if args.use_token_batch else count),
Q
Qiao Longfei 已提交
273 274
        pool_size=args.pool_size,
        sort_type=args.sort_type,
275 276
        shuffle=args.shuffle,
        shuffle_batch=args.shuffle_batch,
Q
Qiao Longfei 已提交
277 278 279 280 281
        start_mark=args.special_token[0],
        end_mark=args.special_token[1],
        unk_mark=args.special_token[2],
        # count start and end tokens out
        max_length=ModelHyperParams.max_length - 2,
282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363
        clip_last_batch=False).batch_generator

    def stack(data_reader, count, clip_last=True):
        def __impl__():
            res = []
            for item in data_reader():
                res.append(item)
                if len(res) == count:
                    yield res
                    res = []
            if len(res) == count:
                yield res
            elif not clip_last:
                data = []
                for item in res:
                    data += item
                if len(data) > count:
                    inst_num_per_part = len(data) // count
                    yield [
                        data[inst_num_per_part * i:inst_num_per_part * (i + 1)]
                        for i in range(count)
                    ]

        return __impl__

    def split(data_reader, count):
        def __impl__():
            for item in data_reader():
                inst_num_per_part = len(item) // count
                for i in range(count):
                    yield item[inst_num_per_part * i:inst_num_per_part * (i + 1
                                                                          )]

        return __impl__

    if not args.use_token_batch:
        # to make data on each device have similar token number
        data_reader = split(data_reader, count)
    if args.use_py_reader:
        pyreader.decorate_tensor_provider(
            py_reader_provider_wrapper(data_reader))
        data_reader = None
    else:  # Data generator for multi-devices
        data_reader = stack(data_reader, count)
    return data_reader


def prepare_feed_dict_list(data_generator, init_flag, count):
    """
    Prepare the list of feed dict for multi-devices.
    """
    feed_dict_list = []
    if data_generator is not None:  # use_py_reader == False
        data_input_names = encoder_data_input_fields + \
                    decoder_data_input_fields[:-1] + label_data_input_fields
        data = next(data_generator)
        for idx, data_buffer in enumerate(data):
            data_input_dict, num_token = prepare_batch_input(
                data_buffer, data_input_names, ModelHyperParams.eos_idx,
                ModelHyperParams.eos_idx, ModelHyperParams.n_head,
                ModelHyperParams.d_model)
            feed_dict_list.append(data_input_dict)
    if init_flag:
        for idx in range(count):
            pos_enc_tables = dict()
            for pos_enc_param_name in pos_enc_param_names:
                pos_enc_tables[pos_enc_param_name] = position_encoding_init(
                    ModelHyperParams.max_length + 1, ModelHyperParams.d_model)
            if len(feed_dict_list) <= idx:
                feed_dict_list.append(pos_enc_tables)
            else:
                feed_dict_list[idx] = dict(
                    list(pos_enc_tables.items()) + list(feed_dict_list[idx]
                                                        .items()))

    return feed_dict_list if len(feed_dict_list) == count else None


def py_reader_provider_wrapper(data_reader):
    """
    Data provider needed by fluid.layers.py_reader.
    """
Q
Qiao Longfei 已提交
364

365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408
    def py_reader_provider():
        data_input_names = encoder_data_input_fields + \
                    decoder_data_input_fields[:-1] + label_data_input_fields
        for batch_id, data in enumerate(data_reader()):
            data_input_dict, num_token = prepare_batch_input(
                data, data_input_names, ModelHyperParams.eos_idx,
                ModelHyperParams.eos_idx, ModelHyperParams.n_head,
                ModelHyperParams.d_model)
            total_dict = dict(data_input_dict.items())
            yield [total_dict[item] for item in data_input_names]

    return py_reader_provider


def test_context(exe, train_exe, dev_count):
    # Context to do validation.
    startup_prog = fluid.Program()
    test_prog = fluid.Program()
    with fluid.program_guard(test_prog, startup_prog):
        with fluid.unique_name.guard():
            sum_cost, avg_cost, predict, token_num, pyreader = transformer(
                ModelHyperParams.src_vocab_size,
                ModelHyperParams.trg_vocab_size,
                ModelHyperParams.max_length + 1,
                ModelHyperParams.n_layer,
                ModelHyperParams.n_head,
                ModelHyperParams.d_key,
                ModelHyperParams.d_value,
                ModelHyperParams.d_model,
                ModelHyperParams.d_inner_hid,
                ModelHyperParams.prepostprocess_dropout,
                ModelHyperParams.attention_dropout,
                ModelHyperParams.relu_dropout,
                ModelHyperParams.preprocess_cmd,
                ModelHyperParams.postprocess_cmd,
                ModelHyperParams.weight_sharing,
                TrainTaskConfig.label_smooth_eps,
                use_py_reader=args.use_py_reader,
                is_test=True)

    test_data = prepare_data_generator(
        args, is_test=True, count=dev_count, pyreader=pyreader)

    exe.run(startup_prog)
Q
Qiao Longfei 已提交
409 410
    test_exe = fluid.ParallelExecutor(
        use_cuda=TrainTaskConfig.use_gpu,
411
        main_program=test_prog,
Q
Qiao Longfei 已提交
412 413
        share_vars_from=train_exe)

414
    def test(exe=test_exe, pyreader=pyreader):
Q
Qiao Longfei 已提交
415 416
        test_total_cost = 0
        test_total_token = 0
417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433

        if args.use_py_reader:
            pyreader.start()
            data_generator = None
        else:
            data_generator = test_data()
        while True:
            try:
                feed_dict_list = prepare_feed_dict_list(data_generator, False,
                                                        dev_count)
                outs = test_exe.run(fetch_list=[sum_cost.name, token_num.name],
                                    feed=feed_dict_list)
            except (StopIteration, fluid.core.EOFException):
                # The current pass is over.
                if args.use_py_reader:
                    pyreader.reset()
                break
Q
Qiao Longfei 已提交
434 435 436 437 438 439 440 441 442 443
            sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[1])
            test_total_cost += sum_cost_val.sum()
            test_total_token += token_num_val.sum()
        test_avg_cost = test_total_cost / test_total_token
        test_ppl = np.exp([min(test_avg_cost, 100)])
        return test_avg_cost, test_ppl

    return test


444
def train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, avg_cost,
G
fix  
gongweibao 已提交
445
               token_num, predict, pyreader, nccl2_num_trainers=1, nccl2_trainer_id=0):
Q
Qiao Longfei 已提交
446 447 448 449
    # Initialize the parameters.
    if TrainTaskConfig.ckpt_path:
        fluid.io.load_persistables(exe, TrainTaskConfig.ckpt_path)
    else:
G
fix  
gongweibao 已提交
450
        logging.info("init fluid.framework.default_startup_program")
451
        exe.run(startup_prog)
Q
Qiao Longfei 已提交
452

G
fix  
gongweibao 已提交
453
    logging.info("begin reader")
454 455
    train_data = prepare_data_generator(
        args, is_test=False, count=dev_count, pyreader=pyreader)
Q
Qiao Longfei 已提交
456

457 458 459 460
    # For faster executor
    exec_strategy = fluid.ExecutionStrategy()
    exec_strategy.use_experimental_executor = True
    # exec_strategy.num_iteration_per_drop_scope = 5
Q
Qiao Longfei 已提交
461 462 463 464
    build_strategy = fluid.BuildStrategy()
    # Since the token number differs among devices, customize gradient scale to
    # use token average cost among multi-devices. and the gradient scale is
    # `1 / token_number` for average cost.
G
fix  
gongweibao 已提交
465

G
fix  
gongweibao 已提交
466 467
    #build_strategy.gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.Customized

G
fix  
gongweibao 已提交
468 469 470 471
    exec_strategy = fluid.ExecutionStrategy()
    if args.update_method == "nccl2":
        exec_strategy.num_threads = 1

G
fix  
gongweibao 已提交
472
    logging.info("begin executor")
Q
Qiao Longfei 已提交
473 474
    train_exe = fluid.ParallelExecutor(
        use_cuda=TrainTaskConfig.use_gpu,
475 476 477
        loss_name=avg_cost.name,
        main_program=train_prog,
        build_strategy=build_strategy,
G
fix  
gongweibao 已提交
478 479
        exec_strategy=exec_strategy,
        num_trainers=nccl2_num_trainers, trainer_id=nccl2_trainer_id)
Q
Qiao Longfei 已提交
480 481

    if args.val_file_pattern is not None:
482
        test = test_context(exe, train_exe, dev_count)
Q
Qiao Longfei 已提交
483

G
guosheng 已提交
484 485 486 487 488 489
    # the best cross-entropy value with label smoothing
    loss_normalizer = -((1. - TrainTaskConfig.label_smooth_eps) * np.log(
        (1. - TrainTaskConfig.label_smooth_eps
         )) + TrainTaskConfig.label_smooth_eps *
                        np.log(TrainTaskConfig.label_smooth_eps / (
                            ModelHyperParams.trg_vocab_size - 1) + 1e-20))
G
guosheng 已提交
490 491

    step_idx = 0
492
    init_flag = True
G
fix  
gongweibao 已提交
493 494

    logging.info("begin train")
G
guosheng 已提交
495
    for pass_id in six.moves.xrange(TrainTaskConfig.pass_num):
Q
Qiao Longfei 已提交
496
        pass_start_time = time.time()
497 498 499 500 501 502 503 504

        if args.use_py_reader:
            pyreader.start()
            data_generator = None
        else:
            data_generator = train_data()

        batch_id = 0
G
fix  
gongweibao 已提交
505
        avg_batch_time=time.time()
506 507 508 509 510 511
        while True:
            try:
                feed_dict_list = prepare_feed_dict_list(data_generator,
                                                        init_flag, dev_count)

                outs = train_exe.run(
G
fix  
gongweibao 已提交
512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528
                    fetch_list=[sum_cost.name, token_num.name] if batch_id % args.fetch_steps == 0 else[], 
                        feed=feed_dict_list)
                
                if batch_id % args.fetch_steps == 0 and batch_id > 0:
                    sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[
                        1])
                    # sum the cost from multi-devices
                    total_sum_cost = sum_cost_val.sum()
                    total_token_num = token_num_val.sum()
                    total_avg_cost = total_sum_cost / total_token_num

                    logging.info("step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
                                 "normalized loss: %f, ppl: %f, speed: %.2f step/s" %
                          (step_idx, pass_id, batch_id, total_avg_cost,
                           total_avg_cost - loss_normalizer,
                           np.exp([min(total_avg_cost, 100)]), 
                           args.fetch_steps / (time.time() - avg_batch_time)))
529 530 531 532 533 534 535 536 537 538 539 540

                if step_idx % int(TrainTaskConfig.
                                  save_freq) == TrainTaskConfig.save_freq - 1:
                    fluid.io.save_persistables(
                        exe,
                        os.path.join(TrainTaskConfig.ckpt_dir,
                                     "latest.checkpoint"), train_prog)
                    fluid.io.save_params(
                        exe,
                        os.path.join(TrainTaskConfig.model_dir,
                                     "iter_" + str(step_idx) + ".infer.model"),
                        train_prog)
G
fix  
gongweibao 已提交
541 542
                if batch_id % args.fetch_steps == 0 and batch_id > 0:
                    avg_batch_time=time.time()
543 544 545 546 547 548 549 550
                init_flag = False
                batch_id += 1
                step_idx += 1
            except (StopIteration, fluid.core.EOFException):
                # The current pass is over.
                if args.use_py_reader:
                    pyreader.reset()
                break
G
guosheng 已提交
551 552

        time_consumed = time.time() - pass_start_time
553
        # Validate and save the persistable.
G
guosheng 已提交
554 555
        if args.val_file_pattern is not None:
            val_avg_cost, val_ppl = test()
G
fix  
gongweibao 已提交
556
            logging.info(
G
guosheng 已提交
557 558 559 560 561
                "epoch: %d, val avg loss: %f, val normalized loss: %f, val ppl: %f,"
                " consumed %fs" % (pass_id, val_avg_cost,
                                   val_avg_cost - loss_normalizer, val_ppl,
                                   time_consumed))
        else:
G
fix  
gongweibao 已提交
562
            logging.info("epoch: %d, consumed %fs" % (pass_id, time_consumed))
Q
Qiao Longfei 已提交
563 564 565
        fluid.io.save_persistables(
            exe,
            os.path.join(TrainTaskConfig.ckpt_dir,
566 567
                         "pass_" + str(pass_id) + ".checkpoint"), train_prog)

G
guosheng 已提交
568
    if args.enable_ce:  # For CE
569
        print("kpis\ttrain_cost_card%d\t%f" % (dev_count, total_avg_cost))
570 571
        if args.val_file_pattern is not None:
            print("kpis\ttest_cost_card%d\t%f" % (dev_count, val_avg_cost))
572
        print("kpis\ttrain_duration_card%d\t%f" % (dev_count, time_consumed))
Q
Qiao Longfei 已提交
573 574


575 576 577 578 579
def train(args):
    # priority: ENV > args > config
    is_local = os.getenv("PADDLE_IS_LOCAL", "1")
    if is_local == '0':
        args.local = False
G
fix  
gongweibao 已提交
580
    logging.info(args)
581

582 583
    if args.device == 'CPU':
        TrainTaskConfig.use_gpu = False
G
guosheng 已提交
584

585
    training_role = os.getenv("TRAINING_ROLE", "TRAINER")
G
guosheng 已提交
586

587 588 589 590 591 592 593 594
    if training_role == "PSERVER" or (not TrainTaskConfig.use_gpu):
        place = fluid.CPUPlace()
        dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
    else:
        place = fluid.CUDAPlace(0)
        dev_count = fluid.core.get_cuda_device_count()

    exe = fluid.Executor(place)
595

596 597
    train_prog = fluid.Program()
    startup_prog = fluid.Program()
G
guosheng 已提交
598

599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623
    if args.enable_ce:
        train_prog.random_seed = 1000
        startup_prog.random_seed = 1000

    with fluid.program_guard(train_prog, startup_prog):
        with fluid.unique_name.guard():
            sum_cost, avg_cost, predict, token_num, pyreader = transformer(
                ModelHyperParams.src_vocab_size,
                ModelHyperParams.trg_vocab_size,
                ModelHyperParams.max_length + 1,
                ModelHyperParams.n_layer,
                ModelHyperParams.n_head,
                ModelHyperParams.d_key,
                ModelHyperParams.d_value,
                ModelHyperParams.d_model,
                ModelHyperParams.d_inner_hid,
                ModelHyperParams.prepostprocess_dropout,
                ModelHyperParams.attention_dropout,
                ModelHyperParams.relu_dropout,
                ModelHyperParams.preprocess_cmd,
                ModelHyperParams.postprocess_cmd,
                ModelHyperParams.weight_sharing,
                TrainTaskConfig.label_smooth_eps,
                use_py_reader=args.use_py_reader,
                is_test=False)
624

625 626 627 628 629 630 631 632 633 634 635 636 637 638
            if args.local:
                lr_decay = fluid.layers.learning_rate_scheduler.noam_decay(
                    ModelHyperParams.d_model, TrainTaskConfig.warmup_steps)
                optimizer = fluid.optimizer.Adam(
                    learning_rate=lr_decay * TrainTaskConfig.learning_rate,
                    beta1=TrainTaskConfig.beta1,
                    beta2=TrainTaskConfig.beta2,
                    epsilon=TrainTaskConfig.eps)
            elif args.sync == False:
                optimizer = fluid.optimizer.SGD(0.003)
            optimizer.minimize(avg_cost)

    if args.use_mem_opt:
        fluid.memory_optimize(train_prog)
639 640 641

    if args.local:
        print("local start_up:")
642 643
        train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, avg_cost,
                   token_num, predict, pyreader)
644
    else:
G
fix  
gongweibao 已提交
645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664
        if args.update_method == "nccl2":
            trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
            port = os.getenv("PADDLE_PORT")
            worker_ips = os.getenv("PADDLE_TRAINERS")
            worker_endpoints = []
            for ip in worker_ips.split(","):
                worker_endpoints.append(':'.join([ip, port]))
            trainers_num = len(worker_endpoints)
            current_endpoint = os.getenv("POD_IP") + ":" + port
            if trainer_id == 0:
                logging.info("train_id == 0, sleep 60s")
                time.sleep(60)
            print("trainers_num:", trainers_num)
            print("worker_endpoints:", worker_endpoints)
            print("current_endpoint:", current_endpoint)
            append_nccl2_prepare(trainer_id, worker_endpoints, current_endpoint)
            train_loop(exe, fluid.default_main_program(), dev_count, sum_cost, avg_cost,
                       lr_scheduler, token_num, predict, trainers_num, trainer_id)
            return

665 666 667 668 669 670 671 672 673 674
        port = os.getenv("PADDLE_PORT", "6174")
        pserver_ips = os.getenv("PADDLE_PSERVERS")  # ip,ip...
        eplist = []
        for ip in pserver_ips.split(","):
            eplist.append(':'.join([ip, port]))
        pserver_endpoints = ",".join(eplist)  # ip:port,ip:port...
        trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "0"))
        current_endpoint = os.getenv("POD_IP") + ":" + port
        trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
        t = fluid.DistributeTranspiler()
675 676 677 678 679 680
        t.transpile(
            trainer_id,
            pservers=pserver_endpoints,
            trainers=trainers,
            program=train_prog,
            startup_program=startup_prog)
681 682

        if training_role == "PSERVER":
G
fix  
gongweibao 已提交
683
            loggin.info("distributed: pserver started")
684 685 686 687 688 689 690 691 692 693 694 695
            current_endpoint = os.getenv("POD_IP") + ":" + os.getenv(
                "PADDLE_PORT")
            if not current_endpoint:
                print("need env SERVER_ENDPOINT")
                exit(1)
            pserver_prog = t.get_pserver_program(current_endpoint)
            pserver_startup = t.get_startup_program(current_endpoint,
                                                    pserver_prog)

            exe.run(pserver_startup)
            exe.run(pserver_prog)
        elif training_role == "TRAINER":
G
fix  
gongweibao 已提交
696
            loggin.info("distributed: trainer started")
697
            trainer_prog = t.get_trainer_program()
698 699
            train_loop(exe, train_prog, startup_prog, dev_count, sum_cost,
                       avg_cost, token_num, predict, pyreader)
700
        else:
G
fix  
gongweibao 已提交
701 702
            logging.critical("environment var TRAINER_ROLE should be TRAINER os PSERVER")
            exit(1)
703 704 705


if __name__ == "__main__":
G
fix  
gongweibao 已提交
706 707 708
    LOG_FORMAT = "[%(asctime)s %(levelname)s %(filename)s:%(lineno)d] %(message)s"
    logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, format=LOG_FORMAT)

709 710
    args = parse_args()
    train(args)