pipeline_parallel.py 39.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

import paddle
15
from paddle import framework
16

17
from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer
18
from ..utils import timer_helper as timer
19 20 21 22 23 24 25 26
from ..utils.hybrid_parallel_util import (
    broadcast_dp_parameters,
    broadcast_mp_parameters,
    broadcast_sharding_parameters,
)
from ..utils.log_util import logger
from .meta_parallel_base import MetaParallelBase
from .parallel_layers.pp_layers import PipelineLayer
S
ShenLiang 已提交
27
from .pp_utils import p2p_communication as p2p
28
from .pp_utils.utils import HOOK_ACTION, FusedCommBuffer, assign_group_by_size
29

30 31
__all__ = []

32

33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
# assume only the first stage and last stage need data, and data consumption are ordred;
# to be replaced by real micro dataset from reader
class FakeMicroDataset:
    def __init__(
        self, data, is_first_stage, is_last_stage, acc_steps, micro_batch_size
    ):
        self._data = data
        self._index = 0
        self._acc_steps = acc_steps
        self._is_first_stage = is_first_stage
        self._is_last_stage = is_last_stage
        self._micro_batch_size = micro_batch_size

    def __iter__(self):
        return self

    def __next__(self):
        assert self._index < self._acc_steps
        assert self._is_first_stage or self._is_last_stage
        micro_batch_data = self._load_micro_batch(self._index)
        self._index += 1
        return micro_batch_data

    def _load_micro_batch(self, micro_step):
        inputs = self._data

S
ShenLiang 已提交
59 60 61
        data = None
        label = None
        if self._is_first_stage:
62 63
            assert len(inputs) == 2, "length of input should be 2"
            data = self._load_micro_batch_impl(inputs[0], micro_step)
S
ShenLiang 已提交
64 65 66

        if self._is_last_stage:
            assert len(inputs) == 2, "length of input should be 2"
67
            label = self._load_micro_batch_impl(inputs[1], micro_step)
S
ShenLiang 已提交
68 69

        return (data, label)
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115

    def _load_micro_batch_impl(self, inputs, micro_step):
        begin = micro_step * self._micro_batch_size
        end = begin + self._micro_batch_size

        if isinstance(inputs, tuple):
            output = []
            for data in inputs:
                if isinstance(data, list):
                    assert (
                        len(data) == self._acc_steps
                    ), "length of data should be %d, but it is %d" % (
                        self._acc_steps,
                        len(data),
                    )
                    output.append(data[micro_step].detach())
                elif data is not None:
                    self._check_data_vaild(data)
                    output.append(data[begin:end, :].detach())
                else:
                    output.append(None)
            return tuple(output)

        elif isinstance(inputs, list):
            assert (
                len(inputs) == self._acc_steps
            ), "length of data should be %d, but it is %d" % (
                self.accumulate_steps,
                len(inputs),
            )
            return inputs[micro_step].detach()
        elif inputs is not None:
            self._check_data_vaild(inputs)
            return inputs[begin:end, :].detach()
        else:
            return None

    def _check_data_vaild(self, data):
        batch_size = data.shape[0]
        assert self._micro_batch_size * self._acc_steps == batch_size, (
            "batch_size needs to be divisible by micro_batch_size. Currently, "
            "batch_size = %d, micro_batch_size = %d, accumulate_steps = %d."
            % (batch_size, self._micro_batch_size, self._acc_steps)
        )


116 117
class PipelineParallel(MetaParallelBase):
    def __init__(self, layers, hcg, strategy):
118 119
        if not isinstance(layers, PipelineLayer):
            raise TypeError(
120 121
                "The Layer should be a derived class of PipelineLayer."
            )
122
        super().__init__(layers, hcg, strategy)
123 124
        self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1
        self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1
125 126 127
        self.use_sharding_parallel = (
            self._hcg.get_sharding_parallel_world_size() > 1
        )
128 129 130 131

        self.total_loss = None

        self.micro_batch_size = self._strategy.pipeline_configs[
132 133
            'micro_batch_size'
        ]
134
        self.accumulate_steps = self._strategy.pipeline_configs[
135 136
            'accumulate_steps'
        ]
137 138 139
        # If sent tensor are not the same from different hosts,
        # they shouldn't been sent partially and then concated as a whole tensor.
        self._enable_partial_send_recv = self._strategy.pipeline_configs[
140 141
            'enable_partial_send_recv'
        ]
142 143
        self._using_cache = self._strategy.pipeline_configs['p2p_cache_shape']

144 145
        self.num_stages = self._hcg.get_pipe_parallel_world_size()
        self.stage_id = self._hcg.get_stage_id()
146
        self.pp_group = self._hcg.get_pipe_parallel_group()
