diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 7803d7916aa2240d2fbf82d5fc0d54da8d627170..06ed83622abf49584050f74f41b0561045b246f5 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -125,7 +125,10 @@ class DeepSpeedEngine(Module): if dist_init_required is None: dist_init_required = not dist.is_initialized() - self._mpi_check(args, dist_init_required) + if self._in_aml(): + self._set_environment_variables_for_nccl_backend(args) + else: + self._mpi_check(args, dist_init_required) self.dist_backend = "nccl" if dist_init_required: @@ -199,6 +202,49 @@ class DeepSpeedEngine(Module): self.flatten = util_ops.flatten self.unflatten = util_ops.unflatten + def _in_aml(self): + # read and environment variable to detect if we are using an Azure ML environment + if 'AZUREML_EXPERIMENT_ID' in os.environ: + return True + else: + return False + + def _set_environment_variables_for_nccl_backend(self, + args, + master_port=6105, + verbose=True): + """Helper routine to get and set environment variables. + This is adapted from Azure ML's documentation available from: + https://azure.github.io/azureml-web/docs/cheatsheet/distributed-training/#environment-variables-from-openmpi + """ + os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"] + os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"] + single_node = int(os.environ["OMPI_COMM_WORLD_LOCAL_SIZE"]) == int( + os.environ["WORLD_SIZE"]) + if not single_node: + master_node_params = os.environ["AZ_BATCH_MASTER_NODE"].split(":") + os.environ["MASTER_ADDR"] = master_node_params[0] + # Do not overwrite master port with that defined in AZ_BATCH_MASTER_NODE + if "MASTER_PORT" not in os.environ: + os.environ["MASTER_PORT"] = str(master_port) + else: + os.environ["MASTER_ADDR"] = os.environ["AZ_BATCHAI_MPI_MASTER_NODE"] + os.environ["MASTER_PORT"] = "54965" + print("NCCL_SOCKET_IFNAME original value = {}".format( + os.environ["NCCL_SOCKET_IFNAME"])) + + os.environ["NCCL_SOCKET_IFNAME"] = "^docker0,lo" + args.local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) + + if verbose: + logger.info( + "Discovered AzureML settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}" + .format(os.environ['RANK'], + args.local_rank, + os.environ['WORLD_SIZE'], + os.environ['MASTER_ADDR'], + os.environ['MASTER_PORT'])) + def _mpi_check(self, args, dist_init_required): if hasattr(args, 'deepspeed_mpi') and args.deepspeed_mpi: from mpi4py import MPI