partition_parameters.py 50.7 KB
Newer Older
J
Jeff Rasley 已提交
1 2 3 4 5
"""
"Copyright 2020 The Microsoft DeepSpeed Team.
Licensed under the MIT license.
"""

S
Samyam Rajbhandari 已提交
6 7 8 9 10 11 12 13 14 15
import os
import time
import types
from enum import Enum
import functools
import itertools

import torch
from torch.distributed.distributed_c10d import _get_global_rank

J
Jeff Rasley 已提交
16 17 18 19
from .linear import LinearModuleForZeroStage3, LinearFunctionForZeroStage3
from .offload_constants import *

from ..utils import see_memory_usage
S
Stas Bekman 已提交
20
from deepspeed.utils import log_dist, init_distributed, logger
S
Stas Bekman 已提交
21
from deepspeed.utils.debug import debug_param2name_id_shape, debug_param2name_id_shape_device, debug_module2name, debug_param2name, debug_param2name_id_shape_status, printflock, log_rank_file
S
Samyam Rajbhandari 已提交
22

J
Jeff Rasley 已提交
23 24 25
from ..swap_tensor.partitioned_param_swapper import AsyncPartitionedParameterSwapper, PartitionedParamStatus
from ..config import DeepSpeedConfig

S
Samyam Rajbhandari 已提交
26
param_count = 0
27
partitioned_param_data_shape = [1]
S
Samyam Rajbhandari 已提交
28 29 30


def print_rank_0(message, debug=False, force=False):
S
Stas Bekman 已提交
31 32
    rank = torch.distributed.get_rank()
    if rank == 0 and (debug or force):
S
Samyam Rajbhandari 已提交
33
        print(message)
S
Stas Bekman 已提交
34 35 36 37 38
    # other variations
    # - print for all ranks w/o interleaving
    # printflock(f"[{rank}] {message}")
    # - print to log file per rank
    # log_rank_file(rank, message)
S
Samyam Rajbhandari 已提交
39 40 41


def is_zero_param(parameter):
J
Jeff Rasley 已提交
42 43
    if not torch.is_tensor(parameter):
        return False
S
Samyam Rajbhandari 已提交
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
    return hasattr(parameter, 'ds_id')


def _init_external_params(module):
    if not hasattr(module, '_external_params'):
        module._external_params = {}

        def external_parameters(self):
            return self._external_params.items()

        def all_parameters(self):
            return itertools.chain(self.named_parameters(self,
                                                         recurse=False),
                                   external_parameters(self))

        module.ds_external_parameters = types.MethodType(external_parameters, module)
        module.all_parameters = types.MethodType(all_parameters, module)


def register_external_parameter(module, parameter):
    """Instruct DeepSpeed to coordinate ``parameter``'s collection and partitioning in
    the forward and backward passes of ``module``.

    This is used when a parameter is accessed outside of its owning module's
    ``forward()``. DeepSpeed must know to collect it from its partitioned
    state and when to release the memory.

    .. note::
        This is only applicable to training with ZeRO stage 3.

    Args:
        module (``torch.nn.Module``): The module that requires ``parameter`` in its forward pass.
        parameter (``torch.nn.Parameter``): The parameter to register.

    Raises:
        RuntimeError: If ``parameter`` is not of type ``torch.nn.Parameter``.


    Examples
    ========

    #. Register a weight that is used in another module's forward pass (line 6).
       Parameter ``layer1.weight`` is used by ``layer2`` (line 11).

        .. code-block:: python
            :linenos:
            :emphasize-lines: 6,11

            class ModuleZ3(torch.nn.Module):
                def __init__(self, *args):
                    super().__init__(self, *args)
                    self.layer1 = SomeLayer()
                    self.layer2 = OtherLayer()
                    deepspeed.zero.register_external_parameter(self, self.layer1.weight)

                def forward(self, input):
                    x = self.layer1(input)
                    # self.layer1.weight is required by self.layer2.forward
                    y = self.layer2(x, self.layer1.weight)
                    return y
    """
    if not isinstance(parameter, torch.nn.Parameter):
        raise RuntimeError('Parameter is not a torch.nn.Parameter')

    if not hasattr(module, '_external_params'):
        _init_external_params(module)

    key = id(parameter)
    module._external_params[key] = parameter


J
Jeff Rasley 已提交
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
def unregister_external_parameter(module, parameter):
    """Reverses the effects of :meth:`register_external_parameter`.

    Args:
        module (``torch.nn.Module``): The module to affect.
        parameter (``torch.nn.Parameter``): The parameter to unregister.

    Raises:
        RuntimeError: If ``parameter`` is not of type ``torch.nn.Parameter``.
        RuntimeError: If ``parameter`` is not a registered external parameter of ``module``.
    """
    if not isinstance(parameter, torch.nn.Parameter):
        raise RuntimeError('Parameter is not a torch.nn.Parameter')

    if not hasattr(module,
                   '_external_params') or id(parameter) not in module._external_params:
        raise RuntimeError('Parameter is not a registered external parameter of module.')

    key = id(parameter)
    del module._external_params[key]


S
Samyam Rajbhandari 已提交
137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
class ZeroParamType(Enum):

    # same as regular pytorch parameters
    NORMAL = 1

    # parameters are partitioned across data parallel process
    PARTITIONED = 2

    # the parameter is held with a unique process rank
    # and is not available on all other process
    REMOTE = 3


class ZeroParamStatus(Enum):
    # parameters are fully present and ready for use on all processes
    AVAILABLE = 1

    # parameters are either partitioned or remote in some or all process
    NOT_AVAILABLE = 2

    # parameters are being gathered.
    INFLIGHT = 3


_orig_torch_empty = torch.empty


164
def empty_cuda_tensor_half(*size, **kwargs):
S
Samyam Rajbhandari 已提交
165 166 167 168 169 170 171 172 173
    if not 'device' in kwargs.keys():
        kwargs['device'] = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"]))
    tensor = _orig_torch_empty(*size, **kwargs)
    if tensor.is_floating_point():
        return tensor.half()
    else:
        return tensor


174
def new_cuda_tensor_half(cls, *args):
S
Samyam Rajbhandari 已提交
175 176 177 178 179 180 181 182
    device = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"]))
    tensor = torch.ones((1, 1), device=device).new_empty(*args).half()
    if tensor.is_floating_point():
        return tensor.half()
    else:
        return tensor


183 184 185 186 187 188 189 190 191 192 193 194 195
def empty_cuda_tensor(*size, **kwargs):
    if not 'device' in kwargs.keys():
        kwargs['device'] = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"]))
    tensor = _orig_torch_empty(*size, **kwargs)
    return tensor


