stage_1_and_2.py 108.3 KB
Newer Older
1 2 3 4
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
J
Jeff Rasley 已提交
5 6

import torch
7
import os
K
Karim Foda 已提交
8
from deepspeed import comm as dist
J
Jeff Rasley 已提交
9
from packaging import version as pkg_version
10
from collections import OrderedDict
11
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
J
Jeff Rasley 已提交
12

J
Jeff Rasley 已提交
13
from deepspeed.runtime import ZeROOptimizer
14
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
15 16
from deepspeed.runtime.utils import (bwc_tensor_model_parallel_rank, get_global_norm, empty_cache, see_memory_usage,
                                     inf, is_model_parallel_parameter, align_dense_tensors, all_gather_dp_groups)
17

18 19
from deepspeed.runtime.zero.config import ZeroStageEnum
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
20
from deepspeed.ops.adam import DeepSpeedCPUAdam
J
Jeff Rasley 已提交
21
from deepspeed.utils import logger
A
Ammar Ahmad Awan 已提交
22
from deepspeed.moe.utils import is_moe_param
J
Jeff Rasley 已提交
23
from deepspeed.git_version_info import version
O
Olatunji Ruwase 已提交
24

25
from deepspeed.runtime.constants import PIPE_REPLICATED
26 27
from deepspeed.accelerator import get_accelerator

28 29 30
from deepspeed.checkpoint.constants import (DS_VERSION, GROUP_PADDINGS, PARTITION_COUNT,
                                            SINGLE_PARTITION_OF_FP32_GROUPS, BASE_OPTIMIZER_STATE, CLIP_GRAD,
                                            ZERO_STAGE, PARAM_SLICE_MAPPINGS)
31 32
from deepspeed.utils import link_hp_params
from deepspeed.checkpoint import enable_universal_checkpoint
33

34 35
# Toggle this to true to enable correctness test
# with gradient partitioning and without
J
Jeff Rasley 已提交
36 37 38 39 40 41 42 43
pg_correctness_test = False


def input(msg):
    return


def split_half_float_double(tensors):
44
    device_type = get_accelerator().device_name()
J
Jeff Rasley 已提交
45
    dtypes = [
46 47
        "torch.{}.HalfTensor".format(device_type), "torch.{}.FloatTensor".format(device_type),
        "torch.{}.DoubleTensor".format(device_type), "torch.{}.BFloat16Tensor".format(device_type)
J
Jeff Rasley 已提交
48 49 50
    ]
    buckets = []
    for i, dtype in enumerate(dtypes):
51
        bucket = [t for t in tensors if t.type() == dtype]
J
Jeff Rasley 已提交
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
        if bucket:
            buckets.append(bucket)
    return buckets


def isclose(a, b, rtol=1e-09, atol=0.0):
    return abs(a - b) <= max(rtol * max(abs(a), abs(b)), atol)


def lcm(x, y):
    from fractions import gcd  # or can import gcd from `math` in Python 3
    return x * y // gcd(x, y)


def get_alignment_padding(tensor_list, alignment):
    num_elements = sum([tensor.numel() for tensor in tensor_list])
    remainder = num_elements % alignment
    return (alignment - remainder) if remainder else remainder


def move_to_cpu(tensor_list):
    for tensor in tensor_list:
        tensor.data = tensor.data.cpu()


def print_rank_msg(msg):
    print(f"rank {dist.get_rank()} - {msg}")


81 82 83 84 85 86 87 88 89
def _get_padded_tensor(src_tensor, size):
    if src_tensor.numel() >= size:
        return src_tensor
    padded_tensor = torch.zeros(size, dtype=src_tensor.dtype, device=src_tensor.device)
    slice_tensor = torch.narrow(padded_tensor, 0, 0, src_tensor.numel())
    slice_tensor.data.copy_(src_tensor.data)
    return padded_tensor


J
Jeff Rasley 已提交
90
class DeepSpeedZeroOptimizer(ZeROOptimizer):
J
Jeff Rasley 已提交
91 92 93 94 95 96 97 98 99 100
    """
    DeepSpeedZeroOptimizer designed to reduce the memory footprint
    required for training large deep learning models.

    For more details please see ZeRO: Memory Optimization Towards Training A Trillion Parameter Models
    https://arxiv.org/abs/1910.02054

    For usage examples, refer to TODO: DeepSpeed Tutorial

    """
101

J
Jeff Rasley 已提交
102 103
    def __init__(self,
                 init_optimizer,
104
                 param_names,
J
Jeff Rasley 已提交
105 106 107 108 109 110 111 112 113
                 timers,
                 static_loss_scale=1.0,
                 dynamic_loss_scale=False,
                 dynamic_loss_args=None,
                 verbose=True,
                 contiguous_gradients=True,
                 reduce_bucket_size=500000000,
                 allgather_bucket_size=5000000000,
                 dp_process_group=None,
A
Ammar Ahmad Awan 已提交
114 115
                 expert_parallel_group=None,
                 expert_data_parallel_group=None,
J
Jeff Rasley 已提交
116 117 118 119 120
                 reduce_scatter=True,
                 overlap_comm=False,
                 cpu_offload=False,
                 mpu=None,
                 clip_grad=0.0,
M
Mikhail Druzhinin 已提交
121
                 communication_data_type=torch.float16,
J
Jeff Rasley 已提交
122 123
                 postscale_gradients=True,
                 gradient_predivide_factor=1.0,
124
                 gradient_accumulation_steps=1,
J
Jeff Rasley 已提交
125
                 ignore_unused_parameters=True,
126
                 partition_grads=True,
A
Ammar Ahmad Awan 已提交
127 128
                 round_robin_gradients=False,
                 has_moe_layers=False,
129 130
                 fp16_master_weights_and_gradients=False,
                 elastic_checkpoint=False):
J
Jeff Rasley 已提交
131 132 133 134 135

        if dist.get_rank() == 0:
            logger.info(f"Reduce bucket size {reduce_bucket_size}")
            logger.info(f"Allgather bucket size {allgather_bucket_size}")
            logger.info(f"CPU Offload: {cpu_offload}")
136
            logger.info(f'Round robin gradient partitioning: {round_robin_gradients}')
J
Jeff Rasley 已提交
137 138 139 140
        # The fused optimizer does all the work. We need this layer for two reason:
        # 1. maintain same user API from apex.fp16_utils
        # 2. keep common stuff here in case we need to add ne552w fused optimizer later

141
        self.elastic_checkpoint = elastic_checkpoint
142 143
        self.param_names = param_names
        self.mpu = mpu
J
Jeff Rasley 已提交
144 145 146 147
        # differences from apex.fp16_utils:
        # - assume all model params in fp16
        # - assume all params requires grad
        # - flat by groups, not keeping state. TODO: remove state explicitly?
148
        # - master grad and unflat master weight never exist. TODO: a way to save out unflat master?
149 150
        if not get_accelerator().is_available():
            raise SystemError("Cannot use fp16 without accelerator.")
J
Jeff Rasley 已提交
151 152
        self.optimizer = init_optimizer

153 154 155
        # Use torch (un)flatten ops
        self.flatten = _flatten_dense_tensors
        self.unflatten = _unflatten_dense_tensors
156

J
Jeff Rasley 已提交
157 158
        # ZeRO stage 1 (False) or 2 (True)
        self.partition_gradients = partition_grads
159
        self.zero_stage_string = "ZeRO-2" if partition_grads else "ZeRO-1"
J
Jeff Rasley 已提交
160

J
Jeff Rasley 已提交
161 162 163 164 165 166 167 168
        self.timers = timers

        self.reduce_scatter = reduce_scatter

        self.overlap_comm = overlap_comm

        self.cpu_offload = cpu_offload

169
        self.deepspeed_adam_offload = cpu_offload
J
Jeff Rasley 已提交
170

171
        self.device = get_accelerator().current_device_name() if not self.cpu_offload else 'cpu'
J
Jeff Rasley 已提交
172 173 174

        self.dp_process_group = dp_process_group

A
Ammar Ahmad Awan 已提交
175 176 177 178 179 180 181 182 183 184 185
        #expert parallel group
        self.ep_process_group = expert_parallel_group

        #data parallel group for experts
        self.expert_dp_process_group = expert_data_parallel_group

        #data parallel size for non-experts
        dp_size = dist.get_world_size(group=self.dp_process_group)

        #For MoE models this maybe different for different param group
        #It will be modified during MoE setup later in the init
186
        self.real_dp_process_group = [dp_process_group for i in range(len(self.optimizer.param_groups))]
A
Ammar Ahmad Awan 已提交
187
        self.partition_count = [dp_size for i in range(len(self.optimizer.param_groups))]
J
Jeff Rasley 已提交
188

J
Jeff Rasley 已提交
189 190
        self.is_gradient_accumulation_boundary = True

A
Ammar Ahmad Awan 已提交
191 192 193 194 195 196
        # CPU-Offload requires contiguous gradients
        self.contiguous_gradients = contiguous_gradients or cpu_offload

        self.has_moe_layers = has_moe_layers
        if self.has_moe_layers:
            self._configure_moe_settings()
197
        self._global_grad_norm = 0.
A
Ammar Ahmad Awan 已提交
198

J
Jeff Rasley 已提交
199 200
        if mpu is None:
            self.model_parallel_group = None
201
            self.model_parallel_world_size = 1
J
Jeff Rasley 已提交
202 203 204
            self.model_parallel_rank = 0
        else:
            self.model_parallel_group = mpu.get_model_parallel_group()
205
            self.model_parallel_world_size = mpu.get_model_parallel_world_size()
206
            self.model_parallel_rank = bwc_tensor_model_parallel_rank(mpu)
J
Jeff Rasley 已提交
207 208 209

        self.overflow = False
        self.clip_grad = clip_grad
M
Mikhail Druzhinin 已提交
210
        self.communication_data_type = communication_data_type
J
Jeff Rasley 已提交
211 212 213 214
        self.gradient_predivide_factor = gradient_predivide_factor
        self.postscale_gradients = postscale_gradients
        self.gradient_accumulation_steps = gradient_accumulation_steps
        self.micro_step_id = 0
215
        self.ignore_unused_parameters = ignore_unused_parameters
216
        self.round_robin_gradients = round_robin_gradients
J
Jeff Rasley 已提交
217

J
Jeff Rasley 已提交
218
        self.extra_large_param_to_reduce = None
A
Ammar Ahmad Awan 已提交
219 220 221
        self.fp16_master_weights_and_gradients = fp16_master_weights_and_gradients

        if self.fp16_master_weights_and_gradients:
222 223 224 225
            assert self.cpu_offload and type(self.optimizer) in [DeepSpeedCPUAdam], \
            f"fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32."\
            f"Currently only supported using ZeRO-Offload with DeepSpeedCPUAdam. But current setting is ZeRO-Offload:{self.cpu_offload} and optimizer type {type(self.optimizer)}." \
            f"Either disable fp16_master_weights_and_gradients or enable {self.zero_stage_string} Offload with DeepSpeedCPUAdam."
J
Jeff Rasley 已提交
226

J
Jeff Rasley 已提交
227
        if self.reduce_scatter:
228 229 230 231
            valid_reduce_scatter_dtypes = (torch.float16, torch.bfloat16, torch.float32)
            assert self.communication_data_type in valid_reduce_scatter_dtypes, f"{self.zero_stage_string} supports {valid_reduce_scatter_dtypes} communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'"
            assert self.gradient_predivide_factor == 1.0, "gradient_predivide_factor != 1.0 is not yet supported with {self.zero_stage_string} with reduce scatter enabled"
            assert self.postscale_gradients, "pre-scale gradients is not yet supported with {self.zero_stage_string} with reduce scatter enabled"
J
Jeff Rasley 已提交
232 233

        # param flattened by groups
R
Rana Ali Amjad 已提交
234 235
        self.bit16_groups = []
        self.bit16_groups_flat = []
J
Jeff Rasley 已提交
236

237 238 239
        # param partitioned by data parallel degree
        # this will contain a list of equal sized tensors
        # each of which will be updated by a different process
R
Rana Ali Amjad 已提交
240
        self.parallel_partitioned_bit16_groups = []
J
Jeff Rasley 已提交
241

242 243
        # a single 32-bit partition of the parallel partitioned parameters
        # that this process will update
J
Jeff Rasley 已提交
244 245
        self.single_partition_of_fp32_groups = []

246
        # param partition info
J
Jeff Rasley 已提交
247

248
        # These are the parameters in each group that will not be updated by this process directly
J
Jeff Rasley 已提交
249 250
        self.params_not_in_partition = []

251
        # These are the parameters that will be updated by this process directly
J
Jeff Rasley 已提交
252 253
        self.params_in_partition = []

A
Alex Hedges 已提交
254
        # Offset from the first parameter in the the self.params_in_partition
255 256
        # the parameter boundaries may not align with partition boundaries
        # so we need to keep track of the offset
J
Jeff Rasley 已提交
257 258
        self.first_offset = []

259
        # number of elements per partition in each group
J
Jeff Rasley 已提交
260 261
        self.partition_size = []

262
        # align nccl all-gather send buffers to 4-byte boundary
263 264
        self.nccl_start_alignment_factor = 2  # 4-byte alignment/sizeof(fp16) = 2

265 266 267
        assert (
            allgather_bucket_size % self.nccl_start_alignment_factor == 0
        ), f"allgather_bucket_size must be a multiple of nccl_start_alignment_factor, {self.nccl_start_alignment_factor} "
268

J
Jeff Rasley 已提交
269
        self.all_reduce_print = False
270
        self.dtype = self.optimizer.param_groups[0]['params'][0].dtype
J
Jeff Rasley 已提交
271

R
Rana Ali Amjad 已提交
272 273
        self.round_robin_bit16_groups = []
        self.round_robin_bit16_indices = []
274

275
        # Use different parallel to do all_to_all_reduce related things
J
Jeff Rasley 已提交
276 277 278 279
        # padding on each partition for alignment purposes
        self.groups_padding = []
        # loop to deal with groups
        for i, param_group in enumerate(self.optimizer.param_groups):
A
Ammar Ahmad Awan 已提交
280 281
            partition_id = dist.get_rank(group=self.real_dp_process_group[i])

J
Jeff Rasley 已提交
282
            # push this group to list before modify
283
            # TODO: Explore simplification that avoids the extra book-keeping by pushing the reordered group
284
            trainable_parameters = [param for param in param_group['params'] if param.requires_grad]
285
            self.bit16_groups.append(trainable_parameters)
286

287 288
            # not sure why apex was cloning the weights before flattening
            # removing cloning here
J
Jeff Rasley 已提交
289 290

            see_memory_usage(f"Before moving param group {i} to CPU")
291
            # move all the parameters to cpu to free up GPU space for creating flat buffer
R
Rana Ali Amjad 已提交
292
            move_to_cpu(self.bit16_groups[i])
293
            empty_cache()
