train.py 24.5 KB
Newer Older
1 2
import argparse
import ast
3
import multiprocessing
Y
Yu Yang 已提交
4
import os
G
guosheng 已提交
5
import six
Y
Yu Yang 已提交
6
import time
Y
ying 已提交
7

Y
Yu Yang 已提交
8
import numpy as np
L
Luo Tao 已提交
9
import paddle.fluid as fluid
Y
ying 已提交
10

Y
Yu Yang 已提交
11 12
import reader
from config import *
13
from model import transformer, position_encoding_init
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45


def parse_args():
    parser = argparse.ArgumentParser("Training for Transformer.")
    parser.add_argument(
        "--src_vocab_fpath",
        type=str,
        required=True,
        help="The path of vocabulary file of source language.")
    parser.add_argument(
        "--trg_vocab_fpath",
        type=str,
        required=True,
        help="The path of vocabulary file of target language.")
    parser.add_argument(
        "--train_file_pattern",
        type=str,
        required=True,
        help="The pattern to match training data files.")
    parser.add_argument(
        "--val_file_pattern",
        type=str,
        help="The pattern to match validation data files.")
    parser.add_argument(
        "--use_token_batch",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to "
        "produce batch data according to token number.")
    parser.add_argument(
        "--batch_size",
        type=int,
46
        default=4096,
47
        help="The number of sequences contained in a mini-batch, or the maximum "
48 49 50
        "number of tokens (include paddings) contained in a mini-batch. Note "
        "that this represents the number on single device and the actual batch "
        "size for multi-devices will multiply the device number.")
