“9a41986aae36188f084797a7139ea9a91d649374”上不存在“projects/Jasonchenlijian/imports.yml”
train.py 22.8 KB
Newer Older
1 2
import argparse
import ast
3
import multiprocessing
Y
Yu Yang 已提交
4 5
import os
import time
6
from functools import partial
Y
ying 已提交
7

Y
Yu Yang 已提交
8
import numpy as np
L
Luo Tao 已提交
9
import paddle.fluid as fluid
Y
ying 已提交
10

Y
Yu Yang 已提交
11 12
import reader
from config import *
13
from model import transformer, position_encoding_init
14
from optim import LearningRateScheduler
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46


def parse_args():
    parser = argparse.ArgumentParser("Training for Transformer.")
    parser.add_argument(
        "--src_vocab_fpath",
        type=str,
        required=True,
        help="The path of vocabulary file of source language.")
    parser.add_argument(
        "--trg_vocab_fpath",
        type=str,
        required=True,
        help="The path of vocabulary file of target language.")
    parser.add_argument(
        "--train_file_pattern",
        type=str,
        required=True,
        help="The pattern to match training data files.")
    parser.add_argument(
        "--val_file_pattern",
        type=str,
        help="The pattern to match validation data files.")
    parser.add_argument(
        "--use_token_batch",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to "
        "produce batch data according to token number.")
    parser.add_argument(
        "--batch_size",
        type=int,
47
        default=2048,
48
        help="The number of sequences contained in a mini-batch, or the maximum "
49 50 51
        "number of tokens (include paddings) contained in a mini-batch. Note "
        "that this represents the number on single device and the actual batch "
        "size for multi-devices will multiply the device number.")
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
    parser.add_argument(
        "--pool_size",
        type=int,
        default=10000,
        help="The buffer size to pool data.")
    parser.add_argument(
        "--sort_type",
        default="pool",
        choices=("global", "pool", "none"),
        help="The grain to sort by length: global for all instances; pool for "
        "instances in pool; none for no sort.")
    parser.add_argument(
        "--shuffle",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to shuffle instances in each pass.")
    parser.add_argument(
        "--shuffle_batch",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to shuffle the data batches.")
    parser.add_argument(
        "--special_token",
        type=str,
        default=["<s>", "<e>", "<unk>"],
        nargs=3,
        help="The <bos>, <eos> and <unk> tokens in the dictionary.")
79 80
    parser.add_argument(
        "--token_delimiter",
81 82
        type=partial(
            str.decode, encoding="string-escape"),
83 84
        default=" ",
        help="The delimiter used to split tokens in source or target sentences. "
85
        "For EN-DE BPE data we provided, use spaces as token delimiter. "
86
        "For EN-FR wordpiece data we provided, use '\x01' as token delimiter.")
87 88 89 90 91
    parser.add_argument(
        'opts',
        help='See config.py for all options',
        default=None,
        nargs=argparse.REMAINDER)
92 93 94 95 96 97 98 99 100 101 102
    parser.add_argument(
        '--local',
        type=ast.literal_eval,
        default=True,
        help='Whether to run as local mode.')
    parser.add_argument(
        '--device',
        type=str,
        default='GPU',
        choices=['CPU', 'GPU'],
        help="The device type.")
Q
Qiao Longfei 已提交
103 104
    parser.add_argument(
        '--sync', type=ast.literal_eval, default=True, help="sync mode.")
G
guosheng 已提交
105 106 107 108 109 110
    parser.add_argument(
        "--enable_ce",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to run the task "
        "for continuous evaluation.")
111

112
    args = parser.parse_args()
113 114 115 116 117 118 119 120 121 122 123
    # Append args related to dict
    src_dict = reader.DataReader.load_dict(args.src_vocab_fpath)
    trg_dict = reader.DataReader.load_dict(args.trg_vocab_fpath)
    dict_args = [
        "src_vocab_size", str(len(src_dict)), "trg_vocab_size",
        str(len(trg_dict)), "bos_idx", str(src_dict[args.special_token[0]]),
        "eos_idx", str(src_dict[args.special_token[1]]), "unk_idx",
        str(src_dict[args.special_token[2]])
    ]
    merge_cfg_from_list(args.opts + dict_args,
                        [TrainTaskConfig, ModelHyperParams])
124
    return args
125 126


127 128 129 130
def pad_batch_data(insts,
                   pad_idx,
                   n_head,
                   is_target=False,
131
                   is_label=False,
132
                   return_attn_bias=True,
133 134
                   return_max_len=True,
                   return_num_token=False):
135 136
    """
    Pad the instances to the max sequence length in batch, and generate the
137 138 139 140
    corresponding position data and attention bias.
    """
    return_list = []
    max_len = max(len(inst) for inst in insts)
141 142
    num_token = reduce(lambda x, y: x + y,
                       [len(inst) for inst in insts]) if return_num_token else 0
G
guosheng 已提交
143 144 145 146
    # Any token included in dict can be used to pad, since the paddings' loss
    # will be masked out by weights and make no effect on parameter gradients.
    inst_data = np.array(
        [inst + [pad_idx] * (max_len - len(inst)) for inst in insts])
147
    return_list += [inst_data.astype("int64").reshape([-1, 1])]
148 149 150 151 152 153 154 155 156
    if is_label:  # label weight
        inst_weight = np.array(
            [[1.] * len(inst) + [0.] * (max_len - len(inst)) for inst in insts])
        return_list += [inst_weight.astype("float32").reshape([-1, 1])]
    else:  # position data
        inst_pos = np.array([
            range(1, len(inst) + 1) + [0] * (max_len - len(inst))
            for inst in insts
        ])
157 158 159 160 161 162
        return_list += [inst_pos.astype("int64").reshape([-1, 1])]
    if return_attn_bias:
        if is_target:
            # This is used to avoid attention on paddings and subsequent
            # words.
            slf_attn_bias_data = np.ones((inst_data.shape[0], max_len, max_len))
163 164
            slf_attn_bias_data = np.triu(slf_attn_bias_data,
                                         1).reshape([-1, 1, max_len, max_len])
165 166 167 168 169 170 171 172 173 174 175 176 177
            slf_attn_bias_data = np.tile(slf_attn_bias_data,
                                         [1, n_head, 1, 1]) * [-1e9]
        else:
            # This is used to avoid attention on paddings.
            slf_attn_bias_data = np.array([[0] * len(inst) + [-1e9] *
                                           (max_len - len(inst))
                                           for inst in insts])
            slf_attn_bias_data = np.tile(
                slf_attn_bias_data.reshape([-1, 1, 1, max_len]),
                [1, n_head, max_len, 1])
        return_list += [slf_attn_bias_data.astype("float32")]
    if return_max_len:
        return_list += [max_len]
178 179
    if return_num_token:
        return_list += [num_token]
180 181 182
    return return_list if len(return_list) > 1 else return_list[0]


183 184
def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
                        trg_pad_idx, n_head, d_model):
185 186
    """
    Put all padded data needed by training into a dict.
187
    """
188
    src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
G
guosheng 已提交
189
        [inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
190
    trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
G
guosheng 已提交
191
        [inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True)
192 193
    trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
                                [1, 1, trg_max_len, 1]).astype("float32")
194 195

    # These shape tensors are used in reshape_op.
196 197
    src_data_shape = np.array([-1, src_max_len, d_model], dtype="int32")
    trg_data_shape = np.array([-1, trg_max_len, d_model], dtype="int32")
G
guosheng 已提交
198 199 200
    src_slf_attn_pre_softmax_shape = np.array(
        [-1, src_slf_attn_bias.shape[-1]], dtype="int32")
    src_slf_attn_post_softmax_shape = np.array(
201
        [-1] + list(src_slf_attn_bias.shape[1:]), dtype="int32")
G
guosheng 已提交
202 203 204
    trg_slf_attn_pre_softmax_shape = np.array(
        [-1, trg_slf_attn_bias.shape[-1]], dtype="int32")
    trg_slf_attn_post_softmax_shape = np.array(
205
        [-1] + list(trg_slf_attn_bias.shape[1:]), dtype="int32")
G
guosheng 已提交
206 207 208
    trg_src_attn_pre_softmax_shape = np.array(
        [-1, trg_src_attn_bias.shape[-1]], dtype="int32")
    trg_src_attn_post_softmax_shape = np.array(
209
        [-1] + list(trg_src_attn_bias.shape[1:]), dtype="int32")
210

211
    lbl_word, lbl_weight, num_token = pad_batch_data(
212 213 214 215 216 217
        [inst[2] for inst in insts],
        trg_pad_idx,
        n_head,
        is_target=False,
        is_label=True,
        return_attn_bias=False,
218 219 220 221 222 223 224
        return_max_len=False,
        return_num_token=True)

    data_input_dict = dict(
        zip(data_input_names, [
            src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
            trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight
225
        ]))
226 227 228 229 230 231 232 233 234
    util_input_dict = dict(
        zip(util_input_names, [
            src_data_shape, src_slf_attn_pre_softmax_shape,
            src_slf_attn_post_softmax_shape, trg_data_shape,
            trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape,
            trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape
        ]))
    return data_input_dict, util_input_dict, np.asarray(
        [num_token], dtype="float32")
235 236


Q
Qiao Longfei 已提交
237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289
def read_multiple(reader, count, clip_last=True):
    """
    Stack data from reader for multi-devices.
    """

    def __impl__():
        res = []
        for item in reader():
            res.append(item)
            if len(res) == count:
                yield res
                res = []
        if len(res) == count:
            yield res
        elif not clip_last:
            data = []
            for item in res:
                data += item
            if len(data) > count:
                inst_num_per_part = len(data) // count
                yield [
                    data[inst_num_per_part * i:inst_num_per_part * (i + 1)]
                    for i in range(count)
                ]

    return __impl__


def split_data(data, num_part):
    """
    Split data for each device.
    """
    if len(data) == num_part:
        return data
    data = data[0]
    inst_num_per_part = len(data) // num_part
    return [
        data[inst_num_per_part * i:inst_num_per_part * (i + 1)]
        for i in range(num_part)
    ]


def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names,
                 util_input_names, sum_cost, token_num):
    # Context to do validation.
    test_program = train_progm.clone()
    with fluid.program_guard(test_program):
        test_program = fluid.io.get_inference_program([avg_cost])

    val_data = reader.DataReader(
        src_vocab_fpath=args.src_vocab_fpath,
        trg_vocab_fpath=args.trg_vocab_fpath,
        fpattern=args.val_file_pattern,
290
        token_delimiter=args.token_delimiter,
Q
Qiao Longfei 已提交
291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352
        use_token_batch=args.use_token_batch,
        batch_size=args.batch_size * (1 if args.use_token_batch else dev_count),
        pool_size=args.pool_size,
        sort_type=args.sort_type,
        start_mark=args.special_token[0],
        end_mark=args.special_token[1],
        unk_mark=args.special_token[2],
        # count start and end tokens out
        max_length=ModelHyperParams.max_length - 2,
        clip_last_batch=False,
        shuffle=False,
        shuffle_batch=False)

    test_exe = fluid.ParallelExecutor(
        use_cuda=TrainTaskConfig.use_gpu,
        main_program=test_program,
        share_vars_from=train_exe)

    def test(exe=test_exe):
        test_total_cost = 0
        test_total_token = 0
        test_data = read_multiple(
            reader=val_data.batch_generator,
            count=dev_count if args.use_token_batch else 1)
        for batch_id, data in enumerate(test_data()):
            feed_list = []
            for place_id, data_buffer in enumerate(
                    split_data(
                        data, num_part=dev_count)):
                data_input_dict, util_input_dict, _ = prepare_batch_input(
                    data_buffer, data_input_names, util_input_names,
                    ModelHyperParams.eos_idx, ModelHyperParams.eos_idx,
                    ModelHyperParams.n_head, ModelHyperParams.d_model)
                feed_list.append(
                    dict(data_input_dict.items() + util_input_dict.items()))

            outs = exe.run(feed=feed_list,
                           fetch_list=[sum_cost.name, token_num.name])
            sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[1])
            test_total_cost += sum_cost_val.sum()
            test_total_token += token_num_val.sum()
        test_avg_cost = test_total_cost / test_total_token
        test_ppl = np.exp([min(test_avg_cost, 100)])
        return test_avg_cost, test_ppl

    return test


