partition_parameters.py 68.2 KB
Newer Older
J
Jeff Rasley 已提交
1 2 3 4 5
"""
"Copyright 2020 The Microsoft DeepSpeed Team.
Licensed under the MIT license.
"""

6
import math
S
Samyam Rajbhandari 已提交
7 8
import os
import types
9
from typing import Callable, Iterable
S
Samyam Rajbhandari 已提交
10 11 12
from enum import Enum
import functools
import itertools
13
from typing import List
S
Samyam Rajbhandari 已提交
14 15

import torch
16
from torch import Tensor
K
Karim Foda 已提交
17
from deepspeed import comm as dist
18 19
from torch.nn import Module
from torch.nn import Parameter
S
Samyam Rajbhandari 已提交
20

21
from .linear import zero3_linear_wrap
J
Jeff Rasley 已提交
22

23
import deepspeed
24 25
from ..utils import get_only_unique_item, see_memory_usage
from deepspeed.runtime.zero.utils import assert_ints_same_as_other_ranks
26 27
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
from deepspeed.utils import instrument_w_nvtx, logger
28
from deepspeed.comm.comm import init_distributed
29 30 31 32
from deepspeed.utils.debug import (debug_param2name_id_shape,
                                   debug_param2name_id_shape_device,
                                   debug_module2name,
                                   debug_param2name_id,
33
                                   debug_param2name_id_shape_status)
J
Jeff Rasley 已提交
34 35
from ..swap_tensor.partitioned_param_swapper import AsyncPartitionedParameterSwapper, PartitionedParamStatus

S
Samyam Rajbhandari 已提交
36
param_count = 0
37
partitioned_param_data_shape = [0]
38
zero_init_enabled = False
39 40


41
def _dist_allgather_fn(input_tensor: Tensor, output_tensor: Tensor, group=None):
42 43 44 45
    return instrument_w_nvtx(dist.allgather_fn)(output_tensor,
                                                input_tensor,
                                                group=group,
                                                async_op=True)
S
Samyam Rajbhandari 已提交
46 47 48


def print_rank_0(message, debug=False, force=False):
49
    rank = dist.get_rank()
S
Stas Bekman 已提交
50
    if rank == 0 and (debug or force):
S
Samyam Rajbhandari 已提交
51
        print(message)
S
Stas Bekman 已提交
52 53 54 55 56
    # other variations
    # - print for all ranks w/o interleaving
    # printflock(f"[{rank}] {message}")
    # - print to log file per rank
    # log_rank_file(rank, message)
S
Samyam Rajbhandari 已提交
57 58


59
def debug_rank0(msg: str) -> None:
60
    if dist.get_rank() == 0:
61 62 63
        logger.debug(msg)


S
Samyam Rajbhandari 已提交
64
def is_zero_param(parameter):
J
Jeff Rasley 已提交
65 66
    if not torch.is_tensor(parameter):
        return False
S
Samyam Rajbhandari 已提交
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
    return hasattr(parameter, 'ds_id')


def _init_external_params(module):
    if not hasattr(module, '_external_params'):
        module._external_params = {}

        def external_parameters(self):
            return self._external_params.items()

        def all_parameters(self):
            return itertools.chain(self.named_parameters(self,
                                                         recurse=False),
                                   external_parameters(self))

        module.ds_external_parameters = types.MethodType(external_parameters, module)
        module.all_parameters = types.MethodType(all_parameters, module)


def register_external_parameter(module, parameter):
    """Instruct DeepSpeed to coordinate ``parameter``'s collection and partitioning in
    the forward and backward passes of ``module``.

    This is used when a parameter is accessed outside of its owning module's
    ``forward()``. DeepSpeed must know to collect it from its partitioned
    state and when to release the memory.

    .. note::
        This is only applicable to training with ZeRO stage 3.

    Args:
        module (``torch.nn.Module``): The module that requires ``parameter`` in its forward pass.
        parameter (``torch.nn.Parameter``): The parameter to register.

    Raises:
        RuntimeError: If ``parameter`` is not of type ``torch.nn.Parameter``.


    Examples
    ========

    #. Register a weight that is used in another module's forward pass (line 6).
       Parameter ``layer1.weight`` is used by ``layer2`` (line 11).

        .. code-block:: python
            :linenos:
            :emphasize-lines: 6,11

            class ModuleZ3(torch.nn.Module):
                def __init__(self, *args):
                    super().__init__(self, *args)
                    self.layer1 = SomeLayer()
                    self.layer2 = OtherLayer()
                    deepspeed.zero.register_external_parameter(self, self.layer1.weight)

                def forward(self, input):
                    x = self.layer1(input)
                    # self.layer1.weight is required by self.layer2.forward
                    y = self.layer2(x, self.layer1.weight)
                    return y
    """
    if not isinstance(parameter, torch.nn.Parameter):
        raise RuntimeError('Parameter is not a torch.nn.Parameter')

    if not hasattr(module, '_external_params'):
        _init_external_params(module)

    key = id(parameter)
    module._external_params[key] = parameter


J
Jeff Rasley 已提交
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
def unregister_external_parameter(module, parameter):
    """Reverses the effects of :meth:`register_external_parameter`.

    Args:
        module (``torch.nn.Module``): The module to affect.
        parameter (``torch.nn.Parameter``): The parameter to unregister.

    Raises:
        RuntimeError: If ``parameter`` is not of type ``torch.nn.Parameter``.
        RuntimeError: If ``parameter`` is not a registered external parameter of ``module``.
    """
    if not isinstance(parameter, torch.nn.Parameter):
        raise RuntimeError('Parameter is not a torch.nn.Parameter')

    if not hasattr(module,
                   '_external_params') or id(parameter) not in module._external_params:
        raise RuntimeError('Parameter is not a registered external parameter of module.')

    key = id(parameter)
    del module._external_params[key]


S
Samyam Rajbhandari 已提交
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
class ZeroParamType(Enum):

    # same as regular pytorch parameters
    NORMAL = 1

    # parameters are partitioned across data parallel process
    PARTITIONED = 2

    # the parameter is held with a unique process rank
    # and is not available on all other process
    REMOTE = 3


class ZeroParamStatus(Enum):
    # parameters are fully present and ready for use on all processes
    AVAILABLE = 1

    # parameters are either partitioned or remote in some or all process
    NOT_AVAILABLE = 2

    # parameters are being gathered.
    INFLIGHT = 3


_orig_torch_empty = torch.empty
185 186 187
_orig_torch_zeros = torch.zeros
_orig_torch_ones = torch.ones
_orig_torch_full = torch.full
S
Samyam Rajbhandari 已提交
188 189


190 191 192 193 194 195 196 197
def zero_wrapper_for_fp_tensor_constructor(fn: Callable,
                                           target_fp_dtype: torch.dtype) -> Callable:
    def wrapped_fn(*args, **kwargs) -> Tensor:
        if kwargs.get("device", None) is None:
            kwargs['device'] = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"]))
        tensor: Tensor = fn(*args, **kwargs)
        if tensor.is_floating_point():
            tensor = tensor.to(target_fp_dtype)
S
Samyam Rajbhandari 已提交
198 199 200

        return tensor

201 202
    return wrapped_fn

S
Samyam Rajbhandari 已提交
203

204 205 206 207 208 209
def get_new_tensor_fn_for_dtype(dtype: torch.dtype) -> Callable:
    def new_tensor(cls, *args) -> Tensor:
        device = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"]))
        tensor = _orig_torch_empty(0, device=device).new_empty(*args)
        if tensor.is_floating_point():
            tensor = tensor.to(dtype)
210

211
        return tensor
212

213
    return new_tensor
214 215


216 217 218 219 220 221 222 223 224 225 226 227 228 229
# https://stackoverflow.com/a/63851681/9201239
def get_all_subclasses(cls):
    subclass_list = []

    def recurse(cl):
        for subclass in cl.__subclasses__():
            subclass_list.append(subclass)
            recurse(subclass)

    recurse(cls)

    return set(subclass_list)


230 231 232 233 234 235 236 237 238 239 240 241 242
@instrument_w_nvtx
def free_param(param: Parameter) -> None:
    """Free underlying storage of a parameter."""
    assert not param.ds_active_sub_modules, param.ds_summary()
    if param.data.is_cuda:
        # need to make sure that we don't free the parameter while it is still
        # being used for computation
        param.data.record_stream(torch.cuda.current_stream())
    # param.data doesn't store anything meaningful in partitioned state
    param.data = torch.empty(0, dtype=param.dtype, device=param.device)
    param.ds_status = ZeroParamStatus.NOT_AVAILABLE


S
Samyam Rajbhandari 已提交
243 244 245 246 247 248 249 250
reuse_buffers = False
temp_contiguous_tensor = None
empty_buffers = {}