51 52 53
    parser.add_argument(
        "--pool_size",
        type=int,
54
        default=200000,
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
        help="The buffer size to pool data.")
    parser.add_argument(
        "--sort_type",
        default="pool",
        choices=("global", "pool", "none"),
        help="The grain to sort by length: global for all instances; pool for "
        "instances in pool; none for no sort.")
    parser.add_argument(
        "--shuffle",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to shuffle instances in each pass.")
    parser.add_argument(
        "--shuffle_batch",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to shuffle the data batches.")
    parser.add_argument(
        "--special_token",
        type=str,
        default=["<s>", "<e>", "<unk>"],
        nargs=3,
        help="The <bos>, <eos> and <unk> tokens in the dictionary.")
78 79
    parser.add_argument(
        "--token_delimiter",
G
guosheng 已提交
80
        type=lambda x: str(x.encode().decode("unicode-escape")),
81 82
        default=" ",
        help="The delimiter used to split tokens in source or target sentences. "
83
        "For EN-DE BPE data we provided, use spaces as token delimiter. "
84
        "For EN-FR wordpiece data we provided, use '\x01' as token delimiter.")
85 86 87 88 89
    parser.add_argument(
        'opts',
        help='See config.py for all options',
        default=None,
        nargs=argparse.REMAINDER)
90 91 92 93 94 95 96 97 98 99 100
    parser.add_argument(
        '--local',
        type=ast.literal_eval,
        default=True,
        help='Whether to run as local mode.')
    parser.add_argument(
        '--device',
        type=str,
        default='GPU',
        choices=['CPU', 'GPU'],
        help="The device type.")
Q
Qiao Longfei 已提交
101 102
    parser.add_argument(
        '--sync', type=ast.literal_eval, default=True, help="sync mode.")
G
guosheng 已提交
103 104 105
    parser.add_argument(
        "--enable_ce",
        type=ast.literal_eval,
106
        default=False,
G
guosheng 已提交
107 108
        help="The flag indicating whether to run the task "
        "for continuous evaluation.")
109 110 111 112 113 114 115 116 117 118
    parser.add_argument(
        "--use_mem_opt",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to use memory optimization.")
    parser.add_argument(
        "--use_py_reader",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to use py_reader.")
119

120
    args = parser.parse_args()
121 122 123 124 125 126 127 128 129 130 131
    # Append args related to dict
    src_dict = reader.DataReader.load_dict(args.src_vocab_fpath)
    trg_dict = reader.DataReader.load_dict(args.trg_vocab_fpath)
    dict_args = [
        "src_vocab_size", str(len(src_dict)), "trg_vocab_size",
        str(len(trg_dict)), "bos_idx", str(src_dict[args.special_token[0]]),
        "eos_idx", str(src_dict[args.special_token[1]]), "unk_idx",
        str(src_dict[args.special_token[2]])
    ]
    merge_cfg_from_list(args.opts + dict_args,
                        [TrainTaskConfig, ModelHyperParams])
132
    return args
133 134


135 136 137 138
def pad_batch_data(insts,
                   pad_idx,
                   n_head,
                   is_target=False,
139
                   is_label=False,
140
                   return_attn_bias=True,
141 142
                   return_max_len=True,
                   return_num_token=False):
143 144
    """
    Pad the instances to the max sequence length in batch, and generate the
145 146 147 148
    corresponding position data and attention bias.
    """
    return_list = []
    max_len = max(len(inst) for inst in insts)
G
guosheng 已提交
149 150 151 152
    # Any token included in dict can be used to pad, since the paddings' loss
    # will be masked out by weights and make no effect on parameter gradients.
    inst_data = np.array(
        [inst + [pad_idx] * (max_len - len(inst)) for inst in insts])
153
    return_list += [inst_data.astype("int64").reshape([-1, 1])]
154 155 156 157 158 159
    if is_label:  # label weight
        inst_weight = np.array(
            [[1.] * len(inst) + [0.] * (max_len - len(inst)) for inst in insts])
        return_list += [inst_weight.astype("float32").reshape([-1, 1])]
    else:  # position data
        inst_pos = np.array([
160
            list(range(0, len(inst))) + [0] * (max_len - len(inst))
161 162
            for inst in insts
        ])
163 164 165 166 167 168
        return_list += [inst_pos.astype("int64").reshape([-1, 1])]
    if return_attn_bias:
        if is_target:
            # This is used to avoid attention on paddings and subsequent
            # words.
            slf_attn_bias_data = np.ones((inst_data.shape[0], max_len, max_len))
169 170
            slf_attn_bias_data = np.triu(slf_attn_bias_data,
                                         1).reshape([-1, 1, max_len, max_len])
171 172 173 174 175 176 177 178 179 180 181 182 183
            slf_attn_bias_data = np.tile(slf_attn_bias_data,
                                         [1, n_head, 1, 1]) * [-1e9]
        else:
            # This is used to avoid attention on paddings.
            slf_attn_bias_data = np.array([[0] * len(inst) + [-1e9] *
                                           (max_len - len(inst))
                                           for inst in insts])
            slf_attn_bias_data = np.tile(
                slf_attn_bias_data.reshape([-1, 1, 1, max_len]),
                [1, n_head, max_len, 1])
        return_list += [slf_attn_bias_data.astype("float32")]
    if return_max_len:
        return_list += [max_len]
184
    if return_num_token:
G
guosheng 已提交
185 186 187
        num_token = 0
        for inst in insts:
            num_token += len(inst)
188
        return_list += [num_token]
189 190 191
    return return_list if len(return_list) > 1 else return_list[0]


192 193
def prepare_batch_input(insts, data_input_names, src_pad_idx, trg_pad_idx,
                        n_head, d_model):
194 195
    """
    Put all padded data needed by training into a dict.
196
    """
197
    src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
G
guosheng 已提交
198
        [inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
199 200
    src_word = src_word.reshape(-1, src_max_len, 1)
    src_pos = src_pos.reshape(-1, src_max_len, 1)
201
    trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
G
guosheng 已提交
202
        [inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True)
203 204 205
    trg_word = trg_word.reshape(-1, trg_max_len, 1)
    trg_pos = trg_pos.reshape(-1, trg_max_len, 1)

206 207
    trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
                                [1, 1, trg_max_len, 1]).astype("float32")
208

209
    lbl_word, lbl_weight, num_token = pad_batch_data(
210 211 212 213 214 215
        [inst[2] for inst in insts],
        trg_pad_idx,
        n_head,
        is_target=False,
        is_label=True,
        return_attn_bias=False,
216 217 218 219 220 221 222
        return_max_len=False,
        return_num_token=True)

    data_input_dict = dict(
        zip(data_input_names, [
            src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
            trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight
223
        ]))
224

225
    return data_input_dict, np.asarray([num_token], dtype="float32")
226 227


228
def prepare_data_generator(args, is_test, count, pyreader):
Q
Qiao Longfei 已提交
229
    """
230 231
    Data generator wrapper for DataReader. If use py_reader, set the data
    provider for py_reader
Q
Qiao Longfei 已提交
232
    """
233 234
    data_reader = reader.DataReader(
        fpattern=args.val_file_pattern if is_test else args.train_file_pattern,
Q
Qiao Longfei 已提交
235 236
        src_vocab_fpath=args.src_vocab_fpath,
        trg_vocab_fpath=args.trg_vocab_fpath,
237
        token_delimiter=args.token_delimiter,
Q
Qiao Longfei 已提交
238
        use_token_batch=args.use_token_batch,
239
        batch_size=args.batch_size * (1 if args.use_token_batch else count),
Q
Qiao Longfei 已提交
240 241
        pool_size=args.pool_size,
        sort_type=args.sort_type,
242 243
        shuffle=args.shuffle,
        shuffle_batch=args.shuffle_batch,
Q
Qiao Longfei 已提交
244 245 246 247 248
        start_mark=args.special_token[0],
        end_mark=args.special_token[1],
        unk_mark=args.special_token[2],
        # count start and end tokens out
        max_length=ModelHyperParams.max_length - 2,
249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330
        clip_last_batch=False).batch_generator

    def stack(data_reader, count, clip_last=True):
        def __impl__():
            res = []
            for item in data_reader():
                res.append(item)
                if len(res) == count:
                    yield res
                    res = []
            if len(res) == count:
                yield res
            elif not clip_last:
                data = []
                for item in res:
                    data += item
                if len(data) > count:
                    inst_num_per_part = len(data) // count
                    yield [
                        data[inst_num_per_part * i:inst_num_per_part * (i + 1)]
                        for i in range(count)
                    ]

        return __impl__

    def split(data_reader, count):
        def __impl__():
            for item in data_reader():
                inst_num_per_part = len(item) // count
                for i in range(count):
                    yield item[inst_num_per_part * i:inst_num_per_part * (i + 1
                                                                          )]

        return __impl__

    if not args.use_token_batch:
        # to make data on each device have similar token number
        data_reader = split(data_reader, count)
    if args.use_py_reader:
        pyreader.decorate_tensor_provider(
            py_reader_provider_wrapper(data_reader))
        data_reader = None
    else:  # Data generator for multi-devices
        data_reader = stack(data_reader, count)
    return data_reader


def prepare_feed_dict_list(data_generator, init_flag, count):
    """
    Prepare the list of feed dict for multi-devices.
    """
    feed_dict_list = []
    if data_generator is not None:  # use_py_reader == False
        data_input_names = encoder_data_input_fields + \
                    decoder_data_input_fields[:-1] + label_data_input_fields
        data = next(data_generator)
        for idx, data_buffer in enumerate(data):
            data_input_dict, num_token = prepare_batch_input(
                data_buffer, data_input_names, ModelHyperParams.eos_idx,
                ModelHyperParams.eos_idx, ModelHyperParams.n_head,
                ModelHyperParams.d_model)
            feed_dict_list.append(data_input_dict)
    if init_flag:
        for idx in range(count):
            pos_enc_tables = dict()
            for pos_enc_param_name in pos_enc_param_names:
                pos_enc_tables[pos_enc_param_name] = position_encoding_init(
                    ModelHyperParams.max_length + 1, ModelHyperParams.d_model)
            if len(feed_dict_list) <= idx:
                feed_dict_list.append(pos_enc_tables)
            else:
                feed_dict_list[idx] = dict(
                    list(pos_enc_tables.items()) + list(feed_dict_list[idx]
                                                        .items()))

    return feed_dict_list if len(feed_dict_list) == count else None


def py_reader_provider_wrapper(data_reader):
    """
    Data provider needed by fluid.layers.py_reader.
    """
Q
Qiao Longfei 已提交
331

332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375
    def py_reader_provider():
        data_input_names = encoder_data_input_fields + \
                    decoder_data_input_fields[:-1] + label_data_input_fields
        for batch_id, data in enumerate(data_reader()):
            data_input_dict, num_token = prepare_batch_input(
                data, data_input_names, ModelHyperParams.eos_idx,
                ModelHyperParams.eos_idx, ModelHyperParams.n_head,
                ModelHyperParams.d_model)
            total_dict = dict(data_input_dict.items())
            yield [total_dict[item] for item in data_input_names]

    return py_reader_provider


def test_context(exe, train_exe, dev_count):
    # Context to do validation.
    startup_prog = fluid.Program()
    test_prog = fluid.Program()
    with fluid.program_guard(test_prog, startup_prog):
        with fluid.unique_name.guard():
            sum_cost, avg_cost, predict, token_num, pyreader = transformer(
                ModelHyperParams.src_vocab_size,
                ModelHyperParams.trg_vocab_size,
                ModelHyperParams.max_length + 1,
                ModelHyperParams.n_layer,
                ModelHyperParams.n_head,
                ModelHyperParams.d_key,
                ModelHyperParams.d_value,
                ModelHyperParams.d_model,
                ModelHyperParams.d_inner_hid,
                ModelHyperParams.prepostprocess_dropout,
                ModelHyperParams.attention_dropout,
                ModelHyperParams.relu_dropout,
                ModelHyperParams.preprocess_cmd,
                ModelHyperParams.postprocess_cmd,
                ModelHyperParams.weight_sharing,
                TrainTaskConfig.label_smooth_eps,
                use_py_reader=args.use_py_reader,
                is_test=True)

    test_data = prepare_data_generator(
        args, is_test=True, count=dev_count, pyreader=pyreader)

    exe.run(startup_prog)
Q
Qiao Longfei 已提交
376 377
    test_exe = fluid.ParallelExecutor(
        use_cuda=TrainTaskConfig.use_gpu,
378
        main_program=test_prog,
Q
Qiao Longfei 已提交
379 380
        share_vars_from=train_exe)

381
    def test(exe=test_exe, pyreader=pyreader):
Q
Qiao Longfei 已提交
382 383
        test_total_cost = 0
        test_total_token = 0
384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400

        if args.use_py_reader:
            pyreader.start()
            data_generator = None
        else:
            data_generator = test_data()
        while True:
            try:
                feed_dict_list = prepare_feed_dict_list(data_generator, False,
                                                        dev_count)
                outs = test_exe.run(fetch_list=[sum_cost.name, token_num.name],
                                    feed=feed_dict_list)
            except (StopIteration, fluid.core.EOFException):
                # The current pass is over.
                if args.use_py_reader:
                    pyreader.reset()
                break
Q
Qiao Longfei 已提交
401 402 403 404 405 406 407 408 409 410
            sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[1])
            test_total_cost += sum_cost_val.sum()
            test_total_token += token_num_val.sum()
        test_avg_cost = test_total_cost / test_total_token
        test_ppl = np.exp([min(test_avg_cost, 100)])
        return test_avg_cost, test_ppl

    return test


411 412
def train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, avg_cost,
               token_num, predict, pyreader):
Q
Qiao Longfei 已提交
413 414 415 416
    # Initialize the parameters.
    if TrainTaskConfig.ckpt_path:
        fluid.io.load_persistables(exe, TrainTaskConfig.ckpt_path)
    else:
G
guosheng 已提交
417
        print("init fluid.framework.default_startup_program")
418
        exe.run(startup_prog)
Q
Qiao Longfei 已提交
419

420 421
    train_data = prepare_data_generator(
        args, is_test=False, count=dev_count, pyreader=pyreader)
Q
Qiao Longfei 已提交
422

423 424 425 426
    # For faster executor
    exec_strategy = fluid.ExecutionStrategy()
    exec_strategy.use_experimental_executor = True
    # exec_strategy.num_iteration_per_drop_scope = 5
Q
Qiao Longfei 已提交
427 428 429 430 431 432 433
    build_strategy = fluid.BuildStrategy()
    # Since the token number differs among devices, customize gradient scale to
    # use token average cost among multi-devices. and the gradient scale is
    # `1 / token_number` for average cost.
    build_strategy.gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.Customized
    train_exe = fluid.ParallelExecutor(
        use_cuda=TrainTaskConfig.use_gpu,
434 435 436 437
        loss_name=avg_cost.name,
        main_program=train_prog,
        build_strategy=build_strategy,
        exec_strategy=exec_strategy)
Q
Qiao Longfei 已提交
438 439

    if args.val_file_pattern is not None:
440
        test = test_context(exe, train_exe, dev_count)
Q
Qiao Longfei 已提交
441

G
guosheng 已提交
442 443 444 445 446 447
    # the best cross-entropy value with label smoothing
    loss_normalizer = -((1. - TrainTaskConfig.label_smooth_eps) * np.log(
        (1. - TrainTaskConfig.label_smooth_eps
         )) + TrainTaskConfig.label_smooth_eps *
                        np.log(TrainTaskConfig.label_smooth_eps / (
                            ModelHyperParams.trg_vocab_size - 1) + 1e-20))
G
guosheng 已提交
448 449

    step_idx = 0
450
    init_flag = True
G
guosheng 已提交
451
    for pass_id in six.moves.xrange(TrainTaskConfig.pass_num):
Q
Qiao Longfei 已提交
452
        pass_start_time = time.time()
453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500

        if args.use_py_reader:
            pyreader.start()
            data_generator = None
        else:
            data_generator = train_data()

        batch_id = 0
        while True:
            try:
                feed_dict_list = prepare_feed_dict_list(data_generator,
                                                        init_flag, dev_count)

                outs = train_exe.run(
                    fetch_list=[sum_cost.name, token_num.name],
                    feed=feed_dict_list)
                sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[
                    1])
                # sum the cost from multi-devices
                total_sum_cost = sum_cost_val.sum()
                total_token_num = token_num_val.sum()
                total_avg_cost = total_sum_cost / total_token_num

                print("step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
                      "normalized loss: %f, ppl: %f" %
                      (step_idx, pass_id, batch_id, total_avg_cost,
                       total_avg_cost - loss_normalizer,
                       np.exp([min(total_avg_cost, 100)])))

                if step_idx % int(TrainTaskConfig.
                                  save_freq) == TrainTaskConfig.save_freq - 1:
                    fluid.io.save_persistables(
                        exe,
                        os.path.join(TrainTaskConfig.ckpt_dir,
                                     "latest.checkpoint"), train_prog)
                    fluid.io.save_params(
                        exe,
                        os.path.join(TrainTaskConfig.model_dir,
                                     "iter_" + str(step_idx) + ".infer.model"),
                        train_prog)
                init_flag = False
                batch_id += 1
                step_idx += 1
            except (StopIteration, fluid.core.EOFException):
                # The current pass is over.
                if args.use_py_reader:
                    pyreader.reset()
                break
G
guosheng 已提交
501 502

        time_consumed = time.time() - pass_start_time
503
        # Validate and save the persistable.
G
guosheng 已提交
504 505 506 507 508 509 510 511 512
        if args.val_file_pattern is not None:
            val_avg_cost, val_ppl = test()
            print(
                "epoch: %d, val avg loss: %f, val normalized loss: %f, val ppl: %f,"
                " consumed %fs" % (pass_id, val_avg_cost,
                                   val_avg_cost - loss_normalizer, val_ppl,
                                   time_consumed))
        else:
            print("epoch: %d, consumed %fs" % (pass_id, time_consumed))
Q
Qiao Longfei 已提交
513 514 515
        fluid.io.save_persistables(
            exe,
            os.path.join(TrainTaskConfig.ckpt_dir,
516 517
                         "pass_" + str(pass_id) + ".checkpoint"), train_prog)

G
guosheng 已提交
518
    if args.enable_ce:  # For CE
519
        print("kpis\ttrain_cost_card%d\t%f" % (dev_count, total_avg_cost))
520 521
        if args.val_file_pattern is not None:
            print("kpis\ttest_cost_card%d\t%f" % (dev_count, val_avg_cost))
522
        print("kpis\ttrain_duration_card%d\t%f" % (dev_count, time_consumed))
Q
Qiao Longfei 已提交
523 524


525 526 527 528 529
def train(args):
    # priority: ENV > args > config
    is_local = os.getenv("PADDLE_IS_LOCAL", "1")
    if is_local == '0':
        args.local = False
G
guosheng 已提交
530
    print(args)
531

532 533
    if args.device == 'CPU':
        TrainTaskConfig.use_gpu = False
G
guosheng 已提交
534

535
    training_role = os.getenv("TRAINING_ROLE", "TRAINER")
G
guosheng 已提交
536

537 538 539 540 541 542 543 544
    if training_role == "PSERVER" or (not TrainTaskConfig.use_gpu):
        place = fluid.CPUPlace()
        dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
    else:
        place = fluid.CUDAPlace(0)
        dev_count = fluid.core.get_cuda_device_count()

    exe = fluid.Executor(place)
545

546 547
    train_prog = fluid.Program()
    startup_prog = fluid.Program()
G
guosheng 已提交
548

549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573
    if args.enable_ce:
        train_prog.random_seed = 1000
        startup_prog.random_seed = 1000

    with fluid.program_guard(train_prog, startup_prog):
        with fluid.unique_name.guard():
            sum_cost, avg_cost, predict, token_num, pyreader = transformer(
                ModelHyperParams.src_vocab_size,
                ModelHyperParams.trg_vocab_size,
                ModelHyperParams.max_length + 1,
                ModelHyperParams.n_layer,
                ModelHyperParams.n_head,
                ModelHyperParams.d_key,
                ModelHyperParams.d_value,
                ModelHyperParams.d_model,
                ModelHyperParams.d_inner_hid,
                ModelHyperParams.prepostprocess_dropout,
                ModelHyperParams.attention_dropout,
                ModelHyperParams.relu_dropout,
                ModelHyperParams.preprocess_cmd,
                ModelHyperParams.postprocess_cmd,
                ModelHyperParams.weight_sharing,
                TrainTaskConfig.label_smooth_eps,
                use_py_reader=args.use_py_reader,
                is_test=False)
574

575 576 577 578 579 580 581 582 583 584 585 586 587 588
            if args.local:
                lr_decay = fluid.layers.learning_rate_scheduler.noam_decay(
                    ModelHyperParams.d_model, TrainTaskConfig.warmup_steps)
                optimizer = fluid.optimizer.Adam(
                    learning_rate=lr_decay * TrainTaskConfig.learning_rate,
                    beta1=TrainTaskConfig.beta1,
                    beta2=TrainTaskConfig.beta2,
                    epsilon=TrainTaskConfig.eps)
            elif args.sync == False:
                optimizer = fluid.optimizer.SGD(0.003)
            optimizer.minimize(avg_cost)

    if args.use_mem_opt:
        fluid.memory_optimize(train_prog)
589 590 591

    if args.local:
        print("local start_up:")
592 593
        train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, avg_cost,
                   token_num, predict, pyreader)
