train.py 21.5 KB
Newer Older
1 2
import argparse
import ast
3
import multiprocessing
Y
Yu Yang 已提交
4
import os
G
guosheng 已提交
5
import six
Y
Yu Yang 已提交
6
import time
Y
ying 已提交
7

Y
Yu Yang 已提交
8
import numpy as np
L
Luo Tao 已提交
9
import paddle.fluid as fluid
Y
ying 已提交
10

Y
Yu Yang 已提交
11 12
import reader
from config import *
13
from model import transformer, position_encoding_init
14
from optim import LearningRateScheduler
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46


def parse_args():
    parser = argparse.ArgumentParser("Training for Transformer.")
    parser.add_argument(
        "--src_vocab_fpath",
        type=str,
        required=True,
        help="The path of vocabulary file of source language.")
    parser.add_argument(
        "--trg_vocab_fpath",
        type=str,
        required=True,
        help="The path of vocabulary file of target language.")
    parser.add_argument(
        "--train_file_pattern",
        type=str,
        required=True,
        help="The pattern to match training data files.")
    parser.add_argument(
        "--val_file_pattern",
        type=str,
        help="The pattern to match validation data files.")
    parser.add_argument(
        "--use_token_batch",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to "
        "produce batch data according to token number.")
    parser.add_argument(
        "--batch_size",
        type=int,
47
        default=2048,
48
        help="The number of sequences contained in a mini-batch, or the maximum "
49 50 51
        "number of tokens (include paddings) contained in a mini-batch. Note "
        "that this represents the number on single device and the actual batch "
        "size for multi-devices will multiply the device number.")
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
    parser.add_argument(
        "--pool_size",
        type=int,
        default=10000,
        help="The buffer size to pool data.")
    parser.add_argument(
        "--sort_type",
        default="pool",
        choices=("global", "pool", "none"),
        help="The grain to sort by length: global for all instances; pool for "
        "instances in pool; none for no sort.")
    parser.add_argument(
        "--shuffle",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to shuffle instances in each pass.")
    parser.add_argument(
        "--shuffle_batch",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to shuffle the data batches.")
    parser.add_argument(
        "--special_token",
        type=str,
        default=["<s>", "<e>", "<unk>"],
        nargs=3,
        help="The <bos>, <eos> and <unk> tokens in the dictionary.")
79 80
    parser.add_argument(
        "--token_delimiter",
G
guosheng 已提交
81
        type=lambda x: str(x.encode().decode("unicode-escape")),
82 83
        default=" ",
        help="The delimiter used to split tokens in source or target sentences. "
84
        "For EN-DE BPE data we provided, use spaces as token delimiter. "
85
        "For EN-FR wordpiece data we provided, use '\x01' as token delimiter.")
86 87 88 89 90
    parser.add_argument(
        'opts',
        help='See config.py for all options',
        default=None,
        nargs=argparse.REMAINDER)
91 92 93 94 95 96 97 98 99 100 101
    parser.add_argument(
        '--local',
        type=ast.literal_eval,
        default=True,
        help='Whether to run as local mode.')
    parser.add_argument(
        '--device',
        type=str,
        default='GPU',
        choices=['CPU', 'GPU'],
        help="The device type.")
Q
Qiao Longfei 已提交
102 103
    parser.add_argument(
        '--sync', type=ast.literal_eval, default=True, help="sync mode.")
G
guosheng 已提交
104 105 106 107 108 109
    parser.add_argument(
        "--enable_ce",
        type=ast.literal_eval,
        default=True,
        help="The flag indicating whether to run the task "
        "for continuous evaluation.")
110

111
    args = parser.parse_args()
112 113 114 115 116 117 118 119 120 121 122
    # Append args related to dict
    src_dict = reader.DataReader.load_dict(args.src_vocab_fpath)
    trg_dict = reader.DataReader.load_dict(args.trg_vocab_fpath)
    dict_args = [
        "src_vocab_size", str(len(src_dict)), "trg_vocab_size",
        str(len(trg_dict)), "bos_idx", str(src_dict[args.special_token[0]]),
        "eos_idx", str(src_dict[args.special_token[1]]), "unk_idx",
        str(src_dict[args.special_token[2]])
    ]
    merge_cfg_from_list(args.opts + dict_args,
                        [TrainTaskConfig, ModelHyperParams])
123
    return args
124 125


126 127 128 129
def pad_batch_data(insts,
                   pad_idx,
                   n_head,
                   is_target=False,
130
                   is_label=False,
131
                   return_attn_bias=True,
132 133
                   return_max_len=True,
                   return_num_token=False):
134 135
    """
    Pad the instances to the max sequence length in batch, and generate the
136 137 138 139
    corresponding position data and attention bias.
    """
    return_list = []
    max_len = max(len(inst) for inst in insts)
G
guosheng 已提交
140 141 142 143
    # Any token included in dict can be used to pad, since the paddings' loss
    # will be masked out by weights and make no effect on parameter gradients.
    inst_data = np.array(
        [inst + [pad_idx] * (max_len - len(inst)) for inst in insts])
144
    return_list += [inst_data.astype("int64").reshape([-1, 1])]
145 146 147 148 149 150
    if is_label:  # label weight
        inst_weight = np.array(
            [[1.] * len(inst) + [0.] * (max_len - len(inst)) for inst in insts])
        return_list += [inst_weight.astype("float32").reshape([-1, 1])]
    else:  # position data
        inst_pos = np.array([
G
guosheng 已提交
151
            list(range(1, len(inst) + 1)) + [0] * (max_len - len(inst))
152 153
            for inst in insts
        ])
154 155 156 157 158 159
        return_list += [inst_pos.astype("int64").reshape([-1, 1])]
    if return_attn_bias:
        if is_target:
            # This is used to avoid attention on paddings and subsequent
            # words.
            slf_attn_bias_data = np.ones((inst_data.shape[0], max_len, max_len))
160 161
            slf_attn_bias_data = np.triu(slf_attn_bias_data,
                                         1).reshape([-1, 1, max_len, max_len])
162 163 164 165 166 167 168 169 170 171 172 173 174
            slf_attn_bias_data = np.tile(slf_attn_bias_data,
                                         [1, n_head, 1, 1]) * [-1e9]
        else:
            # This is used to avoid attention on paddings.
            slf_attn_bias_data = np.array([[0] * len(inst) + [-1e9] *
                                           (max_len - len(inst))
                                           for inst in insts])
            slf_attn_bias_data = np.tile(
                slf_attn_bias_data.reshape([-1, 1, 1, max_len]),
                [1, n_head, max_len, 1])
        return_list += [slf_attn_bias_data.astype("float32")]
    if return_max_len:
        return_list += [max_len]
175
    if return_num_token:
G
guosheng 已提交
176 177 178
        num_token = 0
        for inst in insts:
            num_token += len(inst)
179
        return_list += [num_token]
180 181 182
    return return_list if len(return_list) > 1 else return_list[0]


183 184
def prepare_batch_input(insts, data_input_names, src_pad_idx, trg_pad_idx,
                        n_head, d_model):
185 186
    """
    Put all padded data needed by training into a dict.
187
    """
188
    src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
G
guosheng 已提交
189
        [inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
190 191
    src_word = src_word.reshape(-1, src_max_len, 1)
    src_pos = src_pos.reshape(-1, src_max_len, 1)
192
    trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
G
guosheng 已提交
193
        [inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True)
194 195 196
    trg_word = trg_word.reshape(-1, trg_max_len, 1)
    trg_pos = trg_pos.reshape(-1, trg_max_len, 1)

197 198
    trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
                                [1, 1, trg_max_len, 1]).astype("float32")
199

200
    lbl_word, lbl_weight, num_token = pad_batch_data(
201 202 203 204 205 206
        [inst[2] for inst in insts],
        trg_pad_idx,
        n_head,
        is_target=False,
        is_label=True,
        return_attn_bias=False,
207 208 209 210 211 212 213
        return_max_len=False,
        return_num_token=True)

    data_input_dict = dict(
        zip(data_input_names, [
            src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
            trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight
214
        ]))
215
    return data_input_dict, np.asarray([num_token], dtype="float32")
216 217


Q
Qiao Longfei 已提交
218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259
def read_multiple(reader, count, clip_last=True):
    """
    Stack data from reader for multi-devices.
    """

    def __impl__():
        res = []
        for item in reader():
            res.append(item)
            if len(res) == count:
                yield res
                res = []
        if len(res) == count:
            yield res
        elif not clip_last:
            data = []
            for item in res:
                data += item
            if len(data) > count:
                inst_num_per_part = len(data) // count
                yield [
                    data[inst_num_per_part * i:inst_num_per_part * (i + 1)]
                    for i in range(count)
                ]

    return __impl__


def split_data(data, num_part):
    """
    Split data for each device.
    """
    if len(data) == num_part:
        return data
    data = data[0]
    inst_num_per_part = len(data) // num_part
    return [
        data[inst_num_per_part * i:inst_num_per_part * (i + 1)]
        for i in range(num_part)
    ]


260
def test_context(test_program, avg_cost, train_exe, dev_count, data_input_names,
G
guosheng 已提交
261
                 sum_cost, token_num):
Q
Qiao Longfei 已提交
262 263 264 265
    val_data = reader.DataReader(
        src_vocab_fpath=args.src_vocab_fpath,
        trg_vocab_fpath=args.trg_vocab_fpath,
        fpattern=args.val_file_pattern,
266
        token_delimiter=args.token_delimiter,
Q
Qiao Longfei 已提交
267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295
        use_token_batch=args.use_token_batch,
        batch_size=args.batch_size * (1 if args.use_token_batch else dev_count),
        pool_size=args.pool_size,
        sort_type=args.sort_type,
        start_mark=args.special_token[0],
        end_mark=args.special_token[1],
        unk_mark=args.special_token[2],
        # count start and end tokens out
        max_length=ModelHyperParams.max_length - 2,
        clip_last_batch=False,
        shuffle=False,
        shuffle_batch=False)

    test_exe = fluid.ParallelExecutor(
        use_cuda=TrainTaskConfig.use_gpu,
        main_program=test_program,
        share_vars_from=train_exe)

    def test(exe=test_exe):
        test_total_cost = 0
        test_total_token = 0
        test_data = read_multiple(
            reader=val_data.batch_generator,
            count=dev_count if args.use_token_batch else 1)
        for batch_id, data in enumerate(test_data()):
            feed_list = []
            for place_id, data_buffer in enumerate(
                    split_data(
                        data, num_part=dev_count)):
296
                data_input_dict, _ = prepare_batch_input(
G
guosheng 已提交
297 298 299
                    data_buffer, data_input_names, ModelHyperParams.eos_idx,
                    ModelHyperParams.eos_idx, ModelHyperParams.n_head,
                    ModelHyperParams.d_model)
300
                feed_list.append(data_input_dict)
Q
Qiao Longfei 已提交
301 302 303 304 305 306 307 308 309 310 311 312 313 314

            outs = exe.run(feed=feed_list,
                           fetch_list=[sum_cost.name, token_num.name])
            sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[1])
            test_total_cost += sum_cost_val.sum()
            test_total_token += token_num_val.sum()
        test_avg_cost = test_total_cost / test_total_token
        test_ppl = np.exp([min(test_avg_cost, 100)])
        return test_avg_cost, test_ppl

    return test


def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
315
               token_num, predict, test_program):