294
            see_memory_usage(f"After moving param group {i} to CPU", force=False)
J
Jeff Rasley 已提交
295

296 297 298 299
            # Reorder group parameters for load balancing of gradient partitioning during backward among ranks.
            # This ensures that gradients are reduced in a fashion such that ownership round robins among the ranks.
            # For example, rather than 3 gradients (g_n+2, g_n+1, g_n) that are reduced consecutively belonging
            # to the same rank, instead they will belong to 3 ranks (r_m+2, r_m+1, r_m).
300 301
            if self.round_robin_gradients:
                round_robin_tensors, round_robin_indices = self._round_robin_reorder(
302
                    self.bit16_groups[i], dist.get_world_size(group=self.real_dp_process_group[i]))
303
            else:
R
Rana Ali Amjad 已提交
304 305
                round_robin_tensors = self.bit16_groups[i]
                round_robin_indices = list(range(len(self.bit16_groups[i])))
306

R
Rana Ali Amjad 已提交
307 308
            self.round_robin_bit16_groups.append(round_robin_tensors)
            self.round_robin_bit16_indices.append(round_robin_indices)
309

310
            # create flat buffer in CPU and move to GPU
R
Rana Ali Amjad 已提交
311
            self.bit16_groups_flat.append(
312
                self.flatten_dense_tensors_aligned(
R
Rana Ali Amjad 已提交
313
                    self.round_robin_bit16_groups[i],
314
                    self.nccl_start_alignment_factor * dist.get_world_size(group=self.real_dp_process_group[i])).to(
315
                        get_accelerator().current_device_name()))
316
            see_memory_usage(f"After flattening and moving param group {i} to GPU", force=False)
J
Jeff Rasley 已提交
317

O
Olatunji Ruwase 已提交
318
            # Record padding required for alignment
319
            if partition_id == dist.get_world_size(group=self.real_dp_process_group[i]) - 1:
O
Olatunji Ruwase 已提交
320 321 322 323 324 325
                padding = self.bit16_groups_flat[i].numel() - sum(
                    [t.numel() for t in self.round_robin_bit16_groups[i]])
            else:
                padding = 0
            self.groups_padding.append(padding)

A
Ammar Ahmad Awan 已提交
326
            if dist.get_rank(group=self.real_dp_process_group[i]) == 0:
327
                see_memory_usage(f"After Flattening and after emptying param group {i} cache", force=False)
J
Jeff Rasley 已提交
328

R
Rana Ali Amjad 已提交
329 330
            # set model bit16 weight to slices of flattened buffer
            self._update_model_bit16_weights(i)
J
Jeff Rasley 已提交
331

332 333
            # divide the flat weights into near equal partition equal to the data parallel degree
            # each process will compute on a different part of the partition
334
            data_parallel_partitions = self.get_data_parallel_partitions(self.bit16_groups_flat[i], i)
R
Rana Ali Amjad 已提交
335
            self.parallel_partitioned_bit16_groups.append(data_parallel_partitions)
J
Jeff Rasley 已提交
336

C
Cheng Li 已提交
337 338
            # verify that data partition start locations are 4-byte aligned
            for partitioned_data in data_parallel_partitions:
339
                assert (partitioned_data.data_ptr() % (2 * self.nccl_start_alignment_factor) == 0)
340

341 342 343
            # A partition of the fp32 master weights that will be updated by this process.
            # Note that the params in single_partition_of_fp32_groups is cloned and detached
            # from the origin params of the model.
A
Ammar Ahmad Awan 已提交
344
            if not fp16_master_weights_and_gradients:
345 346
                self.single_partition_of_fp32_groups.append(self.parallel_partitioned_bit16_groups[i][partition_id].to(
                    self.device).clone().float().detach())
A
Ammar Ahmad Awan 已提交
347
            else:
348 349
                self.single_partition_of_fp32_groups.append(self.parallel_partitioned_bit16_groups[i][partition_id].to(
                    self.device).clone().half().detach())
J
Jeff Rasley 已提交
350

351 352 353
            # Set local optimizer to have flat params of its own partition.
            # After this, the local optimizer will only contain its own partition of params.
            # In that case, the local optimizer only saves the states(momentum, variance, etc.) related to its partition's params(zero stage1).
J
Jeff Rasley 已提交
354 355 356 357
            self.single_partition_of_fp32_groups[
                i].requires_grad = True  # keep this in case internal optimizer uses it
            param_group['params'] = [self.single_partition_of_fp32_groups[i]]

358
            partition_size = len(self.bit16_groups_flat[i]) / dist.get_world_size(group=self.real_dp_process_group[i])
359
            params_in_partition, params_not_in_partition, first_offset = self.get_partition_info(
360
                self.round_robin_bit16_groups[i], partition_size, partition_id)
J
Jeff Rasley 已提交
361 362 363 364 365 366

            self.partition_size.append(partition_size)
            self.params_in_partition.append(params_in_partition)
            self.params_not_in_partition.append(params_not_in_partition)
            self.first_offset.append(first_offset)

A
Ammar Ahmad Awan 已提交
367 368 369 370 371 372
        for rank in range(dist.get_world_size()):
            if dist.get_rank() == rank:
                print(
                    f"Rank: {rank} partition count {self.partition_count} and sizes{[(p.numel(), self.is_moe_param_group[i] if hasattr(self, 'is_moe_param_group') else False) for i,p in enumerate(self.single_partition_of_fp32_groups)]} "
                )
                dist.barrier()
373

J
Jeff Rasley 已提交
374 375 376
        self.reduce_bucket_size = int(reduce_bucket_size)
        self.allgather_bucket_size = int(allgather_bucket_size)

377
        self.reduction_event = get_accelerator().Event(enable_timing=False, blocking=False)
378 379 380
        self.reduction_stream = get_accelerator().Stream()
        self.cpu_computation_stream = get_accelerator().Stream()
        self.copy_grad_stream = get_accelerator().Stream()
J
Jeff Rasley 已提交
381 382 383 384
        self.callback_queued = False

        self.param_dict = {}

385
        # map between param_id and bool to specify if a param is in this partition
J
Jeff Rasley 已提交
386 387 388 389 390 391 392 393
        self.is_param_in_current_partition = {}

        self.grads_in_ipg_bucket = []
        self.params_in_ipg_bucket = []
        self.elements_in_ipg_bucket = 0
        self.params_already_reduced = []
        self._release_ipg_buffers()
        self.previous_reduced_grads = None
A
Ammar Ahmad Awan 已提交
394
        self.ipg_bucket_has_moe_params = False
J
Jeff Rasley 已提交
395

396
        # simplified param id
J
Jeff Rasley 已提交
397 398
        self.param_id = {}

399
        #interesting code: unique ids being assigned to individual parameters
J
Jeff Rasley 已提交
400 401
        largest_param_numel = 0
        count = 0
R
Rana Ali Amjad 已提交
402
        for i, params_group in enumerate(self.bit16_groups):
J
Jeff Rasley 已提交
403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424
            for param in params_group:
                unique_id = id(param)
                self.param_id[unique_id] = count
                self.param_dict[count] = param
                self.params_already_reduced.append(False)
                if param.numel() > largest_param_numel:
                    largest_param_numel = param.numel()
                count = count + 1

        for param_group in self.params_in_partition:
            for param in param_group:
                self.is_param_in_current_partition[self.get_param_id(param)] = True

        for param_group in self.params_not_in_partition:
            for param in param_group:
                self.is_param_in_current_partition[self.get_param_id(param)] = False

        if self.cpu_offload:
            self.accumulated_grads_in_cpu = {}
            self.norm_for_param_grads = {}
            self.local_overflow = False
            self.grad_position = {}
425
            self.temp_grad_buffer_for_cpu_offload = get_accelerator().pin_memory(
426 427 428 429
                torch.zeros(largest_param_numel, device=self.device, dtype=self.dtype))
            self.temp_grad_buffer_for_gpu_offload = torch.zeros(largest_param_numel,
                                                                device=get_accelerator().current_device_name(),
                                                                dtype=self.dtype)
R
Rana Ali Amjad 已提交
430
            for i, params_group in enumerate(self.bit16_groups):
431
                self.get_grad_position(i, self.params_in_partition[i], self.first_offset[i], self.partition_size[i])
J
Jeff Rasley 已提交
432

433
        # mapping from parameter to partition that it belongs to
J
Jeff Rasley 已提交
434 435
        self.param_to_partition_ids = {}

436
        # stores if a partition has been reduced in this step
J
Jeff Rasley 已提交
437 438
        self.is_partition_reduced = {}

439
        # number of grads in partition that still need to be computed
J
Jeff Rasley 已提交
440 441
        self.remaining_grads_in_partition = {}

442
        # total number of grads in partition
J
Jeff Rasley 已提交
443 444
        self.total_grads_in_partition = {}

445
        # stores if a grad in a partition has been computed or not
J
Jeff Rasley 已提交
446 447
        self.is_grad_computed = {}

448
        # stores the offset at which a parameter gradient needs to be inserted in a partition
J
Jeff Rasley 已提交
449 450
        self.grad_partition_insertion_offset = {}

451
        # the offset in the gradient at which it must be inserted at the beginning of the partition
J
Jeff Rasley 已提交
452 453
        self.grad_start_offset = {}

454
        # will store the averaged gradients required by this partition
J
Jeff Rasley 已提交
455 456
        self.averaged_gradients = {}

457 458 459
        # For cpu_offload, will store the averaged gradients required by this partition
        self.offload_gradient_dict = {}

J
Jeff Rasley 已提交
460 461 462
        # store index of first parameter in each partition
        self.first_param_index_in_partition = {}

463
        # initializes all data structures for implementing gradient partitioning
J
Jeff Rasley 已提交
464 465
        self.initialize_gradient_partitioning_data_structures()

466
        # resets the data structure value for the next backward propagation
J
Jeff Rasley 已提交
467 468
        self.reset_partition_gradient_structures()

469
        # creates backward hooks for gradient partitioning
J
Jeff Rasley 已提交
470 471
        if self.partition_gradients or self.overlap_comm:
            self.create_reduce_and_remove_grad_hooks()
J
Jeff Rasley 已提交
472

J
Jeff Rasley 已提交
473 474 475
        self.custom_loss_scaler = False
        self.external_loss_scale = None

J
Jeff Rasley 已提交
476
        # we may have a way of fusing dynamic scale. Do not support for now
477 478 479 480 481
        self.loss_scaler = CreateLossScaler(dtype=self.dtype,
                                            static_loss_scale=static_loss_scale,
                                            dynamic_scaling=dynamic_loss_scale,
                                            dynamic_loss_args=dynamic_loss_args)
        self.dynamic_loss_scale = self.loss_scaler.dynamic
J
Jeff Rasley 已提交
482

A
Ammar Ahmad Awan 已提交
483
        see_memory_usage("Before initializing optimizer states", force=True)
J
Jeff Rasley 已提交
484
        self.initialize_optimizer_states()
A
Ammar Ahmad Awan 已提交
485
        see_memory_usage("After initializing optimizer states", force=True)
J
Jeff Rasley 已提交
486 487 488 489 490

        if dist.get_rank() == 0:
            logger.info(f"optimizer state initialized")

        if dist.get_rank(group=self.dp_process_group) == 0:
A
Ammar Ahmad Awan 已提交
491 492
            see_memory_usage(f"After initializing ZeRO optimizer", force=True)

493 494 495 496 497 498 499 500 501 502 503 504 505 506 507
        self._link_all_hp_params()
        self._enable_universal_checkpoint()
        self._param_slice_mappings = self._create_param_mapping()

    def _enable_universal_checkpoint(self):
        for lp_param_group in self.bit16_groups:
            enable_universal_checkpoint(param_list=lp_param_group)

    def _create_param_mapping(self):
        param_mapping = []
        for i, _ in enumerate(self.optimizer.param_groups):
            param_mapping_per_group = OrderedDict()
            for lp in self.bit16_groups[i]:
                if lp._hp_mapping is not None:
                    lp_name = self.param_names[lp]
508
                    param_mapping_per_group[lp_name] = lp._hp_mapping.get_hp_fragment_address()
509 510 511 512 513 514
            param_mapping.append(param_mapping_per_group)

        return param_mapping

    def _link_all_hp_params(self):
        dp_world_size = dist.get_world_size(group=self.dp_process_group)
515 516 517
        if self.cpu_offload:
            self._get_offload_gradient_dict()

518 519 520 521 522
        for i, _ in enumerate(self.optimizer.param_groups):
            # Link bit16 and fp32 params in partition
            partition_id = dist.get_rank(group=self.real_dp_process_group[i])
            partition_size = self.bit16_groups_flat[i].numel() // dp_world_size
            flat_hp_partition = self.single_partition_of_fp32_groups[i]
523 524 525 526 527 528 529 530 531 532
            link_hp_params(lp_param_list=self.bit16_groups[i],
                           flat_hp_partition=flat_hp_partition,
                           gradient_dict=self.averaged_gradients,
                           offload_gradient_dict=self.offload_gradient_dict,
                           use_offload=self.cpu_offload,
                           param_group_index=i,
                           partition_start=partition_id * partition_size,
                           partition_size=partition_size,
                           partition_optimizer_state=self.optimizer.state[flat_hp_partition],
                           dp_group=self.real_dp_process_group[i])
533

534 535 536
    def is_moe_group(self, group):
        return 'moe' in group and group['moe']

A
Ammar Ahmad Awan 已提交
537
    def _configure_moe_settings(self):
538 539 540 541 542 543
        # if we're using ZeRO stage 2, ensure contiguous gradients are used
        if self.partition_gradients:
            assert self.contiguous_gradients, "Contiguous Gradients in ZeRO Stage 2 must be set to True for MoE. Other code paths are not tested with MoE"
        # NOTE: To run ZeRO stage 1 with MoE, we need to set self.contiguous_gradients to True or ignore the assertion
        if not self.partition_gradients and not self.contiguous_gradients:
            logger.warn(
544
                "ZeRO Stage 1 has not been thoroughly tested with MoE. This configuration is still experimental.")
A
Ammar Ahmad Awan 已提交
545 546
        assert self.reduce_scatter, "Reduce Scatter in ZeRO Stage 2 must be set to True for MoE. Other code paths are not tested with MoE"

547 548 549
        assert any(
            [self.is_moe_group(group) for group in self.optimizer.param_groups]
        ), "The model has moe layers, but None of the param groups are marked as MoE. Create a param group with 'moe' key set to True before creating optimizer"
A
Ammar Ahmad Awan 已提交
550 551
        self.is_moe_param_group = []
        for i, group in enumerate(self.optimizer.param_groups):
552
            if self.is_moe_group(group):
553 554 555 556
                assert all([is_moe_param(param)
                            for param in group['params']]), "All params in MoE group must be MoE params"
                self.real_dp_process_group[i] = self.expert_dp_process_group[group['name']]
                self.partition_count[i] = dist.get_world_size(group=self.expert_dp_process_group[group['name']])