def new_cuda_tensor(cls, *args):
    device = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"]))
    tensor = torch.ones((1, 1), device=device).new_empty(*args)
    return tensor


196 197 198 199 200 201 202 203 204 205 206 207 208 209
# https://stackoverflow.com/a/63851681/9201239
def get_all_subclasses(cls):
    subclass_list = []

    def recurse(cl):
        for subclass in cl.__subclasses__():
            subclass_list.append(subclass)
            recurse(subclass)

    recurse(cls)

    return set(subclass_list)


S
Samyam Rajbhandari 已提交
210 211 212 213 214 215 216 217
reuse_buffers = False
temp_contiguous_tensor = None
empty_buffers = {}


# Inserts _post_init_method at the end of init method
# for all sub classes of torch.nn.Module
class InsertPostInitMethodToModuleSubClasses(object):
O
Olatunji Ruwase 已提交
218 219 220 221 222
    def __init__(self,
                 enabled=True,
                 mem_efficient_linear=True,
                 ds_config=None,
                 dtype=None):
S
Samyam Rajbhandari 已提交
223 224
        self.mem_efficient_linear = mem_efficient_linear
        self.enabled = enabled
O
Olatunji Ruwase 已提交
225
        self._set_dtype(ds_config, dtype)
226
        assert self.dtype in [torch.half, torch.float], f"Invalid data type {self.dtype}, allowed values are [torch.half, torch.float]"
S
Samyam Rajbhandari 已提交
227 228 229 230 231 232 233 234

    def __enter__(self):
        if not self.enabled:
            return

        def partition_after(f):
            @functools.wraps(f)
            def wrapper(module, *args, **kwargs):
235 236 237 238 239 240 241 242 243

                # important logic: We want to run post_init only after child's __init__ is
                # completed, and do nothing after __init__ of any of its parents and grandparents in
                # the inheritance ancestry. This way the partitioning will need to happen only once
                # when the whole object is ready to be partitioned and not before. This is because
                # often the child module will need to tweak the weights - for example running a
                # custom weights init function. So if a parent created the weights param, the child
                # won't need to gather it in order to tweak it

S
Samyam Rajbhandari 已提交
244 245
                print_rank_0(f'Before initializing {module.__class__.__name__}',
                             force=False)
246 247 248 249 250 251 252

                is_child_module = False
                if not hasattr(module, "_ds_child_entered"):
                    # child's __init__ was called, since parents all see the same object they can now skip post_init
                    is_child_module = True
                    setattr(module, "_ds_child_entered", True)

S
Samyam Rajbhandari 已提交
253
                f(module, *args, **kwargs)
254 255 256 257 258 259 260 261 262

                if is_child_module:
                    # child's __init__ is done, now we can run a single post_init on the child object
                    delattr(module, "_ds_child_entered")

                    print_rank_0(f'Running post_init for {module.__class__.__name__}',
                                 force=False)
                    self._post_init_method(module)

S
Samyam Rajbhandari 已提交
263 264 265 266 267 268 269 270 271 272 273 274 275
                print_rank_0(
                    f'After initializing followed by post init for {module.__class__.__name__}',
                    force=False)

            return wrapper

        def _enable_class(cls):
            cls._old_init = cls.__init__
            cls.__init__ = partition_after(cls.__init__)

        def _init_subclass(cls, **kwargs):
            cls.__init__ = partition_after(cls.__init__)

276 277 278
        # Replace .__init__() for all existing subclasses of torch.nn.Module recursively
        for subclass in get_all_subclasses(torch.nn.modules.module.Module):
            # print(f"subclass={subclass.__module__}.{subclass.__qualname__}")
S
Samyam Rajbhandari 已提交
279 280 281 282 283 284 285 286
            _enable_class(subclass)

        # holding on to the current __init__subclass__ for exit
        torch.nn.modules.module.Module._old_init_subclass = torch.nn.modules.module.Module.__init_subclass__
        torch.Tensor.__old_new__ = torch.Tensor.__new__

        # Replace .__init__() for future subclasses of torch.nn.Module
        torch.nn.modules.module.Module.__init_subclass__ = classmethod(_init_subclass)
287 288 289 290 291 292
        if self.dtype == torch.half:
            torch.Tensor.__new__ = new_cuda_tensor_half
            torch.empty = empty_cuda_tensor_half
        else:
            torch.Tensor.__new__ = new_cuda_tensor
            torch.empty = empty_cuda_tensor
S
Samyam Rajbhandari 已提交
293 294

        if self.mem_efficient_linear:
295
            print_rank_0(
296
                "nn.functional.linear has been overridden with a more memory efficient version. This will persist unless manually reset.",
S
Stas Bekman 已提交
297
                force=False)
S
Samyam Rajbhandari 已提交
298 299 300 301 302 303 304 305 306 307 308
            self.linear_bk = torch.nn.functional.linear
            torch.nn.functional.linear = LinearFunctionForZeroStage3.apply

    def __exit__(self, exc_type, exc_value, traceback):
        if not self.enabled:
            return

        def _disable_class(cls):
            cls.__init__ = cls._old_init

        # Replace .__init__() for all existing subclasses of torch.nn.Module
309
        for subclass in get_all_subclasses(torch.nn.modules.module.Module):
S
Samyam Rajbhandari 已提交
310 311 312 313 314 315 316 317
            _disable_class(subclass)

        # Replace .__init__() for future subclasses of torch.nn.Module
        torch.nn.modules.module.Module.__init_subclass__ = torch.nn.modules.module.Module._old_init_subclass

        torch.Tensor.__new__ = torch.Tensor.__old_new__
        torch.empty = _orig_torch_empty

318 319
        # un doing it here will undo it during training
        # if self.mem_efficient_linear:
320
        #    torch.nn.functional.linear = self.linear_bk
J
Jeff Rasley 已提交
321 322
        #        if self.mem_efficient_linear:
        #            torch.nn.functional.linear = self.linear_bk
S
Samyam Rajbhandari 已提交
323 324 325 326 327 328 329 330 331

        # Now that we cleaned up the metaclass injection, raise the exception.
        if exc_type is not None:
            return False

    # To be implemented by inheriting classes
    def _post_init_method(self, module):
        pass

332 333
    def _set_dtype(self, ds_config, dtype):
        if ds_config is not None and dtype is None:
O
Olatunji Ruwase 已提交
334
            self.dtype = torch.half if ds_config.fp16_enabled else torch.float
335 336 337 338 339
        elif dtype is None:
            self.dtype = torch.half
        else:
            self.dtype = dtype

S
Samyam Rajbhandari 已提交
340 341 342 343 344 345 346 347 348 349 350

