engine.py 57.7 KB
Newer Older
O
Olatunji Ruwase 已提交
1 2 3 4 5
'''
Copyright 2019 The Microsoft DeepSpeed Team
'''

import os
6
import torch
7
import warnings
O
Olatunji Ruwase 已提交
8
import torch.distributed as dist
9

J
Jeff Rasley 已提交
10
import apex
11
from apex import amp
O
Olatunji Ruwase 已提交
12
from torch.nn.modules import Module
J
Jeff Rasley 已提交
13
from torch.distributed.distributed_c10d import _get_global_rank
O
Olatunji Ruwase 已提交
14 15
from tensorboardX import SummaryWriter

16 17
from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer
from deepspeed.runtime.zero.stage1 import FP16_DeepSpeedZeroOptimizer_Stage1
J
Jeff Rasley 已提交
18
from deepspeed.runtime.zero.utils import is_zero_supported_optimizer
19 20 21 22
from deepspeed.runtime.activation_checkpointing import checkpointing as activation_checkpointing
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
from deepspeed.runtime.config import DeepSpeedConfig, \
J
Jeff Rasley 已提交
23
    ADAM_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, DEEPSPEED_ADAM, DEEPSPEED_OPTIMIZERS
24 25
from deepspeed.runtime.dataloader import DeepSpeedDataLoader
from deepspeed.runtime.constants import \
J
Jeff Rasley 已提交
26
    ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \
J
Jeff Rasley 已提交
27 28
    TORCH_DISTRIBUTED_DEFAULT_PORT
from deepspeed.runtime.zero.constants import \
J
Jeff Rasley 已提交
29
    ZERO_OPTIMIZATION_OPTIMIZER_STATES, ZERO_OPTIMIZATION_GRADIENTS
30 31
from deepspeed.runtime.csr_tensor import CSRTensor
import deepspeed.runtime.lr_schedules as lr_schedules
32
from deepspeed.utils import logger, log_dist
33
from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer
O
Olatunji Ruwase 已提交
34

35 36
from .utils import ensure_directory_exists

O
Olatunji Ruwase 已提交
37 38 39 40 41 42 43 44 45 46
MEMORY_OPT_ALLREDUCE_SIZE = 500000000
SUMMARY_WRITER_DIR_NAME = "JobId"

try:
    from apex_C import flatten
    from apex_C import unflatten
except ImportError:
    try:
        _ = warned_flatten
    except NameError:
C
Chunyang Wen 已提交
47
        logger.warning(
O
Olatunji Ruwase 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
            "Warning:  apex was installed without --cpp_ext.  Falling back to Python flatten and unflatten."
        )
        warned_flatten = True
    from torch._utils import _flatten_dense_tensors as flatten
    from torch._utils import _unflatten_dense_tensors as unflatten


def split_half_float_double_csr(tensors):
    dtypes = [
        "torch.cuda.HalfTensor",
        "torch.cuda.FloatTensor",
        "torch.cuda.DoubleTensor",
        CSRTensor.type()
    ]
    buckets = []
    for i, dtype in enumerate(dtypes):
        bucket = [t for t in tensors if t.type() == dtype]
        if bucket:
            buckets.append((dtype, bucket))
    return buckets


def _initialize_parameter_parallel_groups(parameter_parallel_size=None):
    data_parallel_size = int(dist.get_world_size())
    if parameter_parallel_size is None:
        parameter_parallel_size = int(data_parallel_size)
C
Chunyang Wen 已提交
74 75 76
    logger.info("data_parallel_size: %s, parameter_parallel_size: %s",
                data_parallel_size,
                parameter_parallel_size)