# Inserts _post_init_method at the end of init method
# for all sub classes of torch.nn.Module
class InsertPostInitMethodToModuleSubClasses(object):
O
Olatunji Ruwase 已提交
251 252 253 254 255
    def __init__(self,
                 enabled=True,
                 mem_efficient_linear=True,
                 ds_config=None,
                 dtype=None):
S
Samyam Rajbhandari 已提交
256 257
        self.mem_efficient_linear = mem_efficient_linear
        self.enabled = enabled
O
Olatunji Ruwase 已提交
258
        self._set_dtype(ds_config, dtype)
259
        assert self.dtype in [torch.half, torch.bfloat16, torch.float], f"Invalid data type {self.dtype}, allowed values are [torch.half, torch.bfloat16, torch.float]"
S
Samyam Rajbhandari 已提交
260 261

    def __enter__(self):
262
        global zero_init_enabled
S
Samyam Rajbhandari 已提交
263 264
        if not self.enabled:
            return
265
        zero_init_enabled = True
S
Samyam Rajbhandari 已提交
266

267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316
        def apply_with_gather(orig_module_apply_fn: Callable) -> Callable:
            """many models make use of child modules like Linear or Embedding which
            perform their own weight initialization in their __init__ methods,
            but will then have more weight initialization in a parent module's __init__
            method that modifies weights of child modules, which is typically done
            using the Module.apply method.

            since the Init context manager partitions child modules immediately after
            they are initialized, without modifying apply we would entirely skip
            any initialization done by parent modules.

            to get around this issue, we wrap the function passed to Module.apply
            so that the applied function is applied to child modules correctly.
            """
            def get_wrapped_fn_to_apply(fn_to_apply: Callable) -> Callable:
                if hasattr(fn_to_apply, "wrapped"):
                    return fn_to_apply

                @functools.wraps(fn_to_apply)
                def wrapped_fn_to_apply(module_to_apply_fn_to: Module) -> None:
                    """gathers parameters before calling apply function. afterwards
                    parameters are broadcasted to ensure consistency across all ranks
                    then re-partitioned.

                    takes the following steps:
                    1. allgathers parameters for the current module being worked on
                    2. calls the original function
                    3. broadcasts root rank's parameters to the other ranks
                    4. re-partitions the parameters
                    """
                    if not all(
                            is_zero_param(p)
                            for p in module_to_apply_fn_to.parameters(recurse=False)):
                        raise RuntimeError(
                            f"not all parameters for {module_to_apply_fn_to.__class__.__name__}, "
                            f"were zero params, is it possible that the parameters were "
                            f"overwritten after they were initialized? "
                            f"params: {[p for p in module_to_apply_fn_to.parameters(recurse=False)]} "
                        )

                    params_to_apply_fn_to: Iterable[Parameter] = list(
                        sorted(module_to_apply_fn_to.parameters(recurse=False),
                               key=lambda p: p.ds_id))

                    for param in params_to_apply_fn_to:
                        param.all_gather()

                    fn_to_apply(module_to_apply_fn_to)

                    for param in params_to_apply_fn_to:
317
                        dist.broadcast(param.data, 0, group=param.ds_process_group)
318 319 320 321 322 323 324 325 326 327 328 329 330 331

                    for param in params_to_apply_fn_to:
                        param.partition(has_been_updated=True)

                wrapped_fn_to_apply.wrapped = True

                return wrapped_fn_to_apply

            @functools.wraps(orig_module_apply_fn)
            def wrapped_apply(module: Module, fn_to_apply: Callable) -> None:
                orig_module_apply_fn(module, get_wrapped_fn_to_apply(fn_to_apply))

            return wrapped_apply

S
Samyam Rajbhandari 已提交
332 333 334
        def partition_after(f):
            @functools.wraps(f)
            def wrapper(module, *args, **kwargs):
335 336 337 338 339 340 341 342 343

                # important logic: We want to run post_init only after child's __init__ is
                # completed, and do nothing after __init__ of any of its parents and grandparents in
                # the inheritance ancestry. This way the partitioning will need to happen only once
                # when the whole object is ready to be partitioned and not before. This is because
                # often the child module will need to tweak the weights - for example running a
                # custom weights init function. So if a parent created the weights param, the child
                # won't need to gather it in order to tweak it

S
Samyam Rajbhandari 已提交
344 345
                print_rank_0(f'Before initializing {module.__class__.__name__}',
                             force=False)
346 347 348 349 350 351 352

                is_child_module = False
                if not hasattr(module, "_ds_child_entered"):
                    # child's __init__ was called, since parents all see the same object they can now skip post_init
                    is_child_module = True
                    setattr(module, "_ds_child_entered", True)

S
Samyam Rajbhandari 已提交
353
                f(module, *args, **kwargs)
354 355 356 357 358 359 360 361 362

                if is_child_module:
                    # child's __init__ is done, now we can run a single post_init on the child object
                    delattr(module, "_ds_child_entered")

                    print_rank_0(f'Running post_init for {module.__class__.__name__}',
                                 force=False)
                    self._post_init_method(module)

S
Samyam Rajbhandari 已提交
363 364 365 366 367 368 369 370 371 372 373 374 375
                print_rank_0(
                    f'After initializing followed by post init for {module.__class__.__name__}',
                    force=False)

            return wrapper

        def _enable_class(cls):
            cls._old_init = cls.__init__
            cls.__init__ = partition_after(cls.__init__)

        def _init_subclass(cls, **kwargs):
            cls.__init__ = partition_after(cls.__init__)

376 377 378
        # Replace .__init__() for all existing subclasses of torch.nn.Module recursively
        for subclass in get_all_subclasses(torch.nn.modules.module.Module):
            # print(f"subclass={subclass.__module__}.{subclass.__qualname__}")
S
Samyam Rajbhandari 已提交
379 380
            _enable_class(subclass)

381
        # holding onto some methods so we can put them back the way they were in __exit__
S
Samyam Rajbhandari 已提交
382
        torch.nn.modules.module.Module._old_init_subclass = torch.nn.modules.module.Module.__init_subclass__
383
        torch.nn.modules.module.Module._old_apply = torch.nn.modules.module.Module.apply
S
Samyam Rajbhandari 已提交
384 385 386 387
        torch.Tensor.__old_new__ = torch.Tensor.__new__

        # Replace .__init__() for future subclasses of torch.nn.Module
        torch.nn.modules.module.Module.__init_subclass__ = classmethod(_init_subclass)
388 389 390 391 392 393 394 395 396 397
        torch.nn.modules.module.Module.apply = apply_with_gather(
            torch.nn.modules.module.Module._old_apply)

        torch.Tensor.__new__ = get_new_tensor_fn_for_dtype(self.dtype)
        torch.empty = zero_wrapper_for_fp_tensor_constructor(_orig_torch_empty,
                                                             self.dtype)
        torch.zeros = zero_wrapper_for_fp_tensor_constructor(_orig_torch_zeros,
                                                             self.dtype)
        torch.ones = zero_wrapper_for_fp_tensor_constructor(_orig_torch_ones, self.dtype)
        torch.full = zero_wrapper_for_fp_tensor_constructor(_orig_torch_full, self.dtype)
S
Samyam Rajbhandari 已提交
398 399

        if self.mem_efficient_linear:
400
            print_rank_0(
401
                "nn.functional.linear has been overridden with a more memory efficient version. This will persist unless manually reset.",
S
Stas Bekman 已提交
402
                force=False)
S
Samyam Rajbhandari 已提交
403
            self.linear_bk = torch.nn.functional.linear
J
Jeff Rasley 已提交
404
            torch.nn.functional.linear = zero3_linear_wrap
S
Samyam Rajbhandari 已提交
405 406 407 408 409

    def __exit__(self, exc_type, exc_value, traceback):
        if not self.enabled:
            return

410
        shutdown_init_context()
S
Samyam Rajbhandari 已提交
411

412
        if dist.get_rank() == 0:
413 414 415
            logger.info("finished initializing model with %.2fB parameters",
                        param_count / 1e9)

S
Samyam Rajbhandari 已提交
416 417 418 419 420 421 422 423
        # Now that we cleaned up the metaclass injection, raise the exception.
        if exc_type is not None:
            return False

    # To be implemented by inheriting classes
    def _post_init_method(self, module):
        pass

424 425
    def _set_dtype(self, ds_config, dtype):
        if ds_config is not None and dtype is None:
426 427 428 429 430 431 432 433 434
            if ds_config.bfloat16_enabled and ds_config.fp16_enabled:
                raise RuntimeError("bfloat16 and fp16 cannot be enabled at once")

            if ds_config.bfloat16_enabled:
                self.dtype = torch.bfloat16
            elif ds_config.fp16_enabled:
                self.dtype = torch.half
            else:
                self.dtype = torch.float
435
        else:
436 437 438
            self.dtype = dtype or torch.half