A
Ammar Ahmad Awan 已提交
557 558 559 560 561 562
                self.is_moe_param_group.append(True)
            else:
                self.is_moe_param_group.append(False)

        assert self.expert_dp_process_group is not None, "Expert data parallel group should be configured with MoE"
        assert self.ep_process_group is not None, "Expert parallel group should be configured with MoE"
J
Jeff Rasley 已提交
563

R
Rana Ali Amjad 已提交
564 565 566 567
    def _update_model_bit16_weights(self, group_index):
        updated_params = self.unflatten(self.bit16_groups_flat[group_index],
                                        self.round_robin_bit16_groups[group_index])
        for p, q in zip(self.round_robin_bit16_groups[group_index], updated_params):
568 569 570
            p.data = q.data

        # set model fp16 weight to slices of reordered flattened buffer
R
Rana Ali Amjad 已提交
571 572 573
        for param_index, param in enumerate(self.bit16_groups[group_index]):
            new_index = self.round_robin_bit16_indices[group_index][param_index]
            param.data = self.round_robin_bit16_groups[group_index][new_index].data
574 575

    def _round_robin_reorder(self, tensor_list, num_partitions):
576 577

        # disable round robin if need to debug something
578
        # return tensor_list, list(range(len(tensor_list)))
579

580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597
        partition_tensors = {}

        for i, tensor in enumerate(tensor_list):
            j = i % num_partitions
            if not j in partition_tensors:
                partition_tensors[j] = []
            partition_tensors[j].append((i, tensor))

        reordered_tensors = []
        reordered_indices = {}

        for partition_index in partition_tensors.keys():
            for i, (original_index, tensor) in enumerate(partition_tensors[partition_index]):
                reordered_indices[original_index] = len(reordered_tensors)
                reordered_tensors.append(tensor)

        return reordered_tensors, reordered_indices

J
Jeff Rasley 已提交
598 599 600 601 602 603 604 605
    def _release_ipg_buffers(self):
        if self.contiguous_gradients:
            self.ipg_buffer = None
            self.grads_in_partition = None
            self.grads_in_partition_offset = 0

    def initialize_optimizer_states(self):

R
Rana Ali Amjad 已提交
606
        for i, group in enumerate(self.bit16_groups):
607 608 609
            single_grad_partition = torch.zeros(int(self.partition_size[i]),
                                                dtype=self.single_partition_of_fp32_groups[i].dtype,
                                                device=self.device)
610 611
            self.single_partition_of_fp32_groups[i].grad = get_accelerator().pin_memory(
                single_grad_partition) if self.cpu_offload else single_grad_partition
J
Jeff Rasley 已提交
612

D
digger yu 已提交
613
        # Initialize the optimizer states with the flattened fp32 partition.
J
Joe Mayer 已提交
614 615 616 617 618 619
        # State initialization for the Adagrad optimizer occurs at construction as opposed to other optimizers
        # which do lazy initialization of the state at the first call to step.
        if isinstance(self.optimizer, torch.optim.Adagrad):
            self.optimizer = torch.optim.Adagrad(self.single_partition_of_fp32_groups, **self.optimizer.defaults)
        else:
            self.optimizer.step()
J
Jeff Rasley 已提交
620 621 622

        if not self.cpu_offload:
            for group in self.single_partition_of_fp32_groups:
623
                group.grad = None  #class init
J
Jeff Rasley 已提交
624 625 626

        return

J
Jeff Rasley 已提交
627 628 629 630 631 632 633 634 635 636 637 638
    #########################################################################
    #################### ZeRO Stage 1 - reduce gradients ####################
    #########################################################################
    def reduce_gradients(self, pipeline_parallel=False):
        world_size = dist.get_world_size(self.dp_process_group)
        my_rank = dist.get_rank(self.dp_process_group)

        # with PP we must create ipg buffer, since backward is handled outside zero
        if pipeline_parallel and self.contiguous_gradients:
            self.ipg_buffer = []
            buf_0 = torch.empty(int(self.reduce_bucket_size),
                                dtype=self.dtype,
639
                                device=get_accelerator().current_device_name())
J
Jeff Rasley 已提交
640 641 642 643
            self.ipg_buffer.append(buf_0)
            self.ipg_index = 0

        if not self.overlap_comm:
R
Rana Ali Amjad 已提交
644
            for i, group in enumerate(self.bit16_groups):
J
Jeff Rasley 已提交
645
                for param in group:
646 647
                    if param.grad is not None:
                        self.reduce_ready_partitions_and_remove_grads(param, i)
J
Jeff Rasley 已提交
648 649 650
        # reduce any pending grads in either hook/non-hook case
        self.overlapping_partition_gradients_reduce_epilogue()

J
Jeff Rasley 已提交
651 652 653 654 655 656 657 658 659 660 661 662 663
    #########################################################################
    #########################ZeRO Partition Gradients########################
    #########################################################################

    def get_first_param_index(self, group_id, param_group, partition_id):
        for index, param in enumerate(param_group):
            param_id = self.get_param_id(param)
            if partition_id in self.param_to_partition_ids[group_id][param_id]:
                return index
        return None

    def initialize_gradient_partitioning_data_structures(self):

R
Rana Ali Amjad 已提交
664
        for i, param_group in enumerate(self.round_robin_bit16_groups):
A
Ammar Ahmad Awan 已提交
665 666
            total_partitions = dist.get_world_size(group=self.real_dp_process_group[i])

J
Jeff Rasley 已提交
667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682
            self.param_to_partition_ids[i] = {}
            self.is_partition_reduced[i] = {}
            self.total_grads_in_partition[i] = {}
            self.remaining_grads_in_partition[i] = {}
            self.is_grad_computed[i] = {}
            self.grad_partition_insertion_offset[i] = {}
            self.grad_start_offset[i] = {}
            self.first_param_index_in_partition[i] = {}

            for partition_id in range(total_partitions):
                self.is_grad_computed[i][partition_id] = {}
                self.grad_partition_insertion_offset[i][partition_id] = {}
                self.grad_start_offset[i][partition_id] = {}
                self.total_grads_in_partition[i][partition_id] = 0
                self.initialize_gradient_partition(i, param_group, partition_id)
                self.is_partition_reduced[i][partition_id] = False
683 684
                self.first_param_index_in_partition[i][partition_id] = self.get_first_param_index(
                    i, param_group, partition_id)
J
Jeff Rasley 已提交
685 686 687 688 689 690

    def independent_gradient_partition_epilogue(self):
        self.report_ipg_memory_usage(f"In ipg_epilogue before reduce_ipg_grads", 0)
        self.reduce_ipg_grads()
        self.report_ipg_memory_usage(f"In ipg_epilogue after reduce_ipg_grads", 0)

691
        # if dist.get_rank() == 0:
J
Jeff Rasley 已提交
692 693 694 695 696
        #    logger.info("Params already reduced %s", self.params_already_reduced)
        for i in range(len(self.params_already_reduced)):
            self.params_already_reduced[i] = False

        if self.overlap_comm:
697
            get_accelerator().synchronize()
698 699
            # It is safe to clear previously reduced grads of other partitions
            self._clear_previous_reduced_grads()
J
Jeff Rasley 已提交
700 701

        if self.cpu_offload is False:
R
Rana Ali Amjad 已提交
702
            for i, _ in enumerate(self.bit16_groups):
J
Jeff Rasley 已提交
703 704 705 706 707 708

                if not i in self.averaged_gradients or self.averaged_gradients[i] is None:
                    self.averaged_gradients[i] = self.get_flat_partition(
                        self.params_in_partition[i],
                        self.first_offset[i],
                        self.partition_size[i],
709
                        dtype=self.dtype,
710
                        device=get_accelerator().current_device_name(),
J
Jeff Rasley 已提交
711 712
                        return_tensor_list=True)
                else:
713 714 715 716 717 718
                    avg_new = self.get_flat_partition(self.params_in_partition[i],
                                                      self.first_offset[i],
                                                      self.partition_size[i],
                                                      dtype=self.dtype,
                                                      device=get_accelerator().current_device_name(),
                                                      return_tensor_list=True)
J
Jeff Rasley 已提交
719

720
                    for accumulated_grad, new_avg_grad in zip(self.averaged_gradients[i], avg_new):
J
Jeff Rasley 已提交
721 722 723 724 725 726 727
                        accumulated_grad.add_(new_avg_grad)

        self._release_ipg_buffers()

        # No need to keep the gradients anymore.
        # All gradients required by the step
        # are in self.averaged_gradients
728
        self.zero_grad(set_to_none=True)
J
Jeff Rasley 已提交
729 730 731
        see_memory_usage(f"End ipg_epilogue")

    # resets all partition to no reduced
H
Haibin Lin 已提交
732
    # sets remaining grads to the total number of grads in each partition
J
Jeff Rasley 已提交
733 734
    # set is grad computed to false for all grads in partition
    def reset_partition_gradient_structures(self):
R
Rana Ali Amjad 已提交
735
        for i, _ in enumerate(self.bit16_groups):
A
Ammar Ahmad Awan 已提交
736
            total_partitions = dist.get_world_size(group=self.real_dp_process_group[i])
J
Jeff Rasley 已提交
737 738
            for partition_id in range(total_partitions):
                self.is_partition_reduced[i][partition_id] = False
739
                self.remaining_grads_in_partition[i][partition_id] = self.total_grads_in_partition[i][partition_id]
J
Jeff Rasley 已提交
740 741 742 743 744

                for param_id in self.is_grad_computed[i][partition_id]:
                    self.is_grad_computed[i][partition_id][param_id] = False

    def initialize_gradient_partition(self, i, param_group, partition_id):
745

J
Jeff Rasley 已提交
746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771
        def set_key_value_list(dictionary, key, value):
            if key in dictionary:
                dictionary[key].append(value)
            else:
                dictionary[key] = [value]

        def increment_value(dictionary, key):
            if key in dictionary:
                dictionary[key] += 1
            else:
                dictionary[key] = 1

        partition_size = self.partition_size[i]

        start_index = partition_size * partition_id
        end_index = partition_size * (partition_id + 1)

        current_index = 0
        first_offset = 0

        for param in param_group:

            param_size = param.numel()
            param_id = self.get_param_id(param)

            if (current_index >= start_index and current_index < end_index):
772
                set_key_value_list(self.param_to_partition_ids[i], param_id, partition_id)
J
Jeff Rasley 已提交
773 774 775 776
                increment_value(self.total_grads_in_partition[i], partition_id)

                self.is_grad_computed[i][partition_id][param_id] = False

777
                self.grad_partition_insertion_offset[i][partition_id][param_id] = current_index - start_index
J
Jeff Rasley 已提交
778 779
                self.grad_start_offset[i][partition_id][param_id] = 0

780 781 782
            elif start_index > current_index and start_index < (current_index + param_size):
                assert (first_offset == 0
                        ), "This can happen either zero or only once as this must be the first tensor in the partition"
J
Jeff Rasley 已提交
783 784
                first_offset = start_index - current_index

785
                set_key_value_list(self.param_to_partition_ids[i], param_id, partition_id)
J
Jeff Rasley 已提交
786 787 788 789 790 791 792 793 794 795 796 797 798 799
                increment_value(self.total_grads_in_partition[i], partition_id)

                self.is_grad_computed[i][partition_id][param_id] = False

                self.grad_partition_insertion_offset[i][partition_id][param_id] = 0
                self.grad_start_offset[i][partition_id][param_id] = first_offset

            current_index = current_index + param_size

    def overlapping_partition_gradients_reduce_epilogue(self):
        self.independent_gradient_partition_epilogue()

    def create_reduce_and_remove_grad_hooks(self):
        self.grad_accs = []
R
Rana Ali Amjad 已提交
800
        for i, param_group in enumerate(self.bit16_groups):
J
Jeff Rasley 已提交
801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826
            for param in param_group:
                if param.requires_grad:

                    def wrapper(param, i):
                        param_tmp = param.expand_as(param)
                        grad_acc = param_tmp.grad_fn.next_functions[0][0]

                        def reduce_partition_and_remove_grads(*notneeded):
                            self.reduce_ready_partitions_and_remove_grads(param, i)

                        grad_acc.register_hook(reduce_partition_and_remove_grads)
                        self.grad_accs.append(grad_acc)

                    wrapper(param, i)

    def get_param_id(self, param):
        unique_id = id(param)
        return self.param_id[unique_id]

    def report_ipg_memory_usage(self, tag, param_elems):
        elem_count = self.elements_in_ipg_bucket + param_elems
        percent_of_bucket_size = (100.0 * elem_count) // self.reduce_bucket_size
        see_memory_usage(
            f"{tag}: elems in_bucket {self.elements_in_ipg_bucket} param {param_elems} max_percent {percent_of_bucket_size}"
        )

827 828
    # create a flat tensor aligned at the alignment boundary
    def flatten_dense_tensors_aligned(self, tensor_list, alignment):
829
        return self.flatten(align_dense_tensors(tensor_list, alignment))
830

H
Haibin Lin 已提交
831
    ############### Independent Partition Gradient ########################
J
Jeff Rasley 已提交
832 833
    def reduce_independent_p_g_buckets_and_remove_grads(self, param, i):
        if self.elements_in_ipg_bucket + param.numel() > self.reduce_bucket_size:
834
            self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", param.numel())
J
Jeff Rasley 已提交
835 836 837 838
            self.reduce_ipg_grads()
            if self.contiguous_gradients and self.overlap_comm:
                # Swap ipg_index between 0 and 1
                self.ipg_index = 1 - self.ipg_index
839
            self.report_ipg_memory_usage("In ipg_remove_grads after reduce_ipg_grads", param.numel())
J
Jeff Rasley 已提交
840 841 842 843 844 845 846

        param_id = self.get_param_id(param)
        assert self.params_already_reduced[param_id] == False, \
            f"The parameter {param_id} has already been reduced. \
            Gradient computed twice for this partition. \
            Multiple gradient reduction is currently not supported"

847 848 849 850 851 852 853 854
        if self.contiguous_gradients:
            if param.numel() > self.reduce_bucket_size:
                self.extra_large_param_to_reduce = param
            else:
                # keeping the gradients contiguous to prevent memory fragmentation, and avoid flattening
                new_grad_tensor = self.ipg_buffer[self.ipg_index].narrow(0, self.elements_in_ipg_bucket, param.numel())
                new_grad_tensor.copy_(param.grad.view(-1))
                param.grad.data = new_grad_tensor.data.view_as(param.grad)
J
Jeff Rasley 已提交
855 856

        self.elements_in_ipg_bucket += param.numel()
857 858 859

        assert param.grad is not None, f"rank {dist.get_rank()} - Invalid to reduce Param {param_id} with None gradient"

J
Jeff Rasley 已提交
860 861 862
        self.grads_in_ipg_bucket.append(param.grad)
        self.params_in_ipg_bucket.append((i, param, param_id))

