未验证 提交 0e831e23 编写于 作者: A Ammar Ahmad Awan 提交者: GitHub

Simplify dist init and only init if needed. (#553)

Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
上级 6e65c2cc
......@@ -121,25 +121,31 @@ class DeepSpeedEngine(Module):
self.loaded_checkpoint_dp_world_size = None
self.enable_backward_allreduce = True
self.progressive_layer_drop = None
self.dist_backend = "nccl"
if dist_init_required is None:
dist_init_required = not dist.is_initialized()
if dist_init_required is False:
assert (dist.is_initialized()==True), "Torch distributed not initialized. Please set dist_init_required to True or initialize before calling deepspeed.initialize()"
# DeepSpeed will initialize torch distributed only if the user has not already intialized it.
if dist_init_required and not dist.is_initialized():
# discover using mpi4py if user specifies the flag
if hasattr(args, 'deepspeed_mpi') and args.deepspeed_mpi:
# if in Azure ML environment and user specified this flag, notify the user to remove the flag.
if self._in_aml():
self._set_environment_variables_for_nccl_backend(args)
else:
logger.warning(
"Please remove the --deepspeed_mpi flag if running on AzureML.")
self._mpi_check(args, dist_init_required)
else:
# detect if we are in Azure ML environment
if self._in_aml():
self._set_environment_variables_for_nccl_backend(args)
self.dist_backend = "nccl"
if dist_init_required:
if not dist.is_initialized():
logger.info("Initializing torch distributed with backend: {}".format(
self.dist_backend))
dist.init_process_group(backend=self.dist_backend)
else:
logger.warning(
"Was given dist_init_required=True but detected that torch"
"distributed was already initialized, cannot initialize twice.")
self._do_args_sanity_check(args)
self._configure_with_arguments(args, mpu)
......@@ -203,7 +209,7 @@ class DeepSpeedEngine(Module):
self.unflatten = util_ops.unflatten
def _in_aml(self):
# read and environment variable to detect if we are using an Azure ML environment
# read AzureML environment variable to detect if we are using an Azure ML environment
if 'AZUREML_EXPERIMENT_ID' in os.environ:
return True
else:
......@@ -246,7 +252,6 @@ class DeepSpeedEngine(Module):
os.environ['MASTER_PORT']))
def _mpi_check(self, args, dist_init_required):
if hasattr(args, 'deepspeed_mpi') and args.deepspeed_mpi:
from mpi4py import MPI
import subprocess
comm = MPI.COMM_WORLD
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册