594 595 596 597 598 599 600 601 602 603 604
    else:
        port = os.getenv("PADDLE_PORT", "6174")
        pserver_ips = os.getenv("PADDLE_PSERVERS")  # ip,ip...
        eplist = []
        for ip in pserver_ips.split(","):
            eplist.append(':'.join([ip, port]))
        pserver_endpoints = ",".join(eplist)  # ip:port,ip:port...
        trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "0"))
        current_endpoint = os.getenv("POD_IP") + ":" + port
        trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
        t = fluid.DistributeTranspiler()
605 606 607 608 609 610
        t.transpile(
            trainer_id,
            pservers=pserver_endpoints,
            trainers=trainers,
            program=train_prog,
            startup_program=startup_prog)
611 612 613 614 615 616 617 618 619 620 621

        if training_role == "PSERVER":
            current_endpoint = os.getenv("POD_IP") + ":" + os.getenv(
                "PADDLE_PORT")
            if not current_endpoint:
                print("need env SERVER_ENDPOINT")
                exit(1)
            pserver_prog = t.get_pserver_program(current_endpoint)
            pserver_startup = t.get_startup_program(current_endpoint,
                                                    pserver_prog)

G
guosheng 已提交
622
            print("psserver begin run")
623 624 625 626 627 628 629 630 631 632
            with open('pserver_startup.desc', 'w') as f:
                f.write(str(pserver_startup))
            with open('pserver_prog.desc', 'w') as f:
                f.write(str(pserver_prog))
            exe.run(pserver_startup)
            exe.run(pserver_prog)
        elif training_role == "TRAINER":
            trainer_prog = t.get_trainer_program()
            with open('trainer_prog.desc', 'w') as f:
                f.write(str(trainer_prog))
633 634
            train_loop(exe, train_prog, startup_prog, dev_count, sum_cost,
                       avg_cost, token_num, predict, pyreader)
635 636
        else:
            print("environment var TRAINER_ROLE should be TRAINER os PSERVER")
637 638 639


if __name__ == "__main__":
640 641
    args = parse_args()
    train(args)