diff --git a/fluid/neural_machine_translation/transformer/train.py b/fluid/neural_machine_translation/transformer/train.py index ba8f4c24ce92b470bf4347736a4a4fa1d5b1e9da..50cb93d0c84b72b6a84e083e7abfaa8259d26395 100644 --- a/fluid/neural_machine_translation/transformer/train.py +++ b/fluid/neural_machine_translation/transformer/train.py @@ -13,6 +13,8 @@ import reader from config import * from model import transformer, position_encoding_init +from paddle.fluid.transpiler.details import program_to_code + import logging import sys import copy @@ -642,8 +644,13 @@ def train(args): if args.sync: lr_decay = fluid.layers.learning_rate_scheduler.noam_decay( ModelHyperParams.d_model, TrainTaskConfig.warmup_steps) + print("before adam") + + with fluid.default_main_program()._lr_schedule_guard(): + learning_rate = lr_decay * TrainTaskConfig.learning_rate + optimizer = fluid.optimizer.Adam( - learning_rate=lr_decay * TrainTaskConfig.learning_rate, + learning_rate=learning_rate, beta1=TrainTaskConfig.beta1, beta2=TrainTaskConfig.beta2, epsilon=TrainTaskConfig.eps) @@ -688,6 +695,13 @@ def train(args): trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "0")) current_endpoint = os.getenv("POD_IP") + ":" + port trainer_id = int(os.getenv("PADDLE_TRAINER_ID")) + + print("pserver_endpoints", pserver_endpoints) + print("current_endpoint", current_endpoint) + print("trainer_id", trainer_id) + print("pserver_ips", pserver_ips) + print("port", port) + t = fluid.DistributeTranspiler() t.transpile( trainer_id, @@ -707,11 +721,26 @@ def train(args): pserver_startup = t.get_startup_program(current_endpoint, pserver_prog) + print("pserver start:") + program_to_code(pserver_startup) + print("pserver train:") + program_to_code(pserver_prog) + #sys.exit(0) + exe.run(pserver_startup) exe.run(pserver_prog) elif training_role == "TRAINER": logging.info("distributed: trainer started") trainer_prog = t.get_trainer_program() + + ''' + print("trainer start:") + program_to_code(pserver_startup) + print("trainer train:") + program_to_code(trainer_prog) + sys.exit(0) + ''' + train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, avg_cost, token_num, predict, pyreader) else: