diff --git a/python/paddle/distributed/fleet/base/role_maker.py b/python/paddle/distributed/fleet/base/role_maker.py index 70e2b9209a6f9d9b75ed3652a84a8bbd3e3cb15f..874cb5335746dddbec3d4813d84f03931e2539ba 100644 --- a/python/paddle/distributed/fleet/base/role_maker.py +++ b/python/paddle/distributed/fleet/base/role_maker.py @@ -78,10 +78,10 @@ class Gloo(object): self._worker_num = worker_num self._server_num = server_num self._need_init_all = need_init_all - self._start_http_server = kwargs.get("start_http_server", False) self._iface = "" self._prefix = kwargs.get("store.prefix", "") + http_server = None if self._rendezvous == Gloo.RENDEZVOUS.HDFS: dfs_name = kwargs.get("dfs.name", "") dfs_ugi = kwargs.get("dfs.ugi", "") @@ -101,17 +101,18 @@ class Gloo(object): elif self._rendezvous == Gloo.RENDEZVOUS.HTTP: ip = kwargs.get("http.host", "") port = kwargs.get("http.port", "") + start_http_server = kwargs.get("start_http_server", False) + http_server_d = kwargs.get("http_server_d") if not ip or not port: raise ValueError(self._err_type) - self._init_http(ip, port, self._prefix, self._start_http_server) - ep = ":".join([ip, port]) - wait_server_ready([ep]) - + http_server = self._init_http(ip, port, self._prefix, + start_http_server, http_server_d) else: raise ValueError(self._err_type) self._is_initialized = True + self._http_server = http_server def _init_fs(self, fs_path, prefix): def init(rank, nodes, role): @@ -167,7 +168,7 @@ class Gloo(object): gloo = init(rank, nodes, "ALL") self._nodes_comm = gloo - def _init_http(self, ip, port, prefix, start_http_server): + def _init_http(self, ip, port, prefix, start_http_server, http_server_d): def __start_kv_server(http_server_d, size_d): from paddle.distributed.fleet.utils.http_server import KVServer http_server = KVServer(port, size_d) @@ -177,21 +178,22 @@ class Gloo(object): time.sleep(wait_seconds) http_server.stop() - def init_kv_server(): + def init_kv_server(http_server_d): size_d = { "trainer": self._worker_num, "pserver": self._server_num, "all": self._worker_num + self._server_num } - _http_server_d = {"running": True} + http_server_d["running"] = True # child process for http server _http_server = Process( - target=__start_kv_server, args=(_http_server_d, size_d)) + target=__start_kv_server, args=(http_server_d, size_d)) _http_server.daemon = True # set running status to True # start child process _http_server.start() + return _http_server def init(rank, nodes, role): gloo = fluid.core.Gloo() @@ -202,12 +204,15 @@ class Gloo(object): gloo.set_timeout_seconds(self._init_timeout_seconds, self._run_timeout_seconds) gloo.set_http_store(ip, port, role) + ep = ":".join([ip, str(port)]) + wait_server_ready([ep]) + gloo.init() return gloo port = int(port) if start_http_server: - init_kv_server() + http_server = init_kv_server(http_server_d) if self._role == Role.WORKER: rank, nodes = self._get_rank_nodes(Role.WORKER) @@ -222,6 +227,9 @@ class Gloo(object): rank, nodes = self._get_rank_nodes(Role.ALL) gloo = init(rank, nodes, "ALL") self._nodes_comm = gloo + if start_http_server: + http_server_d["running"] = False + http_server.join() def _get_rank_nodes(self, role): nodes = 0 @@ -804,6 +812,9 @@ class PaddleCloudRoleMaker(RoleMakerBase): } elif rendezvous_type == Gloo.RENDEZVOUS.HTTP: start_http_server = False + manager = Manager() + http_server_d = manager.dict() + http_server_d["running"] = False if self._is_collective: ep_rank_0 = self._worker_endpoints[0] if self._is_first_worker(): @@ -818,6 +829,7 @@ class PaddleCloudRoleMaker(RoleMakerBase): "http.port": port, "store.prefix": prefix, 'start_http_server': start_http_server, + 'http_server_d': http_server_d, } else: dfs_path = os.getenv("PADDLE_GLOO_FS_PATH", "") @@ -844,6 +856,9 @@ class PaddleCloudRoleMaker(RoleMakerBase): need_init_all=need_init_all, kwargs=kwargs) + if rendezvous_type == Gloo.RENDEZVOUS.HTTP: + http_server_d['running'] = False + def _generate_role(self): """ generate role for role maker