提交 72968400 编写于 作者: H heqiaozhi

stop server out of run from file

上级 60d71a9e
...@@ -86,7 +86,7 @@ class AsyncExecutor(object): ...@@ -86,7 +86,7 @@ class AsyncExecutor(object):
scope = global_scope() scope = global_scope()
self.executor = core.AsyncExecutor(scope, p) self.executor = core.AsyncExecutor(scope, p)
self.instance = ps_instance.PaddlePSInstance("init_param", 1, 2) self.instance = ps_instance.PaddlePSInstance(1, 2)
def run(self, program, data_feed, filelist, thread_num, fetch, debug=False): def run(self, program, data_feed, filelist, thread_num, fetch, debug=False):
""" """
...@@ -151,10 +151,7 @@ class AsyncExecutor(object): ...@@ -151,10 +151,7 @@ class AsyncExecutor(object):
self.executor.run_from_files(program_desc, self.executor.run_from_files(program_desc,
data_feed.desc(), filelist, thread_num, data_feed.desc(), filelist, thread_num,
fetch_var_names, debug) fetch_var_names, debug)
self.instance.barrier_all() #worker do all things
if self.instance.is_first_worker():
self.executor.stop_server()
self.instance.barrier_all() #sync
def config_distributed_nodes(self, dist_opt): def config_distributed_nodes(self, dist_opt):
...@@ -167,8 +164,11 @@ class AsyncExecutor(object): ...@@ -167,8 +164,11 @@ class AsyncExecutor(object):
def get_instance(self): def get_instance(self):
return self.instance return self.instance
#def stop_server(self): def stop_server(self):
# self.executor.stop_server() self.instance.barrier_all() #worker do all things
if self.instance.is_first_worker():
self.executor.stop_server()
self.instance.barrier_all() #sync
def init_server(self, dist_desc): def init_server(self, dist_desc):
self.executor.init_server(dist_desc, self.instance._rankid) self.executor.init_server(dist_desc, self.instance._rankid)
......
...@@ -5,9 +5,8 @@ import sys ...@@ -5,9 +5,8 @@ import sys
class PaddlePSInstance(object): class PaddlePSInstance(object):
def __init__(self, init_param, server_worker_mode, proc_per_node): def __init__(self, server_worker_mode, proc_per_node):
self.dh = dist_helper.MPIHelper() self.dh = dist_helper.MPIHelper()
self._config = init_param
self._rankid = self.dh.get_rank() self._rankid = self.dh.get_rank()
self._server_worker_mode = server_worker_mode self._server_worker_mode = server_worker_mode
self._proc_per_node = proc_per_node self._proc_per_node = proc_per_node
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册