A
Ammar Ahmad Awan 已提交
863 864 865 866
        #make sure the average tensor function knows how to average the gradients
        if is_moe_param(param):
            self.ipg_bucket_has_moe_params = True

J
Jeff Rasley 已提交
867 868 869 870 871 872 873
        self.report_ipg_memory_usage("End ipg_remove_grads", 0)

    def print_rank_0(self, message):
        if dist.get_rank() == 0:
            logger.info(message)

    def gradient_reduction_w_predivide(self, tensor):
A
Ammar Ahmad Awan 已提交
874

J
Jeff Rasley 已提交
875 876 877 878
        dp_world_size = dist.get_world_size(group=self.dp_process_group)

        tensor_to_allreduce = tensor

M
Mikhail Druzhinin 已提交
879 880
        if self.communication_data_type != tensor.dtype:
            tensor_to_allreduce = tensor.to(self.communication_data_type)
J
Jeff Rasley 已提交
881 882 883 884 885 886 887 888 889 890 891 892 893

        if self.postscale_gradients:
            if self.gradient_predivide_factor != 1.0:
                tensor_to_allreduce.mul_(1. / self.gradient_predivide_factor)

            dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)

            if self.gradient_predivide_factor != dp_world_size:
                tensor_to_allreduce.mul_(self.gradient_predivide_factor / dp_world_size)
        else:
            tensor_to_allreduce.div_(dp_world_size)
            dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)

M
Mikhail Druzhinin 已提交
894
        if self.communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce:
J
Jeff Rasley 已提交
895 896 897 898 899 900 901
            tensor.copy_(tensor_to_allreduce)

        return tensor

    def average_tensor(self, tensor):
        if self.overlap_comm:
            stream = self.reduction_stream
902
            stream.wait_stream(get_accelerator().current_stream())
J
Jeff Rasley 已提交
903
        else:
904
            stream = get_accelerator().current_stream()
J
Jeff Rasley 已提交
905

906
        with get_accelerator().stream(stream):
J
Jeff Rasley 已提交
907 908 909 910 911 912 913 914 915
            if not self.reduce_scatter:
                self.gradient_reduction_w_predivide(tensor)
                return

            # Accumulate destination ranks and bucket offsets for each gradient slice.
            # Note: potential future optimization, record access pattern of parameters
            # in backward pass and partition gradients w.r.t. access pattern so that our
            # bucket is guaranteed to be contiguous w.r.t. ranks
            rank_and_offsets = []
A
Ammar Ahmad Awan 已提交
916
            real_dp_process_group = []
J
Jeff Rasley 已提交
917
            curr_size = 0
918
            prev_id, prev_process_group = -1, None
A
Ammar Ahmad Awan 已提交
919 920 921

            process_group = self.dp_process_group
            # count = 0
J
Jeff Rasley 已提交
922
            for i, param, param_id in self.params_in_ipg_bucket:
A
Ammar Ahmad Awan 已提交
923 924 925 926

                process_group = self.dp_process_group
                #Averages gradients at parameter level if ipg has a moe param
                #Otherwise averaging is done at the entire buffer level at the end of the loop
927
                # MoE param have different groups
A
Ammar Ahmad Awan 已提交
928
                if self.ipg_bucket_has_moe_params:
929 930
                    process_group = self.expert_dp_process_group[param.group_name] if is_moe_param(
                        param) else self.dp_process_group
A
Ammar Ahmad Awan 已提交
931 932
                    param.grad.data.div_(dist.get_world_size(group=process_group))

J
Jeff Rasley 已提交
933
                partition_ids = self.param_to_partition_ids[i][param_id]
934 935
                assert all([p_id < dist.get_world_size(group=process_group) for p_id in partition_ids
                            ]), f"world size {dist.get_world_size(group=process_group)} and p_ids: {partition_ids}"
J
Jeff Rasley 已提交
936 937 938 939 940 941 942 943 944 945 946 947
                partition_size = self.partition_size[i]
                # Get all partition ids + their offsets
                partition_ids_w_offsets = []
                for partition_id in partition_ids:
                    offset = self.grad_start_offset[i][partition_id][param_id]
                    partition_ids_w_offsets.append((partition_id, offset))
                partition_ids_w_offsets.sort(key=lambda t: t[1])

                # Calculate rank and offsets for grad slices
                for idx in range(len(partition_ids_w_offsets)):
                    partition_id, offset = partition_ids_w_offsets[idx]

A
Ammar Ahmad Awan 已提交
948
                    # if dist.get_rank() == 0 and count < 100:
A
Alex Hedges 已提交
949
                    #     print(f"Rank {dist.get_rank()} rank offset id {idx} calculated dp size {dist.get_world_size(group=process_group)} real dp size {dist.get_world_size(self.real_dp_process_group[i])} and dst: {partition_id}")
A
Ammar Ahmad Awan 已提交
950 951
                    # count += 1

J
Jeff Rasley 已提交
952 953 954 955 956 957 958 959 960
                    # Calculate numel for grad slice depending on partition location
                    if idx == len(partition_ids_w_offsets) - 1:
                        # Last partition_id uses its own offset
                        numel = param.numel() - offset
                    else:
                        # Set numel to next partition's offset
                        numel = partition_ids_w_offsets[idx + 1][1] - offset

                    # Merge bucket ranges if they belong to the same rank
961
                    if partition_id == prev_id and process_group == prev_process_group:
J
Jeff Rasley 已提交
962 963 964 965
                        prev_pid, prev_size, prev_numel = rank_and_offsets[-1]
                        rank_and_offsets[-1] = (prev_pid, prev_size, prev_numel + numel)
                    else:
                        rank_and_offsets.append((partition_id, curr_size, numel))
A
Ammar Ahmad Awan 已提交
966
                        real_dp_process_group.append(process_group)
J
Jeff Rasley 已提交
967
                    curr_size += numel
968
                    prev_id, prev_process_group = partition_id, process_group
A
Ammar Ahmad Awan 已提交
969 970 971

            if not self.ipg_bucket_has_moe_params:
                tensor.div_(dist.get_world_size(group=self.dp_process_group))
J
Jeff Rasley 已提交
972

973 974 975 976
            tensor_to_reduce = tensor
            if self.communication_data_type != tensor.dtype:
                tensor_to_reduce = tensor.to(self.communication_data_type)

J
Jeff Rasley 已提交
977
            async_handles = []
A
Ammar Ahmad Awan 已提交
978
            for i, (dst, bucket_offset, numel) in enumerate(rank_and_offsets):
979
                grad_slice = tensor_to_reduce.narrow(0, int(bucket_offset), int(numel))
A
Ammar Ahmad Awan 已提交
980
                # if dist.get_rank() == 0:
A
Alex Hedges 已提交
981
                #     print(f"Rank {dist.get_rank()} rank offset id {i} real dp size {dist.get_world_size(group=real_dp_process_group[i])} and dst: {dst}")
A
Ammar Ahmad Awan 已提交
982 983
                # dist.barrier()
                #dist.barrier()
984
                dst_rank = dist.get_global_rank(real_dp_process_group[i], dst)
985
                async_handle = dist.reduce(grad_slice, dst=dst_rank, group=real_dp_process_group[i], async_op=True)
J
Jeff Rasley 已提交
986 987 988 989 990
                async_handles.append(async_handle)

            for handle in async_handles:
                handle.wait()

991 992 993
            if self.communication_data_type != tensor.dtype:
                tensor.copy_(tensor_to_reduce)

J
Jeff Rasley 已提交
994 995 996 997 998 999 1000 1001 1002 1003 1004 1005
    ##############################################################################
    ############################# CPU Offload Methods#############################
    ##############################################################################
    def get_grad_position(self, group_id, tensor_list, first_offset, partition_size):
        current_offset = 0

        for i, tensor in enumerate(tensor_list):
            param_id = self.get_param_id(tensor)
            param_start_offset = 0

            num_elements = tensor.numel()

1006
            # we need to offset to get to the right element
J
Jeff Rasley 已提交
1007 1008 1009 1010 1011
            if i == 0 and first_offset > 0:
                tensor_offset = first_offset
                num_elements = num_elements - tensor_offset
                param_start_offset = first_offset

1012
            # we dont need all elements of the tensor
J
Jeff Rasley 已提交
1013 1014 1015 1016
            if num_elements > (partition_size - current_offset):
                num_elements = partition_size - current_offset

            self.grad_position[param_id] = [
1017 1018
                int(group_id), int(param_start_offset),
                int(current_offset), int(num_elements)
J
Jeff Rasley 已提交
1019 1020 1021 1022 1023 1024 1025
            ]
            current_offset += num_elements

    def update_overflow_tracker_for_param_grad(self, param):
        if param.grad is not None and self._has_inf_or_nan(param.grad.data):
            self.local_overflow = True

1026 1027 1028 1029 1030 1031
    def _get_offload_gradient_dict(self):
        for param_group_index, _ in enumerate(self.optimizer.param_groups):
            self.offload_gradient_dict[param_group_index] = []
            for lp_param in self.params_in_partition[param_group_index]:
                param_id = self.get_param_id(lp_param)
                [_, _, dest_offset, num_elements] = self.grad_position[param_id]
1032 1033
                dest_tensor = self.single_partition_of_fp32_groups[param_group_index].grad.view(-1).narrow(
                    0, dest_offset, num_elements)
1034 1035
                self.offload_gradient_dict[param_group_index].append(dest_tensor)

J
Jeff Rasley 已提交
1036 1037 1038
    def async_accumulate_grad_in_cpu_via_gpu(self, param):
        param_id = self.get_param_id(param)

A
Ammar Ahmad Awan 已提交
1039 1040
        [i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]

1041
        # copy to a preexisiting buffer to avoid memory allocation penalty
1042
        dest_buffer = self.temp_grad_buffer_for_gpu_offload.view(-1).narrow(0, 0, param.numel())
J
Jeff Rasley 已提交
1043

A
Ammar Ahmad Awan 已提交
1044 1045 1046
        #buffer for storing gradients for this parameter in CPU
        def buffer_to_accumulate_to_in_cpu():
            if not self.fp16_master_weights_and_gradients:
1047
                return get_accelerator().pin_memory(torch.zeros(param.numel(), dtype=param.dtype, device=self.device))
A
Ammar Ahmad Awan 已提交
1048
            else:
1049
                return self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(0, dest_offset, num_elements)
A
Ammar Ahmad Awan 已提交
1050

A
Alex Hedges 已提交
1051
        #accumulate gradients into param.grad or parts of it that belongs to this partition
A
Ammar Ahmad Awan 已提交
1052 1053
        def accumulate_gradients():
            if not self.fp16_master_weights_and_gradients:
1054
                dest_buffer.copy_(self.accumulated_grads_in_cpu[param_id].view(-1), non_blocking=True)
A
Ammar Ahmad Awan 已提交
1055 1056
                param.grad.data.view(-1).add_(dest_buffer)
            else:
1057 1058 1059 1060 1061
                dest_buffer.narrow(0, source_offset,
                                   num_elements).copy_(self.accumulated_grads_in_cpu[param_id].view(-1),
                                                       non_blocking=True)
                param.grad.data.view(-1).narrow(0, source_offset,
                                                num_elements).add_(dest_buffer.narrow(0, source_offset, num_elements))
A
Ammar Ahmad Awan 已提交
1062 1063 1064 1065

        #move accumulated gradients back to CPU
        def copy_gradients_to_cpu():
            if not self.fp16_master_weights_and_gradients:
1066
                self.accumulated_grads_in_cpu[param_id].data.copy_(param.grad.data.view(-1), non_blocking=True)
A
Ammar Ahmad Awan 已提交
1067
            else:
1068 1069 1070
                self.accumulated_grads_in_cpu[param_id].data.copy_(param.grad.data.view(-1).narrow(
                    0, source_offset, num_elements),
                                                                   non_blocking=True)
A
Ammar Ahmad Awan 已提交
1071

J
Jeff Rasley 已提交
1072
        if param_id not in self.accumulated_grads_in_cpu:
A
Ammar Ahmad Awan 已提交
1073
            self.accumulated_grads_in_cpu[param_id] = buffer_to_accumulate_to_in_cpu()
J
Jeff Rasley 已提交
1074 1075

        if self.micro_step_id > 0:
A
Ammar Ahmad Awan 已提交
1076
            accumulate_gradients()
J
Jeff Rasley 已提交
1077

1078
        # at the boundary we will send 32bit directly
J
Jeff Rasley 已提交
1079
        if not self.is_gradient_accumulation_boundary:
A
Ammar Ahmad Awan 已提交
1080
            copy_gradients_to_cpu()
J
Jeff Rasley 已提交
1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109

    def set_norm_for_param_grad(self, param):
        param_id = self.get_param_id(param)
        accumulated_grad = self.accumulated_grads_in_cpu[
            param_id] if self.gradient_accumulation_steps > 1 else param.grad

        [i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]

        start = source_offset
        accumulated_grad = accumulated_grad.view(-1).narrow(0, start, num_elements)

        self.norm_for_param_grads[param_id] = accumulated_grad.data.double().norm(2)

    def set_norm_for_param_grad_in_gpu(self, param):
        param_id = self.get_param_id(param)
        accumulated_grad = param.grad

        [i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]

        start = source_offset
        accumulated_grad = accumulated_grad.view(-1).narrow(0, start, num_elements)

        self.norm_for_param_grads[param_id] = accumulated_grad.data.double().norm(2)

    def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param):
        param_id = self.get_param_id(param)

        [i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]

1110
        dest_tensor = self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(0, dest_offset, num_elements)
J
Jeff Rasley 已提交
1111

A
Ammar Ahmad Awan 已提交
1112 1113 1114 1115
        src_tensor = param.grad.view(-1).narrow(0, source_offset, num_elements)
        if not self.fp16_master_weights_and_gradients:
            src_tensor = src_tensor.float()

J
Jeff Rasley 已提交
1116
        dest_tensor.copy_(src_tensor, non_blocking=True)
1117
        param.grad = None  #offload only
J
Jeff Rasley 已提交
1118 1119 1120 1121 1122

    def complete_grad_norm_calculation_for_cpu_offload(self, params):
        total_norm = 0.0
        norm_type = 2.0
        for p in params:
1123
            # Pipeline parallelism may replicate parameters. Avoid multi-counting.
1124
            if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated:
1125 1126
                continue

J
Jeff Rasley 已提交
1127 1128
            if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
                param_id = self.get_param_id(p)
1129 1130 1131 1132 1133 1134
                # as some model have trainable parameters but skipped in training,
                # their backward hooks in self.create_reduce_and_remove_grad_hooks() will not run,
                # so they have no norm_for_param_grads
                if param_id in self.norm_for_param_grads:
                    param_norm = self.norm_for_param_grads[param_id]
                    total_norm += param_norm.item()**2
