diff --git a/fluid/PaddleNLP/neural_machine_translation/transformer/train.py b/fluid/PaddleNLP/neural_machine_translation/transformer/train.py index f1b1acdc4362ca02661a7a2ecbf5d626c859f1e6..0e9c18416f62c85e76dd060f1fad44073e5841fc 100644 --- a/fluid/PaddleNLP/neural_machine_translation/transformer/train.py +++ b/fluid/PaddleNLP/neural_machine_translation/transformer/train.py @@ -145,14 +145,15 @@ def parse_args(): return args -def append_nccl2_prepare(trainer_id, worker_endpoints, current_endpoint): +def append_nccl2_prepare(startup_prog, trainer_id, worker_endpoints, + current_endpoint): assert (trainer_id >= 0 and len(worker_endpoints) > 1 and current_endpoint in worker_endpoints) eps = copy.deepcopy(worker_endpoints) eps.remove(current_endpoint) - nccl_id_var = fluid.default_startup_program().global_block().create_var( + nccl_id_var = startup_prog.global_block().create_var( name="NCCLID", persistable=True, type=fluid.core.VarDesc.VarType.RAW) - fluid.default_startup_program().global_block().append_op( + startup_prog.global_block().append_op( type="gen_nccl_id", inputs={}, outputs={"NCCLID": nccl_id_var}, @@ -403,7 +404,7 @@ def test_context(exe, train_exe, dev_count): TrainTaskConfig.label_smooth_eps, use_py_reader=args.use_py_reader, is_test=True) - + test_prog = test_prog.clone(for_test=True) test_data = prepare_data_generator( args, is_test=True, count=dev_count, pyreader=pyreader) @@ -680,10 +681,11 @@ def train(args): logging.info("trainers_num:{}".format(trainers_num)) logging.info("worker_endpoints:{}".format(worker_endpoints)) logging.info("current_endpoint:{}".format(current_endpoint)) - append_nccl2_prepare(trainer_id, worker_endpoints, current_endpoint) - train_loop(exe, - fluid.default_main_program(), dev_count, sum_cost, - avg_cost, token_num, predict, trainers_num, trainer_id) + append_nccl2_prepare(startup_prog, trainer_id, worker_endpoints, + current_endpoint) + train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, + avg_cost, token_num, predict, pyreader, trainers_num, + trainer_id) return port = os.getenv("PADDLE_PORT", "6174")