未验证 提交 28064c2d 编写于 作者: D Dong Daxiang 提交者: GitHub

fix gen nccl id bug (#25669)

* fix gen nccl id bug
上级 e5bbffa8
......@@ -309,6 +309,7 @@ class Fleet(object):
loss, self._role_maker, self.user_defined_optimizer,
self.user_defined_strategy, valid_optimizer_list,
valid_graph_optimizer_list)
optimize_ops = []
params_grads = []
if meta_optimizer:
......
......@@ -78,14 +78,14 @@ class GraphExecutionOptimizer(MetaOptimizerBase):
# should fix the variable
def _setup_nccl_op(self, startup_program, main_program):
trainer_endpoints = self.role_maker.get_trainer_endpoints()
trainers = trainer_endpoints
trainer_id = self.role_maker.worker_index()
current_endpoint = self.role_maker.get_trainer_endpoints()[trainer_id]
trainer_endpoints_env = ",".join(trainer_endpoints)
trainers_num = self.role_maker.worker_num()
trainer_endpoints.remove(current_endpoint)
if trainer_id == 0:
wait_server_ready(trainer_endpoints)
other_trainer_endpoints = trainer_endpoints[:]
other_trainer_endpoints.remove(current_endpoint)
wait_server_ready(other_trainer_endpoints)
nccl_id_var = startup_program.global_block().create_var(
name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW)
for i in range(1, self.user_defined_strategy.nccl_comm_num):
......@@ -110,7 +110,7 @@ class GraphExecutionOptimizer(MetaOptimizerBase):
inputs={},
outputs={"NCCLID": nccl_id_var},
attrs={
"trainers": trainers,
"trainers": trainer_endpoints,
"trainer_id": trainer_id,
"nccl_comm_num": self.user_defined_strategy.nccl_comm_num,
"use_hierarchical_allreduce":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册