439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470
def shutdown_init_context():
    global zero_init_enabled

    if not zero_init_enabled:
        return

    def _disable_class(cls):
        cls.__init__ = cls._old_init

    # Replace .__init__() for all existing subclasses of torch.nn.Module
    for subclass in get_all_subclasses(torch.nn.modules.module.Module):
        _disable_class(subclass)

    # putting methods back the way we found them
    torch.nn.modules.module.Module.__init_subclass__ = torch.nn.modules.module.Module._old_init_subclass
    torch.nn.modules.module.Module.apply = torch.nn.modules.module.Module._old_apply

    torch.Tensor.__new__ = torch.Tensor.__old_new__
    torch.empty = _orig_torch_empty
    torch.zeros = _orig_torch_zeros
    torch.ones = _orig_torch_ones
    torch.full = _orig_torch_full

    # un doing it here will undo it during training
    # if self.mem_efficient_linear:
    #    torch.nn.functional.linear = self.linear_bk
    #        if self.mem_efficient_linear:
    #            torch.nn.functional.linear = self.linear_bk

    zero_init_enabled = False


471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533
class AllGatherHandle:
    def __init__(self, handle, param: Parameter) -> None:
        if param.ds_status != ZeroParamStatus.INFLIGHT:
            raise RuntimeError(f"expected param {param.ds_summary()} to be available")

        self.__handle = handle
        self.__param = param

    def wait(self) -> None:
        instrument_w_nvtx(self.__handle.wait)()
        self.__param.ds_status = ZeroParamStatus.AVAILABLE


class AllGatherCoalescedHandle:
    def __init__(
        self,
        allgather_handle,
        params: List[Parameter],
        partitions: List[Tensor],
        world_size: int,
    ) -> None:
        self.__allgather_handle = allgather_handle
        self.__params = params
        self.__partitions = partitions
        self.__world_size = world_size
        self.__complete = False

        for param in self.__params:
            if param.ds_status != ZeroParamStatus.INFLIGHT:
                raise RuntimeError(
                    f"expected param {param.ds_summary()} to not be available")

    @instrument_w_nvtx
    def wait(self) -> None:
        if self.__complete:
            return

        instrument_w_nvtx(self.__allgather_handle.wait)()

        # split the single tensor out into individual tensors
        param_offset = 0
        for param in self.__params:
            assert param.ds_status == ZeroParamStatus.INFLIGHT, f"expected param {param.ds_summary()} to be inflight"
            partitions: List[Tensor] = []
            for rank in range(self.__world_size):
                param_start = rank * param.ds_tensor.ds_numel
                if param_start < param.ds_numel:
                    part_to_copy = self.__partitions[rank].narrow(
                        0,
                        param_offset,
                        min(param.ds_numel - param_start,
                            param.ds_tensor.ds_numel))
                    partitions.append(part_to_copy)

            param.data = instrument_w_nvtx(torch.cat)(partitions).view(param.ds_shape)
            param.ds_status = ZeroParamStatus.AVAILABLE

            for part_to_copy in partitions:
                part_to_copy.record_stream(torch.cuda.current_stream())

            param_offset += param.ds_tensor.ds_numel

        self.__complete = True
534

S
Samyam Rajbhandari 已提交
535 536 537 538 539 540 541 542 543 544 545

# Replaces all parameters in module with Scattered Parameters
class Init(InsertPostInitMethodToModuleSubClasses):
    param_id = 0

    def __init__(self,
                 module=None,
                 data_parallel_group=None,
                 mem_efficient_linear=True,
                 remote_device=None,
                 pin_memory=False,
O
Olatunji Ruwase 已提交
546
                 config_dict_or_path=None,
547
                 config=None,
548
                 enabled=True,
O
Olatunji Ruwase 已提交
549 550
                 dtype=None,
                 mpu=None):
S
Samyam Rajbhandari 已提交
551 552 553 554 555 556 557
        """A context to enable massive model construction for training with
        ZeRO-3. Models are automatically partitioned (or, sharded) across the
        system and converted to half precision.

        Args:
            module (``torch.nn.Module``, optional): If provided, partition the model as
                if it was constructed in the context.
558
            data_parallel_group (``deepspeed.comm`` process group, optional):
S
Samyam Rajbhandari 已提交
559 560 561 562
                The group of processes to partition among. Defaults to all processes.
            mem_efficient_linear (bool, optional): Replace
                torch.nn.functional.linear with an implementation that allows
                DeepSpeed to partition parameters. Defaults to ``True``.
J
Jeff Rasley 已提交
563 564 565
            remote_device (string, optional): The initial device to store model
                weights e.g., ``cpu``, ``nvme``. Passing ``"cpu"`` will create the model in CPU
                memory. The model may still be moved to GPU based on the
J
Jeff Rasley 已提交
566 567
                offload settings for training. Defaults to param offload device if a config is
                defined, otherwise GPU.
S
Samyam Rajbhandari 已提交
568 569
            pin_memory (bool, optional): Potentially increase performance by
                using pinned memory for model weights. ``remote_device`` must be
J
Jeff Rasley 已提交
570
                ``"cpu"``. Defaults to pin_memory value in config, otherwise ``False``.
571
            config_dict_or_path (dict or ``json file``, optional): If provided, provides configuration
J
Jeff Rasley 已提交
572
                for swapping fp16 params to NVMe.
573
            config (dict or ``json file``, optional): Deprecated, use config_dict_or_path instead.
S
Samyam Rajbhandari 已提交
574 575
            enabled (bool, optional): If ``False``, this context has no
                effect. Defaults to ``True``.
S
Stas Bekman 已提交
576 577
            dtype (``dtype``, optional): Can be used to change the data type of the parameters.
                Supported options are ``torch.half`` and ``torch.float``. Defaults to ``None``
A
Alex Hedges 已提交
578
            mpu (``object``, optional): A model parallelism unit object that implements get_{model,data}_parallel_{rank,group,world_size}.
S
Samyam Rajbhandari 已提交
579 580 581 582 583

        This context accelerates model initialization and enables models that
        are too large to allocate in their entirety in CPU memory. It has the
        following effects:

J
Jeff Rasley 已提交
584
        #. allocates tensors to either GPU or CPU memory or NVMe
S
Samyam Rajbhandari 已提交
585 586 587 588 589 590
        #. converts floating point tensors to half precision
        #. immediately partitions tensors among the group of data-parallel devices
        #. (*optional*) replaces ``torch.nn.functional.linear`` with a more
           memory-efficient implementation

        These modifications allow for models that exceed the size of local CPU/GPU
J
Jeff Rasley 已提交
591 592
        memory/NVMe, but fit within the total NVMe capacity (*i.e.*, aggregate CPU
        or GPU memory or NVMe) across all nodes. Consider initializing a model with one
S
Samyam Rajbhandari 已提交
593 594 595 596 597 598 599 600 601 602
        trillion parameters, whose weights occupy two terabytes (TB) in half
        precision. The initial CPU allocation in full precision requires 4TB of
        memory *per process*, and so a system with 8 GPUs per node would need 32TB of
        CPU memory due to data-parallel redundancies. Instead, by immediately
        partitioning tensors we remove the redundancies. The result is that
        regardless of the number of GPUs, we still only require the original 4TB. This
        allows for a linear increase in model size with the aggregate system memory.
        For example, if a node has 1TB of memory and 8 GPUs, we could fit a trillion
        parameter model with 4 nodes and 32 GPUs.

S
Stas Bekman 已提交
603 604 605
        Important: If the fp16 weights of the model can't fit onto a single GPU memory
        this feature must be used.

S
Samyam Rajbhandari 已提交
606
        .. note::
607
            Initializes ``deepspeed.comm`` if it has not already been done so.
I
iLeGend 已提交
608
            See :meth:`deepspeed.init_distributed` for more information.
S
Samyam Rajbhandari 已提交
609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648

        .. note::
            Can also be used as a decorator:

            .. code-block:: python

                @deepspeed.zero.Init()
                def get_model():
                    return MyLargeModel()

        .. note::
            Only applicable to training with ZeRO-3.

        Examples
        --------

        #. Allocate a model and partition it among all processes:

            .. code-block:: python

                with deepspeed.zero.Init():
                    model = MyLargeModel()


        #. Allocate a model in pinned CPU memory and partition it among a subgroup of processes:

            .. code-block:: python

                with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(),
                                         remote_device="cpu",
                                         pin_memory=True):
                    model = MyLargeModel()


        #. Partition an already-allocated model in CPU memory:

            .. code-block:: python

                model = deepspeed.zero.Init(module=model)
        """
S
Stas Bekman 已提交
649 650 651 652 653 654
        if config is not None:
            config_dict_or_path = config
            logger.warning(
                f'zero.Init: the `config` argument is deprecated. Please use `config_dict_or_path` instead.'
            )

