engine.py 56.6 KB
Newer Older
1 2 3 4
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
5 6 7 8

from types import MethodType

import torch
K
Karim Foda 已提交
9
from deepspeed import comm as dist
10

11
from deepspeed.utils import logger
12
from deepspeed.utils.timer import ThroughputTimer
13
from deepspeed.accelerator import get_accelerator
14 15

from ..engine import DeepSpeedEngine, MEMORY_OPT_ALLREDUCE_SIZE
16 17 18 19 20
from deepspeed.utils.timer import FORWARD_MICRO_TIMER, FORWARD_GLOBAL_TIMER, BACKWARD_MICRO_TIMER, \
    BACKWARD_GLOBAL_TIMER, BACKWARD_INNER_MICRO_TIMER, BACKWARD_INNER_GLOBAL_TIMER, \
    BACKWARD_REDUCE_MICRO_TIMER, BACKWARD_REDUCE_GLOBAL_TIMER, \
    STEP_MICRO_TIMER, STEP_GLOBAL_TIMER

21
from ..utils import PartitionedTensor
22
from ..dataloader import RepeatingLoader
23
from ..zero.config import ZeroStageEnum
24
from ..activation_checkpointing import checkpointing as ds_checkpointing
25

26
from .module import PipelineModule, PipelineError
27 28 29 30 31 32 33
from . import p2p
from . import schedule

TARGET_ID = -2
LOG_STAGE = -2
DATA_PARALLEL_ID = -2

34 35 36 37 38 39 40
BATCH_INPUT_TIMER = 'batch_input'
TRAIN_BATCH_TIMER = 'train_batch'
PIPE_SEND_OUTPUT_TIMER = 'pipe_send_output'
PIPE_SEND_GRAD_TIMER = 'pipe_send_grad'
PIPE_RECV_INPUT_TIMER = 'pipe_recv_input'
PIPE_RECV_GRAD_TIMER = 'pipe_recv_grad'

41 42 43 44 45 46 47 48 49 50 51 52 53 54

def is_even(number):
    return number % 2 == 0


mem_alloced = 0
mem_cached = 0


def _tensor_bytes(tensor):
    return tensor.numel() * tensor.element_size()


class PipelineEngine(DeepSpeedEngine):
S
Shaden Smith 已提交
55
    """ A training engine hybrid pipeline, data, and model parallel training.
56

S
Shaden Smith 已提交
57 58
    This engine is created by ``deepspeed.initialize()`` when a :class:`PipelineModule`
    is provided.
59
    """
60
    ID_TO_DTYPE = [
61 62
        torch.float32, torch.float64, torch.complex64, torch.complex128, torch.float16, torch.bfloat16, torch.uint8,
        torch.int8, torch.int16, torch.int32, torch.int64, torch.bool
63 64 65
    ]
    DTYPE_TO_ID = {dtype: id_ for id_, dtype in enumerate(ID_TO_DTYPE)}

66
    def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs):
67 68 69
        super().__init__(*super_args, **super_kwargs)
        assert isinstance(self.module, PipelineModule), "model must base PipelineModule"

70 71
        assert self.zero_optimization_stage(
        ) < ZeroStageEnum.gradients, "ZeRO-2 and ZeRO-3 are incompatible with pipeline parallelism"
72

73 74
        # We schedule the all-reduces, so disable it in super().backward()
        self.enable_backward_allreduce = False
75
        self.has_bool_tensors = has_bool_tensors
76 77
        self.eval_return_logits = False
        self.outputs = None
C
Conglong Li 已提交
78 79 80 81

        # used to disable the pipeline all-reduce when used with 1-bit Adam/1-bit LAMB
        self.pipeline_enable_backward_allreduce = True

82 83 84 85
        if self.elasticity_enabled():
            if not self.is_elastic_model_parallel_supported():
                assert not self.elasticity_enabled(), "Elasticity is not currently supported" \
                " with pipeline parallelism."
86

87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
        # pipeline step for logging
        self.log_batch_step_id = -1

        self.micro_batch_size = self.train_micro_batch_size_per_gpu()
        self.micro_batches = self.gradient_accumulation_steps()

        # Set Grid and Communication Groups
        self.grid = self.module._grid
        if self.grid.get_global_rank() == 0:
            logger.info(f'CONFIG: micro_batches={self.micro_batches} '
                        f'micro_batch_size={self.micro_batch_size}')

        self.global_rank = self.grid.get_global_rank()

        assert self.dp_world_size == self.grid.data_parallel_size
        assert self.train_batch_size() == \
            self.micro_batch_size * self.micro_batches * self.grid.data_parallel_size

        #  Set Stage Inf
        self.num_stages = self.grid.pipe_parallel_size
        self.stage_id = self.grid.get_stage_id()
        self.prev_stage = self.stage_id - 1
        self.next_stage = self.stage_id + 1

        self.data_iterator = None
        self.batch_fn = None

        self._force_grad_boundary = False

A
Alexander Jipa 已提交
116
        self.batch_timer = ThroughputTimer(batch_size=self.train_batch_size(),
117 118 119 120 121 122 123 124 125 126 127 128 129 130
                                           logging_fn=self.tput_log,
                                           monitor_memory=False,
                                           steps_per_output=self.steps_per_print())

        # PipelineEngine needs to handle data loading specially due to only the first
        # and last stages loading inputs/labels. We construct a sampler that uses
        if self.training_data:
            self._build_data_iter(self.training_data)

        self.is_pipe_parallel = self.grid.pipe_parallel_size > 1
        self.is_data_parallel = self.grid.data_parallel_size > 1
        self.is_model_parallel = self.grid.model_parallel_size > 1

        # Partition input/output buffers
131
        # XXX temporarily disable while I revert some partition hacks.
132
        self.is_pipe_partitioned = self.is_model_parallel
133
        self.is_grad_partitioned = self.is_model_parallel
134 135 136 137 138 139 140 141 142 143 144

        model_parameters = filter(lambda p: p.requires_grad, self.module.parameters())
        num_params = sum([p.numel() for p in model_parameters])
        unique_params = num_params
        # Subtract tied parameters if we don't own them
        if self.module.tied_comms:
            tied_params = 0
            for key, d in self.module.tied_comms.items():
                if self.global_rank != min(d['ranks']):
                    tied_params += sum(p.numel() for p in d['module'].parameters())
            unique_params -= tied_params
145
        params_tensor = torch.LongTensor(data=[num_params, unique_params]).to(self.device)
146 147 148 149 150 151 152 153 154 155 156 157 158
        dist.all_reduce(params_tensor, group=self.grid.get_model_parallel_group())
        params_tensor = params_tensor.tolist()
        total_params = params_tensor[0]
        unique_params = params_tensor[1]
        if self.grid.data_parallel_id == 0:
            logger.info(f'RANK={self.global_rank} '
                        f'STAGE={self.stage_id} '
                        f'LAYERS={self.module._local_stop - self.module._local_start} '
                        f'[{self.module._local_start}, {self.module._local_stop}) '
                        f'STAGE_PARAMS={num_params} ({num_params/1e6:0.3f}M) '
                        f'TOTAL_PARAMS={total_params} ({total_params/1e6:0.3f}M) '
                        f'UNIQUE_PARAMS={unique_params} ({unique_params/1e6:0.3f}M)')

159
        #initialize peer-2-peer communication and allreduce groups