Q
Qiao Longfei 已提交
316 317 318 319 320
    # Initialize the parameters.
    if TrainTaskConfig.ckpt_path:
        fluid.io.load_persistables(exe, TrainTaskConfig.ckpt_path)
        lr_scheduler.current_steps = TrainTaskConfig.start_step
    else:
G
guosheng 已提交
321
        print("init fluid.framework.default_startup_program")
Q
Qiao Longfei 已提交
322 323 324 325 326 327
        exe.run(fluid.framework.default_startup_program())

    train_data = reader.DataReader(
        src_vocab_fpath=args.src_vocab_fpath,
        trg_vocab_fpath=args.trg_vocab_fpath,
        fpattern=args.train_file_pattern,
328
        token_delimiter=args.token_delimiter,
Q
Qiao Longfei 已提交
329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359
        use_token_batch=args.use_token_batch,
        batch_size=args.batch_size * (1 if args.use_token_batch else dev_count),
        pool_size=args.pool_size,
        sort_type=args.sort_type,
        shuffle=args.shuffle,
        shuffle_batch=args.shuffle_batch,
        start_mark=args.special_token[0],
        end_mark=args.special_token[1],
        unk_mark=args.special_token[2],
        # count start and end tokens out
        max_length=ModelHyperParams.max_length - 2,
        clip_last_batch=False)
    train_data = read_multiple(
        reader=train_data.batch_generator,
        count=dev_count if args.use_token_batch else 1)

    build_strategy = fluid.BuildStrategy()
    # Since the token number differs among devices, customize gradient scale to
    # use token average cost among multi-devices. and the gradient scale is
    # `1 / token_number` for average cost.
    build_strategy.gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.Customized
    train_exe = fluid.ParallelExecutor(
        use_cuda=TrainTaskConfig.use_gpu,
        loss_name=sum_cost.name,
        main_program=train_progm,
        build_strategy=build_strategy)

    data_input_names = encoder_data_input_fields + decoder_data_input_fields[:
                                                                             -1] + label_data_input_fields

    if args.val_file_pattern is not None:
360
        test = test_context(test_program, avg_cost, train_exe, dev_count,
G
guosheng 已提交
361
                            data_input_names, sum_cost, token_num)
Q
Qiao Longfei 已提交
362

G
guosheng 已提交
363 364 365 366 367 368
    # the best cross-entropy value with label smoothing
    loss_normalizer = -((1. - TrainTaskConfig.label_smooth_eps) * np.log(
        (1. - TrainTaskConfig.label_smooth_eps
         )) + TrainTaskConfig.label_smooth_eps *
                        np.log(TrainTaskConfig.label_smooth_eps / (
                            ModelHyperParams.trg_vocab_size - 1) + 1e-20))
G
guosheng 已提交
369 370 371

    step_idx = 0
    inst_num = 0
Q
Qiao Longfei 已提交
372
    init = False
G
guosheng 已提交
373
    for pass_id in six.moves.xrange(TrainTaskConfig.pass_num):
Q
Qiao Longfei 已提交
374 375 376 377
        pass_start_time = time.time()
        for batch_id, data in enumerate(train_data()):
            feed_list = []
            total_num_token = 0
378 379
            if args.local:
                lr_rate = lr_scheduler.update_learning_rate()
Q
Qiao Longfei 已提交
380 381 382
            for place_id, data_buffer in enumerate(
                    split_data(
                        data, num_part=dev_count)):