# Replaces all parameters in module with Scattered Parameters
class Init(InsertPostInitMethodToModuleSubClasses):
    param_id = 0

    def __init__(self,
                 module=None,
                 data_parallel_group=None,
                 mem_efficient_linear=True,
                 remote_device=None,
                 pin_memory=False,
O
Olatunji Ruwase 已提交
351
                 config_dict_or_path=None,
352
                 config=None,
353
                 enabled=True,
O
Olatunji Ruwase 已提交
354 355
                 dtype=None,
                 mpu=None):
S
Samyam Rajbhandari 已提交
356 357 358 359 360 361 362 363 364 365 366 367
        """A context to enable massive model construction for training with
        ZeRO-3. Models are automatically partitioned (or, sharded) across the
        system and converted to half precision.

        Args:
            module (``torch.nn.Module``, optional): If provided, partition the model as
                if it was constructed in the context.
            data_parallel_group (``torch.distributed`` process group, optional):
                The group of processes to partition among. Defaults to all processes.
            mem_efficient_linear (bool, optional): Replace
                torch.nn.functional.linear with an implementation that allows
                DeepSpeed to partition parameters. Defaults to ``True``.
J
Jeff Rasley 已提交
368 369 370 371
            remote_device (string, optional): The initial device to store model
                weights e.g., ``cpu``, ``nvme``. Passing ``"cpu"`` will create the model in CPU
                memory. The model may still be moved to GPU based on the
                offload settings for training. Defaults to the local GPU.
S
Samyam Rajbhandari 已提交
372 373 374
            pin_memory (bool, optional): Potentially increase performance by
                using pinned memory for model weights. ``remote_device`` must be
                ``"cpu"``. Defaults to ``False``.
375
            config_dict_or_path (dict or ``json file``, optional): If provided, provides configuration
J
Jeff Rasley 已提交
376
                for swapping fp16 params to NVMe.
377
            config (dict or ``json file``, optional): Deprecated, use config_dict_or_path instead.
S
Samyam Rajbhandari 已提交
378 379
            enabled (bool, optional): If ``False``, this context has no
                effect. Defaults to ``True``.
S
Stas Bekman 已提交
380 381
            dtype (``dtype``, optional): Can be used to change the data type of the parameters.
                Supported options are ``torch.half`` and ``torch.float``. Defaults to ``None``
A
Alex Hedges 已提交
382
            mpu (``object``, optional): A model parallelism unit object that implements get_{model,data}_parallel_{rank,group,world_size}.
S
Samyam Rajbhandari 已提交
383 384 385 386 387

        This context accelerates model initialization and enables models that
        are too large to allocate in their entirety in CPU memory. It has the
        following effects:

J
Jeff Rasley 已提交
388
        #. allocates tensors to either GPU or CPU memory or NVMe
S
Samyam Rajbhandari 已提交
389 390 391 392 393 394
        #. converts floating point tensors to half precision
        #. immediately partitions tensors among the group of data-parallel devices
        #. (*optional*) replaces ``torch.nn.functional.linear`` with a more
           memory-efficient implementation

        These modifications allow for models that exceed the size of local CPU/GPU
J
Jeff Rasley 已提交
395 396
        memory/NVMe, but fit within the total NVMe capacity (*i.e.*, aggregate CPU
        or GPU memory or NVMe) across all nodes. Consider initializing a model with one
S
Samyam Rajbhandari 已提交
397 398 399 400 401 402 403 404 405 406
        trillion parameters, whose weights occupy two terabytes (TB) in half
        precision. The initial CPU allocation in full precision requires 4TB of
        memory *per process*, and so a system with 8 GPUs per node would need 32TB of
        CPU memory due to data-parallel redundancies. Instead, by immediately
        partitioning tensors we remove the redundancies. The result is that
        regardless of the number of GPUs, we still only require the original 4TB. This
        allows for a linear increase in model size with the aggregate system memory.
        For example, if a node has 1TB of memory and 8 GPUs, we could fit a trillion
        parameter model with 4 nodes and 32 GPUs.

S
Stas Bekman 已提交
407 408 409
        Important: If the fp16 weights of the model can't fit onto a single GPU memory
        this feature must be used.

S
Samyam Rajbhandari 已提交
410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452
        .. note::
            Initializes ``torch.distributed`` if it has not already been done so.
            See :meth:`deepseed.init_distributed` for more information.

        .. note::
            Can also be used as a decorator:

            .. code-block:: python

                @deepspeed.zero.Init()
                def get_model():
                    return MyLargeModel()

        .. note::
            Only applicable to training with ZeRO-3.

        Examples
        --------

        #. Allocate a model and partition it among all processes:

            .. code-block:: python

                with deepspeed.zero.Init():
                    model = MyLargeModel()


        #. Allocate a model in pinned CPU memory and partition it among a subgroup of processes:

            .. code-block:: python

                with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(),
                                         remote_device="cpu",
                                         pin_memory=True):
                    model = MyLargeModel()


        #. Partition an already-allocated model in CPU memory:

            .. code-block:: python

                model = deepspeed.zero.Init(module=model)
        """
S
Stas Bekman 已提交
453 454 455 456 457 458
        if config is not None:
            config_dict_or_path = config
            logger.warning(
                f'zero.Init: the `config` argument is deprecated. Please use `config_dict_or_path` instead.'
            )

O
Olatunji Ruwase 已提交
459 460
        _ds_config = DeepSpeedConfig(config_dict_or_path,
                                     mpu) if config_dict_or_path is not None else None
461 462
        super().__init__(enabled=enabled,
                         mem_efficient_linear=mem_efficient_linear,
O
Olatunji Ruwase 已提交
463
                         ds_config=_ds_config,
464
                         dtype=dtype)
S
Samyam Rajbhandari 已提交
465 466 467 468 469 470 471 472 473 474 475
        if not torch.distributed.is_initialized():
            init_distributed()
            assert torch.distributed.is_initialized(), "Parameters cannot be scattered without initializing torch.distributed"
        if data_parallel_group is None:
            self.ds_process_group = torch.distributed.group.WORLD
        else:
            self.ds_process_group = data_parallel_group

        self.rank = torch.distributed.get_rank(group=self.ds_process_group)
        self.world_size = torch.distributed.get_world_size(group=self.ds_process_group)

476 477
        # Local device is the device where the parameters are consumed
        # It is the device where parameters are fully instantiated using allgather
S
Samyam Rajbhandari 已提交
478 479
        self.local_device = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"]))

O
Olatunji Ruwase 已提交
480
        self._validate_remote_device(remote_device, _ds_config)
J
Jeff Rasley 已提交
481