655 656 657
        _ds_config = deepspeed.runtime.config.DeepSpeedConfig(
            config_dict_or_path,
            mpu) if config_dict_or_path is not None else None
658 659
        super().__init__(enabled=enabled,
                         mem_efficient_linear=mem_efficient_linear,
O
Olatunji Ruwase 已提交
660
                         ds_config=_ds_config,
661
                         dtype=dtype)
662
        if not dist.is_initialized():
S
Samyam Rajbhandari 已提交
663
            init_distributed()
664
            assert dist.is_initialized(), "Parameters cannot be scattered without initializing deepspeed.comm"
S
Samyam Rajbhandari 已提交
665
        if data_parallel_group is None:
666
            self.ds_process_group = dist.get_world_group()
S
Samyam Rajbhandari 已提交
667 668 669
        else:
            self.ds_process_group = data_parallel_group

670 671
        self.rank = dist.get_rank(group=self.ds_process_group)
        self.world_size = dist.get_world_size(group=self.ds_process_group)
S
Samyam Rajbhandari 已提交
672

673
        # Local device is the device where the parameters are consumed, must be default device.
674
        # It is the device where parameters are fully instantiated using allgather
S
Samyam Rajbhandari 已提交
675
        self.local_device = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"]))
676
        torch.cuda.set_device(self.local_device)
S
Samyam Rajbhandari 已提交
677

J
Jeff Rasley 已提交
678
        if _ds_config is not None and _ds_config.zero_config.offload_param is not None:
679 680
            remote_device = _ds_config.zero_config.offload_param.device
            pin_memory = _ds_config.zero_config.offload_param.pin_memory
J
Jeff Rasley 已提交
681

O
Olatunji Ruwase 已提交
682
        self._validate_remote_device(remote_device, _ds_config)
J
Jeff Rasley 已提交
683

684
        # Remote device is the device where parameter partitions are stored
685
        # It can be same as local_device or it could be CPU or NVMe.
686 687 688 689 690 691 692
        self.remote_device = self.local_device if remote_device in [
            None,
            OffloadDeviceEnum.none
        ] else remote_device
        self.pin_memory = pin_memory if (
            self.remote_device in [OffloadDeviceEnum.cpu,
                                   OffloadDeviceEnum.nvme]) else False
J
Jeff Rasley 已提交
693 694

        # Enable fp16 param swapping to NVMe
695
        if self.remote_device == OffloadDeviceEnum.nvme:
696
            self.param_swapper = AsyncPartitionedParameterSwapper(_ds_config, self.dtype)
J
Jeff Rasley 已提交
697 698
        else:
            self.param_swapper = None
S
Samyam Rajbhandari 已提交
699 700 701 702

        # If we are provided an already-allocated module to prepare.
        if module is not None:
            assert isinstance(module, torch.nn.Module)
703 704
            self._convert_to_zero_parameters(module.parameters(recurse=True))

705
        self.use_all_gather_base = False
706
        if dist.has_allgather_base():
707
            self.use_all_gather_base = True
708
        else:
709 710 711
            logger.info(
                f"_all_gather_base API is not available in torch {torch.__version__}")

712 713 714 715 716 717
    def _convert_to_zero_parameters(self, param_list):
        for param in param_list:
            if is_zero_param(param):
                continue
            self._convert_to_deepspeed_param(param)
            param.partition()
S
Samyam Rajbhandari 已提交
718

719
    def _validate_remote_device(self, remote_device, ds_config):
J
Jeff Rasley 已提交
720
        if ds_config is not None:
721
            if remote_device in [None, OffloadDeviceEnum.cpu]:
O
Olatunji Ruwase 已提交
722
                if ds_config.zero_config.offload_param is not None:
723 724 725
                    offload_param_device = ds_config.zero_config.offload_param.device
                    assert offload_param_device != OffloadDeviceEnum.nvme, \
                        f"'device' in DeepSpeed Config cannot be {offload_param_device} if remote device is {remote_device}."
J
Jeff Rasley 已提交
726

727
            if remote_device == OffloadDeviceEnum.nvme:
O
Olatunji Ruwase 已提交
728
                assert ds_config.zero_config.offload_param is not None, \
729
                f'"offload_param" must be defined in DeepSpeed Config if remote device is {OffloadDeviceEnum.nvme}.'
J
Jeff Rasley 已提交
730

731 732
                assert ds_config.zero_config.offload_param.nvme_path is not None, \
                f'"nvme_path" in DeepSpeed Config cannot be None if remote device is {OffloadDeviceEnum.nvme}'
J
Jeff Rasley 已提交
733

S
Samyam Rajbhandari 已提交
734 735 736 737 738 739 740 741
    def _post_init_method(self, module):
        #see_memory_usage(f"Before converting parmas in {module.__class__.__name__}", force=False)
        print_rank_0(f'Converting Params in {module.__class__.__name__}', force=False)
        see_memory_usage(
            f"Before converting and partitioning parmas in {module.__class__.__name__}",
            force=False)

        global param_count
742
        for name, param in module.named_parameters(recurse=False):
S
Samyam Rajbhandari 已提交
743 744 745 746
            param_count += param.numel()
            if not is_zero_param(param):
                self._convert_to_deepspeed_param(param)
                print_rank_0(
S
Stas Bekman 已提交
747
                    f"Partitioning param {debug_param2name_id_shape(param)} module={debug_module2name(module)}"
S
Samyam Rajbhandari 已提交
748
                )
749 750

                if param.is_cuda:
751
                    dist.broadcast(param, 0, self.ds_process_group)
752
                else:
753
                    if dist.get_rank() == 0:
754
                        logger.warn(f"param `{name}` in {module.__class__.__name__} "
755 756
                                    f"not on GPU so was not broadcasted from rank 0")

S
Samyam Rajbhandari 已提交
757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772
                param.partition()
        see_memory_usage(
            f"Param count {param_count}. After converting and partitioning parmas in {module.__class__.__name__}",
            force=False)

    def _convert_to_deepspeed_param(self, param):

        # Partitioned, Normal, Remote
        param.ds_param_type = ZeroParamType.PARTITIONED

        # Replicated vs Partitioned vs Inflight
        param.ds_status = ZeroParamStatus.AVAILABLE

        # Stores the shape of the original tensor
        param.ds_shape = param.shape

773
        # Stores the number of elements in the original parameter without padding
S
Samyam Rajbhandari 已提交
774 775
        param.ds_numel = param.numel()

776
        # Stores the partitioned copy of the tensor
S
Samyam Rajbhandari 已提交
777 778 779
        param.ds_tensor = None

        # Keeps track of how many active sub-modules need this param at any given point in time
780
        param.ds_active_sub_modules = set()
S
Samyam Rajbhandari 已提交
781 782 783 784 785

        # If this flag is true, then the parameters are replicated throughput training
        # And only partitioned before the step
        param.ds_persist = False

786 787
        param.is_external_param = False

S
Samyam Rajbhandari 已提交
788 789 790
        # The group that the parameter is scattered across.
        param.ds_process_group = self.ds_process_group

J
Jeff Rasley 已提交
791 792 793 794
        # This is set to the Async Param swapper if remote device is nvme
        # else this is set to None
        param.nvme_swapper = self.param_swapper

kisseternity's avatar
kisseternity 已提交
795
        # DeepSpeed Param ID
S
Samyam Rajbhandari 已提交
796 797 798 799 800 801 802 803 804
        param.ds_id = Init.param_id
        Init.param_id += 1

        def all_gather(param_list=None, async_op=False, hierarchy=0):
            cls = param
            if param_list is None:
                param_list = [cls]
            return self._all_gather(param_list, async_op=async_op, hierarchy=hierarchy)

805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844
        @instrument_w_nvtx
        def all_gather_coalesced(params: Iterable[Parameter],
                                 safe_mode: bool = False) -> AllGatherCoalescedHandle:

            # fetches from nvme if the partition is not available and in nvme
            self._ensure_availability_of_partitioned_params(params)

            for param in params:
                if param.ds_status != ZeroParamStatus.NOT_AVAILABLE:
                    raise RuntimeError(param.ds_summary())
                param.ds_status = ZeroParamStatus.INFLIGHT

            # ensure that each rank has params in same order. the allgather
            # is done by flattening the parameter list into a single tensor that
            # can be allgathered in a single call - this means that if each rank
            # gives a list of the same parameters in a different order we will
            # silently get incorrect parameter values, and have very difficult
            # to debug correctness issues.
            params = sorted(params, key=lambda p: p.ds_id)

            debug_rank0(f"-allgather_coalesced: {[p.ds_id for p in params]}")

            if safe_mode:
                # ensure that same list (with same ordering) of parameters are
                # being allgathered across all ranks, otherwise could mix
                # data between tensors.
                assert_ints_same_as_other_ranks([p.ds_id for p in params])
                # ensure that tensors from each rank agree on the same ds_numel
                # otherwise could mix data between tensors.
                assert_ints_same_as_other_ranks([p.ds_tensor.ds_numel for p in params])

            if len(params) == 1:
                # have an opportunity to avoid some intermediate memory allocations
                param, = params
                param_buffer = torch.empty(
                    math.ceil(param.ds_numel / self.world_size) * self.world_size,
                    dtype=param.dtype,
                    device=torch.cuda.current_device(),
                    requires_grad=False,
                )
