diff --git a/python/paddle/fluid/async_executor.py b/python/paddle/fluid/async_executor.py index 76fdb5b0e263de4517a5ef4b6634dbe0fcee31b3..787a6a6b9e19e6253664187ee166609c2e2abc1f 100644 --- a/python/paddle/fluid/async_executor.py +++ b/python/paddle/fluid/async_executor.py @@ -86,7 +86,7 @@ class AsyncExecutor(object): scope = global_scope() self.executor = core.AsyncExecutor(scope, p) - self.instance = ps_instance.PaddlePSInstance("init_param", 1, 2) + self.instance = ps_instance.PaddlePSInstance(1, 2) def run(self, program, data_feed, filelist, thread_num, fetch, debug=False): """ @@ -151,10 +151,7 @@ class AsyncExecutor(object): self.executor.run_from_files(program_desc, data_feed.desc(), filelist, thread_num, fetch_var_names, debug) - self.instance.barrier_all() #worker do all things - if self.instance.is_first_worker(): - self.executor.stop_server() - self.instance.barrier_all() #sync + def config_distributed_nodes(self, dist_opt): @@ -167,8 +164,11 @@ class AsyncExecutor(object): def get_instance(self): return self.instance - #def stop_server(self): - # self.executor.stop_server() + def stop_server(self): + self.instance.barrier_all() #worker do all things + if self.instance.is_first_worker(): + self.executor.stop_server() + self.instance.barrier_all() #sync def init_server(self, dist_desc): self.executor.init_server(dist_desc, self.instance._rankid) diff --git a/python/paddle/fluid/distributed/ps_instance.py b/python/paddle/fluid/distributed/ps_instance.py index b4045327e1b17b559a82c7fab811e7ef7adbc7b4..94e123c2ceba837375b3adca075e5c0dc144c510 100644 --- a/python/paddle/fluid/distributed/ps_instance.py +++ b/python/paddle/fluid/distributed/ps_instance.py @@ -5,9 +5,8 @@ import sys class PaddlePSInstance(object): - def __init__(self, init_param, server_worker_mode, proc_per_node): + def __init__(self, server_worker_mode, proc_per_node): self.dh = dist_helper.MPIHelper() - self._config = init_param self._rankid = self.dh.get_rank() self._server_worker_mode = server_worker_mode self._proc_per_node = proc_per_node