1135 1136 1137 1138
                else:
                    # As unused parameters in modules may not be expected sometimes,
                    # add an explicit error msg when it occurred and an option to
                    # avoid the error
1139 1140
                    assert self.ignore_unused_parameters, """
                        This assert indicates that your module has parameters that
1141
                        were not used in producing loss.
1142 1143
                        You can avoid this assert by
                        (1) enable ignore_unused_parameters option in zero_optimization config;
1144 1145 1146
                        (2) making sure all trainable parameters and `forward` function
                            outputs participate in calculating loss.
                    """
J
Jeff Rasley 已提交
1147 1148

        # Sum across all model parallel GPUs.
1149
        total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
1150
        dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=self.dp_process_group)
J
Jeff Rasley 已提交
1151

1152
        self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM)
J
Jeff Rasley 已提交
1153 1154 1155

        total_norm = total_norm_cuda[0].item()**(1. / norm_type)

1156
        if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:
J
Jeff Rasley 已提交
1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184
            total_norm = -1

        return total_norm

    ############################################################################################
    def copy_grads_in_partition(self, param):
        if self.cpu_offload:

            if self.gradient_accumulation_steps > 1:
                self.async_accumulate_grad_in_cpu_via_gpu(param)

            if self.is_gradient_accumulation_boundary:
                self.set_norm_for_param_grad_in_gpu(param)

                self.update_overflow_tracker_for_param_grad(param)

                self.async_inplace_copy_grad_to_fp32_buffer_from_gpu(param)

            return
        #print(f"ID {self.get_param_id(param)} grad norm {param.grad.norm()}")
        if self.grads_in_partition is None:
            self.grads_in_partition_offset = 0
            total_size = 0
            for group in self.params_in_partition:
                for param_in_partition in group:
                    total_size += param_in_partition.numel()

            see_memory_usage(f"before copying {total_size} gradients into partition")
1185 1186 1187
            self.grads_in_partition = torch.empty(int(total_size),
                                                  dtype=self.dtype,
                                                  device=get_accelerator().current_device_name())
J
Jeff Rasley 已提交
1188 1189
            see_memory_usage(f"after copying {total_size} gradients into partition")

A
Alex Hedges 已提交
1190
        # The allreduce buffer will be rewritten. Copy the gradients in partition to a new buffer
1191
        new_grad_tensor = self.grads_in_partition.view(-1).narrow(0, self.grads_in_partition_offset, param.numel())
J
Jeff Rasley 已提交
1192 1193 1194 1195 1196 1197 1198
        new_grad_tensor.copy_(param.grad.view(-1))
        param.grad.data = new_grad_tensor.data.view_as(param.grad)
        #print(f"Grad norm after copy to contiguous_buffer {param.grad.data.norm()}")
        self.grads_in_partition_offset += param.numel()

    def reduce_ipg_grads(self):
        if self.contiguous_gradients:
J
Jeff Rasley 已提交
1199 1200 1201
            if self.extra_large_param_to_reduce is not None:
                assert len(self.params_in_ipg_bucket) == 1, "more than 1 param in ipg bucket, this shouldn't happen"
                _, _, param_id = self.params_in_ipg_bucket[0]
1202 1203
                assert self.get_param_id(self.extra_large_param_to_reduce
                                         ) == param_id, "param in ipg bucket does not match extra-large param"
J
Jeff Rasley 已提交
1204 1205 1206 1207
                self.average_tensor(self.extra_large_param_to_reduce.grad.view(-1))
                self.extra_large_param_to_reduce = None
            else:
                self.average_tensor(self.ipg_buffer[self.ipg_index])
J
Jeff Rasley 已提交
1208
        else:
1209 1210 1211
            self.buffered_reduce_fallback(None,
                                          self.grads_in_ipg_bucket,
                                          elements_per_buffer=self.elements_in_ipg_bucket)
J
Jeff Rasley 已提交
1212

1213 1214 1215 1216
        if self.overlap_comm:
            stream = self.reduction_stream
        elif self.cpu_offload:
            # TODO: copy_grad_stream is disabled because of race with reduce. This hurts perf and should be fixed.
1217
            #            get_accelerator().synchronize()
1218
            #            stream = self.copy_grad_stream
1219
            stream = get_accelerator().current_stream()
1220
        else:
1221
            stream = get_accelerator().current_stream()
1222

1223
        with get_accelerator().stream(stream):
J
Jeff Rasley 已提交
1224
            for _, param, param_id in self.params_in_ipg_bucket:
1225 1226 1227 1228 1229 1230

                assert self.params_already_reduced[param_id] == False, \
                    f"The parameter {param_id} has already been reduced. \
                    Gradient computed twice for this partition. \
                    Multiple gradient reduction is currently not supported"

J
Jeff Rasley 已提交
1231
                self.params_already_reduced[param_id] = True
1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244

                if self.partition_gradients:
                    if not self.is_param_in_current_partition[param_id]:
                        if self.overlap_comm and self.contiguous_gradients is False:
                            # Clear grads of other partitions during the next reduction
                            # to avoid clearing them before the reduction is complete.
                            if self.previous_reduced_grads is None:
                                self.previous_reduced_grads = []
                            self.previous_reduced_grads.append(param)
                        else:
                            param.grad = None  #only if self.partition_gradients
                    elif self.contiguous_gradients:
                        self.copy_grads_in_partition(param)
1245
                else:  # zero stage 1 - partition only optimizer state
1246
                    if self.contiguous_gradients and self.is_param_in_current_partition[param_id]:
1247
                        self.copy_grads_in_partition(param)
J
Jeff Rasley 已提交
1248 1249 1250

        self.grads_in_ipg_bucket = []
        self.params_in_ipg_bucket = []
A
Ammar Ahmad Awan 已提交
1251
        self.ipg_bucket_has_moe_params = False
J
Jeff Rasley 已提交
1252 1253 1254 1255
        self.elements_in_ipg_bucket = 0
        #####################################################################

    def reduce_ready_partitions_and_remove_grads(self, param, i):
J
Jeff Rasley 已提交
1256 1257
        if self.partition_gradients or self.is_gradient_accumulation_boundary:
            self.reduce_independent_p_g_buckets_and_remove_grads(param, i)
J
Jeff Rasley 已提交
1258 1259

    def zero_reduced_gradients(self, partition_id, i):
1260

J
Jeff Rasley 已提交
1261 1262 1263 1264 1265 1266 1267 1268
        def are_all_related_partitions_reduced(params_id):
            for partition_id in self.param_to_partition_ids[i][params_id]:
                if not self.is_partition_reduced[i][partition_id]:
                    return False
            return True

        for params_id in self.is_grad_computed[i][partition_id]:
            if are_all_related_partitions_reduced(params_id):
1269
                self.param_dict[params_id].grad = None  # dead code
J
Jeff Rasley 已提交
1270 1271

    def flatten_and_print(self, message, tensors, start=0, n=5):
1272
        flatten_tensor = self.flatten(tensors)
J
Jeff Rasley 已提交
1273 1274 1275 1276 1277 1278 1279

        def print_func():
            logger.info(flatten_tensor.contiguous().view(-1).narrow(0, start, n))

        self.sequential_execution(print_func, message)

    def get_grads_to_reduce(self, i, partition_id):
1280

A
Alex Hedges 已提交
1281
        def get_reducible_portion(key):
J
Jeff Rasley 已提交
1282 1283 1284
            grad = self.param_dict[key].grad
            total_elements = grad.numel()
            start = self.grad_start_offset[i][partition_id][key]
1285 1286
            num_elements = min(total_elements - start,
                               self.partition_size[i] - self.grad_partition_insertion_offset[i][partition_id][key])
J
Jeff Rasley 已提交
1287 1288 1289 1290
            if not pg_correctness_test:
                if num_elements == total_elements:
                    return grad
                else:
1291
                    return grad.contiguous().view(-1).narrow(0, int(start), int(num_elements))
J
Jeff Rasley 已提交
1292 1293 1294 1295
            else:
                if num_elements == total_elements:
                    return grad.clone()
                else:
1296
                    return grad.clone().contiguous().view(-1).narrow(0, int(start), int(num_elements))
J
Jeff Rasley 已提交
1297 1298 1299

        grads_to_reduce = []
        for key in self.is_grad_computed[i][partition_id]:
A
Alex Hedges 已提交
1300
            grad = get_reducible_portion(key)
J
Jeff Rasley 已提交
1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320
            grads_to_reduce.append(grad)
        return grads_to_reduce

    def sequential_execution(self, function, message, group=None):
        if group is None:
            group = self.dp_process_group
        if dist.get_rank(group=group) == 0:
            logger.info(message)
        for id in range(dist.get_world_size(group=group)):
            if id == dist.get_rank(group=group):
                function()
            dist.barrier(group=group)

    def set_none_gradients_to_zero(self, i, partition_id):
        for param_id in self.is_grad_computed[i][partition_id]:
            param = self.param_dict[param_id]
            if param.grad is None:
                param.grad = torch.zero_like(param)

    ######################Reduction Related Methods##############################
1321
    def allreduce_bucket(self, bucket, rank=None, log=None):
J
Jeff Rasley 已提交
1322
        rank = None
1323
        tensor = self.flatten(bucket)
J
Jeff Rasley 已提交
1324 1325 1326 1327

        tensor_to_allreduce = tensor

        if pg_correctness_test:
M
Mikhail Druzhinin 已提交
1328
            communication_data_type = torch.float32
1329 1330
        else:
            communication_data_type = self.communication_data_type
J
Jeff Rasley 已提交
1331

M
Mikhail Druzhinin 已提交
1332 1333
        if communication_data_type != tensor.dtype:
            tensor_to_allreduce = tensor.to(communication_data_type)
J
Jeff Rasley 已提交
1334 1335 1336 1337 1338 1339 1340

        tensor_to_allreduce.div_(dist.get_world_size(group=self.dp_process_group))

        if rank is None:
            #    "All Reducing"
            dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
        else:
1341
            global_rank = dist.get_global_rank(self.dp_process_group, rank)
J
Jeff Rasley 已提交
1342 1343
            dist.reduce(tensor_to_allreduce, global_rank, group=self.dp_process_group)

M
Mikhail Druzhinin 已提交
1344
        if communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce:
J
Jeff Rasley 已提交
1345 1346 1347 1348 1349
            if rank is None or rank == dist.get_rank(group=self.dp_process_group):
                tensor.copy_(tensor_to_allreduce)

        return tensor

1350 1351 1352
    def _clear_previous_reduced_grads(self):
        if self.previous_reduced_grads is not None:
            for param in self.previous_reduced_grads:
1353
                param.grad = None  # overlap enabled
1354 1355
            self.previous_reduced_grads = None

1356
    # if rank is specified do a reduction instead of an allreduce
J
Jeff Rasley 已提交
1357 1358
    def allreduce_and_copy(self, small_bucket, rank=None, log=None):
        if self.overlap_comm:
1359
            get_accelerator().synchronize()
1360 1361
            # It is safe to clear the previously reduced grads of other partitions
            self._clear_previous_reduced_grads()
J
Jeff Rasley 已提交
1362 1363
            stream = self.reduction_stream
        else:
1364
            stream = get_accelerator().current_stream()
J
Jeff Rasley 已提交
1365

1366
        with get_accelerator().stream(stream):
J
Jeff Rasley 已提交
1367 1368
            allreduced = self.allreduce_bucket(small_bucket, rank=rank, log=log)
            if rank is None or rank == dist.get_rank(group=self.dp_process_group):
1369
                for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)):
J
Jeff Rasley 已提交
1370 1371
                    buf.copy_(synced)

1372
    def allreduce_no_retain(self, bucket, numel_per_bucket=500000000, rank=None, log=None):
J
Jeff Rasley 已提交
1373 1374 1375 1376 1377 1378 1379 1380
        small_bucket = []
        numel = 0
        for tensor in bucket:
            small_bucket.append(tensor)
            numel = numel + tensor.numel()
            if numel > numel_per_bucket:
                self.allreduce_and_copy(small_bucket, rank=rank, log=None)
                small_bucket = []
1381

J
Jeff Rasley 已提交
1382 1383 1384
        if len(small_bucket) > 0:
            self.allreduce_and_copy(small_bucket, rank=rank, log=log)

1385 1386
    # allows using reduction of gradients instead of using all_reduce

1387
    def buffered_reduce_fallback(self, rank, grads, elements_per_buffer=500000000, log=None):
J
Jeff Rasley 已提交
1388 1389 1390
        split_buckets = split_half_float_double(grads)

        for i, bucket in enumerate(split_buckets):
1391
            self.allreduce_no_retain(bucket, numel_per_bucket=elements_per_buffer, rank=rank, log=log)
J
Jeff Rasley 已提交
1392 1393 1394 1395 1396

    #############################################################################
    #############################################################################
    #############################################################################

1397 1398
    # views the tensor as multiple partitions and returns
    # those partitions
A
Ammar Ahmad Awan 已提交
1399
    def get_data_parallel_partitions(self, tensor, group_id):
J
Jeff Rasley 已提交
1400 1401
        partitions = []

A
Ammar Ahmad Awan 已提交
1402
        dp = dist.get_world_size(group=self.real_dp_process_group[group_id])
1403
        # dp_id = dist.get_rank(group=self.real_dp_process_group[group_id])
J
Jeff Rasley 已提交
1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435

        total_num_elements = tensor.numel()

        base_size = total_num_elements // dp
        remaining = total_num_elements % dp

        start = 0
        for id in range(dp):
            partition_size = base_size
            if id < remaining:
                partition_size = partition_size + 1
            partitions.append(tensor.narrow(0, start, partition_size))
            start = start + partition_size
        return partitions

    def get_partition_info(self, tensor_list, partition_size, partition_id):
        params_in_partition = []
        params_not_in_partition = []

        start_index = partition_size * partition_id
        end_index = partition_size * (partition_id + 1)

        current_index = 0
        first_offset = 0

        for tensor in tensor_list:

            tensor_size = tensor.numel()

            if (current_index >= start_index and current_index < end_index):
                params_in_partition.append(tensor)

1436
            elif start_index > current_index and start_index < (current_index + tensor_size):
J
Jeff Rasley 已提交
1437 1438
                params_in_partition.append(tensor)

1439 1440
                assert (first_offset == 0
                        ), "This can happen either zero or only once as this must be the first tensor in the partition"
J
Jeff Rasley 已提交
1441 1442 1443 1444 1445 1446 1447 1448 1449
                first_offset = start_index - current_index

            else:
                params_not_in_partition.append(tensor)

            current_index = current_index + tensor_size

        return params_in_partition, params_not_in_partition, first_offset

1450
    def zero_grad(self, set_to_none=False):
J
Jeff Rasley 已提交
1451 1452 1453 1454 1455
        """
        Zero FP16 parameter grads.
        """
        # FP32 grad should never exist.
        # For speed, set model fp16 grad to None by default
