diff --git a/python/paddle/distributed/fleet/meta_optimizers/common.py b/python/paddle/distributed/fleet/meta_optimizers/common.py index 8ff4114bf8eda4080c252a736d7b6ee69990faa4..4ea0bd5585fb69b2b473a2d1b33bf060669ab5bc 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/common.py +++ b/python/paddle/distributed/fleet/meta_optimizers/common.py @@ -99,6 +99,12 @@ class CollectiveHelper(object): OP_ROLE_KEY: OpRole.Forward }) + def _wait(self, current_endpoint, endpoints, wait_port): + assert (wait_port) + other_endpoints = endpoints[:] + other_endpoints.remove(current_endpoint) + wait_server_ready(other_endpoints) + def _broadcast_params(self): block = self.startup_program.global_block() ring_id = -1 diff --git a/python/paddle/distributed/fleet/meta_optimizers/zero_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/zero_optimizer.py index 5856dd070faef62234ca4096518afbfc604b76ed..3f9222577cfd455352fecd6b91669bd37aa29d5a 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/zero_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/zero_optimizer.py @@ -196,6 +196,7 @@ class ZeroOptimizer(MetaOptimizerBase): core.VarDesc.VarType.BOOL: 1, core.VarDesc.VarType.UINT8: 1, } + self._collective_helper = None def _get_var_size(self, param): """ @@ -778,11 +779,12 @@ class ZeroOptimizer(MetaOptimizerBase): print("work idx: ", self.role_maker._worker_index()) endpoints = self.role_maker._get_trainer_endpoints() current_endpoint = endpoints[self.role_maker._worker_index()] - collective_helper = CollectiveHelper(self.role_maker, self._nrings) + self._collective_helper = CollectiveHelper(self.role_maker, + self._nrings) for ring_id in range(self._nrings): - collective_helper._init_communicator( + self._collective_helper._init_communicator( self._startup_program, current_endpoint, endpoints, - self.role_maker._worker_index(), ring_id, '6174') + self.role_maker._worker_index(), ring_id, None) startup_block = self._startup_program.global_block() startup_block._sync_with_cpp() @@ -794,6 +796,12 @@ class ZeroOptimizer(MetaOptimizerBase): self._fp16_params, self._broadcast_vars, self._fp16_to_params = self._find_broadcast_params( self._params, self._param2device) + def _wait(self, ): + endpoints = self.role_maker._get_trainer_endpoints() + current_endpoint = endpoints[self.role_maker._worker_index()] + if self.role_maker._worker_index() == 0: + self._collective_helper._wait(current_endpoint, endpoints, '6174') + def minimize_impl(self, loss, startup_program=None, @@ -855,6 +863,7 @@ class ZeroOptimizer(MetaOptimizerBase): # check op dependecy for broadcast self._check_broadcast(main_block) + self._wait() return optimize_ops, params_grads def _check_broadcast(self, block):