未验证 提交 852c524a 编写于 作者: S sdtblck 提交者: GitHub

Add optional timeout parameter to deepspeed.init_distributed (#637)

Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
上级 34c83a5a
'''
Copyright 2020 The Microsoft DeepSpeed Team
'''
from datetime import timedelta
#############################################
# Torch distributed constants
#############################################
TORCH_DISTRIBUTED_DEFAULT_PORT = 29500
# Default process group wide timeout, if applicable.
# This only applies to the gloo and nccl backends
# (only if NCCL_BLOCKING_WAIT or NCCL_ASYNC_ERROR_HANDLING is set to 1).
# To make an attempt at backwards compatibility with THD, we use an
# extraordinarily high default timeout, given that THD did not have timeouts.
default_pg_timeout = timedelta(minutes=30)
......@@ -3,25 +3,25 @@ Copyright 2020 The Microsoft DeepSpeed Team
'''
import os
import torch
from datetime import timedelta
from .logging import logger
from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT
from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT, default_pg_timeout
def init_distributed(dist_backend="nccl",
auto_mpi_discovery=True,
distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT,
verbose=True):
"""Initialize torch.distributed backend, potentially performing MPI discovery if needed.
verbose=True,
timeout=default_pg_timeout):
"""
Initialize torch.distributed backend, potentially performing MPI discovery if needed
Arguments:
dist_backend: torch distributed backend, e.g., nccl, mpi, gloo
auto_mpi_discovery: if distributed environment variables are not set, attempt to discover them from MPI
distributed_port: torch distributed backend port
verbose: verbose logging
dist_backend (str): torch distributed backend, e.g., nccl, mpi, gloo
auto_mpi_discovery (bool): if distributed environment variables are not set, attempt to discover them from MPI
distributed_port (int, optional): torch distributed backend port
verbose (bool, optional): verbose logging
timeout (timedelta, optional): Timeout for operations executed against the process group. Default value equals 30 minutes.
"""
required_env = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
if auto_mpi_discovery and not all(map(lambda v: v in os.environ, required_env)):
......@@ -38,7 +38,8 @@ def init_distributed(dist_backend="nccl",
if verbose:
logger.info(
"Initializing torch distributed with backend: {}".format(dist_backend))
torch.distributed.init_process_group(backend=dist_backend)
assert isinstance(timeout, timedelta)
torch.distributed.init_process_group(backend=dist_backend, timeout=timeout)
def mpi_discovery(distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT, verbose=True):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册