147
        self.dp_group = self._hcg.get_data_parallel_group()
148
        self.sharding_group = self._hcg.get_sharding_parallel_group()
149

150 151 152 153 154
        self._virtual_pp_world_size = None
        self._virtual_pp_rank = None
        self._real_pp_world_size = self.num_stages
        self._real_pp_rank = self.stage_id

155 156 157
        self._delay_scale_loss = self._strategy.hybrid_configs[
            "pp_configs"
        ].delay_scale_loss
158 159
        # TODO(PP Dev): support dp_comm_overlap without use_main_grad training.
        # This combination will trigger inplace check error during `reshape_` in funct `_split_tensors`.
160 161 162
        self._dp_comm_overlap = self._strategy.hybrid_configs[
            "pp_configs"
        ].dp_comm_overlap
163 164 165 166 167 168
        self._sharding_comm_overlap = self._strategy.hybrid_configs[
            "pp_configs"
        ].sharding_comm_overlap
        self._enable_timer = self._strategy.hybrid_configs[
            "pp_configs"
        ].enable_timer
169

H
Haohongxiang 已提交
170 171 172
        if self._dp_comm_overlap:
            assert self.use_data_parallel and self.num_stages > 1

173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
        if self._sharding_comm_overlap:
            assert self.use_sharding_parallel and self.num_stages > 1

        assert not (
            self._dp_comm_overlap and self._sharding_comm_overlap
        ), "Cannot use dp pp overlap and sharding pp overlap at the same time."

        self._comm_buffers = []
        self._comm_overlap = (
            self._dp_comm_overlap or self._sharding_comm_overlap
        )

        if self._enable_timer:
            if not timer.is_timer_initialized():
                timer.set_timers()
            self.timers = timer.get_timers()

190
        p2p.initialize_p2p_groups(
191 192 193 194
            hcg,
            self._using_cache,
            self._enable_partial_send_recv,
            self._enable_timer,
195
        )
196 197

        self.global_rank = self._hcg.get_global_rank()
198
        self.micro_batch_id = 0
199

200 201
        self._compute_loss = True

202 203 204 205 206
        logger.info(
            "Pipeline Info -- num_stages: {}, stage_id: {}".format(
                self.num_stages, self.stage_id
            )
        )
207 208 209 210 211

        if self.use_model_parallel:
            logger.info("start broadcast mp parameters")
            broadcast_mp_parameters(self._layers, self._hcg)

212 213 214 215
        if self.use_sharding_parallel:
            logger.info("start broadcast sharding parameters")
            broadcast_sharding_parameters(self._layers, self._hcg)

216
        if self.use_data_parallel:
217
            logger.info("start broadcast dp parameters")
218
            broadcast_dp_parameters(self._layers, self._hcg)
219

220 221
        if self._dp_comm_overlap:
            self.register_allreduce_overlap_hook(
222
                self._layers, self.dp_group, self.accumulate_steps, True
223 224
            )

225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246
    def is_pipeline_first_stage(self, ignore_virtual=False):
        if not ignore_virtual:
            if self._virtual_pp_world_size is not None:
                assert self._virtual_pp_rank is not None
                if self._virtual_pp_rank != 0:
                    return False
        assert self._real_pp_rank is not None
        return self._real_pp_rank == 0

    def is_pipeline_last_stage(self, ignore_virtual=False):
        if not ignore_virtual:
            if self._virtual_pp_world_size is not None:
                assert self._virtual_pp_rank is not None
                if self._virtual_pp_rank != (self._virtual_pp_world_size - 1):
                    return False
        assert self._real_pp_rank is not None
        assert self._real_pp_world_size is not None
        return self._real_pp_rank == (self._real_pp_world_size - 1)

    def set_virtual_pipeline_rank(self, rank):
        self._virtual_pp_rank = rank

247 248 249 250 251 252 253
    def bw_hook_func(self, buffer, param):
        @paddle.autograd.no_grad()
        def fused_allreduce(*_):
            buffer.add_grad(param)

        return fused_allreduce

254
    def register_allreduce_overlap_hook(self, model, comm_group, acc_steps, dp):
255 256 257 258 259
        if model.get_num_virtual_stages() > 1:
            models = model.get_model_chunks()
        else:
            models = [model]

260 261 262 263 264 265 266
        if not dp:
            assert hasattr(self, "optimizer")
            assert hasattr(self.optimizer, "_param2rank")
            _param2rank = self.optimizer._param2rank

        act = HOOK_ACTION.ALL_REDUCE if dp else HOOK_ACTION.REDUCE

267 268 269
        for model in models:
            # For virtual pipeline. Will separate parameters in different chunk into
            # different groups to get the best performance.
270

