engine.py 56.3 KB
Newer Older
O
Olatunji Ruwase 已提交
1 2 3 4 5
'''
Copyright 2019 The Microsoft DeepSpeed Team
'''

import os
6
import torch
7
import warnings
O
Olatunji Ruwase 已提交
8
import torch.distributed as dist
9

J
Jeff Rasley 已提交
10
import apex
11
from apex import amp
O
Olatunji Ruwase 已提交
12
from torch.nn.modules import Module
J
Jeff Rasley 已提交
13
from torch.distributed.distributed_c10d import _get_global_rank
O
Olatunji Ruwase 已提交
14 15
from tensorboardX import SummaryWriter

16 17
from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer
from deepspeed.runtime.zero.stage1 import FP16_DeepSpeedZeroOptimizer_Stage1
J
Jeff Rasley 已提交
18
from deepspeed.runtime.zero.utils import is_zero_supported_optimizer
19 20 21 22
from deepspeed.runtime.activation_checkpointing import checkpointing as activation_checkpointing
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
from deepspeed.runtime.config import DeepSpeedConfig, \
J
Jeff Rasley 已提交
23
    ADAM_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, DEEPSPEED_ADAM, DEEPSPEED_OPTIMIZERS
24 25
from deepspeed.runtime.dataloader import DeepSpeedDataLoader
from deepspeed.runtime.constants import \
J
Jeff Rasley 已提交
26
    ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \
J
Jeff Rasley 已提交
27 28
    TORCH_DISTRIBUTED_DEFAULT_PORT
from deepspeed.runtime.zero.constants import \
J
Jeff Rasley 已提交
29
    ZERO_OPTIMIZATION_OPTIMIZER_STATES, ZERO_OPTIMIZATION_GRADIENTS
30 31 32 33
from deepspeed.runtime.csr_tensor import CSRTensor
import deepspeed.runtime.lr_schedules as lr_schedules
from deepspeed.utils import logger
from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer
O
Olatunji Ruwase 已提交
34 35 36 37 38 39 40 41 42 43 44

MEMORY_OPT_ALLREDUCE_SIZE = 500000000
SUMMARY_WRITER_DIR_NAME = "JobId"

try:
    from apex_C import flatten
    from apex_C import unflatten
except ImportError:
    try:
        _ = warned_flatten
    except NameError:
C
Chunyang Wen 已提交
45
        logger.warning(
O
Olatunji Ruwase 已提交
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
            "Warning:  apex was installed without --cpp_ext.  Falling back to Python flatten and unflatten."
        )
        warned_flatten = True
    from torch._utils import _flatten_dense_tensors as flatten
    from torch._utils import _unflatten_dense_tensors as unflatten


def split_half_float_double_csr(tensors):
    dtypes = [
        "torch.cuda.HalfTensor",
        "torch.cuda.FloatTensor",
        "torch.cuda.DoubleTensor",
        CSRTensor.type()
    ]
    buckets = []
    for i, dtype in enumerate(dtypes):
        bucket = [t for t in tensors if t.type() == dtype]
        if bucket:
            buckets.append((dtype, bucket))
    return buckets


def _initialize_parameter_parallel_groups(parameter_parallel_size=None):
    data_parallel_size = int(dist.get_world_size())
    if parameter_parallel_size is None:
        parameter_parallel_size = int(data_parallel_size)
C
Chunyang Wen 已提交
72 73 74
    logger.info("data_parallel_size: %s, parameter_parallel_size: %s",
                data_parallel_size,
                parameter_parallel_size)