def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
               token_num, predict):
    # Initialize the parameters.
    if TrainTaskConfig.ckpt_path:
        fluid.io.load_persistables(exe, TrainTaskConfig.ckpt_path)
        lr_scheduler.current_steps = TrainTaskConfig.start_step
    else:
        print "init fluid.framework.default_startup_program"
        exe.run(fluid.framework.default_startup_program())

    train_data = reader.DataReader(
        src_vocab_fpath=args.src_vocab_fpath,
        trg_vocab_fpath=args.trg_vocab_fpath,
        fpattern=args.train_file_pattern,
353
        token_delimiter=args.token_delimiter,
Q
Qiao Longfei 已提交
354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389
        use_token_batch=args.use_token_batch,
        batch_size=args.batch_size * (1 if args.use_token_batch else dev_count),
        pool_size=args.pool_size,
        sort_type=args.sort_type,
        shuffle=args.shuffle,
        shuffle_batch=args.shuffle_batch,
        start_mark=args.special_token[0],
        end_mark=args.special_token[1],
        unk_mark=args.special_token[2],
        # count start and end tokens out
        max_length=ModelHyperParams.max_length - 2,
        clip_last_batch=False)
    train_data = read_multiple(
        reader=train_data.batch_generator,
        count=dev_count if args.use_token_batch else 1)

    build_strategy = fluid.BuildStrategy()
    # Since the token number differs among devices, customize gradient scale to
    # use token average cost among multi-devices. and the gradient scale is
    # `1 / token_number` for average cost.
    build_strategy.gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.Customized
    train_exe = fluid.ParallelExecutor(
        use_cuda=TrainTaskConfig.use_gpu,
        loss_name=sum_cost.name,
        main_program=train_progm,
        build_strategy=build_strategy)

    data_input_names = encoder_data_input_fields + decoder_data_input_fields[:
                                                                             -1] + label_data_input_fields
    util_input_names = encoder_util_input_fields + decoder_util_input_fields

    if args.val_file_pattern is not None:
        test = test_context(train_progm, avg_cost, train_exe, dev_count,
                            data_input_names, util_input_names, sum_cost,
                            token_num)

