comm.py 25.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
"""
    Copyright 2021 The Microsoft DeepSpeed Team

    DeepSpeed Communication Package: deepspeed.comm

    deepspeed.comm
        -- import and use deepspeeed.ops.comm
        -- use torch.distributed directly if both this package and torch.distributed use the same NCCL version
        -- use custom collectives
            -- can either use torch.dist or ds.ops.comm?

        Note: the old 1-bit compressed allreduce variants that resided in deepspeed.runtime.comm will be moved here as well.

    deepspeed.comm API
        -- must be kept fully compatible (same signatures) as torch.dist API to ensure backward/cross-framework compatibility.
        -- e.g. if a client code used
K
Karim Foda 已提交
17
            from deepspeed import comm as dist
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33

            instead of
            import torch.distributed as dist

            The code should work without breaking any of the public torch.distributed functionality

    Future:
        -- deepspeed groups API should be brought into ds.comm
"""

from enum import Enum
import torch
import os

from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT, default_pg_timeout
from .constants import *
34
from deepspeed.accelerator import get_accelerator
35 36 37 38 39 40 41 42 43 44 45 46 47 48


class ReduceOp(Enum):
    SUM = 0
    PRODUCT = 1
    MIN = 2
    MAX = 3
    BAND = 4
    BOR = 5
    BXOR = 6
    AVG = 7
    UNUSED = 8


49 50
from deepspeed.utils.comms_logging import CommsLogger
from deepspeed.utils import timer, get_caller_func
51 52
from deepspeed.comm.torch import TorchBackend

K
Karim Foda 已提交
53
from deepspeed import utils
54 55 56 57 58 59
from datetime import timedelta

# Current deepspeed.comm backend (cdb) global object for simple access by client code
use_ds_backend = False
cdb = None

60 61 62 63 64 65
# Create global timer for ops
timers = timer.SynchronizedWallClockTimer()
timer_summary = {}

comms_logger = CommsLogger()

66 67 68 69 70 71 72
# Maintain objects of all initialized ds backends and assign them using the API functions in this file
nccl_backend = None
mpi_backend = None

# This should be set here so all rank/size information from the launcher can be propagated
from deepspeed.comm.utils import *

73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108

def _configure_using_config_file(config):
    if config.comms_logger_enabled:
        comms_logger.configure(config)


def configure(
    deepspeed_config=None,
    enabled=None,
    prof_all=None,
    prof_ops=None,
    verbose=None,
    debug=None,
):

    if deepspeed_config is not None:
        _configure_using_config_file(deepspeed_config.comms_config)

    if enabled is not None:
        comms_logger.enabled = enabled

    if prof_all is not None:
        comms_logger.prof_all = prof_all

    if prof_ops is not None:
        comms_logger.prof_ops = prof_ops

    if verbose is not None:
        comms_logger.verbose = verbose

    if debug is not None:
        comms_logger.debug = debug


# Logging wrapper for timing ops
def timed_op(func):
109

110 111 112
    def log_wrapper(*args, **kwargs):
        # Add enabled flag so that overhead to each comm op is two if conditions at most
        if comms_logger.enabled:
113 114 115
            if ('prof' in kwargs
                    and kwargs['prof']) or comms_logger.prof_all or ('log_name' in kwargs
                                                                     and kwargs['log_name'] in comms_logger.prof_ops):
116 117 118 119 120 121 122 123 124 125 126 127
                # Need func args for their defaults
                func_args = get_default_args(func)
                func_args.update(kwargs)
                msg_size = get_msg_size_from_args(func, *args, **kwargs)
                log_name = get_debug_log_name(func_args, comms_logger.debug)
                timers(log_name).start()
        # Return the op, then stop the op's timer
        try:
            return func(*args, **kwargs)
        finally:
            if comms_logger.enabled:
                # Need to make op blocking for accurate logging
128
                get_accelerator().synchronize()
129 130 131 132
                # If we're using MPI, we can't simply sync the stream
                if cdb.using_mpi:
                    cdb.barrier()
                if ('prof' in kwargs and kwargs['prof']) or comms_logger.prof_all or (
133
                        'log_name' in kwargs and kwargs['log_name'] in comms_logger.prof_ops):