O
Olatunji Ruwase 已提交
75 76 77 78 79 80 81 82 83 84 85 86 87
    assert data_parallel_size % parameter_parallel_size == 0, \
        'world size should be divisible by parameter parallel size'
    rank = dist.get_rank()
    my_group = None
    for i in range(dist.get_world_size() // parameter_parallel_size):
        ranks = range(i * parameter_parallel_size, (i + 1) * parameter_parallel_size)
        group = torch.distributed.new_group(ranks)
        if rank in ranks:
            my_group = group
    return my_group


def print_configuration(args, name):
C
Chunyang Wen 已提交
88
    logger.info('{}:'.format(name))
O
Olatunji Ruwase 已提交
89 90
    for arg in sorted(vars(args)):
        dots = '.' * (29 - len(arg))
C
Chunyang Wen 已提交
91
        logger.info('  {} {} {}'.format(arg, dots, getattr(args, arg)))
O
Olatunji Ruwase 已提交
92 93


94
class DeepSpeedEngine(Module):
O
Olatunji Ruwase 已提交
95 96 97 98 99 100 101 102 103 104
    r"""DeepSpeed engine for training.
    """
    def __init__(self,
                 args,
                 model,
                 optimizer=None,
                 model_parameters=None,
                 training_data=None,
                 lr_scheduler=None,
                 mpu=None,
105
                 dist_init_required=None,
J
Jeff Rasley 已提交
106 107
                 collate_fn=None,
                 config_params=None):
108
        super(DeepSpeedEngine, self).__init__()
O
Olatunji Ruwase 已提交
109 110 111 112 113 114 115 116 117 118 119 120
        self.client_optimizer = optimizer
        self.client_model_parameters = model_parameters
        self.client_lr_scheduler = lr_scheduler
        self.training_data = training_data
        self.collate_fn = collate_fn
        self.mpu = mpu
        self.data_parallel_group = None
        self.global_steps = 0
        self.micro_steps = 0
        self.skipped_steps = 0
        self.gradient_average = True
        self.warn_unscaled_loss = True
J
Jeff Rasley 已提交
121
        self.config_params = config_params
122 123
        self.loaded_checkpoint_mp_world_size = None
        self.loaded_checkpoint_dp_world_size = None
124
        self.enable_backward_allreduce = True
O
Olatunji Ruwase 已提交
125

126 127 128
        if dist_init_required is None:
            dist_init_required = not dist.is_initialized()

129 130
        self._mpi_check(args, dist_init_required)

131
        self.dist_backend = "nccl"
O
Olatunji Ruwase 已提交
132
        if dist_init_required:
133
            if not dist.is_initialized():
C
Chunyang Wen 已提交
134
                logger.info("Initializing torch distributed with backend: {}".format(
135 136 137
                    self.dist_backend))
                dist.init_process_group(backend=self.dist_backend)
            else:
C
Chunyang Wen 已提交
138
                logger.warning(
139 140
                    "Was given dist_init_required=True but detected that torch"
                    "distributed was already initialized, cannot initialize twice.")
O
Olatunji Ruwase 已提交
141 142 143 144 145

        self._do_args_sanity_check(args)
        self._configure_with_arguments(args, mpu)
        self._do_sanity_check()

146 147
        self._init_distributed(dist_init_required)

O
Olatunji Ruwase 已提交
148
        self.sample_count = 0
149
        if self.tensorboard_enabled() and self.global_rank == 0:
O
Olatunji Ruwase 已提交
150 151
            self.summary_writer = self.get_summary_writer()

152 153 154
        # Configure distributed model
        self._configure_distributed_model(model)

J
Jeff Rasley 已提交
155 156 157
        # Configure wall clock timer
        self.timers = SynchronizedWallClockTimer()

O
Olatunji Ruwase 已提交
158 159 160
        # Throughput timer
        self.tput_timer = ThroughputTimer(
            batch_size=self.train_micro_batch_size_per_gpu(),
161
            num_workers=self.dp_world_size,
162
            steps_per_output=self.steps_per_print(),
O
Olatunji Ruwase 已提交
163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
            monitor_memory=False)

        self.training_dataloader = self.deepspeed_io(
            training_data) if training_data else None

        # Configure optimizer and scheduler
        self.optimizer = None
        self.lr_scheduler = None
        if model_parameters or optimizer:
            self._configure_optimizer(optimizer, model_parameters)
            self._configure_lr_scheduler(lr_scheduler)
            self._report_progress(0)

        # Bookkeeping for csr support
        self.csr_tensor_module_names = set()
        if self.sparse_gradients_enabled():
            for name, module in self.module.named_modules():
                if isinstance(module, torch.nn.Embedding):
S
Samyam Rajbhandari 已提交
181
                    self.csr_tensor_module_names.add(name + ".weight")
C
Chunyang Wen 已提交
182 183
                    logger.info("Will convert {} to sparse (csr) "
                                "tensor during training".format(name))
O
Olatunji Ruwase 已提交
184 185 186 187 188 189 190 191 192 193

        self.save_non_zero_checkpoint = False
        self.save_zero_checkpoint = False
        self._configure_checkpointing(dist_init_required)

        if self.global_rank == 0:
            self._config.print('DeepSpeedLight configuration')
            if self.dump_state():
                print_configuration(self, 'DeepSpeedLight')

194
    def _mpi_check(self, args, dist_init_required):
J
Jeff Rasley 已提交
195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219
        if hasattr(args, 'deepspeed_mpi') and args.deepspeed_mpi:
            from mpi4py import MPI
            import subprocess
            comm = MPI.COMM_WORLD
            rank = comm.Get_rank()
            world_size = comm.Get_size()

            master_addr = None
            if rank == 0:
                hostname_cmd = ["hostname -I"]
                result = subprocess.check_output(hostname_cmd, shell=True)
                master_addr = result.decode('utf-8').split()[0]
            master_addr = comm.bcast(master_addr, root=0)

            # Determine local rank by assuming hostnames are unique
            proc_name = MPI.Get_processor_name()
            all_procs = comm.allgather(proc_name)
            local_rank = sum([i == proc_name for i in all_procs[:rank]])

            os.environ['RANK'] = str(rank)
            os.environ['WORLD_SIZE'] = str(world_size)
            args.local_rank = local_rank
            os.environ['MASTER_ADDR'] = master_addr
            os.environ['MASTER_PORT'] = TORCH_DISTRIBUTED_DEFAULT_PORT

C
Chunyang Wen 已提交
220
            logger.info(
J
Jeff Rasley 已提交
221 222 223 224 225 226 227
                "Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
                .format(os.environ['RANK'],
                        args.local_rank,
                        os.environ['WORLD_SIZE'],
                        os.environ['MASTER_ADDR'],
                        os.environ['MASTER_PORT']))

228 229
            if not dist_init_required and dist.is_initialized():
                assert dist.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(rank, dist.get_rank())
230 231
                assert dist.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format(
                    world_size, dist.get_world_size())
232

O
Olatunji Ruwase 已提交
233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
    def tensorboard_enabled(self):
        return self._config.tensorboard_enabled

    def tensorboard_output_path(self):
        return self._config.tensorboard_output_path

    def tensorboard_job_name(self):
        return self._config.tensorboard_job_name

    def get_summary_writer(self,
                           name="DeepSpeedJobName",
                           base=os.environ["HOME"] + "/tensorboard"):
        if self.tensorboard_job_name():
            name = self.tensorboard_job_name()
        if self.tensorboard_output_path():
            return SummaryWriter(log_dir=self.tensorboard_output_path())
        if 'DLWS_JOB_ID' in os.environ:
            SUMMARY_WRITER_DIR_NAME = os.environ['DLWS_JOB_ID'] + "/logs"
        return SummaryWriter(log_dir=os.path.join(base, SUMMARY_WRITER_DIR_NAME, name))

    def wall_clock_breakdown(self):
        return self._config.wall_clock_breakdown

J
Jeff Rasley 已提交
256 257 258
    def memory_breakdown(self):
        return self._config.memory_breakdown

O
Olatunji Ruwase 已提交
259 260 261 262 263 264 265 266 267 268
    def sparse_gradients_enabled(self):
        return self._config.sparse_gradients_enabled

    def train_batch_size(self):
        return self._config.train_batch_size

    def train_micro_batch_size_per_gpu(self):
        return self._config.train_micro_batch_size_per_gpu

    def optimizer_name(self):
J
Jeff Rasley 已提交
269
        return self.client_optimizer.__class__.__name__ if self.client_optimizer else self._config.optimizer_name
O
Olatunji Ruwase 已提交
270 271 272 273

    def optimizer_params(self):
        return self._config.optimizer_params

274 275 276
    def optimizer_legacy_fusion(self):
        return self._config.optimizer_legacy_fusion

O
Olatunji Ruwase 已提交
277 278 279 280 281 282 283 284 285
    def scheduler_name(self):
        return self._config.scheduler_name

    def scheduler_params(self):
        return self._config.scheduler_params

    def zero_optimization(self):
        return self._config.zero_enabled

286 287 288
    def zero_allow_untested_optimizer(self):
        return self._config.zero_allow_untested_optimizer

J
Jeff Rasley 已提交
289 290 291 292 293 294
    def zero_reduce_scatter(self):
        return self._config.zero_config.reduce_scatter

    def zero_overlap_comm(self):
        return self._config.zero_config.overlap_comm

J
Jeff Rasley 已提交
295 296 297
    def zero_cpu_offload(self):
        return self._config.zero_config.cpu_offload

J
Jeff Rasley 已提交
298 299 300 301 302 303 304 305 306 307 308 309 310 311 312
    def zero_optimization_stage(self):
        return self._config.zero_optimization_stage

    def zero_reduce_bucket_size(self):
        return self._config.zero_config.reduce_bucket_size

    def zero_allgather_bucket_size(self):
        return self._config.zero_config.allgather_bucket_size

    def zero_optimization_partition_gradients(self):
        return self.zero_optimization_stage() >= ZERO_OPTIMIZATION_GRADIENTS

    def zero_contiguous_gradients(self):
        return self._config.zero_config.contiguous_gradients

313 314 315
    def zero_load_from_fp32_weights(self):
        return self._config.zero_config.load_from_fp32_weights

O
Olatunji Ruwase 已提交
316 317 318
    def fp16_enabled(self):
        return self._config.fp16_enabled

319 320 321 322 323 324
    def amp_enabled(self):
        return self._config.amp_enabled

    def amp_params(self):
        return self._config.amp_params

O
Olatunji Ruwase 已提交
325 326 327 328 329 330 331 332 333 334 335 336
    def loss_scale(self):
        return self._config.loss_scale

    def gradient_accumulation_steps(self):
        return self._config.gradient_accumulation_steps

    def allreduce_always_fp32(self):
        return self._config.allreduce_always_fp32

    def postscale_gradients(self):
        return not self._config.prescale_gradients

337 338 339
    def gradient_predivide_factor(self):
        return self._config.gradient_predivide_factor

O
Olatunji Ruwase 已提交
340 341 342
    def steps_per_print(self):
        return self._config.steps_per_print

J
Jeff Rasley 已提交
343 344
    def zero_allgather_partitions(self):
        return self._config.zero_config.allgather_partitions
O
Olatunji Ruwase 已提交
345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364

    def dump_state(self):
        return self._config.dump_state

    def gradient_clipping(self):
        return self._config.gradient_clipping

    def dynamic_loss_scale(self):
        return self._config.loss_scale == 0

    def initial_dynamic_scale(self):
        return self._config.initial_dynamic_scale

    def dynamic_loss_scale_args(self):
        return self._config.dynamic_loss_scale_args

    def _configure_lr_scheduler(self, client_lr_scheduler):
        # First check for scheduler in json configuration
        lr_scheduler = self._scheduler_from_config(self.optimizer)
        if lr_scheduler:
C
Chunyang Wen 已提交
365
            logger.info(
O
Olatunji Ruwase 已提交
366 367 368
                f'DeepSpeed using configured LR scheduler = {self.scheduler_name()}')
            self.lr_scheduler = lr_scheduler
        else:
C
Chunyang Wen 已提交
369
            logger.warning('DeepSpeed using client LR scheduler')
O
Olatunji Ruwase 已提交
370
            self.lr_scheduler = client_lr_scheduler
C
Chunyang Wen 已提交
371
        logger.info(f'DeepSpeed LR Scheduler = {self.lr_scheduler}')
O
Olatunji Ruwase 已提交
372 373 374

    def _configure_checkpointing(self, dist_init_required):

375 376 377
        dp_rank = self.global_rank
        if self.mpu:
            dp_rank = self.mpu.get_data_parallel_rank()
O
Olatunji Ruwase 已提交
378

379
        # only the first data parallel process needs to store the model checkpoint
380
        self.save_non_zero_checkpoint = (dp_rank == 0)
O
Olatunji Ruwase 已提交
381 382 383 384

        if self.zero_optimization():
            pp_rank = torch.distributed.get_rank(group=self.optimizer.dp_process_group)

385 386 387
            # Only the first parameter parallel process needs to store the
            # optimizer state checkpoints for zero
            self.save_zero_checkpoint = (pp_rank == dp_rank)
O
Olatunji Ruwase 已提交
388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411

    def _scheduler_from_config(self, optimizer):
        scheduler_name = self.scheduler_name()
        if scheduler_name is not None:
            if hasattr(lr_schedules, scheduler_name):
                scheduler = getattr(lr_schedules, scheduler_name)
            else:
                assert hasattr(torch.optim.lr_scheduler, scheduler_name), \
                    f"DeepSpeed does not recognize LR scheduler {scheduler_name}"

                scheduler = getattr(torch.optim.lr_scheduler, scheduler_name)

            scheduler_params = self.scheduler_params()
            instantiated_scheduler = scheduler(optimizer, **scheduler_params)
            return instantiated_scheduler
        else:
            return None

    def _init_distributed(self, dist_init_required):
        if self.local_rank >= 0:
            torch.cuda.set_device(self.local_rank)
            self.device = torch.device("cuda", self.local_rank)
            self.world_size = dist.get_world_size()
            self.global_rank = dist.get_rank()
C
Chunyang Wen 已提交
412
            logger.info("Set device to local rank {} within node.".format(
O
Olatunji Ruwase 已提交
413 414 415 416 417 418 419 420 421
                self.local_rank))
        else:
            self.world_size = 1
            self.global_rank = 0
            self.device = torch.device("cuda")

    # Configure based on command line arguments
    def _configure_with_arguments(self, args, mpu):
        self.local_rank = args.local_rank if hasattr(args, 'local_rank') else 0
J
Jeff Rasley 已提交
422 423 424
        self._config = DeepSpeedConfig(args.deepspeed_config,
                                       mpu,
                                       param_dict=self.config_params)
O
Olatunji Ruwase 已提交
425 426 427

    # Validate command line arguments
    def _do_args_sanity_check(self, args):
428
        if hasattr(args, 'deepscale_config') and args.deepscale_config is not None:
C
Chunyang Wen 已提交
429
            logger.warning(
430 431 432 433 434 435
                "************ --deepscale_config is deprecated, please use --deepspeed_config ************"
            )
            if hasattr(args, 'deepspeed_config'):
                assert args.deepspeed_config is None, "Not sure how to proceed, we were given both a deepscale_config and deepspeed_config"
            args.deepspeed_config = args.deepscale_config

O
Olatunji Ruwase 已提交
436 437 438
        assert hasattr(args, 'local_rank') and type(args.local_rank) == int, \
            'DeepSpeed requires integer command line parameter --local_rank'

J
Jeff Rasley 已提交
439 440 441
        if self.config_params is None:
            assert hasattr(args, 'deepspeed_config') and args.deepspeed_config is not None, \
                'DeepSpeed requires --deepspeed_config to specify configuration file'
O
Olatunji Ruwase 已提交
442

J
Jeff Rasley 已提交
443 444
            assert os.path.isfile(args.deepspeed_config), \
                'DeepSpeed configuration file: {} is not an existing file'.format(args.deepspeed_config)
O
Olatunji Ruwase 已提交
445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461

    def _is_supported_optimizer(self, optimizer_name):
        return optimizer_name in DEEPSPEED_OPTIMIZERS or \
            getattr(torch.optim, optimizer_name, None) is not None

    # Validate configuration based on command line arguments
    def _do_sanity_check(self):
        if not self.client_optimizer:
            assert self._is_supported_optimizer(self.optimizer_name()), \
                '{} is not a supported DeepSpeed Optimizer'.format(self.optimizer_name())
            assert self.client_model_parameters, \
                'DeepSpeed {} optimizer requires parameters in initialize() call'.format(self.optimizer_name())

        if self.optimizer_name() == LAMB_OPTIMIZER:
            assert self.dynamic_loss_scale(), \
                'DeepSpeed {} optimizer requires dynamic loss scaling'.format(self.optimizer_name())

462 463 464 465 466 467 468
    def _broadcast_model(self):
        for p in self.module.parameters():
            if torch.is_tensor(p):
                dist.broadcast(p,
                               self.broadcast_src_rank,
                               group=self.data_parallel_group)

O
Olatunji Ruwase 已提交
469 470 471 472 473
    def _configure_distributed_model(self, model):
        self.module = model
        if self.fp16_enabled():
            self.module.half()
        self.module.to(self.device)
474

O
Olatunji Ruwase 已提交
475 476 477
        if self.mpu is None:
            self.data_parallel_group = _initialize_parameter_parallel_groups()
            self.dp_world_size = dist.get_world_size()
478
            self.mp_world_size = 1
479
            self.broadcast_src_rank = 0
O
Olatunji Ruwase 已提交
480 481 482
        else:
            self.data_parallel_group = self.mpu.get_data_parallel_group()
            self.dp_world_size = self.mpu.get_data_parallel_world_size()
483
            self.mp_world_size = self.mpu.get_model_parallel_world_size()
484 485 486 487
            self.broadcast_src_rank = _get_global_rank(
                self.mpu.get_data_parallel_group(),
                0)
            logger.info(f"global src_rank={self.broadcast_src_rank}")
O
Olatunji Ruwase 已提交
488

489 490
        if not self.amp_enabled():
            self._broadcast_model()
O
Olatunji Ruwase 已提交
491 492 493

    # Configure optimizer
    def _configure_optimizer(self, client_optimizer, model_parameters):
J
Jeff Rasley 已提交
494

O
Olatunji Ruwase 已提交
495 496
        if client_optimizer is not None:
            basic_optimizer = client_optimizer
C
Chunyang Wen 已提交
497
            logger.info('Using client Optimizer as basic optimizer')
O
Olatunji Ruwase 已提交
498 499
        else:
            basic_optimizer = self._configure_basic_optimizer(model_parameters)
C
Chunyang Wen 已提交
500
            logger.info(
O
Olatunji Ruwase 已提交
501 502 503
                'Using DeepSpeed Optimizer param name {} as basic optimizer'.format(
                    self.optimizer_name()))

C
Chunyang Wen 已提交
504
        logger.info('DeepSpeed Basic Optimizer = {}'.format(basic_optimizer))
O
Olatunji Ruwase 已提交
505

506
        if self.zero_optimization():
507
            assert not self.amp_enabled(), "Amp and ZeRO are not currently compatible, please use (legacy) fp16 mode which performs similar to amp opt_mode=O2"
J
Jeff Rasley 已提交
508
            if not is_zero_supported_optimizer(basic_optimizer):
509
                assert self.zero_allow_untested_optimizer(), \
510
                    'You are using an untested ZeRO Optimizer. Please add <"zero_allow_untested_optimizer": true> in the configuration file to use it.'
511

C
Chunyang Wen 已提交
512
                logger.warning(
513 514
                    "**** You are using ZeRO with an untested optimizer, proceed with caution *****"
                )
J
Jeff Rasley 已提交
515

O
Olatunji Ruwase 已提交
516
            self.optimizer = self._configure_zero_optimizer(basic_optimizer)
517 518 519 520 521 522
        elif self.amp_enabled():
            assert not self.fp16_enabled(), "Cannot enable both amp with (legacy) fp16 mode"
            amp_params = self.amp_params()
            logger.info(f"Initializing AMP with these params: {amp_params}")
            self.module, self.optimizer = amp.initialize(self.module, basic_optimizer, **amp_params)
            self._broadcast_model()
O
Olatunji Ruwase 已提交
523 524 525 526
        elif self.fp16_enabled():
            self.optimizer = self._configure_fp16_optimizer(basic_optimizer)
        else:
            self.optimizer = basic_optimizer
J
Jeff Rasley 已提交
527 528
        logger.info('DeepSpeed Final Optimizer = {}'.format(self.optimizer))
        logger.info('DeepSpeed Final Optimizer = {}'.format(self.optimizer.state_dict()))
O
Olatunji Ruwase 已提交
529 530 531

    def _configure_basic_optimizer(self, model_parameters):
        optimizer_parameters = self.optimizer_params()
532
        # print(optimizer_parameters.keys())
533 534
        if 'max_grad_norm' in optimizer_parameters.keys():
            raise ValueError(
535 536
                "'max_grad_norm' is not supported as an optimizer parameter, please switch to using the deepspeed parameter 'gradient_clipping' see: https://www.deepspeed.ai/docs/config-json/#gradient-clipping for more details"
            )
O
Olatunji Ruwase 已提交
537
        if self.optimizer_name() == ADAM_OPTIMIZER:
J
Jeff Rasley 已提交
538 539 540 541 542 543 544 545
            if self.zero_cpu_offload():
                optimizer = torch.optim.Adam(model_parameters, **optimizer_parameters)
            else:
                from apex.optimizers.fused_adam import FusedAdam
                optimizer = FusedAdam(model_parameters, **optimizer_parameters)
        elif self.optimizer_name() == DEEPSPEED_ADAM:
            from deepspeed.ops.adam import DeepSpeedCPUAdam
            optimizer = DeepSpeedCPUAdam(model_parameters, **optimizer_parameters)
O
Olatunji Ruwase 已提交
546
        elif self.optimizer_name() == LAMB_OPTIMIZER:
547
            from deepspeed.ops.lamb import FusedLamb
O
Olatunji Ruwase 已提交
548
            optimizer = FusedLamb(model_parameters, **optimizer_parameters)
549 550 551
        elif self.optimizer_name() == ONEBIT_ADAM_OPTIMIZER:
            from deepspeed.runtime.fp16.onebit_adam import OnebitAdam
            optimizer = OnebitAdam(model_parameters, self, **optimizer_parameters)
O
Olatunji Ruwase 已提交
552 553 554 555 556 557 558 559 560
        else:
            torch_optimizer = getattr(torch.optim, self.optimizer_name())
            optimizer = torch_optimizer(model_parameters, **optimizer_parameters)
        return optimizer

    def _configure_fp16_optimizer(self, optimizer):
        initial_dynamic_scale = self.initial_dynamic_scale()
        dynamic_loss_args = self.dynamic_loss_scale_args()
        clip_grad = self.gradient_clipping()
J
Jeff Rasley 已提交
561 562 563
        if isinstance(optimizer,
                      apex.optimizers.FusedAdam) or self.optimizer_name(
                      ) == ONEBIT_ADAM_OPTIMIZER:
O
Olatunji Ruwase 已提交
564
            if self.dynamic_loss_scale():
C
Chunyang Wen 已提交
565
                logger.info('Creating fp16 optimizer with dynamic loss scale')
566
                timers = self.timers if self.wall_clock_breakdown() else None
567 568 569 570 571 572 573
                optimizer = FP16_Optimizer(
                    optimizer,
                    dynamic_loss_scale=True,
                    initial_dynamic_scale=initial_dynamic_scale,
                    dynamic_loss_args=dynamic_loss_args,
                    mpu=self.mpu,
                    clip_grad=clip_grad,
574 575
                    fused_adam_legacy=self.optimizer_legacy_fusion(),
                    timers=timers)
O
Olatunji Ruwase 已提交
576
            else:
C
Chunyang Wen 已提交
577
                logger.info('Creating fp16 optimizer with static loss scale: {}'.format(
O
Olatunji Ruwase 已提交
578
                    self.loss_scale()))
579 580 581 582 583 584
                optimizer = FP16_Optimizer(
                    optimizer,
                    static_loss_scale=self.loss_scale(),
                    mpu=self.mpu,
                    clip_grad=clip_grad,
                    fused_adam_legacy=self.optimizer_legacy_fusion())
O
Olatunji Ruwase 已提交
585
        else:
C
Chunyang Wen 已提交
586
            logger.info('Creating fp16 unfused optimizer with dynamic loss scale')
O
Olatunji Ruwase 已提交
587 588 589 590 591 592
            optimizer = FP16_UnfusedOptimizer(
                optimizer,
                dynamic_loss_scale=self.dynamic_loss_scale(),
                dynamic_loss_args=dynamic_loss_args,
                mpu=self.mpu,
                clip_grad=clip_grad,
593
                fused_lamb_legacy=self.optimizer_name() == LAMB_OPTIMIZER)
O
Olatunji Ruwase 已提交
594 595 596 597

        return optimizer

    def _configure_zero_optimizer(self, optimizer):
J
Jeff Rasley 已提交
598
        zero_stage = self.zero_optimization_stage()
C
Chunyang Wen 已提交
599
        logger.info('Creating fp16 ZeRO stage {} optimizer'.format(zero_stage))
J
Jeff Rasley 已提交
600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627

        if zero_stage == ZERO_OPTIMIZATION_OPTIMIZER_STATES:
            assert self.zero_reduce_scatter(), 'Stage 1 only supports reduce scatter mode'
            optimizer = FP16_DeepSpeedZeroOptimizer_Stage1(
                optimizer,
                static_loss_scale=self.loss_scale(),
                dynamic_loss_scale=self.dynamic_loss_scale(),
                dynamic_loss_args=self.dynamic_loss_scale_args(),
                clip_grad=self.gradient_clipping(),
                all_gather_partitions=self.zero_allgather_partitions(),
                allgather_size=self.zero_allgather_bucket_size(),
                max_elements_per_comm=self.zero_reduce_bucket_size(),
                dp_process_group=self.data_parallel_group,
                mpu=self.mpu)
        elif zero_stage == ZERO_OPTIMIZATION_GRADIENTS:
            optimizer = FP16_DeepSpeedZeroOptimizer(
                optimizer,
                timers=self.timers,
                static_loss_scale=self.loss_scale(),
                dynamic_loss_scale=self.dynamic_loss_scale(),
                dynamic_loss_args=self.dynamic_loss_scale_args(),
                clip_grad=self.gradient_clipping(),
                contiguous_gradients=self.zero_contiguous_gradients(),
                reduce_bucket_size=self.zero_reduce_bucket_size(),
                allgather_bucket_size=self.zero_allgather_bucket_size(),
                dp_process_group=self.data_parallel_group,
                reduce_scatter=self.zero_reduce_scatter(),
                overlap_comm=self.zero_overlap_comm(),
J
Jeff Rasley 已提交
628
                cpu_offload=self.zero_cpu_offload(),
629 630
                mpu=self.mpu,
                postscale_gradients=self.postscale_gradients(),
J
Jeff Rasley 已提交
631 632
                gradient_predivide_factor=self.gradient_predivide_factor(),
                gradient_accumulation_steps=self.gradient_accumulation_steps())
J
Jeff Rasley 已提交
633 634
        else:
            raise NotImplementedError("ZeRO stage {} not implemented".format(zero_stage))
O
Olatunji Ruwase 已提交
635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662

        return optimizer

    def deepspeed_io(self,
                     dataset,
                     batch_size=None,
                     route=ROUTE_TRAIN,
                     pin_memory=True,
                     data_sampler=None,
                     collate_fn=None,
                     num_local_io_workers=None):
        if not isinstance(dataset, torch.utils.data.Dataset):
            raise ValueError("Training data must be a torch Dataset")

        if data_sampler is None and (route == ROUTE_PREDICT or route == ROUTE_EVAL):
            data_sampler = torch.utils.data.SequentialSampler(dataset)

        if batch_size is None:
            batch_size = self.train_micro_batch_size_per_gpu()

        if collate_fn is None:
            collate_fn = self.collate_fn

        # Currently we only use timer in train route
        deepspeed_io_timer = None
        if route == ROUTE_TRAIN:
            deepspeed_io_timer = self.tput_timer

663 664 665 666
        # If mpu is provied, forward world size and parallel rank to sampler.
        data_parallel_world_size = None
        data_parallel_rank = None
        if self.mpu is not None:
667 668
            data_parallel_world_size = self.mpu.get_data_parallel_world_size()
            data_parallel_rank = self.mpu.get_data_parallel_rank()
669

O
Olatunji Ruwase 已提交
670 671 672 673 674 675 676
        return DeepSpeedDataLoader(dataset=dataset,
                                   batch_size=batch_size,
                                   pin_memory=pin_memory,
                                   collate_fn=collate_fn,
                                   local_rank=self.local_rank,
                                   tput_timer=deepspeed_io_timer,
                                   num_local_io_workers=num_local_io_workers,
677 678 679
                                   data_sampler=data_sampler,
                                   data_parallel_world_size=data_parallel_world_size,
                                   data_parallel_rank=data_parallel_rank)
O
Olatunji Ruwase 已提交
680 681 682 683 684 685 686 687 688 689 690 691 692 693 694

    def train(self):
        r"""
        """

        self.warn_unscaled_loss = True
        self.module.train()

    def eval(self):
        r"""
        """

        self.warn_unscaled_loss = True
        self.module.train(False)

695 696 697 698 699 700 701 702 703 704
    def _scale_loss(self, prescaled_loss):
        if isinstance(prescaled_loss, torch.Tensor):
            scaled_loss = prescaled_loss / self.gradient_accumulation_steps()
        elif isinstance(prescaled_loss, tuple) or isinstance(prescaled_loss, list):
            scaled_loss = []
            for l in prescaled_loss:
                if isinstance(l, torch.Tensor):
                    scaled_loss.append(l / self.gradient_accumulation_steps())
                else:
                    scaled_loss.append(l)
O
Olatunji Ruwase 已提交
705
        else:
706
            scaled_loss = prescaled_loss
O
Olatunji Ruwase 已提交
707
            if self.warn_unscaled_loss:
C
Chunyang Wen 已提交
708
                logger.warning(
709 710
                    f'DeepSpeed unable to scale loss because of type: {type(prescaled_loss)}'
                )
O
Olatunji Ruwase 已提交
711 712
                self.warn_unscaled_loss = False

713
        return scaled_loss
O
Olatunji Ruwase 已提交
714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737

    def forward(self, *inputs, **kwargs):
        r"""Execute forward propagation

        Arguments:
            *inputs: Variable length input list
            **kwargs: variable length keyword arguments
        """

        if self.wall_clock_breakdown():
            self.timers('forward_microstep').start()
            self.timers('forward').start()

        if self.training_dataloader is None:
            self.tput_timer.start()
        loss = self.module(*inputs, **kwargs)

        if self.wall_clock_breakdown():
            self.timers('forward').stop()
            self.timers('forward_microstep').stop()

        return loss

    def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
738 739 740 741 742 743
        #Zero stage 2 communicates during non gradient accumulation boundaries as well
        if self.zero_optimization_partition_gradients():
            self.optimizer.overlapping_partition_gradients_reduce_epilogue()

        #Communicate only at gradient accumulation boundaries
        elif self.is_gradient_accumulation_boundary():
J
Jeff Rasley 已提交
744 745 746 747
            if self.zero_optimization_stage() == ZERO_OPTIMIZATION_OPTIMIZER_STATES:
                assert self.zero_reduce_scatter()
                self.optimizer.reduce_scatter_gradients(
                    postscale_gradients=self.postscale_gradients(),
748
                    gradient_predivide_factor=self.gradient_predivide_factor(),
J
Jeff Rasley 已提交
749 750 751
                    gradient_average=self.gradient_average)
            else:
                self.buffered_allreduce_fallback(elements_per_buffer=bucket_size)
O
Olatunji Ruwase 已提交
752

753
    def backward(self, loss, allreduce_gradients=True, release_loss=False):
O
Olatunji Ruwase 已提交
754 755 756 757 758 759 760
        r"""Execute backward pass on the loss

        Arguments:
            loss: Torch tensor on which to execute backward propagation
            allreduce_gradients: If this is False, then gradient averaging will be skipped. Default is True.
        """

761 762
        # scale loss w.r.t. gradient accumulation if needed
        if self.gradient_accumulation_steps() > 1:
J
Jeff Rasley 已提交
763
            loss = self._scale_loss(loss.float())
764

765 766 767 768 769 770 771 772 773 774 775 776 777 778 779
        # Log training Loss
        if self.tensorboard_enabled():
            if self.is_gradient_accumulation_boundary():
                if self.global_rank == 0:
                    self.sample_count += (self.train_micro_batch_size_per_gpu() *
                                          self.dp_world_size *
                                          self.gradient_accumulation_steps())
                    self.summary_events = [
                        (f'Train/Samples/train_loss',
                         loss.mean().item() * self.gradient_accumulation_steps(),
                         self.sample_count)
                    ]
                    for event in self.summary_events:  # write_summary_events
                        self.summary_writer.add_scalar(event[0], event[1], event[2])
                    self.summary_writer.flush()
O
Olatunji Ruwase 已提交
780 781 782 783 784 785 786 787 788 789 790 791 792

        if self.wall_clock_breakdown():
            self.timers('backward_microstep').start()
            self.timers('backward').start()

        assert self.optimizer is not None, "must provide optimizer during " \
                                           "init in order to use backward"

        if self.wall_clock_breakdown():
            self.timers('backward_inner_microstep').start()
            self.timers('backward_inner').start()

        if self.zero_optimization():
J
Jeff Rasley 已提交
793 794
            self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary(
            )
O
Olatunji Ruwase 已提交
795
            self.optimizer.backward(loss)
796
        elif self.amp_enabled():
797 798 799 800 801 802
            # AMP requires delaying unscale when inside gradient accumulation boundaries
            # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations
            delay_unscale = not self.is_gradient_accumulation_boundary()
            with amp.scale_loss(loss,
                                self.optimizer,
                                delay_unscale=delay_unscale) as scaled_loss:
803
                scaled_loss.backward()
O
Olatunji Ruwase 已提交
804 805 806 807 808 809 810 811 812 813 814 815 816
        elif self.fp16_enabled():
            self.optimizer.backward(loss)
        else:
            loss.backward()

        if self.wall_clock_breakdown():
            self.timers('backward_inner').stop()
            self.timers('backward_inner_microstep').stop()

        if self.wall_clock_breakdown():
            self.timers('backward_allreduce_microstep').start()
            self.timers('backward_allreduce').start()

817
        if allreduce_gradients and self.enable_backward_allreduce:
O
Olatunji Ruwase 已提交
818 819 820 821 822 823 824 825
            self.allreduce_gradients()

        if self.wall_clock_breakdown():
            self.timers('backward_allreduce').stop()
            self.timers('backward_allreduce_microstep').stop()
            self.timers('backward').stop()
            self.timers('backward_microstep').stop()

826 827 828 829
        if release_loss:
            # loss.data = None
            pass

830 831
        return loss

O
Olatunji Ruwase 已提交
832 833 834 835
    def is_gradient_accumulation_boundary(self):
        return (self.micro_steps + 1) % \
            self.gradient_accumulation_steps() == 0

S
Samyam Rajbhandari 已提交
836 837 838 839 840 841 842
    def zero_grad(self):
        """
        Zero parameter grads.
        """
        for param_name, param in self.module.named_parameters():
            param.grad = None

843 844 845 846
    def clip_fp32_gradients(self):
        torch.nn.utils.clip_grad_norm_(parameters=self.module.parameters(),
                                       max_norm=self.gradient_clipping())

O
Olatunji Ruwase 已提交
847 848 849 850 851 852 853 854 855 856 857 858
    def step(self):
        r"""Execute the weight update step after forward and backward propagation on effective_train_batch
        """
        if self.wall_clock_breakdown():
            self.timers('step_microstep').start()
            self.timers('step').start()

        assert self.optimizer is not None, "must provide optimizer during " \
                                           "init in order to use step"
        report_progress = self.global_rank == 0 if self.global_rank else True

        if self.is_gradient_accumulation_boundary():
859

860 861 862 863 864 865 866 867 868
            if self.gradient_clipping() > 0.0:
                if not self.fp16_enabled() and not self.amp_enabled():
                    self.clip_fp32_gradients()
                elif self.amp_enabled():
                    # AMP's recommended way of doing clipping
                    # https://nvidia.github.io/apex/advanced.html#gradient-clipping
                    master_params = amp.master_params(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(parameters=master_params,
                                                   max_norm=self.gradient_clipping())
O
Olatunji Ruwase 已提交
869
            self.optimizer.step()
S
Samyam Rajbhandari 已提交
870

871 872 873 874
            #zero grad in basic optimizer could be unreliable and may not exhibit
            #the behaviour that we want
            if not self.zero_optimization() and not self.fp16_enabled(
            ) and not self.amp_enabled():
S
Samyam Rajbhandari 已提交
875 876 877
                self.zero_grad()
            else:
                self.optimizer.zero_grad()
O
Olatunji Ruwase 已提交
878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896

            # Check overlow here since in DS fp16 optimizer, the overflow is updated in above step() function.
            overflow = False
            if hasattr(self.optimizer, 'overflow'):
                overflow = self.optimizer.overflow

            if overflow:
                self.skipped_steps += 1
            else:
                if self.lr_scheduler is not None:
                    self.lr_scheduler.step()
                if report_progress and (self.global_steps +
                                        1) % self.steps_per_print() == 0:
                    self._report_progress(self.global_steps + 1)

            self.global_steps += 1

        self.tput_timer.stop(report_progress)

897 898 899 900 901 902 903 904 905 906
        # Log learning rate
        if self.tensorboard_enabled():
            if self.is_gradient_accumulation_boundary():
                if self.global_rank == 0:
                    self.summary_events = [(f'Train/Samples/lr',
                                            self.get_lr()[0],
                                            self.sample_count)]
                    for event in self.summary_events:  # write_summary_events
                        self.summary_writer.add_scalar(event[0], event[1], event[2])
                    self.summary_writer.flush()
O
Olatunji Ruwase 已提交
907 908 909 910

        if self.wall_clock_breakdown():
            self.timers('step').stop()
            self.timers('step_microstep').stop()
911
            timer_names = [
O
Olatunji Ruwase 已提交
912 913 914 915 916
                'forward_microstep',
                'backward_microstep',
                'backward_inner_microstep',
                'backward_allreduce_microstep',
                'step_microstep'
917 918
            ]
            self.timers.log(names=timer_names, memory_breakdown=self.memory_breakdown())
J
Jeff Rasley 已提交
919

920
            # Log timing
J
Jeff Rasley 已提交
921
            if self.is_gradient_accumulation_boundary():
922 923
                if self.tensorboard_enabled():
                    if self.global_rank == 0:
924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941
                        self.summary_events = [
                            (f'Train/Samples/elapsed_time_ms_forward',
                             self.timers('forward').elapsed(reset=False) * 1000.0,
                             self.sample_count),
                            (f'Train/Samples/elapsed_time_ms_backward',
                             self.timers('backward').elapsed(reset=False) * 1000.0,
                             self.sample_count),
                            (f'Train/Samples/elapsed_time_ms_backward_inner',
                             self.timers('backward_inner').elapsed(reset=False) * 1000.0,
                             self.sample_count),
                            (f'Train/Samples/elapsed_time_ms_backward_allreduce',
                             self.timers('backward_allreduce').elapsed(reset=False) *
                             1000.0,
                             self.sample_count),
                            (f'Train/Samples/elapsed_time_ms_step',
                             self.timers('step').elapsed(reset=False) * 1000.0,
                             self.sample_count)
                        ]
942 943 944 945 946
                        for event in self.summary_events:  # write_summary_events
                            self.summary_writer.add_scalar(event[0], event[1], event[2])
                        self.summary_writer.flush()

            if self.wall_clock_breakdown():
J
Jeff Rasley 已提交
947 948 949 950 951 952 953
                self.timers.log([
                    'forward',
                    'backward',
                    'backward_inner',
                    'backward_allreduce',
                    'step'
                ])
O
Olatunji Ruwase 已提交
954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970

        self.micro_steps += 1

    def _get_optimizer_param(self, param_name):
        result = []
        if not self.optimizer:
            return result
        for group in self.optimizer.param_groups:
            if param_name in group:
                result.append(group[param_name])
            else:
                result.append(0.0)
        return result

    def get_lr(self):
        return self._get_optimizer_param('lr')

J
Jeff Rasley 已提交
971 972 973
    def get_type(self):
        return self._get_optimizer_param('type')

O
Olatunji Ruwase 已提交
974 975 976 977 978 979
    def get_mom(self):
        return self._get_optimizer_param('betas')

    def _report_progress(self, step):
        lr = self.get_lr()
        mom = self.get_mom()
C
Chunyang Wen 已提交
980
        logger.info('rank:{} step={}, skipped={}, lr={}, mom={}'.format(
O
Olatunji Ruwase 已提交
981 982 983 984 985 986 987 988 989 990 991 992 993 994 995
            self.global_rank,
            step,
            self.skipped_steps,
            lr,
            mom))

    def allreduce_bucket(self, bucket):
        tensor = flatten(bucket)

        tensor_to_allreduce = tensor

        if self.allreduce_always_fp32():
            tensor_to_allreduce = tensor.float()

        if self.postscale_gradients():
996 997
            if self.gradient_predivide_factor() != 1.0:
                tensor_to_allreduce.mul_(1. / self.gradient_predivide_factor())
O
Olatunji Ruwase 已提交
998 999 1000 1001

            dist.all_reduce(tensor_to_allreduce, group=self.data_parallel_group)

            if self.gradient_average:
1002 1003
                if self.gradient_predivide_factor() != self.dp_world_size:
                    tensor_to_allreduce.mul_(self.gradient_predivide_factor() /
O
Olatunji Ruwase 已提交
1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027
                                             self.dp_world_size)
        else:
            tensor_to_allreduce.div_(self.dp_world_size)
            dist.all_reduce(tensor_to_allreduce, group=self.data_parallel_group)

        if self.allreduce_always_fp32() and tensor is not tensor_to_allreduce:
            tensor.copy_(tensor_to_allreduce)

        return tensor

    def allreduce_and_copy(self, small_bucket):
        allreduced = self.allreduce_bucket(small_bucket)
        for buf, synced in zip(small_bucket, unflatten(allreduced, small_bucket)):
            buf.copy_(synced)

    def allreduce_no_retain(self, bucket, numel_per_bucket=500000000):
        small_bucket = []
        numel = 0
        for tensor in bucket:
            small_bucket.append(tensor)
            numel = numel + tensor.numel()
            if numel > numel_per_bucket:
                self.allreduce_and_copy(small_bucket)
                small_bucket = []
1028
                numel = 0
O
Olatunji Ruwase 已提交
1029 1030 1031 1032 1033 1034
        if len(small_bucket) > 0:
            self.allreduce_and_copy(small_bucket)

    def buffered_allreduce_fallback(self, grads=None, elements_per_buffer=500000000):
        grads = []
        for param_name, param in self.module.named_parameters():
J
Jeff Rasley 已提交
1035 1036 1037 1038 1039 1040
            if param.grad is None:
                # In cases where there is an imbalance of empty grads across
                # ranks we must create empty grads, this will ensure that every
                # rank is reducing the same size. In some cases it may make
                # sense in the future to support the ability to average not
                # w.r.t. world size but with a different value.
1041 1042 1043 1044
                param.grad = torch.zeros(param.size(),
                                         dtype=param.dtype,
                                         device=param.device)
                grads.append(param.grad.data)
J
Jeff Rasley 已提交
1045
            else:
O
Olatunji Ruwase 已提交
1046 1047
                grad_data = param.grad.data
                if self.sparse_gradients_enabled(
J
Jeff Rasley 已提交
1048
                ) and param_name in self.csr_tensor_module_names:
O
Olatunji Ruwase 已提交
1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086
                    grads.append(CSRTensor(grad_data))
                else:
                    grads.append(grad_data)

        split_buckets = split_half_float_double_csr(grads)

        for i, bucket_tuple in enumerate(split_buckets):
            bucket_type, bucket = bucket_tuple
            if bucket_type == CSRTensor.type():
                self.csr_allreduce_no_retain(bucket)
            else:
                self.allreduce_no_retain(bucket, numel_per_bucket=elements_per_buffer)

    def csr_allreduce_no_retain(self, bucket):
        allreduced_csrs = self.csr_allreduce_bucket(bucket)
        # Densify csr tensor and copy back to original location
        for csr in allreduced_csrs:
            dense_tensor = csr.to_dense()
            csr.orig_dense_tensor.copy_(dense_tensor)

    def csr_allreduce_bucket(self, bucket):
        csr_list = []
        for csr in bucket:
            csr_list.append(self.csr_allreduce(csr))
        return csr_list

    def csr_allreduce(self, csr):
        # Pre-divide for fp16 stability
        csr.values.div_(self.dp_world_size)

        indices_device_list = self.csr_all_gather(csr.indices)
        values_device_list = self.csr_all_gather(csr.values)

        csr.indices = torch.cat(indices_device_list)
        csr.values = torch.cat(values_device_list)
        return csr

    def csr_all_gather(self, value):
1087
        my_size = torch.LongTensor([value.size()[0]]).to(self.device)
O
Olatunji Ruwase 已提交
1088 1089 1090 1091 1092 1093 1094 1095
        all_sizes = self.all_gather_scalar(my_size)
        max_size = torch.cat(all_sizes).max()
        fill_size = (max_size - my_size)

        assert value.dim() in [1, 2]
        if value.dim() == 1:
            if fill_size > 0:
                value = torch.cat([value, value.new_zeros(fill_size)])
1096
            tensor_list = [value.new_zeros(max_size) for _ in range(self.dp_world_size)]
O
Olatunji Ruwase 已提交
1097 1098 1099 1100 1101
        else:
            if fill_size > 0:
                value = torch.cat([value, value.new_zeros(fill_size, value.size()[1])])
            tensor_list = [
                value.new_zeros(max_size,
1102
                                value.size()[1]) for _ in range(self.dp_world_size)
O
Olatunji Ruwase 已提交
1103 1104 1105 1106 1107 1108
            ]

        dist.all_gather(tensor_list, value, group=self.data_parallel_group)
        tensors = []
        for dev_idx, t in enumerate(tensor_list):
            size = all_sizes[dev_idx][0]
1109 1110 1111
            tensors.append(
                t.index_select(0,
                               torch.LongTensor(range(size)).to(self.device)))
O
Olatunji Ruwase 已提交
1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126

        return tensors

    def all_gather_scalar(self, value):
        tensor_list = [value.new_zeros(value.size()) for _ in range(self.dp_world_size)]
        dist.all_gather(tensor_list, value, group=self.data_parallel_group)
        return tensor_list

    def module_state_dict(self, destination=None, prefix='', keep_vars=False):
        sd = self.module.state_dict(destination, prefix, keep_vars)
        return sd

    def load_module_state_dict(self, state_dict, strict=True):
        self.module.load_state_dict(state_dict, strict=strict)

1127 1128
    def _get_rank_zero_ckpt_name(self, checkpoints_path, tag, mp_rank, dp_rank):
        filename = 'zero_pp_rank_{}'.format(dp_rank)
O
Olatunji Ruwase 已提交
1129 1130 1131 1132 1133 1134
        zero_ckpt_name = os.path.join(
            checkpoints_path,
            str(tag),
            filename + '_mp_rank_{:02d}'.format(mp_rank) + 'optim_states.pt')
        return zero_ckpt_name

1135 1136 1137 1138 1139
    def _get_zero_ckpt_name(self, checkpoints_path, tag):
        mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
        pp_rank = torch.distributed.get_rank(group=self.optimizer.dp_process_group)
        return self._get_rank_zero_ckpt_name(checkpoints_path, tag, mp_rank, pp_rank)

O
Olatunji Ruwase 已提交
1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152
    def _get_ckpt_name(self, checkpoints_path, tag):

        mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
        ckpt_name = os.path.join(checkpoints_path,
                                 str(tag),
                                 'mp_rank_{:02d}'.format(mp_rank) + '_model_states.pt')
        return ckpt_name

    def _ensure_directory_exists(self, filename):
        dirname = os.path.dirname(filename)
        if not os.path.exists(dirname):
            os.makedirs(dirname)

J
Jeff Rasley 已提交
1153 1154 1155 1156 1157 1158
    def load_checkpoint(self,
                        load_dir,
                        tag,
                        load_module_strict=True,
                        load_optimizer_states=True,
                        load_lr_scheduler_states=True):
O
Olatunji Ruwase 已提交
1159 1160 1161 1162 1163
        r"""Load training checkpoint

        Arguments:
            load_dir: Required. Directory to load the checkpoint from
            tag: Required. Checkpoint tag used as a unique identifier for the checkpoint. Ex. Global Step.
J
Jeff Rasley 已提交
1164
            load_module_strict: Optional. Boolean to strictly enforce that the keys in state_dict of module and checkpoint match.
1165
            load_optimizer_states: Optional. Boolean to load the training optimizer states from Checkpoint. Ex. ADAM's momentum and variance
J
Jeff Rasley 已提交
1166
            load_lr_scheduler_states: Optional. Boolean to add the learning rate scheduler states from Checkpoint.
O
Olatunji Ruwase 已提交
1167 1168 1169 1170 1171
        Return:
            load_path: Path of the loaded checkpoint. None if loading the checkpoint failed
            client_state: State dictionary used for loading required training states in the client code.
        """

J
Jeff Rasley 已提交
1172 1173 1174 1175 1176
        load_path, client_states = self._load_checkpoint(load_dir,
                                                         tag,
                                                         load_module_strict=load_module_strict,
                                                         load_optimizer_states=load_optimizer_states,
                                                         load_lr_scheduler_states=load_lr_scheduler_states)
O
Olatunji Ruwase 已提交
1177 1178

        if self.zero_optimization() and load_path is not None:
1179 1180 1181
            self._load_zero_checkpoint(load_dir,
                                       tag,
                                       load_optimizer_states=load_optimizer_states)
O
Olatunji Ruwase 已提交
1182 1183 1184

        return load_path, client_states

J
Jeff Rasley 已提交
1185 1186 1187 1188 1189 1190
    def _load_checkpoint(self,
                         load_dir,
                         tag,
                         load_module_strict=True,
                         load_optimizer_states=True,
                         load_lr_scheduler_states=True):
O
Olatunji Ruwase 已提交
1191 1192 1193 1194

        load_path = self._get_ckpt_name(load_dir, tag)

        if not os.path.exists(load_path):
C
Chunyang Wen 已提交
1195
            logger.warn(
O
Olatunji Ruwase 已提交
1196 1197 1198 1199
                'Client provided checkpoint load path: {} does not exist ... skip checkpoint load'
                .format(load_path))
            return None, None

C
Chunyang Wen 已提交
1200
        logger.info('Loading checkpoint: {}'.format(load_path))
O
Olatunji Ruwase 已提交
1201 1202
        checkpoint = torch.load(load_path, map_location=lambda storage, loc: storage)

J
Jeff Rasley 已提交
1203 1204
        self.load_module_state_dict(state_dict=checkpoint['module'],
                                    strict=load_module_strict)
O
Olatunji Ruwase 已提交
1205
        if not self.zero_optimization():
1206 1207 1208 1209 1210 1211
            if self.fp16_enabled():
                self.optimizer.load_state_dict(
                    checkpoint['optimizer'],
                    load_optimizer_states=load_optimizer_states)
            else:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
O
Olatunji Ruwase 已提交
1212

J
Jeff Rasley 已提交
1213
        if load_lr_scheduler_states and self.lr_scheduler is not None:
O
Olatunji Ruwase 已提交
1214 1215 1216 1217 1218
            self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])

        self.csr_tensor_module_names = checkpoint['csr_tensor_module_names']
        self.global_steps = checkpoint['global_steps']
        self.skipped_steps = checkpoint['skipped_steps']
