未验证 提交 0e831e23 编写于 作者: A Ammar Ahmad Awan 提交者: GitHub

Simplify dist init and only init if needed. (#553)

Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
上级 6e65c2cc
......@@ -121,25 +121,31 @@ class DeepSpeedEngine(Module):
self.loaded_checkpoint_dp_world_size = None
self.enable_backward_allreduce = True
self.progressive_layer_drop = None
self.dist_backend = "nccl"
if dist_init_required is None:
dist_init_required = not dist.is_initialized()
if self._in_aml():
self._set_environment_variables_for_nccl_backend(args)
else:
self._mpi_check(args, dist_init_required)
if dist_init_required is False:
assert (dist.is_initialized()==True), "Torch distributed not initialized. Please set dist_init_required to True or initialize before calling deepspeed.initialize()"
self.dist_backend = "nccl"
if dist_init_required:
if not dist.is_initialized():
logger.info("Initializing torch distributed with backend: {}".format(
self.dist_backend))
dist.init_process_group(backend=self.dist_backend)
# DeepSpeed will initialize torch distributed only if the user has not already intialized it.
if dist_init_required and not dist.is_initialized():
# discover using mpi4py if user specifies the flag
if hasattr(args, 'deepspeed_mpi') and args.deepspeed_mpi:
# if in Azure ML environment and user specified this flag, notify the user to remove the flag.
if self._in_aml():
logger.warning(
"Please remove the --deepspeed_mpi flag if running on AzureML.")
self._mpi_check(args, dist_init_required)
else:
logger.warning(
"Was given dist_init_required=True but detected that torch"
"distributed was already initialized, cannot initialize twice.")
# detect if we are in Azure ML environment
if self._in_aml():
self._set_environment_variables_for_nccl_backend(args)
logger.info("Initializing torch distributed with backend: {}".format(
self.dist_backend))
dist.init_process_group(backend=self.dist_backend)
self._do_args_sanity_check(args)
self._configure_with_arguments(args, mpu)
......@@ -203,7 +209,7 @@ class DeepSpeedEngine(Module):
self.unflatten = util_ops.unflatten
def _in_aml(self):
# read and environment variable to detect if we are using an Azure ML environment
# read AzureML environment variable to detect if we are using an Azure ML environment
if 'AZUREML_EXPERIMENT_ID' in os.environ:
return True
else:
......@@ -246,43 +252,42 @@ class DeepSpeedEngine(Module):
os.environ['MASTER_PORT']))
def _mpi_check(self, args, dist_init_required):
if hasattr(args, 'deepspeed_mpi') and args.deepspeed_mpi:
from mpi4py import MPI
import subprocess
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
world_size = comm.Get_size()
master_addr = None
if rank == 0:
hostname_cmd = ["hostname -I"]
result = subprocess.check_output(hostname_cmd, shell=True)
master_addr = result.decode('utf-8').split()[0]
master_addr = comm.bcast(master_addr, root=0)
# Determine local rank by assuming hostnames are unique
proc_name = MPI.Get_processor_name()
all_procs = comm.allgather(proc_name)
local_rank = sum([i == proc_name for i in all_procs[:rank]])
os.environ['RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
args.local_rank = local_rank
os.environ['MASTER_ADDR'] = master_addr
os.environ['MASTER_PORT'] = TORCH_DISTRIBUTED_DEFAULT_PORT
logger.info(
"Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
.format(os.environ['RANK'],
args.local_rank,
os.environ['WORLD_SIZE'],
os.environ['MASTER_ADDR'],
os.environ['MASTER_PORT']))
if not dist_init_required and dist.is_initialized():
assert dist.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(rank, dist.get_rank())
assert dist.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format(
world_size, dist.get_world_size())
from mpi4py import MPI
import subprocess
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
world_size = comm.Get_size()
master_addr = None
if rank == 0:
hostname_cmd = ["hostname -I"]
result = subprocess.check_output(hostname_cmd, shell=True)
master_addr = result.decode('utf-8').split()[0]
master_addr = comm.bcast(master_addr, root=0)
# Determine local rank by assuming hostnames are unique
proc_name = MPI.Get_processor_name()
all_procs = comm.allgather(proc_name)
local_rank = sum([i == proc_name for i in all_procs[:rank]])
os.environ['RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
args.local_rank = local_rank
os.environ['MASTER_ADDR'] = master_addr
os.environ['MASTER_PORT'] = TORCH_DISTRIBUTED_DEFAULT_PORT
logger.info(
"Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
.format(os.environ['RANK'],
args.local_rank,
os.environ['WORLD_SIZE'],
os.environ['MASTER_ADDR'],
os.environ['MASTER_PORT']))
if not dist_init_required and dist.is_initialized():
assert dist.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(rank, dist.get_rank())
assert dist.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format(
world_size, dist.get_world_size())
def pld_enabled(self):
return self._config.pld_enabled
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册