未验证 提交 22db82c0 编写于 作者: W Wu Yi 提交者: GitHub

fix tangwei merge issue test=develop (#15506)

上级 dec89bd7
......@@ -159,7 +159,7 @@ class ParallelExecutor(object):
trainers_endpoints = main._trainers_endpoints
if num_trainers > 1 and trainers_endpoints:
assert num_trainers == len(
trainers_endpoints), "num_trainers == len(end_points)"
trainers_endpoints), "num_trainers == len(endpoints)"
build_strategy.trainers_endpoints = trainers_endpoints
# step6: get persistable_vars, places. persistable_vars
......
......@@ -477,13 +477,16 @@ class DistributeTranspiler(object):
trainer_id,
trainers,
current_endpoint,
startup_program=None):
startup_program=None,
wait_port=True):
if not startup_program:
startup_program = default_startup_program()
if trainer_id >= 0:
worker_endpoints = trainers.split(",")
# send NCCL_ID to others or recv from trainer 0
worker_endpoints.remove(current_endpoint)
if trainer_id == 0 and wait_port:
wait_server_ready(worker_endpoints)
nccl_id_var = startup_program.global_block().create_var(
name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW)
......@@ -564,11 +567,13 @@ class DistributeTranspiler(object):
if self.config.mode == "nccl2":
assert (isinstance(trainers, str))
self.origin_program._trainers_endpoints = trainers.split(",")
self._transpile_nccl2(
trainer_id,
trainers,
current_endpoint,
startup_program=startup_program)
startup_program=startup_program,
wait_port=self.config.wait_port)
return
self.trainer_num = trainers
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册