134 135 136 137 138 139 140 141 142 143
                    log_name = get_debug_log_name(func_args, comms_logger.debug)
                    raw_name = func.__name__
                    timers(log_name).stop()
                    # need temp var since 'elapsed' resets events
                    time_elapsed = timers(log_name).elapsed(reset=False)
                    comms_logger.append(raw_name, log_name, time_elapsed, msg_size)

    return log_wrapper


144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
# For compatibility with torch distributed's init_process_group, we shall retain the signature from PyTorch code.
# DeepSpeed NCCL/MPI backend may not need all these params as we will have our own implementation.
# Please read full torch.distributed API docs from https://pytorch.org/docs/stable/distributed.html


# UNUSED: Future helper function to initialize DS backends
def init_deepspeed_backend(ds_backend):
    global cdb
    global nccl_backend
    global mpi_backend
    global use_ds_backend

    if ds_backend == NCCL_BACKEND:
        utils.logger.warn("NCCL backend in DeepSpeed not yet implemented")
    elif ds_backend == MPI_BACKEND:
        utils.logger.warn("MPI backend in DeepSpeed not yet implemented")
    elif ds_backend == GLOO_BACKEND:
        utils.logger.warn("Gloo backend in DeepSpeed not yet implemented")
    else:
        utils.logger.warn(f"DeepSpeed does not support {ds_backend} backend")


def is_initialized():
    #assert cdb is not None, 'DeepSpeed backend not set, please initialize it using init_process_group()'
    if cdb is None:
        return False
    else:
        return cdb.is_initialized()


def destroy_process_group(group=None):
    global cdb
    return cdb.destroy_process_group(group=group)


def new_group(ranks):
    global cdb
181 182
    assert cdb is not None and cdb.is_initialized(
    ), 'DeepSpeed backend not set, please initialize it using init_process_group()'
183 184 185 186 187 188 189 190 191 192 193 194 195
    return cdb.new_group(ranks)


def is_available() -> bool:

    # Returns ``True`` if the deepspeed comm package is available.

    # TODO: load other ops. Clients including deepspeed itself should use deepspeed.comm to import
    # any communication related primitives from this package.
    # use hasattr(deepspeed.csrc.ops, "_comm") or something
    return True


196
def set_backend(backend_name):
197 198 199 200
    if not use_ds_backend:
        utils.logger.error(
            "DeepSpeed communication backend is required. Please use deepspeed.comm.init_distributed(backend, use_deepspeed=True) to use this functionality"
        )
201
        raise RuntimeError('Error: Custom DeepSpeed backend called without initializing DeepSpeed distributed.')
202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217

    global cdb
    global nccl_backend
    global mpi_backend

    try:
        if backend_name == NCCL_BACKEND:
            if nccl_backend is not None and nccl_backend.is_initialized():
                cdb = nccl_backend
        elif backend_name == MPI_BACKEND:
            if mpi_backend is not None and mpi_backend.is_initialized():
                cdb = mpi_backend
    except Exception as inst:
        print(inst)


218
@timed_op
219
def broadcast(tensor, src, group=None, async_op=False, prof=False, log_name='broadcast', debug=get_caller_func()):
220 221 222 223
    global cdb
    return cdb.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)


224 225 226 227 228 229 230 231
@timed_op
def all_gather(tensor_list,
               tensor,
               group=None,
               async_op=False,
               prof=False,
               log_name='all_gather',
               debug=get_caller_func()):
232
    global cdb
233
    return cdb.all_gather(tensor_list=tensor_list, tensor=tensor, group=group, async_op=async_op)
234 235


236
def has_reduce_scatter_tensor():
237
    global cdb
238 239
    assert cdb is not None and cdb.is_initialized(
    ), 'DeepSpeed backend not set, please initialize it using init_process_group()'
240
    return cdb.has_reduce_scatter_tensor()
241 242


243 244
def reduce_scatter_fn(output_tensor,
                      tensor,
245
                      op=ReduceOp.SUM,
246 247 248 249
                      group=None,
                      async_op=False,
                      prof=False,
                      debug=get_caller_func()):
250
    global cdb
251 252
    assert cdb is not None and cdb.is_initialized(
    ), 'DeepSpeed backend not set, please initialize it using init_process_group()'
253 254 255 256 257 258 259 260
    if cdb.has_reduce_scatter_tensor():
        return reduce_scatter_tensor(output_tensor,
                                     tensor,
                                     op=op,
                                     group=group,
                                     async_op=async_op,
                                     prof=prof,
                                     debug=debug)