R
Rana Ali Amjad 已提交
1456
        for group in self.bit16_groups:
J
Jeff Rasley 已提交
1457
            for p in group:
1458
                if set_to_none:
1459
                    p.grad = None  # epilogue and in step
J
Jeff Rasley 已提交
1460 1461 1462 1463 1464 1465 1466 1467
                else:
                    if p.grad is not None:
                        p.grad.detach_()
                        p.grad.zero_()

    def _model_parallel_all_reduce(self, tensor, op):
        """ Perform all reduce within model parallel group, if any.
        """
1468
        if self.model_parallel_group is None or self.model_parallel_world_size == 1:
1469
            pass
J
Jeff Rasley 已提交
1470
        else:
1471
            dist.all_reduce(tensor=tensor, op=op, group=self.model_parallel_group)
J
Jeff Rasley 已提交
1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492

    def get_grad_norm_direct(self, gradients, params, norm_type=2):
        """Clips gradient norm of an iterable of parameters.

        This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
        added functionality to handle model parallel parameters. Note that
        the gradients are modified in place.

        Arguments:
            parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
                single Tensor that will have gradients normalized
            max_norm (float or int): max norm of the gradients
            norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
                infinity norm.

        Returns:
            Total norm of the parameters (viewed as a single vector).
        """
        norm_type = float(norm_type)
        if norm_type == inf:
            total_norm = max(g.data.abs().max() for g in gradients)
1493
            total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
1494
            dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=self.dp_process_group)
J
Jeff Rasley 已提交
1495 1496

            # Take max across all GPUs.
1497
            self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX)
J
Jeff Rasley 已提交
1498 1499 1500
            total_norm = total_norm_cuda[0].item()
        else:
            total_norm = 0.0
1501
            # if dist.get_rank() == 0:
A
Alex Hedges 已提交
1502
            #    logger.info(f"Total Norm beginning {total_norm}")
J
Jeff Rasley 已提交
1503
            for g, p in zip(gradients, params):
1504
                # Pipeline parallelism may replicate parameters. Avoid multi-counting.
1505
                if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated:
1506
                    continue
J
Jeff Rasley 已提交
1507 1508 1509 1510
                if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
                    param_norm = g.data.double().norm(2)
                    total_norm += param_norm.item()**2
            # Sum across all model parallel GPUs.
1511
            total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
1512
            dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=self.dp_process_group)
J
Jeff Rasley 已提交
1513

1514
            self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM)
J
Jeff Rasley 已提交
1515 1516 1517

            total_norm = total_norm_cuda[0].item()**(1. / norm_type)

1518
        if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:
J
Jeff Rasley 已提交
1519 1520 1521 1522
            total_norm = -1

        return total_norm

1523 1524 1525
    # creates a flat fused tensor from the tensor list starting at the first_offset
    # in the first tensor of the list. If there are not enough elements in the tensor
    # list then the flat tensor will be padded with zeros
1526
    def get_flat_partition(self, tensor_list, first_offset, partition_size, dtype, device, return_tensor_list=False):
J
Jeff Rasley 已提交
1527 1528 1529 1530
        flat_tensor_list = []
        current_size = 0
        for i, tensor in enumerate(tensor_list):
            if tensor.grad is None:
1531
                tensor.grad = torch.zeros_like(tensor)
J
Jeff Rasley 已提交
1532 1533 1534 1535 1536

            tensor = tensor.grad
            num_elements = tensor.numel()
            tensor_offset = 0

1537
            # we need to offset to get to the right element
J
Jeff Rasley 已提交
1538 1539 1540 1541
            if i == 0 and first_offset > 0:
                tensor_offset = first_offset
                num_elements = num_elements - tensor_offset

1542
            # we dont need all elements of the tensor
J
Jeff Rasley 已提交
1543 1544 1545
            if num_elements > (partition_size - current_size):
                num_elements = partition_size - current_size

1546 1547
            # we need a narrow view of the tensor based on the tensor offset and number of elements that
            # we need from this tensor
J
Jeff Rasley 已提交
1548
            if tensor_offset > 0 or num_elements < tensor.numel():
1549
                flat_tensor_list.append(tensor.contiguous().view(-1).narrow(0, int(tensor_offset), int(num_elements)))
J
Jeff Rasley 已提交
1550 1551 1552 1553 1554
            else:
                flat_tensor_list.append(tensor)

            current_size = current_size + num_elements

1555
        # this means its the last partition and does not align with the dp boundary. We need to pad before flattening
J
Jeff Rasley 已提交
1556
        if current_size < partition_size:
1557
            flat_tensor_list.append(torch.zeros(int(partition_size - current_size), dtype=dtype, device=device))
J
Jeff Rasley 已提交
1558 1559 1560 1561

        if return_tensor_list:
            return flat_tensor_list

1562
        return self.flatten(flat_tensor_list)
J
Jeff Rasley 已提交
1563 1564 1565

    def free_grad_in_param_list(self, param_list):
        for p in param_list:
1566
            p.grad = None  # in step
J
Jeff Rasley 已提交
1567 1568 1569 1570 1571

    def reset_cpu_buffers(self):
        self.norm_for_param_grads = {}
        self.local_overflow = False

1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591
    def log_timers(self, timer_names):
        if self.timers is None:
            return

        self.timers.log(names=list(timer_names))

    def start_timers(self, timer_names):
        if self.timers is None:
            return

        for name in timer_names:
            self.timers(name).start()

    def stop_timers(self, timer_names):
        if self.timers is None:
            return

        for name in timer_names:
            self.timers(name).stop()

J
Jeff Rasley 已提交
1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602
    def set_lr(self, lr):
        """Set the learning rate."""
        for param_group in self.optimizer.param_groups:
            param_group["lr"] = lr

    def get_lr(self):
        """Return the current learning rate."""
        return self.optimizer.param_groups[0]["lr"]

    def override_loss_scale(self, loss_scale):
        if loss_scale != self.external_loss_scale:
1603
            logger.info(f'[deepspeed] setting loss scale from {self.external_loss_scale} -> {loss_scale}')
J
Jeff Rasley 已提交
1604 1605 1606
        self.custom_loss_scaler = True
        self.external_loss_scale = loss_scale

1607 1608 1609 1610 1611 1612
    def scaled_global_norm(self, norm_type=2):
        assert norm_type == 2, "only L2 norm supported"
        norm_groups = []
        for i, group in enumerate(self.bit16_groups):
            partition_id = dist.get_rank(group=self.real_dp_process_group[i])
            if self.cpu_offload:
1613
                norm_groups.append(self.complete_grad_norm_calculation_for_cpu_offload(self.params_in_partition[i]))
1614 1615
                single_grad_partition = self.single_partition_of_fp32_groups[i].grad
            else:
1616
                norm_groups.append(self.get_grad_norm_direct(self.averaged_gradients[i], self.params_in_partition[i]))
1617 1618 1619 1620 1621 1622 1623 1624 1625 1626

        if self.has_moe_layers:
            self._average_expert_grad_norms(norm_groups)

        # note that the get_global_norm function only supports l2 norm
        return get_global_norm(norm_list=norm_groups)

    def get_bit16_param_group(self, group_no):
        bit16_partitions = self.parallel_partitioned_bit16_groups[group_no]
        partition_id = dist.get_rank(group=self.real_dp_process_group[group_no])
1627
        return [bit16_partitions[dist.get_rank(group=self.real_dp_process_group[group_no])]]
1628 1629 1630 1631

    def _optimizer_step(self, group_no):
        original_param_groups = self.optimizer.param_groups
        self.optimizer.param_groups = [original_param_groups[group_no]]
1632
        # Disabling this as the C++ side copy & synchronize is not working correctly
1633 1634 1635 1636 1637 1638
        #from deepspeed.ops.adam import DeepSpeedCPUAdam
        #if type(self.optimizer) == DeepSpeedCPUAdam and self.dtype == torch.half:
        #    self.optimizer.step(fp16_param_groups=[self.get_bit16_param_group(group_no)])
        #else:
        #    self.optimizer.step()
        self.optimizer.step()
1639 1640
        self.optimizer.param_groups = original_param_groups

J
Jeff Rasley 已提交
1641 1642 1643 1644 1645 1646 1647 1648 1649 1650
    def step(self, closure=None):
        """
        Not supporting closure.
        """
        self.micro_step_id = -1

        see_memory_usage(f"In step before checking overflow")

        # First compute norm for all group so we know if there is overflow
        self.check_overflow()
1651 1652 1653 1654
        OPTIMIZER_ALLGATHER = 'optimizer_allgather'
        OPTIMIZER_GRADIENTS = 'optimizer_gradients'
        OPTIMIZER_STEP = 'optimizer_step'
        timer_names = [OPTIMIZER_ALLGATHER, OPTIMIZER_GRADIENTS, OPTIMIZER_STEP]
J
Jeff Rasley 已提交
1655 1656 1657 1658 1659

        prev_scale = self.loss_scale
        self._update_scale(self.overflow)
        if self.overflow:
            see_memory_usage('After overflow before clearing gradients')
1660
            self.zero_grad(set_to_none=True)
J
Jeff Rasley 已提交
1661 1662 1663 1664 1665 1666 1667
            if self.cpu_offload:
                self.reset_cpu_buffers()
            else:
                self.averaged_gradients = {}

            see_memory_usage('After overflow after clearing gradients')

1668 1669
            self.start_timers(timer_names)
            self.stop_timers(timer_names)
J
Jeff Rasley 已提交
1670 1671
            return

1672
        # Step 1:- Calculate gradient norm using fp-16 grads
1673 1674 1675 1676 1677
        if self.dtype == torch.float16:
            see_memory_usage('Before norm calculation')
            scaled_global_grad_norm = self.scaled_global_norm()
            self._global_grad_norm = scaled_global_grad_norm / prev_scale
            see_memory_usage('After norm before optimizer')
1678 1679

        # Step 2:- run optimizer and upscaling simultaneously
R
Rana Ali Amjad 已提交
1680
        for i, group in enumerate(self.bit16_groups):
1681
            self.start_timers([OPTIMIZER_GRADIENTS])
A
Ammar Ahmad Awan 已提交
1682
            partition_id = dist.get_rank(group=self.real_dp_process_group[i])
J
Jeff Rasley 已提交
1683 1684
            if self.cpu_offload:
                single_grad_partition = self.single_partition_of_fp32_groups[i].grad
1685 1686 1687
                if self.dtype == torch.float16:
                    self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm)

1688 1689 1690 1691
                self.stop_timers([OPTIMIZER_GRADIENTS])
                self.start_timers([OPTIMIZER_STEP])
                self._optimizer_step(i)

1692 1693 1694 1695 1696 1697 1698 1699 1700
                # Disabled, this is not currently working
                #from deepspeed.ops.adam import DeepSpeedCPUAdam
                #if not (type(self.optimizer) == DeepSpeedCPUAdam and self.dtype == torch.half):
                #    bit16_partitions = self.parallel_partitioned_bit16_groups[i]
                #    fp32_partition = self.single_partition_of_fp32_groups[i]
                #    bit16_partitions[partition_id].data.copy_(fp32_partition.data)
                bit16_partitions = self.parallel_partitioned_bit16_groups[i]
                fp32_partition = self.single_partition_of_fp32_groups[i]
                bit16_partitions[partition_id].data.copy_(fp32_partition.data)
J
Jeff Rasley 已提交
1701

1702 1703
                self.stop_timers([OPTIMIZER_STEP])
            else:
1704
                # free gradients for all the parameters that are not updated by this process(ZeRO stage2)
J
Jeff Rasley 已提交
1705 1706
                self.free_grad_in_param_list(self.params_not_in_partition[i])

1707
                # create a flat gradients for parameters updated by this process
J
Jeff Rasley 已提交
1708
                # If we are last partition, ensure we have same size grads and partition size, if not pad with zero tensors
1709
                if partition_id == dist.get_world_size(group=self.real_dp_process_group[i]) - 1:
1710
                    single_grad_partition = self.flatten_dense_tensors_aligned(
J
Jeff Rasley 已提交
1711
                        self.averaged_gradients[i],
1712
                        int(self.partition_size[i])).to(self.single_partition_of_fp32_groups[i].dtype)
J
Jeff Rasley 已提交
1713
                else:
1714 1715
                    single_grad_partition = self.flatten(self.averaged_gradients[i]).to(
                        self.single_partition_of_fp32_groups[i].dtype)
J
Jeff Rasley 已提交
1716
                assert single_grad_partition.numel() == self.partition_size[i], \
1717 1718
                    "averaged gradients have different number of elements that partition size {} {} {} {}".format(
                        single_grad_partition.numel(), self.partition_size[i], i, partition_id)
J
Jeff Rasley 已提交
1719 1720

                self.single_partition_of_fp32_groups[i].grad = single_grad_partition
1721
                # release all the gradient since we have already created a necessary copy in dp_grad_partition(ZeRO stage2)
J
Jeff Rasley 已提交
1722 1723 1724 1725
                self.free_grad_in_param_list(self.params_in_partition[i])

                self.averaged_gradients[i] = None

1726 1727 1728
                if self.dtype == torch.float16:
                    self.unscale_and_clip_grads([single_grad_partition], scaled_global_grad_norm)

1729 1730 1731 1732 1733 1734 1735 1736 1737 1738
                self.stop_timers([OPTIMIZER_GRADIENTS])

                # Step 3:- run the optimizer if no offloading
                self.start_timers([OPTIMIZER_STEP])
                self._optimizer_step(i)
                # Step 4:- get rid of the fp32 gradients. Not needed anymore
                self.single_partition_of_fp32_groups[i].grad = None
                del single_grad_partition
                bit16_partitions = self.parallel_partitioned_bit16_groups[i]
                fp32_partition = self.single_partition_of_fp32_groups[i]
R
Rana Ali Amjad 已提交
1739
                bit16_partitions[partition_id].data.copy_(fp32_partition.data)
1740
                self.stop_timers([OPTIMIZER_STEP])
J
Jeff Rasley 已提交
1741

1742
        see_memory_usage('After optimizer before all-gather')
J
Jeff Rasley 已提交
1743 1744 1745
        if self.cpu_offload:
            self.reset_cpu_buffers()

1746
        self.start_timers([OPTIMIZER_ALLGATHER])
1747 1748
        # Gather the updated weights from everyone.
        # Then all partitions of the model parameters are updated and ready for next round forward.
1749 1750 1751 1752
        all_gather_dp_groups(partitioned_param_groups=self.parallel_partitioned_bit16_groups,
                             dp_process_group=self.real_dp_process_group,
                             start_alignment_factor=self.nccl_start_alignment_factor,
                             allgather_bucket_size=self.allgather_bucket_size)
J
Jeff Rasley 已提交
1753

1754
        self.stop_timers([OPTIMIZER_ALLGATHER])
J
Jeff Rasley 已提交
1755 1756

        # TODO: we probably don't need this? just to be safe