G
guosheng 已提交
390 391 392 393 394 395
    # the best cross-entropy value with label smoothing
    loss_normalizer = -((1. - TrainTaskConfig.label_smooth_eps) * np.log(
        (1. - TrainTaskConfig.label_smooth_eps
         )) + TrainTaskConfig.label_smooth_eps *
                        np.log(TrainTaskConfig.label_smooth_eps / (
                            ModelHyperParams.trg_vocab_size - 1) + 1e-20))
Q
Qiao Longfei 已提交
396 397 398 399 400 401
    init = False
    for pass_id in xrange(TrainTaskConfig.pass_num):
        pass_start_time = time.time()
        for batch_id, data in enumerate(train_data()):
            feed_list = []
            total_num_token = 0
402 403
            if args.local:
                lr_rate = lr_scheduler.update_learning_rate()
Q
Qiao Longfei 已提交
404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434
            for place_id, data_buffer in enumerate(
                    split_data(
                        data, num_part=dev_count)):
                data_input_dict, util_input_dict, num_token = prepare_batch_input(
                    data_buffer, data_input_names, util_input_names,
                    ModelHyperParams.eos_idx, ModelHyperParams.eos_idx,
                    ModelHyperParams.n_head, ModelHyperParams.d_model)
                total_num_token += num_token
                feed_kv_pairs = data_input_dict.items() + util_input_dict.items(
                )
                if args.local:
                    feed_kv_pairs += {
                        lr_scheduler.learning_rate.name: lr_rate
                    }.items()
                feed_list.append(dict(feed_kv_pairs))

                if not init:
                    for pos_enc_param_name in pos_enc_param_names:
                        pos_enc = position_encoding_init(
                            ModelHyperParams.max_length + 1,
                            ModelHyperParams.d_model)
                        feed_list[place_id][pos_enc_param_name] = pos_enc
            for feed_dict in feed_list:
                feed_dict[sum_cost.name + "@GRAD"] = 1. / total_num_token
            outs = train_exe.run(fetch_list=[sum_cost.name, token_num.name],
                                 feed=feed_list)
            sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[1])
            total_sum_cost = sum_cost_val.sum(
            )  # sum the cost from multi-devices
            total_token_num = token_num_val.sum()
            total_avg_cost = total_sum_cost / total_token_num
