From db15ef578a4015617d6c33df43aa832d708a0bc9 Mon Sep 17 00:00:00 2001 From: noabauma <62301037+noabauma@users.noreply.github.com> Date: Tue, 7 Mar 2023 18:42:22 +0100 Subject: [PATCH] deepspeed.init_distributed() support for TCP protocols (#2905) Co-authored-by: Jeff Rasley --- deepspeed/comm/comm.py | 8 ++++++-- deepspeed/comm/torch.py | 16 ++++++++++++---- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/deepspeed/comm/comm.py b/deepspeed/comm/comm.py index dc7a38a2..302e3eaf 100644 --- a/deepspeed/comm/comm.py +++ b/deepspeed/comm/comm.py @@ -603,7 +603,9 @@ def init_distributed(dist_backend=None, timeout=default_pg_timeout, init_method=None, dist_init_required=None, - config=None): + config=None, + rank=-1, + world_size=-1): ''' Initialize dist backend, potentially performing MPI discovery if needed Arguments: @@ -614,6 +616,8 @@ def init_distributed(dist_backend=None, timeout: Optional (timedelta). Timeout for operations executed against the process group. Default value equals 30 minutes. init_method: Optional (string). Torch distributed, URL specifying how to initialize the process group. Default is “env://” if no init_method or store is specified. config: Optional (dict). DeepSpeed configuration for setting up comms options (e.g. Comms profiling) + rank: Optional (int). The current manually specified rank. Some init_method like “tcp://” need the rank and world_size as well (see: https://pytorch.org/docs/stable/distributed.html#tcp-initialization) + world_size: Optional (int). Desired world_size for the TCP or Shared file-system initialization. ''' global cdb @@ -658,7 +662,7 @@ def init_distributed(dist_backend=None, 'Initializing TorchBackend in DeepSpeed with backend {}'.format( dist_backend)) # Create a torch backend object, initialize torch distributed, and assign to cdb - cdb = TorchBackend(dist_backend, timeout, init_method) + cdb = TorchBackend(dist_backend, timeout, init_method, rank, world_size) def mpi_discovery(distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT, verbose=True): diff --git a/deepspeed/comm/torch.py b/deepspeed/comm/torch.py index db4604ca..0753e7ce 100644 --- a/deepspeed/comm/torch.py +++ b/deepspeed/comm/torch.py @@ -16,7 +16,13 @@ class TorchBackend(Backend): so no need to wrap all the functions. We can keep adding wrappers as needed. """ - def __init__(self, backend, timeout, init_method, name='torch'): + def __init__(self, + backend, + timeout, + init_method, + rank=-1, + world_size=-1, + name='torch'): super(TorchBackend, self).__init__() self.torch_version_before_18 = older_torch() self.has_allgather_base = has_allgather_base() @@ -27,13 +33,15 @@ class TorchBackend(Backend): # The idea is to fake that dist backend is initialized even when # it is not so we can run on a single GPU without doing any init_process_group self.single_gpu_mode = True - self.init_process_group(backend, timeout, init_method) + self.init_process_group(backend, timeout, init_method, rank, world_size) - def init_process_group(self, backend, timeout, init_method): + def init_process_group(self, backend, timeout, init_method, rank, world_size): if not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend, timeout=timeout, - init_method=init_method) + init_method=init_method, + rank=rank, + world_size=world_size) self.using_mpi = torch.distributed.get_backend() == 'mpi' def all_reduce(self, -- GitLab