From 4912e0ad7efcaf97389ae944259aa0e9f331038a Mon Sep 17 00:00:00 2001 From: Justin Chiu <31414860+jfc4050@users.noreply.github.com> Date: Thu, 20 Jan 2022 18:14:13 -0800 Subject: [PATCH] Various ZeRO Stage3 Optimizations + Improvements (including bfloat16 support) (#1453) * Changes for bfloat16 Zero2 * ZeRO stage3 optimizations, with some bug fixes optimizations for stage3: - prefetching improvements - batching allgather calls to amortize fixed overhead and improve bandwidth utilization - batching reduce_scatter calls to amortize fixed overhead and improve bandwidth utilization - using *_base variants of allgather and reduce scatter to reduce memory allocations and data movement - more fine grained synchronization for communication that allows blocking on less work - precomputation of fetching code - using a fetch queue rather than deciding what to (pre)fetch at each iteration - limiting queued coalesced communication ops to reduce memory pressure on pytorch cuda caching allocator (not elegant solution) optimizations for stage3-offload: - made some host-device tensor copies async to improve performance bug fixes and qol improvements: - fix init context method when parent modules modify child weights - speed up model initialization by moving model to GPU before weight initialization - fixed unit test imports so that unit tests can be run from any directory - change performance logging to include memory consumption - add logging w/ model size when done partitioning model new features - bfloat16 support for ZeRO 3 * fix import in ut * ran yapf * improvements to cache flush warn log * backwards compatibility with older versions of pytorch * handle edge case where reduced tensor smaller than world size * moved event synchronization to allgather handle wait() call * removed unnecessary barrier call * formatting fix after resolving merge conflict * skip nvme prefetch when trace not complete * opportunistically avoid memory allocation in allgather coalesced where possible * fix indentation after merge * fixes to account for parameter offload * accounting for torch.cuda.memory_stats not being available * moved partition_all_params to optimizer step * allgathering on params before item gets called * fix param status checks needed after moving partition_all_parameters call to optimizer step * fix grad accumulation with optimizer offload * grad norm computation fix for optimizer offload * change post divide in reduce-scatter to pre divide * fix gradient race condition w/ optimizer offload * improve inf/nan gradient tracking * don't prefetch when not in training mode * format fix after merging * fix prefetching issue when using NVME offload * improved defragmentation for fp16 parameters * relative imports for bf16 tests * changes for bwd compatibility with pytorch 1.2 * remove buffered_reduce_fallback * removed unused parameter offset bookkeeping * fixed tracking for multiple param groups * unbroke bfloat16 config after merge conflict * using base allgather params when only 1 param * cleanup/fixes for fp16 partition defragmentation * switch to CRLF * convert to same new-line style as master * align new line with master * Fix merge issues * switch to CRLF * fix to LF line endings * minor merge fixes * remove extra bfloat16_enabled definition * asserting params inflight for AllGatherHandle * remove get_cuda_mem_allocated_str * Format fixes * fix bfloat16 zero stage check (broken after merge commit) * +self.communication_data_type, -self.allreduce_always_fp32; delete dead code * Add self.reduce_scatter * Format fix * Fix merge issues * iterate over params_to_fetch rather than make another iterator * add some TODOs * remove unnecessary division by micro_step_id * rename config keys "bfloat16" -> "bf16" * rename stage3_gather_fp16_weights_on_model_save -> stage3_gather_16bit_weights_on_model_save * add unit test to check backwards compatibility for gather_16bit_weights * added test to confirm bf16 key bwd compatibility * Format fixes Co-authored-by: Rana Ali Amjad Co-authored-by: Justin Chiu Co-authored-by: Olatunji Ruwase Co-authored-by: Jeff Rasley --- DeepSpeedExamples | 2 +- .../config_templates/template_zero3.json | 2 +- .../runtime/comm/coalesced_collectives.py | 116 ++ deepspeed/runtime/config.py | 14 +- deepspeed/runtime/constants.py | 5 +- deepspeed/runtime/engine.py | 52 +- deepspeed/runtime/utils.py | 9 + deepspeed/runtime/zero/config.py | 16 +- deepspeed/runtime/zero/constants.py | 7 +- .../runtime/zero/partition_parameters.py | 422 ++++- deepspeed/runtime/zero/stage3.py | 1620 ++++++++--------- deepspeed/runtime/zero/utils.py | 36 + deepspeed/utils/__init__.py | 1 + deepspeed/utils/nvtx.py | 15 + deepspeed/utils/timer.py | 14 +- docs/_pages/config-json.md | 8 +- docs/_tutorials/zero.md | 6 +- docs/code-docs/source/training.rst | 2 +- tests/unit/__init__.py | 0 tests/unit/megatron_model.py | 3 +- tests/unit/test_activation_checkpointing.py | 2 +- tests/unit/test_adamw.py | 4 +- tests/unit/test_aio.py | 2 +- tests/unit/test_autotuning.py | 2 +- tests/unit/test_bf16.py | 18 +- tests/unit/test_checkpointing.py | 6 +- tests/unit/test_coalesced_collectives.py | 62 + tests/unit/test_config.py | 35 +- tests/unit/test_configurable_parallel.py | 8 +- tests/unit/test_cuda_backward.py | 11 +- tests/unit/test_cuda_forward.py | 5 +- tests/unit/test_curriculum_learning.py | 4 +- tests/unit/test_data.py | 4 +- tests/unit/test_dist.py | 2 +- tests/unit/test_ds_initialize.py | 6 +- tests/unit/test_dynamic_loss_scale.py | 4 +- tests/unit/test_elastic.py | 4 +- tests/unit/test_flops_profiler.py | 4 +- tests/unit/test_fp16.py | 6 +- tests/unit/test_ignore_unused_parameters.py | 4 +- tests/unit/test_lr_schedulers.py | 4 +- tests/unit/test_moe.py | 6 +- tests/unit/test_multi_output_model.py | 6 +- tests/unit/test_onebit.py | 6 +- tests/unit/test_partition.py | 2 +- tests/unit/test_pipe.py | 4 +- tests/unit/test_pipe_module.py | 4 +- tests/unit/test_pld.py | 5 +- tests/unit/test_runtime_utils.py | 2 +- tests/unit/test_sparse_grads.py | 3 +- tests/unit/test_topology.py | 2 +- tests/unit/test_zero.py | 762 +++++++- tests/unit/test_zero_context.py | 3 +- 53 files changed, 2295 insertions(+), 1057 deletions(-) create mode 100644 deepspeed/runtime/comm/coalesced_collectives.py create mode 100644 deepspeed/utils/nvtx.py create mode 100644 tests/unit/__init__.py create mode 100644 tests/unit/test_coalesced_collectives.py diff --git a/DeepSpeedExamples b/DeepSpeedExamples index 1fed12e8..174ae3bc 160000 --- a/DeepSpeedExamples +++ b/DeepSpeedExamples @@ -1 +1 @@ -Subproject commit 1fed12e8b375b0c54902827e7140d8266dfccd59 +Subproject commit 174ae3bc8dbb688cfaccb4afa15d6e2cdbe19ce5 diff --git a/deepspeed/autotuning/config_templates/template_zero3.json b/deepspeed/autotuning/config_templates/template_zero3.json index e00f47f6..620d7eb1 100644 --- a/deepspeed/autotuning/config_templates/template_zero3.json +++ b/deepspeed/autotuning/config_templates/template_zero3.json @@ -11,7 +11,7 @@ "stage3_max_reuse_distance": 1e9, "stage3_prefetch_bucket_size": 5e8, "stage3_param_persistence_threshold": 1e6, - "stage3_gather_fp16_weights_on_model_save": false, + "stage3_gather_16bit_weights_on_model_save": false, "sub_group_size": 1e12 } } diff --git a/deepspeed/runtime/comm/coalesced_collectives.py b/deepspeed/runtime/comm/coalesced_collectives.py new file mode 100644 index 00000000..880a3cc4 --- /dev/null +++ b/deepspeed/runtime/comm/coalesced_collectives.py @@ -0,0 +1,116 @@ +"""batched collective operations for overhead amortization and better +bandwidth utilization""" + +import math +from typing import List + +import torch +from torch import Tensor +import torch.distributed +from torch.distributed import ProcessGroup +import torch.nn.functional + +from deepspeed.utils import instrument_w_nvtx +from deepspeed.utils.logging import logger + +if hasattr(torch.distributed, "_reduce_scatter_base"): + + def torch_reduce_scatter_fn(input_tensor: Tensor, output_tensor: Tensor, group): + instrument_w_nvtx(torch.distributed._reduce_scatter_base)( + output_tensor, + input_tensor, + group=group, + ) +else: + logger.warning( + "unable to find torch.distributed._reduce_scatter_base. will fall back to " + "torch.distributed.reduce_scatter which will result in suboptimal performance. " + "please consider upgrading your pytorch installation.") + + def torch_reduce_scatter_fn(input_tensor: Tensor, output_tensor: Tensor, group): + input_tensor_lst = list( + torch.chunk(input_tensor, + torch.distributed.get_world_size(group))) + instrument_w_nvtx(torch.distributed.reduce_scatter)( + output_tensor, + input_tensor_lst, + group=group, + ) + + +@instrument_w_nvtx +@torch.no_grad() +def reduce_scatter_coalesced( + tensors: List[Tensor], + group: ProcessGroup = None, +) -> List[Tensor]: + """simultaneously reduce-scatter a list of tensors - this can be done more + efficiently than individual reduce scatter calls + + TODO. see if PyTorch team wants a c++ verson of this for ProcessGroupNCCL + """ + this_rank = torch.distributed.get_rank(group) + world_sz = torch.distributed.get_world_size(group) + + partition_lst_for_each_tensor = [None] * len(tensors) + for tensor_idx, tensor in enumerate(tensors): + flattened_tensor = tensor.view(-1) + chunk_sz = math.ceil(tensor.numel() / world_sz) + partition_lst_for_each_tensor[tensor_idx] = [ + flattened_tensor[rank * chunk_sz:rank * chunk_sz + chunk_sz] + for rank in range(0, + world_sz) + ] + + padded_partition_sz_for_each_tensor = tuple( + math.ceil(t.numel() / world_sz) for t in tensors) + + if len(tensors) == 1 and tensors[0].numel() % world_sz == 0: + # if there's only one tensor being reduced and we don't need to pad + # we have an opportunity to avoid a memory allocation + tensor_partition_flat_buffer = tensors[0].view(-1) + else: + # interleave tensor partitions such that the correct reduced partitions of each tensor + # end up at each rank + tensor_partitions_lst_with_padding = [] + for rank in range(world_sz): + for tensor_idx in range(len(tensors)): + # add tensor content + tensor_chunk = partition_lst_for_each_tensor[tensor_idx][rank] + tensor_partitions_lst_with_padding.append(tensor_chunk) + + # add padding if necessary + padding_sz = padded_partition_sz_for_each_tensor[ + tensor_idx] - tensor_chunk.numel() + if padding_sz > 0: + tensor_partitions_lst_with_padding.append( + torch.empty(padding_sz, + dtype=tensor_chunk.dtype, + device=tensor_chunk.device)) + + tensor_partition_flat_buffer = instrument_w_nvtx( + torch.cat)(tensor_partitions_lst_with_padding) + + tensor_partition_flat_buffer.div_(world_sz) # pre-divide + tensor_partition_buffer_for_each_rank: List[Tensor] = torch.chunk( + tensor_partition_flat_buffer, + world_sz) + + # batched reduce-scatter call + torch_reduce_scatter_fn(tensor_partition_flat_buffer, + tensor_partition_buffer_for_each_rank[this_rank], + group) + + # reverse procedure of the interleaving done previously, done on the + # result of the batched reduce-scatter + output_lst: List[Tensor] = [None] * len(tensors) + offset = 0 + for tensor_idx in range(len(tensors)): + output_lst[tensor_idx] = tensor_partition_buffer_for_each_rank[this_rank].narrow( + 0, + offset, + partition_lst_for_each_tensor[tensor_idx][this_rank].numel()) + + offset += padded_partition_sz_for_each_tensor[tensor_idx] + + return output_lst diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 5cb6deb9..81db59b5 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -134,12 +134,12 @@ def get_fp16_enabled(param_dict): def get_bfloat16_enabled(param_dict): - if BFLOAT16 in param_dict.keys(): - return get_scalar_param(param_dict[BFLOAT16], - BFLOAT16_ENABLED, - BFLOAT16_ENABLED_DEFAULT) - else: - return False + for key in [BFLOAT16, BFLOAT16_OLD]: + if key in param_dict.keys(): + return get_scalar_param(param_dict[key], + BFLOAT16_ENABLED, + BFLOAT16_ENABLED_DEFAULT) + return False def get_fp16_master_weights_and_grads_enabled(param_dict): @@ -899,7 +899,7 @@ class DeepSpeedConfig(object): self.fp16_enabled = get_fp16_enabled(param_dict) self.bfloat16_enabled = get_bfloat16_enabled(param_dict) assert not (self.fp16_enabled and self.bfloat16_enabled), 'bfloat16 and fp16 modes cannot be simultaneously enabled' - assert not (self.bfloat16_enabled and (self.zero_optimization_stage != 2)), 'bfloat16 mode is only enabled for Zero2 currently' + assert not (self.bfloat16_enabled and (self.zero_optimization_stage not in {2, 3})), f'bfloat16 mode is only enabled for Zero 2 and 3 currently. got {self.zero_optimization_stage}' self.fp16_master_weights_and_gradients = get_fp16_master_weights_and_grads_enabled( param_dict) self.amp_enabled = get_amp_enabled(param_dict) diff --git a/deepspeed/runtime/constants.py b/deepspeed/runtime/constants.py index e524a460..84a8325b 100755 --- a/deepspeed/runtime/constants.py +++ b/deepspeed/runtime/constants.py @@ -114,11 +114,12 @@ SPARSE_GRADIENTS_DEFAULT = False # Users can configure in ds_config.json as below example: BFLOAT16_FORMAT = ''' BFLOAT16 parameters should be of the format: -"bfloat16": { +"bf16": { "enabled": true } ''' -BFLOAT16 = "bfloat16" +BFLOAT16 = "bf16" +BFLOAT16_OLD = "bfloat16" # keeping for backwards compatibility BFLOAT16_ENABLED = "enabled" BFLOAT16_ENABLED_DEFAULT = False diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 6271ee88..7d22e103 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -47,7 +47,7 @@ from deepspeed.runtime.sparse_tensor import SparseTensor import deepspeed.runtime.lr_schedules as lr_schedules import deepspeed.utils.groups as groups from deepspeed.runtime.utils import get_grad_norm -from deepspeed.utils import logger, log_dist, init_distributed +from deepspeed.utils import logger, log_dist, init_distributed, instrument_w_nvtx from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer from deepspeed.utils.debug import debug_extract_module_and_param_names from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop @@ -706,8 +706,8 @@ class DeepSpeedEngine(Module): def zero_param_persistence_threshold(self): return self._config.zero_config.param_persistence_threshold - def zero_gather_fp16_weights_on_model_save(self): - return self._config.zero_config.gather_fp16_weights_on_model_save + def zero_gather_16bit_weights_on_model_save(self): + return self._config.zero_config.gather_16bit_weights_on_model_save def zero_grad_hooks(self): return self._config.zero_config.grad_hooks @@ -969,6 +969,16 @@ class DeepSpeedEngine(Module): self.broadcast_src_rank, group=self.data_parallel_group) + @staticmethod + def __check_params(model: Module, dtype: torch.dtype) -> None: + if not all(param.dtype == dtype + for param in model.parameters()) and dist.get_rank() == 0: + raise ValueError( + f"{dtype} is enabled but the following parameters have dtype that is " + f"not {dtype}: " + f"{[(n, p.dtype) for n, p in model.named_parameters() if p.dtype != dtype]}" + ) + def _configure_distributed_model(self, model): self.module = model if self.fp16_enabled(): @@ -986,17 +996,13 @@ class DeepSpeedEngine(Module): ) self.module.half() elif self.bfloat16_enabled(): + if self.zero_optimization_partition_weights() and any( + hasattr(param, + 'ds_id') for param in self.module.parameters()): + self.__check_params(self.module, torch.bfloat16) self.module.bfloat16() else: - if not all( - [param.dtype == torch.float for param in self.module.parameters()]): - names = [ - n for n, - p in self.module.named_parameters() if p.dtype != torch.float - ] - raise ValueError( - f"fp32 is enabled but the following parameters have dtype that is not fp32: {', '.join(names)}" - ) + self.__check_params(self.module, torch.float) if not self.dont_change_device: self.module.to(self.device) @@ -1542,6 +1548,7 @@ class DeepSpeedEngine(Module): return scaled_loss + @instrument_w_nvtx def forward(self, *inputs, **kwargs): r"""Execute forward propagation Arguments: @@ -1637,6 +1644,7 @@ class DeepSpeedEngine(Module): f"rank={torch.distributed.get_rank()} time (ms) | forward: {fwd_time:.2f} (forward_moe: {moe_time:.2f}, 1st alltoall: {falltoall:.2f}, 2nd alltoall: {salltoall:.2f}, top-k: {gate_time:.2f})", ranks=[0]) + @instrument_w_nvtx def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): # Pass (PP) gas boundary flag to optimizer (required for zero) self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary( @@ -1654,6 +1662,7 @@ class DeepSpeedEngine(Module): else: self.buffered_allreduce_fallback(elements_per_buffer=bucket_size) + @instrument_w_nvtx def backward(self, loss, allreduce_gradients=True, release_loss=False): r"""Execute backward pass on the loss @@ -3013,7 +3022,7 @@ class DeepSpeedEngine(Module): self._copy_recovery_script(save_path) logger.info('zero checkpoint saved {}'.format(zero_checkpoint_name)) - def _zero3_consolidated_fp16_state_dict(self): + def _zero3_consolidated_16bit_state_dict(self): """ Get a full non-partitioned state_dict with fp16 weights on cpu. @@ -3082,9 +3091,14 @@ class DeepSpeedEngine(Module): return state_dict def save_fp16_model(self, save_dir, save_filename="pytorch_model.bin"): - r"""Save fp16 model weights + """has been renamed to save_16bit_model, keeping this around for backwards + compatibility""" + return self.save_16bit_model(save_dir, save_filename) + + def save_16bit_model(self, save_dir, save_filename="pytorch_model.bin"): + r"""Save 16bit model weights - This method saves the fp16 model weights at the desired destination. + This method saves the 16bit model weights at the desired destination. Arguments: save_dir: Required. Directory for saving the model @@ -3092,7 +3106,7 @@ class DeepSpeedEngine(Module): Returns: ``True`` when a model has been saved, ``False`` otherwise. It will not be saved if - stage3_gather_fp16_weights_on_model_save is ``False``. + stage3_gather_16bit_weights_on_model_save is ``False``. Important: all processes must call this method and not just the process with rank 0. It is because the processes need to work in sync to gather the weights. This method will hang @@ -3103,13 +3117,13 @@ class DeepSpeedEngine(Module): path = os.path.join(save_dir, save_filename) if self.zero_optimization_partition_weights(): - if self.zero_gather_fp16_weights_on_model_save(): + if self.zero_gather_16bit_weights_on_model_save(): # consolidation is expensive in time and memory and therefore isn't a default - state_dict = self._zero3_consolidated_fp16_state_dict() + state_dict = self._zero3_consolidated_16bit_state_dict() else: # the model will be bogus if not consolidated so don't confuse the user by saving it logger.info( - f"Did not save the model {path} because `stage3_gather_fp16_weights_on_model_save` is False" + f"Did not save the model {path} because `stage3_gather_16bit_weights_on_model_save` is False" ) return False else: diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index 22014aaf..2decd12f 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -858,3 +858,12 @@ def call_to_str(base, *args, **kwargs): name += ', '.join(f'{key}={repr(arg)}' for key, arg in kwargs.items()) name += ')' return name + + +def get_only_unique_item(items): + item_set = set(items) + if len(item_set) != 1: + raise RuntimeError(f"expected there to be only one unique element in {items}") + unique_item, = item_set + + return unique_item diff --git a/deepspeed/runtime/zero/config.py b/deepspeed/runtime/zero/config.py index 782d4d9e..3804fb50 100755 --- a/deepspeed/runtime/zero/config.py +++ b/deepspeed/runtime/zero/config.py @@ -36,7 +36,7 @@ class DeepSpeedZeroConfig(DeepSpeedConfigObject): self.param_persistence_threshold = None self.max_live_parameters = None self.max_reuse_distance = None - self.gather_fp16_weights_on_model_save = None + self.gather_16bit_weights_on_model_save = None self.ignore_unused_parameters = None self.round_robin_gradients = None @@ -171,10 +171,16 @@ class DeepSpeedZeroConfig(DeepSpeedConfigObject): ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD, ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT) - self.gather_fp16_weights_on_model_save = get_scalar_param( - zero_config_dict, - ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE, - ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE_DEFAULT) + # config key has been renamed to use "16bit" instead of "fp16." falling back + # to old config name in order to preserve backwards compatibility + self.gather_16bit_weights_on_model_save = ZERO_OPTIMIZATION_GATHER_16BIT_WEIGHTS_ON_MODEL_SAVE_DEFAULT + for key in [ + ZERO_OPTIMIZATION_GATHER_16BIT_WEIGHTS_ON_MODEL_SAVE, + ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE + ]: + if key in zero_config_dict: + self.gather_16bit_weights_on_model_save = zero_config_dict[key] + break self.ignore_unused_parameters = get_scalar_param( zero_config_dict, diff --git a/deepspeed/runtime/zero/constants.py b/deepspeed/runtime/zero/constants.py index 69f3eea6..804f15ba 100755 --- a/deepspeed/runtime/zero/constants.py +++ b/deepspeed/runtime/zero/constants.py @@ -116,7 +116,8 @@ ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT = 100000 # gathers params for saving a model - inefficient but is required in certain situations ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE = 'stage3_gather_fp16_weights_on_model_save' -ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE_DEFAULT = False +ZERO_OPTIMIZATION_GATHER_16BIT_WEIGHTS_ON_MODEL_SAVE = 'stage3_gather_16bit_weights_on_model_save' +ZERO_OPTIMIZATION_GATHER_16BIT_WEIGHTS_ON_MODEL_SAVE_DEFAULT = False # Now just used in stage2 complete_grad_norm_calculation_for_cpu_offload # Enable this option to avoid: @@ -164,8 +165,8 @@ ZERO_OPTIMIZATION_DEFAULT = { ZERO_OPTIMIZATION_PREFETCH_BUCKET_SIZE_DEFAULT, ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD: ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT, - ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE: - ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE_DEFAULT, + ZERO_OPTIMIZATION_GATHER_16BIT_WEIGHTS_ON_MODEL_SAVE: + ZERO_OPTIMIZATION_GATHER_16BIT_WEIGHTS_ON_MODEL_SAVE_DEFAULT, ZERO_OPTIMIZATION_IGNORE_UNUSED_PARAMETERS: ZERO_OPTIMIZATION_IGNORE_UNUSED_PARAMETERS_DEFAULT, ZERO_OPTIMIZATION_LEGACY_STAGE1: diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 2ad72280..42d56227 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -3,29 +3,68 @@ Licensed under the MIT license. """ +import math import os import time import types +from typing import Callable, Iterable from enum import Enum import functools import itertools +from typing import List import torch -from torch.distributed.distributed_c10d import _get_global_rank, group +from torch import Tensor import torch.distributed as dist +from torch.distributed.distributed_c10d import _get_global_rank, group +from torch.nn import Module +from torch.nn import Parameter from .linear import LinearModuleForZeroStage3, LinearFunctionForZeroStage3 from .offload_constants import * -from ..utils import see_memory_usage -from deepspeed.utils import log_dist, init_distributed, logger +from ..utils import get_only_unique_item, see_memory_usage +from deepspeed.runtime.zero.utils import assert_ints_same_as_other_ranks +from deepspeed.utils import init_distributed, instrument_w_nvtx, logger from deepspeed.utils.debug import debug_param2name_id_shape, debug_param2name_id_shape_device, debug_module2name, debug_param2name, debug_param2name_id_shape_status, printflock, log_rank_file +from deepspeed.utils.logging import logger from ..swap_tensor.partitioned_param_swapper import AsyncPartitionedParameterSwapper, PartitionedParamStatus from ..config import DeepSpeedConfig param_count = 0 -partitioned_param_data_shape = [1] +partitioned_param_data_shape = [0] + +if hasattr(torch.distributed, "_all_gather_base"): + + def torch_allgather_fn(input_tensor: Tensor, output_tensor: Tensor, group): + try: + return instrument_w_nvtx(torch.distributed._all_gather_base)( + output_tensor, + input_tensor, + group=group, + async_op=True, + ) + except RuntimeError as e: + raise RuntimeError( + f"output_tensor: {output_tensor.device}, input_tensor: {input_tensor.device}" + ) from e +else: + logger.warning( + "unable to find torch.distributed._all_gather_base. will fall back to " + "torch.distributed.all_gather which will result in suboptimal performance. " + "please consider upgrading your pytorch installation.") + + def torch_allgather_fn(input_tensor: Tensor, output_tensor: Tensor, group): + output_tensors = list( + torch.chunk(output_tensor, + torch.distributed.get_world_size(group))) + return instrument_w_nvtx(torch.distributed.all_gather)( + output_tensors, + input_tensor, + group=group, + async_op=True, + ) def print_rank_0(message, debug=False, force=False): @@ -39,6 +78,11 @@ def print_rank_0(message, debug=False, force=False): # log_rank_file(rank, message) +def debug_rank0(msg: str) -> None: + if torch.distributed.get_rank() == 0: + logger.debug(msg) + + def is_zero_param(parameter): if not torch.is_tensor(parameter): return False @@ -160,38 +204,35 @@ class ZeroParamStatus(Enum): _orig_torch_empty = torch.empty +_orig_torch_zeros = torch.zeros +_orig_torch_ones = torch.ones +_orig_torch_full = torch.full -def empty_cuda_tensor_half(*size, **kwargs): - if not 'device' in kwargs.keys(): - kwargs['device'] = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"])) - tensor = _orig_torch_empty(*size, **kwargs) - if tensor.is_floating_point(): - return tensor.half() - else: - return tensor - +def zero_wrapper_for_fp_tensor_constructor(fn: Callable, + target_fp_dtype: torch.dtype) -> Callable: + def wrapped_fn(*args, **kwargs) -> Tensor: + if kwargs.get("device", None) is None: + kwargs['device'] = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"])) + tensor: Tensor = fn(*args, **kwargs) + if tensor.is_floating_point(): + tensor = tensor.to(target_fp_dtype) -def new_cuda_tensor_half(cls, *args): - device = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"])) - tensor = torch.ones((1, 1), device=device).new_empty(*args).half() - if tensor.is_floating_point(): - return tensor.half() - else: return tensor + return wrapped_fn -def empty_cuda_tensor(*size, **kwargs): - if not 'device' in kwargs.keys(): - kwargs['device'] = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"])) - tensor = _orig_torch_empty(*size, **kwargs) - return tensor +def get_new_tensor_fn_for_dtype(dtype: torch.dtype) -> Callable: + def new_tensor(cls, *args) -> Tensor: + device = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"])) + tensor = _orig_torch_empty(0, device=device).new_empty(*args) + if tensor.is_floating_point(): + tensor = tensor.to(dtype) -def new_cuda_tensor(cls, *args): - device = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"])) - tensor = torch.ones((1, 1), device=device).new_empty(*args) - return tensor + return tensor + + return new_tensor # https://stackoverflow.com/a/63851681/9201239 @@ -208,6 +249,19 @@ def get_all_subclasses(cls): return set(subclass_list) +@instrument_w_nvtx +def free_param(param: Parameter) -> None: + """Free underlying storage of a parameter.""" + assert not param.ds_active_sub_modules, param.ds_summary() + if param.data.is_cuda: + # need to make sure that we don't free the parameter while it is still + # being used for computation + param.data.record_stream(torch.cuda.current_stream()) + # param.data doesn't store anything meaningful in partitioned state + param.data = torch.empty(0, dtype=param.dtype, device=param.device) + param.ds_status = ZeroParamStatus.NOT_AVAILABLE + + reuse_buffers = False temp_contiguous_tensor = None empty_buffers = {} @@ -224,12 +278,79 @@ class InsertPostInitMethodToModuleSubClasses(object): self.mem_efficient_linear = mem_efficient_linear self.enabled = enabled self._set_dtype(ds_config, dtype) - assert self.dtype in [torch.half, torch.float], f"Invalid data type {self.dtype}, allowed values are [torch.half, torch.float]" + assert self.dtype in [torch.half, torch.bfloat16, torch.float], f"Invalid data type {self.dtype}, allowed values are [torch.half, torch.bfloat16, torch.float]" def __enter__(self): if not self.enabled: return + def apply_with_gather(orig_module_apply_fn: Callable) -> Callable: + """many models make use of child modules like Linear or Embedding which + perform their own weight initialization in their __init__ methods, + but will then have more weight initialization in a parent module's __init__ + method that modifies weights of child modules, which is typically done + using the Module.apply method. + + since the Init context manager partitions child modules immediately after + they are initialized, without modifying apply we would entirely skip + any initialization done by parent modules. + + to get around this issue, we wrap the function passed to Module.apply + so that the applied function is applied to child modules correctly. + """ + def get_wrapped_fn_to_apply(fn_to_apply: Callable) -> Callable: + if hasattr(fn_to_apply, "wrapped"): + return fn_to_apply + + @functools.wraps(fn_to_apply) + def wrapped_fn_to_apply(module_to_apply_fn_to: Module) -> None: + """gathers parameters before calling apply function. afterwards + parameters are broadcasted to ensure consistency across all ranks + then re-partitioned. + + takes the following steps: + 1. allgathers parameters for the current module being worked on + 2. calls the original function + 3. broadcasts root rank's parameters to the other ranks + 4. re-partitions the parameters + """ + if not all( + is_zero_param(p) + for p in module_to_apply_fn_to.parameters(recurse=False)): + raise RuntimeError( + f"not all parameters for {module_to_apply_fn_to.__class__.__name__}, " + f"were zero params, is it possible that the parameters were " + f"overwritten after they were initialized? " + f"params: {[p for p in module_to_apply_fn_to.parameters(recurse=False)]} " + ) + + params_to_apply_fn_to: Iterable[Parameter] = list( + sorted(module_to_apply_fn_to.parameters(recurse=False), + key=lambda p: p.ds_id)) + + for param in params_to_apply_fn_to: + param.all_gather() + + fn_to_apply(module_to_apply_fn_to) + + for param in params_to_apply_fn_to: + torch.distributed.broadcast(param.data, + 0, + group=param.ds_process_group) + + for param in params_to_apply_fn_to: + param.partition(has_been_updated=True) + + wrapped_fn_to_apply.wrapped = True + + return wrapped_fn_to_apply + + @functools.wraps(orig_module_apply_fn) + def wrapped_apply(module: Module, fn_to_apply: Callable) -> None: + orig_module_apply_fn(module, get_wrapped_fn_to_apply(fn_to_apply)) + + return wrapped_apply + def partition_after(f): @functools.wraps(f) def wrapper(module, *args, **kwargs): @@ -279,18 +400,23 @@ class InsertPostInitMethodToModuleSubClasses(object): # print(f"subclass={subclass.__module__}.{subclass.__qualname__}") _enable_class(subclass) - # holding on to the current __init__subclass__ for exit + # holding onto some methods so we can put them back the way they were in __exit__ torch.nn.modules.module.Module._old_init_subclass = torch.nn.modules.module.Module.__init_subclass__ + torch.nn.modules.module.Module._old_apply = torch.nn.modules.module.Module.apply torch.Tensor.__old_new__ = torch.Tensor.__new__ # Replace .__init__() for future subclasses of torch.nn.Module torch.nn.modules.module.Module.__init_subclass__ = classmethod(_init_subclass) - if self.dtype == torch.half: - torch.Tensor.__new__ = new_cuda_tensor_half - torch.empty = empty_cuda_tensor_half - else: - torch.Tensor.__new__ = new_cuda_tensor - torch.empty = empty_cuda_tensor + torch.nn.modules.module.Module.apply = apply_with_gather( + torch.nn.modules.module.Module._old_apply) + + torch.Tensor.__new__ = get_new_tensor_fn_for_dtype(self.dtype) + torch.empty = zero_wrapper_for_fp_tensor_constructor(_orig_torch_empty, + self.dtype) + torch.zeros = zero_wrapper_for_fp_tensor_constructor(_orig_torch_zeros, + self.dtype) + torch.ones = zero_wrapper_for_fp_tensor_constructor(_orig_torch_ones, self.dtype) + torch.full = zero_wrapper_for_fp_tensor_constructor(_orig_torch_full, self.dtype) if self.mem_efficient_linear: print_rank_0( @@ -310,11 +436,15 @@ class InsertPostInitMethodToModuleSubClasses(object): for subclass in get_all_subclasses(torch.nn.modules.module.Module): _disable_class(subclass) - # Replace .__init__() for future subclasses of torch.nn.Module + # putting methods back the way we found them torch.nn.modules.module.Module.__init_subclass__ = torch.nn.modules.module.Module._old_init_subclass + torch.nn.modules.module.Module.apply = torch.nn.modules.module.Module._old_apply torch.Tensor.__new__ = torch.Tensor.__old_new__ torch.empty = _orig_torch_empty + torch.zeros = _orig_torch_zeros + torch.ones = _orig_torch_ones + torch.full = _orig_torch_full # un doing it here will undo it during training # if self.mem_efficient_linear: @@ -322,6 +452,10 @@ class InsertPostInitMethodToModuleSubClasses(object): # if self.mem_efficient_linear: # torch.nn.functional.linear = self.linear_bk + if torch.distributed.get_rank() == 0: + logger.info("finished initializing model with %.2fB parameters", + param_count / 1e9) + # Now that we cleaned up the metaclass injection, raise the exception. if exc_type is not None: return False @@ -332,11 +466,82 @@ class InsertPostInitMethodToModuleSubClasses(object): def _set_dtype(self, ds_config, dtype): if ds_config is not None and dtype is None: - self.dtype = torch.half if ds_config.fp16_enabled else torch.float - elif dtype is None: - self.dtype = torch.half + if ds_config.bfloat16_enabled and ds_config.fp16_enabled: + raise RuntimeError("bfloat16 and fp16 cannot be enabled at once") + + if ds_config.bfloat16_enabled: + self.dtype = torch.bfloat16 + elif ds_config.fp16_enabled: + self.dtype = torch.half + else: + self.dtype = torch.float else: - self.dtype = dtype + self.dtype = dtype or torch.half + + +class AllGatherHandle: + def __init__(self, handle, param: Parameter) -> None: + if param.ds_status != ZeroParamStatus.INFLIGHT: + raise RuntimeError(f"expected param {param.ds_summary()} to be available") + + self.__handle = handle + self.__param = param + + def wait(self) -> None: + instrument_w_nvtx(self.__handle.wait)() + self.__param.ds_status = ZeroParamStatus.AVAILABLE + + +class AllGatherCoalescedHandle: + def __init__( + self, + allgather_handle, + params: List[Parameter], + partitions: List[Tensor], + world_size: int, + ) -> None: + self.__allgather_handle = allgather_handle + self.__params = params + self.__partitions = partitions + self.__world_size = world_size + self.__complete = False + + for param in self.__params: + if param.ds_status != ZeroParamStatus.INFLIGHT: + raise RuntimeError( + f"expected param {param.ds_summary()} to not be available") + + @instrument_w_nvtx + def wait(self) -> None: + if self.__complete: + return + + instrument_w_nvtx(self.__allgather_handle.wait)() + + # split the single tensor out into individual tensors + param_offset = 0 + for param in self.__params: + assert param.ds_status == ZeroParamStatus.INFLIGHT, f"expected param {param.ds_summary()} to be inflight" + partitions: List[Tensor] = [] + for rank in range(self.__world_size): + param_start = rank * param.ds_tensor.ds_numel + if param_start < param.ds_numel: + part_to_copy = self.__partitions[rank].narrow( + 0, + param_offset, + min(param.ds_numel - param_start, + param.ds_tensor.ds_numel)) + partitions.append(part_to_copy) + + param.data = instrument_w_nvtx(torch.cat)(partitions).view(param.ds_shape) + param.ds_status = ZeroParamStatus.AVAILABLE + + for part_to_copy in partitions: + part_to_copy.record_stream(torch.cuda.current_stream()) + + param_offset += param.ds_tensor.ds_numel + + self.__complete = True # Replaces all parameters in module with Scattered Parameters @@ -549,6 +754,14 @@ class Init(InsertPostInitMethodToModuleSubClasses): print_rank_0( f"Partitioning param {debug_param2name_id_shape(param)} module={debug_module2name(module)}" ) + + if param.is_cuda: + torch.distributed.broadcast(param, 0, self.ds_process_group) + else: + if torch.distributed.get_rank() == 0: + logger.warn(f"param in {module.__class__.__name__} " + f"not on GPU so was not broadcasted from rank 0") + param.partition() see_memory_usage( f"Param count {param_count}. After converting and partitioning parmas in {module.__class__.__name__}", @@ -572,12 +785,14 @@ class Init(InsertPostInitMethodToModuleSubClasses): param.ds_tensor = None # Keeps track of how many active sub-modules need this param at any given point in time - param.ds_active_sub_modules = 0 + param.ds_active_sub_modules = set() # If this flag is true, then the parameters are replicated throughput training # And only partitioned before the step param.ds_persist = False + param.is_external_param = False + # The group that the parameter is scattered across. param.ds_process_group = self.ds_process_group @@ -595,6 +810,85 @@ class Init(InsertPostInitMethodToModuleSubClasses): param_list = [cls] return self._all_gather(param_list, async_op=async_op, hierarchy=hierarchy) + @instrument_w_nvtx + def all_gather_coalesced(params: Iterable[Parameter], + safe_mode: bool = False) -> AllGatherCoalescedHandle: + + # fetches from nvme if the partition is not available and in nvme + self._ensure_availability_of_partitioned_params(params) + + for param in params: + if param.ds_status != ZeroParamStatus.NOT_AVAILABLE: + raise RuntimeError(param.ds_summary()) + param.ds_status = ZeroParamStatus.INFLIGHT + + # ensure that each rank has params in same order. the allgather + # is done by flattening the parameter list into a single tensor that + # can be allgathered in a single call - this means that if each rank + # gives a list of the same parameters in a different order we will + # silently get incorrect parameter values, and have very difficult + # to debug correctness issues. + params = sorted(params, key=lambda p: p.ds_id) + + debug_rank0(f"-allgather_coalesced: {[p.ds_id for p in params]}") + + if safe_mode: + # ensure that same list (with same ordering) of parameters are + # being allgathered across all ranks, otherwise could mix + # data between tensors. + assert_ints_same_as_other_ranks([p.ds_id for p in params]) + # ensure that tensors from each rank agree on the same ds_numel + # otherwise could mix data between tensors. + assert_ints_same_as_other_ranks([p.ds_tensor.ds_numel for p in params]) + + if len(params) == 1: + # have an opportunity to avoid some intermediate memory allocations + param, = params + param_buffer = torch.empty( + math.ceil(param.ds_numel / self.world_size) * self.world_size, + dtype=param.dtype, + device=torch.cuda.current_device(), + requires_grad=False, + ) + handle = torch_allgather_fn( + param.ds_tensor.to(torch.cuda.current_device()), + param_buffer, + self.ds_process_group, + ) + param.data = param_buffer.narrow(0, + 0, + param.ds_numel).view(param.ds_shape).to( + param.device) + return AllGatherHandle(handle, param) + else: + partition_sz = sum(p.ds_tensor.ds_numel for p in params) + flat_tensor = torch.empty(partition_sz * self.world_size, + dtype=get_only_unique_item(p.dtype + for p in params), + device=torch.cuda.current_device(), + requires_grad=False) + partitions: List[Parameter] = [] + for i in range(self.world_size): + partitions.append( + flat_tensor.narrow(0, + partition_sz * i, + partition_sz)) + + instrument_w_nvtx(torch.cat)( + [p.ds_tensor.to(torch.cuda.current_device()) for p in params], + out=partitions[self.rank]) + + handle = torch_allgather_fn(partitions[self.rank], + flat_tensor, + self.ds_process_group) + + return AllGatherCoalescedHandle( + allgather_handle=handle, + params=params, + partitions=partitions, + world_size=self.world_size, + ) + def partition(param_list=None, hierarchy=0, has_been_updated=False): cls = param print_rank_0( @@ -639,11 +933,37 @@ class Init(InsertPostInitMethodToModuleSubClasses): def partitioned_size(): return self._partitioned_size(param) + def item_override(): + param.all_gather() + return param._orig_item() + + def ds_summary(slf: torch.Tensor) -> dict: + return { + "id": slf.ds_id, + "status": slf.ds_status.name, + "numel": slf.numel(), + "ds_numel": slf.ds_numel, + "shape": tuple(slf.shape), + "ds_shape": tuple(slf.ds_shape), + "requires_grad": slf.requires_grad, + "grad_shape": tuple(slf.grad.shape) if slf.grad is not None else None, + "persist": slf.ds_persist, + "active_sub_modules": slf.ds_active_sub_modules, + } + def convert_to_zero_parameters(param_list): self._convert_to_zero_parameters(param_list) + def allgather_before(func: Callable) -> Callable: + def wrapped(*args, **kwargs): + param.all_gather() + return func(*args, **kwargs) + + return wrapped + # Collectives for gathering and partitioning parameters param.all_gather = all_gather + param.all_gather_coalesced = all_gather_coalesced param.partition = partition # Collective for averaging gradients @@ -654,6 +974,9 @@ class Init(InsertPostInitMethodToModuleSubClasses): param.aligned_size = aligned_size param.padding_size = padding_size param.partitioned_size = partitioned_size + param.ds_summary = types.MethodType(ds_summary, param) + + param.item = allgather_before(param.item) param.convert_to_zero_parameters = convert_to_zero_parameters @@ -682,6 +1005,7 @@ class Init(InsertPostInitMethodToModuleSubClasses): elif len(swap_in_flight) > 0: swap_in_flight[0].nvme_swapper.synchronize_reads() + @instrument_w_nvtx def _all_gather(self, param_list, async_op=False, hierarchy=None): # fetches from nvme if the partition is not available and in nvme @@ -701,8 +1025,10 @@ class Init(InsertPostInitMethodToModuleSubClasses): all_gather_list.append(param) if not async_op: - # ret_value = self._allgather_params(all_gather_list, hierarchy=hierarchy) - ret_value = self._allgather_params_coalesced(all_gather_list, hierarchy) + if len(param_list) == 1: + ret_value = self._allgather_params(all_gather_list, hierarchy=hierarchy) + else: + ret_value = self._allgather_params_coalesced(all_gather_list, hierarchy) for param in all_gather_list: param.ds_status = ZeroParamStatus.AVAILABLE @@ -722,6 +1048,7 @@ class Init(InsertPostInitMethodToModuleSubClasses): #print_rank_0(f"After Partitioning Param {param.ds_id}") # self._param_status(param) + @instrument_w_nvtx def _partition_param(self, param, buffer=None, has_been_updated=False): assert param.ds_status is not ZeroParamStatus.INFLIGHT, f" {param} Cannot partition a param in flight" @@ -750,8 +1077,7 @@ class Init(InsertPostInitMethodToModuleSubClasses): f'Before partitioning param {param.ds_id} {param.shape}', force=False) # param.data does not store anything meaningful in partitioned state - param.data = torch.empty(1, dtype=self.dtype, device=param.device) - + free_param(param) see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}', force=False) @@ -772,7 +1098,7 @@ class Init(InsertPostInitMethodToModuleSubClasses): numel=partition_size): final_location = OFFLOAD_NVME_DEVICE buffer = self.param_swapper.get_buffer(param, partition_size) - partitioned_tensor = torch.empty(1, + partitioned_tensor = torch.empty(0, dtype=param.dtype, device=buffer.device) partitioned_tensor.data = buffer.data @@ -831,7 +1157,7 @@ class Init(InsertPostInitMethodToModuleSubClasses): see_memory_usage(f'Before partitioning param {param.ds_id} {param.shape}', force=False) - param.data = torch.ones(1, dtype=self.dtype).to(param.device) + free_param(param) see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}', force=False) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index cb8ab3c8..cdec9bda 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -3,19 +3,26 @@ Licensed under the MIT license. """ -import sys +import gc +from dataclasses import dataclass +import functools import os -from collections import defaultdict, OrderedDict +import collections +from collections import OrderedDict, UserDict import itertools +from typing import Deque, Dict, Iterable, Set, Tuple import torch -from torch.distributed.distributed_c10d import _get_global_rank +from torch.cuda import Event, Stream +from torch.nn import Module, Parameter import torch.distributed as dist import math from torch._six import inf -from torch.autograd import Variable +from torch.nn import Module +from torch.nn.parameter import Parameter from deepspeed.utils.logging import logger from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler +from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced from deepspeed.runtime.utils import get_global_norm, see_memory_usage, is_model_parallel_parameter, DummyOptim from deepspeed.runtime.zero.partition_parameters import * from deepspeed.runtime.zero.partition_parameters import _init_external_params @@ -51,20 +58,6 @@ def input(msg): return -def split_half_float_double(tensors): - dtypes = [ - "torch.cuda.HalfTensor", - "torch.cuda.FloatTensor", - "torch.cuda.DoubleTensor" - ] - buckets = [] - for i, dtype in enumerate(dtypes): - bucket = [t for t in tensors if t.type() == dtype] - if bucket: - buckets.append(bucket) - return buckets - - def isclose(a, b, rtol=1e-09, atol=0.0): return abs(a - b) <= max(rtol * max(abs(a), abs(b)), atol) @@ -74,16 +67,26 @@ def lcm(x, y): return x * y // gcd(x, y) +def debug_rank0(message: str) -> None: + if dist.get_rank() == 0: + logger.debug(message) + + def move_to_cpu(tensor_list): for tensor in tensor_list: tensor.data = tensor.data.cpu() +@instrument_w_nvtx def get_all_parameters(sub_module, recurse=False): return itertools.chain(sub_module.named_parameters(recurse=recurse), sub_module.ds_external_parameters()) +def iter_params(module: Module, recurse=False) -> Iterable[Parameter]: + return map(lambda pair: pair[1], get_all_parameters(module, recurse)) + + #apply torch.autograd.Function that calls a backward_function to tensors in output def _apply_to_tensors_only(module, functional, backward_function, outputs): if type(outputs) is tuple: @@ -166,383 +169,329 @@ def _inject_parameters(module, cls): module._parameters = new_param -# TODO Needs to be implemented -class PrefetchCoordinator(object): - def __init__(self): - # step_id keeps track of the number of sub-modules invoked so far - # the step_id is tracking forward and backward sequence of sub-modules - self.step_id = 0 - - # stores the sequence of sub modules in forward+backward pass - self.sub_module_trace = [] - - # maps sub_module id to submodule objects - self.id_to_sub_module_map = {} - - # stores the total number of parameters in each sub_module - self.id_to_sub_module_size_map = {} - - self.trace_completed = False - - self.most_recent_sub_module_step = {} +class PartitionedParameterCoordinator: + """Handles partitioning and gathering of parameters.""" + class __InflightParamRegistry(UserDict): + """registry for parameters in flight""" + def __setitem__(self, + param: Parameter, + handle: AllGatherCoalescedHandle) -> None: + if param in self.data: + raise RuntimeError(f"{param.ds_summary()} already in registry") + if param.ds_status != ZeroParamStatus.INFLIGHT: + raise RuntimeError( + f"attempted to add non-inflight parameter to registry {param.ds_summary()}" + ) + self.data[param] = handle + + @dataclass + class __ParamInTrace: + param: Parameter + step_id_last_used_at: int + + def __init__( + self, + prefetch_bucket_sz: int, + max_reuse_distance_in_numel: int, + max_available_parameters_in_numel: int, + allgather_stream: Stream, + prefetch_nvme: bool = False, + ) -> None: + # mapping of param -> handle for each param that is currently in flight + self.__inflight_param_registry = __class__.__InflightParamRegistry() + # keeps track of the number of submodules invoked so far. + self.__step_id: int = 0 + # whether or not we have completed a trace of the entire network. This should + # always be true after the first forward pass + backward pass. + self.trace_complete: bool = False + # sequence of submodules/parameters in forward pass + backward pass + self.__submodule_order: Iterable[Module] = [] + self.__param_order: Iterable[__class__.__ParamInTrace] = [] + self.__most_recent_step_id_param_fetched_for = collections.defaultdict( + lambda: int(-1e10)) + # number of available params, and max number of available params + self.__n_available_params: int = 0 + self.__max_n_available_params: int = max_available_parameters_in_numel + # max distance between two use of the module beyond which module is released + self.__max_reuse_dist_in_numel: int = max_reuse_distance_in_numel + # queue for parameters to fetch. parameters will be popped off the left + # side of the dequeue as they are fetched + self.__param_queue: Deque[__class__.__ParamInTrace] = None + self.__prefetch_bucket_sz: int = prefetch_bucket_sz + self.__prefetch_nvme: bool = prefetch_nvme + self.hierarchy: int = 0 + + # stream that will be used for allgather operations + self.__allgather_stream: Stream = allgather_stream + + # limit the number of fetch events that can be queued at once + # otherwise, what happens is memory is allocated by the host thread at the + # time of the call, but not used until later by the asynchronous cuda stream. + # allowing an infinite number of these to queue up causes a lot of memory + # pressure that then becomes detrimental to performance. + # this is a much less elegant way of fixing this vs something like using + # cudaMallocAsync/cudaFreeAsync. Choosing to not expose this to the user now + # because ideally in the future its replaced by an async allocation + # mechanism which doesnt require any configuration by the user. + self.__ongoing_fetch_events: Deque[Event] = collections.deque() + # TODO. make this configurable via JSON + self.__max_ongoing_fetch_events: int = 2 + + """Tracing and Tracking + TODO. consider performing trace before initializing PartitionedParameterCoordinator + and passing trace results into constructor. This way all the code in here can + just assume that the trace is complete and the results can be entirely + immutable. + + Bookkeeping operations used to track where we are in the forward/backward pass + """ - # reuse distances - self.reuse_numel_for_step_id = {} + def record_trace(self, sub_module: Module) -> None: + """adds sub module to trace""" + if self.trace_complete: + raise RuntimeError( + "attemted to record trace when trace was already complete") + + self.__submodule_order.append(sub_module) + for param in sorted(set(iter_params(sub_module)), key=lambda p: p.ds_id): + self.__param_order.append( + __class__.__ParamInTrace(param=param, + step_id_last_used_at=self.__step_id)) + + def reset_step(self) -> None: + """indicate that we have completed one fwd+bwd for the model""" + if self.__inflight_param_registry: + raise RuntimeError( + f"still have inflight params " + f"{[p.ds_summary for p in self.__inflight_param_registry.keys()]}") + + if not self.trace_complete: + # make sure that recorded parameter and submodule orders are + # identical across ranks + assert_ints_same_as_other_ranks([m.id for m in self.__submodule_order]) + assert_ints_same_as_other_ranks([p.param.ds_id for p in self.__param_order]) + assert_ints_same_as_other_ranks( + [p.step_id_last_used_at for p in self.__param_order]) + + self.__submodule_order = tuple(self.__submodule_order) # freeze + self.__param_order = tuple(self.__param_order) # freeze + self.trace_complete = True + print_rank_0(f"completed trace: {[m.id for m in self.__submodule_order]}", + force=True) + + self.__param_queue = collections.deque(self.__param_order) # reset fetch queue + self.__most_recent_step_id_param_fetched_for = collections.defaultdict( + lambda: int(-1e10)) + self.__step_id = 0 + self.__n_available_params = 0 + + """Fetch and Release + Fetching, prefetching, and releasing parameters + """ - def record_trace(self, sub_module): - if not self.trace_completed: - self.sub_module_trace.append(sub_module.id) - self.id_to_sub_module_map[sub_module.id] = sub_module + @instrument_w_nvtx + @torch.no_grad() + def fetch_sub_module(self, current_submodule: Module) -> None: + """This method does the following (in order): + 1. kick off fetch for parameters in immediately required sub module + 2. kick off fetch for next few parameters we will need later (prefetch) + 3. block on parameters in immediately required sub module + """ + debug_rank0( + f"{self.__step_id}: M{current_submodule.id}({type(current_submodule).__name__}) P{[p.ds_id for p in iter_params(current_submodule)]} " + + str({ + "avail": f"{self.__n_available_params:.1e}", + "queue_sz": f"{len(self.__param_queue or [])}", + "inflight": [p.ds_id for p in self.__inflight_param_registry], + })) - def print_trace(self): - print_rank_0( - f"The module trace is : {[self.id_to_sub_module_map[module_id].id for module_id in self.sub_module_trace]}" - ) + params_to_fetch = frozenset(iter_params(current_submodule)) - def increment_step(self, sub_module): - self.most_recent_sub_module_step[sub_module.id] = self.step_id - self.step_id += 1 + # kick off all gather for params in the immediately required submodule + for param in params_to_fetch: + debug_rank0(f"-fetch: {param.ds_summary()}") + self.__all_gather_params(params_to_fetch) - def reset_step(self): - self.step_id = 0 + # wait for parameters in the immediately needed submodule to become available + for param in params_to_fetch: + param.ds_active_sub_modules.add(current_submodule.id) + debug_rank0(f"-wait: {param.ds_summary()}") + if param in self.__inflight_param_registry: + with torch.cuda.stream(self.__allgather_stream): + while self.__ongoing_fetch_events and self.__ongoing_fetch_events[ + 0].query(): + self.__ongoing_fetch_events.popleft() + if len(self.__ongoing_fetch_events + ) > self.__max_ongoing_fetch_events: + self.__ongoing_fetch_events.popleft().synchronize() + + self.__inflight_param_registry.pop(param).wait() + + event = Event() + event.record() + self.__ongoing_fetch_events.append(event) + + assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary() + torch.cuda.current_stream().wait_stream(self.__allgather_stream) + + # kick off parameter prefetches for upcoming modules + # don't prefetch if we dont have a completed model trace, or if we aren't + # training (throws off the tracing and don't want to prefetch modules for bwd) + if self.trace_complete and current_submodule.training: + # go through the parameters we need for the current module and pop them + # off the fetch queue so that they aren't prefetched later. + # if params have already been popped off the fetch queue by earlier + # prefetches we won't look for them here + discarded_from_prefetch_queue = set() + params_not_already_fetched = set( + filter( + lambda p: self.__most_recent_step_id_param_fetched_for[p] < self. + __step_id, + params_to_fetch)) + while self.__param_queue and len(discarded_from_prefetch_queue) < len( + params_not_already_fetched): + param_in_trace = self.__param_queue.popleft() + self.__most_recent_step_id_param_fetched_for[ + param_in_trace.param] = param_in_trace.step_id_last_used_at + discarded_from_prefetch_queue.add(param_in_trace.param) + if discarded_from_prefetch_queue != params_not_already_fetched: + raise RuntimeError( + f"tracing error at step {self.__step_id}: " + f"expected the next {len(params_not_already_fetched)} parameters in the " + f"parameter fetch queue to be {tuple(p.ds_summary() for p in params_not_already_fetched)} " + f"but got {tuple(p.ds_summary() for p in discarded_from_prefetch_queue)}." + ) - # returns the next numel parameters that will be used next but are not available or inflight - def get_params_to_prefetch(self, sub_module, numel=2000000): + # kick off all gather for params in the next few submodules (prefetch) + max_params_to_prefetch = min( + self.__max_n_available_params - self.__n_available_params, + self.__prefetch_bucket_sz) + params_to_prefetch = set() + numel_prefetching = 0 + while self.__param_queue and numel_prefetching < max_params_to_prefetch: + param_in_trace: __class__.__ParamInTrace = self.__param_queue.popleft() + self.__most_recent_step_id_param_fetched_for[ + param_in_trace.param] = param_in_trace.step_id_last_used_at + if param_in_trace.param not in params_to_prefetch: + params_to_prefetch.add(param_in_trace.param) + numel_prefetching += param_in_trace.param.ds_numel + for param in params_to_prefetch: + debug_rank0(f"-prefetch: {param.ds_summary()}") + self.__all_gather_params(params_to_prefetch) + + if self.__prefetch_nvme: + self.__prefetch_nvme_param_partitions() + + self.__step_id += 1 + + @instrument_w_nvtx + @torch.no_grad() + def release_sub_module(self, submodule: Module) -> None: + """release the parameters of a sub module, assuming they meet conditions to + be released.""" + params_to_release = (self.__params_to_release(submodule, + self.__step_id) + if self.trace_complete else set( + p.ds_id for p in iter_params(submodule))) + + for param in iter_params(submodule): + param.ds_active_sub_modules.discard(submodule.id) + if param.ds_id in params_to_release and not param.is_external_param: + self.__release_param(param) + + @instrument_w_nvtx + @torch.no_grad() + def release_and_reset_all(self) -> None: + """release all module parameters""" + for param in map(lambda p: p.param, self.__param_order): + if param in self.__inflight_param_registry: + raise RuntimeError(f"param {param.ds_summary()} still in flight") + + # TODO. make this throw if if there are still active submodules. currently + # there's a hook execution issue + param.ds_active_sub_modules.clear() + self.__release_param(param) + + for param_in_trace in self.__param_order: + if param_in_trace.param.ds_status != ZeroParamStatus.NOT_AVAILABLE: + raise RuntimeError( + f"{param_in_trace.param.ds_summary()} expected to be released") + + @instrument_w_nvtx + def __all_gather_params(self, params: Set[Parameter]) -> None: + """for each partitioned parameter, kick off an async allgather and store + the work handle for the in flight parameters.""" + partitioned_params = [] + for param in params: + if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + partitioned_params.append(param) + self.__n_available_params += param.ds_numel - # numel_in_sub_module = 0 - # for name, param in sub_module.named_parameters(recurse=False): - # numel_in_sub_module += param.ds_numel + if partitioned_params: + with torch.cuda.stream(self.__allgather_stream): + handle = partitioned_params[0].all_gather_coalesced(partitioned_params) - # #if numel_in_sub_module < (numel // 2): - # return [] + for param in partitioned_params: + assert param.ds_status == ZeroParamStatus.INFLIGHT, param.ds_summary() + self.__inflight_param_registry[param] = handle - # tracing failed. The sub_module passed at the step_id must match with the sub_module during tracing - if sub_module.id != self.sub_module_trace[self.step_id]: - print_rank_0( - f"Tracing failed. Prefetching is disabled at sub-module: {debug_module2name_id(sub_module)}" - ) - return [] - - params_to_prefetch = [] - total_numel_to_prefetch = 0 - - for i in range(self.step_id, len(self.sub_module_trace)): - module_id = self.sub_module_trace[i] - for _, param in get_all_parameters(self.id_to_sub_module_map[module_id]): - if param.ds_status is ZeroParamStatus.NOT_AVAILABLE and ( - param.ds_id not in [p.ds_id for p in params_to_prefetch]): - params_to_prefetch.append(param) - total_numel_to_prefetch += param.ds_numel - #print_rank_0(f"Total numel to prefetch: {total_numel_to_prefetch}. Param: {param.ds_shape} and numel {param.ds_numel}, numel limit {numel}") - if total_numel_to_prefetch >= numel: # and total_numel_to_prefetch > (numel_in_sub_module // 2): - return params_to_prefetch - - return params_to_prefetch - - # checks if this sub_module will be used again and if so then returns the number of elements - # in the parameters used between this sub_module and the reuse of this sub_module - def get_reuse_distance_in_numel(self, sub_module, sub_module_step_id=None): - #assert is_forward is not None, "is_forward must be set to True for Forward Propagation and False for backward Propagation" - is_there_reuse = False - reuse_distance_in_numel = 1000000000000 - - # set the appropriate trace - trace = self.sub_module_trace - total_steps = len(trace) - if sub_module_step_id is None: - sub_module_step_id = self.most_recent_sub_module_step[sub_module.id] - - # tracing failed. The sub_module passed at the step_id must match with the sub_module during tracing - if sub_module.id != trace[sub_module_step_id]: - print_rank_0( - f"Tracing failed. Cannot tell if the sub_module: {sub_module.id} is reused" - ) - return reuse_distance_in_numel - - # return cached value - if sub_module_step_id in self.reuse_numel_for_step_id: - return self.reuse_numel_for_step_id[sub_module_step_id] - - start_step = self.step_id - print_rank_0(f"Step id is {self.step_id} ") - for step_id in range(start_step, total_steps): - print_rank_0(f"Trace id {trace[step_id]} and sub_module id {sub_module.id}") - if sub_module.id == trace[step_id]: - end_step = step_id - - is_there_reuse = True - reuse_distance_in_numel = self._distance_in_numel( - start_step, - end_step, - trace) + @instrument_w_nvtx + def __release_param(self, param: Parameter) -> None: + if param.ds_status == ZeroParamStatus.AVAILABLE and not param.ds_active_sub_modules: + debug_rank0(f"-release: {param.ds_summary()}") + param.partition() + self.__n_available_params -= param.ds_numel + + @instrument_w_nvtx + @functools.lru_cache(maxsize=None) + def __params_to_release(self, + submodule_to_release: Module, + step_id: int) -> Set[int]: + if not self.trace_complete: + raise RuntimeError("expected trace to be complete") + + params_to_release = set(p.ds_id for p in iter_params(submodule_to_release) + if not p.ds_persist) + + # examine all modules within `max_reuse_dist_in_numel` of the current step, + # if we see any of the candidate parameters to be released reoccur while + # doing this, remove them from the set of parameters to release. + params_traversed = 0 + for module in self.__submodule_order[step_id:]: + if params_traversed > self.__max_reuse_dist_in_numel: break + for param in iter_params(module): + params_to_release.discard(param.ds_id) + params_traversed += param.ds_numel - self.reuse_numel_for_step_id[sub_module_step_id] = reuse_distance_in_numel + return params_to_release - return reuse_distance_in_numel - - def _distance_in_numel(self, start_step, end_step, trace): - distance_in_numel = 0 - for step_id in range(start_step, end_step): - module_id = trace[step_id] - for _, param in self.id_to_sub_module_map[module_id].named_parameters(recurse=False): - distance_in_numel += param.ds_numel - for _, param in self.id_to_sub_module_map[module_id].ds_external_parameters(): - distance_in_numel += param.ds_numel - return distance_in_numel - - -class PartitionedParameterCoordinator(object): - def __init__(self, - comm_stream=None, - max_reuse_distance_in_numel=500000000, - max_available_parameters_in_numel=700000000): - - self.in_flight_handles = [] - self.params_in_flight = [] - self.comm_stream = comm_stream if comm_stream is not None else torch.cuda.current_stream( - ) - self.prefetch_coordinator = PrefetchCoordinator() - self.hierarchy = 0 - - self.total_available_parameter_numel = 0 - self.max_available_parameters_in_numel = max_available_parameters_in_numel - - # max distance between two use of the module beyond which module is released - self.max_reuse_distance_in_numel = max_reuse_distance_in_numel - - def _increment_available_parameter_numel(self, increment): - self.total_available_parameter_numel += increment - - def _decrement_available_parameter_numel(self, decrement): - self.total_available_parameter_numel -= decrement - - '''-----------------------Tracing and Prefetching ---------------''' - - def record_trace(self, sub_module): - self.prefetch_coordinator.record_trace(sub_module) - - def finish_tracing(self, print_trace=False): - self.prefetch_coordinator.trace_completed = True + @instrument_w_nvtx + def __prefetch_nvme_param_partitions(self) -> None: + """swap in parameter partitions from nvme for those parameters that will be used + after the ones that are already being prefetched into full parameters + """ + if not self.trace_complete: + return - if print_trace: - self.prefetch_coordinator.print_trace() + numel_in_flight = sum(param.ds_numel for param in self.__inflight_param_registry) - #swap in parameter partitions from nvme for those parameters that will be used - # after the ones that are already being prefetched into full parameters - def _prefetch_nvme_param_partitions(self, sub_module, params_in_flight): - numel_in_flight = sum([param.ds_tensor.ds_numel for param in params_in_flight]) - upcoming_param_list = self.prefetch_coordinator.get_params_to_prefetch( - sub_module, - numel=2 * numel_in_flight) + numel_considered = 0 swap_in_params = [] - for param in upcoming_param_list: - if len(swap_in_params) >= param.nvme_swapper.available_swap_in_buffers(): + for param_in_trace in self.__param_queue: + param = param_in_trace.param + if param.nvme_swapper is None: + continue + if (numel_considered > 2 * numel_in_flight or len(swap_in_params) >= + param.nvme_swapper.available_swap_in_buffers()): break if param.ds_tensor.status == PartitionedParamStatus.NOT_AVAILABLE: swap_in_params.append(param) + numel_considered += param.ds_numel - if len(swap_in_params) > 0: + if swap_in_params: swap_in_params[0].nvme_swapper.swap_in(swap_in_params, async_op=True) - # Pre fetches the parameters for sub_modules that comes after - # the current sub_module. This call is asynchronous - def prefetch_next_sub_modules(self, sub_module, numel=5000000, nvme=False): - - params_to_prefetch = [] - if not self.prefetch_coordinator.trace_completed: - return params_to_prefetch - - # prefetch if there is no current prefetching in flight - if not self.in_flight_handles and self.total_available_parameter_numel < self.max_available_parameters_in_numel: - params_to_prefetch = self.prefetch_coordinator.get_params_to_prefetch( - sub_module, - numel=numel) - - self._all_gather(params_to_prefetch, async_op=True) - for param in params_to_prefetch: - param.ds_status = ZeroParamStatus.INFLIGHT - - # keeping track of number of elements consumed by available parameters - self._increment_available_parameter_numel(param.ds_numel) - - if nvme: - self._prefetch_nvme_param_partitions(sub_module, params_to_prefetch) - - self._print_prefetch_elements_info(sub_module, params_to_prefetch) - print_rank_0( - f"{'--' * self.hierarchy}--PreFetching parameters {[param.ds_id for param in params_to_prefetch]} and available {self.total_available_parameter_numel}, max limit {self.max_available_parameters_in_numel}", - force=False) - - def _print_prefetch_elements_info(self, sub_module, params_to_prefetch): - sub_module_numel = 0.0 - for name, param in sub_module.named_parameters(recurse=False): - sub_module_numel += param.ds_numel - numel_being_prefetched = 0 - for param in params_to_prefetch: - numel_being_prefetched = param.ds_numel - print_rank_0( - f"{'--' * self.hierarchy}--PreFetching {numel_being_prefetched} numels and number of numel in the next sub module is {sub_module_numel}", - force=False) - - def increment_step(self, sub_module): - self.prefetch_coordinator.increment_step(sub_module) - - def reset_step(self): - self.prefetch_coordinator.reset_step() - - '''----------------------------------------------------------------------''' - - # Fetches the parameters in the sub_module - # This call is blocking - def fetch_sub_module(self, sub_module): - partitioned_params = [] - params_in_flight = False - print_rank_0( - f"{'--' * self.hierarchy}Fetching params in module {debug_module2name_class(sub_module)}" - ) - params_to_fetch = [ - param for _, - param in sub_module.named_parameters(recurse=False) - ] - # print([n for n,p in sub_module.named_parameters(recurse=False)]) - - if hasattr(sub_module, 'ds_external_parameters'): - print_rank_0( - f"{'--' * self.hierarchy}--Fetching external parameters {sub_module.ds_external_parameters()}" - ) - params_to_fetch += [ - param for _, - param in sub_module.ds_external_parameters() - ] - # for _, param in sub_module.named_parameters(recurse=False): - for param in params_to_fetch: - param.ds_active_sub_modules += 1 - print_rank_0( - f"{'--' * self.hierarchy}--Fetching parameters {debug_param2name_id_shape(param)} with active sub modules {param.ds_active_sub_modules}" - ) - - if param.ds_status == ZeroParamStatus.AVAILABLE: - print_rank_0( - f"{'--' * self.hierarchy}--Parameter {debug_param2name_id(param)} is already available" - ) - - if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: - print_rank_0( - f"{'--' * self.hierarchy}--Parameter {debug_param2name_id(param)} is being fetched" - ) - partitioned_params.append(param) - - # keeping track of number of elements consumed by available parameters - self._increment_available_parameter_numel(param.ds_numel) - print_rank_0(f"Incrementing with parameter id {param.ds_id}") - - if param.ds_status == ZeroParamStatus.INFLIGHT: - params_in_flight = True - print_rank_0( - f"{'--' * self.hierarchy}--Parameters {debug_param2name_id(param)} is already in flight (prefetched)" - ) - self.hierarchy += 1 - - # parameters are partitioned and need to be allgathered - self._all_gather(partitioned_params, async_op=False) - - # parameters are inflight and communication needs to be completed - if partitioned_params or params_in_flight: - self._synchronize_communication() - - for _, param in sub_module.named_parameters(recurse=False): - param.ds_status = ZeroParamStatus.AVAILABLE - print_rank_0( - f"Param {debug_param2name_id_shape_device(param)} norm={param.norm()}", - force=False) - #print_rank_0(f"After fetching (id, shape, device): {[(param.ds_id, param.shape, param.device) for param in sub_module.named_parameters(recurse=False)]}") - - def release_sub_module(self, sub_module): - self.hierarchy -= 1 - print_rank_0( - f"{'--' * self.hierarchy}Releasing params in module {debug_module2name_class(sub_module)}" - ) - params_to_release = [ - param for _, - param in sub_module.named_parameters(recurse=False) - ] - - if hasattr(sub_module, 'ds_external_parameters'): - #print_rank_0(f"Releasing external parameters {sub_module.ds_external_parameters()}") - params_to_release += [ - param for _, - param in sub_module.ds_external_parameters() - ] - - # for _, param in sub_module.named_parameters(recurse=False): - for param in params_to_release: - param.ds_active_sub_modules -= 1 - if not param.ds_active_sub_modules and not self._keep_for_later( - sub_module) and not param.ds_persist: - print_rank_0( - f"{'--' * self.hierarchy}--Releasing parameter {debug_param2name_id_numel(param)} active sub modules {param.ds_active_sub_modules} and keep for later {self._keep_for_later(sub_module)}", - force=False) - - # Keeping track of number of elements that are consumed by available parameters - self._decrement_available_parameter_numel(param.ds_numel) - see_memory_usage( - f"Before releasing param {debug_param2name_id_numel(param)}", - force=False) - param.partition(hierarchy=self.hierarchy) - see_memory_usage( - f"After releasing param {debug_param2name_id_numel(param)}", - force=False) - - param.ds_status = ZeroParamStatus.NOT_AVAILABLE - else: - - print_rank_0( - f"{'--' * self.hierarchy}--Did not release param {debug_param2name_id_numel(param)} with active sub modules {param.ds_active_sub_modules}, keep for later={self._keep_for_later(sub_module)} and persistence={param.ds_persist}", - force=False) - - def release_and_reset_parameter(self, param): - param.ds_active_sub_modules = 0 - if param.ds_status == ZeroParamStatus.AVAILABLE: - print_rank_0( - f"Releasing unpartitioned param {debug_param2name_id_numel(param)} active sub-modules {param.ds_active_sub_modules} and persistence {param.ds_persist}" - ) - self._decrement_available_parameter_numel(param.ds_numel) - param.partition() - - def _keep_for_later(self, sub_module): - if not self.prefetch_coordinator.trace_completed: - return False - if self.max_reuse_distance_in_numel == 0: - return False - reuse_distance_in_numel = self.prefetch_coordinator.get_reuse_distance_in_numel( - sub_module) - #print_rank_0(f"Reuse distance and numel for sub_module id {sub_module.id} is {reuse_distance_in_numel}") - return reuse_distance_in_numel < self.max_reuse_distance_in_numel - - def _all_gather(self, partitioned_params, async_op=False): - with torch.cuda.stream(self.comm_stream): - handles = partitioned_params[0].all_gather( - param_list=partitioned_params, - async_op=async_op, - hierarchy=self.hierarchy) if partitioned_params else None - - if handles is not None: - self.in_flight_handles.extend(handles) - self.params_in_flight.extend(partitioned_params) - - def _synchronize_communication(self, synchronize_streams=True): - assert len(self.params_in_flight) == len(self.in_flight_handles) - for handle, param in zip(self.in_flight_handles, self.params_in_flight): - if handle is not None: - with torch.cuda.stream(self.comm_stream): - handle.wait() - param.ds_status = ZeroParamStatus.AVAILABLE - self.comm_stream.synchronize() - torch.cuda.synchronize() if synchronize_streams else None - self.in_flight_handles = [] - self.params_in_flight = [] - class PreBackwardFunction(torch.autograd.Function): @staticmethod @@ -637,6 +586,7 @@ class DeepSpeedZeroOptimizer_Stage3(object): see_memory_usage("Stage 3 initialize beginning", force=False) if dist.get_rank() == 0: + logger.info(f"initialized {__class__.__name__} with args: {locals()}") logger.info(f"Reduce bucket size {reduce_bucket_size}") logger.info(f"Allgather bucket size {prefetch_bucket_size}") # The fused optimizer does all the work. We need this layer for two reason: @@ -679,32 +629,42 @@ class DeepSpeedZeroOptimizer_Stage3(object): self.module = module self.elastic_checkpoint = elastic_checkpoint - self.overlap_comm = overlap_comm # Replace ._parameters with a new class to enable auto-registration of # external parameters _inject_parameters(module, ZeROOrderedDict) - if self.overlap_comm: - self.gpu_sum = torch.zeros(1, dtype=torch.float).cuda() + self.__inf_or_nan_tracker: Tensor = torch.zeros( + 1, + dtype=torch.bool, + device=torch.cuda.current_device(), + requires_grad=False) self.deepspeed_adam_offload = (self.offload_optimizer and type(init_optimizer) == DeepSpeedCPUAdam) self.device = torch.cuda.current_device( ) if not self.offload_optimizer else OFFLOAD_CPU_DEVICE + ### streams used for overlapping computation with communication + self.__allgather_stream = Stream( + ) if overlap_comm else torch.cuda.default_stream() + self.__reduce_and_partition_stream = Stream( + ) if overlap_comm else torch.cuda.default_stream() + ############################################################################ see_memory_usage("Before Partitioned Parameter Coordinator", force=False) - - fetch_stream = torch.cuda.Stream() if self.overlap_comm else None self.param_coordinator = PartitionedParameterCoordinator( - comm_stream=fetch_stream, + prefetch_bucket_sz=int(prefetch_bucket_size), max_reuse_distance_in_numel=int(max_reuse_distance), - max_available_parameters_in_numel=int(max_live_parameters)) - + max_available_parameters_in_numel=int(max_live_parameters), + allgather_stream=self.__allgather_stream, + prefetch_nvme=self.params_in_nvme_and_cpu, + ) see_memory_usage("After Partitioned Parameter Coordinator", force=False) + self.__n_caching_allocator_flushes = 0 + #self.param_coordinator = PartitionedParameterCoordinator(comm_stream=torch.cuda.Stream()) #-------------Stage 3 Setup-------------------# # parameters smaller than the threshold will be collectively gathered at the @@ -742,7 +702,7 @@ class DeepSpeedZeroOptimizer_Stage3(object): self.gradient_predivide_factor = gradient_predivide_factor self.postscale_gradients = postscale_gradients self.gradient_accumulation_steps = gradient_accumulation_steps - self.micro_step_id = INITIAL_MICRO_STEP_ID + self.micro_step_id = 0 if self.reduce_scatter: assert self.communication_data_type in (torch.float16, torch.bfloat16), f"ZeRO-3 supports only float16 or bfloat16 communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'" @@ -813,12 +773,41 @@ class DeepSpeedZeroOptimizer_Stage3(object): self.reduce_bucket_size = int(reduce_bucket_size) - self.reduction_event = torch.cuda.Event(enable_timing=False, blocking=False) + # IPG + if contiguous_gradients: + self.__ipg_bucket_flat_buffer: Tensor = torch.empty( + int(reduce_bucket_size), + dtype=self.dtype, + device=torch.cuda.current_device()) + + self.__param_id_to_grad_partition: Dict[int, Tensor] = {} + + all_params = list(itertools.chain.from_iterable(self.fp16_groups)) + + grad_partitions_flat_buffer: Tensor = torch.zeros( + sum(p.ds_tensor.ds_numel for p in all_params), + dtype=self.dtype, + device=self.device, + pin_memory=self.offload_optimizer_pin_memory) + + offset = 0 + for param in all_params: + self.__param_id_to_grad_partition[ + param.ds_id] = grad_partitions_flat_buffer.narrow( + 0, + offset, + param.ds_tensor.numel()) + offset += param.ds_tensor.numel() + + self.__params_in_ipg_bucket: List[Parameter] = [] + self.is_gradient_accumulation_boundary: bool = True + + self.__param_reduce_events: Deque[Event] = collections.deque() + # TODO. make this configurable via JSON + self.__max_param_reduce_events: int = 2 - self.reduction_stream = torch.cuda.Stream( - ) if self.overlap_comm else torch.cuda.current_stream() - self.callback_queued = False - self.copy_grad_stream = torch.cuda.Stream() + if dist.get_rank() == 0: + logger.info(f"optimizer state initialized") self.param_dict = {} @@ -829,7 +818,7 @@ class DeepSpeedZeroOptimizer_Stage3(object): self.extra_large_param_to_reduce = None self.grads_in_ipg_bucket = [] self.params_in_ipg_bucket = [] - self.elements_in_ipg_bucket = 0 + self.params_already_reduced = [] self.is_gradient_accumulation_boundary = True self._release_ipg_buffers() @@ -867,16 +856,9 @@ class DeepSpeedZeroOptimizer_Stage3(object): self.grads_in_partition = None if self.offload_optimizer: - self.accumulated_grads_in_cpu = {} self.norm_for_param_grads = {} self.local_overflow = False - self.temp_grad_buffer_for_gpu_offload = torch.zeros( - largest_partitioned_param_numel, - device=torch.cuda.current_device(), - dtype=self.dtype) - self.temp_grad_gpu_buffer = torch.zeros(largest_partitioned_param_numel, - device=torch.cuda.current_device(), - dtype=self.dtype) + see_memory_usage(f"After CPU Offload initialization", force=False) # stores if a partition has been reduced in this step @@ -913,6 +895,44 @@ class DeepSpeedZeroOptimizer_Stage3(object): if dist.get_rank(group=self.dp_process_group) == 0: see_memory_usage(f"After initializing ZeRO optimizer", force=False) + # TODO. factor out to a utility outside of stage3 + @staticmethod + def defragment(tensors: List[Tensor]) -> Tensor: + """move provided tensors into a contiguous flat buffer, with some additional + measures taken to reduce memory fragmentation""" + assert len(set(t.dtype for t in tensors)) == 1 + assert len(set(t.device for t in tensors)) == 1 + + cpu_buffer = torch.empty(sum(p.numel() for p in tensors), + dtype=get_only_unique_item(t.dtype for t in tensors), + device="cpu") + tensor_infos: List[Tuple[Tensor, int, int]] = [] + orig_device = get_only_unique_item(t.device for t in tensors) + + offset = 0 + for tensor in tensors: + tensor_numel = tensor.numel() + # move the tensor from device memory to host memory + cpu_buffer.narrow(0, offset, tensor_numel).copy_(tensor) + tensor.data = torch.empty(0, dtype=tensor.dtype, device=tensor.device) + + # record some data so we can restore the device tensor later + tensor_infos.append((tensor, offset, tensor_numel)) + + offset += tensor_numel + + gc.collect() + torch.cuda.empty_cache() + + # copy tensors (now flattened and contiguous) back to GPU + device_buffer = cpu_buffer.to(orig_device) + + # restore device tensors + for tensor, offset, tensor_numel in tensor_infos: + tensor.data = device_buffer.narrow(0, offset, tensor_numel) + + return device_buffer + def _configure_offloading(self, offload_optimizer_config, offload_param_config): ###################### offload optimizer setup ################################## if offload_optimizer_config is not None: @@ -985,6 +1005,10 @@ class DeepSpeedZeroOptimizer_Stage3(object): dtype=torch.float32, timers=self.timers) + @property + def elements_in_ipg_bucket(self): + return sum(p.ds_numel for p in self.__params_in_ipg_bucket) + def _move_to_flat_buffer(self, param_list, flat_buffer, avoid_copy=False): '''If flat buffer is None then the parameters in the param_list are not copied to the flat buffer. This is because they excede the number of max_params_in_cpu @@ -1053,95 +1077,73 @@ class DeepSpeedZeroOptimizer_Stage3(object): def _create_fp16_partitions_with_defragmentation(self): dist.barrier() - partition_id = dist.get_rank(group=self.dp_process_group) - create_fp16_flat_reuse_buffer = False - largest_partition_numel = [] - max_partition_numel = 0 + param_groups: List[List[Parameter]] = tuple( + self._create_fp16_sub_groups(param_group["params"]) + for param_group in self.optimizer.param_groups) - #create a flat CPU memory allocation for each param group - if self.offload_param: - self._create_param_groups_fp16_flat_cpu_memory() + # bookkeeping related to param groups + for param_group_idx, param_group in enumerate(param_groups): + for sub_group in param_group: + sub_group_idx = len(self.fp16_groups) - # loop to deal with groups - for j, param_group in enumerate(self.optimizer.param_groups): - - sub_groups = self._create_fp16_sub_groups(param_group['params']) - print_rank_0(f'fp16 group {j} has {len(sub_groups)} subgroups', force=False) - - flat_offset = 0 - for sub_group in sub_groups: - i = len(self.fp16_groups) - - # push this group to list before modify + # record sub group and partitions self.fp16_groups.append(sub_group) - self.sub_group_to_group_id[i] = j - - # comment out for zero_to_fp32 debug - # if torch.distributed.get_rank() == 0: - # for param in self.fp16_groups[i]: - # print(f"{debug_param2name_id_shape(param)} {param.ds_shape}") - - #These are the list of the partitioned parameters self.fp16_partitioned_groups.append( - [param.ds_tensor for param in self.fp16_groups[i]]) - - total_elements = sum( - [t.ds_numel for t in self.fp16_partitioned_groups[i]]) - self.fp16_partitioned_groups_flat_numel.append(total_elements) - - if total_elements > max_partition_numel: - largest_partition_numel = [ - t.ds_numel for t in self.fp16_partitioned_groups[i] - ] - max_partition_numel = total_elements - - print_rank_0( - f"fp16 group {i} partitioned_param norms : {[param.ds_tensor.norm().item() for param in self.fp16_groups[i]]}" - ) - - # Record padding required to align group to world size (only applies to last rank) - if partition_id == dist.get_world_size(group=self.dp_process_group) - 1: - padding = [p.padding_size() for p in self.fp16_groups[i]] - else: - padding = [0] * len(self.fp16_groups[i]) - self.groups_padding.append(padding) - - #not sure why apex was cloning the weights before flattening - #removing cloning here - see_memory_usage(f"Before Flattening param subgroup {i}", force=False) - - #all partitioned parameters remain in GPU during training - if not self.offload_param: - see_memory_usage(f"Before moving param subgroup group {i} to CPU", - force=False) - #move all the parameters to cpu to free up GPU space for creating flat buffer - move_to_cpu(self.fp16_partitioned_groups[i]) - see_memory_usage(f"After moving param subgroup {i} to CPU", - force=False) - - #create flat buffer in CPU and move to GPU - self.fp16_partitioned_groups_flat.append( - self.flatten_dense_tensors_aligned( - self.fp16_partitioned_groups[i], - 1).cuda(torch.cuda.current_device())) - see_memory_usage( - f"After flattening and moving param subgroup {i} to GPU", - force=False) - - #all partitioned parameters are in CPU during training - else: + [param.ds_tensor for param in sub_group]) + + # record sub group -> group mapping + self.sub_group_to_group_id[sub_group_idx] = param_group_idx + + # record total elements of parameter partitions in sub group + self.fp16_partitioned_groups_flat_numel.append( + sum(p.ds_tensor.ds_numel for p in sub_group)) + + # record padding required to align group to world size (only applies to last rank) + rank_requires_padding = dist.get_rank( + self.dp_process_group) == dist.get_world_size( + self.dp_process_group) - 1 + self.groups_padding.append([ + p.padding_size() if rank_requires_padding else 0 for p in sub_group + ]) + + # move parameters to flattened buffer + if not self.offload_param: # partitioned params remain in GPU during training + # move parameter partitions into a single contiguous flat buffer + parameter_partitions: List[Tensor] = [] + for sub_group in self.fp16_groups: + for param in sub_group: + parameter_partitions.append(param.ds_tensor) + device_buffer = __class__.defragment(parameter_partitions) + + # setup flat buffers per subgroup, these are each just sections of the + # contiguous flat buffer for all parameters that we created earlier + offset = 0 + for sub_group in self.fp16_groups: + sub_group_numel = sum(param.ds_tensor.ds_numel for param in sub_group) + self.fp16_partitioned_groups_flat.append( + device_buffer.narrow(0, + offset, + sub_group_numel)) + offset += sub_group_numel + else: # partitioned params offloaded to CPU when not in use + # create a flat CPU memory allocation for each param group + self._create_param_groups_fp16_flat_cpu_memory() + for param_group_idx, param_group in enumerate(param_groups): + flat_offset = 0 + for i, sub_group in enumerate(param_group): + total_elements = sum(p.ds_tensor.ds_numel for p in sub_group) print_rank_0(f"Params in nvme and cpu {self.params_in_nvme_and_cpu}") #Flat buffer may not be available for parameters that reside in NVME if not self.params_in_nvme_and_cpu or flat_offset + total_elements <= self.param_groups_fp16_flat_cpu_memory[ - j].numel(): + param_group_idx].numel(): fp16_partitioned_group_flat = self.param_groups_fp16_flat_cpu_memory[ - j].narrow(0, - flat_offset, - total_elements) + param_group_idx].narrow(0, + flat_offset, + total_elements) print_rank_0( f"Creating a flat buffer for subgroup {i} requiring {total_elements} elements, and cumulative CPU elements {flat_offset + total_elements}", force=False) - #these parameters reside in NVME and + elif self.params_in_nvme_and_cpu: fp16_partitioned_group_flat = None print_rank_0( @@ -1153,20 +1155,23 @@ class DeepSpeedZeroOptimizer_Stage3(object): self.fp16_partitioned_groups_flat.append(fp16_partitioned_group_flat) flat_offset += total_elements - # move param to flat buffer for both param offload on/off - self._move_to_flat_buffer(self.fp16_groups[i], - self.fp16_partitioned_groups_flat[i], - avoid_copy=not self.offload_param) - - see_memory_usage(f"After Flattening param group {i}", force=False) - - #create a pinned memory to be used for swapping out params to NVME after optimizer step - if self.fp16_partitioned_groups_flat[-1] is None: - create_fp16_flat_reuse_buffer = True - - see_memory_usage(f"After Flattening param subgroup {i}", force=False) + self._move_to_flat_buffer(sub_group, + fp16_partitioned_group_flat, + avoid_copy=not self.offload_param) + + # if necessary, create a pinned memory buffer to be used for swapping out + # params to NVME after optimizer step + should_create_fp16_flat_reuse_buffer = any( + flattened_partition_group is None + for flattened_partition_group in self.fp16_partitioned_groups_flat) + if should_create_fp16_flat_reuse_buffer: + max_partition_numel, largest_partition_numel = 0, None + for sub_group in self.fp16_groups: + total_elements = sum(t.ds_tensor.ds_numel for t in sub_group) + if total_elements > max_partition_numel: + largest_partition_numel = [t.ds_numel for t in sub_group] + max_partition_numel = total_elements - if create_fp16_flat_reuse_buffer: assert len(largest_partition_numel) > 0, f'Unexpected that largest partition is empty' self.fp16_groups[0][0].nvme_swapper.reserve_partitioned_swap_space( largest_partition_numel) @@ -1358,21 +1363,17 @@ class DeepSpeedZeroOptimizer_Stage3(object): def setup_zero_stage3_hooks(self): self.hierarchy = 0 - self._register_hooks_recursively(self.module) - - #reset step at the beginning of forward - def _pre_forward_hook(module, *args): - self.param_coordinator.reset_step() #reset step if in inference mode + @instrument_w_nvtx def _end_of_forward_hook(module, *args): if not torch._C.is_grad_enabled(): self.param_coordinator.reset_step() #likely one of them should be enough but just to be safe + self._register_hooks_recursively(self.module) self.module.register_forward_hook(_end_of_forward_hook) - self.module.register_forward_pre_hook(_pre_forward_hook) # Add top module to stack trace global FWD_MODULE_STACK @@ -1404,9 +1405,11 @@ class DeepSpeedZeroOptimizer_Stage3(object): count[0] = count[0] + 1 self._register_hooks_recursively(child, count=count) + @instrument_w_nvtx def _pre_forward_module_hook(module, *args): self.pre_sub_module_forward_function(module) + @instrument_w_nvtx def _post_forward_module_hook(module, input, output): global FWD_MODULE_STACK FWD_MODULE_STACK.pop() @@ -1427,7 +1430,7 @@ class DeepSpeedZeroOptimizer_Stage3(object): for item in filter(lambda item: is_zero_param(item), output): if not any(id(item) in m._external_params for m in FWD_MODULE_STACK): - item.ds_active_sub_modules += 1 + item.is_external_param = True module_to_register = FWD_MODULE_STACK[-1] print_rank_0( f'Registering dangling parameter for module {module_to_register.__class__.__name__}.', @@ -1447,6 +1450,7 @@ class DeepSpeedZeroOptimizer_Stage3(object): self.post_sub_module_forward_function(module) def _pre_backward_module_hook(module, inputs, output): + @instrument_w_nvtx def _run_before_backward_function(sub_module): # some models (e.g. Albert) may run multiple forwards on the same layer in a loop # before doing backwards, so each backward will need a pre-fetch - using reference @@ -1488,6 +1492,7 @@ class DeepSpeedZeroOptimizer_Stage3(object): def _post_backward_module_hook(module, inputs): module.ds_grads_remaining = 0 + @instrument_w_nvtx def _run_after_backward_function(sub_module): if sub_module.ds_grads_remaining == 0: self.post_sub_module_backward_function(sub_module) @@ -1508,6 +1513,7 @@ class DeepSpeedZeroOptimizer_Stage3(object): # post backward hook module.register_forward_pre_hook(_post_backward_module_hook) + @torch.no_grad() def pre_sub_module_forward_function(self, sub_module): see_memory_usage(f"Before sub module function {sub_module.__class__.__name__}", force=False) @@ -1515,23 +1521,15 @@ class DeepSpeedZeroOptimizer_Stage3(object): global FWD_MODULE_STACK FWD_MODULE_STACK.append(sub_module) - self.param_coordinator.record_trace(sub_module) + if not self.param_coordinator.trace_complete: + self.param_coordinator.record_trace(sub_module) self.param_coordinator.fetch_sub_module(sub_module) see_memory_usage( f"Before sub module function {sub_module.__class__.__name__} after fetch", force=False) - self.param_coordinator.prefetch_next_sub_modules( - sub_module, - numel=self.prefetch_elements, - nvme=self.params_in_nvme_and_cpu) - see_memory_usage( - f"Before sub module function {sub_module.__class__.__name__} after prefetch", - force=False) - - self.param_coordinator.increment_step(sub_module) - + @torch.no_grad() def post_sub_module_forward_function(self, sub_module): see_memory_usage( f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release", @@ -1543,16 +1541,13 @@ class DeepSpeedZeroOptimizer_Stage3(object): f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release", force=False) + @torch.no_grad() def pre_sub_module_backward_function(self, sub_module): - self.param_coordinator.record_trace(sub_module) - + if not self.param_coordinator.trace_complete: + self.param_coordinator.record_trace(sub_module) self.param_coordinator.fetch_sub_module(sub_module) - self.param_coordinator.prefetch_next_sub_modules(sub_module, - numel=self.prefetch_elements) - - self.param_coordinator.increment_step(sub_module) - + @torch.no_grad() def post_sub_module_backward_function(self, sub_module): see_memory_usage( f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release", @@ -1719,16 +1714,13 @@ class DeepSpeedZeroOptimizer_Stage3(object): param_group, partition_id) + @instrument_w_nvtx def independent_gradient_partition_epilogue(self): self.report_ipg_memory_usage(f"In ipg_epilogue before reduce_ipg_grads", 0) - self.reduce_ipg_grads() + self.__reduce_and_partition_ipg_grads() self.report_ipg_memory_usage(f"In ipg_epilogue after reduce_ipg_grads", 0) - if self.overlap_comm: - self.reduction_stream.synchronize() - - with torch.cuda.stream(self.reduction_stream): - self.partition_previous_reduced_grads() + self.__reduce_and_partition_stream.synchronize() # if dist.get_rank() == 0: # logger.info("Params already reduced %s", self.params_already_reduced) @@ -1740,10 +1732,8 @@ class DeepSpeedZeroOptimizer_Stage3(object): if not self.offload_optimizer: for i, sub_group in enumerate(self.fp16_groups): self.averaged_gradients[i] = [ - torch.zeros_like(param.ds_tensor) if param.grad is None else - param.grad.data.narrow(0, - 0, - param.ds_tensor.numel()) + self.__param_id_to_grad_partition[param.ds_id] + if param.requires_grad else torch.zeros_like(param.ds_tensor) for param in sub_group ] # self.averaged_gradients[i] = self.get_flat_partition( @@ -1752,82 +1742,15 @@ class DeepSpeedZeroOptimizer_Stage3(object): # self.fp32_partitioned_groups_flat[i].numel(), # return_tensor_list=True) - self._release_ipg_buffers() - - see_memory_usage(f"End ipg_epilogue", force=False) - - # resets all partition to no reduced - # sets remaining grads to the total number of grads in each partition - # set is grad computed to false for all grads in partition - def reset_partition_gradient_structures(self): - total_partitions = dist.get_world_size(group=self.dp_process_group) - for i, _ in enumerate(self.fp16_groups): - for partition_id in range(total_partitions): - self.is_partition_reduced[i][partition_id] = False - self.remaining_grads_in_partition[i][ - partition_id] = self.total_grads_in_partition[i][partition_id] - - for param_id in self.is_grad_computed[i][partition_id]: - self.is_grad_computed[i][partition_id][param_id] = False - - def initialize_gradient_partition(self, i, param_group, partition_id): - def set_key_value_list(dictionary, key, value): - if key in dictionary: - dictionary[key].append(value) - else: - dictionary[key] = [value] - - def increment_value(dictionary, key): - if key in dictionary: - dictionary[key] += 1 - else: - dictionary[key] = 1 - - partition_size = self.partition_size[i] - - start_index = partition_size * partition_id - end_index = partition_size * (partition_id + 1) - - current_index = 0 - first_offset = 0 - - for param in param_group: - - param_size = param.numel() - param_id = self.get_param_id(param) - - if (current_index >= start_index and current_index < end_index): - set_key_value_list(self.param_to_partition_ids[i], - param_id, - partition_id) - increment_value(self.total_grads_in_partition[i], partition_id) - - self.is_grad_computed[i][partition_id][param_id] = False - - self.grad_partition_insertion_offset[i][partition_id][ - param_id] = current_index - start_index - self.grad_start_offset[i][partition_id][param_id] = 0 - - elif start_index > current_index and start_index < (current_index + - param_size): - assert (first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition" - first_offset = start_index - current_index - - set_key_value_list(self.param_to_partition_ids[i], - param_id, - partition_id) - increment_value(self.total_grads_in_partition[i], partition_id) - - self.is_grad_computed[i][partition_id][param_id] = False - - self.grad_partition_insertion_offset[i][partition_id][param_id] = 0 - self.grad_start_offset[i][partition_id][param_id] = first_offset - - current_index = current_index + param_size + # this method gets called after every backward. need to increment + # here because if it gets incremented in backward() the micro step + # id will be off by one when we do the reduce and partition at the. + # start of this method. + # TODO. make this less error prone + self.micro_step_id += 1 def overlapping_partition_gradients_reduce_epilogue(self): self.independent_gradient_partition_epilogue() - self.zero_grad() def create_reduce_and_remove_grad_hooks(self): print_rank_0(f'[Begin] Create gradient reduction hooks') @@ -1845,6 +1768,7 @@ class DeepSpeedZeroOptimizer_Stage3(object): param_tmp = param.expand_as(param) grad_acc = param_tmp.grad_fn.next_functions[0][0] + @instrument_w_nvtx def reduce_partition_and_remove_grads(*notneeded): self.reduce_ready_partitions_and_remove_grads(param, i) @@ -1882,13 +1806,7 @@ class DeepSpeedZeroOptimizer_Stage3(object): self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", param.ds_numel) - self.reduce_ipg_grads() - - if self.contiguous_gradients and self.overlap_comm: - # Swap ipg_index between 0 and 1 - self.ipg_index = 1 - self.ipg_index - self.report_ipg_memory_usage("In ipg_remove_grads after reduce_ipg_grads", - param.ds_numel) + self.__reduce_and_partition_ipg_grads() param_id = self.get_param_id(param) assert self.params_already_reduced[param_id] == False, \ @@ -1896,68 +1814,91 @@ class DeepSpeedZeroOptimizer_Stage3(object): Gradient computed twice for this partition. \ Multiple gradient reduction is currently not supported" - # keeping the gradients contiguous to prevent memory fragmentation, and avoid flattening - if param.ds_numel > self.reduce_bucket_size: - self.extra_large_param_to_reduce = param + self.__add_grad_to_ipg_bucket(param) - elif self.contiguous_gradients: - #print_rank_0("before new grad tensor move") - new_grad_tensor = self.ipg_buffer[self.ipg_index].narrow( - 0, - self.elements_in_ipg_bucket, - param.ds_numel) - #print_rank_0("after new grad tensor move") - new_grad_tensor.copy_(param.grad.view(-1)) - param.grad.data = new_grad_tensor.data.view_as(param.grad) + @instrument_w_nvtx + @torch.no_grad() + def __add_grad_to_ipg_bucket(self, param: Parameter) -> None: + self.__reduce_and_partition_stream.wait_stream(torch.cuda.default_stream()) - self.elements_in_ipg_bucket += param.ds_numel - self.grads_in_ipg_bucket.append(param.grad) - self.params_in_ipg_bucket.append((i, param, param_id)) - self.report_ipg_memory_usage("End ipg_remove_grads", 0) + if self.contiguous_gradients and self.elements_in_ipg_bucket + param.grad.numel( + ) < self.reduce_bucket_size: + # move the gradient to a contiguous buffer + with torch.cuda.stream(self.__reduce_and_partition_stream): + # move the parameter's gradient to the contiguous flat buffer + new_grad_tensor = self.__ipg_bucket_flat_buffer.narrow( + 0, + self.elements_in_ipg_bucket, + param.grad.numel()).view_as(param.grad) + new_grad_tensor.copy_(param.grad, non_blocking=True) + param.grad.record_stream(torch.cuda.current_stream()) + param.grad.data = new_grad_tensor + + self.__params_in_ipg_bucket.append(param) + + @instrument_w_nvtx + @torch.no_grad() + def __reduce_and_partition_ipg_grads(self, safe_mode: bool = False) -> None: + if not self.__params_in_ipg_bucket: + return - def gradient_reduction_w_predivide(self, tensor): - dp_world_size = dist.get_world_size(group=self.dp_process_group) + for param in self.__params_in_ipg_bucket: + if param.grad.numel() != param.ds_numel: + raise RuntimeError( + f"{param.grad.numel()} != {param.ds_numel} Cannot reduce scatter " + f"gradients whose size is not same as the params") - tensor_to_allreduce = tensor + self.__params_in_ipg_bucket.sort(key=lambda p: p.ds_id) - if self.communication_data_type != tensor.dtype: - tensor_to_allreduce = tensor.to(self.communication_data_type) + assert len(set(p.ds_id for p in self.__params_in_ipg_bucket)) == len( + self.__params_in_ipg_bucket) - if self.postscale_gradients: - if self.gradient_predivide_factor != 1.0: - tensor_to_allreduce.mul_(1. / self.gradient_predivide_factor) + while self.__param_reduce_events and self.__param_reduce_events[0].query(): + self.__param_reduce_events.popleft() + if len(self.__param_reduce_events) > self.__max_param_reduce_events: + self.__param_reduce_events.popleft().synchronize() - dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) + with torch.cuda.stream(self.__reduce_and_partition_stream): + if safe_mode: + assert_ints_same_as_other_ranks( + [p.ds_id for p in self.__params_in_ipg_bucket]) - if self.gradient_predivide_factor != dp_world_size: - tensor_to_allreduce.mul_(self.gradient_predivide_factor / dp_world_size) - else: - tensor_to_allreduce.div_(dp_world_size) - dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) + grad_partitions = self.__avg_scatter_grads(self.__params_in_ipg_bucket) + self.__partition_grads(self.__params_in_ipg_bucket, grad_partitions) - if self.communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce: - tensor.copy_(tensor_to_allreduce) + self.__params_in_ipg_bucket.clear() - return tensor + event = Event() + event.record() + self.__param_reduce_events.append(event) - def average_tensor(self, tensors, params_to_reduce): - with torch.cuda.stream(self.reduction_stream): - if not self.reduce_scatter: - for tensor in tensors: - self.gradient_reduction_w_predivide(tensor) - return - - for tensor in tensors: - tensor.div_(dist.get_world_size(group=self.dp_process_group)) - - # reduction resulting with each rank only holding the gradient partition it owns - # This could either be a reduce scatter or a reduce op depending on how - # parameters are partitionied. The method is implemented by the - # DeepSpeed param extensions to the pytorch parameter, so its up to - # the extension to define what happens here - params_to_reduce[0].reduce_gradients_at_owner( - param_list=params_to_reduce, - hierarchy=self.param_coordinator.hierarchy) + @instrument_w_nvtx + def __avg_scatter_grads(self, params_to_reduce: List[Parameter]) -> List[Tensor]: + """average gradients and scatter partitions across ranks""" + dtype = get_only_unique_item(p.grad.dtype for p in params_to_reduce) + + full_grads_for_rank = [p.grad for p in params_to_reduce] + if self.communication_data_type == torch.float32: + full_grads_for_rank = [g.float() for g in full_grads_for_rank] + + if self.postscale_gradients and self.gradient_predivide_factor != 1.0: + full_grads_for_rank = [ + g.div(self.gradient_predivide_factor) for g in full_grads_for_rank + ] + + grad_partitions_for_rank = reduce_scatter_coalesced(full_grads_for_rank, + self.dp_process_group) + + if self.postscale_gradients and self.gradient_predivide_factor != dist.get_world_size( + self.dp_process_group): + grad_partitions_for_rank = [ + g.mul(self.gradient_predivide_factor) for g in grad_partitions_for_rank + ] + + if self.communication_data_type == torch.float32: + grad_partitions_for_rank = [g.to(dtype) for g in grad_partitions_for_rank] + + return grad_partitions_for_rank def set_grad_positions(self): for i, group in enumerate(self.fp16_groups): @@ -1974,23 +1915,6 @@ class DeepSpeedZeroOptimizer_Stage3(object): #print(f"param id {param_id} i:{i}, ds_tensor {num_elements} numel {param.numel()}") current_offset += num_elements - def async_accumulate_grad_in_cpu_via_gpu(self, param, acc_grad_cpu_partition): - - # copy to a preexisiting buffer to avoid memory allocation penalty - dest_buffer = self.temp_grad_buffer_for_gpu_offload.view(-1).narrow( - 0, - 0, - param.ds_tensor.ds_numel) - - if self.micro_step_id > 0: - dest_buffer.copy_(acc_grad_cpu_partition.view(-1), non_blocking=True) - param.grad.data.view(-1).add_(dest_buffer) - - # at the boundary we will send 32bit directly - if not self.is_gradient_accumulation_boundary: - acc_grad_cpu_partition.data.copy_(param.grad.data.view(-1), - non_blocking=True) - def _constant_buffered_norm2(self, input, buffer_size=250000000): norm = None for part in input.view(-1).split(buffer_size): @@ -2006,14 +1930,6 @@ class DeepSpeedZeroOptimizer_Stage3(object): #Using a more memory efficient version self.norm_for_param_grads[param_id] = self._constant_buffered_norm2(param.grad) - def update_overflow_tracker_for_param_grad(self, param): - #Credit to our user David Minn - if param.grad is not None: - if self.overlap_comm: - self.gpu_sum = self.gpu_sum + param.grad.data.float().sum() - elif self._has_inf_or_nan(param.grad.data): - self.local_overflow = True - def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param, fp32_grad_tensor): with torch.cuda.stream(self.copy_grad_stream): param_id = self.get_param_id(param) @@ -2050,143 +1966,82 @@ class DeepSpeedZeroOptimizer_Stage3(object): return total_norm - def partition_previous_reduced_grads(self): - if not self.previous_reduced_grads: - return - - if self.offload_optimizer: - allocate_grads_in_partition = self.grads_in_partition is None\ - and self.gradient_accumulation_steps > 1 - else: - allocate_grads_in_partition = self.grads_in_partition is None - - if allocate_grads_in_partition: - self.grads_in_partition = [] - - for i, group in enumerate(self.fp16_groups): - total_size = 0 - for param_in_partition in group: - total_size += param_in_partition.ds_tensor.ds_numel - - see_memory_usage( - f"group {i} before creating {total_size} reduced gradients into partition", - force=False) - if self.offload_param_pin_memory: - self.grads_in_partition.append( - torch.zeros(int(total_size), - dtype=self.dtype, - device=self.device).pin_memory()) - else: - self.grads_in_partition.append( - torch.zeros(int(total_size), - dtype=self.dtype, - device=self.device)) - see_memory_usage( - f"group {i} after creating {total_size} reduced gradients into partition", - force=False) - - if self.offload_optimizer: - offload_fp32_gradients = {} - offload_fp32_offsets = {} - - with torch.cuda.stream(self.copy_grad_stream): - self.reduction_stream.synchronize() - for param in self.previous_reduced_grads: - - [i, - dest_offset, - num_elements] = self.grad_position[self.get_param_id(param)] - - if self.offload_optimizer: - param.partition_gradients( - partition_buffers=self.temp_grad_gpu_buffer) - #with torch.cuda.stream(self.copy_grad_stream): - # self.reduction_stream.synchronize() - - if self.gradient_accumulation_steps > 1: - # The allreduce buffer will be rewritten. Copy the gradients in partition to a new buffer - fp16_grad_tensor = self.grads_in_partition[i].narrow( - 0, - dest_offset, - num_elements) - self.async_accumulate_grad_in_cpu_via_gpu( - param, - fp16_grad_tensor) - - if self.is_gradient_accumulation_boundary: - - self.set_norm_for_param_grad_in_gpu(param) - - self.update_overflow_tracker_for_param_grad(param) - - if self._swappable_optimizer_subgroup(i): - if not i in offload_fp32_gradients.keys(): - offload_fp32_gradients[i] = [] - offload_fp32_offsets[i] = [] - - offload_fp32_gradients[i].append(param.grad.view(-1).float()) - param.grad = None - offload_fp32_offsets[i].append(dest_offset) - else: - fp32_grad_tensor = self.fp32_partitioned_groups_flat[ - i].grad.narrow(0, - dest_offset, - num_elements) - - self.async_inplace_copy_grad_to_fp32_buffer_from_gpu( - param, - fp32_grad_tensor) - else: - # The allreduce buffer will be rewritten. Copy the gradients in partition to a new buffer - fp16_grad_tensor = self.grads_in_partition[i].narrow( - 0, - dest_offset, - num_elements) - param.partition_gradients( - partition_buffers=fp16_grad_tensor, - accumulate=True if self.micro_step_id > 0 else False) - - if self.offload_optimizer and self.swap_optimizer: - for i in offload_fp32_gradients.keys(): - self.optimizer_swapper.swap_out_gradients( - parameter=self.fp32_partitioned_groups_flat[i], - gradient_offsets=offload_fp32_offsets[i], - gradient_tensors=offload_fp32_gradients[i]) - - self.previous_reduced_grads = [] + @instrument_w_nvtx + def __partition_grads(self, + params_to_release: List[Parameter], + grad_partitions: List[Tensor]) -> None: + for param, grad_partition in zip(params_to_release, grad_partitions): + if param.ds_tensor.ds_numel * dist.get_rank( + self.dp_process_group) > param.ds_numel: + # this grad partition is empty - don't need to do anything + continue - def reduce_ipg_grads(self, extra_param=None): - if self.overlap_comm: - self.reduction_stream.synchronize() + # move or accumulate gradient partition to target buffer + grad_buffer = self.__param_id_to_grad_partition[param.ds_id].narrow( + 0, + 0, + grad_partition.numel()) + if self.micro_step_id == 0: # don't accumulate + grad_buffer.copy_(grad_partition, non_blocking=True) + # ensure grad buffer is a CUDA buffer to speed up the next few + # operations and so it can be used asynchronously + grad_buffer = grad_buffer.to(grad_partition.device, non_blocking=True) + elif grad_buffer.is_cuda: + grad_buffer.add_(grad_partition) + else: + # if dst is CPU, copy first to src device, do the addition + # there, then move back to dst. adding directly to cpu is very slow + cuda_grad_buffer = grad_buffer.to(grad_partition.device, + non_blocking=True) + cuda_grad_buffer.add_(grad_partition) + grad_buffer.copy_(cuda_grad_buffer, non_blocking=True) + # ensure grad buffer is a CUDA buffer to speed up the next few + # operations and so it can be used asynchronously + grad_buffer = cuda_grad_buffer + + if hasattr(self.__inf_or_nan_tracker, "logical_or_"): + self.__inf_or_nan_tracker.logical_or_(torch.isinf(grad_buffer).any()) + self.__inf_or_nan_tracker.logical_or_(torch.isnan(grad_buffer).any()) + else: + # logical_or_ not available in older versions of pytorch + self.__inf_or_nan_tracker += torch.isinf(grad_buffer).any() + self.__inf_or_nan_tracker += torch.isnan(grad_buffer).any() + self.__inf_or_nan_tracker = self.__inf_or_nan_tracker > 0 - with torch.cuda.stream(self.reduction_stream): - self.partition_previous_reduced_grads() + # offload the gradient partition if applicable + if self.offload_optimizer: + i, dest_offset, _ = self.grad_position[self.get_param_id(param)] + offload_fp32_gradients = {} + offload_fp32_offsets = {} - params_to_reduce = [param for i, param, param_id in self.params_in_ipg_bucket] - #print(f"Params in ipg bucket {self.params_in_ipg_bucket}") - #print(f"Reducing {[(debug_param2name_id_shape(param), param.grad) for param in params_to_reduce]}") - #exit(0) - if self.contiguous_gradients: - reduction_list = [self.ipg_buffer[self.ipg_index]] - if self.extra_large_param_to_reduce is not None: - reduction_list.append(self.extra_large_param_to_reduce.grad) - self.extra_large_param_to_reduce = None - self.average_tensor(reduction_list, params_to_reduce) - else: - self.buffered_reduce_fallback( - None, - self.grads_in_ipg_bucket, - elements_per_buffer=self.elements_in_ipg_bucket) + if self.is_gradient_accumulation_boundary: + self.norm_for_param_grads[self.get_param_id( + param)] = self._constant_buffered_norm2(grad_buffer) - for _, param, param_id in self.params_in_ipg_bucket: - self.params_already_reduced[param_id] = True + if self._swappable_optimizer_subgroup(i): + if not i in offload_fp32_gradients.keys(): + offload_fp32_gradients[i] = [] + offload_fp32_offsets[i] = [] - self.previous_reduced_grads = params_to_reduce + offload_fp32_gradients[i].append(grad_buffer.float()) + offload_fp32_offsets[i].append(dest_offset) + else: + fp32_grad_tensor = self.fp32_partitioned_groups_flat[ + i].grad.narrow(0, + dest_offset, + grad_buffer.numel()) + fp32_grad_tensor.copy_(grad_buffer) + + # free the gradient + param.grad.record_stream(torch.cuda.current_stream()) + param.grad = None - self.grads_in_ipg_bucket = [] - self.params_in_ipg_bucket = [] - self.elements_in_ipg_bucket = 0 - ##################################################################### + if self.offload_optimizer and self.swap_optimizer: + for i in offload_fp32_gradients.keys(): + self.optimizer_swapper.swap_out_gradients( + parameter=self.fp32_partitioned_groups_flat[i], + gradient_offsets=offload_fp32_offsets[i], + gradient_tensors=offload_fp32_gradients[i]) def reduce_ready_partitions_and_remove_grads(self, param, i): #print_rank_0(f"Backward {debug_param2name_id_shape(param)}", force=True) @@ -2315,20 +2170,6 @@ class DeepSpeedZeroOptimizer_Stage3(object): if len(small_bucket) > 0: self.allreduce_and_copy(small_bucket, rank=rank, log=log) - # allows using reduction of gradients instead of using all_reduce - def buffered_reduce_fallback(self, - rank, - grads, - elements_per_buffer=500000000, - log=None): - split_buckets = split_half_float_double(grads) - - for i, bucket in enumerate(split_buckets): - self.allreduce_no_retain(bucket, - numel_per_bucket=elements_per_buffer, - rank=rank, - log=log) - ############################################################################# ############################################################################# ############################################################################# @@ -2386,15 +2227,20 @@ class DeepSpeedZeroOptimizer_Stage3(object): return params_in_partition, params_not_in_partition, first_offset + @instrument_w_nvtx def zero_grad(self, set_grads_to_None=True): """ Zero FP16 parameter grads. """ + self.micro_step_id = 0 + # FP32 grad should never exist. # For speed, set model fp16 grad to None by default for group in self.fp16_groups: for p in group: if set_grads_to_None: + if p.grad is not None and p.grad.is_cuda: + p.grad.record_stream(torch.cuda.current_stream()) p.grad = None else: if p.grad is not None: @@ -2411,6 +2257,7 @@ class DeepSpeedZeroOptimizer_Stage3(object): op=op, group=self.model_parallel_group) + @instrument_w_nvtx def get_grad_norm_direct(self, gradients, params, norm_type=2): """Clips gradient norm of an iterable of parameters. @@ -2441,15 +2288,15 @@ class DeepSpeedZeroOptimizer_Stage3(object): op=torch.distributed.ReduceOp.MAX) total_norm = total_norm_cuda[0].item() else: - total_norm = 0.0 # if dist.get_rank() == 0: # logger.info(f"Total Norm beginning {total_norm}") + grad_norms = [] for g, p in zip(gradients, params): if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): - param_norm = g.data.double().norm(2) - total_norm += param_norm.item()**2 + grad_norms.append(g.cuda(non_blocking=True).double().norm(2)) + # Sum across all model parallel GPUs. - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + total_norm_cuda = torch.sum(torch.pow(torch.stack(grad_norms), 2)) torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.SUM, @@ -2458,7 +2305,7 @@ class DeepSpeedZeroOptimizer_Stage3(object): self._model_parallel_all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM) - total_norm = total_norm_cuda[0].item()**(1. / norm_type) + total_norm = total_norm_cuda.item()**(1. / norm_type) if total_norm == float( 'inf') or total_norm == -float('inf') or total_norm != total_norm: @@ -2546,19 +2393,17 @@ class DeepSpeedZeroOptimizer_Stage3(object): self.timers(name).stop() def _pre_step(self): - self.micro_step_id = INITIAL_MICRO_STEP_ID + self.micro_step_id = 0 print_rank_0(f"Inside Step function") see_memory_usage(f"In step before checking overflow", force=False) print_rank_0("Finished Tracing at Beginning of Step") self.param_coordinator.hierarchy = 0 - self.param_coordinator.finish_tracing(print_trace=True) - - self.param_coordinator.reset_step() print_rank_0("Finished Tracing at Beginning of Step") + @instrument_w_nvtx def _get_norm_groups(self): norm_groups = [] for i, group in enumerate(self.fp16_groups): @@ -2572,6 +2417,7 @@ class DeepSpeedZeroOptimizer_Stage3(object): self.fp16_groups[i])) return norm_groups + @instrument_w_nvtx def _prepare_fp32_grad_for_sub_group(self, sub_group_id): partition_id = dist.get_rank(group=self.dp_process_group) @@ -2587,8 +2433,12 @@ class DeepSpeedZeroOptimizer_Stage3(object): # release all the gradient since we have already created a necessary copy in dp_grad_partition self.zero_grad() + for grad in filter(lambda g: g.is_cuda, self.averaged_gradients[sub_group_id]): + grad.record_stream(torch.cuda.current_stream()) + self.averaged_gradients[sub_group_id] = None + @instrument_w_nvtx def _prepare_sub_group(self, sub_group_id, timer_names=set()): see_memory_usage(f'Before prepare optimizer sub group {sub_group_id}', force=False) @@ -2619,6 +2469,7 @@ class DeepSpeedZeroOptimizer_Stage3(object): see_memory_usage(f'pre-step After swapping in optimizer tensors {sub_group_id}', force=False) + @instrument_w_nvtx def _release_sub_group(self, sub_group_id, timer_names=set()): see_memory_usage(f'Before release optimizer sub group {sub_group_id}', force=False) @@ -2632,6 +2483,7 @@ class DeepSpeedZeroOptimizer_Stage3(object): force=False) # create a flat tensor aligned at the alignment boundary + @instrument_w_nvtx def flatten_dense_tensors_aligned(self, tensor_list, alignment): num_elements = 0 for tens in tensor_list: @@ -2703,6 +2555,7 @@ class DeepSpeedZeroOptimizer_Stage3(object): prev_scale, self.loss_scale)) + @instrument_w_nvtx def _overflow_check_and_loss_scale_update(self): # First compute norm for all group so we know if there is overflow @@ -2717,6 +2570,7 @@ class DeepSpeedZeroOptimizer_Stage3(object): return self.overflow + @instrument_w_nvtx def _post_step(self, timer_names=set()): if self.offload_optimizer: self.reset_cpu_buffers() @@ -2733,6 +2587,7 @@ class DeepSpeedZeroOptimizer_Stage3(object): see_memory_usage('After zero_optimizer step', force=False) print_rank_0(f"------------------Finishing Step-----------------------") + @instrument_w_nvtx def _reassign_or_swap_out_partitioned_parameters(self, sub_group_id): if self.fp16_partitioned_groups_flat[sub_group_id] is not None: self.fp16_partitioned_groups_flat[sub_group_id].data.copy_( @@ -2743,11 +2598,13 @@ class DeepSpeedZeroOptimizer_Stage3(object): else: self._partitioned_params_swap_out(sub_group_id) + @instrument_w_nvtx def step(self, closure=None): """ Not supporting closure. """ self._pre_step() + self._partition_all_parameters() #checks for overflow, adjust the loss scale accordingly if self._overflow_check_and_loss_scale_update(): @@ -2784,7 +2641,23 @@ class DeepSpeedZeroOptimizer_Stage3(object): self.stop_timers(['optimizer_step']) self._post_step(timer_names) - return + + # warn user about caching allocator flushes + alloc_retries = torch.cuda.memory_stats()["num_alloc_retries"] if hasattr( + torch.cuda, + "memory_stats") else 0 + if alloc_retries > self.__n_caching_allocator_flushes: + if dist.get_rank() == 0: + logger.warning( + "%d pytorch allocator cache flushes since last step. this happens " + "when there is high memory pressure and is detrimental to " + "performance. if this is happening frequently consider adjusting " + "settings to reduce memory consumption. If you are unable to " + "make the cache flushes go away consider adding " + "torch.cuda.empty_cache() calls in your training loop to ensure " + "that all ranks flush their caches at the same time", + alloc_retries - self.__n_caching_allocator_flushes) + self.__n_caching_allocator_flushes = alloc_retries def dump_pre_step_gradients(self, debug_fp32_grads): # Dump gradient norms for debugging @@ -2819,9 +2692,8 @@ class DeepSpeedZeroOptimizer_Stage3(object): norm_list = [param_norm, ds_norm] + unflat_norm print(f'Post-Step Norms {i} {param_id} = {norm_list}') + @instrument_w_nvtx def unscale_and_clip_grads(self, sub_group_id, total_norm): - grad_groups_flat = [self.fp32_partitioned_groups_flat[sub_group_id].grad] - # compute combined scale factor for this group combined_scale = self.loss_scale if self.clip_grad > 0.: @@ -2830,13 +2702,7 @@ class DeepSpeedZeroOptimizer_Stage3(object): if clip > 1: combined_scale = clip * self.loss_scale - for grad in grad_groups_flat: - if isinstance(grad, list): - sub_partitions = grad - for g in sub_partitions: - g.data.mul_(1. / combined_scale) - else: - grad.data.mul_(1. / combined_scale) + self.fp32_partitioned_groups_flat[sub_group_id].grad.mul_(1. / combined_scale) def _check_overflow(self, partition_gradients=True): self.overflow = self.has_overflow(partition_gradients) @@ -2856,14 +2722,14 @@ class DeepSpeedZeroOptimizer_Stage3(object): return True return False + @instrument_w_nvtx def has_overflow(self, partition_gradients=True): if partition_gradients: - if self.overlap_comm: - self.local_overflow = self._has_inf_or_nan(self.gpu_sum) - self.gpu_sum = torch.zeros(1, dtype=torch.float).cuda() + with torch.cuda.stream(self.__reduce_and_partition_stream): + self.local_overflow = bool(self.__inf_or_nan_tracker.item()) + self.__inf_or_nan_tracker.zero_() - overflow = self.local_overflow if self.offload_optimizer else self.has_overflow_partitioned_grads_serial( - ) + overflow = self.local_overflow #overflow = self.has_overflow_partitioned_grads_serial() overflow_gpu = torch.cuda.ByteTensor([overflow]) torch.distributed.all_reduce(overflow_gpu, @@ -2909,6 +2775,7 @@ class DeepSpeedZeroOptimizer_Stage3(object): return True return False + @instrument_w_nvtx def backward(self, loss, retain_graph=False): """ :attr:`backward` performs the following steps: @@ -2917,42 +2784,49 @@ class DeepSpeedZeroOptimizer_Stage3(object): 2. scaled_loss = fp32_loss*loss_scale 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves """ - self.micro_step_id += 1 - print_rank_0( - f"Total fully available parameters {self.param_coordinator.total_available_parameter_numel}" - ) - if self.swap_optimizer: self.optimizer_swapper.pre_backward() see_memory_usage(f"Before backward", force=False) - if self.contiguous_gradients: - self.ipg_buffer = [] - buf_0 = torch.empty(self.reduce_bucket_size, - dtype=self.dtype, - device=torch.cuda.current_device()) - self.ipg_buffer.append(buf_0) - - # Use double buffers to avoid data access conflict when overlap_comm is enabled. - if self.overlap_comm: - buf_1 = torch.empty(self.reduce_bucket_size, - dtype=self.dtype, - device=torch.cuda.current_device()) - self.ipg_buffer.append(buf_1) - self.ipg_index = 0 self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) - '''Partitioning Parameters that were not partitioned - Usually if parameters of modules whose input parameters do not require - grad computation do not trigger post call and will therefore will remain unpartitioned ''' - self._partition_all_parameters() + + self.param_coordinator.reset_step() if self.swap_optimizer: self.optimizer_swapper.post_backward() + def get_fp32_grad_partitions(self) -> Dict[int, Dict[int, Tensor]]: + """get fp32 gradient partition dictionary + accessed as grad_dict[parameter_group_index][parameter_index] + """ + self.__reduce_and_partition_stream.synchronize() + grad_dict = collections.defaultdict(dict) + if self.offload_optimizer: + for group in self.fp16_groups: + for param_idx, param in enumerate(group): + group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)] + fp32_grad = self.fp32_partitioned_groups_flat[group_idx].grad.narrow( + 0, + dest_offset, + num_elements) + grad_dict[group_idx][param_idx] = fp32_grad + else: + for group_idx, group in self.averaged_gradients.items(): + for param_idx, gradient in enumerate(group): + grad_dict[group_idx][param_idx] = gradient.float() + + return grad_dict + + @instrument_w_nvtx def _partition_all_parameters(self): - for name, param in self.module.named_parameters(recurse=True): - self.param_coordinator.release_and_reset_parameter(param) + """Partitioning Parameters that were not partitioned usually if parameters + of modules whose input parameters do not require grad computation do not + trigger post call and will therefore will remain unpartitioned""" + self.param_coordinator.release_and_reset_all() + for param in iter_params(self.module, recurse=True): + if param.ds_status != ZeroParamStatus.NOT_AVAILABLE: + raise RuntimeError(f"{param.ds_summary()} expected to be released") def check_overflow(self, partition_gradients=True): self._check_overflow(partition_gradients) diff --git a/deepspeed/runtime/zero/utils.py b/deepspeed/runtime/zero/utils.py index 4554e3c0..0017213a 100755 --- a/deepspeed/runtime/zero/utils.py +++ b/deepspeed/runtime/zero/utils.py @@ -1,8 +1,12 @@ +import os +from typing import List + import torch import torch.distributed as dist from deepspeed.utils import logger from deepspeed.ops.adam import DeepSpeedCPUAdam from deepspeed.ops.adam import FusedAdam +from deepspeed.utils.nvtx import instrument_w_nvtx def _initialize_parameter_parallel_groups(parameter_parallel_size=None): @@ -47,5 +51,37 @@ def is_zero_supported_optimizer(optimizer): return type(optimizer) in ZERO_SUPPORTED_OPTIMIZERS +def get_lst_from_rank0(lst: List[int]) -> None: + """ + NOTE: creates both communication and synchronization overhead so should be used + sparingly + """ + lst_tensor = torch.tensor( + lst if dist.get_rank() == 0 else [-1] * len(lst), + dtype=int, + # device=torch.cuda.current_device(), + device=torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"])), + requires_grad=False, + ) + dist.broadcast(lst_tensor, src=0, async_op=False) + + return list(lst_tensor.cpu().numpy()) + + +@instrument_w_nvtx +def assert_ints_same_as_other_ranks(ints: List[int]) -> None: + """ + NOTE: creates both communication and synchronization overhead so should be + used sparingly + + takes a list of ints from each rank and ensures that they are the same + across ranks, throwing an exception if they are not. + """ + rank0_ints = get_lst_from_rank0(ints) + if ints != rank0_ints: + raise RuntimeError(f"disagreement between rank0 and rank{dist.get_rank()}: " + f"rank0: {rank0_ints}, rank{dist.get_rank()}: {ints}") + + class ZeRORuntimeException(Exception): pass diff --git a/deepspeed/utils/__init__.py b/deepspeed/utils/__init__.py index 892eea16..efa45853 100644 --- a/deepspeed/utils/__init__.py +++ b/deepspeed/utils/__init__.py @@ -1,4 +1,5 @@ from .logging import logger, log_dist from .distributed import init_distributed from .groups import * +from .nvtx import instrument_w_nvtx from deepspeed.runtime.dataloader import RepeatingLoader diff --git a/deepspeed/utils/nvtx.py b/deepspeed/utils/nvtx.py new file mode 100644 index 00000000..11579ffb --- /dev/null +++ b/deepspeed/utils/nvtx.py @@ -0,0 +1,15 @@ +import torch + + +def instrument_w_nvtx(func): + """decorator that causes an NVTX range to be recorded for the duration of the + function call.""" + if hasattr(torch.cuda.nvtx, "range"): + + def wrapped_fn(*args, **kwargs): + with torch.cuda.nvtx.range(func.__qualname__): + return func(*args, **kwargs) + + return wrapped_fn + else: + return func diff --git a/deepspeed/utils/timer.py b/deepspeed/utils/timer.py index 386784d1..6f5076ea 100755 --- a/deepspeed/utils/timer.py +++ b/deepspeed/utils/timer.py @@ -178,11 +178,15 @@ class ThroughputTimer: self.total_elapsed_time += duration if self.local_step_count % self.steps_per_output == 0: if report_speed: - self.logging("{}/{}, SamplesPerSec={}".format( - self.epoch_count, - self.local_step_count, - self.avg_samples_per_sec(), - )) + self.logging( + "{}/{}, SamplesPerSec={}, MemAllocated={}GB, MaxMemAllocated={}GB" + .format(self.epoch_count, + self.local_step_count, + self.avg_samples_per_sec(), + round(torch.cuda.memory_allocated() / 1024**3, + 2), + round(torch.cuda.max_memory_allocated() / 1024**3, + 2))) if self.monitor_memory: virt_mem = psutil.virtual_memory() swap = psutil.swap_memory() diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md index d7b47de2..84a7cd1f 100755 --- a/docs/_pages/config-json.md +++ b/docs/_pages/config-json.md @@ -251,7 +251,7 @@ Example of **scheduler** | Configuration for using [bfloat16](https://en.wikipedia.org/wiki/Bfloat16_floating-point_format) floating-point format as an alternative to FP16. BFLOAT16 requires hardware support (e.g., NVIDIA A100). An example, including the available dictionary keys is illustrated below. Training with bfloat16 does not require loss scaling. | None | ```json -"bfloat16": { +"bf16": { "enabled": true } ``` @@ -329,7 +329,7 @@ Enabling and configuring ZeRO memory optimizations "stage3_param_persistence_threshold" : 1e6, "sub_group_size" : 1e12, "elastic_checkpoint" : [true|false], - "stage3_gather_fp16_weights_on_model_save": [true|false], + "stage3_gather_16bit_weights_on_model_save": [true|false], "ignore_unused_parameters": [true|false] "round_robin_gradients": [true|false] } @@ -433,11 +433,11 @@ Enabling and configuring ZeRO memory optimizations | Do not partition parameters smaller than this threshold. Smaller values use less memory, but can greatly increase communication (especially latency-bound messages). | `1e6` | -***stage3_gather_fp16_weights_on_model_save***: [boolean] +***stage3_gather_16bit_weights_on_model_save***: [boolean] | Description | Default | |--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| ------- | -| Consolidate the weights before saving the model by `save_fp16_model()`. Since the weights are partitioned across GPUs, they aren't part of `state_dict`, so this function automatically gathers the weights when this option is enabled and then saves the fp16 model weights. | `False` | +| Consolidate the weights before saving the model by `save_16bit_model()`. Since the weights are partitioned across GPUs, they aren't part of `state_dict`, so this function automatically gathers the weights when this option is enabled and then saves the fp16 model weights. | `False` | ***cpu_offload***: [boolean] diff --git a/docs/_tutorials/zero.md b/docs/_tutorials/zero.md index 411ddf34..7721f45e 100644 --- a/docs/_tutorials/zero.md +++ b/docs/_tutorials/zero.md @@ -252,19 +252,19 @@ If you need to take the pretrained weights out of Deepspeed here is what you can ```json "zero_optimization": { - "stage3_gather_fp16_weights_on_model_save": true + "stage3_gather_16bit_weights_on_model_save": true }, ``` And then save the model using: ```python if self.deepspeed: - self.deepspeed.save_fp16_model(output_dir, output_file) + self.deepspeed.save_16bit_model(output_dir, output_file) ``` Because it requires consolidation of the weights on one GPU it can be slow and memory demanding, so only use this feature when needed. -Note that if `stage3_gather_fp16_weights_on_model_save` is `False`, no weights will be saved (again, because `state_dict` doesn't have them). +Note that if `stage3_gather_16bit_weights_on_model_save` is `False`, no weights will be saved (again, because `state_dict` doesn't have them). You can use this method to save ZeRO-2 weights as well. If you'd like to get the fp32 weights, we supply a special script that can do offline consolidation. It requires no configuration files or GPUs. Here is an example of its usage: diff --git a/docs/code-docs/source/training.rst b/docs/code-docs/source/training.rst index 52e124fc..e3a7029a 100644 --- a/docs/code-docs/source/training.rst +++ b/docs/code-docs/source/training.rst @@ -35,7 +35,7 @@ Gradient Accumulation Model Saving ------------ -.. autofunction:: deepspeed.DeepSpeedEngine.save_fp16_model +.. autofunction:: deepspeed.DeepSpeedEngine.save_16bit_model Additionally when a DeepSpeed checkpoint is created, a script ``zero_to_fp32.py`` is added there which can be used to reconstruct fp32 master weights into a single pytorch ``state_dict`` file. diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/megatron_model.py b/tests/unit/megatron_model.py index 0957eab0..fd2ef69b 100644 --- a/tests/unit/megatron_model.py +++ b/tests/unit/megatron_model.py @@ -1,9 +1,10 @@ +from pathlib import Path import torch import os import sys import math -from common import get_test_path +from .common import get_test_path from deepspeed.pipe import PipelineModule, LayerSpec diff --git a/tests/unit/test_activation_checkpointing.py b/tests/unit/test_activation_checkpointing.py index 6c96b5a3..e66f2abf 100644 --- a/tests/unit/test_activation_checkpointing.py +++ b/tests/unit/test_activation_checkpointing.py @@ -10,7 +10,7 @@ import deepspeed ckpt = deepspeed.checkpointing.checkpoint -from common import distributed_test +from .common import distributed_test def _compute(module, *inputs, do_checkpoint=False): diff --git a/tests/unit/test_adamw.py b/tests/unit/test_adamw.py index 83e0b543..b4bfbf3c 100644 --- a/tests/unit/test_adamw.py +++ b/tests/unit/test_adamw.py @@ -2,10 +2,10 @@ import deepspeed import torch import pytest -from common import distributed_test from deepspeed.ops.adam import FusedAdam from deepspeed.ops.adam import DeepSpeedCPUAdam -from simple_model import SimpleModel, args_from_dict +from .common import distributed_test +from .simple_model import SimpleModel, args_from_dict # yapf: disable #'optimizer, zero_offload, torch_adam, adam_w_mode, resulting_optimizer diff --git a/tests/unit/test_aio.py b/tests/unit/test_aio.py index 48272bad..fdec95a3 100755 --- a/tests/unit/test_aio.py +++ b/tests/unit/test_aio.py @@ -4,8 +4,8 @@ import filecmp import torch import deepspeed import torch.distributed as dist -from common import distributed_test from deepspeed.ops.aio import AsyncIOBuilder +from .common import distributed_test MEGA_BYTE = 1024**2 BLOCK_SIZE = MEGA_BYTE diff --git a/tests/unit/test_autotuning.py b/tests/unit/test_autotuning.py index 96617f8f..2a7898b8 100644 --- a/tests/unit/test_autotuning.py +++ b/tests/unit/test_autotuning.py @@ -1,7 +1,7 @@ import os import pytest import torch -from simple_model import create_config_from_dict +from .simple_model import create_config_from_dict from deepspeed.launcher import runner as dsrun from deepspeed.autotuning.autotuner import Autotuner from deepspeed.autotuning.scheduler import ResourceManager diff --git a/tests/unit/test_bf16.py b/tests/unit/test_bf16.py index 9220ce7e..aa2ab132 100644 --- a/tests/unit/test_bf16.py +++ b/tests/unit/test_bf16.py @@ -3,10 +3,10 @@ import torch import deepspeed import pytest from deepspeed.ops.adam import FusedAdam -from common import distributed_test +from .common import distributed_test from deepspeed.ops.op_builder import CPUAdamBuilder -from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict -from util import bf16_required_version_check +from .simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict +from .util import bf16_required_version_check @pytest.mark.parametrize('zero_stage, use_cpu_offload', [(2, False)]) @@ -45,7 +45,7 @@ def test_adam_bf16_zero_onecycle_compatibility(tmpdir, zero_stage, use_cpu_offlo "fp16": { "enabled": False }, - "bfloat16": { + "bf16": { "enabled": True }, "zero_optimization": { @@ -95,7 +95,7 @@ def test_zero_allow_untested_optimizer(tmpdir, zero_stage, use_cpu_offload): "fp16": { "enabled": False, }, - "bfloat16": { + "bf16": { "enabled": True }, "zero_optimization": { @@ -139,7 +139,7 @@ def test_zero_empty_partition(tmpdir, zero_stage, use_cpu_offload): "fp16": { "enabled": False }, - "bfloat16": { + "bf16": { "enabled": True }, "optimizer": { @@ -199,7 +199,7 @@ def test_zero_supported_client_optimizer(tmpdir, zero_stage, optimizer_construct "fp16": { "enabled": False }, - "bfloat16": { + "bf16": { "enabled": True }, "zero_optimization": { @@ -250,7 +250,7 @@ def test_zero2_reduce_scatter_off(tmpdir): "fp16": { "enabled": False }, - "bfloat16": { + "bf16": { "enabled": True } } @@ -290,7 +290,7 @@ def test_zero_empty_grad(tmpdir, stage): "fp16": { "enabled": False }, - "bfloat16": { + "bf16": { "enabled": True }, "zero_optimization": { diff --git a/tests/unit/test_checkpointing.py b/tests/unit/test_checkpointing.py index 28c2099e..2f487eef 100755 --- a/tests/unit/test_checkpointing.py +++ b/tests/unit/test_checkpointing.py @@ -16,7 +16,7 @@ PipeTopo = PipeDataParallelTopology from deepspeed.ops.op_builder import FusedLambBuilder, CPUAdamBuilder from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3 -from util import required_torch_version +from .util import required_torch_version import itertools import argparse @@ -24,8 +24,8 @@ import pytest import json import os import numbers -from common import distributed_test -from simple_model import * +from .common import distributed_test +from .simple_model import * def compare_deepspeed_states(saved_model, loaded_model): diff --git a/tests/unit/test_coalesced_collectives.py b/tests/unit/test_coalesced_collectives.py new file mode 100644 index 00000000..fb6b5354 --- /dev/null +++ b/tests/unit/test_coalesced_collectives.py @@ -0,0 +1,62 @@ +"""unit tests for coalesced collectives""" + +import pytest + +import torch +import torch.distributed as dist +from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced + +from .common import distributed_test + + +@distributed_test(world_size=2) +def test_reduce_scatter_coalesced_single_input(): + input = torch.full((6, + ), + dist.get_rank(), + dtype=torch.half, + device=torch.cuda.current_device()) + + (output, ) = reduce_scatter_coalesced([input], dist.group.WORLD) + + assert output.shape == (3, ) + assert torch.allclose(output, torch.full_like(output, 0.5)) + + +@distributed_test(world_size=2) +def test_reduce_scatter_coalesced_two_inputs(): + tensor_kwargs = {"device": torch.cuda.current_device(), "dtype": torch.half} + inputs = [ + dist.get_rank() * torch.arange(0, + 6, + **tensor_kwargs), + dist.get_rank() * torch.arange(6, + 9, + **tensor_kwargs), + ] + + output1, output2 = reduce_scatter_coalesced(inputs, dist.group.WORLD) + + if dist.get_rank() == 0: + assert output1.shape == (3, ) + assert torch.allclose(output1, torch.arange(0, 3, **tensor_kwargs) / 2) + assert output2.shape == (2, ) + assert torch.allclose(output2, torch.arange(6, 8, **tensor_kwargs) / 2) + elif dist.get_rank() == 1: + assert output1.shape == (3, ) + assert torch.allclose(output1, torch.arange(3, 6, **tensor_kwargs) / 2) + assert output2.shape == (1, ) + assert torch.allclose(output2, torch.arange(8, 9, **tensor_kwargs) / 2) + + +@distributed_test(world_size=2) +def test_reduce_scatter_coalesced_tensor_smaller_than_world_sz(): + input = torch.zeros((1, ), dtype=torch.half, device=torch.cuda.current_device()) + + (output, ) = reduce_scatter_coalesced([input], dist.group.WORLD) + + if dist.get_rank() == 0: + assert output.shape == (1, ) + assert torch.allclose(output, torch.zeros_like(output)) + elif dist.get_rank() == 1: + assert output.shape == (0, ) diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index ad06a851..a88cb293 100755 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -3,13 +3,16 @@ import torch import pytest import json import argparse -from common import distributed_test, get_test_path -from simple_model import SimpleModel, create_config_from_dict, random_dataloader + +from deepspeed.runtime.zero.config import DeepSpeedZeroConfig + +from .common import distributed_test, get_test_path +from .simple_model import SimpleModel, create_config_from_dict, random_dataloader import torch.distributed as dist # A test on its own import deepspeed -from deepspeed.runtime.config import DeepSpeedConfig +from deepspeed.runtime.config import DeepSpeedConfig, get_bfloat16_enabled def test_cuda(): @@ -114,6 +117,32 @@ def test_temp_config_json(tmpdir): assert 'train_batch_size' in config_json +@pytest.mark.parametrize("gather_weights_key", + [ + "stage3_gather_16bit_weights_on_model_save", + "stage3_gather_fp16_weights_on_model_save" + ]) +def test_gather_16bit_params_on_model_save(gather_weights_key): + config_dict = { + "zero_optimization": { + gather_weights_key: True, + }, + } + config = DeepSpeedZeroConfig(config_dict) + + assert config.gather_16bit_weights_on_model_save == True + + +@pytest.mark.parametrize("bf16_key", ["bf16", "bfloat16"]) +def test_get_bfloat16_enabled(bf16_key): + cfg = { + bf16_key: { + "enabled": True, + }, + } + assert get_bfloat16_enabled(cfg) == True + + def test_deprecated_deepscale_config(tmpdir): config_dict = { "train_batch_size": 1, diff --git a/tests/unit/test_configurable_parallel.py b/tests/unit/test_configurable_parallel.py index e6933421..d31e89a7 100755 --- a/tests/unit/test_configurable_parallel.py +++ b/tests/unit/test_configurable_parallel.py @@ -7,10 +7,10 @@ import random import numpy as np import torch.multiprocessing as mp import torch.distributed as dist -from common import distributed_test -from simple_model import args_from_dict, create_deepspeed_args -from megatron_model import get_gpt2_model, get_megatron_version -from megatron_model import MockGPT2ModelPipe as GPT2ModelPipe +from .common import distributed_test +from .simple_model import args_from_dict, create_deepspeed_args +from .megatron_model import get_gpt2_model, get_megatron_version +from .megatron_model import MockGPT2ModelPipe as GPT2ModelPipe from deepspeed.utils import RepeatingLoader TORCH_MAJOR = int(torch.__version__.split('.')[0]) diff --git a/tests/unit/test_cuda_backward.py b/tests/unit/test_cuda_backward.py index 4a449c39..d7faee7c 100755 --- a/tests/unit/test_cuda_backward.py +++ b/tests/unit/test_cuda_backward.py @@ -1,20 +1,13 @@ -import argparse import numpy as np import torch import torch.nn.functional as F import pytest -import json import random -import time import copy from torch import nn -from modelingpreln import BertEncoder as BertEncoderPreln -from modeling import BertEncoder as BertEncoderPostln -from modeling import BertConfig, BertLayerNorm from deepspeed import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig -import deepspeed - -import sys +from .modeling import BertConfig, BertLayerNorm, BertEncoder as BertEncoderPostln +from .modelingpreln import BertEncoder as BertEncoderPreln #if not deepspeed.ops.__installed_ops__['transformer']: #pytest.skip( diff --git a/tests/unit/test_cuda_forward.py b/tests/unit/test_cuda_forward.py index 6f63b695..e07ef16a 100755 --- a/tests/unit/test_cuda_forward.py +++ b/tests/unit/test_cuda_forward.py @@ -8,9 +8,8 @@ import random import time import copy from torch import nn -from modelingpreln import BertEncoder as BertEncoderPreln -from modeling import BertEncoder as BertEncoderPostln -from modeling import BertLayerNorm, BertConfig +from .modelingpreln import BertEncoder as BertEncoderPreln +from .modeling import BertLayerNorm, BertConfig, BertEncoder as BertEncoderPostln from deepspeed import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig import deepspeed diff --git a/tests/unit/test_curriculum_learning.py b/tests/unit/test_curriculum_learning.py index cf0562ab..3677b596 100644 --- a/tests/unit/test_curriculum_learning.py +++ b/tests/unit/test_curriculum_learning.py @@ -7,8 +7,8 @@ import json import os import numpy as np import time -from common import distributed_test -from simple_model import Curriculum_SimpleModel, random_dataloader, args_from_dict +from .common import distributed_test +from .simple_model import Curriculum_SimpleModel, random_dataloader, args_from_dict def test_curriculum_scheduler_fixed_discrete(tmpdir): diff --git a/tests/unit/test_data.py b/tests/unit/test_data.py index 290ca0c3..93510e55 100644 --- a/tests/unit/test_data.py +++ b/tests/unit/test_data.py @@ -2,8 +2,8 @@ from deepspeed.utils import RepeatingLoader import torch import pytest import deepspeed -from common import distributed_test -from simple_model import SimpleModel, args_from_dict, random_dataset +from .common import distributed_test +from .simple_model import SimpleModel, args_from_dict, random_dataset def test_repeating_loader(): diff --git a/tests/unit/test_dist.py b/tests/unit/test_dist.py index 25a5fd22..d3713360 100644 --- a/tests/unit/test_dist.py +++ b/tests/unit/test_dist.py @@ -1,7 +1,7 @@ import torch import torch.distributed as dist -from common import distributed_test +from .common import distributed_test import pytest diff --git a/tests/unit/test_ds_initialize.py b/tests/unit/test_ds_initialize.py index 04e6545b..a9756af6 100644 --- a/tests/unit/test_ds_initialize.py +++ b/tests/unit/test_ds_initialize.py @@ -4,9 +4,9 @@ import torch from torch.optim import Optimizer, Adam, AdamW from torch.optim.lr_scheduler import _LRScheduler, LambdaLR -from simple_model import args_from_dict, SimpleModel, random_dataloader -from common import distributed_test -from util import required_torch_version +from .simple_model import args_from_dict, SimpleModel, random_dataloader +from .common import distributed_test +from .util import required_torch_version import deepspeed from deepspeed.ops.adam import FusedAdam diff --git a/tests/unit/test_dynamic_loss_scale.py b/tests/unit/test_dynamic_loss_scale.py index 302de55c..65a679d9 100755 --- a/tests/unit/test_dynamic_loss_scale.py +++ b/tests/unit/test_dynamic_loss_scale.py @@ -5,8 +5,8 @@ import pytest import json import os import numpy as np -from common import distributed_test -from simple_model import SimpleModel, args_from_dict +from .common import distributed_test +from .simple_model import SimpleModel, args_from_dict def run_model_step(model, gradient_list): diff --git a/tests/unit/test_elastic.py b/tests/unit/test_elastic.py index 62d948d5..353d6def 100644 --- a/tests/unit/test_elastic.py +++ b/tests/unit/test_elastic.py @@ -1,8 +1,8 @@ import pytest import deepspeed -from common import distributed_test +from .common import distributed_test from deepspeed.git_version_info import version as ds_version -from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict +from .simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict base_ds_config = { "elasticity": { diff --git a/tests/unit/test_flops_profiler.py b/tests/unit/test_flops_profiler.py index 48b8f983..173fa7ee 100644 --- a/tests/unit/test_flops_profiler.py +++ b/tests/unit/test_flops_profiler.py @@ -3,8 +3,8 @@ import pytest import deepspeed import deepspeed.runtime.utils as ds_utils from deepspeed.profiling.flops_profiler import FlopsProfiler, get_model_profile -from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict -from common import distributed_test +from .simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict +from .common import distributed_test TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MINOR = int(torch.__version__.split('.')[1]) diff --git a/tests/unit/test_fp16.py b/tests/unit/test_fp16.py index e5b5b8ef..fa21e263 100755 --- a/tests/unit/test_fp16.py +++ b/tests/unit/test_fp16.py @@ -8,10 +8,10 @@ import pytest import json import os from deepspeed.ops.adam import FusedAdam -from common import distributed_test +from .common import distributed_test from deepspeed.ops.op_builder import CPUAdamBuilder -from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict, create_deepspeed_args, SimpleMoEModel, sequence_dataloader -from util import required_torch_version +from .simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict, create_deepspeed_args, SimpleMoEModel, sequence_dataloader +from .util import required_torch_version try: from apex import amp diff --git a/tests/unit/test_ignore_unused_parameters.py b/tests/unit/test_ignore_unused_parameters.py index 19fd5081..eb26f46c 100644 --- a/tests/unit/test_ignore_unused_parameters.py +++ b/tests/unit/test_ignore_unused_parameters.py @@ -3,8 +3,8 @@ import pytest import json import argparse import os -from common import distributed_test -from simple_model import UnusedParametersModel, random_dataloader, args_from_dict +from .common import distributed_test +from .simple_model import UnusedParametersModel, random_dataloader, args_from_dict from deepspeed.ops.op_builder import CPUAdamBuilder import deepspeed diff --git a/tests/unit/test_lr_schedulers.py b/tests/unit/test_lr_schedulers.py index 85966494..47bcfb1e 100755 --- a/tests/unit/test_lr_schedulers.py +++ b/tests/unit/test_lr_schedulers.py @@ -4,8 +4,8 @@ import argparse import pytest import json import os -from common import distributed_test -from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict +from .common import distributed_test +from .simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict from deepspeed.runtime.lr_schedules import LR_RANGE_TEST, LR_RANGE_TEST_MIN_LR, LR_RANGE_TEST_STEP_RATE, LR_RANGE_TEST_STEP_SIZE, LR_RANGE_TEST_STAIRCASE from deepspeed.runtime.lr_schedules import WARMUP_LR, WARMUP_MIN_LR, WARMUP_MAX_LR, WARMUP_NUM_STEPS, WARMUP_TYPE, WARMUP_LOG_RATE, WARMUP_LINEAR_RATE from deepspeed.runtime.lr_schedules import ONE_CYCLE, CYCLE_MIN_LR, CYCLE_MAX_LR, CYCLE_FIRST_STEP_SIZE, DECAY_LR_RATE, DECAY_STEP_SIZE diff --git a/tests/unit/test_moe.py b/tests/unit/test_moe.py index 126ac510..381e6082 100644 --- a/tests/unit/test_moe.py +++ b/tests/unit/test_moe.py @@ -8,10 +8,10 @@ import pytest import json import os from deepspeed.ops.adam import FusedAdam -from common import distributed_test +from .common import distributed_test from deepspeed.ops.op_builder import CPUAdamBuilder -from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict, create_deepspeed_args, SimpleMoEModel, sequence_dataloader -from util import required_torch_version +from .simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict, create_deepspeed_args, SimpleMoEModel, sequence_dataloader +from .util import required_torch_version try: from apex import amp diff --git a/tests/unit/test_multi_output_model.py b/tests/unit/test_multi_output_model.py index ccbe7f48..478bdc8d 100755 --- a/tests/unit/test_multi_output_model.py +++ b/tests/unit/test_multi_output_model.py @@ -5,9 +5,9 @@ import pytest from pytest import approx import json import os -from common import distributed_test -from simple_model import args_from_dict -from multi_output_model import MultiOutputModel, multi_output_dataloader +from .common import distributed_test +from .simple_model import args_from_dict +from .multi_output_model import MultiOutputModel, multi_output_dataloader def create_config_dict(micro_batch_size, grad_accumulation_steps, world_size): diff --git a/tests/unit/test_onebit.py b/tests/unit/test_onebit.py index b1a0d8ba..6f113c83 100644 --- a/tests/unit/test_onebit.py +++ b/tests/unit/test_onebit.py @@ -15,9 +15,9 @@ from deepspeed.runtime.pipe.topology import PipeDataParallelTopology, PipeModelD PipeTopo = PipeDataParallelTopology from deepspeed.runtime.pipe.module import PipelineModule, LayerSpec -from common import distributed_test -from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict, create_deepspeed_args -from test_pipe import AlexNetPipe, train_cifar +from .common import distributed_test +from .simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict, create_deepspeed_args +from .test_pipe import AlexNetPipe, train_cifar TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MINOR = int(torch.__version__.split('.')[1]) diff --git a/tests/unit/test_partition.py b/tests/unit/test_partition.py index 7cd26475..f766e459 100644 --- a/tests/unit/test_partition.py +++ b/tests/unit/test_partition.py @@ -8,7 +8,7 @@ from deepspeed.runtime.utils import partition_balanced from deepspeed.runtime.utils import prefix_sum_inc from deepspeed.runtime.utils import PartitionedTensor -from common import distributed_test +from .common import distributed_test @distributed_test(world_size=4) diff --git a/tests/unit/test_pipe.py b/tests/unit/test_pipe.py index 495d4d72..f7f2b1a1 100755 --- a/tests/unit/test_pipe.py +++ b/tests/unit/test_pipe.py @@ -17,7 +17,7 @@ from deepspeed.runtime.pipe.topology import PipeDataParallelTopology, PipeModelD PipeTopo = PipeDataParallelTopology from deepspeed.runtime.pipe.module import PipelineModule, LayerSpec -from common import distributed_test +from .common import distributed_test def rel_diff(A, B): @@ -25,7 +25,7 @@ def rel_diff(A, B): # All models -from simple_model import args_from_dict +from .simple_model import args_from_dict class AlexNet(nn.Module): diff --git a/tests/unit/test_pipe_module.py b/tests/unit/test_pipe_module.py index e4eb3e53..28110149 100644 --- a/tests/unit/test_pipe_module.py +++ b/tests/unit/test_pipe_module.py @@ -15,8 +15,8 @@ PipeTopo = PipeDataParallelTopology from deepspeed.pipe import PipelineModule, LayerSpec from deepspeed.utils import RepeatingLoader -from common import distributed_test -from simple_model import args_from_dict +from .common import distributed_test +from .simple_model import args_from_dict HIDDEN_DIM = 32 LAYERS = 8 diff --git a/tests/unit/test_pld.py b/tests/unit/test_pld.py index d8fa8488..5d275d16 100755 --- a/tests/unit/test_pld.py +++ b/tests/unit/test_pld.py @@ -2,8 +2,9 @@ import numpy as np import deepspeed import pytest from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop -from common import distributed_test -from simple_model import SimpleModel, PLD_SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict + +from .common import distributed_test +from .simple_model import SimpleModel, PLD_SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict @pytest.mark.parametrize('theta', [0, 0.1, 0.9, 1.0]) diff --git a/tests/unit/test_runtime_utils.py b/tests/unit/test_runtime_utils.py index 612fa130..c27f3e74 100644 --- a/tests/unit/test_runtime_utils.py +++ b/tests/unit/test_runtime_utils.py @@ -8,7 +8,7 @@ import deepspeed.runtime.utils as ds_utils from deepspeed.utils.logging import log_dist import deepspeed.utils.groups as groups -from common import distributed_test +from .common import distributed_test def test_call_to_str(): diff --git a/tests/unit/test_sparse_grads.py b/tests/unit/test_sparse_grads.py index 765550ae..c0e72721 100644 --- a/tests/unit/test_sparse_grads.py +++ b/tests/unit/test_sparse_grads.py @@ -2,7 +2,8 @@ import torch import torch.distributed as dist import deepspeed import pytest -from common import distributed_test +from .common import distributed_test + import deepspeed.utils.groups as groups diff --git a/tests/unit/test_topology.py b/tests/unit/test_topology.py index 17636368..89bb8ec3 100644 --- a/tests/unit/test_topology.py +++ b/tests/unit/test_topology.py @@ -7,7 +7,7 @@ from deepspeed.runtime.pipe.topology import PipelineParallelGrid as Grid from deepspeed.runtime.pipe.topology import ProcessTopology as Topo from deepspeed.runtime.pipe.topology import _prime_factors -from common import distributed_test +from .common import distributed_test def test_topology_2d(): diff --git a/tests/unit/test_zero.py b/tests/unit/test_zero.py index b58eacfa..c2ff33a1 100755 --- a/tests/unit/test_zero.py +++ b/tests/unit/test_zero.py @@ -1,14 +1,20 @@ -import torch +import math +from typing import Dict, List, Set import pytest -import json -import argparse -import os import torch.distributed as dist +import torch +from torch import Tensor +from torch.nn import Linear, Module +from torch.nn.modules.container import ModuleList +from torch.nn.modules.loss import L1Loss +from torch.nn.parameter import Parameter -from common import distributed_test -from simple_model import SimpleModel, random_dataloader, args_from_dict +from .common import distributed_test +from .simple_model import SimpleModel, random_dataloader, args_from_dict import deepspeed +from deepspeed.runtime.engine import DeepSpeedEngine +from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint @@ -429,3 +435,747 @@ def test_partition_nccl_alignment(tmpdir, zero_stage, world_size): (2 * nccl_start_alignment_factor) == 0) _test_partition_nccl_alignment(model=model, hidden_dim=hidden_dim) + + +def _ds_initialize_for_param_partitioning_testing(model: Module, + cfg: dict) -> DeepSpeedEngine: + ds_engine, _, _, _ = deepspeed.initialize( + config=cfg, + model=model, + model_parameters=model.parameters() + ) + + return ds_engine + + +def _assert_partition_status(model: Module, + valid_statuses: Set[ZeroParamStatus]) -> None: + for _, param in model.named_parameters(): + assert param.ds_status in valid_statuses, param.ds_summary() + + +def _assert_fully_available(model: Module) -> None: + for _, param in model.named_parameters(): + assert param.ds_status == ZeroParamStatus.AVAILABLE + + +class EltwiseMultiplicationModule(Module): + def __init__(self, weight: Parameter) -> None: + super().__init__() + self.weight = weight + + def forward(self, x: Tensor) -> Tensor: + _assert_fully_available(self) + result = self.weight * x + + return result + + +class EltwiseMultiplicationTestNetwork(Module): + """used for testing purposes""" + def __init__( + self, + weight1: Parameter, + weight2: Parameter, + weight3: Parameter, + ) -> None: + super().__init__() + self.__layer1 = EltwiseMultiplicationModule(weight1) + self.__layer2 = EltwiseMultiplicationModule(weight2) + self.__layer3 = EltwiseMultiplicationModule(weight3) + + self.loss = L1Loss(reduction="none") + + def forward(self, x: Tensor, y: Tensor, prefetching: bool) -> Dict[str, Tensor]: + _assert_partition_status( + self, + { + ZeroParamStatus.NOT_AVAILABLE, + ZeroParamStatus.INFLIGHT, + ZeroParamStatus.AVAILABLE + } if prefetching else {ZeroParamStatus.NOT_AVAILABLE}) + + layerwise_expected_states = { + ZeroParamStatus.INFLIGHT if prefetching else ZeroParamStatus.NOT_AVAILABLE, + ZeroParamStatus.AVAILABLE, + } + + _assert_partition_status(self.__layer1, layerwise_expected_states) + hidden1 = self.__layer1(x) + _assert_partition_status(self.__layer1, {ZeroParamStatus.NOT_AVAILABLE}) + + _assert_partition_status(self.__layer2, layerwise_expected_states) + hidden2 = self.__layer2(hidden1) + _assert_partition_status(self.__layer2, {ZeroParamStatus.NOT_AVAILABLE}) + + _assert_partition_status(self.__layer3, layerwise_expected_states) + y_hat = self.__layer3(hidden2) + _assert_partition_status(self.__layer3, + { + ZeroParamStatus.AVAILABLE + if prefetching else ZeroParamStatus.NOT_AVAILABLE + }) + + loss = self.loss(y_hat, y) + + _assert_partition_status( + self, + { + ZeroParamStatus.NOT_AVAILABLE, + ZeroParamStatus.INFLIGHT, + ZeroParamStatus.AVAILABLE + } if prefetching else {ZeroParamStatus.NOT_AVAILABLE}) + + return { + "hidden1": hidden1, + "hidden2": hidden2, + "y_hat": y_hat, + "loss": loss, + } + + +@pytest.mark.parametrize("param_persistence_threshold", [0, 10]) +@pytest.mark.parametrize("fp16_enabled", [True, False]) +@pytest.mark.parametrize("contiguous_gradients", [True, False]) +@pytest.mark.parametrize("offload_optimizer", [True, False]) +@pytest.mark.parametrize("zero_grad", [True, False]) +@pytest.mark.parametrize("iteration", list(range(1))) +def test_zero3_param_partitioning_base( + param_persistence_threshold: int, + fp16_enabled: bool, + contiguous_gradients: bool, + offload_optimizer: bool, + zero_grad: bool, + iteration: int, +) -> None: + @distributed_test(world_size=[2]) + def _test_zero3_param_partitioning(): + if offload_optimizer and not contiguous_gradients: + return + + m = 3 + n = 5 + weights = [Parameter(torch.zeros((m, n), dtype=torch.float32)) for _ in range(3)] + model = EltwiseMultiplicationTestNetwork(*weights) + + cfg = { + "train_micro_batch_size_per_gpu": 1, + "zero_optimization": { + "stage": 3, + "stage3_max_reuse_distance": 0, + "stage3_param_persistence_threshold": param_persistence_threshold, + "contiguous_gradients": contiguous_gradients, + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1. + } + }, + "fp16": { + "enabled": fp16_enabled, + "loss_scale": 1., + } + } + + if offload_optimizer: + cfg["zero_optimization"]["offload_optimizer"] = { + "device": "cpu", + "pin_memory": True, + } + + ds_engine = _ds_initialize_for_param_partitioning_testing(model, cfg) + for i, weight in enumerate(weights): + weight.ds_tensor.data = torch.full_like(weight.ds_tensor.data, + (i + 1) * (1 + dist.get_rank())) + + def create_tensor(vals, dtype: torch.dtype = None) -> Tensor: + return torch.as_tensor(vals, + dtype=dtype + or (torch.float16 if fp16_enabled else torch.float32), + device=ds_engine.device) + + expected_hidden1 = create_tensor([ + [1, + 1, + 1, + 1, + 1], + [1, + 1, + 1, + 2, + 2], + [2, + 2, + 2, + 2, + 2], + ]) + expected_hidden2 = create_tensor([ + [2, + 2, + 2, + 2, + 2], + [2, + 2, + 2, + 8, + 8], + [8, + 8, + 8, + 8, + 8], + ]) + expected_yhat = create_tensor([[6, + 6, + 6, + 6, + 6], + [6, + 6, + 6, + 48, + 48], + [48, + 48, + 48, + 48, + 48]]) + expected_loss = create_tensor([ + [5, + 5, + 5, + 5, + 5], + [5, + 5, + 5, + 47, + 47], + [47, + 47, + 47, + 47, + 47], + ]) + + for train_iter in range(3): + activations = ds_engine( + x=torch.ones((m, + n), + dtype=torch.float16 if fp16_enabled else torch.float32, + device=ds_engine.device), + y=torch.ones((m, + n), + dtype=torch.float16 if fp16_enabled else torch.float32, + device=ds_engine.device), + prefetching=train_iter > 0, + ) + assert torch.allclose(activations["hidden1"], expected_hidden1) + assert torch.allclose(activations["hidden2"], expected_hidden2) + assert torch.allclose(activations["y_hat"], expected_yhat) + assert torch.allclose(activations["loss"], expected_loss) + + ds_engine.backward(activations["loss"].sum()) + + # check the gradients + grad_partitions = ds_engine.optimizer.get_fp32_grad_partitions() + assert set(grad_partitions.keys()) == {0}, f"should have one parameter group but got {len(grad_partitions)}" + assert set(grad_partitions[0].keys()) == {0, 1, 2} + dloss_wrt_layer1 = grad_partitions[0][0] + dloss_wrt_layer2 = grad_partitions[0][1] + dloss_wrt_layer3 = grad_partitions[0][2] + + assert dloss_wrt_layer1.dtype == torch.float + assert dloss_wrt_layer2.dtype == torch.float + assert dloss_wrt_layer3.dtype == torch.float + + # layer1 = [..., 1, 2, ...] + # layer2 = [..., 2, 4, ...] + # layer3 = [..., 3, 6, ...] + # dloss_wrt_layer3 = hidden2 + # dloss_wrt_layer2 = layer3 * hidden1 + # dloss_wrt_layer1 = layer3 * layer2 * x + + grad_multiplier = 1 if zero_grad else (train_iter + 1) + if dist.get_rank() == 0: + assert torch.allclose( + dloss_wrt_layer3.cuda(), + grad_multiplier * create_tensor([2] * 8, + torch.float)) + assert torch.allclose( + dloss_wrt_layer2.cuda(), + grad_multiplier * create_tensor([3 * 1] * 8, + torch.float)) + assert torch.allclose( + dloss_wrt_layer1.cuda(), + grad_multiplier * create_tensor([3 * 2 * 1] * 8, + torch.float)) + elif dist.get_rank() == 1: + # parameters dont split evenly across ranks so rank 1 has a zero-padded + # partition + assert torch.allclose( + dloss_wrt_layer3.cuda(), + grad_multiplier * create_tensor(([8] * 7) + [0], + torch.float)) + assert torch.allclose( + dloss_wrt_layer2.cuda(), + grad_multiplier * create_tensor(([6 * 2] * 7) + [0], + torch.float)) + assert torch.allclose( + dloss_wrt_layer1.cuda(), + grad_multiplier * create_tensor(([6 * 4 * 1] * 7) + [0], + torch.float)) + else: + raise RuntimeError("test has world size of two") + + if zero_grad: + ds_engine.optimizer.zero_grad() + + # TODO. add testing for this - for now we just call it to make sure it + # doesnt throw + ds_engine.optimizer.step() + # taking an optimizer step invalidates all parameters, make sure everything + # has been partitioned afterwards + _assert_partition_status(ds_engine, {ZeroParamStatus.NOT_AVAILABLE}) + assert not math.isclose(ds_engine.optimizer._global_grad_norm, 0.0) + + _test_zero3_param_partitioning() + + +@pytest.mark.parametrize("world_sz", [1, 2, 4]) +@pytest.mark.parametrize("param_sz", [8100]) +@pytest.mark.parametrize("init_context_manager", [True, False]) +def test_zero3_param_partitioning_large_param(world_sz: int, + param_sz: int, + init_context_manager: bool) -> None: + class LargeParamModel(Module): + def __init__(self): + super().__init__() + self.param = Parameter(torch.zeros((param_sz, ), dtype=torch.float32)) + + # only do weight initialization on root rank to + # make sure we are broadcasting correctly from rank 0 + if dist.get_rank() == 0: + partition_sz = math.ceil(self.param.numel() / dist.get_world_size()) + offset = 0 + for rank in range(dist.get_world_size()): + with torch.no_grad(): + self.param[offset:offset + partition_sz].fill_(rank) + offset += partition_sz + + def forward(self, x: Tensor) -> Tensor: + return x * self.param + + @distributed_test(world_size=[world_sz]) + def _distributed_test(): + ds_config = { + "train_micro_batch_size_per_gpu": 1, + "zero_optimization": { + "stage": 3, + "stage3_max_reuse_distance": 0, + "contiguous_gradients": True, + "overlap_comm": True, + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1. + } + }, + "fp16": { + "enabled": True, + "loss_scale": 1., + } + } + with deepspeed.zero.Init(mem_efficient_linear=False, + enabled=init_context_manager): + model = LargeParamModel() + ds_engine = _ds_initialize_for_param_partitioning_testing(model, ds_config) + + for train_iter in range(3): # test multiple iterations to cover prefetching + activation: Tensor = ds_engine( + torch.ones(param_sz, + dtype=torch.float16, + device=ds_engine.device)) + + partition_sz = math.ceil(param_sz / world_sz) + for rank_idx, start_idx in enumerate(range(0, param_sz, partition_sz)): + activation_from_partition = activation[start_idx:start_idx + + partition_sz] + assert torch.allclose( + activation_from_partition, + torch.full_like(activation_from_partition, + rank_idx)) + + ds_engine.backward(activation.sum()) + ds_engine.allreduce_gradients() + + avgd_gradients = ds_engine.optimizer.averaged_gradients + assert set(avgd_gradients.keys()) == {0}, "should only have one parameter group" + weight_gradient, = avgd_gradients[0] + expected_weight_gradient = (train_iter + 1) * torch.full_like( + weight_gradient, + 1) + + assert torch.allclose(weight_gradient, expected_weight_gradient) + + _distributed_test() + + +@pytest.mark.parametrize("world_sz", [1, 2, 4]) +@pytest.mark.parametrize("param_sz", [100, 1_000, 10_000]) +@pytest.mark.parametrize("n_layers", [100, 1_000]) +@pytest.mark.parametrize("init_context_manager", [True, False]) +def test_zero3_param_partitioning_many_params(world_sz: int, + param_sz: int, + n_layers: int, + init_context_manager: bool) -> None: + class ManyParamModel(Module): + def __init__(self) -> None: + super().__init__() + + self.modulelist = ModuleList( + EltwiseMultiplicationModule( + weight=Parameter(torch.empty((param_sz, + ), + dtype=torch.float32))) + for _ in range(n_layers)) + + for layer_num, module in enumerate(self.modulelist): + if dist.get_rank() == 0: + param: Parameter = module.weight + partition_sz = math.ceil(param.numel() / dist.get_world_size()) + offset = 0 + for rank in range(dist.get_world_size()): + with torch.no_grad(): + param[offset:offset + partition_sz].fill_(2 * layer_num * + rank) + offset += partition_sz + + def forward(self, x: Tensor) -> Tensor: + activations = [] + + for module in self.modulelist: + x = module(x) + activations.append(x) + + return activations + + @distributed_test(world_size=[world_sz]) + def _distributed_test(): + ds_cfg = { + "train_micro_batch_size_per_gpu": 1, + "zero_optimization": { + "stage": 3, + "stage3_max_reuse_distance": 0, + "contiguous_gradients": True, + "overlap_comm": True, + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1. + } + }, + "fp16": { + "enabled": True, + "loss_scale": 1., + } + } + + with deepspeed.zero.Init(config=ds_cfg, + mem_efficient_linear=False, + enabled=init_context_manager): + model = ManyParamModel() + + ds_engine = _ds_initialize_for_param_partitioning_testing(model, ds_cfg) + + for _ in range(3): # test multiple iterations to cover prefetching + activations: List[Tensor] = ds_engine( + torch.ones((param_sz, + ), + dtype=torch.float16, + device=ds_engine.device)) + assert len(activations) == n_layers + + partition_sz = math.ceil(param_sz / world_sz) + expected_activations = torch.empty(param_sz, + dtype=torch.float16, + device=ds_engine.device) + for start_idx in range(0, param_sz, partition_sz): + expected_activations[start_idx:start_idx + + partition_sz] = dist.get_rank() + + for layer_num, activation in enumerate(activations): + expected_activations *= 2 * layer_num + assert torch.allclose(activation, expected_activations) + + # TODO. finish writing this test + ds_engine.backward(activations[-1].sum()) + + avgd_gradients = ds_engine.optimizer.averaged_gradients + assert set(avgd_gradients.keys()) == {0}, "should only have one parameter group" + weight_gradients: List[Tensor] = avgd_gradients[0] + + for layer_num, activation in enumerate(weight_gradients): + pass + + _distributed_test() + + +@pytest.mark.parametrize("world_sz", [1, 2, 4]) +def test_zero3_init_for_parent_weight_initialization(world_sz): + class ModelWhereParentInitializesChildWeights(Module): + def __init__(self) -> None: + super().__init__() + + self.linear = Linear(12, 1) + + self.apply(self.__init_weights) + + def __init_weights(self, module): + if isinstance(module, Linear): + with torch.no_grad(): + module.weight.fill_(1 + dist.get_rank()) + + @distributed_test(world_size=[world_sz]) + def _distributed_test(): + ds_cfg = { + "train_micro_batch_size_per_gpu": 1, + "zero_optimization": { + "stage": 3, + "stage3_max_reuse_distance": 0, + "contiguous_gradients": True, + "overlap_comm": True, + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1. + } + }, + "fp16": { + "enabled": True, + "loss_scale": 1., + } + } + + with deepspeed.zero.Init(config=ds_cfg, + mem_efficient_linear=False, + enabled=True): + model = ModelWhereParentInitializesChildWeights() + + assert model.linear.weight.ds_tensor.numel() == math.ceil(12 / world_sz) + assert torch.allclose(model.linear.weight.ds_tensor, + torch.full_like(model.linear.weight.ds_tensor, + 1)) + + _distributed_test() + + +@pytest.mark.skip( + reason="depends on upgraded pytorch and nccl that isnt always available") +@pytest.mark.parametrize("param_persistence_threshold", [0, 10]) +@pytest.mark.parametrize("contiguous_gradients", [True, False]) +@pytest.mark.parametrize("offload_optimizer", [True, False]) +@pytest.mark.parametrize("zero_grad", [True]) +@pytest.mark.parametrize("iteration", list(range(1))) +def test_zero3_param_partitioning_base_bf16( + param_persistence_threshold: int, + contiguous_gradients: bool, + offload_optimizer: bool, + zero_grad: bool, + iteration: int, +) -> None: + @distributed_test(world_size=[2]) + def _test_zero3_param_partitioning(): + if offload_optimizer and not contiguous_gradients: + return + + m = 3 + n = 5 + weights = [Parameter(torch.zeros((m, n), dtype=torch.float32)) for _ in range(3)] + model = EltwiseMultiplicationTestNetwork(*weights) + + cfg = { + "train_micro_batch_size_per_gpu": 1, + "zero_optimization": { + "stage": 3, + "stage3_max_reuse_distance": 0, + "stage3_param_persistence_threshold": param_persistence_threshold, + "contiguous_gradients": contiguous_gradients, + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1. + } + }, + "bf16": { + "enabled": True, + "loss_scale": 1., + } + } + + if offload_optimizer: + cfg["zero_optimization"]["offload_optimizer"] = { + "device": "cpu", + "pin_memory": True, + } + + ds_engine = _ds_initialize_for_param_partitioning_testing(model, cfg) + for i, weight in enumerate(weights): + weight.ds_tensor.data = torch.full_like(weight.ds_tensor.data, + (i + 1) * (1 + dist.get_rank())) + + def create_tensor(vals): + return torch.as_tensor(vals, dtype=torch.bfloat16, device=ds_engine.device) + + expected_hidden1 = create_tensor([ + [1, + 1, + 1, + 1, + 1], + [1, + 1, + 1, + 2, + 2], + [2, + 2, + 2, + 2, + 2], + ]) + expected_hidden2 = create_tensor([ + [2, + 2, + 2, + 2, + 2], + [2, + 2, + 2, + 8, + 8], + [8, + 8, + 8, + 8, + 8], + ]) + expected_yhat = create_tensor([[6, + 6, + 6, + 6, + 6], + [6, + 6, + 6, + 48, + 48], + [48, + 48, + 48, + 48, + 48]]) + expected_loss = create_tensor([ + [5, + 5, + 5, + 5, + 5], + [5, + 5, + 5, + 47, + 47], + [47, + 47, + 47, + 47, + 47], + ]) + + for train_iter in range(3): + _assert_partition_status(ds_engine, {ZeroParamStatus.NOT_AVAILABLE}) + activations = ds_engine( + x=torch.ones((m, + n), + dtype=torch.bfloat16, + device=ds_engine.device), + y=torch.ones((m, + n), + dtype=torch.bfloat16, + device=ds_engine.device), + prefetching=train_iter > 0, + ) + assert torch.allclose(activations["hidden1"], expected_hidden1) + assert torch.allclose(activations["hidden2"], expected_hidden2) + assert torch.allclose(activations["y_hat"], expected_yhat) + assert torch.allclose(activations["loss"], expected_loss) + + ds_engine.backward(activations["loss"].sum()) + _assert_partition_status(ds_engine, {ZeroParamStatus.NOT_AVAILABLE}) + + # check the gradients + grad_partitions = ds_engine.optimizer.get_fp32_grad_partitions() + assert set(grad_partitions.keys()) == {0}, f"should have one parameter group but got {len(grad_partitions)}" + assert set(grad_partitions[0].keys()) == {0, 1, 2} + dloss_wrt_layer1 = grad_partitions[0][0] + dloss_wrt_layer2 = grad_partitions[0][1] + dloss_wrt_layer3 = grad_partitions[0][2] + + # layer1 = [..., 1, 2, ...] + # layer2 = [..., 2, 4, ...] + # layer3 = [..., 3, 6, ...] + # dloss_wrt_layer3 = hidden2 + # dloss_wrt_layer2 = layer3 * hidden1 + # dloss_wrt_layer1 = layer3 * layer2 * x + + expected_grad_dtype = torch.float32 if offload_optimizer else torch.bfloat16 + + grad_multiplier = 1 if zero_grad else (train_iter + 1) + if dist.get_rank() == 0: + assert torch.allclose( + dloss_wrt_layer3.cuda(), + grad_multiplier * create_tensor([2] * 8).to(expected_grad_dtype)) + assert torch.allclose( + dloss_wrt_layer2.cuda(), + grad_multiplier * create_tensor([3 * 1] * 8).to(expected_grad_dtype)) + assert torch.allclose( + dloss_wrt_layer1.cuda(), + grad_multiplier * + create_tensor([3 * 2 * 1] * 8).to(expected_grad_dtype)) + elif dist.get_rank() == 1: + # parameters dont split evenly across ranks so rank 1 has a zero-padded + # partition + assert torch.allclose( + dloss_wrt_layer3.cuda(), + grad_multiplier * + create_tensor(([8] * 7) + [0]).to(expected_grad_dtype)) + assert torch.allclose( + dloss_wrt_layer2.cuda(), + grad_multiplier * + create_tensor(([6 * 2] * 7) + [0]).to(expected_grad_dtype)) + assert torch.allclose( + dloss_wrt_layer1.cuda(), + grad_multiplier * + create_tensor(([6 * 4 * 1] * 7) + [0]).to(expected_grad_dtype)) + else: + raise RuntimeError("test has world size of two") + + if zero_grad: + ds_engine.optimizer.zero_grad() + + # TODO. add testing for this - for now we just call it to make sure it + # doesnt throw + ds_engine.optimizer.step() + _assert_partition_status(ds_engine, {ZeroParamStatus.NOT_AVAILABLE}) + + _test_zero3_param_partitioning() diff --git a/tests/unit/test_zero_context.py b/tests/unit/test_zero_context.py index 1f5202e2..66521e07 100644 --- a/tests/unit/test_zero_context.py +++ b/tests/unit/test_zero_context.py @@ -1,5 +1,4 @@ import os -import sys from types import SimpleNamespace import torch @@ -8,7 +7,7 @@ import pytest import deepspeed from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus, partitioned_param_data_shape -from common import distributed_test, get_master_port +from .common import distributed_test, get_master_port def setup_serial_env(): -- GitLab