383 384 385 386
                data_input_dict, num_token = prepare_batch_input(
                    data_buffer, data_input_names, ModelHyperParams.eos_idx,
                    ModelHyperParams.eos_idx, ModelHyperParams.n_head,
                    ModelHyperParams.d_model)
Q
Qiao Longfei 已提交
387
                total_num_token += num_token
G
guosheng 已提交
388 389
                inst_num += len(data_buffer)
                feed_kv_pairs = list(data_input_dict.items())
Q
Qiao Longfei 已提交
390
                if args.local:
G
guosheng 已提交
391
                    feed_kv_pairs += list({
Q
Qiao Longfei 已提交
392
                        lr_scheduler.learning_rate.name: lr_rate
G
guosheng 已提交
393
                    }.items())
Q
Qiao Longfei 已提交
394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410
                feed_list.append(dict(feed_kv_pairs))

                if not init:
                    for pos_enc_param_name in pos_enc_param_names:
                        pos_enc = position_encoding_init(
                            ModelHyperParams.max_length + 1,
                            ModelHyperParams.d_model)
                        feed_list[place_id][pos_enc_param_name] = pos_enc
            for feed_dict in feed_list:
                feed_dict[sum_cost.name + "@GRAD"] = 1. / total_num_token
            outs = train_exe.run(fetch_list=[sum_cost.name, token_num.name],
                                 feed=feed_list)
            sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[1])
            total_sum_cost = sum_cost_val.sum(
            )  # sum the cost from multi-devices
            total_token_num = token_num_val.sum()
            total_avg_cost = total_sum_cost / total_token_num