160 161 162 163 164 165
        if self.is_pipe_parallel:
            p2p.init_process_groups(self.grid)

        # Pipeline buffers
        self.num_pipe_buffers = 0
        self.pipe_buffers = {
166 167 168 169
            'inputs': [],  # batch input and received activations
            'labels': [],  # labels from batch input
            'outputs': [],  # activations
            'output_tensors': [],  # tensor object to preserve backward graph
170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
        }
        self.pipe_recv_buf = None
        self.grad_layer = None

        self.meta_buffer = None

        self.first_output_send = True
        self.first_gradient_send = True

        #stores the loss for the current micro batch being processed
        self.loss = torch.tensor(0.0).to(self.device)

        #stores the loss for the entire batch
        self.total_loss = None
        self.agg_loss = torch.tensor(0.0, requires_grad=False).to(self.device)
        self.dp_group_loss = torch.tensor(0.0, requires_grad=False).to(self.device)

        if self._config.pipeline['activation_checkpoint_interval'] > 0:
188
            self.module.activation_checkpoint_interval = self._config.pipeline['activation_checkpoint_interval']
189

190 191
        self.module.checkpoint_parallel_write_pipeline = self._config.checkpoint_parallel_write_pipeline

192 193 194
        if self.is_last_stage():
            self.loss_model = self.module.loss_fn

195
        self.has_attention_mask = self.module.__class__.__name__ == 'GPT2ModelPipe'
196 197 198 199 200 201 202 203 204 205 206 207 208 209 210
        # Initialize pipeline communicators. Just send a 0.
        if is_even(self.stage_id):
            if not self.is_last_stage():
                p2p.send(self.loss, self.next_stage)
            if not self.is_first_stage():
                p2p.recv(self.loss, self.prev_stage)
        else:
            if not self.is_first_stage():
                p2p.recv(self.loss, self.prev_stage)
            if not self.is_last_stage():
                p2p.send(self.loss, self.next_stage)

        # XXX look into timer reporting timing
        # Initialize some timers because of early weirdness.
        if self.wall_clock_breakdown():
211 212 213 214 215 216 217 218 219 220 221 222
            self.timers(FORWARD_MICRO_TIMER).start()
            self.timers(FORWARD_MICRO_TIMER).stop()
            self.timers(BACKWARD_MICRO_TIMER).start()
            self.timers(BACKWARD_MICRO_TIMER).stop()
            self.timers(BACKWARD_INNER_MICRO_TIMER).start()
            self.timers(BACKWARD_INNER_MICRO_TIMER).stop()
            self.timers(BACKWARD_REDUCE_MICRO_TIMER).start()
            self.timers(BACKWARD_REDUCE_MICRO_TIMER).stop()
            self.timers(BACKWARD_REDUCE_GLOBAL_TIMER).start()
            self.timers(BACKWARD_REDUCE_GLOBAL_TIMER).stop()
            self.timers(STEP_MICRO_TIMER).start()
            self.timers(STEP_MICRO_TIMER).stop()
223

224 225 226 227
    def set_has_attention_mask(self, value):
        assert isinstance(value, bool)
        self.has_attention_mask = value

228
    def _build_data_iter(self, dataset):
229 230 231 232
        sampler = torch.utils.data.distributed.DistributedSampler(dataset,
                                                                  num_replicas=self.dp_world_size,
                                                                  rank=self.mpu.get_data_parallel_rank(),
                                                                  shuffle=False)
233 234 235 236 237 238
        # Build a loader and make it repeating.
        pipe_dataloader = self.deepspeed_io(dataset, data_sampler=sampler)
        pipe_dataloader = RepeatingLoader(pipe_dataloader)
        self.set_dataloader(pipe_dataloader)

    def _exec_reduce_tied_grads(self):
L
Leo Gao 已提交
239 240 241 242 243 244 245 246 247 248
        # We need to run this first to write to self.averaged_gradients;
        # since this class turns `enable_backward_allreduce` off,
        # `self.overlapping_partition_gradients_reduce_epilogue()` defined in the DeepSpeedEngine
        # never actually runs. I suspect this is because of efficiency problems; get_flat_partition in
        # stage2.py might do something expensive; someone will have to look into that later. But
        # in the meantime, this fixes ZeRO2 + Pipelining enough to run a demo. Further profiling
        # needed to decide if it actually breaks everything.
        # (see https://github.com/EleutherAI/gpt-neox/issues/62#issuecomment-761471944)
        if self.zero_optimization_partition_gradients():
            self.optimizer.overlapping_partition_gradients_reduce_epilogue()
249 250 251 252 253

        weight_group_list = self.module.get_tied_weights_and_groups()
        for weight, group in weight_group_list:
            grad = weight._hp_grad if self.bfloat16_enabled() else weight.grad
            dist.all_reduce(grad, group=group)
254 255 256

    def _exec_reduce_grads(self):
        self._force_grad_boundary = True
J
Jeff Rasley 已提交
257
        if self.pipeline_enable_backward_allreduce:
258
            if self.bfloat16_enabled():
259 260
                # PP+BF16 work for ZeRO Stage 1
                self._bf16_reduce_grads()
261 262
            else:
                self.allreduce_gradients(bucket_size=MEMORY_OPT_ALLREDUCE_SIZE)
263 264
        self._force_grad_boundary = False

265 266 267 268 269 270
    def _bf16_reduce_grads(self):
        # Make our own list of gradients from the optimizer's FP32 grads
        grads = []
        self.buffered_allreduce_fallback(grads=self.optimizer.get_grads_for_reduction(),
                                         elements_per_buffer=MEMORY_OPT_ALLREDUCE_SIZE)

271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286
    def _reserve_pipe_buffers(self, num_buffers):
        """Ensure that each pipeline buffer has at least ``num_buffers`` slots.

        This method only reserves slots and does not allocate tensors.

        Args:
            num_buffers (int): The number of buffers to reserve.
        """
        if self.num_pipe_buffers >= num_buffers:
            return

        num_added = num_buffers - self.num_pipe_buffers
        for key in self.pipe_buffers:
            self.pipe_buffers[key].extend([None] * num_added)
        self.num_pipe_buffers = num_buffers

C
Conglong Li 已提交
287 288 289 290 291 292 293 294 295 296
    def reset_activation_shape(self):
        """Reset the buffers when the shape of activation and gradient change.
        For example, for curriculum learning that changes the seqlen of each
        sample, we need to call this whenever the seqlen is going to change.
        """
        self.first_output_send = True
        self.pipe_recv_buf = None
        self.grad_layer = None
        self.meta_buffer = None

297
    def train_batch(self, data_iter=None):
S
Shaden Smith 已提交
298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316
        """Progress the pipeline to train the next batch of data. The engine will ingest
        ``self.train_batch_size()`` total samples collectively across all workers.


        An iterator that over training data should be provided as an argument
        unless ``deepspeed.initialize()`` was provided a training set. In that event,
        the training data will automatically be read.


        .. warning::
            A total of ``self.gradient_accumulation_steps()`` entries will be pulled
            from ``data_iter`` by each pipeline. There must be sufficient
            data left in ``data_iter`` or else a ``StopIteration`` will halt training.

            DeepSpeed provides a convenience class :class:`deepspeed.utils.RepeatingLoader`
            that wraps data loaders to automatically restart upon a ``StopIteration``.

        Args:
            data_iter (Iterator, optional): Iterator of training data.
317 318

        Returns:
S
Shaden Smith 已提交
319
            The arithmetic mean of the losses computed this batch.
320 321
        """
        if not torch._C.is_grad_enabled():
322
            raise RuntimeError(f'train_batch() requires gradients enabled. Use eval_batch() instead.')
323

C
Conglong Li 已提交
324
        # Curriculum learning could change activation shape
325 326
        if self.curriculum_enabled_legacy():
            new_difficulty = self.curriculum_scheduler_legacy.update_difficulty( \
C
Conglong Li 已提交
327
                self.global_steps + 1)