845
                handle = _dist_allgather_fn(
846 847
                    param.ds_tensor.to(torch.cuda.current_device()),
                    param_buffer,
848
                    self.ds_process_group)
849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870
                param.data = param_buffer.narrow(0,
                                                 0,
                                                 param.ds_numel).view(param.ds_shape).to(
                                                     param.device)
                return AllGatherHandle(handle, param)
            else:
                partition_sz = sum(p.ds_tensor.ds_numel for p in params)
                flat_tensor = torch.empty(partition_sz * self.world_size,
                                          dtype=get_only_unique_item(p.dtype
                                                                     for p in params),
                                          device=torch.cuda.current_device(),
                                          requires_grad=False)
                partitions: List[Parameter] = []
                for i in range(self.world_size):
                    partitions.append(
                        flat_tensor.narrow(0,
                                           partition_sz * i,
                                           partition_sz))

                instrument_w_nvtx(torch.cat)(
                    [p.ds_tensor.to(torch.cuda.current_device()) for p in params],
                    out=partitions[self.rank])
871
                handle = _dist_allgather_fn(partitions[self.rank],
872 873 874 875 876 877 878 879 880 881
                                            flat_tensor,
                                            self.ds_process_group)

                return AllGatherCoalescedHandle(
                    allgather_handle=handle,
                    params=params,
                    partitions=partitions,
                    world_size=self.world_size,
                )

S
Samyam Rajbhandari 已提交
882 883 884
        def partition(param_list=None, hierarchy=0, has_been_updated=False):
            cls = param
            print_rank_0(
S
Stas Bekman 已提交
885
                f"{'--'*hierarchy}----Partitioning param {debug_param2name_id_shape_device(cls)}"
S
Samyam Rajbhandari 已提交
886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905
            )
            if param_list is None:
                param_list = [cls]
            self._partition(param_list, has_been_updated=has_been_updated)

        def reduce_gradients_at_owner(param_list=None, hierarchy=0):
            cls = param
            if param_list is None:
                param_list = [cls]
            print_rank_0(
                f"{'--'*hierarchy}----Reducing Gradients for param with ids {[param.ds_id for param in param_list]} to owner"
            )
            self._reduce_scatter_gradients(param_list)

        def partition_gradients(param_list=None,
                                partition_buffers=None,
                                hierarchy=0,
                                accumulate=False):
            cls = param
            print_rank_0(
S
Stas Bekman 已提交
906 907
                f"{'--'*hierarchy}----Partitioning param gradient with id {debug_param2name_id_shape_device(cls)}"
            )
S
Samyam Rajbhandari 已提交
908 909 910 911 912 913 914 915 916 917 918 919 920 921 922
            if param_list is None:
                param_list = [cls]
                if isinstance(partition_buffers, torch.Tensor):
                    partition_buffers = [partition_buffers]

            self._partition_gradients(param_list,
                                      partition_buffers=partition_buffers,
                                      accumulate=accumulate)

        def aligned_size():
            return self._aligned_size(param)

        def padding_size():
            return self._padding_size(param)

O
Olatunji Ruwase 已提交
923 924
        def partition_numel():
            return self._partition_numel(param)
J
Jeff Rasley 已提交
925

926 927 928 929
        def item_override():
            param.all_gather()
            return param._orig_item()

930
        def ds_summary(slf: torch.Tensor, use_debug_name: bool = False) -> dict:
931
            return {
932
                "id": debug_param2name_id(slf) if use_debug_name else slf.ds_id,
933 934 935 936 937 938 939 940 941 942 943
                "status": slf.ds_status.name,
                "numel": slf.numel(),
                "ds_numel": slf.ds_numel,
                "shape": tuple(slf.shape),
                "ds_shape": tuple(slf.ds_shape),
                "requires_grad": slf.requires_grad,
                "grad_shape": tuple(slf.grad.shape) if slf.grad is not None else None,
                "persist": slf.ds_persist,
                "active_sub_modules": slf.ds_active_sub_modules,
            }

944 945 946
        def convert_to_zero_parameters(param_list):
            self._convert_to_zero_parameters(param_list)

947 948 949 950 951 952 953
        def allgather_before(func: Callable) -> Callable:
            def wrapped(*args, **kwargs):
                param.all_gather()
                return func(*args, **kwargs)

            return wrapped

S
Samyam Rajbhandari 已提交
954 955
        # Collectives for gathering and partitioning parameters
        param.all_gather = all_gather
956
        param.all_gather_coalesced = all_gather_coalesced
S
Samyam Rajbhandari 已提交
957 958 959 960 961 962 963 964 965
        param.partition = partition

        # Collective for averaging gradients
        param.reduce_gradients_at_owner = reduce_gradients_at_owner
        param.partition_gradients = partition_gradients

        # Partitioning size utilities
        param.aligned_size = aligned_size
        param.padding_size = padding_size
O
Olatunji Ruwase 已提交
966
        param.partition_numel = partition_numel
967 968 969
        param.ds_summary = types.MethodType(ds_summary, param)

        param.item = allgather_before(param.item)
S
Samyam Rajbhandari 已提交
970

971 972
        param.convert_to_zero_parameters = convert_to_zero_parameters

S
Samyam Rajbhandari 已提交
973 974 975 976 977 978 979
    def _aligned_size(self, param):
        return param.ds_numel + self._padding_size(param)

    def _padding_size(self, param):
        remainder = param.ds_numel % self.world_size
        return (self.world_size - remainder) if remainder else 0

O
Olatunji Ruwase 已提交
980
    def _partition_numel(self, param):
J
Jeff Rasley 已提交
981 982 983 984 985 986 987
        return param.ds_tensor.ds_numel

    def _ensure_availability_of_partitioned_params(self, params):
        swap_in_list = []
        swap_in_flight = []
        for param in params:
            if param.ds_tensor.status == PartitionedParamStatus.NOT_AVAILABLE:
988
                assert param.ds_tensor.final_location == OffloadDeviceEnum.nvme and param.ds_status == ZeroParamStatus.NOT_AVAILABLE
J
Jeff Rasley 已提交
989 990
                swap_in_list.append(param)
            if param.ds_tensor.status == PartitionedParamStatus.INFLIGHT:
991
                assert param.ds_tensor.final_location == OffloadDeviceEnum.nvme and param.ds_status == ZeroParamStatus.NOT_AVAILABLE
J
Jeff Rasley 已提交
992 993 994 995 996 997
                swap_in_flight.append(param)
        if len(swap_in_list) > 0:
            swap_in_list[0].nvme_swapper.swap_in(swap_in_list, async_op=False)
        elif len(swap_in_flight) > 0:
            swap_in_flight[0].nvme_swapper.synchronize_reads()

998
    @instrument_w_nvtx
S
Samyam Rajbhandari 已提交
999
    def _all_gather(self, param_list, async_op=False, hierarchy=None):
J
Jeff Rasley 已提交
1000

1001
        # fetches from nvme if the partition is not available and in nvme
J
Jeff Rasley 已提交
1002 1003
        self._ensure_availability_of_partitioned_params(param_list)

S
Samyam Rajbhandari 已提交
1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017
        handles = []
        all_gather_list = []
        for param in param_list:
            if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
                if async_op:
                    handle = self._allgather_param(param,
                                                   async_op=async_op,
                                                   hierarchy=hierarchy)
                    param.ds_status = ZeroParamStatus.INFLIGHT  # if async_op else ZeroParamStatus.AVAILABLE
                    handles.append(handle)
                else:
                    all_gather_list.append(param)

        if not async_op:
1018 1019 1020 1021
            if len(param_list) == 1:
                ret_value = self._allgather_params(all_gather_list, hierarchy=hierarchy)
            else:
                ret_value = self._allgather_params_coalesced(all_gather_list, hierarchy)
1022

S
Samyam Rajbhandari 已提交
1023 1024 1025 1026 1027 1028 1029 1030 1031
            for param in all_gather_list:
                param.ds_status = ZeroParamStatus.AVAILABLE
            return ret_value

        return handles

    def _partition(self, param_list, force=False, has_been_updated=False):
        for param in param_list:
            #print_rank_0(f"Before Partitioning Param {param.ds_id}")
1032
            # self._param_status(param)
S
Samyam Rajbhandari 已提交
1033 1034
            self._partition_param(param, has_been_updated=has_been_updated)
            param.ds_status = ZeroParamStatus.NOT_AVAILABLE