O
Olatunji Ruwase 已提交
77 78 79 80 81 82 83 84 85 86 87 88 89
    assert data_parallel_size % parameter_parallel_size == 0, \
        'world size should be divisible by parameter parallel size'
    rank = dist.get_rank()
    my_group = None
    for i in range(dist.get_world_size() // parameter_parallel_size):
        ranks = range(i * parameter_parallel_size, (i + 1) * parameter_parallel_size)
        group = torch.distributed.new_group(ranks)
        if rank in ranks:
            my_group = group
    return my_group


def print_configuration(args, name):
C
Chunyang Wen 已提交
90
    logger.info('{}:'.format(name))
O
Olatunji Ruwase 已提交
91 92
    for arg in sorted(vars(args)):
        dots = '.' * (29 - len(arg))
C
Chunyang Wen 已提交
93
        logger.info('  {} {} {}'.format(arg, dots, getattr(args, arg)))
O
Olatunji Ruwase 已提交
94 95


96
class DeepSpeedEngine(Module):
O
Olatunji Ruwase 已提交
97 98 99 100 101 102 103 104 105 106
    r"""DeepSpeed engine for training.
    """
    def __init__(self,
                 args,
                 model,
                 optimizer=None,
                 model_parameters=None,
                 training_data=None,
                 lr_scheduler=None,
                 mpu=None,
107
                 dist_init_required=None,
J
Jeff Rasley 已提交
108 109
                 collate_fn=None,
                 config_params=None):
110
        super(DeepSpeedEngine, self).__init__()
O
Olatunji Ruwase 已提交
111 112 113 114 115 116 117 118
        self.client_optimizer = optimizer
        self.client_model_parameters = model_parameters
        self.client_lr_scheduler = lr_scheduler
        self.training_data = training_data
        self.collate_fn = collate_fn
        self.mpu = mpu
        self.data_parallel_group = None
        self.global_steps = 0
119
        self.global_samples = 0
O
Olatunji Ruwase 已提交
120 121 122 123
        self.micro_steps = 0
        self.skipped_steps = 0
        self.gradient_average = True
        self.warn_unscaled_loss = True
J
Jeff Rasley 已提交
124
        self.config_params = config_params
125 126
        self.loaded_checkpoint_mp_world_size = None
        self.loaded_checkpoint_dp_world_size = None
127
        self.enable_backward_allreduce = True
O
Olatunji Ruwase 已提交
128

129 130 131
        if dist_init_required is None:
            dist_init_required = not dist.is_initialized()

132 133
        self._mpi_check(args, dist_init_required)

134
        self.dist_backend = "nccl"
O
Olatunji Ruwase 已提交
135
        if dist_init_required:
136
            if not dist.is_initialized():
C
Chunyang Wen 已提交
137
                logger.info("Initializing torch distributed with backend: {}".format(
138 139 140
                    self.dist_backend))
                dist.init_process_group(backend=self.dist_backend)
            else:
C
Chunyang Wen 已提交
141
                logger.warning(
142 143
                    "Was given dist_init_required=True but detected that torch"
                    "distributed was already initialized, cannot initialize twice.")
O
Olatunji Ruwase 已提交
144 145 146 147 148

        self._do_args_sanity_check(args)
        self._configure_with_arguments(args, mpu)
        self._do_sanity_check()

149 150 151
        self._init_distributed(dist_init_required)

        if self.tensorboard_enabled() and self.global_rank == 0:
O
Olatunji Ruwase 已提交
152 153
            self.summary_writer = self.get_summary_writer()

154 155 156
        # Configure distributed model
        self._configure_distributed_model(model)

J
Jeff Rasley 已提交
157 158 159
        # Configure wall clock timer
        self.timers = SynchronizedWallClockTimer()

O
Olatunji Ruwase 已提交
160 161 162
        # Throughput timer
        self.tput_timer = ThroughputTimer(
            batch_size=self.train_micro_batch_size_per_gpu(),
163
            num_workers=self.dp_world_size,
164
            steps_per_output=self.steps_per_print(),
O
Olatunji Ruwase 已提交
165 166
            monitor_memory=False)

167 168 169 170
        if training_data:
            self.training_dataloader = self.deepspeed_io(training_data)
        else:
            self.training_dataloader = None
O
Olatunji Ruwase 已提交
171 172 173 174 175 176 177 178 179 180 181 182 183 184

        # Configure optimizer and scheduler
        self.optimizer = None
        self.lr_scheduler = None
        if model_parameters or optimizer:
            self._configure_optimizer(optimizer, model_parameters)
            self._configure_lr_scheduler(lr_scheduler)
            self._report_progress(0)

        # Bookkeeping for csr support
        self.csr_tensor_module_names = set()
        if self.sparse_gradients_enabled():
            for name, module in self.module.named_modules():
                if isinstance(module, torch.nn.Embedding):
S
Samyam Rajbhandari 已提交
185
                    self.csr_tensor_module_names.add(name + ".weight")
C
Chunyang Wen 已提交
186 187
                    logger.info("Will convert {} to sparse (csr) "
                                "tensor during training".format(name))
O
Olatunji Ruwase 已提交
188 189 190 191 192 193 194 195 196 197

        self.save_non_zero_checkpoint = False
        self.save_zero_checkpoint = False
        self._configure_checkpointing(dist_init_required)

        if self.global_rank == 0:
            self._config.print('DeepSpeedLight configuration')
            if self.dump_state():
                print_configuration(self, 'DeepSpeedLight')

198
    def _mpi_check(self, args, dist_init_required):
J
Jeff Rasley 已提交
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
        if hasattr(args, 'deepspeed_mpi') and args.deepspeed_mpi:
            from mpi4py import MPI
            import subprocess
            comm = MPI.COMM_WORLD
            rank = comm.Get_rank()
            world_size = comm.Get_size()

            master_addr = None
            if rank == 0:
                hostname_cmd = ["hostname -I"]
                result = subprocess.check_output(hostname_cmd, shell=True)
                master_addr = result.decode('utf-8').split()[0]
            master_addr = comm.bcast(master_addr, root=0)

            # Determine local rank by assuming hostnames are unique
            proc_name = MPI.Get_processor_name()
            all_procs = comm.allgather(proc_name)
            local_rank = sum([i == proc_name for i in all_procs[:rank]])

            os.environ['RANK'] = str(rank)
            os.environ['WORLD_SIZE'] = str(world_size)
            args.local_rank = local_rank
            os.environ['MASTER_ADDR'] = master_addr
            os.environ['MASTER_PORT'] = TORCH_DISTRIBUTED_DEFAULT_PORT

C
Chunyang Wen 已提交
224
            logger.info(
J
Jeff Rasley 已提交
225 226 227 228 229 230 231
                "Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
                .format(os.environ['RANK'],
                        args.local_rank,
                        os.environ['WORLD_SIZE'],
                        os.environ['MASTER_ADDR'],
                        os.environ['MASTER_PORT']))

232 233
            if not dist_init_required and dist.is_initialized():
                assert dist.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(rank, dist.get_rank())
234 235
                assert dist.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format(
                    world_size, dist.get_world_size())
236

O
Olatunji Ruwase 已提交
237 238 239 240 241 242 243 244 245 246 247
    def tensorboard_enabled(self):
        return self._config.tensorboard_enabled

    def tensorboard_output_path(self):
        return self._config.tensorboard_output_path

    def tensorboard_job_name(self):
        return self._config.tensorboard_job_name

    def get_summary_writer(self,
                           name="DeepSpeedJobName",
248 249
                           base=os.path.join(os.environ["HOME"],
                                             "tensorboard")):
O
Olatunji Ruwase 已提交
250
        if self.tensorboard_output_path():
251 252 253 254 255 256 257 258 259 260 261
            log_dir = self.tensorboard_output_path()
        else:
            if self.tensorboard_job_name():
                name = self.tensorboard_job_name()
            if 'DLWS_JOB_ID' in os.environ:
                SUMMARY_WRITER_DIR_NAME = os.path.join(os.environ['DLWS_JOB_ID'], "logs")
            log_dir = os.path.join(base, SUMMARY_WRITER_DIR_NAME, name)

        os.makedirs(log_dir, exist_ok=True)

        return SummaryWriter(log_dir=log_dir)
O
Olatunji Ruwase 已提交
262 263 264 265

    def wall_clock_breakdown(self):
        return self._config.wall_clock_breakdown

J
Jeff Rasley 已提交
266 267 268
    def memory_breakdown(self):
        return self._config.memory_breakdown

O
Olatunji Ruwase 已提交
269 270 271 272 273 274 275 276 277 278
    def sparse_gradients_enabled(self):
        return self._config.sparse_gradients_enabled

    def train_batch_size(self):
        return self._config.train_batch_size

    def train_micro_batch_size_per_gpu(self):
        return self._config.train_micro_batch_size_per_gpu

    def optimizer_name(self):
J
Jeff Rasley 已提交
279
        return self.client_optimizer.__class__.__name__ if self.client_optimizer else self._config.optimizer_name
O
Olatunji Ruwase 已提交
280 281 282 283

    def optimizer_params(self):
        return self._config.optimizer_params

284 285 286
    def optimizer_legacy_fusion(self):
        return self._config.optimizer_legacy_fusion

O
Olatunji Ruwase 已提交
287 288 289 290 291 292 293 294 295
    def scheduler_name(self):
        return self._config.scheduler_name

    def scheduler_params(self):
        return self._config.scheduler_params

    def zero_optimization(self):
        return self._config.zero_enabled

296 297 298
    def zero_allow_untested_optimizer(self):
        return self._config.zero_allow_untested_optimizer

J
Jeff Rasley 已提交
299 300 301 302 303 304
    def zero_reduce_scatter(self):
        return self._config.zero_config.reduce_scatter

    def zero_overlap_comm(self):
        return self._config.zero_config.overlap_comm

J
Jeff Rasley 已提交
305 306 307
    def zero_cpu_offload(self):
        return self._config.zero_config.cpu_offload

J
Jeff Rasley 已提交
308 309 310 311 312 313 314 315 316 317 318 319 320 321 322
    def zero_optimization_stage(self):
        return self._config.zero_optimization_stage

    def zero_reduce_bucket_size(self):
        return self._config.zero_config.reduce_bucket_size

    def zero_allgather_bucket_size(self):
        return self._config.zero_config.allgather_bucket_size

    def zero_optimization_partition_gradients(self):
        return self.zero_optimization_stage() >= ZERO_OPTIMIZATION_GRADIENTS

    def zero_contiguous_gradients(self):
        return self._config.zero_config.contiguous_gradients

323 324 325
    def zero_load_from_fp32_weights(self):
        return self._config.zero_config.load_from_fp32_weights

O
Olatunji Ruwase 已提交
326 327 328
    def fp16_enabled(self):
        return self._config.fp16_enabled

329 330 331 332 333 334
    def amp_enabled(self):
        return self._config.amp_enabled

    def amp_params(self):
        return self._config.amp_params

O
Olatunji Ruwase 已提交
335 336 337 338 339 340 341 342 343 344 345 346
    def loss_scale(self):
        return self._config.loss_scale

    def gradient_accumulation_steps(self):
        return self._config.gradient_accumulation_steps

    def allreduce_always_fp32(self):
        return self._config.allreduce_always_fp32

    def postscale_gradients(self):
        return not self._config.prescale_gradients

347 348 349
    def gradient_predivide_factor(self):
        return self._config.gradient_predivide_factor

O
Olatunji Ruwase 已提交
350 351 352
    def steps_per_print(self):
        return self._config.steps_per_print

J
Jeff Rasley 已提交
353 354
    def zero_allgather_partitions(self):
        return self._config.zero_config.allgather_partitions
O
Olatunji Ruwase 已提交
355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374

    def dump_state(self):
        return self._config.dump_state

    def gradient_clipping(self):
        return self._config.gradient_clipping

    def dynamic_loss_scale(self):
        return self._config.loss_scale == 0

    def initial_dynamic_scale(self):
        return self._config.initial_dynamic_scale

    def dynamic_loss_scale_args(self):
        return self._config.dynamic_loss_scale_args

    def _configure_lr_scheduler(self, client_lr_scheduler):
        # First check for scheduler in json configuration
        lr_scheduler = self._scheduler_from_config(self.optimizer)
        if lr_scheduler:
375 376 377
            if self.global_rank == 0:
                logger.info(
                    f'DeepSpeed using configured LR scheduler = {self.scheduler_name()}')
O
Olatunji Ruwase 已提交
378 379
            self.lr_scheduler = lr_scheduler
        else:
380 381
            if self.global_rank == 0:
                logger.info('DeepSpeed using client LR scheduler')
O
Olatunji Ruwase 已提交
382
            self.lr_scheduler = client_lr_scheduler
383
        log_dist(f'DeepSpeed LR Scheduler = {self.lr_scheduler}', ranks=[0])
O
Olatunji Ruwase 已提交
384 385 386

    def _configure_checkpointing(self, dist_init_required):

387 388 389
        dp_rank = self.global_rank
        if self.mpu:
            dp_rank = self.mpu.get_data_parallel_rank()
O
Olatunji Ruwase 已提交
390

391
        # only the first data parallel process needs to store the model checkpoint
392
        self.save_non_zero_checkpoint = (dp_rank == 0)
O
Olatunji Ruwase 已提交
393 394

        if self.zero_optimization():
395 396
            param_rank = torch.distributed.get_rank(
                group=self.optimizer.dp_process_group)
O
Olatunji Ruwase 已提交
397

398 399
            # Only the first parameter parallel process needs to store the
            # optimizer state checkpoints for zero
400
            self.save_zero_checkpoint = (param_rank == dp_rank)
O
Olatunji Ruwase 已提交
401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432

    def _scheduler_from_config(self, optimizer):
        scheduler_name = self.scheduler_name()
        if scheduler_name is not None:
            if hasattr(lr_schedules, scheduler_name):
                scheduler = getattr(lr_schedules, scheduler_name)
            else:
                assert hasattr(torch.optim.lr_scheduler, scheduler_name), \
                    f"DeepSpeed does not recognize LR scheduler {scheduler_name}"

                scheduler = getattr(torch.optim.lr_scheduler, scheduler_name)

            scheduler_params = self.scheduler_params()
            instantiated_scheduler = scheduler(optimizer, **scheduler_params)
            return instantiated_scheduler
        else:
            return None

    def _init_distributed(self, dist_init_required):
        if self.local_rank >= 0:
            torch.cuda.set_device(self.local_rank)
            self.device = torch.device("cuda", self.local_rank)
            self.world_size = dist.get_world_size()
            self.global_rank = dist.get_rank()
        else:
            self.world_size = 1
            self.global_rank = 0
            self.device = torch.device("cuda")

    # Configure based on command line arguments
    def _configure_with_arguments(self, args, mpu):
        self.local_rank = args.local_rank if hasattr(args, 'local_rank') else 0
J
Jeff Rasley 已提交
433 434 435
        self._config = DeepSpeedConfig(args.deepspeed_config,
                                       mpu,
                                       param_dict=self.config_params)
O
Olatunji Ruwase 已提交
436 437 438

    # Validate command line arguments
    def _do_args_sanity_check(self, args):
439
        if hasattr(args, 'deepscale_config') and args.deepscale_config is not None:
C
Chunyang Wen 已提交
440
            logger.warning(
441 442 443 444 445 446
                "************ --deepscale_config is deprecated, please use --deepspeed_config ************"
            )
            if hasattr(args, 'deepspeed_config'):
                assert args.deepspeed_config is None, "Not sure how to proceed, we were given both a deepscale_config and deepspeed_config"
            args.deepspeed_config = args.deepscale_config

O
Olatunji Ruwase 已提交
447 448 449
        assert hasattr(args, 'local_rank') and type(args.local_rank) == int, \
            'DeepSpeed requires integer command line parameter --local_rank'

J
Jeff Rasley 已提交
450 451 452
        if self.config_params is None:
            assert hasattr(args, 'deepspeed_config') and args.deepspeed_config is not None, \
                'DeepSpeed requires --deepspeed_config to specify configuration file'
O
Olatunji Ruwase 已提交
453

J
Jeff Rasley 已提交
454 455
            assert os.path.isfile(args.deepspeed_config), \
                'DeepSpeed configuration file: {} is not an existing file'.format(args.deepspeed_config)
O
Olatunji Ruwase 已提交
456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472

    def _is_supported_optimizer(self, optimizer_name):
        return optimizer_name in DEEPSPEED_OPTIMIZERS or \
            getattr(torch.optim, optimizer_name, None) is not None

    # Validate configuration based on command line arguments
    def _do_sanity_check(self):
        if not self.client_optimizer:
            assert self._is_supported_optimizer(self.optimizer_name()), \
                '{} is not a supported DeepSpeed Optimizer'.format(self.optimizer_name())
            assert self.client_model_parameters, \
                'DeepSpeed {} optimizer requires parameters in initialize() call'.format(self.optimizer_name())

        if self.optimizer_name() == LAMB_OPTIMIZER:
            assert self.dynamic_loss_scale(), \
                'DeepSpeed {} optimizer requires dynamic loss scaling'.format(self.optimizer_name())

473 474 475 476 477 478 479
    def _broadcast_model(self):
        for p in self.module.parameters():
            if torch.is_tensor(p):
                dist.broadcast(p,
                               self.broadcast_src_rank,
                               group=self.data_parallel_group)

O
Olatunji Ruwase 已提交
480 481 482 483 484
    def _configure_distributed_model(self, model):
        self.module = model
        if self.fp16_enabled():
            self.module.half()
        self.module.to(self.device)
485

O
Olatunji Ruwase 已提交
486 487 488
        if self.mpu is None:
            self.data_parallel_group = _initialize_parameter_parallel_groups()
            self.dp_world_size = dist.get_world_size()
489
            self.mp_world_size = 1
490
            self.broadcast_src_rank = 0
O
Olatunji Ruwase 已提交
491 492 493
        else:
            self.data_parallel_group = self.mpu.get_data_parallel_group()
            self.dp_world_size = self.mpu.get_data_parallel_world_size()
494
            self.mp_world_size = self.mpu.get_model_parallel_world_size()
495 496 497
            self.broadcast_src_rank = _get_global_rank(
                self.mpu.get_data_parallel_group(),
                0)
O
Olatunji Ruwase 已提交
498

499 500
        if not self.amp_enabled():
            self._broadcast_model()
O
Olatunji Ruwase 已提交
501 502 503

    # Configure optimizer
    def _configure_optimizer(self, client_optimizer, model_parameters):
J
Jeff Rasley 已提交
504

O
Olatunji Ruwase 已提交
505 506
        if client_optimizer is not None:
            basic_optimizer = client_optimizer
507 508
            if self.global_rank == 0:
                logger.info('Using client Optimizer as basic optimizer')
O
Olatunji Ruwase 已提交
509 510
        else:
            basic_optimizer = self._configure_basic_optimizer(model_parameters)
511 512 513 514
            if self.global_rank == 0:
                logger.info(
                    'Using DeepSpeed Optimizer param name {} as basic optimizer'.format(
                        self.optimizer_name()))
O
Olatunji Ruwase 已提交
515

516 517
        if self.global_rank == 0:
            logger.info('DeepSpeed Basic Optimizer = {}'.format(basic_optimizer))
O
Olatunji Ruwase 已提交
518

519
        if self.zero_optimization():
520
            assert not self.amp_enabled(), "Amp and ZeRO are not currently compatible, please use (legacy) fp16 mode which performs similar to amp opt_mode=O2"
J
Jeff Rasley 已提交
521
            if not is_zero_supported_optimizer(basic_optimizer):
522
                assert self.zero_allow_untested_optimizer(), \
523
                    'You are using an untested ZeRO Optimizer. Please add <"zero_allow_untested_optimizer": true> in the configuration file to use it.'
524

525 526 527 528
                if self.global_rank == 0:
                    logger.warning(
                        "**** You are using ZeRO with an untested optimizer, proceed with caution *****"
                    )
O
Olatunji Ruwase 已提交
529
            self.optimizer = self._configure_zero_optimizer(basic_optimizer)
530 531 532
        elif self.amp_enabled():
            assert not self.fp16_enabled(), "Cannot enable both amp with (legacy) fp16 mode"
            amp_params = self.amp_params()
533 534
            if self.global_rank == 0:
                logger.info(f"Initializing AMP with these params: {amp_params}")
535 536
            self.module, self.optimizer = amp.initialize(self.module, basic_optimizer, **amp_params)
            self._broadcast_model()
O
Olatunji Ruwase 已提交
537 538 539 540
        elif self.fp16_enabled():
            self.optimizer = self._configure_fp16_optimizer(basic_optimizer)
        else:
            self.optimizer = basic_optimizer
J
Jeff Rasley 已提交
541 542
        logger.info('DeepSpeed Final Optimizer = {}'.format(self.optimizer))
        logger.info('DeepSpeed Final Optimizer = {}'.format(self.optimizer.state_dict()))
O
Olatunji Ruwase 已提交
543 544 545

    def _configure_basic_optimizer(self, model_parameters):
        optimizer_parameters = self.optimizer_params()
546
        # print(optimizer_parameters.keys())
547 548
        if 'max_grad_norm' in optimizer_parameters.keys():
            raise ValueError(
549 550
                "'max_grad_norm' is not supported as an optimizer parameter, please switch to using the deepspeed parameter 'gradient_clipping' see: https://www.deepspeed.ai/docs/config-json/#gradient-clipping for more details"
            )
O
Olatunji Ruwase 已提交
551
        if self.optimizer_name() == ADAM_OPTIMIZER:
J
Jeff Rasley 已提交
552 553 554 555 556 557 558 559
            if self.zero_cpu_offload():
                optimizer = torch.optim.Adam(model_parameters, **optimizer_parameters)
            else:
                from apex.optimizers.fused_adam import FusedAdam
                optimizer = FusedAdam(model_parameters, **optimizer_parameters)
        elif self.optimizer_name() == DEEPSPEED_ADAM:
            from deepspeed.ops.adam import DeepSpeedCPUAdam
            optimizer = DeepSpeedCPUAdam(model_parameters, **optimizer_parameters)
O
Olatunji Ruwase 已提交
560
        elif self.optimizer_name() == LAMB_OPTIMIZER:
561
            from deepspeed.ops.lamb import FusedLamb
O
Olatunji Ruwase 已提交
562
            optimizer = FusedLamb(model_parameters, **optimizer_parameters)
563 564 565
        elif self.optimizer_name() == ONEBIT_ADAM_OPTIMIZER:
            from deepspeed.runtime.fp16.onebit_adam import OnebitAdam
            optimizer = OnebitAdam(model_parameters, self, **optimizer_parameters)
O
Olatunji Ruwase 已提交
566 567 568 569 570 571 572 573 574
        else:
            torch_optimizer = getattr(torch.optim, self.optimizer_name())
            optimizer = torch_optimizer(model_parameters, **optimizer_parameters)
        return optimizer

    def _configure_fp16_optimizer(self, optimizer):
        initial_dynamic_scale = self.initial_dynamic_scale()
        dynamic_loss_args = self.dynamic_loss_scale_args()
        clip_grad = self.gradient_clipping()
J
Jeff Rasley 已提交
575 576 577
        if isinstance(optimizer,
                      apex.optimizers.FusedAdam) or self.optimizer_name(
                      ) == ONEBIT_ADAM_OPTIMIZER:
O
Olatunji Ruwase 已提交
578
            if self.dynamic_loss_scale():
C
Chunyang Wen 已提交
579
                logger.info('Creating fp16 optimizer with dynamic loss scale')
580
                timers = self.timers if self.wall_clock_breakdown() else None
581 582 583 584 585 586 587
                optimizer = FP16_Optimizer(
                    optimizer,
                    dynamic_loss_scale=True,
                    initial_dynamic_scale=initial_dynamic_scale,
                    dynamic_loss_args=dynamic_loss_args,
                    mpu=self.mpu,
                    clip_grad=clip_grad,
588 589
                    fused_adam_legacy=self.optimizer_legacy_fusion(),
                    timers=timers)
O
Olatunji Ruwase 已提交
590
            else:
C
Chunyang Wen 已提交
591
                logger.info('Creating fp16 optimizer with static loss scale: {}'.format(
O
Olatunji Ruwase 已提交
592
                    self.loss_scale()))
593 594 595 596 597 598
                optimizer = FP16_Optimizer(
                    optimizer,
                    static_loss_scale=self.loss_scale(),
                    mpu=self.mpu,
                    clip_grad=clip_grad,
                    fused_adam_legacy=self.optimizer_legacy_fusion())
O
Olatunji Ruwase 已提交
599
        else:
C
Chunyang Wen 已提交
600
            logger.info('Creating fp16 unfused optimizer with dynamic loss scale')
O
Olatunji Ruwase 已提交
601 602 603 604 605 606
            optimizer = FP16_UnfusedOptimizer(
                optimizer,
                dynamic_loss_scale=self.dynamic_loss_scale(),
                dynamic_loss_args=dynamic_loss_args,
                mpu=self.mpu,
                clip_grad=clip_grad,
607
                fused_lamb_legacy=self.optimizer_name() == LAMB_OPTIMIZER)
O
Olatunji Ruwase 已提交
608 609 610 611

        return optimizer

    def _configure_zero_optimizer(self, optimizer):
J
Jeff Rasley 已提交
612
        zero_stage = self.zero_optimization_stage()
C
Chunyang Wen 已提交
613
        logger.info('Creating fp16 ZeRO stage {} optimizer'.format(zero_stage))
J
Jeff Rasley 已提交
614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641

        if zero_stage == ZERO_OPTIMIZATION_OPTIMIZER_STATES:
            assert self.zero_reduce_scatter(), 'Stage 1 only supports reduce scatter mode'
            optimizer = FP16_DeepSpeedZeroOptimizer_Stage1(
                optimizer,
                static_loss_scale=self.loss_scale(),
                dynamic_loss_scale=self.dynamic_loss_scale(),
                dynamic_loss_args=self.dynamic_loss_scale_args(),
                clip_grad=self.gradient_clipping(),
                all_gather_partitions=self.zero_allgather_partitions(),
                allgather_size=self.zero_allgather_bucket_size(),
                max_elements_per_comm=self.zero_reduce_bucket_size(),
                dp_process_group=self.data_parallel_group,
                mpu=self.mpu)
        elif zero_stage == ZERO_OPTIMIZATION_GRADIENTS:
            optimizer = FP16_DeepSpeedZeroOptimizer(
                optimizer,
                timers=self.timers,
                static_loss_scale=self.loss_scale(),
                dynamic_loss_scale=self.dynamic_loss_scale(),
                dynamic_loss_args=self.dynamic_loss_scale_args(),
                clip_grad=self.gradient_clipping(),
                contiguous_gradients=self.zero_contiguous_gradients(),
                reduce_bucket_size=self.zero_reduce_bucket_size(),
                allgather_bucket_size=self.zero_allgather_bucket_size(),
                dp_process_group=self.data_parallel_group,
                reduce_scatter=self.zero_reduce_scatter(),
                overlap_comm=self.zero_overlap_comm(),
J
Jeff Rasley 已提交
642
                cpu_offload=self.zero_cpu_offload(),
643 644
                mpu=self.mpu,
                postscale_gradients=self.postscale_gradients(),
J
Jeff Rasley 已提交
645 646
                gradient_predivide_factor=self.gradient_predivide_factor(),
                gradient_accumulation_steps=self.gradient_accumulation_steps())
J
Jeff Rasley 已提交
647 648
        else:
            raise NotImplementedError("ZeRO stage {} not implemented".format(zero_stage))
O
Olatunji Ruwase 已提交
649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676

        return optimizer

    def deepspeed_io(self,
                     dataset,
                     batch_size=None,
                     route=ROUTE_TRAIN,
                     pin_memory=True,
                     data_sampler=None,
                     collate_fn=None,
                     num_local_io_workers=None):
        if not isinstance(dataset, torch.utils.data.Dataset):
            raise ValueError("Training data must be a torch Dataset")

        if data_sampler is None and (route == ROUTE_PREDICT or route == ROUTE_EVAL):
            data_sampler = torch.utils.data.SequentialSampler(dataset)

        if batch_size is None:
            batch_size = self.train_micro_batch_size_per_gpu()

        if collate_fn is None:
            collate_fn = self.collate_fn

        # Currently we only use timer in train route
        deepspeed_io_timer = None
        if route == ROUTE_TRAIN:
            deepspeed_io_timer = self.tput_timer

677 678 679 680
        # If mpu is provied, forward world size and parallel rank to sampler.
        data_parallel_world_size = None
        data_parallel_rank = None
        if self.mpu is not None:
681 682
            data_parallel_world_size = self.mpu.get_data_parallel_world_size()
            data_parallel_rank = self.mpu.get_data_parallel_rank()
683

O
Olatunji Ruwase 已提交
684 685 686 687 688 689 690
        return DeepSpeedDataLoader(dataset=dataset,
                                   batch_size=batch_size,
                                   pin_memory=pin_memory,
                                   collate_fn=collate_fn,
                                   local_rank=self.local_rank,
                                   tput_timer=deepspeed_io_timer,
                                   num_local_io_workers=num_local_io_workers,
691 692 693
                                   data_sampler=data_sampler,
                                   data_parallel_world_size=data_parallel_world_size,
                                   data_parallel_rank=data_parallel_rank)
O
Olatunji Ruwase 已提交
694 695 696 697 698 699 700 701 702 703 704 705 706 707 708

    def train(self):
        r"""
        """

        self.warn_unscaled_loss = True
        self.module.train()

    def eval(self):
        r"""
        """

        self.warn_unscaled_loss = True
        self.module.train(False)

709 710 711 712 713 714 715 716 717 718
    def _scale_loss(self, prescaled_loss):
        if isinstance(prescaled_loss, torch.Tensor):
            scaled_loss = prescaled_loss / self.gradient_accumulation_steps()
        elif isinstance(prescaled_loss, tuple) or isinstance(prescaled_loss, list):
            scaled_loss = []
            for l in prescaled_loss:
                if isinstance(l, torch.Tensor):
                    scaled_loss.append(l / self.gradient_accumulation_steps())
                else:
                    scaled_loss.append(l)
O
Olatunji Ruwase 已提交
719
        else:
720
            scaled_loss = prescaled_loss
O
Olatunji Ruwase 已提交
721
            if self.warn_unscaled_loss:
C
Chunyang Wen 已提交
722
                logger.warning(
723 724
                    f'DeepSpeed unable to scale loss because of type: {type(prescaled_loss)}'
                )
O
Olatunji Ruwase 已提交
725 726
                self.warn_unscaled_loss = False

727
        return scaled_loss
O
Olatunji Ruwase 已提交
728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751

    def forward(self, *inputs, **kwargs):
        r"""Execute forward propagation

        Arguments:
            *inputs: Variable length input list
            **kwargs: variable length keyword arguments
        """

        if self.wall_clock_breakdown():
            self.timers('forward_microstep').start()
            self.timers('forward').start()

        if self.training_dataloader is None:
            self.tput_timer.start()
        loss = self.module(*inputs, **kwargs)

        if self.wall_clock_breakdown():
            self.timers('forward').stop()
            self.timers('forward_microstep').stop()

        return loss

    def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
752 753 754 755 756 757
        #Zero stage 2 communicates during non gradient accumulation boundaries as well
        if self.zero_optimization_partition_gradients():
            self.optimizer.overlapping_partition_gradients_reduce_epilogue()

        #Communicate only at gradient accumulation boundaries
        elif self.is_gradient_accumulation_boundary():
J
Jeff Rasley 已提交
758 759 760 761
            if self.zero_optimization_stage() == ZERO_OPTIMIZATION_OPTIMIZER_STATES:
                assert self.zero_reduce_scatter()
                self.optimizer.reduce_scatter_gradients(
                    postscale_gradients=self.postscale_gradients(),
762
                    gradient_predivide_factor=self.gradient_predivide_factor(),
J
Jeff Rasley 已提交
763 764 765
                    gradient_average=self.gradient_average)
            else:
                self.buffered_allreduce_fallback(elements_per_buffer=bucket_size)
O
Olatunji Ruwase 已提交
766

767
    def backward(self, loss, allreduce_gradients=True, release_loss=False):
O
Olatunji Ruwase 已提交
768 769 770 771 772 773 774
        r"""Execute backward pass on the loss

        Arguments:
            loss: Torch tensor on which to execute backward propagation
            allreduce_gradients: If this is False, then gradient averaging will be skipped. Default is True.
        """

775 776
        # scale loss w.r.t. gradient accumulation if needed
        if self.gradient_accumulation_steps() > 1:
J
Jeff Rasley 已提交
777
            loss = self._scale_loss(loss.float())
778

779 780 781 782 783 784 785
        # Log training Loss
        if self.tensorboard_enabled():
            if self.is_gradient_accumulation_boundary():
                if self.global_rank == 0:
                    self.summary_events = [
                        (f'Train/Samples/train_loss',
                         loss.mean().item() * self.gradient_accumulation_steps(),
786
                         self.global_samples)
787 788 789 790
                    ]
                    for event in self.summary_events:  # write_summary_events
                        self.summary_writer.add_scalar(event[0], event[1], event[2])
                    self.summary_writer.flush()
O
Olatunji Ruwase 已提交
791 792 793 794 795 796 797 798 799 800 801 802 803

        if self.wall_clock_breakdown():
            self.timers('backward_microstep').start()
            self.timers('backward').start()

        assert self.optimizer is not None, "must provide optimizer during " \
                                           "init in order to use backward"

        if self.wall_clock_breakdown():
            self.timers('backward_inner_microstep').start()
            self.timers('backward_inner').start()

        if self.zero_optimization():
J
Jeff Rasley 已提交
804 805
            self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary(
            )
O
Olatunji Ruwase 已提交
806
            self.optimizer.backward(loss)
807
        elif self.amp_enabled():
808 809 810 811 812 813
            # AMP requires delaying unscale when inside gradient accumulation boundaries
            # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations
            delay_unscale = not self.is_gradient_accumulation_boundary()
            with amp.scale_loss(loss,
                                self.optimizer,
                                delay_unscale=delay_unscale) as scaled_loss:
814
                scaled_loss.backward()
O
Olatunji Ruwase 已提交
815 816 817 818 819 820 821 822 823 824 825 826 827
        elif self.fp16_enabled():
            self.optimizer.backward(loss)
        else:
            loss.backward()

        if self.wall_clock_breakdown():
            self.timers('backward_inner').stop()
            self.timers('backward_inner_microstep').stop()

        if self.wall_clock_breakdown():
            self.timers('backward_allreduce_microstep').start()
            self.timers('backward_allreduce').start()

828
        if allreduce_gradients and self.enable_backward_allreduce:
O
Olatunji Ruwase 已提交
829 830 831 832 833 834 835 836
            self.allreduce_gradients()

        if self.wall_clock_breakdown():
            self.timers('backward_allreduce').stop()
            self.timers('backward_allreduce_microstep').stop()
            self.timers('backward').stop()
            self.timers('backward_microstep').stop()

837 838 839 840
        if release_loss:
            # loss.data = None
            pass

841 842
        return loss

O
Olatunji Ruwase 已提交
843
    def is_gradient_accumulation_boundary(self):
S
Shaden Smith 已提交
844 845 846 847 848 849 850
        """Query whether the current micro-batch is at the boundary of
        gradient accumulation, and thus will trigger gradient reductions and
        an optimizer step.

        Returns:
            bool: if the current step is a gradient accumulation boundary.
        """
O
Olatunji Ruwase 已提交
851 852 853
        return (self.micro_steps + 1) % \
            self.gradient_accumulation_steps() == 0

S
Samyam Rajbhandari 已提交
854 855 856 857 858 859 860
    def zero_grad(self):
        """
        Zero parameter grads.
        """
        for param_name, param in self.module.named_parameters():
            param.grad = None

861 862 863 864
    def clip_fp32_gradients(self):
        torch.nn.utils.clip_grad_norm_(parameters=self.module.parameters(),
                                       max_norm=self.gradient_clipping())

865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902
    def _take_model_step(self):
        if self.gradient_clipping() > 0.0:
            if not self.fp16_enabled() and not self.amp_enabled():
                self.clip_fp32_gradients()
            elif self.amp_enabled():
                # AMP's recommended way of doing clipping
                # https://nvidia.github.io/apex/advanced.html#gradient-clipping
                master_params = amp.master_params(self.optimizer)
                torch.nn.utils.clip_grad_norm_(parameters=master_params,
                                               max_norm=self.gradient_clipping())
        self.optimizer.step()

        #zero grad in basic optimizer could be unreliable and may not exhibit
        #the behaviour that we want
        if not self.zero_optimization() and not self.fp16_enabled(
        ) and not self.amp_enabled():
            self.zero_grad()
        else:
            self.optimizer.zero_grad()

        report_progress = self.global_rank == 0 if self.global_rank else True

        # Check overlow here since in DS fp16 optimizer, the overflow is updated in above step() function.
        overflow = False
        if hasattr(self.optimizer, 'overflow'):
            overflow = self.optimizer.overflow

        if overflow:
            self.skipped_steps += 1
        else:
            if self.lr_scheduler is not None:
                self.lr_scheduler.step()
            if report_progress and (self.global_steps + 1) % self.steps_per_print() == 0:
                self._report_progress(self.global_steps + 1)

        self.global_steps += 1
        self.global_samples += self.train_batch_size()

O
Olatunji Ruwase 已提交
903
    def step(self):
904 905
        r"""Execute the weight update step after forward and backward propagation
        on effective_train_batch.
O
Olatunji Ruwase 已提交
906 907 908 909 910 911 912 913 914
        """
        if self.wall_clock_breakdown():
            self.timers('step_microstep').start()
            self.timers('step').start()

        assert self.optimizer is not None, "must provide optimizer during " \
                                           "init in order to use step"
        report_progress = self.global_rank == 0 if self.global_rank else True

915
        # Update the model when we reach gradient accumulation boundaries
O
Olatunji Ruwase 已提交
916
        if self.is_gradient_accumulation_boundary():
917
            self._take_model_step()
O
Olatunji Ruwase 已提交
918 919 920

        self.tput_timer.stop(report_progress)

921 922 923 924 925 926
        # Log learning rate
        if self.tensorboard_enabled():
            if self.is_gradient_accumulation_boundary():
                if self.global_rank == 0:
                    self.summary_events = [(f'Train/Samples/lr',
                                            self.get_lr()[0],
927 928 929 930 931 932 933
                                            self.global_samples)]
                    for event in self.summary_events:  # write_summary_events
                        self.summary_writer.add_scalar(event[0], event[1], event[2])
                    if self.fp16_enabled() and hasattr(self.optimizer, 'cur_scale'):
                        self.summary_events.append((f'Train/Samples/loss_scale',
                                                    self.optimizer.cur_scale,
                                                    self.global_samples))
934 935 936
                    for event in self.summary_events:  # write_summary_events
                        self.summary_writer.add_scalar(event[0], event[1], event[2])
                    self.summary_writer.flush()
O
Olatunji Ruwase 已提交
937 938 939 940

        if self.wall_clock_breakdown():
            self.timers('step').stop()
            self.timers('step_microstep').stop()
941
            timer_names = [
O
Olatunji Ruwase 已提交
942 943 944 945 946
                'forward_microstep',
                'backward_microstep',
                'backward_inner_microstep',
                'backward_allreduce_microstep',
                'step_microstep'
947 948
            ]
            self.timers.log(names=timer_names, memory_breakdown=self.memory_breakdown())
J
Jeff Rasley 已提交
949

950
            # Log timing
J
Jeff Rasley 已提交
951
            if self.is_gradient_accumulation_boundary():
952 953
                if self.tensorboard_enabled():
                    if self.global_rank == 0:
954 955 956
                        self.summary_events = [
                            (f'Train/Samples/elapsed_time_ms_forward',
                             self.timers('forward').elapsed(reset=False) * 1000.0,
957
                             self.global_samples),
958 959
                            (f'Train/Samples/elapsed_time_ms_backward',
                             self.timers('backward').elapsed(reset=False) * 1000.0,
960
                             self.global_samples),
961 962
                            (f'Train/Samples/elapsed_time_ms_backward_inner',
                             self.timers('backward_inner').elapsed(reset=False) * 1000.0,
963
                             self.global_samples),
964 965 966
                            (f'Train/Samples/elapsed_time_ms_backward_allreduce',
                             self.timers('backward_allreduce').elapsed(reset=False) *
                             1000.0,
967
                             self.global_samples),
968 969
                            (f'Train/Samples/elapsed_time_ms_step',
                             self.timers('step').elapsed(reset=False) * 1000.0,
970
                             self.global_samples)
971
                        ]
972 973 974 975 976
                        for event in self.summary_events:  # write_summary_events
                            self.summary_writer.add_scalar(event[0], event[1], event[2])
                        self.summary_writer.flush()

            if self.wall_clock_breakdown():
J
Jeff Rasley 已提交
977 978 979 980 981 982 983
                self.timers.log([
                    'forward',
                    'backward',
                    'backward_inner',
                    'backward_allreduce',
                    'step'
                ])
O
Olatunji Ruwase 已提交
984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000

        self.micro_steps += 1

    def _get_optimizer_param(self, param_name):
        result = []
        if not self.optimizer:
            return result
        for group in self.optimizer.param_groups:
            if param_name in group:
                result.append(group[param_name])
            else:
                result.append(0.0)
        return result

    def get_lr(self):
        return self._get_optimizer_param('lr')

J
Jeff Rasley 已提交
1001 1002 1003
    def get_type(self):
        return self._get_optimizer_param('type')

O
Olatunji Ruwase 已提交
1004 1005 1006 1007 1008 1009
    def get_mom(self):
        return self._get_optimizer_param('betas')

    def _report_progress(self, step):
        lr = self.get_lr()
        mom = self.get_mom()
1010 1011
        log_dist(f'step={step}, skipped={self.skipped_steps}, lr={lr}, mom={mom}',
                 ranks=[0])
O
Olatunji Ruwase 已提交
1012 1013 1014 1015 1016 1017 1018 1019 1020 1021

    def allreduce_bucket(self, bucket):
        tensor = flatten(bucket)

        tensor_to_allreduce = tensor

        if self.allreduce_always_fp32():
            tensor_to_allreduce = tensor.float()

        if self.postscale_gradients():
1022 1023
            if self.gradient_predivide_factor() != 1.0:
                tensor_to_allreduce.mul_(1. / self.gradient_predivide_factor())
O
Olatunji Ruwase 已提交
1024 1025 1026 1027

            dist.all_reduce(tensor_to_allreduce, group=self.data_parallel_group)

            if self.gradient_average:
1028 1029
                if self.gradient_predivide_factor() != self.dp_world_size:
                    tensor_to_allreduce.mul_(self.gradient_predivide_factor() /
O
Olatunji Ruwase 已提交
1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053
                                             self.dp_world_size)
        else:
            tensor_to_allreduce.div_(self.dp_world_size)
            dist.all_reduce(tensor_to_allreduce, group=self.data_parallel_group)

        if self.allreduce_always_fp32() and tensor is not tensor_to_allreduce:
            tensor.copy_(tensor_to_allreduce)

        return tensor

    def allreduce_and_copy(self, small_bucket):
        allreduced = self.allreduce_bucket(small_bucket)
        for buf, synced in zip(small_bucket, unflatten(allreduced, small_bucket)):
            buf.copy_(synced)

    def allreduce_no_retain(self, bucket, numel_per_bucket=500000000):
        small_bucket = []
        numel = 0
        for tensor in bucket:
            small_bucket.append(tensor)
            numel = numel + tensor.numel()
            if numel > numel_per_bucket:
                self.allreduce_and_copy(small_bucket)
                small_bucket = []
1054
                numel = 0
O
Olatunji Ruwase 已提交
1055 1056 1057 1058 1059 1060
        if len(small_bucket) > 0:
            self.allreduce_and_copy(small_bucket)

    def buffered_allreduce_fallback(self, grads=None, elements_per_buffer=500000000):
        grads = []
        for param_name, param in self.module.named_parameters():
J
Jeff Rasley 已提交
1061 1062 1063 1064 1065 1066
            if param.grad is None:
                # In cases where there is an imbalance of empty grads across
                # ranks we must create empty grads, this will ensure that every
                # rank is reducing the same size. In some cases it may make
                # sense in the future to support the ability to average not
                # w.r.t. world size but with a different value.
1067 1068 1069 1070
                param.grad = torch.zeros(param.size(),
                                         dtype=param.dtype,
                                         device=param.device)
                grads.append(param.grad.data)
J
Jeff Rasley 已提交
1071
            else:
O
Olatunji Ruwase 已提交
1072 1073
                grad_data = param.grad.data
                if self.sparse_gradients_enabled(
J
Jeff Rasley 已提交
1074
                ) and param_name in self.csr_tensor_module_names:
O
Olatunji Ruwase 已提交
1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112
                    grads.append(CSRTensor(grad_data))
                else:
                    grads.append(grad_data)

        split_buckets = split_half_float_double_csr(grads)

        for i, bucket_tuple in enumerate(split_buckets):
            bucket_type, bucket = bucket_tuple
            if bucket_type == CSRTensor.type():
                self.csr_allreduce_no_retain(bucket)
            else:
                self.allreduce_no_retain(bucket, numel_per_bucket=elements_per_buffer)

    def csr_allreduce_no_retain(self, bucket):
        allreduced_csrs = self.csr_allreduce_bucket(bucket)
        # Densify csr tensor and copy back to original location
        for csr in allreduced_csrs:
            dense_tensor = csr.to_dense()
            csr.orig_dense_tensor.copy_(dense_tensor)

    def csr_allreduce_bucket(self, bucket):
        csr_list = []
        for csr in bucket:
            csr_list.append(self.csr_allreduce(csr))
        return csr_list

    def csr_allreduce(self, csr):
        # Pre-divide for fp16 stability
        csr.values.div_(self.dp_world_size)

        indices_device_list = self.csr_all_gather(csr.indices)
        values_device_list = self.csr_all_gather(csr.values)

        csr.indices = torch.cat(indices_device_list)
        csr.values = torch.cat(values_device_list)
        return csr

    def csr_all_gather(self, value):
1113
        my_size = torch.LongTensor([value.size()[0]]).to(self.device)
O
Olatunji Ruwase 已提交
1114 1115 1116 1117 1118 1119 1120 1121
        all_sizes = self.all_gather_scalar(my_size)
        max_size = torch.cat(all_sizes).max()
        fill_size = (max_size - my_size)

        assert value.dim() in [1, 2]
        if value.dim() == 1:
            if fill_size > 0:
                value = torch.cat([value, value.new_zeros(fill_size)])
1122
            tensor_list = [value.new_zeros(max_size) for _ in range(self.dp_world_size)]
O
Olatunji Ruwase 已提交
1123 1124 1125 1126 1127
        else:
            if fill_size > 0:
                value = torch.cat([value, value.new_zeros(fill_size, value.size()[1])])
            tensor_list = [
                value.new_zeros(max_size,
1128
                                value.size()[1]) for _ in range(self.dp_world_size)
O
Olatunji Ruwase 已提交
1129 1130 1131 1132 1133 1134
            ]

        dist.all_gather(tensor_list, value, group=self.data_parallel_group)
        tensors = []
        for dev_idx, t in enumerate(tensor_list):
            size = all_sizes[dev_idx][0]
1135 1136 1137
            tensors.append(
                t.index_select(0,
                               torch.LongTensor(range(size)).to(self.device)))
O
Olatunji Ruwase 已提交
1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152

        return tensors

    def all_gather_scalar(self, value):
        tensor_list = [value.new_zeros(value.size()) for _ in range(self.dp_world_size)]
        dist.all_gather(tensor_list, value, group=self.data_parallel_group)
        return tensor_list

    def module_state_dict(self, destination=None, prefix='', keep_vars=False):
        sd = self.module.state_dict(destination, prefix, keep_vars)
        return sd

    def load_module_state_dict(self, state_dict, strict=True):
        self.module.load_state_dict(state_dict, strict=strict)

1153 1154
    def _get_rank_zero_ckpt_name(self, checkpoints_path, tag, mp_rank, dp_rank):
        filename = 'zero_pp_rank_{}'.format(dp_rank)
O
Olatunji Ruwase 已提交
1155 1156 1157 1158 1159 1160
        zero_ckpt_name = os.path.join(
            checkpoints_path,
            str(tag),
            filename + '_mp_rank_{:02d}'.format(mp_rank) + 'optim_states.pt')
        return zero_ckpt_name

1161 1162 1163 1164 1165
    def _get_zero_ckpt_name(self, checkpoints_path, tag):
        mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
        pp_rank = torch.distributed.get_rank(group=self.optimizer.dp_process_group)
        return self._get_rank_zero_ckpt_name(checkpoints_path, tag, mp_rank, pp_rank)

O
Olatunji Ruwase 已提交
1166 1167 1168 1169 1170 1171 1172
    def _get_ckpt_name(self, checkpoints_path, tag):
        mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
        ckpt_name = os.path.join(checkpoints_path,
                                 str(tag),
                                 'mp_rank_{:02d}'.format(mp_rank) + '_model_states.pt')
        return ckpt_name

J
Jeff Rasley 已提交
1173 1174 1175 1176 1177 1178
    def load_checkpoint(self,
                        load_dir,
                        tag,
                        load_module_strict=True,
                        load_optimizer_states=True,
                        load_lr_scheduler_states=True):
O
Olatunji Ruwase 已提交
1179 1180 1181 1182 1183
        r"""Load training checkpoint

        Arguments:
            load_dir: Required. Directory to load the checkpoint from
            tag: Required. Checkpoint tag used as a unique identifier for the checkpoint. Ex. Global Step.
J
Jeff Rasley 已提交
1184
            load_module_strict: Optional. Boolean to strictly enforce that the keys in state_dict of module and checkpoint match.
1185
            load_optimizer_states: Optional. Boolean to load the training optimizer states from Checkpoint. Ex. ADAM's momentum and variance
J
Jeff Rasley 已提交
1186
            load_lr_scheduler_states: Optional. Boolean to add the learning rate scheduler states from Checkpoint.
O
Olatunji Ruwase 已提交
1187 1188 1189 1190 1191
        Return:
            load_path: Path of the loaded checkpoint. None if loading the checkpoint failed
            client_state: State dictionary used for loading required training states in the client code.
        """

J
Jeff Rasley 已提交
1192 1193 1194 1195 1196
        load_path, client_states = self._load_checkpoint(load_dir,
                                                         tag,
                                                         load_module_strict=load_module_strict,
                                                         load_optimizer_states=load_optimizer_states,
                                                         load_lr_scheduler_states=load_lr_scheduler_states)
O
Olatunji Ruwase 已提交
1197 1198

        if self.zero_optimization() and load_path is not None:
1199 1200 1201
            self._load_zero_checkpoint(load_dir,
                                       tag,
                                       load_optimizer_states=load_optimizer_states)
O
Olatunji Ruwase 已提交
1202 1203 1204

        return load_path, client_states

J
Jeff Rasley 已提交
1205 1206 1207 1208 1209 1210
    def _load_checkpoint(self,
                         load_dir,
                         tag,
                         load_module_strict=True,
                         load_optimizer_states=True,
                         load_lr_scheduler_states=True):
O
Olatunji Ruwase 已提交
1211 1212 1213 1214

        load_path = self._get_ckpt_name(load_dir, tag)

        if not os.path.exists(load_path):
C
Chunyang Wen 已提交
1215
            logger.warn(
O
Olatunji Ruwase 已提交
1216 1217 1218 1219
                'Client provided checkpoint load path: {} does not exist ... skip checkpoint load'
                .format(load_path))
            return None, None

1220
        logger.info(f'rank: {self.global_rank} loading checkpoint: {load_path}')
O
Olatunji Ruwase 已提交
1221 1222
        checkpoint = torch.load(load_path, map_location=lambda storage, loc: storage)

J
Jeff Rasley 已提交
1223 1224
        self.load_module_state_dict(state_dict=checkpoint['module'],
                                    strict=load_module_strict)
O
Olatunji Ruwase 已提交
1225
        if not self.zero_optimization():
1226 1227 1228 1229
            if self.fp16_enabled():
                self.optimizer.load_state_dict(
                    checkpoint['optimizer'],
                    load_optimizer_states=load_optimizer_states)
S
Sylwester Klocek 已提交
1230
            elif load_optimizer_states:
1231
                self.optimizer.load_state_dict(checkpoint['optimizer'])
O
Olatunji Ruwase 已提交
1232

J
Jeff Rasley 已提交
1233
        if load_lr_scheduler_states and self.lr_scheduler is not None:
O
Olatunji Ruwase 已提交
1234 1235 1236 1237
            self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])

        self.csr_tensor_module_names = checkpoint['csr_tensor_module_names']
        self.global_steps = checkpoint['global_steps']