1219 1220
        self.loaded_checkpoint_mp_world_size = checkpoint['mp_world_size']
        self.loaded_checkpoint_dp_world_size = checkpoint['dp_world_size']
O
Olatunji Ruwase 已提交
1221 1222 1223
        deepspeed_states = [
            'module',
            'optimizer',
J
Jeff Rasley 已提交
1224
            'lr_scheduler',
O
Olatunji Ruwase 已提交
1225 1226
            'csr_tensor_module_names',
            'skipped_steps',
1227 1228 1229
            'global_steps',
            'dp_world_size',
            'mp_world_size'
O
Olatunji Ruwase 已提交
1230 1231 1232 1233 1234 1235 1236 1237 1238
        ]
        client_state = {
            key: value
            for key,
            value in checkpoint.items() if not key in deepspeed_states
        }

        return load_path, client_state

1239
    def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True):
1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250
        zero_sd_list = self._get_all_zero_checkpoints(load_dir, tag)
        if zero_sd_list is None:
            return

        self.optimizer.load_state_dict(
            state_dict_list=zero_sd_list,
            load_optimizer_states=load_optimizer_states,
            load_from_fp32_weights=self.zero_load_from_fp32_weights())
        print(
            f'loading {len(zero_sd_list)} zero partition checkpoints for rank {self.global_rank}'
        )
