提交 e665655c 编写于 作者: M mapingshuo

move _wait to end

上级 21a276b6
...@@ -99,6 +99,12 @@ class CollectiveHelper(object): ...@@ -99,6 +99,12 @@ class CollectiveHelper(object):
OP_ROLE_KEY: OpRole.Forward OP_ROLE_KEY: OpRole.Forward
}) })
def _wait(self, current_endpoint, endpoints, wait_port):
assert (wait_port)
other_endpoints = endpoints[:]
other_endpoints.remove(current_endpoint)
wait_server_ready(other_endpoints)
def _broadcast_params(self): def _broadcast_params(self):
block = self.startup_program.global_block() block = self.startup_program.global_block()
ring_id = -1 ring_id = -1
......
...@@ -196,6 +196,7 @@ class ZeroOptimizer(MetaOptimizerBase): ...@@ -196,6 +196,7 @@ class ZeroOptimizer(MetaOptimizerBase):
core.VarDesc.VarType.BOOL: 1, core.VarDesc.VarType.BOOL: 1,
core.VarDesc.VarType.UINT8: 1, core.VarDesc.VarType.UINT8: 1,
} }
self._collective_helper = None
def _get_var_size(self, param): def _get_var_size(self, param):
""" """
...@@ -778,11 +779,12 @@ class ZeroOptimizer(MetaOptimizerBase): ...@@ -778,11 +779,12 @@ class ZeroOptimizer(MetaOptimizerBase):
print("work idx: ", self.role_maker._worker_index()) print("work idx: ", self.role_maker._worker_index())
endpoints = self.role_maker._get_trainer_endpoints() endpoints = self.role_maker._get_trainer_endpoints()
current_endpoint = endpoints[self.role_maker._worker_index()] current_endpoint = endpoints[self.role_maker._worker_index()]
collective_helper = CollectiveHelper(self.role_maker, self._nrings) self._collective_helper = CollectiveHelper(self.role_maker,
self._nrings)
for ring_id in range(self._nrings): for ring_id in range(self._nrings):
collective_helper._init_communicator( self._collective_helper._init_communicator(
self._startup_program, current_endpoint, endpoints, self._startup_program, current_endpoint, endpoints,
self.role_maker._worker_index(), ring_id, '6174') self.role_maker._worker_index(), ring_id, None)
startup_block = self._startup_program.global_block() startup_block = self._startup_program.global_block()
startup_block._sync_with_cpp() startup_block._sync_with_cpp()
...@@ -794,6 +796,12 @@ class ZeroOptimizer(MetaOptimizerBase): ...@@ -794,6 +796,12 @@ class ZeroOptimizer(MetaOptimizerBase):
self._fp16_params, self._broadcast_vars, self._fp16_to_params = self._find_broadcast_params( self._fp16_params, self._broadcast_vars, self._fp16_to_params = self._find_broadcast_params(
self._params, self._param2device) self._params, self._param2device)
def _wait(self, ):
endpoints = self.role_maker._get_trainer_endpoints()
current_endpoint = endpoints[self.role_maker._worker_index()]
if self.role_maker._worker_index() == 0:
self._collective_helper._wait(current_endpoint, endpoints, '6174')
def minimize_impl(self, def minimize_impl(self,
loss, loss,
startup_program=None, startup_program=None,
...@@ -855,6 +863,7 @@ class ZeroOptimizer(MetaOptimizerBase): ...@@ -855,6 +863,7 @@ class ZeroOptimizer(MetaOptimizerBase):
# check op dependecy for broadcast # check op dependecy for broadcast
self._check_broadcast(main_block) self._check_broadcast(main_block)
self._wait()
return optimize_ops, params_grads return optimize_ops, params_grads
def _check_broadcast(self, block): def _check_broadcast(self, block):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册