328
            if self.global_steps == 0 or self.curriculum_scheduler_legacy.first_step:
C
Conglong Li 已提交
329
                self.reset_activation_shape()
330 331
                self.curriculum_scheduler_legacy.first_step = False
            elif new_difficulty != self.curriculum_scheduler_legacy.get_difficulty( \
C
Conglong Li 已提交
332 333 334
                self.global_steps):
                self.reset_activation_shape()

335 336 337 338 339
        if data_iter:
            self.set_dataiterator(data_iter)

        self.module.train()
        self.total_loss = None
340
        self._compute_loss = True
341 342

        # Do the work
343
        self.timers(TRAIN_BATCH_TIMER).start()
344 345 346 347 348
        sched = schedule.TrainSchedule(micro_batches=self.micro_batches,
                                       stages=self.num_stages,
                                       stage_id=self.stage_id)
        self._exec_schedule(sched)
        self.agg_train_loss = self._aggregate_total_loss()
349

350
        self.timers(TRAIN_BATCH_TIMER).stop()
351 352 353

        if self.global_steps % self.steps_per_print() == 0:
            if self.global_rank == 0:
354
                elapsed = self.timers(TRAIN_BATCH_TIMER).elapsed(reset=True) / 1000.0
355
                iter_time = elapsed / self.steps_per_print()
356
                tput = self.train_batch_size() / iter_time
357 358 359 360 361
                print(f'steps: {self.global_steps} '
                      f'loss: {self.agg_train_loss:0.4f} '
                      f'iter time (s): {iter_time:0.3f} '
                      f'samples/sec: {tput:0.3f}')

362 363
        # Monitoring
        if self.global_rank == 0 and self.monitor.enabled:
364
            self.summary_events = [(f'Train/Samples/train_loss', self.agg_train_loss.mean().item(),
365 366
                                    self.global_samples)]
            self.monitor.write_events(self.summary_events)
367

368
        if self.wall_clock_breakdown() and self.global_steps % self.steps_per_print() == 0:
369 370 371 372 373 374
            self.timers.log([
                PIPE_SEND_OUTPUT_TIMER,
                PIPE_SEND_GRAD_TIMER,
                PIPE_RECV_INPUT_TIMER,
                PIPE_RECV_GRAD_TIMER,
            ])
375 376 377 378

        # TODO: should return precisely what loss returned and allow others to be queried?
        return self.agg_train_loss

379
    def eval_batch(self, data_iter, return_logits=False, compute_loss=True, reduce_output='avg'):
S
Shaden Smith 已提交
380 381 382
        """Evaluate the pipeline on a batch of data from ``data_iter``. The
        engine will evaluate ``self.train_batch_size()`` total samples
        collectively across all workers.
383 384 385 386 387 388 389 390 391

        This method is equivalent to:

        .. code-block:: python

            module.eval()
            with torch.no_grad():
                output = module(batch)

S
Shaden Smith 已提交
392 393 394 395 396 397 398 399 400 401 402
        .. warning::
            A total of ``self.gradient_accumulation_steps()`` entries will be pulled
            from ``data_iter`` by each pipeline. There must be sufficient
            data left in ``data_iter`` or else a ``StopIteration`` will halt training.

            DeepSpeed provides a convenience class :class:`deepspeed.utils.RepeatingLoader`
            that wraps data loaders to automatically restart upon a ``StopIteration``.

        Args:
            data_iter (Iterator): Iterator of data to evaluate.

403
        Returns:
S
Shaden Smith 已提交
404
            The arithmetic mean of the losses computed this batch.
405
        """
406
        self.eval_return_logits = return_logits
407
        self.module.eval()
408

C
Conglong Li 已提交
409
        # Curriculum learning could change activation shape
410 411
        if self.curriculum_enabled_legacy():
            new_difficulty = self.curriculum_scheduler_legacy.update_difficulty( \
C
Conglong Li 已提交
412
                self.global_steps + 1)
413
            if self.global_steps == 0 or self.curriculum_scheduler_legacy.first_step:
C
Conglong Li 已提交
414
                self.reset_activation_shape()
415 416
                self.curriculum_scheduler_legacy.first_step = False
            elif new_difficulty != self.curriculum_scheduler_legacy.get_difficulty( \
C
Conglong Li 已提交
417 418 419
                self.global_steps):
                self.reset_activation_shape()

420 421 422
        eval_output = None

        self._compute_loss = compute_loss
423 424 425 426 427 428 429 430 431

        # Use the provided data iterator
        train_iterator = self.data_iterator
        self.set_dataiterator(data_iter)

        # Do the work
        sched = schedule.InferenceSchedule(micro_batches=self.micro_batches,
                                           stages=self.num_stages,
                                           stage_id=self.stage_id)
432 433 434 435

        # prevent dead-lock with multiple evals sequence
        dist.barrier()

436 437 438
        with torch.no_grad():
            self._exec_schedule(sched)

439 440 441 442 443 444
        if self.is_last_stage():
            eval_output = self._reduce_outputs(self.fwd_outputs, reduce=reduce_output)

        if compute_loss:
            eval_output = self._bcast_pipe_scalar(eval_output)

445
        if self.global_rank == 0 and self.monitor.enabled:
446
            self.summary_events = [(f'Train/Samples/eval_loss', eval_output.mean().item(), self.global_samples)]
447
            self.monitor.write_events(self.summary_events)
448 449 450 451 452 453

        # Restore the training iterator
        self.set_dataiterator(train_iterator)

        # Reset any buffers that may have been populated during the forward passes.
        #ds_checkpointing.reset()
454 455 456 457 458
        self.eval_return_logits = False
        if return_logits:
            outputs = self.outputs
            self.outputs = None
            return eval_output, outputs
459
        return eval_output
460

461 462 463 464 465 466 467 468 469 470 471 472 473
    def set_train_batch_size(self, train_batch_size):
        """Adjust the global batch size by increasing or decreasing the number of
        micro-batches (i.e., gradient accumulation steps). The size of each micro-batch
        (i.e., ``train_micro_batch_size_per_gpu``) is not changed.
        Args:
            train_batch_size (int): The new global batch size for training.
        Raises:
            ValueError: if ``train_batch_size`` is not divisible by the
                configured micro-batch size and data parallelism.
        """
        super().set_train_batch_size(train_batch_size)
        self.micro_batches = self.gradient_accumulation_steps()

S
Shaden Smith 已提交
474 475 476 477 478 479 480 481
    def is_first_stage(self):
        """True if this process is in the first stage in the pipeline."""
        return self.stage_id == 0

    def is_last_stage(self):
        """True if this process is in the last stage in the pipeline."""
        return self.stage_id == self.num_stages - 1

482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505
    def _reduce_outputs(self, outputs, reduce='avg', reduce_dp=True):
        if reduce is None:
            return outputs

        if reduce.lower() == 'avg':
            # first sum over all microbatches
            if torch.is_tensor(outputs[0]):
                reduced = sum(outputs)
            else:
                assert isinstance(outputs, (list, tuple))
                reduced = [torch.zeros_like(o) for o in outputs[0]]
                for idx, out in outputs:
                    reduced[idx] += out

            # Average over the microbatches
            reduced = self._scale_loss_by_gas(reduced)

            # Average over DP groups
            if reduce_dp and self.is_data_parallel:
                if torch.is_tensor(reduced):
                    dist.all_reduce(reduced, group=self.mpu.get_data_parallel_group())
                    reduced /= self.dp_world_size
                else:
                    for idx in range(len(reduced)):
506
                        dist.all_reduce(reduced[idx], group=self.mpu.get_data_parallel_group())