G
guosheng 已提交
435 436 437 438
            print("epoch: %d, batch: %d, avg loss: %f, normalized loss: %f,"
                  " ppl: %f" % (pass_id, batch_id, total_avg_cost,
                                total_avg_cost - loss_normalizer,
                                np.exp([min(total_avg_cost, 100)])))
G
guosheng 已提交
439 440 441 442
            if batch_id > 0 and batch_id % 1000 == 0:
                fluid.io.save_persistables(
                    exe,
                    os.path.join(TrainTaskConfig.ckpt_dir, "latest.checkpoint"))
Q
Qiao Longfei 已提交
443
            init = True
G
guosheng 已提交
444 445

        time_consumed = time.time() - pass_start_time
Q
Qiao Longfei 已提交
446
        # Validate and save the model for inference.
G
guosheng 已提交
447 448 449 450 451 452 453 454 455
        if args.val_file_pattern is not None:
            val_avg_cost, val_ppl = test()
            print(
                "epoch: %d, val avg loss: %f, val normalized loss: %f, val ppl: %f,"
                " consumed %fs" % (pass_id, val_avg_cost,
                                   val_avg_cost - loss_normalizer, val_ppl,
                                   time_consumed))
        else:
            print("epoch: %d, consumed %fs" % (pass_id, time_consumed))