261
    else:
262 263 264 265
        if get_rank() == 0:
            utils.logger.warning_once("unable to find torch.distributed.reduce_scatter_tensor. will fall back to "
                                      "torch.distributed.all_gather which will result in suboptimal performance. "
                                      "please consider upgrading your pytorch installation.")
266 267 268
        input_tensor_lst = list(torch.chunk(tensor, cdb.get_world_size(group)))
        return reduce_scatter(output_tensor,
                              input_tensor_lst,
269
                              op=op,
270 271 272 273 274 275 276
                              group=group,
                              async_op=async_op,
                              prof=prof,
                              debug=debug)


@timed_op
277 278 279 280 281 282 283 284
def reduce_scatter_tensor(output_tensor,
                          tensor,
                          op=ReduceOp.SUM,
                          group=None,
                          async_op=False,
                          prof=False,
                          log_name='reduce_scatter_tensor',
                          debug=get_caller_func()):
285
    global cdb
286 287 288 289 290
    return cdb.reduce_scatter_tensor(output_tensor=output_tensor,
                                     input_tensor=tensor,
                                     op=op,
                                     group=group,
                                     async_op=async_op)
291 292


293
@timed_op
294 295 296 297 298 299 300
def all_gather_into_tensor(output_tensor,
                           tensor,
                           group=None,
                           async_op=False,
                           prof=False,
                           log_name='all_gather_into_tensor',
                           debug=get_caller_func()):
301
    global cdb
302
    return cdb.all_gather_into_tensor(output_tensor=output_tensor, input_tensor=tensor, group=group, async_op=async_op)
303 304


305
def has_all_gather_into_tensor():
306
    global cdb
307 308
    assert cdb is not None and cdb.is_initialized(
    ), 'DeepSpeed backend not set, please initialize it using init_process_group()'
309
    return cdb.has_all_gather_into_tensor()
310 311


312
def allgather_fn(output_tensor, input_tensor, group=None, async_op=False, debug=get_caller_func()):
313
    global cdb
314 315
    assert cdb is not None and cdb.is_initialized(
    ), 'DeepSpeed backend not set, please initialize it using init_process_group()'
316 317
    if cdb.has_all_gather_into_tensor():
        return all_gather_into_tensor(output_tensor, input_tensor, group=group, async_op=async_op, debug=debug)
318
    else:
319
        if get_rank() == 0:
320
            utils.logger.warning_once("unable to find torch.distributed.all_gather_into_tensor. will fall back to "
321 322
                                      "torch.distributed.all_gather which will result in suboptimal performance. "
                                      "please consider upgrading your pytorch installation.")
323
        output_tensors = list(torch.chunk(output_tensor, cdb.get_world_size(group)))
324
        return all_gather(output_tensors, input_tensor, group=group, async_op=async_op, debug=debug)
325 326 327 328 329 330 331 332 333 334 335 336


@timed_op
def all_to_all_single(output,
                      tensor,
                      output_split_sizes=None,
                      input_split_sizes=None,
                      group=None,
                      async_op=False,
                      prof=False,
                      log_name='all_to_all_single',
                      debug=get_caller_func()):
337 338
    global cdb
    return cdb.all_to_all_single(output=output,
339
                                 input=tensor,
340 341 342 343 344 345
                                 output_split_sizes=output_split_sizes,
                                 input_split_sizes=input_split_sizes,
                                 group=group,
                                 async_op=async_op)


346
@timed_op
347
def send(tensor, dst, group=None, tag=0, prof=False, log_name='send', debug=get_caller_func()):
348 349 350 351
    global cdb
    return cdb.send(tensor=tensor, dst=dst, group=group, tag=tag)


352
@timed_op
353
def recv(tensor, src=None, group=None, tag=0, prof=False, log_name='recv', debug=get_caller_func()):
354 355 356 357
    global cdb
    return cdb.recv(tensor=tensor, src=src, group=group, tag=tag)


358
@timed_op
359
def isend(tensor, dst, group=None, tag=0, prof=False, log_name='isend', debug=get_caller_func()):
360 361 362 363
    global cdb
    return cdb.send(tensor=tensor, dst=dst, group=group, tag=tag)