507 508 509 510 511 512 513 514 515 516 517 518 519
                        reduced[idx] /= self.dp_world_size

            return reduced
        else:
            raise NotImplementedError(f'reduction type {reduce} not supported.')

    def _bcast_pipe_scalar(self, data, src_rank=None, dtype=torch.float32):
        # Default to last stage (e.g., for broadcasting loss)
        if src_rank is None:
            src_rank = self.grid.stage_to_global(self.num_stages - 1)
        assert src_rank in self.grid.pp_group

        if self.global_rank == src_rank:
520
            result = data.clone().detach().type(dtype).to(self.device)
521 522 523
        else:
            result = torch.Tensor([0.]).type(dtype).to(self.device)

524
        dist.broadcast(tensor=result, src=src_rank, group=self.mpu.get_pipe_parallel_group())
525 526 527

        return result

528 529 530
    def _aggregate_total_loss(self):
        # Scale loss, average among DP ranks, and bcast loss to the rest of my DP group
        if self.is_last_stage():
531
            loss = self._scale_loss_by_gas(self.total_loss)
532 533 534 535 536 537 538 539 540 541 542
            self.dp_group_loss = loss.clone().detach()

            ## Average loss across all data-parallel groups
            agg_loss = self.dp_group_loss.clone().detach()
            #print(f'RANK={self.global_rank} bcast SENDER src={self.global_rank} group={self.grid.pp_group}', flush=True)
            if self.is_data_parallel:
                dist.all_reduce(agg_loss, group=self.mpu.get_data_parallel_group())
                agg_loss /= self.dp_world_size

            assert self.global_rank in self.grid.pp_group
            losses = torch.Tensor([self.dp_group_loss, agg_loss]).to(self.device)
M
mzl 已提交
543 544
            if self.is_pipe_parallel:
                dist.broadcast(tensor=losses, src=self.global_rank, group=self.mpu.get_pipe_parallel_group())
545 546 547 548 549
        else:
            # Get loss from last stage
            src_rank = self.grid.stage_to_global(self.num_stages - 1)
            assert src_rank in self.grid.pp_group
            losses = torch.Tensor([0., 0.]).to(self.device)
550
            dist.broadcast(tensor=losses, src=src_rank, group=self.grid.get_pipe_parallel_group())
551 552 553 554 555 556
            self.dp_group_loss = losses[0].clone().detach()
            agg_loss = losses[1].clone().detach()

        return agg_loss

    def set_dataloader(self, loader):
S
Shaden Smith 已提交
557
        """"""
558 559 560 561 562 563 564 565 566 567 568
        if self.is_first_stage() or self.is_last_stage():
            self.training_dataloader = loader
            self.data_iterator = iter(self.training_dataloader)

    def set_dataiterator(self, iterator):
        """ Store an iterator to sample for training data. """
        if self.is_first_stage() or self.is_last_stage():
            self.training_dataloader = None
            self.data_iterator = iterator

    def set_batch_fn(self, fn):
569 570 571 572 573
        """Execute a post-processing function on input data.

        Args:
            fn (function): The function to run.
        """
574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604
        self.batch_fn = fn

    def is_gradient_accumulation_boundary(self):
        """True if the engine is executing a gradient reduction or optimizer step instruction.

        This is overridden from :class:`DeepSpeedEngine` to force reductions
        and steps when the pipeline engine is instructed to do so.

        Returns:
            bool: whether reductions and optimizer steps should occur.
        """
        return self._force_grad_boundary

    def log_for_device(self, *msg):
        if LOG_STAGE == self.stage_id or LOG_STAGE == -1:
            if DATA_PARALLEL_ID == self.grid.data_parallel_id or DATA_PARALLEL_ID == -1:
                print(
                    f'RANK={dist.get_rank()} '
                    f'PIPE-ID={self.stage_id} '
                    f'DATA-ID={self.grid.data_parallel_id} '
                    f'MBATCH-ID={self.microbatch_id} '
                    f'STEP-ID={self.log_batch_step_id} '
                    '::',
                    *msg,
                    flush=True)

    def tput_log(self, *msg):
        if self.global_rank == 0 and self.global_steps % self.steps_per_print() == 0:
            print(*msg)

    def _next_batch(self):
605
        # If using 3D parallelism, only some first-stage ranks may do IO
606
        batch = None
607
        if self.data_iterator is not None:
608 609
            batch = next(self.data_iterator)

610
        # Any post-processing, like broadcasting across a slice-parallel group.
611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626
        if self.batch_fn:
            batch = self.batch_fn(batch)

        return batch

    def _exec_forward_pass(self, buffer_id):
        self.tput_timer.start()
        self.mem_status('BEFORE FWD', reset_max=True)

        if isinstance(self.pipe_buffers['inputs'][buffer_id], tuple):
            inputs = tuple(t.clone() for t in self.pipe_buffers['inputs'][buffer_id])
        else:
            inputs = self.pipe_buffers['inputs'][buffer_id].clone()

        # collect the partitioned input from the previous stage
        if self.is_pipe_partitioned and not self.is_first_stage():
627 628 629
            part_input = PartitionedTensor.from_meta(meta=inputs[0],
                                                     local_part=inputs[1],
                                                     group=self.grid.get_slice_parallel_group())
630

631
            inputs = (part_input.full(), *inputs[2:])
632 633 634 635
            inputs[0].requires_grad = True
            # skip mask
            #inputs[1].requires_grad = True
            part_input = None
636
            inputs = inputs[0] if len(inputs) == 1 else inputs
637 638 639 640 641 642 643 644
            self.pipe_buffers['inputs'][buffer_id] = inputs

        # Zero out the gradients each time we use the tensor because only the data in
        # tensor changes across batches
        self._zero_grads(inputs)

        outputs = super().forward(inputs)

645 646 647 648 649
        # Reset activation checkpointing buffers.
        # Need to call this between evaluation iterations
        if not self.module.training:
            ds_checkpointing.reset()

650 651
        # Partition the outputs if we are not the last stage
        if self.is_pipe_partitioned and not self.is_last_stage():
652 653 654
            if isinstance(outputs, tuple):
                first_output = outputs[0]
                # TODO: Improve pipe partitioning to pass multiple tensors that require grads
655
                assert all([torch.is_tensor(elt) and elt.requires_grad is False for elt in outputs[1:]])
656 657 658 659 660 661
                outputs_tail = outputs[1:]
            elif torch.is_tensor(outputs):
                first_output = outputs
                outputs_tail = []
            else:
                raise ValueError("expecting a tensor or a tuple of tensors")
662
            part = PartitionedTensor(tensor=first_output, group=self.grid.get_slice_parallel_group())
663
            # Clear the large output data, but save the computation graph
664 665
            first_output.data = torch.zeros(1)
            self.pipe_buffers['output_tensors'][buffer_id] = first_output
666
            # Inject the partitioned tensor into the output before sending
667
            outputs = (part.to_meta(), part.data(), *outputs_tail)
668 669 670 671 672 673
            part = None

        self.pipe_buffers['outputs'][buffer_id] = outputs

        # Optionally compute loss on the last device
        if self.is_last_stage():
J
Jeff Rasley 已提交
674
            if self._compute_loss and self.module.loss_fn is not None:
675
                labels = self.pipe_buffers['labels'][buffer_id]
J
Jeff Rasley 已提交
676
                self.loss = self.module.loss_fn(outputs, labels)
677 678 679
            else:
                # Some models just return loss from forward()
                self.loss = outputs
680 681
            if self.eval_return_logits:
                self.outputs = outputs
682
            if isinstance(self.loss, torch.Tensor):
683 684
                self.fwd_outputs.append(self.loss.detach())

