未验证 提交 22db82c0 编写于 作者: W Wu Yi 提交者: GitHub

fix tangwei merge issue test=develop (#15506)

上级 dec89bd7
...@@ -159,7 +159,7 @@ class ParallelExecutor(object): ...@@ -159,7 +159,7 @@ class ParallelExecutor(object):
trainers_endpoints = main._trainers_endpoints trainers_endpoints = main._trainers_endpoints
if num_trainers > 1 and trainers_endpoints: if num_trainers > 1 and trainers_endpoints:
assert num_trainers == len( assert num_trainers == len(
trainers_endpoints), "num_trainers == len(end_points)" trainers_endpoints), "num_trainers == len(endpoints)"
build_strategy.trainers_endpoints = trainers_endpoints build_strategy.trainers_endpoints = trainers_endpoints
# step6: get persistable_vars, places. persistable_vars # step6: get persistable_vars, places. persistable_vars
......
...@@ -477,13 +477,16 @@ class DistributeTranspiler(object): ...@@ -477,13 +477,16 @@ class DistributeTranspiler(object):
trainer_id, trainer_id,
trainers, trainers,
current_endpoint, current_endpoint,
startup_program=None): startup_program=None,
wait_port=True):
if not startup_program: if not startup_program:
startup_program = default_startup_program() startup_program = default_startup_program()
if trainer_id >= 0: if trainer_id >= 0:
worker_endpoints = trainers.split(",") worker_endpoints = trainers.split(",")
# send NCCL_ID to others or recv from trainer 0 # send NCCL_ID to others or recv from trainer 0
worker_endpoints.remove(current_endpoint) worker_endpoints.remove(current_endpoint)
if trainer_id == 0 and wait_port:
wait_server_ready(worker_endpoints)
nccl_id_var = startup_program.global_block().create_var( nccl_id_var = startup_program.global_block().create_var(
name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW) name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW)
...@@ -564,11 +567,13 @@ class DistributeTranspiler(object): ...@@ -564,11 +567,13 @@ class DistributeTranspiler(object):
if self.config.mode == "nccl2": if self.config.mode == "nccl2":
assert (isinstance(trainers, str)) assert (isinstance(trainers, str))
self.origin_program._trainers_endpoints = trainers.split(",")
self._transpile_nccl2( self._transpile_nccl2(
trainer_id, trainer_id,
trainers, trainers,
current_endpoint, current_endpoint,
startup_program=startup_program) startup_program=startup_program,
wait_port=self.config.wait_port)
return return
self.trainer_num = trainers self.trainer_num = trainers
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册