364
@timed_op
365
def irecv(tensor, src=None, group=None, tag=0, prof=False, log_name='irecv', debug=get_caller_func()):
366 367 368 369
    global cdb
    return cdb.recv(tensor=tensor, src=src, group=group, tag=tag)


370 371 372 373 374 375 376 377 378
@timed_op
def gather(tensor,
           gather_list=None,
           dst=0,
           group=None,
           async_op=False,
           prof=False,
           log_name='gather',
           debug=get_caller_func()):
379
    global cdb
380
    return cdb.gather(tensor=tensor, gather_list=gather_list, dst=dst, group=group, async_op=async_op)
381 382


383 384 385 386 387 388 389 390 391
@timed_op
def scatter(tensor,
            scatter_list=None,
            src=0,
            group=None,
            async_op=False,
            prof=False,
            log_name='scatter',
            debug=get_caller_func()):
392
    global cdb
393
    return cdb.scatter(tensor=tensor, scatter_list=scatter_list, src=src, group=group, async_op=async_op)
394 395


396
@timed_op
397
def barrier(group=None, async_op=False, device_ids=None, prof=False, log_name='barrier', debug=get_caller_func()):
398 399 400 401 402 403 404 405 406 407 408
    global cdb
    return cdb.barrier(group=group, async_op=async_op, device_ids=device_ids)


@timed_op
def monitored_barrier(group=None,
                      timeout=None,
                      wait_all_ranks=False,
                      prof=False,
                      log_name='monitored_barrier',
                      debug=get_caller_func()):
409
    global cdb
410
    return cdb.barrier(group=group, timeout=timeout, wait_all_ranks=wait_all_ranks)
411 412


413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429
def log_summary():
    global cdb
    barrier(log_name='log_summary_barrier')
    if cdb.get_rank() == 0:
        comms_logger.log_all()
    barrier(log_name='log_summary_barrier')


@timed_op
def reduce(tensor,
           dst,
           op=ReduceOp.SUM,
           group=None,
           async_op=False,
           prof=False,
           log_name='reduce',
           debug=get_caller_func()):
430 431 432 433
    global cdb
    return cdb.reduce(tensor=tensor, dst=dst, op=op, group=group, async_op=async_op)


434 435 436 437 438 439 440 441 442
@timed_op
def reduce_scatter(output,
                   input_list,
                   op=ReduceOp.SUM,
                   group=None,
                   async_op=False,
                   prof=False,
                   log_name='reduce_scatter',
                   debug=get_caller_func()):
443
    global cdb
444
    return cdb.reduce_scatter(output=output, input_list=input_list, op=op, group=group, async_op=async_op)
445 446


447 448 449 450 451 452 453 454
@timed_op
def all_reduce(tensor,
               op=ReduceOp.SUM,
               group=None,
               async_op=False,
               prof=False,
               log_name='all_reduce',
               debug=get_caller_func()):
455 456 457 458 459 460 461 462 463 464 465
    #if profile_comm:
    # context of the timers?
    # timers.start()
    # TensorBoard logging for comm calls.?
    global cdb
    #print(f'op = {op}, cdb= {cdb.name}')
    return cdb.all_reduce(tensor, op, group, async_op)


def get_world_group():
    global cdb
466 467
    assert cdb is not None and cdb.is_initialized(
    ), 'DeepSpeed backend not set, please initialize it using init_process_group()'
468 469 470 471 472 473 474 475 476 477 478 479 480 481 482
    return cdb.get_world_group()


def get_world_size(group=None) -> int:
    """
    Returns the number of processes in the current process group
    Args:
        group (ProcessGroup, optional): The process group to work on. If None,
            the default process group will be used.
    Returns:
        The world size of the process group
        -1, if not part of the group
    """
    global cdb

483 484
    assert cdb is not None and cdb.is_initialized(
    ), 'DeepSpeed backend not set, please initialize it using init_process_group()'
485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502
    return cdb.get_world_size(group)


def get_rank(group=None):
    """
    Returns the rank of the current process in the provided ``group`` or the
    default group if none was provided.
    Rank is a unique identifier assigned to each process within a distributed
    process group. They are always consecutive integers ranging from 0 to
    ``world_size``.
    Args:
        group (ProcessGroup, optional): The process group to work on. If None,
            the default process group will be used.
    Returns:
        The rank of the process group
        -1, if not part of the group
    """
    global cdb