685 686 687 688
                if self.total_loss is None:
                    self.total_loss = torch.zeros_like(self.loss)
                self.total_loss += self.loss.detach()
            else:
689 690
                self.fwd_outputs.append([l.detach() for l in self.loss])

691 692 693 694 695 696 697 698 699 700 701 702 703 704
                if self.total_loss is None:
                    self.total_loss = [torch.zeros_like(l) for l in self.loss]
                for idx, l in enumerate(self.loss):
                    self.total_loss[idx] += l.detach()

    def _exec_backward_pass(self, buffer_id):
        assert self.optimizer is not None, "must provide optimizer during " \
                                           "init in order to use backward"

        self.mem_status('BEFORE BWD', reset_max=True)

        # The last stage just runs backward on the loss using DeepSpeed's typical
        # mechanisms.
        if self.is_last_stage():
705
            super().backward(self.loss)
706 707 708 709 710 711
            self.mem_status('AFTER BWD')
            return

        outputs = self.pipe_buffers['outputs'][buffer_id]

        if self.wall_clock_breakdown():
712 713 714 715
            self.timers(BACKWARD_MICRO_TIMER).start()
            self.timers(BACKWARD_GLOBAL_TIMER).start()
            self.timers(BACKWARD_INNER_MICRO_TIMER).start()
            self.timers(BACKWARD_INNER_GLOBAL_TIMER).start()
716 717 718 719 720

        # Reconstruct if we previously partitioned the output. We must be
        # careful to also restore the computational graph of the tensors we partitioned.
        if self.is_pipe_partitioned:
            if self.is_grad_partitioned:
721 722 723
                part_output = PartitionedTensor.from_meta(meta=outputs[0],
                                                          local_part=outputs[1],
                                                          group=self.grid.get_slice_parallel_group())
724
                self.pipe_buffers['output_tensors'][buffer_id].data = part_output.full()
725
                outputs = (self.pipe_buffers['output_tensors'][buffer_id], *outputs[2:])
726 727
            else:
                # Already restored from partition
728 729
                self.pipe_buffers['output_tensors'][buffer_id].data = outputs[0]
                outputs = (self.pipe_buffers['output_tensors'][buffer_id], *outputs[1:])
730 731 732 733

        grad_tensors = self.grad_layer
        if self.is_grad_partitioned:
            #print(f'RANK={self.global_rank} BEFORE-BWD restoring grad={self.grad_layer[0].size()} {self.grad_layer[1].size()}')
734 735 736
            part_grad = PartitionedTensor.from_meta(meta=self.grad_layer[0],
                                                    local_part=self.grad_layer[1],
                                                    group=self.grid.get_slice_parallel_group())
737
            grad_tensors = (part_grad.full(), *grad_tensors[2:])
738 739 740
            part_grad = None
            #print(f'RANK={self.global_rank} BEFORE-BWD restored grad={self.grad_layer[0].size()} {self.grad_layer[1].size()}')

741 742 743 744
        if self.bfloat16_enabled() and not self.is_last_stage():
            # manually call because we don't call optimizer.backward()
            self.optimizer.clear_lp_grads()

745 746 747 748 749 750 751 752
        # This handles either a single tensor or tuple of tensors.
        if isinstance(outputs, tuple):
            out_tensors = [t for t in outputs if t.is_floating_point()]
            assert len(out_tensors) == len(grad_tensors)
            torch.autograd.backward(tensors=out_tensors, grad_tensors=grad_tensors)
        else:
            torch.autograd.backward(tensors=(outputs, ), grad_tensors=(grad_tensors, ))

753 754 755 756
        if self.bfloat16_enabled() and not self.is_last_stage():
            # manually call because we don't call optimizer.backward()
            self.optimizer.update_hp_grads(clear_lp_grads=False)

757 758 759 760 761 762
        # Free up the memory from the output of forward()
        self.pipe_buffers['output_tensors'][buffer_id] = None
        self.pipe_buffers['outputs'][buffer_id] = None
        grad_tensors = None

        if self.wall_clock_breakdown():
763 764 765 766
            self.timers(BACKWARD_INNER_MICRO_TIMER).stop()
            self.timers(BACKWARD_INNER_GLOBAL_TIMER).stop()
            self.timers(BACKWARD_MICRO_TIMER).stop()
            self.timers(BACKWARD_GLOBAL_TIMER).stop()
767 768 769 770 771

        self.mem_status('AFTER BWD')

    def _exec_load_micro_batch(self, buffer_id):
        if self.wall_clock_breakdown():
772
            self.timers(BATCH_INPUT_TIMER).start()
773 774 775 776 777 778 779 780 781

        batch = self._next_batch()

        if self.is_first_stage():
            loaded = None
            if torch.is_tensor(batch[0]):
                loaded = batch[0].clone().to(self.device).detach()
                loaded.requires_grad = loaded.is_floating_point()
            else:
S
Satpal Singh Rathore 已提交
782
                assert isinstance(batch[0], (tuple, list))
783 784 785 786 787 788 789 790 791 792 793 794 795 796 797
                # Assume list or tuple
                loaded = []
                for x in batch[0]:
                    assert torch.is_tensor(x)
                    mine = x.clone().detach().to(self.device)
                    mine.requires_grad = mine.is_floating_point()
                    loaded.append(mine)
                loaded = tuple(loaded)

            self.pipe_buffers['inputs'][buffer_id] = loaded

        if self.is_last_stage():
            loaded = batch[1]
            if torch.is_tensor(batch[1]):
                loaded = batch[1].to(self.device)
798 799
            # XXX: torch 1.6.0 DataLoader will auto convert tuple to list
            elif isinstance(batch[1], (tuple, list)):
800 801 802 803 804 805 806 807 808 809
                loaded = []
                for x in batch[1]:
                    assert torch.is_tensor(x)
                    x = x.to(self.device).detach()
                    loaded.append(x)
                loaded = tuple(loaded)

            self.pipe_buffers['labels'][buffer_id] = loaded

        if self.wall_clock_breakdown():
810
            self.timers(BATCH_INPUT_TIMER).stop()
811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852

    def _send_tensor_meta(self, buffer, recv_stage):
        """ Communicate metadata about upcoming p2p transfers.

        Metadata is communicated in this order:
            * type (0: tensor, 1: list)
            * num_tensors if type=list
            foreach tensor in buffer:
                * ndims
                * shape
        """
        send_bytes = 0
        if isinstance(buffer, torch.Tensor):
            type_tensor = torch.LongTensor(data=[0]).to(self.device)
            p2p.send(type_tensor, recv_stage)
            send_shape = torch.LongTensor(data=buffer.size()).to(self.device)
            send_ndims = torch.LongTensor(data=[len(buffer.size())]).to(self.device)
            p2p.send(send_ndims, recv_stage)
            p2p.send(send_shape, recv_stage)
            send_bytes += _tensor_bytes(buffer)
        elif isinstance(buffer, list):
            assert (False)
            type_tensor = torch.LongTensor(data=[1]).to(self.device)
            p2p.send(type_tensor, recv_stage)
            count_tensor = torch.LongTensor(data=[len(buffer)]).to(self.device)
            p2p.send(count_tensor, recv_stage)
            for tensor in buffer:
                assert isinstance(tensor, torch.Tensor)
                send_shape = torch.LongTensor(data=tensor.size()).to(self.device)
                send_ndims = torch.LongTensor(data=[len(tensor.size())]).to(self.device)
                p2p.send(send_ndims, recv_stage)
                p2p.send(send_shape, recv_stage)
                send_bytes += _tensor_bytes(tensor)
        elif isinstance(buffer, tuple):
            type_tensor = torch.LongTensor(data=[2]).to(self.device)
            p2p.send(type_tensor, recv_stage)
            count_tensor = torch.LongTensor(data=[len(buffer)]).to(self.device)
            p2p.send(count_tensor, recv_stage)
            for idx, tensor in enumerate(buffer):
                assert isinstance(tensor, torch.Tensor)
                send_shape = torch.LongTensor(data=tensor.size()).to(self.device)
                send_ndims = torch.LongTensor(data=[len(tensor.size())]).to(self.device)