1757
        for i in range(len(self.bit16_groups)):
R
Rana Ali Amjad 已提交
1758
            self._update_model_bit16_weights(i)
J
Jeff Rasley 已提交
1759

1760
        self.log_timers(timer_names)
J
Jeff Rasley 已提交
1761
        see_memory_usage('After zero_optimizer step')
1762

J
Jeff Rasley 已提交
1763 1764
        return

1765 1766
    @torch.no_grad()
    def update_lp_params(self):
1767 1768
        for i, (bit16_partitions, fp32_partition) in enumerate(
                zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups)):
1769 1770 1771 1772 1773 1774
            partition_id = dist.get_rank(group=self.real_dp_process_group[i])
            bit16_partitions[partition_id].data.copy_(fp32_partition.data)
            # print_rank_0(f'update_lp_params {i=} {partition_id=}', force=True)
            # if i == 0:
            #     print_rank_0(f'{fp32_partition[:10]=}', force=True)

1775 1776 1777 1778
        all_gather_dp_groups(partitioned_param_groups=self.parallel_partitioned_bit16_groups,
                             dp_process_group=self.real_dp_process_group,
                             start_alignment_factor=self.nccl_start_alignment_factor,
                             allgather_bucket_size=self.allgather_bucket_size)
1779

A
Ammar Ahmad Awan 已提交
1780 1781 1782
    def _average_expert_grad_norms(self, norm_groups):
        for i, norm in enumerate(norm_groups):
            if self.is_moe_param_group[i]:
1783
                scaled_norm = norm * 1.0 / float(dist.get_world_size(group=self.real_dp_process_group[i]))
A
Ammar Ahmad Awan 已提交
1784
                scaled_norm_tensor = torch.tensor(scaled_norm,
1785
                                                  device=get_accelerator().device_name(),
A
Ammar Ahmad Awan 已提交
1786
                                                  dtype=torch.float)
1787
                dist.all_reduce(scaled_norm_tensor, group=self.real_dp_process_group[i])
A
Ammar Ahmad Awan 已提交
1788 1789
                norm_groups[i] = scaled_norm_tensor.item()

1790
    def unscale_and_clip_grads(self, grad_groups_flat, total_norm):
J
Jeff Rasley 已提交
1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818
        # compute combined scale factor for this group
        combined_scale = self.loss_scale
        if self.clip_grad > 0.:
            # norm is in fact norm*scale
            clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad
            if clip > 1:
                combined_scale = clip * self.loss_scale

        for grad in grad_groups_flat:
            if isinstance(grad, list):
                sub_partitions = grad
                for g in sub_partitions:
                    g.data.mul_(1. / combined_scale)
            else:
                grad.data.mul_(1. / combined_scale)

    def _check_overflow(self, partition_gradients=True):
        self.overflow = self.has_overflow(partition_gradients)

    # `params` is a list / generator of torch.Variable
    def has_overflow_serial(self, params, is_grad_list=False):
        for p in params:
            if p.grad is not None and self._has_inf_or_nan(p.grad.data):
                return True

        return False

    def has_overflow_partitioned_grads_serial(self):
R
Rana Ali Amjad 已提交
1819
        for i in range(len(self.bit16_groups)):
J
Jeff Rasley 已提交
1820 1821 1822 1823 1824 1825 1826
            for j, grad in enumerate(self.averaged_gradients[i]):
                if grad is not None and self._has_inf_or_nan(grad.data, j):
                    return True
        return False

    def has_overflow(self, partition_gradients=True):
        if partition_gradients:
1827
            overflow = self.local_overflow if self.cpu_offload else self.has_overflow_partitioned_grads_serial()
1828
            overflow_gpu = get_accelerator().ByteTensor([overflow])
A
Ammar Ahmad Awan 已提交
1829 1830
            '''This will capture overflow across all data parallel and expert parallel process
            Since expert parallel process are a subset of data parallel process'''
1831
            dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.dp_process_group)
J
Jeff Rasley 已提交
1832 1833 1834

        else:
            params = []
R
Rana Ali Amjad 已提交
1835
            for group in self.bit16_groups:
J
Jeff Rasley 已提交
1836 1837 1838 1839
                for param in group:
                    params.append(param)

            overflow = self.has_overflow_serial(params, is_grad_list=partition_gradients)
1840
            overflow_gpu = get_accelerator().ByteTensor([overflow])
J
Jeff Rasley 已提交
1841 1842 1843

        # Since each model parallel GPU carries only part of the model,
        # make sure overflow flag is synced across all the model parallel GPUs
1844
        self._model_parallel_all_reduce(tensor=overflow_gpu, op=dist.ReduceOp.MAX)
J
Jeff Rasley 已提交
1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882

        overflow = overflow_gpu[0].item()
        return bool(overflow)

    # `x` is a torch.Tensor
    @staticmethod
    def _has_inf_or_nan(x, j=None):
        try:
            # if x is half, the .float() incurs an additional deep copy, but it's necessary if
            # Pytorch's .sum() creates a one-element tensor of the same type as x
            # (which is true for some recent version of pytorch).
            cpu_sum = float(x.float().sum())
            # More efficient version that can be used if .sum() returns a Python scalar
            # cpu_sum = float(x.sum())
        except RuntimeError as instance:
            # We want to check if inst is actually an overflow exception.
            # RuntimeError could come from a different error.
            # If so, we still want the exception to propagate.
            if "value cannot be converted" not in instance.args[0]:
                raise
            return True
        else:
            if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
                return True
            return False

    def backward(self, loss, retain_graph=False):
        """
        :attr:`backward` performs the following steps:

        1. fp32_loss = loss.float()
        2. scaled_loss = fp32_loss*loss_scale
        3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
        """
        self.micro_step_id += 1

        if self.contiguous_gradients:
            self.ipg_buffer = []
J
Jeff Rasley 已提交
1883
            buf_0 = torch.empty(int(self.reduce_bucket_size),
1884
                                dtype=self.dtype,
1885
                                device=get_accelerator().current_device_name())
J
Jeff Rasley 已提交
1886 1887 1888 1889
            self.ipg_buffer.append(buf_0)

            # Use double buffers to avoid data access conflict when overlap_comm is enabled.
            if self.overlap_comm:
J
Jeff Rasley 已提交
1890
                buf_1 = torch.empty(int(self.reduce_bucket_size),
1891
                                    dtype=self.dtype,
1892
                                    device=get_accelerator().current_device_name())
J
Jeff Rasley 已提交
1893 1894 1895
                self.ipg_buffer.append(buf_1)
            self.ipg_index = 0

J
Jeff Rasley 已提交
1896 1897 1898 1899 1900
        if self.custom_loss_scaler:
            scaled_loss = self.external_loss_scale * loss
            scaled_loss.backward()
        else:
            self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
J
Jeff Rasley 已提交
1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928

    def check_overflow(self, partition_gradients=True):
        self._check_overflow(partition_gradients)

    def _update_scale(self, has_overflow=False):
        self.loss_scaler.update_scale(has_overflow)

    # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
    def _get_state(self):
        return self.optimizer.state

    def _set_state(self, value):
        self.optimizer.state = value

    state = property(_get_state, _set_state)

    # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
    # (for example, to adjust the learning rate)
    def _get_param_groups(self):
        return self.optimizer.param_groups

    def _set_param_groups(self, value):
        self.optimizer.param_groups = value

    param_groups = property(_get_param_groups, _set_param_groups)

    # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
    def _get_loss_scale(self):
J
Jeff Rasley 已提交
1929 1930 1931 1932
        if self.custom_loss_scaler:
            return self.external_loss_scale
        else:
            return self.loss_scaler.cur_scale
J
Jeff Rasley 已提交
1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967

    def _set_loss_scale(self, value):
        self.loss_scaler.cur_scale = value

    loss_scale = property(_get_loss_scale, _set_loss_scale)
    cur_scale = property(_get_loss_scale, _set_loss_scale)

    # Return group tensor after removing paddings that are added for alignment to DP world size.
    # This method works on the assumption that each group contains a single flattened tensor.
    def _get_groups_without_padding(self, groups_with_padding):
        groups_without_padding = []
        for i, group in enumerate(groups_with_padding):
            lean_length = group.numel() - self.groups_padding[i]
            groups_without_padding.append(group[:lean_length])

        return groups_without_padding

    # Return optimizer state after removing paddings that are added for alignment.
    def _get_state_without_padding(self, state_with_padding, padding):
        lean_state = {}
        for key, value in state_with_padding.items():
            if torch.is_tensor(value):
                lean_length = value.numel() - padding
                lean_state[key] = value[:lean_length]
            else:
                lean_state[key] = value

        return lean_state

    # Return base optimizer states.
    # This method assumes that each param group contains a single flattened tensor.
    def _get_base_optimizer_state(self):
        optimizer_groups_state = []
        for i, group in enumerate(self.optimizer.param_groups):
            p = group['params'][0]
1968
            lean_optimizer_state = self._get_state_without_padding(self.optimizer.state[p], self.groups_padding[i])
J
Jeff Rasley 已提交
1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987
            optimizer_groups_state.append(lean_optimizer_state)

        return optimizer_groups_state

    def state_dict(self):
        """
        Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
        This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
        of the contained Pytorch optimizer.
        Example::
            checkpoint = {}
            checkpoint['model'] = model.state_dict()
            checkpoint['optimizer'] = optimizer.state_dict()
            torch.save(checkpoint, "saved.pth")
        """
        state_dict = {}
        state_dict['loss_scaler'] = self.loss_scaler
        state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
        state_dict['overflow'] = self.overflow
1988
        state_dict[CLIP_GRAD] = self.clip_grad
J
Jeff Rasley 已提交
1989

1990 1991 1992 1993
        if self.elastic_checkpoint:
            state_dict[BASE_OPTIMIZER_STATE] = self._get_base_optimizer_state()
        else:
            state_dict[BASE_OPTIMIZER_STATE] = self.optimizer.state_dict()
J
Jeff Rasley 已提交
1994

J
Jeff Rasley 已提交
1995
        # Remove paddings for DP alignment to enable loading for other alignment values
1996
        fp32_groups_without_padding = self._get_groups_without_padding(self.single_partition_of_fp32_groups)
1997 1998
        state_dict[SINGLE_PARTITION_OF_FP32_GROUPS] = fp32_groups_without_padding

O
Olatunji Ruwase 已提交
1999
        state_dict[
2000
            ZERO_STAGE] = ZeroStageEnum.gradients if self.partition_gradients else ZeroStageEnum.optimizer_states
O
Olatunji Ruwase 已提交
2001
        state_dict[GROUP_PADDINGS] = self.groups_padding
2002
        state_dict[PARTITION_COUNT] = self.partition_count
J
Jeff Rasley 已提交
2003

2004
        state_dict[DS_VERSION] = version
2005
        state_dict[PARAM_SLICE_MAPPINGS] = self._param_slice_mappings
J
Jeff Rasley 已提交
2006 2007 2008

        return state_dict

2009
    # Restore base optimizer fp32 weights from elastic checkpoint by:
J
Jeff Rasley 已提交
2010 2011 2012
    # 1) Merging fp32 weights from checkpoints of all partitions
    # 2) Extracting fp32 weights for current partition from merged weights
    # 3) Using extracted weights to update base optimizer weights directly.
2013
    def _restore_from_elastic_fp32_weights(self, all_state_dict):
J
Jeff Rasley 已提交
2014
        merged_single_partition_of_fp32_groups = []
2015

J
Jeff Rasley 已提交
2016
        for i in range(len(self.single_partition_of_fp32_groups)):
A
Ammar Ahmad Awan 已提交
2017
            partition_id = dist.get_rank(group=self.real_dp_process_group[i])
2018
            merged_partitions = [sd[SINGLE_PARTITION_OF_FP32_GROUPS][i] for sd in all_state_dict]
2019
            if self.is_moe_group(self.optimizer.param_groups[i]):
2020
                ranks = self.get_ep_ranks(group_name=self.optimizer.param_groups[i]['name'])
2021
                merged_partitions = [merged_partitions[i] for i in ranks]
2022
            flat_merged_partitions = self.flatten_dense_tensors_aligned(
J
Jeff Rasley 已提交
2023
                merged_partitions,
2024
                self.nccl_start_alignment_factor * dist.get_world_size(group=self.real_dp_process_group[i]))
A
Ammar Ahmad Awan 已提交
2025
            dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions, i)
J
Jeff Rasley 已提交
2026 2027 2028 2029 2030
            merged_single_partition_of_fp32_groups.append(dp_partitions[partition_id])

        for current, saved in zip(self.single_partition_of_fp32_groups, merged_single_partition_of_fp32_groups):
            current.data.copy_(saved.data)

R
Rana Ali Amjad 已提交
2031 2032
    # Restore base optimizer fp32 weights from ZeRO fp16 or bfloat16 weights
    def _restore_from_bit16_weights(self):
2033 2034
        for group_id, (bit16_partitions, fp32_partition) in enumerate(
                zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups)):
A
Ammar Ahmad Awan 已提交
2035
            partition_id = dist.get_rank(group=self.real_dp_process_group[group_id])
R
Rana Ali Amjad 已提交
2036
            fp32_partition.data.copy_(bit16_partitions[partition_id].data)
J
Jeff Rasley 已提交
2037

R
Rana Ali Amjad 已提交
2038
    # Refresh the fp32 master params from the fp16 or bfloat16 copies.
J
Jeff Rasley 已提交
2039
    def refresh_fp32_params(self):
R
Rana Ali Amjad 已提交
2040
        self._restore_from_bit16_weights()
J
Jeff Rasley 已提交
2041 2042

    # Extract optimizer state for current partition from merged states of all partitions
A
Ammar Ahmad Awan 已提交
2043 2044 2045
    def _partition_base_optimizer_state(self, state_key, all_partition_states, group_id):
        partition_id = dist.get_rank(group=self.real_dp_process_group[group_id])
        alignment = dist.get_world_size(group=self.real_dp_process_group[group_id])
J
Jeff Rasley 已提交
2046
        if torch.is_tensor(all_partition_states[0]):
2047 2048
            flat_merged_partitions = self.flatten_dense_tensors_aligned(all_partition_states, alignment)
            dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions, group_id)
J
Jeff Rasley 已提交
2049 2050 2051 2052 2053
            return dp_partitions[partition_id]
        else:
            # Assume non-tensor states are not partitioned and equal across ranks, so return first one
            return all_partition_states[0]

2054 2055 2056 2057 2058 2059 2060 2061 2062 2063 2064 2065 2066
    def _restore_base_optimizer_state(self, base_optimizer_group_states):
        if type(base_optimizer_group_states) == dict:
            base_optimizer_group_states = base_optimizer_group_states['state']
        for i, group in enumerate(self.optimizer.param_groups):
            p = group['params'][0]
            for key, saved in base_optimizer_group_states[i].items():
                if torch.is_tensor(self.optimizer.state[p][key]):
                    dst_tensor = self.optimizer.state[p][key]
                    src_tensor = _get_padded_tensor(saved, dst_tensor.numel())
                    self.optimizer.state[p][key].data.copy_(src_tensor.data)
                else:
                    self.optimizer.state[p][key] = saved