503 504
    assert cdb is not None and cdb.is_initialized(
    ), 'DeepSpeed backend not set, please initialize it using init_process_group()'
505 506 507 508 509 510 511 512 513 514 515 516
    return cdb.get_rank(group)


def get_local_rank():
    """
        Helper function to get local rank after a backend has been set and initialized
        Args:
            None
        Returns:
            local rank (= GPU device ID)
    """
    global cdb
517 518
    assert cdb is not None and cdb.is_initialized(
    ), 'DeepSpeed backend not set, please initialize it using init_process_group()'
519 520 521 522 523
    return get_local_rank_from_launcher()


def get_global_rank(group=None, group_rank=0):
    global cdb
524 525
    assert cdb is not None and cdb.is_initialized(
    ), 'DeepSpeed backend not set, please initialize it using init_process_group()'
526 527 528 529
    return cdb.get_global_rank(group, group_rank)


# Main DeepSpeed Comms. public API.
530
def init_distributed(dist_backend=None,
531 532 533 534 535
                     auto_mpi_discovery=True,
                     distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT,
                     verbose=True,
                     timeout=default_pg_timeout,
                     init_method=None,
536
                     dist_init_required=None,
537 538 539
                     config=None,
                     rank=-1,
                     world_size=-1):
540 541 542 543 544 545 546 547 548
    ''' Initialize dist backend, potentially performing MPI discovery if needed

    Arguments:
        dist_backend: Optional (str). torch distributed backend, e.g., nccl, mpi, gloo
        auto_mpi_discovery Optional (bool). if distributed environment variables are not set, attempt to discover them from MPI
        distributed_port: Optional (int). torch distributed backend port
        verbose: Optional (bool). verbose logging
        timeout: Optional (timedelta). Timeout for operations executed against the process group. Default value equals 30 minutes.
        init_method: Optional (string). Torch distributed, URL specifying how to initialize the process group. Default is “env://” if no init_method or store is specified.
549
        config: Optional (dict). DeepSpeed configuration for setting up comms options (e.g. Comms profiling)
550 551
        rank: Optional (int). The current manually specified rank. Some init_method like “tcp://” need the rank and world_size as well (see: https://pytorch.org/docs/stable/distributed.html#tcp-initialization)
        world_size: Optional (int). Desired world_size for the TCP or Shared file-system initialization.
552 553 554
    '''
    global cdb

555 556
    configure(deepspeed_config=config)

557 558 559
    if dist_init_required is None:
        dist_init_required = cdb is None or not cdb.is_initialized()

560 561 562 563 564
    if cdb is None and torch.distributed.is_initialized():
        # The user initialized torch.dist themselves, create cdb and short-circuit
        cdb = TorchBackend(dist_backend, timeout, init_method)
        return

565 566 567 568 569 570 571 572 573
    if dist_init_required is False:
        assert (
            cdb is not None and cdb.is_initialized() is True
        ), "Distributed backend is not initialized. Please set dist_init_required to True or initialize before calling deepspeed.initialize()"
    else:
        # Initialize torch distributed if needed
        required_env = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
        if auto_mpi_discovery and not all(map(lambda v: v in os.environ, required_env)):
            if verbose:
574
                utils.logger.info("Not using the DeepSpeed or dist launchers, attempting to detect MPI environment...")
575 576 577 578 579 580 581 582 583 584 585 586
            if in_aml() and not in_dlts():
                patch_aml_env_for_torch_nccl_backend(verbose=verbose)
            elif in_aws_sm():
                patch_aws_sm_env_for_torch_nccl_backend(verbose=verbose)
            else:
                mpi_discovery(distributed_port=distributed_port, verbose=verbose)

        if cdb is not None and cdb.is_initialized():
            if int(os.getenv('RANK', '0')) == 0:
                utils.logger.info('Distributed backend already initialized')
        else:
            assert isinstance(timeout, timedelta)
587 588
            if dist_backend == None:
                dist_backend = get_accelerator().communication_backend_name()
589
            if int(os.getenv('RANK', '0')) == 0:
590
                utils.logger.info('Initializing TorchBackend in DeepSpeed with backend {}'.format(dist_backend))
591
            # Create a torch backend object, initialize torch distributed, and assign to cdb
592
            cdb = TorchBackend(dist_backend, timeout, init_method, rank, world_size)