853
                send_dtype = torch.LongTensor(data=[self.DTYPE_TO_ID[tensor.dtype]]).to(self.device)
854
                p2p.send(send_dtype, recv_stage)
855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908
                p2p.send(send_ndims, recv_stage)
                p2p.send(send_shape, recv_stage)
                # Useful for performance debugging.
                '''
                new_bytes = _tensor_bytes(tensor)
                send_bytes += _tensor_bytes(tensor)
                # Useful for performance debugging.
                if self.grid.data_parallel_id == 0:
                    print(
                        f'STAGE={self.stage_id} pipe-send-volume[{idx}]: shape={send_shape} {new_bytes/1024**2:0.2f}MB'
                    )
                '''
        else:
            raise NotImplementedError(f'Could not send meta type {type(buffer)}')

        # Useful for performance debugging.
        '''
        if self.grid.data_parallel_id == 0:
            print(f'STAGE={self.stage_id} pipe-send-volume: {send_bytes/1024**2:0.2f}MB')
        '''

    def _recv_tensor_meta(self, send_stage):
        """Receive metadata about upcoming p2p transfers and return allocated buffers.

        Metadata is communicated in this order:
            * type (0: tensor, 1: list)
            * num_tensors if type=list
            foreach tensor in buffer:
                * ndims
                * shape

        Returns:
            Allocated buffer for receiving from send_stage.
        """

        type_tensor = torch.LongTensor(data=[0]).to(self.device)
        p2p.recv(type_tensor, send_stage)
        recv_type = type_tensor.item()

        # A single tensor will be sent.
        if recv_type == 0:
            recv_ndims = torch.LongTensor(data=[0]).to(self.device)
            p2p.recv(recv_ndims, send_stage)
            recv_ndims = recv_ndims.item()
            recv_shape = torch.LongTensor([1] * recv_ndims).to(self.device)
            p2p.recv(recv_shape, send_stage)
            recv_shape = recv_shape.tolist()
            return self._allocate_buffer(recv_shape, num_buffers=1)[0]

        # List or tuple of tensors
        elif recv_type == 1 or recv_type == 2:
            count_tensor = torch.LongTensor(data=[0]).to(self.device)
            p2p.recv(count_tensor, send_stage)
            num_tensors = count_tensor.item()
909
            recv_shapes_and_dtypes = []
910
            for idx in range(num_tensors):
911 912 913
                recv_dtype = torch.LongTensor(data=[0]).to(self.device)
                p2p.recv(recv_dtype, send_stage)
                recv_dtype = self.ID_TO_DTYPE[recv_dtype.item()]
914 915 916 917 918
                recv_ndims = torch.LongTensor(data=[0]).to(self.device)
                p2p.recv(recv_ndims, send_stage)
                recv_ndims = recv_ndims.item()
                recv_shape = torch.LongTensor([1] * recv_ndims).to(self.device)
                p2p.recv(recv_shape, send_stage)
919
                recv_shapes_and_dtypes.append((recv_shape.tolist(), recv_dtype))
920

921
            buffers = self._allocate_buffers(recv_shapes_and_dtypes, num_buffers=1)[0]
922 923 924 925 926 927 928 929 930 931
            # Convert to tuples if requested.
            if recv_type == 2:
                buffers = tuple(buffers)
            return buffers

        else:
            raise NotImplementedError(f'Could not receive type {type(recv_type)}')

    def _exec_send_activations(self, buffer_id):
        if self.wall_clock_breakdown():
932
            self.timers(PIPE_SEND_OUTPUT_TIMER).start()
933 934 935 936 937 938

        outputs = self.pipe_buffers['outputs'][buffer_id]

        # NCCL does not like to send torch.BoolTensor types, so cast the mask to half().
        # We could do char, but with half() we can eventually flatten with other fp16
        # messages (TODO)
939
        if self.has_attention_mask or self.has_bool_tensors:
940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957
            outputs = list(outputs)
            outputs[-1] = outputs[-1].half()
            outputs = tuple(outputs)

        if self.first_output_send:
            self.first_output_send = False
            self._send_tensor_meta(outputs, self.next_stage)

        if isinstance(outputs, torch.Tensor):
            p2p.send(outputs, self.next_stage)
        elif isinstance(outputs, tuple):
            for idx, buffer in enumerate(outputs):
                p2p.send(buffer, self.next_stage)
        else:
            raise NotImplementedError('Could not send output of type '
                                      f'{type(outputs)}')

        # Restore the boolean tensor
958
        if self.has_attention_mask or self.has_bool_tensors:
959 960 961 962 963
            outputs = list(outputs)
            outputs[-1] = outputs[-1].bool()
            outputs = tuple(outputs)

        if self.wall_clock_breakdown():
964
            self.timers(PIPE_SEND_OUTPUT_TIMER).stop()
965 966 967

    def _exec_send_grads(self, buffer_id):
        if self.wall_clock_breakdown():
968
            self.timers(PIPE_SEND_GRAD_TIMER).start()
969 970 971 972 973

        inputs = self.pipe_buffers['inputs'][buffer_id]

        # Partition the gradient
        if self.is_grad_partitioned:
974 975 976
            if isinstance(inputs, tuple):
                first_input = inputs[0]
                assert all([torch.is_tensor(elt) for elt in inputs[1:]])
977
                inputs_grad_tail = [elt.grad for elt in inputs[1:] if elt.grad is not None]
978 979 980 981 982 983
            elif torch.is_tensor(inputs):
                first_input = inputs
                inputs_grad_tail = []
            else:
                raise ValueError("expecting a tensor or a tuple of tensors")
            assert torch.is_tensor(first_input)
984
            part = PartitionedTensor(tensor=first_input.grad, group=self.grid.get_slice_parallel_group())
985 986

            inputs = (part.to_meta(), part.data(), *inputs_grad_tail)
987 988 989 990 991 992

        # XXX Terrible hack
        # Drop the attention mask from the input buffer here. It does not have
        # a grad that needs to be communicated. We free the buffer immediately
        # after, so no need to restore it. The receiver also has a hack that skips
        # the recv. This is because NCCL does not let us send torch.BoolTensor :-(.
993
        if self.has_attention_mask or self.has_bool_tensors:
994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019
            inputs = list(inputs)
            inputs.pop()
            inputs = tuple(inputs)

        if isinstance(inputs, torch.Tensor):
            assert inputs.grad is not None
            p2p.send(inputs.grad, self.prev_stage)
        else:
            # XXX terrible hacky branch
            if self.is_grad_partitioned:
                # First two sends are partitioned gradient
                p2p.send(inputs[0], self.prev_stage)
                p2p.send(inputs[1], self.prev_stage)
            else:
                for idx, buffer in enumerate(inputs):
                    # Skip tensors that will not produce a grad
                    if not buffer.is_floating_point():
                        assert buffer.grad is None
                        continue
                    assert buffer.grad is not None
                    p2p.send(buffer.grad, self.prev_stage)

        # We can free up the input buffer now
        self.pipe_buffers['inputs'][buffer_id] = None

        if self.wall_clock_breakdown():
1020
            self.timers(PIPE_SEND_GRAD_TIMER).stop()
1021 1022 1023

    def _exec_recv_activations(self, buffer_id):
        if self.wall_clock_breakdown():
