# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except jin compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import time from multiprocessing import Process, Manager # deprecated module import from paddle.fluid import core from paddle.distributed.fleet.base.private_helper_function import wait_server_ready __all__ = [] _global_gloo_ctx = None def _start_kv_server(port, http_server_d, size): from paddle.distributed.fleet.utils.http_server import KVServer http_server = KVServer(int(port), size=size) http_server.start() wait_seconds = 3 while http_server_d.get("running", False) or not http_server.should_stop(): time.sleep(wait_seconds) http_server.stop() def gloo_init_parallel_env(rank_id, rank_num, server_endpoint): """ Initialize parallel environment with gloo for cpu only. Args: - rank_id(int, required) - the index of current rank; - rank_num (int, required) - the number of ranks in this parallel env; - server_endpoint (str, required) - endpoint of server to init gloo context in ip:port format; Returns: None Examples: .. code-block:: python import paddle import multiprocessing from contextlib import closing import socket port_set = set() def find_free_port(): def _free_port(): with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: s.bind(('', 0)) return s.getsockname()[1] while True: port = _free_port() if port not in port_set: port_set.add(port) return port def test_gloo_init(id, rank_num, server_endpoint): paddle.distributed.gloo_init_parallel_env( id, rank_num, server_endpoint) def test_gloo_init_with_multiprocess(num_of_ranks): jobs = [] server_endpoint = "127.0.0.1:%s" % (find_free_port()) for id in range(num_of_ranks): p = multiprocessing.Process( target=test_gloo_init, args=(id, num_of_ranks, server_endpoint)) jobs.append(p) p.start() for proc in jobs: proc.join() if __name__ == '__main__': # Arg: number of ranks (processes) test_gloo_init_with_multiprocess(2) """ assert (rank_num < 2) is False, \ "rank_num should greater than or equal to 2 for parallel environment initialzation." # init gloo context manager = Manager() # global dict to store status http_server_status = manager.dict() http_server_status["running"] = False if rank_id == 0: # The scope for worker used by http server is '_worker' size = {'_worker': rank_num} http_server_proc = Process(target=_start_kv_server, args=(int(server_endpoint.split(":")[1]), http_server_status, size)) http_server_proc.daemon = True http_server_status["running"] = True http_server_proc.start() # all processes in this parallel environment should wait until server is ready wait_server_ready([server_endpoint]) gloo_strategy = core.GlooParallelStrategy() gloo_strategy.rank = rank_id gloo_strategy.rank_num = rank_num gloo_strategy.ip_address = server_endpoint.split(":")[0] gloo_strategy.ip_port = int(server_endpoint.split(":")[1]) # default_init_timeout_seconds gloo_strategy.init_seconds = 3600 # default_run_timeout_seconds gloo_strategy.run_seconds = 9999999 global _global_gloo_ctx _global_gloo_ctx = core.GlooParallelContext(gloo_strategy) _global_gloo_ctx.init() if rank_id == 0: http_server_status["running"] = False http_server_proc.join() def gloo_barrier(): """ Call barrier function with initialized gloo context. Args: None Returns: None Examples: .. code-block:: python import paddle import multiprocessing from contextlib import closing import socket port_set = set() def find_free_port(): def _free_port(): with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: s.bind(('', 0)) return s.getsockname()[1] while True: port = _free_port() if port not in port_set: port_set.add(port) return port def test_gloo_barrier(id, rank_num, server_endpoint): paddle.distributed.gloo_init_parallel_env( id, rank_num, server_endpoint) paddle.distributed.gloo_barrier() def test_gloo_barrier_with_multiprocess(num_of_ranks): jobs = [] server_endpoint = "127.0.0.1:%s" % (find_free_port()) for id in range(num_of_ranks): p = multiprocessing.Process( target=test_gloo_barrier, args=(id, num_of_ranks, server_endpoint)) jobs.append(p) p.start() for proc in jobs: proc.join() if __name__ == '__main__': # Arg: number of ranks (processes) test_gloo_barrier_with_multiprocess(2) """ assert _global_gloo_ctx is not None, "gloo context is not initialzed." _global_gloo_ctx.barrier() def gloo_release(): """ Release the parallel environment initialized by gloo Args: None Returns: None Examples: .. code-block:: python import paddle import multiprocessing from contextlib import closing import socket port_set = set() def find_free_port(): def _free_port(): with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: s.bind(('', 0)) return s.getsockname()[1] while True: port = _free_port() if port not in port_set: port_set.add(port) return port def test_gloo_release(id, rank_num, server_endpoint): paddle.distributed.gloo_init_parallel_env( id, rank_num, server_endpoint) paddle.distributed.gloo_barrier() paddle.distributed.gloo_release() def test_gloo_release_with_multiprocess(num_of_ranks): jobs = [] server_endpoint = "127.0.0.1:%s" % (find_free_port()) for id in range(num_of_ranks): p = multiprocessing.Process( target=test_gloo_release, args=(id, num_of_ranks, server_endpoint)) jobs.append(p) p.start() for proc in jobs: proc.join() if __name__ == '__main__': # Arg: number of ranks (processes) test_gloo_release_with_multiprocess(2) """ if _global_gloo_ctx is not None: _global_gloo_ctx.release()