From 22db82c05358ce112cc4f93299da26f6b546a8cd Mon Sep 17 00:00:00 2001 From: Wu Yi Date: Thu, 24 Jan 2019 16:24:56 +0800 Subject: [PATCH] fix tangwei merge issue test=develop (#15506) --- python/paddle/fluid/parallel_executor.py | 2 +- python/paddle/fluid/transpiler/distribute_transpiler.py | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index a1b1d2f584..a07ff6ac69 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -159,7 +159,7 @@ class ParallelExecutor(object): trainers_endpoints = main._trainers_endpoints if num_trainers > 1 and trainers_endpoints: assert num_trainers == len( - trainers_endpoints), "num_trainers == len(end_points)" + trainers_endpoints), "num_trainers == len(endpoints)" build_strategy.trainers_endpoints = trainers_endpoints # step6: get persistable_vars, places. persistable_vars diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index c61cb54e1f..e58f34e375 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -477,13 +477,16 @@ class DistributeTranspiler(object): trainer_id, trainers, current_endpoint, - startup_program=None): + startup_program=None, + wait_port=True): if not startup_program: startup_program = default_startup_program() if trainer_id >= 0: worker_endpoints = trainers.split(",") # send NCCL_ID to others or recv from trainer 0 worker_endpoints.remove(current_endpoint) + if trainer_id == 0 and wait_port: + wait_server_ready(worker_endpoints) nccl_id_var = startup_program.global_block().create_var( name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW) @@ -564,11 +567,13 @@ class DistributeTranspiler(object): if self.config.mode == "nccl2": assert (isinstance(trainers, str)) + self.origin_program._trainers_endpoints = trainers.split(",") self._transpile_nccl2( trainer_id, trainers, current_endpoint, - startup_program=startup_program) + startup_program=startup_program, + wait_port=self.config.wait_port) return self.trainer_num = trainers -- GitLab