271 272
            fused_parameter_group = {}

273 274 275 276 277 278
            parameter_list = [
                p for p in model.parameters() if not p.stop_gradient
            ]
            if len(parameter_list) < 1:
                return

279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299
            if dp:
                fused_parameter_group[-1] = parameter_list
            else:
                # Sort parameters for sharding, since they have different dst rank
                for p in parameter_list:
                    assert p.name in _param2rank
                    dst_rank = _param2rank[p.name]
                    if dst_rank in fused_parameter_group:
                        fused_parameter_group[dst_rank].append(p)
                    else:
                        fused_parameter_group[dst_rank] = [p]

            for dst in fused_parameter_group:
                parameter_list = fused_parameter_group[dst]
                if not dp:
                    # parse the relative dst rank to absolute dst rank for sharding
                    dst = comm_group.ranks[dst]
                var_groups = assign_group_by_size(parameter_list)
                for group_idx, parameters in var_groups.items():
                    buffer = FusedCommBuffer(
                        group_idx, parameters, comm_group, acc_steps, act, dst
300
                    )
301 302 303 304 305 306 307 308 309 310 311
                    self._comm_buffers.append(buffer)
                    for param in parameters:
                        param._register_backward_hook(
                            self.bw_hook_func(buffer, param)
                        )

    def timer_printer(self):
        if not self._enable_timer:
            return
        all_flag_names = self.timers.timers.keys()
        self.timers.log(all_flag_names)
312

313 314 315 316
    def forward_backward_pipeline(self, data, scaler=None):
        # use the 1f1b scheduling strategy.
        # this strategy is inspired by:
        # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
317

318 319
        self.scaler = scaler

320 321 322
        # store total loss of entire batch
        self.total_loss = None

323 324
        # store data id for micro_batch
        self.micro_batch_id = 0
325

326
        startup_steps = self.num_stages - self.stage_id - 1
327 328
        startup_steps = min(startup_steps, self.accumulate_steps)
        steady_steps = self.accumulate_steps - startup_steps
329

330 331
        input_buffers = []
        output_buffers = []
332

333 334
        micro_dataset = self._wrap_data(data)

335
        for step_id in range(startup_steps):
336
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
337

338
            output_tensor = self._forward_step(input_tensor, micro_dataset)
339
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
340

341 342
            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)
343

344 345 346
            if not self.is_pipeline_last_stage():
                self._release_output(output_tensor)

347
        if steady_steps > 0:
348
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
349

350
        for i in range(steady_steps):
351
            last_iter = i == (steady_steps - 1)
352

353
            output_tensor = self._forward_step(input_tensor, micro_dataset)
354

355
            output_tensor_grad = p2p.send_forward_recv_backward(
356 357
                output_tensor, self.is_pipeline_last_stage()
            )
358

359 360
            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)
361

362 363 364
            if not self.is_pipeline_last_stage():
                self._release_output(output_tensor)

365
            input_tensor, output_tensor = input_buffers.pop(
366 367
                0
            ), output_buffers.pop(0)
368

369 370 371
            input_tensor_grad = self._backward_step(
                input_tensor, output_tensor, output_tensor_grad
            )
372 373 374

            if last_iter:
                input_tensor = None
375 376 377
                p2p.send_backward(
                    input_tensor_grad, self.is_pipeline_first_stage()
                )
378
            else:
379
                input_tensor = p2p.send_backward_recv_forward(
380 381
                    input_tensor_grad, self.is_pipeline_first_stage()
                )
382

383 384 385
        for i in range(startup_steps):
            input_tensor = input_buffers.pop(0)
            output_tensor = output_buffers.pop(0)
386

387
            output_tensor_grad = p2p.recv_backward(
388 389
                self.is_pipeline_last_stage()
            )
390

391 392 393
            input_tensor_grad = self._backward_step(
                input_tensor, output_tensor, output_tensor_grad
            )
394
            p2p.send_backward(input_tensor_grad, self.is_pipeline_first_stage())
395

396 397 398
        if self._comm_overlap:
            assert len(self._comm_buffers) > 0
            for buffer in self._comm_buffers:
399 400
                buffer.scale_and_split_grads()

401 402
        if self._enable_timer:
            self.timers("allreduce_shared_weight_gradients").start()
403
        self._layers.allreduce_shared_weight_gradients()
404 405 406
        if self._enable_timer:
            self.timers("allreduce_shared_weight_gradients").stop()
            self.timers("broadcast_final_loss").start()
407 408
        with paddle.amp.auto_cast(enable=False):
            train_loss = self._broadcast_final_loss()
409 410 411
        if self._enable_timer:
            self.timers("broadcast_final_loss").stop()
        self.timer_printer()
412 413
        return train_loss

