diff --git a/python/paddle/fleet/base/fleet_base.py b/python/paddle/fleet/base/fleet_base.py index 46d06e5d026cae0cb2fcc78360b2b3d0faf0acd9..13b9fc3220a0911415b6abc6f4cf91298038e298 100644 --- a/python/paddle/fleet/base/fleet_base.py +++ b/python/paddle/fleet/base/fleet_base.py @@ -309,6 +309,7 @@ class Fleet(object): loss, self._role_maker, self.user_defined_optimizer, self.user_defined_strategy, valid_optimizer_list, valid_graph_optimizer_list) + optimize_ops = [] params_grads = [] if meta_optimizer: diff --git a/python/paddle/fleet/meta_optimizers/graph_execution_optimizer.py b/python/paddle/fleet/meta_optimizers/graph_execution_optimizer.py index 13d62a6d462a2b76a609739a63785922221bc3b0..cc3d1cd2128bd6b4e632e45e025cded7ace0fa41 100644 --- a/python/paddle/fleet/meta_optimizers/graph_execution_optimizer.py +++ b/python/paddle/fleet/meta_optimizers/graph_execution_optimizer.py @@ -78,14 +78,14 @@ class GraphExecutionOptimizer(MetaOptimizerBase): # should fix the variable def _setup_nccl_op(self, startup_program, main_program): trainer_endpoints = self.role_maker.get_trainer_endpoints() - trainers = trainer_endpoints trainer_id = self.role_maker.worker_index() current_endpoint = self.role_maker.get_trainer_endpoints()[trainer_id] trainer_endpoints_env = ",".join(trainer_endpoints) trainers_num = self.role_maker.worker_num() - trainer_endpoints.remove(current_endpoint) if trainer_id == 0: - wait_server_ready(trainer_endpoints) + other_trainer_endpoints = trainer_endpoints[:] + other_trainer_endpoints.remove(current_endpoint) + wait_server_ready(other_trainer_endpoints) nccl_id_var = startup_program.global_block().create_var( name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW) for i in range(1, self.user_defined_strategy.nccl_comm_num): @@ -110,7 +110,7 @@ class GraphExecutionOptimizer(MetaOptimizerBase): inputs={}, outputs={"NCCLID": nccl_id_var}, attrs={ - "trainers": trainers, + "trainers": trainer_endpoints, "trainer_id": trainer_id, "nccl_comm_num": self.user_defined_strategy.nccl_comm_num, "use_hierarchical_allreduce":