1024
            self.timers(PIPE_RECV_INPUT_TIMER).start()
1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043

        recvd = None

        # Allocate the buffer if necessary
        if self.pipe_recv_buf is None:
            self.pipe_recv_buf = self._recv_tensor_meta(self.prev_stage)

        if isinstance(self.pipe_recv_buf, torch.Tensor):
            p2p.recv(self.pipe_recv_buf, self.prev_stage)
            recvd = self.pipe_recv_buf.clone().detach()
            recvd.requires_grad = recvd.is_floating_point()
        else:
            assert isinstance(self.pipe_recv_buf, tuple)
            recvd = [None] * len(self.pipe_recv_buf)
            for idx, buffer in enumerate(self.pipe_recv_buf):
                assert torch.is_tensor(buffer)
                # XXX hardcode meta type
                if self.is_pipe_partitioned and idx == 0 and buffer.dtype != torch.long:
                    if self.meta_buffer is None:
1044
                        self.meta_buffer = torch.zeros(buffer.size(), dtype=torch.long, device=self.device)
1045 1046 1047 1048 1049 1050 1051
                    buffer = self.meta_buffer

                p2p.recv(buffer, self.prev_stage)
                recvd[idx] = buffer.clone().detach()

            # NCCL does not like to send torch.BoolTensor types, so un-cast the
            # attention mask
1052
            if self.has_attention_mask or self.has_bool_tensors:
1053 1054 1055 1056 1057 1058 1059 1060 1061 1062
                recvd[-1] = recvd[-1].bool()

            recvd = tuple(recvd)

            for buffer in recvd:
                buffer.requires_grad = buffer.is_floating_point()

        self.pipe_buffers['inputs'][buffer_id] = recvd

        if self.wall_clock_breakdown():
1063
            self.timers(PIPE_RECV_INPUT_TIMER).stop()
1064 1065 1066

    def _exec_recv_grads(self, buffer_id):
        if self.wall_clock_breakdown():
1067
            self.timers(PIPE_RECV_GRAD_TIMER).start()
1068 1069 1070 1071 1072

        outputs = self.pipe_buffers['outputs'][buffer_id]
        # XXX these shapes are hardcoded for Megatron
        # Restore partitioned output if it was partitioned and we are sending full gradients
        if self.is_pipe_partitioned and not self.is_grad_partitioned:
1073 1074 1075
            part_output = PartitionedTensor.from_meta(meta=outputs[0],
                                                      local_part=outputs[1],
                                                      group=self.grid.get_slice_parallel_group())
1076
            outputs[0].data = part_output.full()
C
Conglong Li 已提交
1077
            outputs = (outputs[0], *outputs[2:])
1078 1079 1080 1081 1082 1083 1084
            # save for backward
            self.pipe_buffers['outputs'][buffer_id] = outputs

        # Allocate gradient if necessary
        if self.grad_layer is None:
            if isinstance(outputs, torch.Tensor):
                s = list(outputs.size())
1085
                self.grad_layer = self._allocate_buffer(s, dtype=outputs.dtype, num_buffers=1)[0]
1086
            else:
C
Conglong Li 已提交
1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097
                # XXX This is a HACK
                # When we exchange activations/gradients, the two pipe stages
                # need to issue the send/recv with the same buffer sizes or
                # else there is a deadlock. The is_floating_point() filter is
                # used to avoid sending gradients for tensors that do not
                # produce gradients. When TP>1, we partition the first
                # activations/gradients across TP ranks to save communication
                # volume and memory. That partitioned tensor is represented as
                # two tensors: a 1/TPth chunk of the original data and also a
                # small LongTensor storing the metadata used to reconstruct on
                # the other side. When combined, the floating point filter also
C
Conglong Li 已提交
1098
                # filtered out the metadata tensor. This quick (hacky) fix just
C
Conglong Li 已提交
1099 1100 1101
                # branches on is_grad_partitioned so we don't filter out the
                # metadata tensor.
                if self.is_grad_partitioned:
1102 1103 1104
                    sizes_and_dtypes = [(list(t.size()), t.dtype)
                                        for t in outputs[:2]] + [(list(t.size()), t.dtype)
                                                                 for t in outputs[2:] if t.is_floating_point()]
C
Conglong Li 已提交
1105
                else:
1106 1107
                    sizes_and_dtypes = [(list(t.size()), t.dtype) for t in outputs if t.is_floating_point()]
                self.grad_layer = self._allocate_buffers(sizes_and_dtypes, num_buffers=1)[0]
1108 1109 1110 1111 1112 1113 1114 1115

        if isinstance(self.grad_layer, torch.Tensor):
            p2p.recv(self.grad_layer, self.next_stage)
        else:
            assert isinstance(outputs, tuple)
            for idx, buffer in enumerate(self.grad_layer):
                # XXX GPT-2 hack
                if self.is_grad_partitioned and idx == 0 and buffer.dtype != torch.long:
1116
                    buffer.data = torch.zeros(buffer.size(), dtype=torch.long, device=self.device)
1117 1118 1119
                p2p.recv(buffer, self.next_stage)

        if self.wall_clock_breakdown():
1120
            self.timers(PIPE_RECV_GRAD_TIMER).stop()
1121

1122
    def _exec_optimizer_step(self, lr_kwargs=None):
1123
        if self.wall_clock_breakdown():
1124 1125
            self.timers(STEP_MICRO_TIMER).start()
            self.timers(STEP_GLOBAL_TIMER).start()
1126 1127 1128
        self.mem_status('BEFORE STEP', reset_max=True)

        self._force_grad_boundary = True
1129
        self._take_model_step(lr_kwargs)
1130 1131 1132 1133
        self._force_grad_boundary = False

        self.mem_status('AFTER STEP')

1134
        if self.global_rank == 0 and self.monitor.enabled:
1135
            self.summary_events = [(f'Train/Samples/lr', self.get_lr()[0], self.global_samples)]
1136
            if self.fp16_enabled() and hasattr(self.optimizer, 'cur_scale'):
1137 1138
                self.summary_events.append(
                    (f'Train/Samples/loss_scale', self.optimizer.cur_scale, self.global_samples))
1139
            self.monitor.write_events(self.summary_events)
1140 1141

        if self.wall_clock_breakdown():
1142 1143
            self.timers(STEP_MICRO_TIMER).stop()
            self.timers(STEP_GLOBAL_TIMER).stop()
1144 1145
            if self.global_steps % self.steps_per_print() == 0:
                self.timers.log([
1146 1147 1148 1149 1150 1151
                    BATCH_INPUT_TIMER,
                    FORWARD_MICRO_TIMER,
                    BACKWARD_MICRO_TIMER,
                    BACKWARD_INNER_MICRO_TIMER,
                    BACKWARD_REDUCE_MICRO_TIMER,
                    STEP_MICRO_TIMER,
1152 1153
                ])
            if self.global_steps % self.steps_per_print() == 0:
1154 1155 1156 1157 1158 1159 1160
                self.timers.log([
                    FORWARD_GLOBAL_TIMER,
                    BACKWARD_GLOBAL_TIMER,
                    BACKWARD_INNER_GLOBAL_TIMER,
                    BACKWARD_REDUCE_GLOBAL_TIMER,
                    STEP_GLOBAL_TIMER,
                ])
1161 1162 1163 1164 1165 1166 1167 1168 1169 1170

    def _zero_grads(self, inputs):
        if isinstance(inputs, torch.Tensor):
            if inputs.grad is not None:
                inputs.grad.data.zero_()
        else:
            for t in inputs:
                if t.grad is not None:
                    t.grad.data.zero_()

1171
    def _allocate_zeros(self, shape, **kwargs):