Q
Qiao Longfei 已提交
456 457 458 459 460 461 462 463
        fluid.io.save_persistables(
            exe,
            os.path.join(TrainTaskConfig.ckpt_dir,
                         "pass_" + str(pass_id) + ".checkpoint"))
        fluid.io.save_inference_model(
            os.path.join(TrainTaskConfig.model_dir,
                         "pass_" + str(pass_id) + ".infer.model"),
            data_input_names[:-2] + util_input_names, [predict], exe)
G
guosheng 已提交
464
    if args.enable_ce:  # For CE
465 466 467
        print("kpis\ttrain_cost_card%d\t%f" % (dev_count, total_avg_cost))
        print("kpis\ttest_cost_card%d\t%f" % (dev_count, val_avg_cost))
        print("kpis\ttrain_duration_card%d\t%f" % (dev_count, time_consumed))
Q
Qiao Longfei 已提交
468 469


470 471 472 473 474 475
def train(args):
    # priority: ENV > args > config
    is_local = os.getenv("PADDLE_IS_LOCAL", "1")
    if is_local == '0':
        args.local = False
    print args
476

477 478
    if args.device == 'CPU':
        TrainTaskConfig.use_gpu = False
G
guosheng 已提交
479

480
    training_role = os.getenv("TRAINING_ROLE", "TRAINER")
G
guosheng 已提交
481

482 483 484 485 486 487 488 489
    if training_role == "PSERVER" or (not TrainTaskConfig.use_gpu):
        place = fluid.CPUPlace()
        dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
    else:
        place = fluid.CUDAPlace(0)
        dev_count = fluid.core.get_cuda_device_count()

    exe = fluid.Executor(place)
490

G
guosheng 已提交
491 492 493
    if args.enable_ce:
        fluid.default_startup_program().random_seed = 1000