482 483
        # Remote device is the device where parameter partiitons are stored
        # It can be same as local_device or it could be CPU or NVMe.
S
Samyam Rajbhandari 已提交
484
        self.remote_device = self.local_device if remote_device is None else remote_device
J
Jeff Rasley 已提交
485 486 487 488 489 490 491 492
        self.pin_memory = pin_memory if (
            self.remote_device == OFFLOAD_CPU_DEVICE) else False

        # Enable fp16 param swapping to NVMe
        if self.remote_device == OFFLOAD_NVME_DEVICE:
            self.param_swapper = AsyncPartitionedParameterSwapper(_ds_config)
        else:
            self.param_swapper = None
S
Samyam Rajbhandari 已提交
493 494 495 496 497 498 499 500 501 502

        # If we are provided an already-allocated module to prepare.
        if module is not None:
            assert isinstance(module, torch.nn.Module)
            for param in module.parameters(recurse=True):
                if is_zero_param(param):
                    continue
                self._convert_to_deepspeed_param(param)
                param.partition()

503
    def _validate_remote_device(self, remote_device, ds_config):
J
Jeff Rasley 已提交
504 505
        if ds_config is not None:
            if remote_device in [None, OFFLOAD_CPU_DEVICE]:
O
Olatunji Ruwase 已提交
506 507
                if ds_config.zero_config.offload_param is not None:
                    offload_param_device = ds_config.zero_config.offload_param[
J
Jeff Rasley 已提交
508 509
                        OFFLOAD_PARAM_DEVICE]
                    assert offload_param_device != OFFLOAD_NVME_DEVICE, \
510
                        f"{OFFLOAD_PARAM_DEVICE} in DeepSpeed Config cannot be {offload_param_device} if remote device is {remote_device}."
J
Jeff Rasley 已提交
511 512

            if remote_device == OFFLOAD_NVME_DEVICE:
O
Olatunji Ruwase 已提交
513
                assert ds_config.zero_config.offload_param is not None, \
J
Jeff Rasley 已提交
514 515
                f'{OFFLOAD_PARAM} must be defined in DeepSpeed Config if remote device is {OFFLOAD_NVME_DEVICE}.'

O
Olatunji Ruwase 已提交
516
                assert ds_config.zero_config.offload_param[OFFLOAD_PARAM_NVME_PATH] is not None, \
J
Jeff Rasley 已提交
517 518
                f'{OFFLOAD_PARAM_NVME_PATH} in DeepSpeed Config cannot be None if remote device is {OFFLOAD_NVME_DEVICE}'

S
Samyam Rajbhandari 已提交
519 520 521 522 523 524 525 526
    def _post_init_method(self, module):
        #see_memory_usage(f"Before converting parmas in {module.__class__.__name__}", force=False)
        print_rank_0(f'Converting Params in {module.__class__.__name__}', force=False)
        see_memory_usage(
            f"Before converting and partitioning parmas in {module.__class__.__name__}",
            force=False)

        global param_count
S
Stas Bekman 已提交
527
        for param in module.parameters(recurse=False):
S
Samyam Rajbhandari 已提交
528 529 530 531
            param_count += param.numel()
            if not is_zero_param(param):
                self._convert_to_deepspeed_param(param)
                print_rank_0(
S
Stas Bekman 已提交
532
                    f"Partitioning param {debug_param2name_id_shape(param)} module={debug_module2name(module)}"
S
Samyam Rajbhandari 已提交
533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549
                )
                param.partition()
        see_memory_usage(
            f"Param count {param_count}. After converting and partitioning parmas in {module.__class__.__name__}",
            force=False)

    def _convert_to_deepspeed_param(self, param):

        # Partitioned, Normal, Remote
        param.ds_param_type = ZeroParamType.PARTITIONED

        # Replicated vs Partitioned vs Inflight
        param.ds_status = ZeroParamStatus.AVAILABLE

        # Stores the shape of the original tensor
        param.ds_shape = param.shape

550
        # Stores the number of elements in the original parameter without padding
S
Samyam Rajbhandari 已提交
551 552
        param.ds_numel = param.numel()

553
        # Stores the partitioned copy of the tensor
S
Samyam Rajbhandari 已提交
554 555 556 557 558 559 560 561 562 563 564 565
        param.ds_tensor = None

        # Keeps track of how many active sub-modules need this param at any given point in time
        param.ds_active_sub_modules = 0

        # If this flag is true, then the parameters are replicated throughput training
        # And only partitioned before the step
        param.ds_persist = False

        # The group that the parameter is scattered across.
        param.ds_process_group = self.ds_process_group

J
Jeff Rasley 已提交
566 567 568 569
        # This is set to the Async Param swapper if remote device is nvme
        # else this is set to None
        param.nvme_swapper = self.param_swapper

S
Samyam Rajbhandari 已提交
570 571 572 573 574 575 576 577 578 579 580 581 582
        # DeepSped Param ID
        param.ds_id = Init.param_id
        Init.param_id += 1

        def all_gather(param_list=None, async_op=False, hierarchy=0):
            cls = param
            if param_list is None:
                param_list = [cls]
            return self._all_gather(param_list, async_op=async_op, hierarchy=hierarchy)

        def partition(param_list=None, hierarchy=0, has_been_updated=False):
            cls = param
            print_rank_0(
S
Stas Bekman 已提交
583
                f"{'--'*hierarchy}----Partitioning param {debug_param2name_id_shape_device(cls)}"
S
Samyam Rajbhandari 已提交
584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603
            )
            if param_list is None:
                param_list = [cls]
            self._partition(param_list, has_been_updated=has_been_updated)

        def reduce_gradients_at_owner(param_list=None, hierarchy=0):
            cls = param
            if param_list is None:
                param_list = [cls]
            print_rank_0(
                f"{'--'*hierarchy}----Reducing Gradients for param with ids {[param.ds_id for param in param_list]} to owner"
            )
            self._reduce_scatter_gradients(param_list)

        def partition_gradients(param_list=None,
                                partition_buffers=None,
                                hierarchy=0,
                                accumulate=False):
            cls = param
            print_rank_0(
S
Stas Bekman 已提交
604 605
                f"{'--'*hierarchy}----Partitioning param gradient with id {debug_param2name_id_shape_device(cls)}"
            )
S
Samyam Rajbhandari 已提交
606 607 608 609 610 611 612 613 614 615 616 617 618 619 620
            if param_list is None:
                param_list = [cls]
                if isinstance(partition_buffers, torch.Tensor):
                    partition_buffers = [partition_buffers]

            self._partition_gradients(param_list,
                                      partition_buffers=partition_buffers,
                                      accumulate=accumulate)

        def aligned_size():
            return self._aligned_size(param)

        def padding_size():
            return self._padding_size(param)

