提交 2c1e986f 编写于 作者: H heqiaozhi

barrier_all to barrier_worker

上级 009c7cf6
...@@ -170,7 +170,8 @@ class AsyncExecutor(object): ...@@ -170,7 +170,8 @@ class AsyncExecutor(object):
self.instance.get_worker_index(), self.instance.get_worker_index(),
self.instance.get_node_cnt() / 2, self.instance.get_node_cnt() / 2,
multi_processes=process_num) multi_processes=process_num)
self.instance.barrier_all() #wait for download_data #TODO only barriere worker #self.instance.barrier_all() #wait for download_data #TODO only barriere worker
self.instance.barrier_worker() #wait for download_data #TODO only barriere worker
def config_distributed_nodes(self): def config_distributed_nodes(self):
self.instance = ps_instance.PaddlePSInstance(1, 2) self.instance = ps_instance.PaddlePSInstance(1, 2)
...@@ -187,13 +188,13 @@ class AsyncExecutor(object): ...@@ -187,13 +188,13 @@ class AsyncExecutor(object):
raise ValueError('instance is None, please run config_distributed_nodes init instance') raise ValueError('instance is None, please run config_distributed_nodes init instance')
return self.instance return self.instance
def stop_server(self): def stop(self):
if self.instance is None: if self.instance is None:
raise ValueError('instance is None, please run config_distributed_nodes init instance') raise ValueError('instance is None, please run config_distributed_nodes init instance')
self.instance.barrier_all() #worker do all things self.instance.barrier_worker() #worker do all things
if self.instance.is_first_worker(): if self.instance.is_first_worker():
self.executor.stop_server() self.executor.stop_server()
self.instance.barrier_all() #sync self.instance.barrier_worker() #sync
def init_server(self, dist_desc): def init_server(self, dist_desc):
if self.instance is None: if self.instance is None:
...@@ -205,10 +206,6 @@ class AsyncExecutor(object): ...@@ -205,10 +206,6 @@ class AsyncExecutor(object):
ips = self.instance.gather_ips() ips = self.instance.gather_ips()
self.executor.gather_servers(ips, self.instance.get_node_cnt()) self.executor.gather_servers(ips, self.instance.get_node_cnt())
self.instance.barrier_all() #wait all worker start self.instance.barrier_all() #wait all worker start
self.instance.barrier_all() #wait init model
self.instance.barrier_all() #wait for download_data #TODO remove this after only barrier worker
self.instance.barrier_all() #wait worker do all things
self.instance.barrier_all() #sync
def init_worker(self, dist_desc, startup_program): def init_worker(self, dist_desc, startup_program):
if self.instance is None: if self.instance is None:
...@@ -223,7 +220,7 @@ class AsyncExecutor(object): ...@@ -223,7 +220,7 @@ class AsyncExecutor(object):
self.instance.barrier_all() #wait all worker start self.instance.barrier_all() #wait all worker start
if self.instance.is_first_worker(): if self.instance.is_first_worker():
self.executor.init_model() self.executor.init_model()
self.instance.barrier_all() #wait init model self.instance.barrier_worker() #wait init model
def init_model(self): def init_model(self):
if self.instance is None: if self.instance is None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册