1238 1239
        self.global_samples = checkpoint.get('global_samples',
                                             self.global_steps * self.train_batch_size())
O
Olatunji Ruwase 已提交
1240
        self.skipped_steps = checkpoint['skipped_steps']
1241 1242
        self.loaded_checkpoint_mp_world_size = checkpoint['mp_world_size']
        self.loaded_checkpoint_dp_world_size = checkpoint['dp_world_size']
O
Olatunji Ruwase 已提交
1243 1244 1245
        deepspeed_states = [
            'module',
            'optimizer',
J
Jeff Rasley 已提交
1246
            'lr_scheduler',
O
Olatunji Ruwase 已提交
1247 1248
            'csr_tensor_module_names',
            'skipped_steps',
1249 1250 1251
            'global_steps',
            'dp_world_size',
            'mp_world_size'
O
Olatunji Ruwase 已提交
1252 1253 1254 1255 1256 1257 1258 1259 1260
        ]
        client_state = {
            key: value
            for key,
            value in checkpoint.items() if not key in deepspeed_states
        }

        return load_path, client_state

1261
    def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True):
1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272
        zero_sd_list = self._get_all_zero_checkpoints(load_dir, tag)
        if zero_sd_list is None:
            return

        self.optimizer.load_state_dict(
            state_dict_list=zero_sd_list,
            load_optimizer_states=load_optimizer_states,
            load_from_fp32_weights=self.zero_load_from_fp32_weights())
        print(
            f'loading {len(zero_sd_list)} zero partition checkpoints for rank {self.global_rank}'
        )