G
guosheng 已提交
494
    sum_cost, avg_cost, predict, token_num = transformer(
G
guosheng 已提交
495
        ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size,
G
guosheng 已提交
496
        ModelHyperParams.max_length + 1, ModelHyperParams.n_layer,
G
guosheng 已提交
497 498
        ModelHyperParams.n_head, ModelHyperParams.d_key,
        ModelHyperParams.d_value, ModelHyperParams.d_model,
499
        ModelHyperParams.d_inner_hid, ModelHyperParams.dropout,
G
guosheng 已提交
500
        ModelHyperParams.weight_sharing, TrainTaskConfig.label_smooth_eps)
Q
Qiao Longfei 已提交
501 502 503 504
    lr_scheduler = LearningRateScheduler(ModelHyperParams.d_model,
                                         TrainTaskConfig.warmup_steps,
                                         TrainTaskConfig.learning_rate)

505 506 507 508 509 510 511
    if args.local:
        optimizer = fluid.optimizer.Adam(
            learning_rate=lr_scheduler.learning_rate,
            beta1=TrainTaskConfig.beta1,
            beta2=TrainTaskConfig.beta2,
            epsilon=TrainTaskConfig.eps)
        optimizer.minimize(sum_cost)
Q
Qiao Longfei 已提交
512 513 514
    elif args.sync == False:
        optimizer = fluid.optimizer.SGD(0.003)
        optimizer.minimize(sum_cost)
515
    else:
516 517 518 519 520 521 522 523 524 525 526 527 528 529
        lr_decay = fluid.layers\
         .learning_rate_scheduler\
         .noam_decay(ModelHyperParams.d_model,
            TrainTaskConfig.warmup_steps)

        optimizer = fluid.optimizer.Adam(
            learning_rate=lr_decay,
            beta1=TrainTaskConfig.beta1,
            beta2=TrainTaskConfig.beta2,
            epsilon=TrainTaskConfig.eps)
        optimizer.minimize(sum_cost)

    if args.local:
        print("local start_up:")
Q
Qiao Longfei 已提交
530 531 532
        train_loop(exe,
                   fluid.default_main_program(), dev_count, sum_cost, avg_cost,
                   lr_scheduler, token_num, predict)
533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567
    else:
        port = os.getenv("PADDLE_PORT", "6174")
        pserver_ips = os.getenv("PADDLE_PSERVERS")  # ip,ip...
        eplist = []
        for ip in pserver_ips.split(","):
            eplist.append(':'.join([ip, port]))
        pserver_endpoints = ",".join(eplist)  # ip:port,ip:port...
        trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "0"))
        current_endpoint = os.getenv("POD_IP") + ":" + port
        trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
        t = fluid.DistributeTranspiler()
        t.transpile(trainer_id, pservers=pserver_endpoints, trainers=trainers)

        if training_role == "PSERVER":
            current_endpoint = os.getenv("POD_IP") + ":" + os.getenv(
                "PADDLE_PORT")
            if not current_endpoint:
                print("need env SERVER_ENDPOINT")
                exit(1)
            pserver_prog = t.get_pserver_program(current_endpoint)
            pserver_startup = t.get_startup_program(current_endpoint,
                                                    pserver_prog)

            print "psserver begin run"
            with open('pserver_startup.desc', 'w') as f:
                f.write(str(pserver_startup))
            with open('pserver_prog.desc', 'w') as f:
                f.write(str(pserver_prog))
            exe.run(pserver_startup)
            exe.run(pserver_prog)
        elif training_role == "TRAINER":

            trainer_prog = t.get_trainer_program()
            with open('trainer_prog.desc', 'w') as f:
                f.write(str(trainer_prog))
Q
Qiao Longfei 已提交
568 569
            train_loop(exe, trainer_prog, dev_count, sum_cost, avg_cost,
                       lr_scheduler, token_num, predict)
570 571
        else:
            print("environment var TRAINER_ROLE should be TRAINER os PSERVER")
572 573 574


if __name__ == "__main__":
575 576
    args = parse_args()
    train(args)