414 415 416 417
    def _prepare_training(self, data, optimizer, lr_scheduler):
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

418 419 420
        assert isinstance(
            optimizer, HybridParallelOptimizer
        ), 'optimizer should be HybridParallelOptimizer subclass.'
421

422
        assert (
423
            framework._dygraph_tracer()._has_grad
424
        ), 'Please enable the generation of gradients.'
425

426
        if self.is_pipeline_first_stage(
427 428 429 430 431
            ignore_virtual=True
        ) or self.is_pipeline_last_stage(ignore_virtual=True):
            assert (
                data is not None
            ), "For the first and the last stage, the data must be set."
432 433 434 435 436 437 438 439
        else:
            data = None

        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler

        self._layers.train()

440 441 442 443 444
        if self._sharding_comm_overlap and len(self._comm_buffers) == 0:
            self.register_allreduce_overlap_hook(
                self._layers, self.sharding_group, self.accumulate_steps, False
            )

445 446
        return data

447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462
    def _wrap_data(self, data):
        """
        for backward compatibilty, wrap data to Fake FakeMicroDataset if it is of type list or tuple
        """
        if (not isinstance(data, tuple)) and (not isinstance(data, list)):
            return data

        micro_dataset = FakeMicroDataset(
            data,
            self.is_pipeline_first_stage(ignore_virtual=True),
            self.is_pipeline_last_stage(ignore_virtual=True),
            self.accumulate_steps,
            self.micro_batch_size,
        )
        return micro_dataset

463 464 465
    def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
        data = self._prepare_training(data, optimizer, lr_scheduler)
        # 1f1b scheduler for pipeline parallel
466
        train_loss = self.forward_backward_pipeline(data, scaler)
467 468

        # optimizer
469 470
        with paddle.amp.auto_cast(enable=False):
            self._optimizer_step()
471 472

        return train_loss
473

474
    def eval_batch(self, data, compute_loss=False):
475 476 477
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

478 479 480 481 482 483 484 485 486
        self._layers.eval()
        self._compute_loss = compute_loss

        # store data id for micro_batch
        self.micro_batch_id = 0

        # store total loss of entire batch
        self.total_loss = None

487
        startup_steps = self.num_stages - self.stage_id - 1
488 489 490 491 492 493
        startup_steps = min(startup_steps, self.accumulate_steps)
        steady_steps = self.accumulate_steps - startup_steps

        input_buffers = []
        output_buffers = []

494 495
        micro_dataset = self._wrap_data(data)

496
        for step_id in range(startup_steps):
497
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
498

499
            output_tensor = self._forward_step(input_tensor, micro_dataset)
500
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
501 502 503 504 505

            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)

        if steady_steps > 0:
506
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
507 508

        for i in range(steady_steps):
509
            last_iter = i == (steady_steps - 1)
510

511
            output_tensor = self._forward_step(input_tensor, micro_dataset)
512
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
513 514 515 516 517

            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)

            if not last_iter:
518
                input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
519

520 521 522 523 524 525
        if self._compute_loss:
            self.train_loss = self._broadcast_final_loss()
        else:
            self.train_loss = output_buffers

        return self.train_loss
526

527
    def _forward_step(self, input_tensor, micro_dataset, chunk_id=None):
528 529
        if self._enable_timer:
            self.timers("forward_step").start()
530
        if self.is_pipeline_first_stage():
531 532
            input_tensor = next(micro_dataset)[0]
            self._check_micro_batch_data_valid(input_tensor)
533

534 535 536
        assert chunk_id is None or isinstance(chunk_id, int)

        output_tensor = self._layers.forward(input_tensor, chunk_id=chunk_id)
537

538
        if self.is_pipeline_last_stage():
539 540
            # train calculate loss for train
            if self._compute_loss:
541 542 543
                assert (
                    self._layers._loss_fn is not None
                ), "loss function should exist to compute loss"
544 545
                labels = next(micro_dataset)[1]
                self._check_micro_batch_data_valid(labels)
546
                output_tensor = self._layers._loss_fn(output_tensor, labels)
547
                assert isinstance(
548
                    output_tensor, (paddle.Tensor, framework.core.eager.Tensor)
549
                ), "Currently, loss_fn should obtain Paddle.Tensor dtype"
550

551
                with paddle.amp.auto_cast(enable=False):
552
                    if self.accumulate_steps > 1 and not self._delay_scale_loss:
553
                        output_tensor = output_tensor / self.accumulate_steps
554

555 556 557
                    if self.total_loss is None:
                        self.total_loss = paddle.zeros_like(output_tensor)
                    self.total_loss += output_tensor.detach()
558