O
Olatunji Ruwase 已提交
1251

1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294
    def _get_mp_rank_zero_checkpoint_names(self, load_dir, tag, mp_rank, dp_world_size):
        zero_ckpt_names = []
        for dp_rank in range(dp_world_size):
            ckpt_name = self._get_rank_zero_ckpt_name(checkpoints_path=load_dir,
                                                      tag=tag,
                                                      mp_rank=mp_rank,
                                                      dp_rank=dp_rank)
            zero_ckpt_names.append(ckpt_name)

        return zero_ckpt_names

    def _get_all_zero_checkpoint_names(self,
                                       load_dir,
                                       tag,
                                       mp_world_size,
                                       dp_world_size):
        zero_ckpt_names = []
        for mp_rank in range(mp_world_size):
            mp_rank_ckpt_names = self._get_mp_rank_zero_checkpoint_names(
                load_dir=load_dir,
                tag=tag,
                mp_rank=mp_rank,
                dp_world_size=dp_world_size)
            zero_ckpt_names += mp_rank_ckpt_names

        return zero_ckpt_names

    def _get_all_zero_checkpoints(self, load_dir, tag):
        mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
        zero_ckpt_names = self._get_mp_rank_zero_checkpoint_names(
            load_dir=load_dir,
            tag=tag,
            mp_rank=mp_rank,
            dp_world_size=self.loaded_checkpoint_dp_world_size)
        invalid_zero_ckpt_paths = []
        for ckpt_name in zero_ckpt_names:
            if not os.path.exists(ckpt_name):
                invalid_zero_ckpt_paths.append(ckpt_name)

        if len(invalid_zero_ckpt_paths) > 0:
            logging.warn(
                f"Client provided zero checkpoint load paths: {invalid_zero_ckpt_paths} does not exist"
            )