O
Olatunji Ruwase 已提交
1273

1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313
    def _get_mp_rank_zero_checkpoint_names(self, load_dir, tag, mp_rank, dp_world_size):
        zero_ckpt_names = []
        for dp_rank in range(dp_world_size):
            ckpt_name = self._get_rank_zero_ckpt_name(checkpoints_path=load_dir,
                                                      tag=tag,
                                                      mp_rank=mp_rank,
                                                      dp_rank=dp_rank)
            zero_ckpt_names.append(ckpt_name)

        return zero_ckpt_names

    def _get_all_zero_checkpoint_names(self,
                                       load_dir,
                                       tag,
                                       mp_world_size,
                                       dp_world_size):
        zero_ckpt_names = []
        for mp_rank in range(mp_world_size):
            mp_rank_ckpt_names = self._get_mp_rank_zero_checkpoint_names(
                load_dir=load_dir,
                tag=tag,
                mp_rank=mp_rank,
                dp_world_size=dp_world_size)
            zero_ckpt_names += mp_rank_ckpt_names

        return zero_ckpt_names

    def _get_all_zero_checkpoints(self, load_dir, tag):
        mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
        zero_ckpt_names = self._get_mp_rank_zero_checkpoint_names(
            load_dir=load_dir,
            tag=tag,
            mp_rank=mp_rank,
            dp_world_size=self.loaded_checkpoint_dp_world_size)
        invalid_zero_ckpt_paths = []
        for ckpt_name in zero_ckpt_names:
            if not os.path.exists(ckpt_name):
                invalid_zero_ckpt_paths.append(ckpt_name)

        if len(invalid_zero_ckpt_paths) > 0:
1314
            logger.warn(
1315 1316
                f"Client provided zero checkpoint load paths: {invalid_zero_ckpt_paths} does not exist"
            )
O
Olatunji Ruwase 已提交
1317 1318
            return None

1319 1320 1321 1322 1323 1324 1325 1326 1327
        zero_sd_list = []
        for ckpt_name in zero_ckpt_names:
            zero_sd_list.append(torch.load(ckpt_name, map_location='cpu'))

        zero_optimizer_sd = [sd['optimizer_state_dict'] for sd in zero_sd_list]
        print(
            f"successfully loaded {len(zero_optimizer_sd)} ZeRO state_dicts for rank {self.global_rank}"
        )
        return zero_optimizer_sd
O
Olatunji Ruwase 已提交
1328 1329 1330 1331 1332 1333 1334 1335 1336 1337

    def save_checkpoint(self, save_dir, tag, client_state={}):
        r"""Save training checkpoint

        Arguments:
            save_dir: Required. Directory for saving the checkpoint
            tag: Required. Checkpoint tag used as a unique identifier for the checkpoint. Ex. Global Step.
            client_state: Optional. State dictionary used for saving required training states in the client code.
        """

1338 1339
        # This is to make sure the checkpoint names are created without collision
        # There seems to be issue creating them in parallel
