未验证 提交 db15ef57 编写于 作者: N noabauma 提交者: GitHub

deepspeed.init_distributed() support for TCP protocols (#2905)

Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
上级 0acf7e9c
......@@ -603,7 +603,9 @@ def init_distributed(dist_backend=None,
timeout=default_pg_timeout,
init_method=None,
dist_init_required=None,
config=None):
config=None,
rank=-1,
world_size=-1):
''' Initialize dist backend, potentially performing MPI discovery if needed
Arguments:
......@@ -614,6 +616,8 @@ def init_distributed(dist_backend=None,
timeout: Optional (timedelta). Timeout for operations executed against the process group. Default value equals 30 minutes.
init_method: Optional (string). Torch distributed, URL specifying how to initialize the process group. Default is “env://” if no init_method or store is specified.
config: Optional (dict). DeepSpeed configuration for setting up comms options (e.g. Comms profiling)
rank: Optional (int). The current manually specified rank. Some init_method like “tcp://” need the rank and world_size as well (see: https://pytorch.org/docs/stable/distributed.html#tcp-initialization)
world_size: Optional (int). Desired world_size for the TCP or Shared file-system initialization.
'''
global cdb
......@@ -658,7 +662,7 @@ def init_distributed(dist_backend=None,
'Initializing TorchBackend in DeepSpeed with backend {}'.format(
dist_backend))
# Create a torch backend object, initialize torch distributed, and assign to cdb
cdb = TorchBackend(dist_backend, timeout, init_method)
cdb = TorchBackend(dist_backend, timeout, init_method, rank, world_size)
def mpi_discovery(distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT, verbose=True):
......
......@@ -16,7 +16,13 @@ class TorchBackend(Backend):
so no need to wrap all the functions. We can keep adding wrappers as
needed.
"""
def __init__(self, backend, timeout, init_method, name='torch'):
def __init__(self,
backend,
timeout,
init_method,
rank=-1,
world_size=-1,
name='torch'):
super(TorchBackend, self).__init__()
self.torch_version_before_18 = older_torch()
self.has_allgather_base = has_allgather_base()
......@@ -27,13 +33,15 @@ class TorchBackend(Backend):
# The idea is to fake that dist backend is initialized even when
# it is not so we can run on a single GPU without doing any init_process_group
self.single_gpu_mode = True
self.init_process_group(backend, timeout, init_method)
self.init_process_group(backend, timeout, init_method, rank, world_size)
def init_process_group(self, backend, timeout, init_method):
def init_process_group(self, backend, timeout, init_method, rank, world_size):
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend,
timeout=timeout,
init_method=init_method)
init_method=init_method,
rank=rank,
world_size=world_size)
self.using_mpi = torch.distributed.get_backend() == 'mpi'
def all_reduce(self,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册