559 560 561 562
        if self.is_pipeline_first_stage() or self.is_pipeline_last_stage():
            # Only increase micro batch id at virtual first/last pp stage.
            # The micro batch id is used to load data, therefore, only increase it when load data.
            self.micro_batch_id += 1
563 564
        if self._enable_timer:
            self.timers("forward_step").stop()
565 566
        return output_tensor

567 568 569 570 571 572 573 574 575 576
    def _check_micro_batch_data_valid(self, micro_batch_data):
        if isinstance(micro_batch_data, (tuple, list)):
            for data in micro_batch_data:
                self._check_micro_batch_data_valid(data)
        elif micro_batch_data is not None:
            micro_batch_size = micro_batch_data.shape[0]
            assert (
                micro_batch_size == self.micro_batch_size
            ), f"expected micro_batch_size {self.micro_batch_size} but get {micro_batch_size}"

577
    def _backward_step(self, input_tensor, output_tensor, output_tensor_grad):
578 579
        if self._enable_timer:
            self.timers("backward_step").start()
S
ShenLiang 已提交
580
        with paddle.amp.auto_cast(enable=False):
581
            if self.is_pipeline_last_stage():
S
ShenLiang 已提交
582 583 584 585 586
                assert output_tensor_grad is None
                if self.scaler:
                    paddle.autograd.backward(self.scaler.scale(output_tensor))
                else:
                    paddle.autograd.backward(output_tensor)
587
            else:
S
ShenLiang 已提交
588 589 590 591 592
                if isinstance(output_tensor, tuple):
                    outputs = [t for t in output_tensor if not t.stop_gradient]
                    assert len(outputs) == len(output_tensor_grad)
                    paddle.autograd.backward(
                        tensors=outputs,
593
                        grad_tensors=list(output_tensor_grad),
594
                    )
S
ShenLiang 已提交
595
                else:
596 597 598 599
                    paddle.autograd.backward(
                        tensors=[output_tensor],
                        grad_tensors=[output_tensor_grad],
                    )
S
ShenLiang 已提交
600 601 602 603 604

            input_tensor_grad = None
            if input_tensor is not None:
                if isinstance(input_tensor, tuple):
                    input_tensor_grad = tuple(
605 606
                        [t.grad for t in input_tensor if not t.stop_gradient]
                    )
S
ShenLiang 已提交
607 608
                else:
                    input_tensor_grad = input_tensor.grad
609 610
            if self._enable_timer:
                self.timers("backward_step").stop()
S
ShenLiang 已提交
611
            return input_tensor_grad
612

613
    def _broadcast_final_loss(self):
614 615 616
        # Since the last backward run in interleave will set the virtual rank to 0,
        # here we need to check last stage ignoring virtual stage.
        if self.is_pipeline_last_stage(ignore_virtual=True):
617 618 619
            assert (
                self.total_loss is not None
            ), "train_batch() in last stage should obtain vaild loss"
620 621 622 623 624
            loss = (
                self.total_loss.detach()
                if not self._delay_scale_loss
                else self.total_loss / self.accumulate_steps
            )
625
            is_fp32 = (
626
                paddle.full([], 1, 'int64')
627
                if loss.dtype == paddle.float32
628
                else paddle.full([], 0, 'int64')
629 630 631 632 633 634 635
            )
            paddle.distributed.broadcast(
                is_fp32, src=self.global_rank, sync_op=True, group=self.pp_group
            )
            paddle.distributed.broadcast(
                loss, src=self.global_rank, sync_op=True, group=self.pp_group
            )
636
        else:
637
            is_fp32 = paddle.full([], 1, 'int64')
638 639 640
            paddle.distributed.broadcast(
                is_fp32,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
641
                sync_op=True,
642 643 644 645
                group=self.pp_group,
            )
            loss = (
                paddle.zeros(shape=[1], dtype="float32")
646
                if is_fp32.item()
647 648
                else paddle.zeros(shape=[1], dtype="float16")
            )
649 650 651
            paddle.distributed.broadcast(
                loss,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
652
                sync_op=True,
653 654
                group=self.pp_group,
            )
655
        return loss
656

657
    def _optimizer_step(self):
658 659 660 661 662 663 664 665
        if self._delay_scale_loss:
            for p in self._layers.parameters():
                if hasattr(p, "main_grad") and p.main_grad is not None:
                    assert p.grad is None
                    p.main_grad = p.main_grad.scale(1.0 / self.accumulate_steps)
                elif p.grad is not None:
                    p.grad = p.grad.scale(1.0 / self.accumulate_steps)

666
        if self.scaler:
667
            self.scaler.step(self.optimizer)
S
ShenLiang 已提交
668
            self.scaler.update()
669 670
        else:
            self.optimizer.step()
671