1035
            # if param.ds_tensor is not None:
S
Samyam Rajbhandari 已提交
1036 1037 1038 1039 1040
            #    assert id(param.data) == id(param.ds_tensor.data), \
            #    "After the parameters are initially partitioned, make sure we are not recreating the partition."
            #print_rank_0(f"After Partitioning Param {param.ds_id}")
            # self._param_status(param)

1041
    @instrument_w_nvtx
J
Jeff Rasley 已提交
1042
    def _partition_param(self, param, buffer=None, has_been_updated=False):
1043
        assert param.ds_status is not ZeroParamStatus.INFLIGHT, f" {param} Cannot partition a param in flight"
J
Jeff Rasley 已提交
1044

S
Samyam Rajbhandari 已提交
1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059
        global reuse_buffers
        #print_rank_0(f"Param id {param.ds_id} status is {param.ds_status}")
        if param.ds_status is ZeroParamStatus.AVAILABLE:
            print_rank_0(
                f"Partitioning param id {param.ds_id} reuse buffers {reuse_buffers}",
                force=False)
            # if reuse_buffers and False:
            #     numel = buffer.numel()
            #     buffer = param.data.view(-1)
            #     print_rank_0(
            #         "Returning buffer for param {param.ds_id} with numel {param.ds_numel} to empty buffers",
            #         force=False)
            #     if numel in empty_buffers:
            #         empty_buffers[numel].append(buffer)

1060
            # if deepspeed.comm.get_rank():
S
Samyam Rajbhandari 已提交
1061 1062 1063 1064 1065
            #    print(f"Releasing {param.data.numel()}")
            if param.ds_tensor is not None and not has_been_updated:

                #param.data = param.ds_tensor.data

J
Jeff Rasley 已提交
1066 1067 1068
                see_memory_usage(
                    f'Before partitioning param {param.ds_id} {param.shape}',
                    force=False)
1069
                # param.data does not store anything meaningful in partitioned state
1070
                free_param(param)
J
Jeff Rasley 已提交
1071 1072 1073
                see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}',
                                 force=False)

1074
                if param.ds_tensor.final_location == OffloadDeviceEnum.nvme:
J
Jeff Rasley 已提交
1075 1076 1077 1078 1079
                    print_rank_0(
                        f"Param {param.ds_id} partition released since it exists in nvme",
                        force=False)
                    param.nvme_swapper.remove_partition_and_release_buffers([param])

S
Samyam Rajbhandari 已提交
1080 1081 1082 1083 1084 1085
                return

            tensor_size = self._aligned_size(param)
            partition_size = tensor_size // self.world_size

            if param.ds_tensor is None:
J
Jeff Rasley 已提交
1086
                final_location = None
1087
                if self.remote_device == OffloadDeviceEnum.nvme and self.param_swapper.swappable_tensor(
J
Jeff Rasley 已提交
1088
                        numel=partition_size):
1089
                    final_location = OffloadDeviceEnum.nvme
J
Jeff Rasley 已提交
1090
                    buffer = self.param_swapper.get_buffer(param, partition_size)
1091
                    partitioned_tensor = torch.empty(0,
J
Jeff Rasley 已提交
1092 1093 1094 1095 1096 1097
                                                     dtype=param.dtype,
                                                     device=buffer.device)
                    partitioned_tensor.data = buffer.data
                    print_rank_0(
                        f"ID {param.ds_id} Initializing partition for the first time for nvme offload."
                    )
S
Samyam Rajbhandari 已提交
1098

J
Jeff Rasley 已提交
1099
                else:
1100
                    partitioned_tensor = torch.empty(
J
Jeff Rasley 已提交
1101 1102
                        partition_size,
                        dtype=param.dtype,
1103 1104
                        device=OffloadDeviceEnum.cpu if self.remote_device
                        == OffloadDeviceEnum.nvme else self.remote_device)
J
Jeff Rasley 已提交
1105 1106 1107 1108
                    if self.pin_memory:
                        partitioned_tensor = partitioned_tensor.pin_memory()

                partitioned_tensor.requires_grad = False
S
Samyam Rajbhandari 已提交
1109
                param.ds_tensor = partitioned_tensor
J
Jeff Rasley 已提交
1110 1111 1112
                param.ds_tensor.ds_numel = partition_size
                param.ds_tensor.status = PartitionedParamStatus.AVAILABLE
                param.ds_tensor.final_location = final_location
S
Samyam Rajbhandari 已提交
1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145

            start = partition_size * self.rank
            end = start + partition_size

            one_dim_param = param.contiguous().view(-1)

            if start < param.ds_numel and end <= param.ds_numel:
                src_tensor = one_dim_param.narrow(0, start, partition_size)

                param.ds_tensor.copy_(src_tensor)
                #partitioned_tensor = src_tensor.clone().detach().to(self.remote_device)

            else:
                # partitioned_tensor = torch.zeros(partition_size,
                #                                  dtype=param.dtype,
                #                                  device=self.remote_device )

                if start < param.ds_numel:
                    elements_to_copy = param.ds_numel - start
                    param.ds_tensor.narrow(0,
                                           0,
                                           elements_to_copy).copy_(
                                               one_dim_param.narrow(
                                                   0,
                                                   start,
                                                   elements_to_copy))

            #print(f"Remote device {self.remote_device}")

            #param.ds_tensor = partitioned_tensor

            #param.data = param.ds_tensor.data

1146
            # param.data does not store anything meaningful in partitioned state
J
Jeff Rasley 已提交
1147 1148 1149

            see_memory_usage(f'Before partitioning param {param.ds_id} {param.shape}',
                             force=False)
1150
            free_param(param)
J
Jeff Rasley 已提交
1151 1152 1153
            see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}',
                             force=False)

1154
            if param.ds_tensor.final_location == OffloadDeviceEnum.nvme:
J
Jeff Rasley 已提交
1155 1156 1157 1158 1159 1160
                self.param_swapper.swap_out_and_release([param])
                print_rank_0(
                    f"ID {param.ds_id} Offloaded to nvme offload and buffers released.")
                see_memory_usage(
                    f"ID {param.ds_id} Offloaded to nvme offload and buffers released.",
                    force=False)
S
Samyam Rajbhandari 已提交
1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177

            print_rank_0(
                f"ID {param.ds_id} partitioned type {param.dtype} dev {param.device} shape {param.shape}"
            )

    def _param_status(self, param):
        if param.ds_tensor is not None:
            print_rank_0(
                f"Param id {param.ds_id}, param status: {param.ds_status}, param numel {param.ds_numel}, partitioned numel {param.ds_tensor.numel()}, data numel {param.data.numel()}"
            )
        else:
            print_rank_0(
                f"Param id {param.ds_id}, param status: {param.ds_status}, param numel {param.ds_numel}, partitioned ds_tensor {param.ds_tensor}, data numel {param.data.numel()}"
            )

    def _allgather_param(self, param, async_op=False, hierarchy=0):

J
Jeff Rasley 已提交
1178
        partition_size = param.ds_tensor.ds_numel
S
Samyam Rajbhandari 已提交
1179 1180 1181 1182 1183 1184

        tensor_size = partition_size * self.world_size
        aligned_param_size = self._aligned_size(param)
        assert tensor_size == aligned_param_size, f'param id {param.ds_id} aligned size {aligned_param_size} does not match tensor size {tensor_size}'

        print_rank_0(
S
Stas Bekman 已提交
1185
            f"{'--'* hierarchy}---- Before allocating allgather param {debug_param2name_id_shape_status(param)} partition size={partition_size}"
S
Samyam Rajbhandari 已提交
1186
        )
J
Jeff Rasley 已提交
1187 1188

        see_memory_usage(
S
Stas Bekman 已提交
1189
            f'Before allocate allgather param {debug_param2name_id_shape_status(param)} partition_size={partition_size} ',
J
Jeff Rasley 已提交
1190
            force=False)
S
Samyam Rajbhandari 已提交
1191 1192 1193
        flat_tensor = torch.zeros(aligned_param_size,
                                  dtype=param.dtype,
                                  device=param.device).view(-1)
J
Jeff Rasley 已提交
1194
        see_memory_usage(
S
Stas Bekman 已提交
1195
            f'After allocate allgather param {debug_param2name_id_shape_status(param)} {aligned_param_size} {partition_size} ',
J
Jeff Rasley 已提交
1196
            force=False)
S
Samyam Rajbhandari 已提交
1197 1198 1199 1200

        torch.cuda.synchronize()

        print_rank_0(
S
Stas Bekman 已提交
1201
            f"{'--'* hierarchy}----allgather param with {debug_param2name_id_shape_status(param)} partition size={partition_size}"
S
Samyam Rajbhandari 已提交
1202 1203 1204 1205 1206 1207 1208
        )
        #        if not flat_tensor.numel() > 100000:
        #            replicated_tensor = flat_tensor.narrow(0,
        #                                                   0,
        #                                                   param.ds_numel).view(param.ds_shape)
        #            param.data = replicated_tensor.data
        #            return None