O
Olatunji Ruwase 已提交
1295 1296
            return None

1297 1298 1299 1300 1301 1302 1303 1304 1305
        zero_sd_list = []
        for ckpt_name in zero_ckpt_names:
            zero_sd_list.append(torch.load(ckpt_name, map_location='cpu'))

        zero_optimizer_sd = [sd['optimizer_state_dict'] for sd in zero_sd_list]
        print(
            f"successfully loaded {len(zero_optimizer_sd)} ZeRO state_dicts for rank {self.global_rank}"
        )
        return zero_optimizer_sd
O
Olatunji Ruwase 已提交
1306 1307 1308 1309 1310 1311 1312 1313 1314 1315

    def save_checkpoint(self, save_dir, tag, client_state={}):
        r"""Save training checkpoint

        Arguments:
            save_dir: Required. Directory for saving the checkpoint
            tag: Required. Checkpoint tag used as a unique identifier for the checkpoint. Ex. Global Step.
            client_state: Optional. State dictionary used for saving required training states in the client code.
        """

1316 1317
        # This is to make sure the checkpoint names are created without collision
        # There seems to be issue creating them in parallel
O
Olatunji Ruwase 已提交
1318

J
Jeff Rasley 已提交
1319
        if self.save_non_zero_checkpoint:
1320
            self._create_checkpoint_file(save_dir, tag, False)
J
Jeff Rasley 已提交
1321 1322 1323
            self._save_checkpoint(save_dir, tag, client_state=client_state)

        if self.save_zero_checkpoint:
1324
            self._create_zero_checkpoint_files(save_dir, tag)
J
Jeff Rasley 已提交
1325
            self._save_zero_checkpoint(save_dir, tag)
O
Olatunji Ruwase 已提交
1326 1327 1328

        return True

1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342
    def _create_checkpoint_file(self, save_dir, tag, zero_checkpoint):
        name_function = self._get_zero_ckpt_name if zero_checkpoint else self._get_ckpt_name
        try:
            checkpoint_name = name_function(save_dir, tag)
            self._ensure_directory_exists(checkpoint_name)
        except:
            logger.error(f'Failed Saving model checkpoint to {save_dir} with tag {tag}')
            return False

        return True

    def _create_zero_checkpoint_files(self, save_dir, tag):
        success = True
        # zero checkpoint files are created sequentially
1343 1344
        for rank in range(self.world_size):
            if rank == self.global_rank:
1345 1346
                success = self._create_checkpoint_file(save_dir, tag, True)

O
Olatunji Ruwase 已提交
1347 1348
            dist.barrier()