O
Olatunji Ruwase 已提交
1340

J
Jeff Rasley 已提交
1341
        if self.save_non_zero_checkpoint:
1342
            self._create_checkpoint_file(save_dir, tag, False)
J
Jeff Rasley 已提交
1343 1344 1345
            self._save_checkpoint(save_dir, tag, client_state=client_state)

        if self.save_zero_checkpoint:
1346
            self._create_zero_checkpoint_files(save_dir, tag)
J
Jeff Rasley 已提交
1347
            self._save_zero_checkpoint(save_dir, tag)
O
Olatunji Ruwase 已提交
1348 1349 1350

        return True

1351 1352 1353 1354
    def _create_checkpoint_file(self, save_dir, tag, zero_checkpoint):
        name_function = self._get_zero_ckpt_name if zero_checkpoint else self._get_ckpt_name
        try:
            checkpoint_name = name_function(save_dir, tag)
1355
            ensure_directory_exists(checkpoint_name)
1356
        except:
1357
            logger.error(f'Failed saving model checkpoint to {save_dir} with tag {tag}')
1358 1359 1360 1361 1362 1363 1364
            return False

        return True

    def _create_zero_checkpoint_files(self, save_dir, tag):
        success = True
        # zero checkpoint files are created sequentially
1365 1366
        for rank in range(self.world_size):
            if rank == self.global_rank:
1367 1368
                success = self._create_checkpoint_file(save_dir, tag, True)

O
Olatunji Ruwase 已提交
1369 1370
            dist.barrier()

1371 1372
        return success

O
Olatunji Ruwase 已提交
1373 1374 1375
    def _save_checkpoint(self, save_dir, tag, client_state={}):

        save_path = self._get_ckpt_name(save_dir, tag)
1376 1377 1378 1379
        # A hack to save the checkpointing directory. Pipeline parallelism overrides
        # module_state_dict() and uses this path to save the model. module_state_dict()
        # then instead just returns self._curr_save_path.
        self._curr_save_path = os.path.dirname(save_path)
O
Olatunji Ruwase 已提交
1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394

        state = {
            'module':
            self.module_state_dict(),
            'optimizer':
            self.optimizer.state_dict()
            if self.optimizer and not self.zero_optimization() else None,
            'lr_scheduler':
            self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
            'csr_tensor_module_names':
            self.csr_tensor_module_names,
            'skipped_steps':
            self.skipped_steps,
            'global_steps':
            self.global_steps,
1395 1396
            'global_samples':
            self.global_samples,
1397 1398 1399 1400
            'dp_world_size':
            self.dp_world_size,
            'mp_world_size':
            self.mp_world_size
O
Olatunji Ruwase 已提交
1401 1402 1403
        }
        state.update(client_state)

1404 1405
        log_dist(message=f'Saving model checkpoint: {save_path}', ranks=[0])
        #logger.info('Saving model checkpoint: {}'.format(save_path))
O
Olatunji Ruwase 已提交
1406
        torch.save(state, save_path)
1407
        self._curr_save_path = None
O
Olatunji Ruwase 已提交
1408 1409

    def _save_zero_checkpoint(self, save_path, tag):
J
Jeff Rasley 已提交
1410
        zero_checkpoint_name = self._get_zero_ckpt_name(save_path, tag)
O
Olatunji Ruwase 已提交
1411 1412
        zero_sd = {'optimizer_state_dict': self.optimizer.state_dict()}
        torch.save(zero_sd, zero_checkpoint_name)
C
Chunyang Wen 已提交
1413
        logger.info('zero checkpoint saved {}'.format(zero_checkpoint_name))