J
Jeff Rasley 已提交
621 622 623
        def partitioned_size():
            return self._partitioned_size(param)

S
Samyam Rajbhandari 已提交
624 625 626 627 628 629 630 631 632 633 634
        # Collectives for gathering and partitioning parameters
        param.all_gather = all_gather
        param.partition = partition

        # Collective for averaging gradients
        param.reduce_gradients_at_owner = reduce_gradients_at_owner
        param.partition_gradients = partition_gradients

        # Partitioning size utilities
        param.aligned_size = aligned_size
        param.padding_size = padding_size
J
Jeff Rasley 已提交
635
        param.partitioned_size = partitioned_size
S
Samyam Rajbhandari 已提交
636 637 638 639 640 641 642 643

    def _aligned_size(self, param):
        return param.ds_numel + self._padding_size(param)

    def _padding_size(self, param):
        remainder = param.ds_numel % self.world_size
        return (self.world_size - remainder) if remainder else 0

J
Jeff Rasley 已提交
644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661
    def _partitioned_size(self, param):
        return param.ds_tensor.ds_numel

    def _ensure_availability_of_partitioned_params(self, params):
        swap_in_list = []
        swap_in_flight = []
        for param in params:
            if param.ds_tensor.status == PartitionedParamStatus.NOT_AVAILABLE:
                assert param.ds_tensor.final_location == OFFLOAD_NVME_DEVICE and param.ds_status == ZeroParamStatus.NOT_AVAILABLE
                swap_in_list.append(param)
            if param.ds_tensor.status == PartitionedParamStatus.INFLIGHT:
                assert param.ds_tensor.final_location == OFFLOAD_NVME_DEVICE and param.ds_status == ZeroParamStatus.NOT_AVAILABLE
                swap_in_flight.append(param)
        if len(swap_in_list) > 0:
            swap_in_list[0].nvme_swapper.swap_in(swap_in_list, async_op=False)
        elif len(swap_in_flight) > 0:
            swap_in_flight[0].nvme_swapper.synchronize_reads()

S
Samyam Rajbhandari 已提交
662
    def _all_gather(self, param_list, async_op=False, hierarchy=None):
J
Jeff Rasley 已提交
663

664
        # fetches from nvme if the partition is not available and in nvme
J
Jeff Rasley 已提交
665 666
        self._ensure_availability_of_partitioned_params(param_list)

S
Samyam Rajbhandari 已提交
667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690
        handles = []
        all_gather_list = []
        for param in param_list:
            if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
                if async_op:
                    handle = self._allgather_param(param,
                                                   async_op=async_op,
                                                   hierarchy=hierarchy)
                    param.ds_status = ZeroParamStatus.INFLIGHT  # if async_op else ZeroParamStatus.AVAILABLE
                    handles.append(handle)
                else:
                    all_gather_list.append(param)

        if not async_op:
            ret_value = self._allgather_params(all_gather_list, hierarchy=hierarchy)
            for param in all_gather_list:
                param.ds_status = ZeroParamStatus.AVAILABLE
            return ret_value

        return handles

    def _partition(self, param_list, force=False, has_been_updated=False):
        for param in param_list:
            #print_rank_0(f"Before Partitioning Param {param.ds_id}")
691
            # self._param_status(param)
S
Samyam Rajbhandari 已提交
692 693
            self._partition_param(param, has_been_updated=has_been_updated)
            param.ds_status = ZeroParamStatus.NOT_AVAILABLE
694
            # if param.ds_tensor is not None:
S
Samyam Rajbhandari 已提交
695 696 697 698 699
            #    assert id(param.data) == id(param.ds_tensor.data), \
            #    "After the parameters are initially partitioned, make sure we are not recreating the partition."
            #print_rank_0(f"After Partitioning Param {param.ds_id}")
            # self._param_status(param)

J
Jeff Rasley 已提交
700
    def _partition_param(self, param, buffer=None, has_been_updated=False):
701
        assert param.ds_status is not ZeroParamStatus.INFLIGHT, f" {param} Cannot partition a param in flight"
J
Jeff Rasley 已提交
702

S
Samyam Rajbhandari 已提交
703 704 705 706 707 708 709 710 711 712 713 714 715 716 717
        global reuse_buffers
        #print_rank_0(f"Param id {param.ds_id} status is {param.ds_status}")
        if param.ds_status is ZeroParamStatus.AVAILABLE:
            print_rank_0(
                f"Partitioning param id {param.ds_id} reuse buffers {reuse_buffers}",
                force=False)
            # if reuse_buffers and False:
            #     numel = buffer.numel()
            #     buffer = param.data.view(-1)
            #     print_rank_0(
            #         "Returning buffer for param {param.ds_id} with numel {param.ds_numel} to empty buffers",
            #         force=False)
            #     if numel in empty_buffers:
            #         empty_buffers[numel].append(buffer)

718
            # if torch.distributed.get_rank():
S
Samyam Rajbhandari 已提交
719 720 721 722 723
            #    print(f"Releasing {param.data.numel()}")
            if param.ds_tensor is not None and not has_been_updated:

                #param.data = param.ds_tensor.data

J
Jeff Rasley 已提交
724 725 726
                see_memory_usage(
                    f'Before partitioning param {param.ds_id} {param.shape}',
                    force=False)
727
                # param.data does not store anything meaningful in partitioned state
728
                param.data = torch.ones(1, dtype=self.dtype).to(param.device)
J
Jeff Rasley 已提交
729 730 731 732 733 734 735 736 737
                see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}',
                                 force=False)

                if param.ds_tensor.final_location == OFFLOAD_NVME_DEVICE:
                    print_rank_0(
                        f"Param {param.ds_id} partition released since it exists in nvme",
                        force=False)
                    param.nvme_swapper.remove_partition_and_release_buffers([param])

S
Samyam Rajbhandari 已提交
738 739 740 741 742 743
                return

            tensor_size = self._aligned_size(param)
            partition_size = tensor_size // self.world_size

            if param.ds_tensor is None:
J
Jeff Rasley 已提交
744 745 746 747 748 749 750 751 752 753 754 755
                final_location = None
                if self.remote_device == OFFLOAD_NVME_DEVICE and self.param_swapper.swappable_tensor(
                        numel=partition_size):
                    final_location = OFFLOAD_NVME_DEVICE
                    buffer = self.param_swapper.get_buffer(param, partition_size)
                    partitioned_tensor = torch.zeros(1,
                                                     dtype=param.dtype,
                                                     device=buffer.device)
                    partitioned_tensor.data = buffer.data
                    print_rank_0(
                        f"ID {param.ds_id} Initializing partition for the first time for nvme offload."
                    )