1349 1350
        return success

O
Olatunji Ruwase 已提交
1351 1352 1353
    def _save_checkpoint(self, save_dir, tag, client_state={}):

        save_path = self._get_ckpt_name(save_dir, tag)
1354
        # self._ensure_directory_exists(save_path)
O
Olatunji Ruwase 已提交
1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369

        state = {
            'module':
            self.module_state_dict(),
            'optimizer':
            self.optimizer.state_dict()
            if self.optimizer and not self.zero_optimization() else None,
            'lr_scheduler':
            self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
            'csr_tensor_module_names':
            self.csr_tensor_module_names,
            'skipped_steps':
            self.skipped_steps,
            'global_steps':
            self.global_steps,
1370 1371 1372 1373
            'dp_world_size':
            self.dp_world_size,
            'mp_world_size':
            self.mp_world_size
O
Olatunji Ruwase 已提交
1374 1375 1376
        }
        state.update(client_state)

C
Chunyang Wen 已提交
1377
        logger.info('Saving model checkpoint: {}'.format(save_path))
O
Olatunji Ruwase 已提交
1378 1379 1380
        torch.save(state, save_path)

    def _save_zero_checkpoint(self, save_path, tag):
J
Jeff Rasley 已提交
1381
        zero_checkpoint_name = self._get_zero_ckpt_name(save_path, tag)
1382
        # self._ensure_directory_exists(zero_checkpoint_name)
O
Olatunji Ruwase 已提交
1383 1384
        zero_sd = {'optimizer_state_dict': self.optimizer.state_dict()}
        torch.save(zero_sd, zero_checkpoint_name)
C
Chunyang Wen 已提交
1385
        logger.info('zero checkpoint saved {}'.format(zero_checkpoint_name))