From e665655c5c591b27262752591b1244fdcf5036dd Mon Sep 17 00:00:00 2001 From: mapingshuo Date: Wed, 23 Sep 2020 14:51:20 +0800 Subject: [PATCH] move _wait to end --- .../distributed/fleet/meta_optimizers/common.py | 6 ++++++ .../fleet/meta_optimizers/zero_optimizer.py | 15 ++++++++++++--- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/common.py b/python/paddle/distributed/fleet/meta_optimizers/common.py index 8ff4114bf8..4ea0bd5585 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/common.py +++ b/python/paddle/distributed/fleet/meta_optimizers/common.py @@ -99,6 +99,12 @@ class CollectiveHelper(object): OP_ROLE_KEY: OpRole.Forward }) + def _wait(self, current_endpoint, endpoints, wait_port): + assert (wait_port) + other_endpoints = endpoints[:] + other_endpoints.remove(current_endpoint) + wait_server_ready(other_endpoints) + def _broadcast_params(self): block = self.startup_program.global_block() ring_id = -1 diff --git a/python/paddle/distributed/fleet/meta_optimizers/zero_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/zero_optimizer.py index 5856dd070f..3f9222577c 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/zero_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/zero_optimizer.py @@ -196,6 +196,7 @@ class ZeroOptimizer(MetaOptimizerBase): core.VarDesc.VarType.BOOL: 1, core.VarDesc.VarType.UINT8: 1, } + self._collective_helper = None def _get_var_size(self, param): """ @@ -778,11 +779,12 @@ class ZeroOptimizer(MetaOptimizerBase): print("work idx: ", self.role_maker._worker_index()) endpoints = self.role_maker._get_trainer_endpoints() current_endpoint = endpoints[self.role_maker._worker_index()] - collective_helper = CollectiveHelper(self.role_maker, self._nrings) + self._collective_helper = CollectiveHelper(self.role_maker, + self._nrings) for ring_id in range(self._nrings): - collective_helper._init_communicator( + self._collective_helper._init_communicator( self._startup_program, current_endpoint, endpoints, - self.role_maker._worker_index(), ring_id, '6174') + self.role_maker._worker_index(), ring_id, None) startup_block = self._startup_program.global_block() startup_block._sync_with_cpp() @@ -794,6 +796,12 @@ class ZeroOptimizer(MetaOptimizerBase): self._fp16_params, self._broadcast_vars, self._fp16_to_params = self._find_broadcast_params( self._params, self._param2device) + def _wait(self, ): + endpoints = self.role_maker._get_trainer_endpoints() + current_endpoint = endpoints[self.role_maker._worker_index()] + if self.role_maker._worker_index() == 0: + self._collective_helper._wait(current_endpoint, endpoints, '6174') + def minimize_impl(self, loss, startup_program=None, @@ -855,6 +863,7 @@ class ZeroOptimizer(MetaOptimizerBase): # check op dependecy for broadcast self._check_broadcast(main_block) + self._wait() return optimize_ops, params_grads def _check_broadcast(self, block): -- GitLab