S
Samyam Rajbhandari 已提交
756

J
Jeff Rasley 已提交
757 758 759 760 761 762 763 764 765 766 767
                else:
                    partitioned_tensor = torch.zeros(
                        partition_size,
                        dtype=param.dtype,
                        device=OFFLOAD_CPU_DEVICE
                        if self.remote_device == OFFLOAD_NVME_DEVICE else
                        self.remote_device)
                    if self.pin_memory:
                        partitioned_tensor = partitioned_tensor.pin_memory()

                partitioned_tensor.requires_grad = False
S
Samyam Rajbhandari 已提交
768
                param.ds_tensor = partitioned_tensor
J
Jeff Rasley 已提交
769 770 771
                param.ds_tensor.ds_numel = partition_size
                param.ds_tensor.status = PartitionedParamStatus.AVAILABLE
                param.ds_tensor.final_location = final_location
S
Samyam Rajbhandari 已提交
772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804

            start = partition_size * self.rank
            end = start + partition_size

            one_dim_param = param.contiguous().view(-1)

            if start < param.ds_numel and end <= param.ds_numel:
                src_tensor = one_dim_param.narrow(0, start, partition_size)

                param.ds_tensor.copy_(src_tensor)
                #partitioned_tensor = src_tensor.clone().detach().to(self.remote_device)

            else:
                # partitioned_tensor = torch.zeros(partition_size,
                #                                  dtype=param.dtype,
                #                                  device=self.remote_device )

                if start < param.ds_numel:
                    elements_to_copy = param.ds_numel - start
                    param.ds_tensor.narrow(0,
                                           0,
                                           elements_to_copy).copy_(
                                               one_dim_param.narrow(
                                                   0,
                                                   start,
                                                   elements_to_copy))

            #print(f"Remote device {self.remote_device}")

            #param.ds_tensor = partitioned_tensor

            #param.data = param.ds_tensor.data

805
            # param.data does not store anything meaningful in partitioned state
J
Jeff Rasley 已提交
806 807 808

            see_memory_usage(f'Before partitioning param {param.ds_id} {param.shape}',
                             force=False)
809
            param.data = torch.ones(1, dtype=self.dtype).to(param.device)
J
Jeff Rasley 已提交
810 811 812 813 814 815 816 817 818 819
            see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}',
                             force=False)

            if param.ds_tensor.final_location == OFFLOAD_NVME_DEVICE:
                self.param_swapper.swap_out_and_release([param])
                print_rank_0(
                    f"ID {param.ds_id} Offloaded to nvme offload and buffers released.")
                see_memory_usage(
                    f"ID {param.ds_id} Offloaded to nvme offload and buffers released.",
                    force=False)
S
Samyam Rajbhandari 已提交
820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836

            print_rank_0(
                f"ID {param.ds_id} partitioned type {param.dtype} dev {param.device} shape {param.shape}"
            )

    def _param_status(self, param):
        if param.ds_tensor is not None:
            print_rank_0(
                f"Param id {param.ds_id}, param status: {param.ds_status}, param numel {param.ds_numel}, partitioned numel {param.ds_tensor.numel()}, data numel {param.data.numel()}"
            )
        else:
            print_rank_0(
                f"Param id {param.ds_id}, param status: {param.ds_status}, param numel {param.ds_numel}, partitioned ds_tensor {param.ds_tensor}, data numel {param.data.numel()}"
            )

    def _allgather_param(self, param, async_op=False, hierarchy=0):

J
Jeff Rasley 已提交
837
        partition_size = param.ds_tensor.ds_numel
S
Samyam Rajbhandari 已提交
838 839 840 841 842 843

        tensor_size = partition_size * self.world_size
        aligned_param_size = self._aligned_size(param)
        assert tensor_size == aligned_param_size, f'param id {param.ds_id} aligned size {aligned_param_size} does not match tensor size {tensor_size}'

        print_rank_0(
S
Stas Bekman 已提交
844
            f"{'--'* hierarchy}---- Before allocating allgather param {debug_param2name_id_shape_status(param)} partition size={partition_size}"
S
Samyam Rajbhandari 已提交
845
        )
J
Jeff Rasley 已提交
846 847

        see_memory_usage(
S
Stas Bekman 已提交
848
            f'Before allocate allgather param {debug_param2name_id_shape_status(param)} partition_size={partition_size} ',
J
Jeff Rasley 已提交
849
            force=False)
S
Samyam Rajbhandari 已提交
850 851 852
        flat_tensor = torch.zeros(aligned_param_size,
                                  dtype=param.dtype,
                                  device=param.device).view(-1)
J
Jeff Rasley 已提交
853
        see_memory_usage(
S
Stas Bekman 已提交
854
            f'After allocate allgather param {debug_param2name_id_shape_status(param)} {aligned_param_size} {partition_size} ',
J
Jeff Rasley 已提交
855
            force=False)
S
Samyam Rajbhandari 已提交
856 857 858 859

        torch.cuda.synchronize()

        print_rank_0(
S
Stas Bekman 已提交
860
            f"{'--'* hierarchy}----allgather param with {debug_param2name_id_shape_status(param)} partition size={partition_size}"
S
Samyam Rajbhandari 已提交
861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887
        )
        #        if not flat_tensor.numel() > 100000:
        #            replicated_tensor = flat_tensor.narrow(0,
        #                                                   0,
        #                                                   param.ds_numel).view(param.ds_shape)
        #            param.data = replicated_tensor.data
        #            return None
        partitions = []
        for i in range(self.world_size):
            partitions.append(flat_tensor.narrow(0, partition_size * i, partition_size))

            if i == torch.distributed.get_rank(group=self.ds_process_group):
                partitions[i].data.copy_(param.ds_tensor.data, non_blocking=True)

        handle = torch.distributed.all_gather(partitions,
                                              partitions[self.rank],
                                              group=self.ds_process_group,
                                              async_op=async_op)

        replicated_tensor = flat_tensor.narrow(0, 0, param.ds_numel).view(param.ds_shape)
        param.data = replicated_tensor.data
        return handle

    def _allgather_params(self, param_list, hierarchy=0):
        if len(param_list) == 0:
            return

J
Jeff Rasley 已提交
888
        partition_size = sum([param.ds_tensor.ds_numel for param in param_list])
S
Samyam Rajbhandari 已提交
889 890 891 892 893 894 895 896 897 898 899 900 901 902 903

        tensor_size = partition_size * self.world_size
        flat_tensor = torch.empty(tensor_size,
                                  dtype=param_list[0].dtype,
                                  device=self.local_device)
        flat_tensor.requres_grad = False
        partitions = []
        for i in range(self.world_size):
            start = partition_size * i

            partitions.append(flat_tensor.narrow(0, start, partition_size))

            if i == self.rank:
                offset = 0
                for param in param_list:
J
Jeff Rasley 已提交
904
                    param_numel = param.ds_tensor.ds_numel
S
Samyam Rajbhandari 已提交
905 906 907 908 909 910 911 912 913 914 915 916 917 918

                    partitions[i].narrow(0,
                                         offset,
                                         param_numel).copy_(param.ds_tensor.data)

                    offset += param_numel

        torch.distributed.all_gather(partitions,
                                     partitions[self.rank],
                                     group=self.ds_process_group,
                                     async_op=False)
        param_offset = 0

        for param in param_list:
J
Jeff Rasley 已提交
919
            param_partition_size = param.ds_tensor.ds_numel
S
Samyam Rajbhandari 已提交
920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939
            param_size = param.ds_numel
            replicated_tensor = torch.empty(param.ds_shape,
                                            dtype=param.dtype,
                                            device=self.local_device)

            for i in range(self.world_size):

                start = i * partition_size

                param_start = i * param_partition_size

                if param_start < param_size:
                    numel_to_copy = min(param_size - param_start, param_partition_size)

                    part_to_copy = partitions[i].narrow(0, param_offset, numel_to_copy)

                    replicated_tensor.view(-1).narrow(0,
                                                      param_start,
                                                      numel_to_copy).copy_(part_to_copy)
            #param_offset += param.data.numel()
J
Jeff Rasley 已提交
940
            param_offset += param.ds_tensor.ds_numel
S
Samyam Rajbhandari 已提交
941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963

            param.data = replicated_tensor.data

        return None

    def _reduce_scatter_gradients(self, param_list):
        #print_rank_0([param.grad for param in param_list])
        #assert any([param.grad is None for param in param_list]), "None gradients cannot be reduce scattered"

        handles_and_reduced_partitions = []
        for param in param_list:
            assert param.grad.numel(
            ) == param.ds_numel, f"{param.grad.numel()} != {param.ds_numel} Cannot reduce scatter gradients whose size is not same as the params"

            handles_and_reduced_partitions.append(self._reduce_scatter_gradient(param))

        for param, (handle, reduced_partition) in zip(param_list, handles_and_reduced_partitions):
            if handle is not None:
                handle.wait()

            # some ranks may have partitions that are padded to go beyond the grad size.
            # For these ranks the output of reduce scatter is a separate buffer and needs
            # to be copied in
J
Jeff Rasley 已提交
964
            partition_size = param.ds_tensor.ds_numel
S
Samyam Rajbhandari 已提交
965 966 967 968 969 970 971 972 973 974 975 976 977 978
            start = self.rank * partition_size
            end = start + partition_size
            #print_rank_0("REduce scatter was executed for praam {param.ds_id}")
            if start < param.ds_numel and end > param.ds_numel:
                elements = param.ds_numel - start
                param.grad.view(-1).narrow(0,
                                           start,
                                           elements).copy_(
                                               reduced_partition.narrow(0,
                                                                        0,
                                                                        elements))

    def _reduce_scatter_gradient(self, param):

J
Jeff Rasley 已提交
979
        partition_size = param.ds_tensor.ds_numel
S
Samyam Rajbhandari 已提交
980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030
        #output = torch.empty(partition_size, dtype=param.dtype, device=param.device)

        total_size = partition_size * self.world_size
        input_list = []

        for i in range(self.world_size):

            start = i * partition_size
            end = start + partition_size

            #print("before reduce scatter gradients")
            if start < param.ds_numel and end <= param.ds_numel:
                input = param.grad.view(-1).narrow(0, start, partition_size)
            else:
                input = torch.zeros(partition_size,
                                    dtype=param.dtype,
                                    device=param.device)

                if start < param.ds_numel:
                    elements = param.ds_numel - start
                    input.narrow(0,
                                 0,
                                 elements).copy_(
                                     param.grad.view(-1).narrow(0,
                                                                start,
                                                                elements))
            #print("after reduce scatter gradients")
            input_list.append(input)

        rank = torch.distributed.get_rank(group=self.ds_process_group)
        handle = torch.distributed.reduce_scatter(input_list[rank],
                                                  input_list,
                                                  group=self.ds_process_group,
                                                  async_op=True)

        return handle, input_list[rank]

    def _partition_gradients(self, param_list, partition_buffers=None, accumulate=False):
        if partition_buffers is None:
            partition_buffers = [None] * len(param_list)

        for param, partition_buffer in zip(param_list, partition_buffers):
            self._partition_gradient(param,
                                     partition_buffer=partition_buffer,
                                     accumulate=accumulate)

    def _partition_gradient(self, param, partition_buffer=None, accumulate=False):
        #import pdb;pdb.set_trace()
        # param.grad=None
        # param.grad.test()
        print_rank_0(
J
Jeff Rasley 已提交
1031
            f"Partitioning param {param.ds_id} gradient of size {param.grad.numel()} type {param.grad.dtype} part_size {param.ds_tensor.ds_numel}"
S
Samyam Rajbhandari 已提交
1032 1033
        )
        see_memory_usage("Before partitioning gradients", force=False)
J
Jeff Rasley 已提交
1034
        partition_size = param.ds_tensor.ds_numel
S
Samyam Rajbhandari 已提交
1035 1036 1037 1038 1039 1040 1041

        if partition_buffer is None:
            assert not accumulate, "No buffer to accumulate to"
            partition_buffer = torch.zeros(partition_size,
                                           dtype=param.dtype,
                                           device=param.device)
        else:
1042 1043
            assert partition_buffer.numel(
            ) >= partition_size, f"The partition buffer size {partition_buffer.numel()} should match the size of param.ds_tensor {partition_size}"
S
Samyam Rajbhandari 已提交
1044 1045 1046 1047 1048

        rank = torch.distributed.get_rank(group=self.ds_process_group)
        start = partition_size * rank
        end = start + partition_size

1049
        dest_tensor_full_buffer = partition_buffer.view(-1).narrow(0, 0, partition_size)
S
Samyam Rajbhandari 已提交
1050 1051 1052 1053 1054

        #print("before partition gradients")
        if start < param.ds_numel:
            elements = min(param.ds_numel - start, partition_size)

1055
            dest_tensor = dest_tensor_full_buffer.narrow(0, 0, elements)
S
Samyam Rajbhandari 已提交
1056 1057 1058 1059 1060 1061
            src_tensor = param.grad.view(-1).narrow(0, start, elements)

            # just copy the grad partition to the buffer
            if not accumulate:
                dest_tensor.copy_(src_tensor)