672 673 674
        self.optimizer.clear_grad()
        if self.lr_scheduler:
            self.lr_scheduler.step()
675

676
    def _release_output(self, output):
677 678 679 680 681 682 683 684
        def can_free(t):
            return (
                t is not None
                and isinstance(t, paddle.Tensor)
                and t._is_initialized()
                and t.inplace_version == 0
            )

685 686
        if isinstance(output, (tuple, list)):
            for t in output:
687
                if can_free(t):
688
                    t._clear_dataptr()
689 690

        elif can_free(output):
691 692
            output._clear_dataptr()

693 694 695 696 697

class PipelineParallelWithInterleave(PipelineParallel):
    # pipeline parallel with interleave scheduler

    def __init__(self, layers, hcg, strategy):
698
        super().__init__(layers=layers, hcg=hcg, strategy=strategy)
699
        assert layers.get_num_virtual_stages() > 1
700 701
        assert (
            framework.in_dygraph_mode()
702
        ), "virtual pipeline stage with interleave only support eager dygraph mode"
703 704 705 706 707 708
        assert (
            self.num_stages > 2
        ), "virtual pipeline must run under pp degree > 2"
        assert (
            self.accumulate_steps % self.num_stages == 0
        ), "accumulate_steps should be evenly divisible by num_stages for pipeline with interleave"
709 710 711 712 713 714 715
        # setup for interleave scheduler
        self.num_model_chunks = layers.get_num_virtual_stages()
        self.model_chunks = layers.get_model_chunks()
        assert self.model_chunks is not None
        assert len(self.model_chunks) == self.num_model_chunks
        self._virtual_pp_world_size = self.num_model_chunks
        self._virtual_pp_rank = 0
716 717 718 719 720 721 722
        self._assign_vpp_info(self.model_chunks)

    def _assign_vpp_info(self, chunks):
        chunk_num = len(chunks)
        for i, chunk in enumerate(chunks):
            for p in chunk.parameters():
                p._chunk_info = {"chunk_id": i, "chunk_num": chunk_num}
723 724

    def _get_virtual_pp_rank(self, micro_step, forward):
725 726 727
        virtual_pp_stage = micro_step % (
            self.num_stages * self.num_model_chunks
        )
728 729
        virtual_pp_stage = virtual_pp_stage // self.num_stages
        if not forward:
730
            virtual_pp_stage = self.num_model_chunks - virtual_pp_stage - 1
731 732
        return virtual_pp_stage

733
    def _forward_step_helper(self, micro_dataset, micro_step):
734 735 736 737 738 739 740 741 742 743 744
        virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=True)
        self.set_virtual_pipeline_rank(virtual_pp_rank)

        # some checkers
        assert hasattr(self, 'input_tensors')
        assert hasattr(self, 'output_tensors')
        if not self._forward_only:
            assert hasattr(self, 'output_tensor_grads')

        if self.is_pipeline_first_stage():
            if len(self.input_tensors[virtual_pp_rank]) == len(
745 746
                self.output_tensors[virtual_pp_rank]
            ):
747 748
                self.input_tensors[virtual_pp_rank].append(None)
        input_tensor = self.input_tensors[virtual_pp_rank][-1]
749 750 751
        output_tensor = self._forward_step(
            input_tensor, micro_dataset, virtual_pp_rank
        )
752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776
        self.output_tensors[virtual_pp_rank].append(output_tensor)

        if self._forward_only:
            # no need to store tensor for backward
            self.input_tensors[virtual_pp_rank].pop()
            self.output_tensors[virtual_pp_rank].pop()

        return output_tensor

    def _backward_step_helper(self, micro_step):
        virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=False)
        self.set_virtual_pipeline_rank(virtual_pp_rank)

        # some checkers
        assert hasattr(self, 'input_tensors')
        assert hasattr(self, 'output_tensors')
        assert hasattr(self, 'output_tensor_grads')

        if self.is_pipeline_last_stage():
            if len(self.output_tensor_grads[virtual_pp_rank]) == 0:
                self.output_tensor_grads[virtual_pp_rank].append(None)

        input_tensor = self.input_tensors[virtual_pp_rank].pop(0)
        output_tensor = self.output_tensors[virtual_pp_rank].pop(0)
        output_tensor_grad = self.output_tensor_grads[virtual_pp_rank].pop(0)
777 778 779
        input_tensor_grad = self._backward_step(
            input_tensor, output_tensor, output_tensor_grad
        )
780 781 782

        return input_tensor_grad

783
    def forward_backward_pipeline(
784 785
        self, data, scaler, forward_only=False, compute_loss=True
    ):
786 787 788 789
        # use interleave scheduling strategy.
        # this strategy is inspired by:
        # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
        if not compute_loss:
790 791 792
            assert (
                not forward_only
            ), "compute_loss can only be set to False when forward_only is set to True"
793 794 795 796 797 798 799 800 801 802 803 804

        # init some attributes for this batch run
        self.scaler = scaler
        self.total_loss = None
        self.micro_batch_id = 0
        self._forward_only = forward_only

        # init some data buffers for interleave scheduler
        self.input_tensors = [[] for _ in range(self.num_model_chunks)]
        self.output_tensors = [[] for _ in range(self.num_model_chunks)]
        self.output_tensor_grads = [[] for _ in range(self.num_model_chunks)]

805 806
        micro_dataset = self._wrap_data(data)

807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824
        num_steps = self.accumulate_steps * self.num_model_chunks
        all_startup_steps = False
        if forward_only:
            # If only forward, since there is no backward during running, all steps are startup steps
            startup_steps = num_steps
        else:
            if self.accumulate_steps == self.num_stages:
                startup_steps = num_steps
                all_startup_steps = True
            else:
                startup_steps = (self.num_stages - self.stage_id - 1) * 2
                startup_steps += (self.num_model_chunks - 1) * self.num_stages
                startup_steps = min(startup_steps, num_steps)

        steady_steps = num_steps - startup_steps

        self.set_virtual_pipeline_rank(0)
        self.input_tensors[0].append(
825 826
            p2p.recv_forward(self.is_pipeline_first_stage(), sync_recv=False)
        )
827 828 829

        # run startup steps
        for micro_step in range(startup_steps):
830
            output_tensor = self._forward_step_helper(micro_dataset, micro_step)
831 832

            # determine whether recv forward tensor or not
833 834 835
            next_virtual_pp_rank = self._get_virtual_pp_rank(
                micro_step + 1, forward=True
            )
836 837 838 839 840 841 842 843 844 845 846 847 848
            recv_prev = True
            if self.is_pipeline_first_stage(ignore_virtual=True):
                if next_virtual_pp_rank == 0:
                    # next chunk is the first chunk, not need to pre recv an input tensor
                    recv_prev = False
            # last micro step, no next run
            if micro_step == (num_steps - 1):
                recv_prev = False

            # last stage shouldn't send tensor to downstream
            if self.is_pipeline_last_stage():
                output_tensor = None

849 850 851 852 853
            if (
                micro_step == (startup_steps - 1)
                and not forward_only
                and not all_startup_steps
            ):
854 855 856 857 858 859
                input_tensor_grad = None
                recv_next = True
                if self.is_pipeline_last_stage(ignore_virtual=True):
                    recv_next = False

                # the last startup step needs on four direction comm to set up for steady 1f1b
860 861 862 863
                (
                    input_tensor,
                    output_tensor_grad,
                ) = p2p.send_forward_backward_recv_forward_backward(
864 865 866
                    output_tensor,
                    input_tensor_grad,
                    recv_prev=recv_prev,
867 868 869 870 871
                    recv_next=recv_next,
                )
                self.output_tensor_grads[self.num_model_chunks - 1].append(
                    output_tensor_grad
                )
872 873
            else:
                input_tensor = p2p.send_forward_recv_forward(
874 875
                    output_tensor, recv_prev=recv_prev
                )
876 877
            self.input_tensors[next_virtual_pp_rank].append(input_tensor)

878 879
            self._release_output(output_tensor)

880 881 882 883
        # run 1f1b steady steps
        for micro_step in range(steady_steps):
            # forward
            forward_micro_step_id = micro_step + startup_steps
884 885 886
            output_tensor = self._forward_step_helper(
                micro_dataset, forward_micro_step_id
            )
887 888 889 890

            # backward
            backward_micro_step_id = micro_step
            input_tensor_grad = self._backward_step_helper(
891 892
                backward_micro_step_id
            )
893 894 895 896 897 898 899 900 901

            # four directions comm
            # send output tensor to downstream
            # send input tensor grad to upstream
            # recv input tensor from upstream
            # recv output tensor grad from downstream

            # last stage doesn't send rst to downstream
            forward_virtual_pp_rank = self._get_virtual_pp_rank(
902 903
                forward_micro_step_id, forward=True
            )
904 905 906 907 908 909
            self.set_virtual_pipeline_rank(forward_virtual_pp_rank)
            if self.is_pipeline_last_stage():
                output_tensor = None

            # first stage doesn't send grad to upstream
            backward_virtual_pp_rank = self._get_virtual_pp_rank(
910 911
                backward_micro_step_id, forward=False
            )