G
guosheng 已提交
411 412 413 414 415 416
            print(
                "step_idx: %d, total samples: %d, epoch: %d, batch: %d, avg loss: %f, "
                "normalized loss: %f, ppl: %f" %
                (step_idx, inst_num, pass_id, batch_id, total_avg_cost,
                 total_avg_cost - loss_normalizer,
                 np.exp([min(total_avg_cost, 100)])))
G
guosheng 已提交
417 418 419 420
            if batch_id > 0 and batch_id % 1000 == 0:
                fluid.io.save_persistables(
                    exe,
                    os.path.join(TrainTaskConfig.ckpt_dir, "latest.checkpoint"))
G
guosheng 已提交
421
            step_idx += 1
Q
Qiao Longfei 已提交
422
            init = True
G
guosheng 已提交
423 424

        time_consumed = time.time() - pass_start_time
Q
Qiao Longfei 已提交
425
        # Validate and save the model for inference.
G
guosheng 已提交
426 427 428 429 430 431 432 433 434
        if args.val_file_pattern is not None:
            val_avg_cost, val_ppl = test()
            print(
                "epoch: %d, val avg loss: %f, val normalized loss: %f, val ppl: %f,"
                " consumed %fs" % (pass_id, val_avg_cost,
                                   val_avg_cost - loss_normalizer, val_ppl,
                                   time_consumed))
        else:
            print("epoch: %d, consumed %fs" % (pass_id, time_consumed))
Q
Qiao Longfei 已提交
435 436 437 438 439 440 441
        fluid.io.save_persistables(
            exe,
            os.path.join(TrainTaskConfig.ckpt_dir,
                         "pass_" + str(pass_id) + ".checkpoint"))
        fluid.io.save_inference_model(
            os.path.join(TrainTaskConfig.model_dir,
                         "pass_" + str(pass_id) + ".infer.model"),
442
            data_input_names[:-2], [predict], exe)
G
guosheng 已提交
443
    if args.enable_ce:  # For CE
444 445 446
        print("kpis\ttrain_cost_card%d\t%f" % (dev_count, total_avg_cost))
        print("kpis\ttest_cost_card%d\t%f" % (dev_count, val_avg_cost))
        print("kpis\ttrain_duration_card%d\t%f" % (dev_count, time_consumed))
Q
Qiao Longfei 已提交
447 448


449 450 451 452 453
def train(args):
    # priority: ENV > args > config
    is_local = os.getenv("PADDLE_IS_LOCAL", "1")
    if is_local == '0':
        args.local = False
G
guosheng 已提交
454
    print(args)
455

456 457
    if args.device == 'CPU':
        TrainTaskConfig.use_gpu = False
G
guosheng 已提交
458

459
    training_role = os.getenv("TRAINING_ROLE", "TRAINER")
G
guosheng 已提交
460

461 462 463 464 465 466 467 468
    if training_role == "PSERVER" or (not TrainTaskConfig.use_gpu):
        place = fluid.CPUPlace()
        dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
    else:
        place = fluid.CUDAPlace(0)
        dev_count = fluid.core.get_cuda_device_count()

    exe = fluid.Executor(place)
469

G
guosheng 已提交
470 471 472
    if args.enable_ce:
        fluid.default_startup_program().random_seed = 1000