1209 1210
        if self.use_all_gather_base:
            # try the _all_gather_base on PyTorch master branch
1211 1212 1213 1214
            handle = dist.all_gather_base(flat_tensor,
                                          param.ds_tensor.cuda(),
                                          group=self.ds_process_group,
                                          async_op=async_op)
1215 1216 1217 1218 1219 1220 1221
        else:
            partitions = []
            for i in range(self.world_size):
                partitions.append(
                    flat_tensor.narrow(0,
                                       partition_size * i,
                                       partition_size))
S
Samyam Rajbhandari 已提交
1222

1223 1224
                if i == dist.get_rank(group=self.ds_process_group):
                    partitions[i].data.copy_(param.ds_tensor.data, non_blocking=True)
S
Samyam Rajbhandari 已提交
1225

1226 1227 1228 1229
            handle = dist.all_gather(partitions,
                                     partitions[self.rank],
                                     group=self.ds_process_group,
                                     async_op=async_op)
S
Samyam Rajbhandari 已提交
1230 1231 1232 1233 1234

        replicated_tensor = flat_tensor.narrow(0, 0, param.ds_numel).view(param.ds_shape)
        param.data = replicated_tensor.data
        return handle

1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266
    def _allgather_params_coalesced(self, param_list, hierarchy=0):
        """ blocking call
        avoid explicit memory copy in _allgather_params
        """
        if len(param_list) == 0:
            return
        # collect local tensors and partition sizes
        partition_sizes = []
        local_tensors = []
        for param in param_list:
            partition_sizes.append(param.ds_tensor.ds_numel)
            local_tensors.append(param.ds_tensor.cuda())

        # allocate memory for allgather params
        allgather_params = []
        for psize in partition_sizes:
            tensor_size = psize * self.world_size
            flat_tensor = torch.empty(tensor_size,
                                      dtype=param_list[0].dtype,
                                      device=self.local_device).view(-1)
            flat_tensor.requires_grad = False
            allgather_params.append(flat_tensor)

        # launch
        launch_handles = []
        # backend = get_backend(self.ds_process_group)
        # with _batch_p2p_manager(backend):
        for param_idx, param in enumerate(param_list):
            input_tensor = local_tensors[param_idx].view(-1)

            if self.use_all_gather_base:
                # try the _all_gather_base from Pytorch master
1267 1268 1269 1270
                h = dist.all_gather_base(allgather_params[param_idx],
                                         input_tensor,
                                         group=self.ds_process_group,
                                         async_op=True)
1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303
            else:
                output_list = []
                for i in range(self.world_size):
                    psize = partition_sizes[param_idx]
                    partition = allgather_params[param_idx].narrow(0, i * psize, psize)
                    output_list.append(partition)
                    if not partition.is_cuda:
                        logger.warning(
                            f'param {param_idx}, partition {i} is not on CUDA, partition shape {partition.size()}'
                        )

                # back to old all_gather function signature
                h = dist.all_gather(output_list,
                                    input_tensor,
                                    group=self.ds_process_group,
                                    async_op=True)
            launch_handles.append(h)

        # Wait ensures the operation is enqueued, but not necessarily complete.
        launch_handles[-1].wait()

        # assign to param.data (not copy)
        for i, param in enumerate(param_list):
            gathered_tensor = allgather_params[i]
            param.data = gathered_tensor.narrow(0,
                                                0,
                                                param.ds_numel).view(param.ds_shape).data

        # guarantee the communication to be completed
        torch.cuda.synchronize()

        return None

S
Samyam Rajbhandari 已提交
1304 1305 1306 1307
    def _allgather_params(self, param_list, hierarchy=0):
        if len(param_list) == 0:
            return

J
Jeff Rasley 已提交
1308
        partition_size = sum([param.ds_tensor.ds_numel for param in param_list])
S
Samyam Rajbhandari 已提交
1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323

        tensor_size = partition_size * self.world_size
        flat_tensor = torch.empty(tensor_size,
                                  dtype=param_list[0].dtype,
                                  device=self.local_device)
        flat_tensor.requres_grad = False
        partitions = []
        for i in range(self.world_size):
            start = partition_size * i

            partitions.append(flat_tensor.narrow(0, start, partition_size))

            if i == self.rank:
                offset = 0
                for param in param_list:
J
Jeff Rasley 已提交
1324
                    param_numel = param.ds_tensor.ds_numel
S
Samyam Rajbhandari 已提交
1325 1326 1327 1328 1329 1330 1331

                    partitions[i].narrow(0,
                                         offset,
                                         param_numel).copy_(param.ds_tensor.data)

                    offset += param_numel

1332 1333 1334 1335
        dist.all_gather(partitions,
                        partitions[self.rank],
                        group=self.ds_process_group,
                        async_op=False)
S
Samyam Rajbhandari 已提交
1336 1337 1338
        param_offset = 0

        for param in param_list:
J
Jeff Rasley 已提交
1339
            param_partition_size = param.ds_tensor.ds_numel
S
Samyam Rajbhandari 已提交
1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359
            param_size = param.ds_numel
            replicated_tensor = torch.empty(param.ds_shape,
                                            dtype=param.dtype,
                                            device=self.local_device)

            for i in range(self.world_size):

                start = i * partition_size

                param_start = i * param_partition_size

                if param_start < param_size:
                    numel_to_copy = min(param_size - param_start, param_partition_size)

                    part_to_copy = partitions[i].narrow(0, param_offset, numel_to_copy)

                    replicated_tensor.view(-1).narrow(0,
                                                      param_start,
                                                      numel_to_copy).copy_(part_to_copy)
            #param_offset += param.data.numel()
J
Jeff Rasley 已提交
1360
            param_offset += param.ds_tensor.ds_numel
S
Samyam Rajbhandari 已提交
1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383

            param.data = replicated_tensor.data

        return None

    def _reduce_scatter_gradients(self, param_list):
        #print_rank_0([param.grad for param in param_list])
        #assert any([param.grad is None for param in param_list]), "None gradients cannot be reduce scattered"

        handles_and_reduced_partitions = []
        for param in param_list:
            assert param.grad.numel(
            ) == param.ds_numel, f"{param.grad.numel()} != {param.ds_numel} Cannot reduce scatter gradients whose size is not same as the params"

            handles_and_reduced_partitions.append(self._reduce_scatter_gradient(param))

        for param, (handle, reduced_partition) in zip(param_list, handles_and_reduced_partitions):
            if handle is not None:
                handle.wait()

            # some ranks may have partitions that are padded to go beyond the grad size.
            # For these ranks the output of reduce scatter is a separate buffer and needs
            # to be copied in
J
Jeff Rasley 已提交
1384
            partition_size = param.ds_tensor.ds_numel
S
Samyam Rajbhandari 已提交
1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398
            start = self.rank * partition_size
            end = start + partition_size
            #print_rank_0("REduce scatter was executed for praam {param.ds_id}")
            if start < param.ds_numel and end > param.ds_numel:
                elements = param.ds_numel - start
                param.grad.view(-1).narrow(0,
                                           start,
                                           elements).copy_(
                                               reduced_partition.narrow(0,
                                                                        0,
                                                                        elements))

    def _reduce_scatter_gradient(self, param):

J
Jeff Rasley 已提交
1399
        partition_size = param.ds_tensor.ds_numel
S
Samyam Rajbhandari 已提交
1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428
        #output = torch.empty(partition_size, dtype=param.dtype, device=param.device)

        total_size = partition_size * self.world_size
        input_list = []

        for i in range(self.world_size):

            start = i * partition_size
            end = start + partition_size

            #print("before reduce scatter gradients")
            if start < param.ds_numel and end <= param.ds_numel:
                input = param.grad.view(-1).narrow(0, start, partition_size)
            else:
                input = torch.zeros(partition_size,
                                    dtype=param.dtype,
                                    device=param.device)

                if start < param.ds_numel:
                    elements = param.ds_numel - start
                    input.narrow(0,
                                 0,
                                 elements).copy_(
                                     param.grad.view(-1).narrow(0,
                                                                start,
                                                                elements))
            #print("after reduce scatter gradients")
            input_list.append(input)

1429 1430 1431 1432 1433
        rank = dist.get_rank(group=self.ds_process_group)
        handle = dist.reduce_scatter(input_list[rank],
                                     input_list,
                                     group=self.ds_process_group,
                                     async_op=True)