2067
    def get_ep_ranks(self, rank=0, group_name=None):
2068
        from deepspeed.utils import groups
2069 2070 2071
        expert_parallel_size_ = groups._get_expert_parallel_world_size(group_name)
        world_size = groups._get_data_parallel_world_size()
        rank = groups._get_expert_parallel_rank(group_name)
2072 2073 2074
        ranks = range(rank, world_size, expert_parallel_size_)
        return list(ranks)

2075
    # Restore base optimizer state from elastic checkpoint by
J
Jeff Rasley 已提交
2076 2077 2078
    # 1) Merging optimizer state from checkpoints of all partitions
    # 2) Extracting optimizer state for current partition from the merged state
    # 3) Using the extracted value to directly update the base optimizer.
2079
    def _restore_elastic_base_optimizer_state(self, all_state_dict):
J
Jeff Rasley 已提交
2080 2081 2082
        base_optimizer_group_states = []
        for i in range(len(self.optimizer.param_groups)):
            partition_states = {}
2083
            all_partition_group_states = [sd[BASE_OPTIMIZER_STATE][i] for sd in all_state_dict]
2084 2085

            if self.is_moe_group(self.optimizer.param_groups[i]):
2086 2087
                ranks = self.get_ep_ranks(group_name=self.optimizer.param_groups[i]['name'])
                all_partition_group_states = [all_partition_group_states[i] for i in ranks]
2088

J
Jeff Rasley 已提交
2089
            for key in all_partition_group_states[0].keys():
2090 2091
                all_partition_states = [all_states[key] for all_states in all_partition_group_states]
                partition_states[key] = self._partition_base_optimizer_state(key, all_partition_states, i)
J
Jeff Rasley 已提交
2092 2093
            base_optimizer_group_states.append(partition_states)

2094
        self._restore_base_optimizer_state(base_optimizer_group_states)
J
Jeff Rasley 已提交
2095 2096 2097 2098

    def load_state_dict(self,
                        state_dict_list,
                        load_optimizer_states=True,
O
Olatunji Ruwase 已提交
2099 2100
                        load_from_fp32_weights=False,
                        checkpoint_folder=None):
2101
        if checkpoint_folder:
2102
            self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights)
2103
        else:
2104 2105 2106
            self._load_legacy_checkpoint(state_dict_list, load_optimizer_states, load_from_fp32_weights)

    def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights):
2107 2108 2109 2110 2111 2112 2113 2114 2115 2116 2117 2118 2119 2120 2121 2122
        self._load_hp_checkpoint_state(checkpoint_folder)

    @property
    def param_groups(self):
        """Forward the wrapped optimizer's parameters."""
        return self.optimizer.param_groups

    def _load_hp_checkpoint_state(self, checkpoint_dir):
        checkpoint_dir = os.path.join(checkpoint_dir, "zero")
        tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
        tp_world_size = self.mpu.get_slice_parallel_world_size()

        for i, _ in enumerate(self.optimizer.param_groups):
            for lp in self.bit16_groups[i]:
                if lp._hp_mapping is not None:
                    #print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}")
2123 2124 2125 2126
                    lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank,
                                                tp_world_size)

    def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, load_from_fp32_weights=False):
J
Jeff Rasley 已提交
2127 2128 2129 2130 2131 2132 2133 2134 2135 2136 2137 2138 2139 2140 2141 2142 2143
        r"""Loading ZeRO checkpoint

        Arguments:
            state_dict_list: List of all saved ZeRO checkpoints, one for each saved partition.
                Note that the number of saved partitions may differ from number of loading partitions to support
                changing GPU count, specifically DP world size, between saving and loading checkpoints.
            load_optimizer_states: Boolean indicating whether or not to load base optimizer states
            load_from_fp32_weights: Boolean indicating whether to initialize fp32 master weights from fp32
            copies in checkpoints (no precision loss) or from model's fp16 copies (with precision loss).
        """
        """
        Loads a state_dict created by an earlier call to state_dict().
        If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
        whose parameters in turn came from ``model``, it is expected that the user
        will call ``model.load_state_dict()`` before
        ``fp16_optimizer_instance.load_state_dict()`` is called.
        Example::
2144
            model = torch.nn.Linear(D_in, D_out).to(get_accelerator().device_name()).half()
J
Jeff Rasley 已提交
2145 2146 2147 2148 2149 2150 2151
            optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
            optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
            ...
            checkpoint = torch.load("saved.pth")
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
        """
2152

J
Jeff Rasley 已提交
2153
        # I think it should actually be ok to reload the optimizer before the model.
2154 2155
        dp_rank = dist.get_rank(group=self.dp_process_group)
        current_rank_sd = state_dict_list[dp_rank]
2156
        self.loss_scaler = current_rank_sd.get('loss_scaler', self.loss_scaler)
2157
        self.dynamic_loss_scale = current_rank_sd.get('dynamic_loss_scale', self.dynamic_loss_scale)
2158 2159
        self.overflow = current_rank_sd.get('overflow', self.overflow)
        self.clip_grad = current_rank_sd.get(CLIP_GRAD, self.clip_grad)
2160

2161
        ckpt_version = current_rank_sd.get(DS_VERSION, False)
2162
        assert ckpt_version, f"Empty ds_version in checkpoint, not clear how to proceed"
2163
        ckpt_version = pkg_version.parse(ckpt_version)
J
Jeff Rasley 已提交
2164

J
Jeff Rasley 已提交
2165 2166 2167 2168 2169
        # zero stage 1 mode
        if not self.partition_gradients:
            required_version = pkg_version.parse("0.3.17")
            error_str = f"ZeRO stage 1 changed in {required_version} and is not backwards compatible " \
                "with older stage 1 checkpoints. If you'd like to load an old ZeRO-1 checkpoint " \
2170
                "please use an older version of DeepSpeed (<= 0.5.8) and set 'legacy_stage1': true in your zero config json."
2171
            assert required_version <= ckpt_version, f"Old version: {ckpt_version} {error_str}"
J
Jeff Rasley 已提交
2172

2173
        ckpt_is_rigid = isinstance(current_rank_sd[BASE_OPTIMIZER_STATE], dict)
J
Jeff Rasley 已提交
2174

O
Olatunji Ruwase 已提交
2175 2176 2177 2178 2179 2180 2181 2182 2183 2184
        # padding is always at the last rank/partition
        # if DP=1024 and param-group elems=16 -> padding will be 1024-16 across all but one rank
        # scenario-1 (shrink): saving w. 4 gpus -> loading w. 2 gpus
        # scenario-2 (expand): saving w. 2 gpus -> loading w. 4 gpus
        # if load_optimizer_states:
        #     if new_dp_size:
        #         self.strip_padding()
        #         self.add_padding_w_new_dp_size()
        #     self.optimizer.load_state_dict(current_rank_sd[BASE_OPTIMIZER_STATE])

J
Jeff Rasley 已提交
2185
        if load_optimizer_states:
2186 2187 2188 2189 2190 2191 2192 2193 2194
            if ckpt_is_rigid:
                # loading rigid ckpt into either rigid or elastic exec
                self.optimizer.load_state_dict(current_rank_sd[BASE_OPTIMIZER_STATE])
            else:
                if self.elastic_checkpoint:
                    # loading elastic into elastic exec
                    self._restore_elastic_base_optimizer_state(state_dict_list)
                else:
                    # loading an elastic checkpoint into rigid exec
2195
                    self._restore_base_optimizer_state(current_rank_sd[BASE_OPTIMIZER_STATE])
J
Jeff Rasley 已提交
2196 2197 2198 2199 2200 2201 2202 2203 2204 2205 2206 2207 2208 2209 2210 2211 2212

        # At this point, the optimizer's references to the model's fp32 parameters are up to date.
        # The optimizer's hyperparameters and internal buffers are also up to date.
        # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
        # out of date.  There are two options.
        # 1:  Refresh the master params from the model's fp16 params.
        # This requires less storage but incurs precision loss.
        # 2:  Save and restore the fp32 master copies separately.
        # We choose option 1 if changing DP degree and option 2 otherwise.
        #
        # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device
        # of their associated parameters, because it's possible those buffers might not exist yet in
        # the current optimizer instance.  In our case, as long as the current FP16_Optimizer has been
        # constructed in the same way as the one whose state_dict we are loading, the same master params
        # are guaranteed to exist, so we can just copy_() from the saved master params.

        if load_from_fp32_weights:
2213 2214 2215 2216 2217
            # option 2 from above
            if self.elastic_checkpoint and not ckpt_is_rigid:
                self._restore_from_elastic_fp32_weights(state_dict_list)
            else:
                # For non-elastic checkpoint, simply copying from saved weights of current rank is sufficient.
2218 2219
                for current, saved in zip(self.single_partition_of_fp32_groups,
                                          current_rank_sd[SINGLE_PARTITION_OF_FP32_GROUPS]):
2220 2221
                    src_tensor = _get_padded_tensor(saved, current.numel())
                    current.data.copy_(src_tensor.data)
J
Jeff Rasley 已提交
2222
        else:
2223
            # option 1 from above
R
Rana Ali Amjad 已提交
2224
            self._restore_from_bit16_weights()
J
Jeff Rasley 已提交
2225

2226 2227 2228
        if load_optimizer_states:
            self._link_all_hp_params()

J
Jeff Rasley 已提交
2229 2230 2231

def _handle_overflow(cpu_sum, x, i):
    import math
2232
    rank = dist.get_rank()
J
Jeff Rasley 已提交
2233 2234 2235 2236 2237 2238
    if rank == 0:
        t_i = -1
        for v_i, v in enumerate(x.data.contiguous().view(-1)):
            if not math.isfinite(float(v)):
                t_i = v_i
                break
2239
        logger.info(f"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}")
S
Stas Bekman 已提交
2240 2241 2242 2243 2244 2245 2246 2247 2248 2249 2250 2251 2252 2253 2254 2255 2256 2257 2258 2259 2260 2261


def estimate_zero2_model_states_mem_needs(total_params,
                                          num_gpus_per_node=1,
                                          num_nodes=1,
                                          cpu_offload=True,
                                          additional_buffer_factor=1.5):

    total_gpus = num_nodes * num_gpus_per_node

    if cpu_offload:
        gpu_mem = 2 * total_params
        cpu_mem = total_params * max(4 * total_gpus, 16) * additional_buffer_factor
    else:
        gpu_mem = 4 * total_params + int(16 * total_params / total_gpus)
        cpu_mem = total_params * 4 * num_gpus_per_node * additional_buffer_factor

    return int(cpu_mem), int(gpu_mem)


def model_to_params(model):
    # shared params calculated only once
2262
    total_params = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())
S
Stas Bekman 已提交
2263 2264 2265 2266 2267 2268 2269 2270 2271 2272 2273 2274 2275 2276 2277 2278 2279 2280 2281 2282 2283 2284 2285 2286 2287 2288 2289
    return total_params


def estimate_zero2_model_states_mem_needs_all_live(model,
                                                   num_gpus_per_node=1,
                                                   num_nodes=1,
                                                   additional_buffer_factor=1.5):
    """
    Print out estimates on memory usage requirements for ZeRO 2 params, optim states and gradients
    for a given ``model`` and hardware setup.

    If you have an actual model object, use this function and everything will be derived
    automatically.

    If it's a hypothetical model, use ``estimate_zero2_model_states_mem_needs_all_cold`` where you have to pass
    the ``total_params`` explicitly.

    Args:
        - ``model``: ``nn.Module`` object
        - ``num_gpus_per_node``: how many gpus per node (defaults to 1)
        - ``num_nodes``: how many nodes (defaults to 1),
        - ``additional_buffer_factor``: estimation factor (defaults to 1.5):

    """

    total_params = model_to_params(model)

2290 2291 2292 2293
    estimate_zero2_model_states_mem_needs_all_cold(total_params=total_params,
                                                   num_gpus_per_node=num_gpus_per_node,
                                                   num_nodes=num_nodes,
                                                   additional_buffer_factor=additional_buffer_factor)
S
Stas Bekman 已提交
2294 2295 2296 2297 2298 2299 2300 2301 2302 2303 2304 2305 2306 2307 2308 2309 2310 2311 2312 2313 2314 2315 2316


def estimate_zero2_model_states_mem_needs_all_cold(total_params,
                                                   num_gpus_per_node=1,
                                                   num_nodes=1,
                                                   additional_buffer_factor=1.5):
    """
    Print out estimates on memory usage requirements for ZeRO 2 params, optim states and gradients
    for a given ``model`` and hardware setup.

    If it's a hypothetical model, use this function where you have to pass
    the ``total_params`` and ``largest_layer_params`` explicitly.

    If you have an actual model object, use ``estimate_zero2_model_states_mem_needs_all_live`` and everything
    will be derived automatically.

    Args:
        - ``total_params``: total  model params
        - ``num_gpus_per_node``: how many gpus per node (defaults to 1)
        - ``num_nodes``: how many nodes (defaults to 1),
        - ``additional_buffer_factor``: estimation factor (defaults to 1.5):

    """
2317

S
Stas Bekman 已提交
2318 2319
    def format_options(cpu_offload):
        enabled = []
2320 2321
        device = f'{OffloadDeviceEnum.cpu:4}' if cpu_offload else "none"
        enabled.append(f"offload_optimizer={device}")
S
Stas Bekman 已提交
2322 2323 2324 2325
        return ", ".join(enabled)

    nodes_str = "nodes" if num_nodes > 1 else "node"
    gpus_str = "GPUs" if num_gpus_per_node > 1 else "GPU"
2326 2327 2328
    print("Estimated memory needed for params, optim states and gradients for a:\n"
          f"HW: Setup with {num_nodes} {nodes_str}, {num_gpus_per_node} {gpus_str} per node.\n"
          f"SW: Model with {int(total_params/1e6)}M total params.")
S
Stas Bekman 已提交
2329 2330
    print("  per CPU  |  per GPU |   Options")
    for cpu_offload in [True, False]:
2331 2332 2333 2334 2335
        cpu_mem, gpu_mem = estimate_zero2_model_states_mem_needs(total_params=total_params,
                                                                 num_gpus_per_node=num_gpus_per_node,
                                                                 num_nodes=num_nodes,
                                                                 cpu_offload=cpu_offload,
                                                                 additional_buffer_factor=additional_buffer_factor)
S
Stas Bekman 已提交
2336 2337 2338

        options_str = format_options(cpu_offload=cpu_offload)
        print(f" {cpu_mem/2**30:7.2f}GB | {gpu_mem/2**30:6.2f}GB | {options_str}")