1172 1173 1174 1175 1176 1177 1178 1179 1180
        """ Allocate a tensor of zeros on the engine's device.

        Arguments:
            shape: the shape of the tensor to allocate
            kwargs: passed to torch.zeros()

        Returns:
            A tensor from torch.zeros() allocated on self.device.
        """
D
Du Li 已提交
1181 1182 1183 1184 1185
        if "dtype" not in kwargs:
            if self.fp16_enabled():
                kwargs["dtype"] = torch.half
            if self.bfloat16_enabled():
                kwargs["dtype"] = torch.bfloat16
1186

1187
        return torch.zeros(shape, device=self.device, **kwargs)
1188 1189 1190 1191 1192 1193 1194 1195 1196

    def _allocate_buffer(self, shape, num_buffers=-1, **kwargs):
        buffers = []
        if num_buffers == -1:
            num_buffers = self.num_pipe_buffers
        for count in range(num_buffers):
            buffers.append(self._allocate_zeros(shape, **kwargs))
        return buffers

1197
    def _allocate_buffers(self, shapes_and_dtypes, requires_grad=False, num_buffers=-1):
1198 1199 1200 1201 1202
        buffers = []
        if num_buffers == -1:
            num_buffers = self.num_pipe_buffers
        for count in range(num_buffers):
            buffer = []
1203
            for shape, dtype in shapes_and_dtypes:
1204
                buffer.append(self._allocate_zeros(shape, dtype=dtype, requires_grad=requires_grad))
1205 1206 1207 1208
            buffers.append(buffer)
        return buffers

    def forward(self, *args, **kwargs):
S
Shaden Smith 已提交
1209
        """Disabled for pipeline parallel training. See ``train_batch()``. """
1210 1211 1212
        raise PipelineError("Only train_batch() is accessible in pipeline mode.")

    def backward(self, *args, **kwargs):
S
Shaden Smith 已提交
1213
        """Disabled for pipeline parallel training. See ``train_batch()``. """
1214 1215 1216
        raise PipelineError("Only train_batch() is accessible in pipeline mode.")

    def step(self, *args, **kwargs):
S
Shaden Smith 已提交
1217
        """Disabled for pipeline parallel training. See ``train_batch()``. """
1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235
        raise PipelineError("Only train_batch() is accessible in pipeline mode.")

    def mem_status(self, msg, print_rank=-1, reset_max=False):
        return
        global mem_alloced, mem_cached
        if not self.global_steps == 0 or not self.global_steps == 9:
            #return
            pass
        if self.mpu.get_data_parallel_rank() != 0:
            return

        if self.global_rank != 0:
            return

        rank = self.global_rank
        if print_rank != -1 and rank != print_rank:
            return

1236
        get_accelerator().synchronize()
1237 1238

        if reset_max:
1239 1240
            get_accelerator().reset_max_memory_cached()
            get_accelerator().reset_max_memory_allocated()
1241

1242 1243
        new_alloced = get_accelerator().memory_allocated()
        new_cached = get_accelerator().memory_cached()
1244 1245 1246 1247 1248 1249 1250

        delta_alloced = new_alloced - mem_alloced
        delta_cached = new_cached - mem_cached

        mem_cached = new_cached
        mem_alloced = new_alloced

1251 1252
        max_alloced = get_accelerator().max_memory_allocated()
        max_cached = get_accelerator().max_memory_cached()
1253 1254 1255 1256 1257 1258 1259 1260 1261 1262

        # convert to GB for printing
        new_alloced /= 1024**3
        new_cached /= 1024**3
        delta_alloced /= 1024**3
        delta_cached /= 1024**3
        max_alloced /= 1024**3
        max_cached /= 1024**3

        print(
1263
            f'RANK={rank} STAGE={self.stage_id} STEP={self.global_steps} MEMSTATS', msg,
1264
            f'current alloc={new_alloced:0.4f}GB (delta={delta_alloced:0.4f}GB max={max_alloced:0.4f}GB) '
1265
            f'current cache={new_cached:0.4f}GB (delta={delta_cached:0.4f}GB max={max_cached:0.4f}GB)')
1266

1267
    def module_state_dict(self, exclude_frozen_parameters=False):
1268 1269 1270 1271 1272 1273 1274
        """Override hack to save a pipe model and return the directory path of the save.

        This method should only be called by DeepSpeed's ``save_checkpoint()``. The
        recommended way of saving a ``PipelineModule`` outside of ``save_checkpoint()``
        is ``save_state_dict()``.

        Returns:
1275
            None
1276 1277
        """
        assert isinstance(self.module, PipelineModule)
1278
        assert self._curr_ckpt_path is not None, \
1279 1280
            "PipelineEngine expects module_state_dict() to be called from save_checkpoint()"

1281 1282 1283
        self.module.save_state_dict(self._curr_ckpt_path,
                                    checkpoint_engine=self.checkpoint_engine,
                                    exclude_frozen_params=exclude_frozen_parameters)
1284
        return None
1285

1286
    def load_module_state_dict(self, checkpoint, strict=True, custom_load_fn=None, fetch_z3_params=False):
1287 1288 1289 1290
        """Override hack to instead use a directory path.

        This is important because pipeline models checkpoint by layer instead of rank.

1291
        If ``state_dict`` is not ``None`` or a ``str``, we revert to ``super()`` expecting a ``dict``.
1292 1293

        Args:
1294
            state_dict (str, None): unused
1295 1296
            strict (bool, optional): Strict state loading. Defaults to True.
        """
J
Jeff Rasley 已提交
1297
        assert custom_load_fn is None, "custom_load_fn not supported w. pipeline parallelism"
1298
        state_dict = checkpoint['module']
1299
        if (state_dict is not None) and (not isinstance(state_dict, str)):
1300 1301 1302
            super().load_module_state_dict(state_dict, strict)
            return

T
trajep 已提交
1303 1304 1305
        self.module.load_state_dir(load_dir=self._curr_ckpt_path,
                                   strict=strict,
                                   checkpoint_engine=self.checkpoint_engine)
1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322

    # A map of PipeInstruction types to methods. Each method will be executed with the
    # kwargs provided to the PipeInstruction from the scheduler.
    _INSTRUCTION_MAP = {
        schedule.OptimizerStep: _exec_optimizer_step,
        schedule.ReduceGrads: _exec_reduce_grads,
        schedule.ReduceTiedGrads: _exec_reduce_tied_grads,
        schedule.LoadMicroBatch: _exec_load_micro_batch,
        schedule.ForwardPass: _exec_forward_pass,
        schedule.BackwardPass: _exec_backward_pass,
        schedule.SendActivation: _exec_send_activations,
        schedule.RecvActivation: _exec_recv_activations,
        schedule.SendGrad: _exec_send_grads,
        schedule.RecvGrad: _exec_recv_grads,
    }

    def _exec_schedule(self, pipe_schedule):
1323
        # Reserve and reset buffers.
1324
        self._reserve_pipe_buffers(pipe_schedule.num_pipe_buffers())
1325 1326
        self.fwd_outputs = []

1327 1328 1329 1330 1331
        # For each step in the schedule
        for step_cmds in pipe_schedule:
            # For each instruction in the step
            for cmd in step_cmds:
                if type(cmd) not in self._INSTRUCTION_MAP:
1332
                    raise RuntimeError(f'{self.__class__.__name__} does not understand instruction {repr(cmd)}')
1333 1334 1335 1336

                # Equivalent to: self._exec_forward_pass(buffer_id=0)
                self._exec_instr = MethodType(self._INSTRUCTION_MAP[type(cmd)], self)
                self._exec_instr(**cmd.kwargs)