S
Samyam Rajbhandari 已提交
1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450

        return handle, input_list[rank]

    def _partition_gradients(self, param_list, partition_buffers=None, accumulate=False):
        if partition_buffers is None:
            partition_buffers = [None] * len(param_list)

        for param, partition_buffer in zip(param_list, partition_buffers):
            self._partition_gradient(param,
                                     partition_buffer=partition_buffer,
                                     accumulate=accumulate)

    def _partition_gradient(self, param, partition_buffer=None, accumulate=False):
        #import pdb;pdb.set_trace()
        # param.grad=None
        # param.grad.test()
        print_rank_0(
J
Jeff Rasley 已提交
1451
            f"Partitioning param {param.ds_id} gradient of size {param.grad.numel()} type {param.grad.dtype} part_size {param.ds_tensor.ds_numel}"
S
Samyam Rajbhandari 已提交
1452 1453
        )
        see_memory_usage("Before partitioning gradients", force=False)
J
Jeff Rasley 已提交
1454
        partition_size = param.ds_tensor.ds_numel
S
Samyam Rajbhandari 已提交
1455 1456 1457 1458 1459 1460 1461

        if partition_buffer is None:
            assert not accumulate, "No buffer to accumulate to"
            partition_buffer = torch.zeros(partition_size,
                                           dtype=param.dtype,
                                           device=param.device)
        else:
1462 1463
            assert partition_buffer.numel(
            ) >= partition_size, f"The partition buffer size {partition_buffer.numel()} should match the size of param.ds_tensor {partition_size}"
S
Samyam Rajbhandari 已提交
1464

1465
        rank = dist.get_rank(group=self.ds_process_group)
S
Samyam Rajbhandari 已提交
1466 1467 1468
        start = partition_size * rank
        end = start + partition_size

1469
        dest_tensor_full_buffer = partition_buffer.view(-1).narrow(0, 0, partition_size)
S
Samyam Rajbhandari 已提交
1470 1471 1472 1473 1474

        #print("before partition gradients")
        if start < param.ds_numel:
            elements = min(param.ds_numel - start, partition_size)

1475
            dest_tensor = dest_tensor_full_buffer.narrow(0, 0, elements)
S
Samyam Rajbhandari 已提交
1476 1477 1478 1479 1480 1481
            src_tensor = param.grad.view(-1).narrow(0, start, elements)

            # just copy the grad partition to the buffer
            if not accumulate:
                dest_tensor.copy_(src_tensor)

A
Alex Hedges 已提交
1482
            # if source and destination are on same device,
S
Samyam Rajbhandari 已提交
1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507
            # add to the provided buffer
            elif src_tensor.device == dest_tensor.device:
                dest_tensor.add_(src_tensor)

            # if source and destination are on different device, copy first to src
            # then add and move back to the destination. This seems to run faster
            # when src is gpu and dest is cpu
            # adding directly to cpu is very slow
            else:
                acc_tensor = torch.empty(src_tensor.numel(),
                                         dtype=param.dtype,
                                         device=param.device)

                acc_tensor.copy_(dest_tensor)
                acc_tensor.add_(src_tensor)
                dest_tensor.copy_(acc_tensor)

            # partition_buffer.view(-1).narrow(
            #     0,
            #     0,
            #     elements).copy_(param.grad.view(-1).narrow(0,
            #                                             start,
            #                                             elements))

        #print("after partition gradients")
1508
        param.grad.data = dest_tensor_full_buffer.data
S
Samyam Rajbhandari 已提交
1509 1510 1511 1512
        see_memory_usage("After partitioning gradients", force=False)


class GatheredParameters:
1513 1514 1515
    def __init__(self, params, modifier_rank=None, fwd_module=None, enabled=True):
        """A context that collects parameters that were partitioned via a
        :class:`deepspeed.zero.Init` context. The parameters are partitioned
S
Samyam Rajbhandari 已提交
1516 1517 1518
        again upon exit.

        Args:
1519
            params (``torch.nn.Parameter``): A single parameter, a list, or a tuple of parameters to collect.
1520
                It's assumed that all parameters are zero params.
S
Samyam Rajbhandari 已提交
1521
            modifier_rank (int, optional): If specified, this rank's parameter will be
1522 1523
                broadcasted on exit from the context. This argument is required if ``params`` are
                modified, so that all processes have a consistent view of the data. Defaults
S
Samyam Rajbhandari 已提交
1524
                to ``None``.
1525 1526
            fwd_module (``torch.nn.Module``, optional): If specified, ``params`` will be
                registered as external parameters of ``fwd_module``. See :meth:`deepspeed.zero.register_external_parameter`.
S
Samyam Rajbhandari 已提交
1527 1528
            enabled (bool, optional): If ``False``, this context is a no-op. Defaults to ``True``.

1529
        Important: Make sure to use ``modifier_rank`` that is not ``None`` (e.g., ``modifier_rank=0``)
1530 1531
        if you need the GPU memory allocated by gather to be released upon exit from the context manager.

S
Samyam Rajbhandari 已提交
1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544
        Examples
        ========

        #. Allocate a partitioned module, initialize its weight on rank 0, and update all
           processes.

            .. code-block:: python

                with deepspeed.zero.Init():
                    linear = torch.nn.Linear(1000,1000)

                with deepspeed.zero.GatheredParameters(linear.weight,
                                                       modifier_rank=0):
1545
                    if deepspeed.comm.get_rank() == 0:
S
Samyam Rajbhandari 已提交
1546 1547
                        linear.weight.zero_()

J
Jeff Rasley 已提交
1548 1549
                with deepspeed.zero.GatheredParameters(linear.weight,
                                                       modifier_rank=0):
1550
                    if deepspeed.comm.get_rank() == 0:
J
Jeff Rasley 已提交
1551
                        linear.weight.zero_()
S
Samyam Rajbhandari 已提交
1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567

        #. Collect a partitioned weight to pass to another module during
           training. The parameter will be registered as an external parameter
           and made available during the backward pass.

            .. code-block:: python
                :emphasize-lines: 6

                def forward(self, input):
                    x = self.layer1(input)

                    # self.layer1.weight is required by self.layer2.forward
                    with deepspeed.zero.GatheredParameters(self.layer1.weight,
                                                           fwd_module=self):
                        y = self.layer2(x, self.layer1.weight)
                    return y
1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583


        #. Pretrained model loading

            .. code-block:: python

                with deepspeed.zero.Init():
                    model = MyModel()

                state_dict = torch.load(model_path, map_location="cpu")

                def load(module: nn.Module, prefix=""):
                    # because zero3 puts placeholders in model params, this context
                    # manager gathers (unpartitions) the params of the current layer, then loads from
                    # the state dict and then re-partitions them again
                    with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
1584
                        if deepspeed.comm.get_rank() == 0:
1585 1586 1587 1588 1589 1590 1591 1592
                            module._load_from_state_dict(state_dict, prefix)

                    for name, child in module._modules.items():
                        if child is not None:
                            load(child, prefix + name + ".")

                load(model, prefix="")

1593 1594
        If this approach is not used, then the full model will first be copied to each GPU. For models
        bigger than the memory of a single GPU, this method is required.
S
Samyam Rajbhandari 已提交
1595 1596 1597 1598 1599 1600
        """

        self.enabled = enabled
        if not enabled:
            return

1601
        if not (isinstance(params, list) or isinstance(params, tuple)):
1602 1603 1604 1605
            params = [params]

        # enable if at least one is zero-param, otherwise a noop
        if not any(is_zero_param(p) for p in params):
S
Samyam Rajbhandari 已提交
1606 1607 1608
            self.enabled = False
            return

1609
        self.params = [p for p in params if hasattr(p, "ds_id")]
S
Samyam Rajbhandari 已提交
1610 1611
        self.src_rank = None
        if modifier_rank is not None:
1612
            if self.params[0].ds_process_group == dist.get_world_group():
S
Samyam Rajbhandari 已提交
1613 1614 1615
                self.src_rank = modifier_rank
            else:
                # A group was specified; convert DP rank to global rank
1616 1617
                self.src_rank = dist.get_global_rank(self.params[0].ds_process_group,
                                                     modifier_rank)
S
Samyam Rajbhandari 已提交
1618 1619 1620
        self.fwd_module = fwd_module
        if self.fwd_module is not None:
            # is a no-op if already registered
1621 1622
            for p in self.params:
                register_external_parameter(self.fwd_module, p)
S
Samyam Rajbhandari 已提交
1623 1624 1625 1626

    def __enter__(self):
        if not self.enabled:
            return
1627
        self.params[0].all_gather(param_list=self.params)
S
Samyam Rajbhandari 已提交
1628 1629 1630 1631

    def __exit__(self, *exc):
        if not self.enabled:
            return
1632 1633 1634 1635
        if self.src_rank is None:
            return

        handles = [
1636 1637 1638 1639
            dist.broadcast(p,
                           self.src_rank,
                           group=p.ds_process_group,
                           async_op=True) for p in self.params
1640 1641 1642 1643
        ]
        for h in handles:
            h.wait()
        self.params[0].partition(param_list=self.params, has_been_updated=True)