未验证 提交 1b45917c 编写于 作者: A Ammar Ahmad Awan 提交者: GitHub

Discover variables for NCCL backend on AML without mpi4py (#542)

* Use AML method to set env vars instead of using mpi4py.
Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
上级 d81cb26d
......@@ -125,7 +125,10 @@ class DeepSpeedEngine(Module):
if dist_init_required is None:
dist_init_required = not dist.is_initialized()
self._mpi_check(args, dist_init_required)
if self._in_aml():
self._set_environment_variables_for_nccl_backend(args)
else:
self._mpi_check(args, dist_init_required)
self.dist_backend = "nccl"
if dist_init_required:
......@@ -199,6 +202,49 @@ class DeepSpeedEngine(Module):
self.flatten = util_ops.flatten
self.unflatten = util_ops.unflatten
def _in_aml(self):
# read and environment variable to detect if we are using an Azure ML environment
if 'AZUREML_EXPERIMENT_ID' in os.environ:
return True
else:
return False
def _set_environment_variables_for_nccl_backend(self,
args,
master_port=6105,
verbose=True):
"""Helper routine to get and set environment variables.
This is adapted from Azure ML's documentation available from:
https://azure.github.io/azureml-web/docs/cheatsheet/distributed-training/#environment-variables-from-openmpi
"""
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
single_node = int(os.environ["OMPI_COMM_WORLD_LOCAL_SIZE"]) == int(
os.environ["WORLD_SIZE"])
if not single_node:
master_node_params = os.environ["AZ_BATCH_MASTER_NODE"].split(":")
os.environ["MASTER_ADDR"] = master_node_params[0]
# Do not overwrite master port with that defined in AZ_BATCH_MASTER_NODE
if "MASTER_PORT" not in os.environ:
os.environ["MASTER_PORT"] = str(master_port)
else:
os.environ["MASTER_ADDR"] = os.environ["AZ_BATCHAI_MPI_MASTER_NODE"]
os.environ["MASTER_PORT"] = "54965"
print("NCCL_SOCKET_IFNAME original value = {}".format(
os.environ["NCCL_SOCKET_IFNAME"]))
os.environ["NCCL_SOCKET_IFNAME"] = "^docker0,lo"
args.local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
if verbose:
logger.info(
"Discovered AzureML settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
.format(os.environ['RANK'],
args.local_rank,
os.environ['WORLD_SIZE'],
os.environ['MASTER_ADDR'],
os.environ['MASTER_PORT']))
def _mpi_check(self, args, dist_init_required):
if hasattr(args, 'deepspeed_mpi') and args.deepspeed_mpi:
from mpi4py import MPI
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册