提交 e665655c 编写于 作者: M mapingshuo

move _wait to end

上级 21a276b6
......@@ -99,6 +99,12 @@ class CollectiveHelper(object):
OP_ROLE_KEY: OpRole.Forward
})
def _wait(self, current_endpoint, endpoints, wait_port):
assert (wait_port)
other_endpoints = endpoints[:]
other_endpoints.remove(current_endpoint)
wait_server_ready(other_endpoints)
def _broadcast_params(self):
block = self.startup_program.global_block()
ring_id = -1
......
......@@ -196,6 +196,7 @@ class ZeroOptimizer(MetaOptimizerBase):
core.VarDesc.VarType.BOOL: 1,
core.VarDesc.VarType.UINT8: 1,
}
self._collective_helper = None
def _get_var_size(self, param):
"""
......@@ -778,11 +779,12 @@ class ZeroOptimizer(MetaOptimizerBase):
print("work idx: ", self.role_maker._worker_index())
endpoints = self.role_maker._get_trainer_endpoints()
current_endpoint = endpoints[self.role_maker._worker_index()]
collective_helper = CollectiveHelper(self.role_maker, self._nrings)
self._collective_helper = CollectiveHelper(self.role_maker,
self._nrings)
for ring_id in range(self._nrings):
collective_helper._init_communicator(
self._collective_helper._init_communicator(
self._startup_program, current_endpoint, endpoints,
self.role_maker._worker_index(), ring_id, '6174')
self.role_maker._worker_index(), ring_id, None)
startup_block = self._startup_program.global_block()
startup_block._sync_with_cpp()
......@@ -794,6 +796,12 @@ class ZeroOptimizer(MetaOptimizerBase):
self._fp16_params, self._broadcast_vars, self._fp16_to_params = self._find_broadcast_params(
self._params, self._param2device)
def _wait(self, ):
endpoints = self.role_maker._get_trainer_endpoints()
current_endpoint = endpoints[self.role_maker._worker_index()]
if self.role_maker._worker_index() == 0:
self._collective_helper._wait(current_endpoint, endpoints, '6174')
def minimize_impl(self,
loss,
startup_program=None,
......@@ -855,6 +863,7 @@ class ZeroOptimizer(MetaOptimizerBase):
# check op dependecy for broadcast
self._check_broadcast(main_block)
self._wait()
return optimize_ops, params_grads
def _check_broadcast(self, block):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册