A
Alex Hedges 已提交
1062
            # if source and destination are on same device,
S
Samyam Rajbhandari 已提交
1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087
            # add to the provided buffer
            elif src_tensor.device == dest_tensor.device:
                dest_tensor.add_(src_tensor)

            # if source and destination are on different device, copy first to src
            # then add and move back to the destination. This seems to run faster
            # when src is gpu and dest is cpu
            # adding directly to cpu is very slow
            else:
                acc_tensor = torch.empty(src_tensor.numel(),
                                         dtype=param.dtype,
                                         device=param.device)

                acc_tensor.copy_(dest_tensor)
                acc_tensor.add_(src_tensor)
                dest_tensor.copy_(acc_tensor)

            # partition_buffer.view(-1).narrow(
            #     0,
            #     0,
            #     elements).copy_(param.grad.view(-1).narrow(0,
            #                                             start,
            #                                             elements))

        #print("after partition gradients")
1088
        param.grad.data = dest_tensor_full_buffer.data
S
Samyam Rajbhandari 已提交
1089 1090 1091 1092
        see_memory_usage("After partitioning gradients", force=False)


class GatheredParameters:
1093 1094 1095
    def __init__(self, params, modifier_rank=None, fwd_module=None, enabled=True):
        """A context that collects parameters that were partitioned via a
        :class:`deepspeed.zero.Init` context. The parameters are partitioned
S
Samyam Rajbhandari 已提交
1096 1097 1098
        again upon exit.

        Args:
1099 1100
            params (``torch.nn.Parameter``): A single parameter or a list of parameters to collect.
                It's assumed that all parameters are zero params.
S
Samyam Rajbhandari 已提交
1101
            modifier_rank (int, optional): If specified, this rank's parameter will be
1102 1103
                broadcasted on exit from the context. This argument is required if ``params`` are
                modified, so that all processes have a consistent view of the data. Defaults
S
Samyam Rajbhandari 已提交
1104
                to ``None``.
1105 1106
            fwd_module (``torch.nn.Module``, optional): If specified, ``params`` will be
                registered as external parameters of ``fwd_module``. See :meth:`deepspeed.zero.register_external_parameter`.
S
Samyam Rajbhandari 已提交
1107 1108
            enabled (bool, optional): If ``False``, this context is a no-op. Defaults to ``True``.

1109 1110 1111
        Important: Make sure to use ``modifier_rank`` that is not ``None`` (e.g. ``modifier_rank=0``)
        if you need the GPU memory allocated by gather to be released upon exit from the context manager.

S
Samyam Rajbhandari 已提交
1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127
        Examples
        ========

        #. Allocate a partitioned module, initialize its weight on rank 0, and update all
           processes.

            .. code-block:: python

                with deepspeed.zero.Init():
                    linear = torch.nn.Linear(1000,1000)

                with deepspeed.zero.GatheredParameters(linear.weight,
                                                       modifier_rank=0):
                    if torch.distributed.get_rank() == 0:
                        linear.weight.zero_()

J
Jeff Rasley 已提交
1128 1129 1130 1131
                with deepspeed.zero.GatheredParameters(linear.weight,
                                                       modifier_rank=0):
                    if torch.distributed.get_rank() == 0:
                        linear.weight.zero_()
S
Samyam Rajbhandari 已提交
1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147

        #. Collect a partitioned weight to pass to another module during
           training. The parameter will be registered as an external parameter
           and made available during the backward pass.

            .. code-block:: python
                :emphasize-lines: 6

                def forward(self, input):
                    x = self.layer1(input)

                    # self.layer1.weight is required by self.layer2.forward
                    with deepspeed.zero.GatheredParameters(self.layer1.weight,
                                                           fwd_module=self):
                        y = self.layer2(x, self.layer1.weight)
                    return y
1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174


        #. Pretrained model loading

            .. code-block:: python

                with deepspeed.zero.Init():
                    model = MyModel()

                state_dict = torch.load(model_path, map_location="cpu")

                def load(module: nn.Module, prefix=""):
                    # because zero3 puts placeholders in model params, this context
                    # manager gathers (unpartitions) the params of the current layer, then loads from
                    # the state dict and then re-partitions them again
                    with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
                        if torch.distributed.get_rank() == 0:
                            module._load_from_state_dict(state_dict, prefix)

                    for name, child in module._modules.items():
                        if child is not None:
                            load(child, prefix + name + ".")

                load(model, prefix="")

        If this approach is not used, then the full model will first get copied to each GPU. For models
        bigger than the memory of a single gpu this method is required.
S
Samyam Rajbhandari 已提交
1175 1176 1177 1178 1179 1180
        """

        self.enabled = enabled
        if not enabled:
            return

1181 1182 1183 1184 1185
        if not isinstance(params, list):
            params = [params]

        # enable if at least one is zero-param, otherwise a noop
        if not any(is_zero_param(p) for p in params):
S
Samyam Rajbhandari 已提交
1186 1187 1188
            self.enabled = False
            return

1189
        self.params = [p for p in params if hasattr(p, "ds_id")]
S
Samyam Rajbhandari 已提交
1190 1191
        self.src_rank = None
        if modifier_rank is not None:
1192
            if self.params[0].ds_process_group == torch.distributed.group.WORLD:
S
Samyam Rajbhandari 已提交
1193 1194 1195
                self.src_rank = modifier_rank
            else:
                # A group was specified; convert DP rank to global rank
1196
                self.src_rank = _get_global_rank(self.params[0].ds_process_group,
S
Samyam Rajbhandari 已提交
1197 1198 1199 1200
                                                 modifier_rank)
        self.fwd_module = fwd_module
        if self.fwd_module is not None:
            # is a no-op if already registered
1201 1202
            for p in self.params:
                register_external_parameter(self.fwd_module, p)
S
Samyam Rajbhandari 已提交
1203 1204 1205 1206

    def __enter__(self):
        if not self.enabled:
            return
1207
        self.params[0].all_gather(param_list=self.params)
S
Samyam Rajbhandari 已提交
1208 1209 1210 1211

    def __exit__(self, *exc):
        if not self.enabled:
            return
1212 1213 1214 1215 1216
        if self.src_rank is None:
            return

        handles = [
            torch.distributed.broadcast(p,
S
Samyam Rajbhandari 已提交
1217
                                        self.src_rank,
1218 1219 1220 1221 1222 1223
                                        group=p.ds_process_group,
                                        async_op=True) for p in self.params
        ]
        for h in handles:
            h.wait()
        self.params[0].partition(param_list=self.params, has_been_updated=True)