# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from collections import namedtuple import pickle import time import datetime import paddle.fluid.core as core from paddle.distributed.utils.launch_utils import logger from paddle.distributed.rpc.internal import _serialize, PythonFunc from paddle.distributed.launch.context import Node WorkerInfo = namedtuple("WorkerInfo", ["name", "rank", "ip", "port"]) _DEFAULT_RPC_TIMEOUT = -1 _MAX_RPC_TIMEOUT_MS = 0x7FFFFFFF _BARRIER_TIMEOUT_MAX_DAYS = 99999999 # tcp store for `_barrier_never_timeout` _barrier_store = None # count the number of `_barrier_never_timeout` is called and # ensure that the barrier key is unique _barrier_count = 0 def _set_barrier_store(store): global _barrier_store _barrier_store = store def _del_barrier_store(): global _barrier_store del _barrier_store def _set_self_info(name, rank, ip, port): self_info = pickle.dumps(WorkerInfo(name, rank, ip, port)) _barrier_store.set(str(rank), self_info) def _exchange_all_service_infos(world_size): all_infos = [] s = set() for rank in range(world_size): info = pickle.loads(_barrier_store.get(str(rank))) assert ( info.name not in s ), "The Worker name must be unique, but name `{}` is repeated." s.add(info.name) all_infos.append(info) return all_infos def _gen_endpoint(): node = Node() ip = node.get_host_ip() free_port = node.get_free_port() return "{}:{}".format(ip, free_port) def init_rpc(name, rank=None, world_size=None, master_endpoint=None): """ init rpc. Args: name (str): worker name. rank (int, optional): worker id, default is None. world_size (int, optional): number of workers, default is None. master_endpoint (str, optional): id address of master, other nodes communicate with the master to get the information of all worker nodes, default is None. Returns: None. Examples: .. code-block:: python import paddle.distributed.rpc as rpc rpc.init_rpc("worker0", rank=0, world_size=1, master_endpoint="127.0.0.1:8001") rpc.shutdown() """ rank = int(os.environ["PADDLE_TRAINER_ID"]) if rank is None else rank world_size = ( int(os.environ["PADDLE_TRAINERS_NUM"]) if world_size is None else world_size ) worker_endpoint = os.getenv("PADDLE_WORKER_ENDPOINT", None) if worker_endpoint is None: worker_endpoint = _gen_endpoint() logger.info("Trainer {}: worker endpoint: {}".format(rank, worker_endpoint)) master_endpoint = ( master_endpoint if master_endpoint is not None else os.environ["PADDLE_MASTER_ENDPOINT"] ) master_addr, master_port = master_endpoint.split(":") master_port = int(master_port) stop_check_timeout = int(os.getenv("FLAGS_stop_check_timeout", "900")) store = core.TCPStore( master_addr, master_port, rank == 0, world_size, timeout=stop_check_timeout, ) _set_barrier_store(store) ip, port = worker_endpoint.split(":") port = int(port) _set_self_info(name, rank, ip, port) all_infos = _exchange_all_service_infos(world_size) c_infos = [] for node_info in all_infos: info = core.WorkerInfo( node_info.name, node_info.rank, node_info.ip, node_info.port ) c_infos.append(info) core.init_and_set_agent_instance(name, c_infos) core.rpc_start_worker() # ensure that all the workers are started _barrier_never_timeout(rank, world_size) core.rpc_start_client() logger.info("Trainer {}: Init RPC done!".format(rank)) def rpc_sync(to, fn, args=None, kwargs=None, timeout=_DEFAULT_RPC_TIMEOUT): """ Make a blocking RPC call to run function ``fn`` on worker ``to``. Args: to (str): name of the destination worker. fn (fn): a callable function, such as Python callables. args (tuple, optional): the argument tuple for the ``fn`` invocation, default is None. kwargs (dict, optional): is a dictionary of keyword arguments for the ``fn`` invocation, default is None. timeout (int, optional): timeout in seconds to use for this RPC. If the RPC does not complete in this amount of time, an exception indicating it has timed out will be raised. A value less than or equal to 0 indicates an infinite timeout, i.e. a timeout error will never be raised. The default value is -1. Returns: Returns the result of running ``fn`` with ``args`` and ``kwargs``. Examples: .. code-block:: python import paddle.distributed.rpc as rpc def add(a, b): return a + b rpc.init_rpc("worker0", rank=0, world_size=1, master_endpoint="127.0.0.1:8002") ret = rpc.rpc_sync("worker0", add, args=(2, 3)) rpc.shutdown() """ fut = _invoke_rpc(to, fn, args, kwargs, timeout) return fut.wait() def rpc_async(to, fn, args=None, kwargs=None, timeout=_DEFAULT_RPC_TIMEOUT): """ Make a non-blocking RPC call to run function ``fn`` on worker ``to``. Args: to (str): name of the destination worker. fn (fn): a callable function, such as Python callables. args (tuple, optional): the argument tuple for the ``fn`` invocation, default is None. kwargs (dict, optional): is a dictionary of keyword arguments for the ``fn`` invocation, default is None. timeout (int, optional): timeout in seconds to use for this RPC. If the RPC does not complete in this amount of time, an exception indicating it has timed out will be raised. A value less than or equal to 0 indicates an infinite timeout, i.e. a timeout error will never be raised. The default value is -1. Returns: Returns a :class:`FutureWrapper` object that can be waited on. When completed, the return value of ``fn`` on ``args`` and ``kwargs`` can be got by `fut.wait()`. Examples: .. code-block:: python import paddle.distributed.rpc as rpc def add(a, b): return a + b rpc.init_rpc("worker0", rank=0, world_size=1, master_endpoint="127.0.0.1:8003") fut = rpc.rpc_async("worker0", add, args=(2, 3)) print(fut.wait()) rpc.shutdown() """ return _invoke_rpc(to, fn, args, kwargs, timeout) def _invoke_rpc(to, fn, args, kwargs, timeout): args = args if args else () kwargs = kwargs if kwargs else {} serial_obj = _serialize(PythonFunc(fn, args, kwargs)) timeout_ms = timeout * 1000 timeout_ms = _MAX_RPC_TIMEOUT_MS if timeout_ms <= 0 else timeout_ms future = core.invoke_rpc(to, serial_obj, timeout_ms) return future def _barrier_never_timeout(global_rank, global_world_size): # max timeout timeout = datetime.timedelta(days=_BARRIER_TIMEOUT_MAX_DAYS) if global_world_size < 2: return global _barrier_count barrier_prefix = "Barrier/" + str(_barrier_count) + "/" _barrier_count += 1 is_master = global_rank == 0 def _check_keys_ready(wait_keys): start_time = time.time() while len(wait_keys) > 0: time.sleep(0.1) elapse_time = time.time() - start_time if datetime.timedelta(seconds=elapse_time) > timeout: raise RuntimeError( "Keys {} are not ready sinck rank {} is waiting them.".format( wait_keys, global_rank ) ) wait_keys = list( filter(lambda key: int(_barrier_store.get(key)) != 1, wait_keys) ) if is_master: # the master will add key, wait for all workers'exiting key and exit in the end. # Note: the master must exit in the end to ensure that the TcpServer is destroyed in the end. wait_keys = [ barrier_prefix + str(rank) for rank in range(1, global_world_size) ] _barrier_store.add(barrier_prefix + str(0), 1) _check_keys_ready(wait_keys) else: wait_keys = [barrier_prefix + str(0)] _check_keys_ready(wait_keys) _barrier_store.add(barrier_prefix + str(global_rank), 1) def shutdown(): """ Perform a shutdown of the RPC agent, stop the worker and destroy the agent. This will block until all local and remote RPC processes reach this method and wait for all outstanding work to complete. Returns: None. Examples: .. code-block:: python import paddle.distributed.rpc as rpc rpc.init_rpc("worker0", rank=0, world_size=1, master_endpoint="127.0.0.1:8004") rpc.shutdown() """ info = get_current_worker_info() rank = info.rank world_size = len(get_all_worker_infos()) # master will exit in the end _barrier_never_timeout(rank, world_size) core.rpc_stop_worker() _del_barrier_store() logger.info("Trainer {}: rpc shutdown!".format(rank)) def get_worker_info(name): """ Get worker information by worker name. Args: name (str): name of the worker. Returns: class `WorkerInfo` with attribute `name`, `rank`, `ip` and `port`. Examples: .. code-block:: python import paddle.distributed.rpc as rpc import os os.environ["PADDLE_WORKER_ENDPOINT"] = "127.0.0.1:9002" rpc.init_rpc("worker0", rank=0, world_size=1, master_endpoint="127.0.0.1:8005") print(rpc.get_worker_info("worker0")) # {name: worker0, rank: 0, ip: 127.0.0.1, port: 9002} rpc.shutdown() """ return core.rpc_get_worker_info(name) def get_all_worker_infos(): """ Get all worker informations. Returns: List[WorkerInfo]. Examples: .. code-block:: python import paddle.distributed.rpc as rpc import os os.environ["PADDLE_WORKER_ENDPOINT"] = "127.0.0.1:9003" rpc.init_rpc("worker0", rank=0, world_size=1, master_endpoint="127.0.0.1:8006") print(rpc.get_all_worker_infos()) # [{name: worker0, rank: 0, ip: 127.0.0.1, port: 9003}] rpc.shutdown() """ return core.rpc_get_all_worker_infos() def get_current_worker_info(): """ Get current worker information. Returns: class `WorkerInfo` with attribute `name`, `rank`, `ip` and `port`. Examples: .. code-block:: python import paddle.distributed.rpc as rpc import os os.environ["PADDLE_WORKER_ENDPOINT"] = "127.0.0.1:9004" rpc.init_rpc("worker0", rank=0, world_size=1, master_endpoint="127.0.0.1:8007") print(rpc.get_current_worker_info()) # {name: worker0, rank: 0, ip: 127.0.0.1, port: 9004} rpc.shutdown() """ return core.rpc_get_current_worker_info()