From 729684007d70cad38e9d34317748e3fedd477886 Mon Sep 17 00:00:00 2001 From: heqiaozhi Date: Tue, 11 Dec 2018 14:44:02 +0800 Subject: [PATCH] stop server out of run from file --- python/paddle/fluid/async_executor.py | 14 +++++++------- python/paddle/fluid/distributed/ps_instance.py | 3 +-- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/python/paddle/fluid/async_executor.py b/python/paddle/fluid/async_executor.py index 76fdb5b0e2..787a6a6b9e 100644 --- a/python/paddle/fluid/async_executor.py +++ b/python/paddle/fluid/async_executor.py @@ -86,7 +86,7 @@ class AsyncExecutor(object): scope = global_scope() self.executor = core.AsyncExecutor(scope, p) - self.instance = ps_instance.PaddlePSInstance("init_param", 1, 2) + self.instance = ps_instance.PaddlePSInstance(1, 2) def run(self, program, data_feed, filelist, thread_num, fetch, debug=False): """ @@ -151,10 +151,7 @@ class AsyncExecutor(object): self.executor.run_from_files(program_desc, data_feed.desc(), filelist, thread_num, fetch_var_names, debug) - self.instance.barrier_all() #worker do all things - if self.instance.is_first_worker(): - self.executor.stop_server() - self.instance.barrier_all() #sync + def config_distributed_nodes(self, dist_opt): @@ -167,8 +164,11 @@ class AsyncExecutor(object): def get_instance(self): return self.instance - #def stop_server(self): - # self.executor.stop_server() + def stop_server(self): + self.instance.barrier_all() #worker do all things + if self.instance.is_first_worker(): + self.executor.stop_server() + self.instance.barrier_all() #sync def init_server(self, dist_desc): self.executor.init_server(dist_desc, self.instance._rankid) diff --git a/python/paddle/fluid/distributed/ps_instance.py b/python/paddle/fluid/distributed/ps_instance.py index b4045327e1..94e123c2ce 100644 --- a/python/paddle/fluid/distributed/ps_instance.py +++ b/python/paddle/fluid/distributed/ps_instance.py @@ -5,9 +5,8 @@ import sys class PaddlePSInstance(object): - def __init__(self, init_param, server_worker_mode, proc_per_node): + def __init__(self, server_worker_mode, proc_per_node): self.dh = dist_helper.MPIHelper() - self._config = init_param self._rankid = self.dh.get_rank() self._server_worker_mode = server_worker_mode self._proc_per_node = proc_per_node -- GitLab