912 913 914 915 916 917 918 919
            self.set_virtual_pipeline_rank(backward_virtual_pp_rank)
            if self.is_pipeline_first_stage():
                input_tensor_grad = None

            # determine whether to recv input tensor from upstream
            recv_prev = True
            if self.is_pipeline_first_stage(ignore_virtual=True):
                next_forward_virtual_pp_rank = self._get_virtual_pp_rank(
920 921
                    forward_micro_step_id - (self.num_stages - 1), forward=True
                )
922 923 924 925 926 927
                if next_forward_virtual_pp_rank == (self.num_model_chunks - 1):
                    # first pp stage and first virtual stage
                    recv_prev = False
                next_forward_virtual_pp_rank += 1
            else:
                next_forward_virtual_pp_rank = self._get_virtual_pp_rank(
928 929
                    forward_micro_step_id + 1, forward=True
                )
930 931 932 933 934 935 936 937 938 939

            # last iteration doesn't need recv from upstream
            if micro_step == (steady_steps - 1):
                recv_prev = False

            # determine whether to recv grad from downstream
            recv_next = True
            if self.is_pipeline_last_stage(ignore_virtual=True):
                next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
                    backward_micro_step_id - (self.num_stages - 1),
940 941
                    forward=False,
                )
942 943 944 945 946 947
                if next_backward_virtual_pp_rank == 0:
                    # last pp stage and last virtual stage
                    recv_next = False
                next_backward_virtual_pp_rank -= 1
            else:
                next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
948 949
                    backward_micro_step_id + 1, forward=False
                )
950

951 952 953 954
            (
                input_tensor,
                output_tensor_grad,
            ) = p2p.send_forward_backward_recv_forward_backward(
955 956 957
                output_tensor,
                input_tensor_grad,
                recv_prev=recv_prev,
958 959
                recv_next=recv_next,
            )
960

961 962
            self._release_output(output_tensor)

963 964
            if recv_prev:
                self.input_tensors[next_forward_virtual_pp_rank].append(
965 966
                    input_tensor
                )
967 968
            if recv_next:
                self.output_tensor_grads[next_backward_virtual_pp_rank].append(
969 970
                    output_tensor_grad
                )
971

972 973
        self._release_output(output_tensor)

974 975 976 977
        # remaining backward steps
        if not forward_only:
            if all_startup_steps:
                self.output_tensor_grads[self.num_model_chunks - 1].append(
978 979 980 981
                    p2p.recv_backward(
                        self.is_pipeline_last_stage(), sync_recv=False
                    )
                )
982 983 984 985 986

            for micro_step in range(steady_steps, num_steps):
                # cooldown loop
                input_tensor_grad = self._backward_step_helper(micro_step)
                next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
987 988
                    micro_step + 1, forward=False
                )
989 990 991

                recv_next = True
                if self.is_pipeline_last_stage(ignore_virtual=True):
992 993 994
                    if next_backward_virtual_pp_rank == (
                        self.num_model_chunks - 1
                    ):
995 996 997 998 999 1000
                        recv_next = False

                if micro_step == (num_steps - 1):
                    recv_next = False

                self.output_tensor_grads[next_backward_virtual_pp_rank].append(
1001 1002 1003 1004
                    p2p.send_backward_recv_backward(
                        input_tensor_grad, recv_next=recv_next
                    )
                )
1005

1006 1007 1008
            if self._comm_overlap:
                assert len(self._comm_buffers) > 0
                for buffer in self._comm_buffers:
1009 1010
                    buffer.scale_and_split_grads()

1011 1012
            if self._enable_timer:
                self.timers("allreduce_shared_weight_gradients").start()
1013
            self._layers.allreduce_shared_weight_gradients()
1014 1015
            if self._enable_timer:
                self.timers("allreduce_shared_weight_gradients").stop()
1016 1017 1018

        if compute_loss:
            # return loss if compute loss
1019 1020
            if self._enable_timer:
                self.timers("broadcast_final_loss").start()
1021 1022
            with paddle.amp.auto_cast(enable=False):
                train_loss = self._broadcast_final_loss()
1023 1024
            if self._enable_timer:
                self.timers("broadcast_final_loss").stop()
1025 1026 1027 1028
        else:
            # else just return all intermediate output tensor for all micro steps
            train_loss = self.output_tensors

1029
        self.timer_printer()
1030 1031 1032 1033 1034
        return train_loss

    def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
        data = self._prepare_training(data, optimizer, lr_scheduler)
        # interleave scheduler for pipeline parallel
1035
        train_loss = self.forward_backward_pipeline(data, scaler)
1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049

        # optimizer
        with paddle.amp.auto_cast(enable=False):
            self._optimizer_step()

        return train_loss

    def eval_batch(self, data, compute_loss=False):
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

        self._layers.eval()
        self._compute_loss = compute_loss

1050
        return self.forward_backward_pipeline(data, None, forward_only=True)