G
guosheng 已提交
473
    sum_cost, avg_cost, predict, token_num = transformer(
G
guosheng 已提交
474
        ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size,
G
guosheng 已提交
475
        ModelHyperParams.max_length + 1, ModelHyperParams.n_layer,
G
guosheng 已提交
476 477
        ModelHyperParams.n_head, ModelHyperParams.d_key,
        ModelHyperParams.d_value, ModelHyperParams.d_model,
478
        ModelHyperParams.d_inner_hid, ModelHyperParams.dropout,
G
guosheng 已提交
479
        ModelHyperParams.weight_sharing, TrainTaskConfig.label_smooth_eps)
Q
Qiao Longfei 已提交
480 481 482 483
    lr_scheduler = LearningRateScheduler(ModelHyperParams.d_model,
                                         TrainTaskConfig.warmup_steps,
                                         TrainTaskConfig.learning_rate)

484 485
    test_program = fluid.default_main_program().clone(for_test=True)

486 487 488 489 490 491 492
    if args.local:
        optimizer = fluid.optimizer.Adam(
            learning_rate=lr_scheduler.learning_rate,
            beta1=TrainTaskConfig.beta1,
            beta2=TrainTaskConfig.beta2,
            epsilon=TrainTaskConfig.eps)
        optimizer.minimize(sum_cost)
Q
Qiao Longfei 已提交
493 494 495
    elif args.sync == False:
        optimizer = fluid.optimizer.SGD(0.003)
        optimizer.minimize(sum_cost)
496
    else:
497 498 499 500 501 502 503 504 505 506 507 508 509 510
        lr_decay = fluid.layers\
         .learning_rate_scheduler\
         .noam_decay(ModelHyperParams.d_model,
            TrainTaskConfig.warmup_steps)

        optimizer = fluid.optimizer.Adam(
            learning_rate=lr_decay,
            beta1=TrainTaskConfig.beta1,
            beta2=TrainTaskConfig.beta2,
            epsilon=TrainTaskConfig.eps)
        optimizer.minimize(sum_cost)

    if args.local:
        print("local start_up:")
Q
Qiao Longfei 已提交
511 512
        train_loop(exe,
                   fluid.default_main_program(), dev_count, sum_cost, avg_cost,
513
                   lr_scheduler, token_num, predict, test_program)
514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536
    else:
        port = os.getenv("PADDLE_PORT", "6174")
        pserver_ips = os.getenv("PADDLE_PSERVERS")  # ip,ip...
        eplist = []
        for ip in pserver_ips.split(","):
            eplist.append(':'.join([ip, port]))
        pserver_endpoints = ",".join(eplist)  # ip:port,ip:port...
        trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "0"))
        current_endpoint = os.getenv("POD_IP") + ":" + port
        trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
        t = fluid.DistributeTranspiler()
        t.transpile(trainer_id, pservers=pserver_endpoints, trainers=trainers)

        if training_role == "PSERVER":
            current_endpoint = os.getenv("POD_IP") + ":" + os.getenv(
                "PADDLE_PORT")
            if not current_endpoint:
                print("need env SERVER_ENDPOINT")
                exit(1)
            pserver_prog = t.get_pserver_program(current_endpoint)
            pserver_startup = t.get_startup_program(current_endpoint,
                                                    pserver_prog)

G
guosheng 已提交
537
            print("psserver begin run")
538 539 540 541 542 543 544 545 546 547 548
            with open('pserver_startup.desc', 'w') as f:
                f.write(str(pserver_startup))
            with open('pserver_prog.desc', 'w') as f:
                f.write(str(pserver_prog))
            exe.run(pserver_startup)
            exe.run(pserver_prog)
        elif training_role == "TRAINER":

            trainer_prog = t.get_trainer_program()
            with open('trainer_prog.desc', 'w') as f:
                f.write(str(trainer_prog))
Q
Qiao Longfei 已提交
549
            train_loop(exe, trainer_prog, dev_count, sum_cost, avg_cost,
550
                       lr_scheduler, token_num, predict, test_program)
551 552
        else:
            print("environment var TRAINER_ROLE should be TRAINER os PSERVER")
553 554 555


if __name__ == "__main__":
556 557
    args = parse_args()
    train(args)