未验证 提交 28064c2d 编写于 作者: D Dong Daxiang 提交者: GitHub

fix gen nccl id bug (#25669)

* fix gen nccl id bug
上级 e5bbffa8
...@@ -309,6 +309,7 @@ class Fleet(object): ...@@ -309,6 +309,7 @@ class Fleet(object):
loss, self._role_maker, self.user_defined_optimizer, loss, self._role_maker, self.user_defined_optimizer,
self.user_defined_strategy, valid_optimizer_list, self.user_defined_strategy, valid_optimizer_list,
valid_graph_optimizer_list) valid_graph_optimizer_list)
optimize_ops = [] optimize_ops = []
params_grads = [] params_grads = []
if meta_optimizer: if meta_optimizer:
......
...@@ -78,14 +78,14 @@ class GraphExecutionOptimizer(MetaOptimizerBase): ...@@ -78,14 +78,14 @@ class GraphExecutionOptimizer(MetaOptimizerBase):
# should fix the variable # should fix the variable
def _setup_nccl_op(self, startup_program, main_program): def _setup_nccl_op(self, startup_program, main_program):
trainer_endpoints = self.role_maker.get_trainer_endpoints() trainer_endpoints = self.role_maker.get_trainer_endpoints()
trainers = trainer_endpoints
trainer_id = self.role_maker.worker_index() trainer_id = self.role_maker.worker_index()
current_endpoint = self.role_maker.get_trainer_endpoints()[trainer_id] current_endpoint = self.role_maker.get_trainer_endpoints()[trainer_id]
trainer_endpoints_env = ",".join(trainer_endpoints) trainer_endpoints_env = ",".join(trainer_endpoints)
trainers_num = self.role_maker.worker_num() trainers_num = self.role_maker.worker_num()
trainer_endpoints.remove(current_endpoint)
if trainer_id == 0: if trainer_id == 0:
wait_server_ready(trainer_endpoints) other_trainer_endpoints = trainer_endpoints[:]
other_trainer_endpoints.remove(current_endpoint)
wait_server_ready(other_trainer_endpoints)
nccl_id_var = startup_program.global_block().create_var( nccl_id_var = startup_program.global_block().create_var(
name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW) name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW)
for i in range(1, self.user_defined_strategy.nccl_comm_num): for i in range(1, self.user_defined_strategy.nccl_comm_num):
...@@ -110,7 +110,7 @@ class GraphExecutionOptimizer(MetaOptimizerBase): ...@@ -110,7 +110,7 @@ class GraphExecutionOptimizer(MetaOptimizerBase):
inputs={}, inputs={},
outputs={"NCCLID": nccl_id_var}, outputs={"NCCLID": nccl_id_var},
attrs={ attrs={
"trainers": trainers, "trainers": trainer_endpoints,
"trainer_id": trainer_id, "trainer_id": trainer_id,
"nccl_comm_num": self.user_defined_strategy.nccl_comm_num, "nccl_comm_num": self.user_defined_strategy.nccl_comm_num,
"use_hierarchical_allreduce": "use_hierarchical_allreduce":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册