593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624


def mpi_discovery(distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT, verbose=True):
    '''
    Discovery MPI environment via mpi4py and map to relevant dist state
    '''
    from mpi4py import MPI
    import subprocess
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    world_size = comm.Get_size()

    master_addr = None
    if rank == 0:
        hostname_cmd = ["hostname -I"]
        result = subprocess.check_output(hostname_cmd, shell=True)
        master_addr = result.decode('utf-8').split()[0]
    master_addr = comm.bcast(master_addr, root=0)

    # Determine local rank by assuming hostnames are unique
    proc_name = MPI.Get_processor_name()
    all_procs = comm.allgather(proc_name)
    local_rank = sum([i == proc_name for i in all_procs[:rank]])

    os.environ['RANK'] = str(rank)
    os.environ['WORLD_SIZE'] = str(world_size)
    os.environ['LOCAL_RANK'] = str(local_rank)
    os.environ['MASTER_ADDR'] = master_addr
    os.environ['MASTER_PORT'] = str(distributed_port)

    if verbose:
        utils.logger.info(
625 626 627
            "Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}".
            format(os.environ['RANK'], os.environ['LOCAL_RANK'], os.environ['WORLD_SIZE'], os.environ['MASTER_ADDR'],
                   os.environ['MASTER_PORT']))
628 629

    if cdb is not None and cdb.is_initialized():
630
        assert cdb.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(rank, cdb.get_rank())
631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656
        assert cdb.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format(
            world_size, cdb.get_world_size())


def in_aml():
    # Are we running inside an Azure Machine Learning (AML) environment?
    return 'AZUREML_EXPERIMENT_ID' in os.environ


def in_aws_sm():
    # Are we running inside an AWS SageMaker environment?
    return 'SM_TRAINING_ENV' in os.environ


def in_dlts():
    # Are we running on a DLTS cluster?
    return 'DLTS_JOB_ID' in os.environ


def patch_aml_env_for_torch_nccl_backend(master_port=6105, verbose=True):
    """Helper routine to get and set environment variables.
    This is adapted from Azure ML's documentation available from:
    https://azure.github.io/azureml-web/docs/cheatsheet/distributed-training/#environment-variables-from-openmpi
    """
    os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
    os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
657
    single_node = int(os.environ["OMPI_COMM_WORLD_LOCAL_SIZE"]) == int(os.environ["WORLD_SIZE"])
658 659 660 661 662 663 664 665 666 667 668 669

    if not single_node:
        master_node_params = os.environ["AZ_BATCH_MASTER_NODE"].split(":")
        os.environ["MASTER_ADDR"] = master_node_params[0]
        # Do not overwrite master port with that defined in AZ_BATCH_MASTER_NODE
        if "MASTER_PORT" not in os.environ:
            os.environ["MASTER_PORT"] = str(master_port)
    else:
        os.environ["MASTER_ADDR"] = os.environ["AZ_BATCHAI_MPI_MASTER_NODE"]
        os.environ["MASTER_PORT"] = DEFAULT_AML_MASTER_PORT

    if verbose:
670
        utils.logger.info("NCCL_SOCKET_IFNAME original value = {}".format(os.environ["NCCL_SOCKET_IFNAME"]))
671 672 673 674 675 676 677

    os.environ["NCCL_SOCKET_IFNAME"] = DEFAULT_AML_NCCL_SOCKET_IFNAME
    os.environ['LOCAL_RANK'] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]

    if verbose:
        utils.logger.info(
            "Discovered AzureML settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
678
            .format(os.environ['RANK'], os.environ['LOCAL_RANK'], os.environ['WORLD_SIZE'], os.environ['MASTER_ADDR'],
679 680 681 682 683 684 685 686 687 688 689 690 691
                    os.environ['MASTER_PORT']))


def patch_aws_sm_env_for_torch_nccl_backend(verbose=True):
    """Helper routine to get and set environment variables when running inside an AWS SageMaker environment.
    """
    os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
    os.environ['LOCAL_RANK'] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
    os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]

    if verbose:
        utils.logger.info(
            "Discovered AWS SageMaker settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
692
            .format(os.environ['RANK'], os.environ['LOCAL_RANK'], os.environ['WORLD_SIZE'], os.environ['MASTER_ADDR'],
693
                    os.environ['MASTER_PORT']))