pipeline_parallel.py 36.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

import paddle
15
from paddle import framework
16

17
from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer
Y
Yuang Liu 已提交
18
from ..utils import timer_helper as timer
19 20 21 22 23 24 25 26
from ..utils.hybrid_parallel_util import (
    broadcast_dp_parameters,
    broadcast_mp_parameters,
    broadcast_sharding_parameters,
)
from ..utils.log_util import logger
from .meta_parallel_base import MetaParallelBase
from .parallel_layers.pp_layers import PipelineLayer
S
ShenLiang 已提交
27
from .pp_utils import p2p_communication as p2p
28
from .pp_utils.utils import HOOK_ACTION, FusedCommBuffer, assign_group_by_size
29

30 31
__all__ = []

32 33 34

class PipelineParallel(MetaParallelBase):
    def __init__(self, layers, hcg, strategy):
35 36
        if not isinstance(layers, PipelineLayer):
            raise TypeError(
37 38
                "The Layer should be a derived class of PipelineLayer."
            )
39
        super().__init__(layers, hcg, strategy)
40 41
        self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1
        self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1
42 43 44
        self.use_sharding_parallel = (
            self._hcg.get_sharding_parallel_world_size() > 1
        )
45 46 47 48

        self.total_loss = None

        self.micro_batch_size = self._strategy.pipeline_configs[
49 50
            'micro_batch_size'
        ]
51
        self.accumulate_steps = self._strategy.pipeline_configs[
52 53
            'accumulate_steps'
        ]
54 55 56
        # If sent tensor are not the same from different hosts,
        # they shouldn't been sent partially and then concated as a whole tensor.
        self._enable_partial_send_recv = self._strategy.pipeline_configs[
57 58
            'enable_partial_send_recv'
        ]
59 60
        self._using_cache = self._strategy.pipeline_configs['p2p_cache_shape']

61 62
        self.num_stages = self._hcg.get_pipe_parallel_world_size()
        self.stage_id = self._hcg.get_stage_id()
63
        self.pp_group = self._hcg.get_pipe_parallel_group()
64
        self.dp_group = self._hcg.get_data_parallel_group()
65
        self.sharding_group = self._hcg.get_sharding_parallel_group()
66

67 68 69 70 71
        self._virtual_pp_world_size = None
        self._virtual_pp_rank = None
        self._real_pp_world_size = self.num_stages
        self._real_pp_rank = self.stage_id

72 73 74
        self._delay_scale_loss = self._strategy.hybrid_configs[
            "pp_configs"
        ].delay_scale_loss
75 76
        # TODO(PP Dev): support dp_comm_overlap without use_main_grad training.
        # This combination will trigger inplace check error during `reshape_` in funct `_split_tensors`.
77 78 79
        self._dp_comm_overlap = self._strategy.hybrid_configs[
            "pp_configs"
        ].dp_comm_overlap
80 81 82
        self._sharding_comm_overlap = self._strategy.hybrid_configs[
            "pp_configs"
        ].sharding_comm_overlap
Y
Yuang Liu 已提交
83 84 85
        self._enable_timer = self._strategy.hybrid_configs[
            "pp_configs"
        ].enable_timer
86

87 88 89
        if self._dp_comm_overlap:
            assert self.use_data_parallel and self.num_stages > 1

90 91 92 93 94 95 96 97 98 99 100 101
        if self._sharding_comm_overlap:
            assert self.use_sharding_parallel and self.num_stages > 1

        assert not (
            self._dp_comm_overlap and self._sharding_comm_overlap
        ), "Cannot use dp pp overlap and sharding pp overlap at the same time."

        self._comm_buffers = []
        self._comm_overlap = (
            self._dp_comm_overlap or self._sharding_comm_overlap
        )

Y
Yuang Liu 已提交
102 103 104 105 106
        if self._enable_timer:
            if not timer.is_timer_initialized():
                timer.set_timers()
            self.timers = timer.get_timers()

107
        p2p.initialize_p2p_groups(
Y
Yuang Liu 已提交
108 109 110 111
            hcg,
            self._using_cache,
            self._enable_partial_send_recv,
            self._enable_timer,
112
        )
113 114

        self.global_rank = self._hcg.get_global_rank()
115
        self.micro_batch_id = 0
116

117 118
        self._compute_loss = True

119 120 121 122 123
        logger.info(
            "Pipeline Info -- num_stages: {}, stage_id: {}".format(
                self.num_stages, self.stage_id
            )
        )
124 125 126 127 128

        if self.use_model_parallel:
            logger.info("start broadcast mp parameters")
            broadcast_mp_parameters(self._layers, self._hcg)

129 130 131 132
        if self.use_sharding_parallel:
            logger.info("start broadcast sharding parameters")
            broadcast_sharding_parameters(self._layers, self._hcg)

133
        if self.use_data_parallel:
134
            logger.info("start broadcast dp parameters")
135
            broadcast_dp_parameters(self._layers, self._hcg)
136

137 138
        if self._dp_comm_overlap:
            self.register_allreduce_overlap_hook(
139
                self._layers, self.dp_group, self.accumulate_steps, True
140 141
            )

142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
    def is_pipeline_first_stage(self, ignore_virtual=False):
        if not ignore_virtual:
            if self._virtual_pp_world_size is not None:
                assert self._virtual_pp_rank is not None
                if self._virtual_pp_rank != 0:
                    return False
        assert self._real_pp_rank is not None
        return self._real_pp_rank == 0

    def is_pipeline_last_stage(self, ignore_virtual=False):
        if not ignore_virtual:
            if self._virtual_pp_world_size is not None:
                assert self._virtual_pp_rank is not None
                if self._virtual_pp_rank != (self._virtual_pp_world_size - 1):
                    return False
        assert self._real_pp_rank is not None
        assert self._real_pp_world_size is not None
        return self._real_pp_rank == (self._real_pp_world_size - 1)

    def set_virtual_pipeline_rank(self, rank):
        self._virtual_pp_rank = rank

164 165 166 167 168 169 170
    def bw_hook_func(self, buffer, param):
        @paddle.autograd.no_grad()
        def fused_allreduce(*_):
            buffer.add_grad(param)

        return fused_allreduce

171
    def register_allreduce_overlap_hook(self, model, comm_group, acc_steps, dp):
172 173 174 175 176
        if model.get_num_virtual_stages() > 1:
            models = model.get_model_chunks()
        else:
            models = [model]

177 178 179 180 181 182 183 184 185
        if not dp:
            assert hasattr(self, "optimizer")
            assert hasattr(self.optimizer, "_param2rank")
            _param2rank = self.optimizer._param2rank

        act = HOOK_ACTION.ALL_REDUCE if dp else HOOK_ACTION.REDUCE

        fused_parameter_group = {}

186 187 188 189 190 191 192 193 194
        for model in models:
            # For virtual pipeline. Will separate parameters in different chunk into
            # different groups to get the best performance.
            parameter_list = [
                p for p in model.parameters() if not p.stop_gradient
            ]
            if len(parameter_list) < 1:
                return

195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
            if dp:
                fused_parameter_group[-1] = parameter_list
            else:
                # Sort parameters for sharding, since they have different dst rank
                for p in parameter_list:
                    assert p.name in _param2rank
                    dst_rank = _param2rank[p.name]
                    if dst_rank in fused_parameter_group:
                        fused_parameter_group[dst_rank].append(p)
                    else:
                        fused_parameter_group[dst_rank] = [p]

            for dst in fused_parameter_group:
                parameter_list = fused_parameter_group[dst]
                if not dp:
                    # parse the relative dst rank to absolute dst rank for sharding
                    dst = comm_group.ranks[dst]
                var_groups = assign_group_by_size(parameter_list)
                for group_idx, parameters in var_groups.items():
                    buffer = FusedCommBuffer(
                        group_idx, parameters, comm_group, acc_steps, act, dst
216
                    )
217 218 219 220 221
                    self._comm_buffers.append(buffer)
                    for param in parameters:
                        param._register_backward_hook(
                            self.bw_hook_func(buffer, param)
                        )
222

Y
Yuang Liu 已提交
223 224 225 226 227 228
    def timer_printer(self):
        if not self._enable_timer:
            return
        all_flag_names = self.timers.timers.keys()
        self.timers.log(all_flag_names)

229 230 231 232
    def forward_backward_pipeline(self, data, scaler=None):
        # use the 1f1b scheduling strategy.
        # this strategy is inspired by:
        # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
233

234 235
        self.scaler = scaler

236 237
        # store data for train
        self.data = data
238

239 240 241
        # store total loss of entire batch
        self.total_loss = None

242 243
        # store data id for micro_batch
        self.micro_batch_id = 0
244

245
        startup_steps = self.num_stages - self.stage_id - 1
246 247
        startup_steps = min(startup_steps, self.accumulate_steps)
        steady_steps = self.accumulate_steps - startup_steps
248

249 250
        input_buffers = []
        output_buffers = []
251

252
        for step_id in range(startup_steps):
253
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
254

255
            output_tensor = self._forward_step(input_tensor)
256
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
257

258 259
            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)
260

261 262 263
            if not self.is_pipeline_last_stage():
                self._release_output(output_tensor)

264
        if steady_steps > 0:
265
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
266

267
        for i in range(steady_steps):
268
            last_iter = i == (steady_steps - 1)
269

270
            output_tensor = self._forward_step(input_tensor)
271

272
            output_tensor_grad = p2p.send_forward_recv_backward(
273 274
                output_tensor, self.is_pipeline_last_stage()
            )
275

276 277
            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)
278

279 280 281
            if not self.is_pipeline_last_stage():
                self._release_output(output_tensor)

282
            input_tensor, output_tensor = input_buffers.pop(
283 284
                0
            ), output_buffers.pop(0)
285

286 287 288
            input_tensor_grad = self._backward_step(
                input_tensor, output_tensor, output_tensor_grad
            )
289 290 291

            if last_iter:
                input_tensor = None
292 293 294
                p2p.send_backward(
                    input_tensor_grad, self.is_pipeline_first_stage()
                )
295
            else:
296
                input_tensor = p2p.send_backward_recv_forward(
297 298
                    input_tensor_grad, self.is_pipeline_first_stage()
                )
299

300 301 302
        for i in range(startup_steps):
            input_tensor = input_buffers.pop(0)
            output_tensor = output_buffers.pop(0)
303

304
            output_tensor_grad = p2p.recv_backward(
305 306
                self.is_pipeline_last_stage()
            )
307

308 309 310
            input_tensor_grad = self._backward_step(
                input_tensor, output_tensor, output_tensor_grad
            )
311
            p2p.send_backward(input_tensor_grad, self.is_pipeline_first_stage())
312

313 314 315
        if self._comm_overlap:
            assert len(self._comm_buffers) > 0
            for buffer in self._comm_buffers:
316 317
                buffer.scale_and_split_grads()

Y
Yuang Liu 已提交
318 319
        if self._enable_timer:
            self.timers("allreduce_shared_weight_gradients").start()
320
        self._layers.allreduce_shared_weight_gradients()
Y
Yuang Liu 已提交
321 322 323
        if self._enable_timer:
            self.timers("allreduce_shared_weight_gradients").stop()
            self.timers("broadcast_final_loss").start()
324 325
        with paddle.amp.auto_cast(enable=False):
            train_loss = self._broadcast_final_loss()
Y
Yuang Liu 已提交
326 327 328
        if self._enable_timer:
            self.timers("broadcast_final_loss").stop()
        self.timer_printer()
329 330
        return train_loss

331 332 333 334
    def _prepare_training(self, data, optimizer, lr_scheduler):
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

335 336 337
        assert isinstance(
            optimizer, HybridParallelOptimizer
        ), 'optimizer should be HybridParallelOptimizer subclass.'
338

339
        assert (
340
            framework._dygraph_tracer()._has_grad
341
        ), 'Please enable the generation of gradients.'
342

343
        if self.is_pipeline_first_stage(
344 345 346 347 348
            ignore_virtual=True
        ) or self.is_pipeline_last_stage(ignore_virtual=True):
            assert (
                data is not None
            ), "For the first and the last stage, the data must be set."
349 350 351 352 353 354 355 356
        else:
            data = None

        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler

        self._layers.train()

357 358 359 360 361
        if self._sharding_comm_overlap and len(self._comm_buffers) == 0:
            self.register_allreduce_overlap_hook(
                self._layers, self.sharding_group, self.accumulate_steps, False
            )

362 363 364 365 366
        return data

    def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
        data = self._prepare_training(data, optimizer, lr_scheduler)
        # 1f1b scheduler for pipeline parallel
367
        train_loss = self.forward_backward_pipeline(data, scaler)
368 369

        # optimizer
370 371
        with paddle.amp.auto_cast(enable=False):
            self._optimizer_step()
372 373

        return train_loss
374

375
    def eval_batch(self, data, compute_loss=False):
376 377 378
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

379 380 381 382 383 384 385 386 387 388 389
        self._layers.eval()
        self._compute_loss = compute_loss

        # save data for eval
        self.data = data
        # store data id for micro_batch
        self.micro_batch_id = 0

        # store total loss of entire batch
        self.total_loss = None

390
        startup_steps = self.num_stages - self.stage_id - 1
391 392 393 394 395 396 397
        startup_steps = min(startup_steps, self.accumulate_steps)
        steady_steps = self.accumulate_steps - startup_steps

        input_buffers = []
        output_buffers = []

        for step_id in range(startup_steps):
398
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
399 400

            output_tensor = self._forward_step(input_tensor)
401
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
402 403 404 405 406

            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)

        if steady_steps > 0:
407
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
408 409

        for i in range(steady_steps):
410
            last_iter = i == (steady_steps - 1)
411 412

            output_tensor = self._forward_step(input_tensor)
413
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
414 415 416 417 418

            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)

            if not last_iter:
419
                input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
420

421 422 423 424 425 426
        if self._compute_loss:
            self.train_loss = self._broadcast_final_loss()
        else:
            self.train_loss = output_buffers

        return self.train_loss
427

428
    def _forward_step(self, input_tensor, chunk_id=None):
Y
Yuang Liu 已提交
429 430
        if self._enable_timer:
            self.timers("forward_step").start()
431
        if self.is_pipeline_first_stage():
432 433
            input_tensor = self._load_micro_batch(self.micro_batch_id)

434 435 436
        assert chunk_id is None or isinstance(chunk_id, int)

        output_tensor = self._layers.forward(input_tensor, chunk_id=chunk_id)
437

438
        if self.is_pipeline_last_stage():
439 440
            # train calculate loss for train
            if self._compute_loss:
441 442 443
                assert (
                    self._layers._loss_fn is not None
                ), "loss function should exist to compute loss"
444 445
                labels = self._load_micro_batch(self.micro_batch_id)
                output_tensor = self._layers._loss_fn(output_tensor, labels)
446
                assert isinstance(
447
                    output_tensor, (paddle.Tensor, framework.core.eager.Tensor)
448
                ), "Currently, loss_fn should obtain Paddle.Tensor dtype"
449

450
                with paddle.amp.auto_cast(enable=False):
451
                    if self.accumulate_steps > 1 and not self._delay_scale_loss:
452
                        output_tensor = output_tensor / self.accumulate_steps
453

454 455 456
                    if self.total_loss is None:
                        self.total_loss = paddle.zeros_like(output_tensor)
                    self.total_loss += output_tensor.detach()
457

458 459 460 461
        if self.is_pipeline_first_stage() or self.is_pipeline_last_stage():
            # Only increase micro batch id at virtual first/last pp stage.
            # The micro batch id is used to load data, therefore, only increase it when load data.
            self.micro_batch_id += 1
Y
Yuang Liu 已提交
462 463
        if self._enable_timer:
            self.timers("forward_step").stop()
464 465 466
        return output_tensor

    def _backward_step(self, input_tensor, output_tensor, output_tensor_grad):
Y
Yuang Liu 已提交
467 468
        if self._enable_timer:
            self.timers("backward_step").start()
S
ShenLiang 已提交
469
        with paddle.amp.auto_cast(enable=False):
470
            if self.is_pipeline_last_stage():
S
ShenLiang 已提交
471 472 473 474 475
                assert output_tensor_grad is None
                if self.scaler:
                    paddle.autograd.backward(self.scaler.scale(output_tensor))
                else:
                    paddle.autograd.backward(output_tensor)
476
            else:
S
ShenLiang 已提交
477 478 479 480 481
                if isinstance(output_tensor, tuple):
                    outputs = [t for t in output_tensor if not t.stop_gradient]
                    assert len(outputs) == len(output_tensor_grad)
                    paddle.autograd.backward(
                        tensors=outputs,
482
                        grad_tensors=list(output_tensor_grad),
483
                    )
S
ShenLiang 已提交
484
                else:
485 486 487 488
                    paddle.autograd.backward(
                        tensors=[output_tensor],
                        grad_tensors=[output_tensor_grad],
                    )
S
ShenLiang 已提交
489 490 491 492 493

            input_tensor_grad = None
            if input_tensor is not None:
                if isinstance(input_tensor, tuple):
                    input_tensor_grad = tuple(
494 495
                        [t.grad for t in input_tensor if not t.stop_gradient]
                    )
S
ShenLiang 已提交
496 497
                else:
                    input_tensor_grad = input_tensor.grad
Y
Yuang Liu 已提交
498 499
            if self._enable_timer:
                self.timers("backward_step").stop()
S
ShenLiang 已提交
500
            return input_tensor_grad
501

502 503 504 505 506 507 508 509 510
    def _check_data_vaild(self, data):
        batch_size = data.shape[0]
        assert self.micro_batch_size * self.accumulate_steps == batch_size, (
            "batch_size needs to be divisible by micro_batch_size. Currently, "
            "batch_size = %d, micro_batch_size = %d, accumulate_steps = %d."
            % (batch_size, self.micro_batch_size, self.accumulate_steps)
        )

    def _load_micro_batch_impl(self, inputs, cache_id):
511 512 513
        begin = cache_id * self.micro_batch_size
        end = begin + self.micro_batch_size

514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543
        if isinstance(inputs, tuple):
            output = []
            for data in inputs:
                if isinstance(data, list):
                    assert (
                        len(data) == self.accumulate_steps
                    ), "length of data should be %d, but it is %d" % (
                        self.accumulate_steps,
                        len(data),
                    )
                    output.append(data[cache_id].detach())
                else:
                    self._check_data_vaild(data)
                    output.append(data[begin:end, :].detach())
            return tuple(output)

        elif isinstance(inputs, list):
            assert (
                len(inputs) == self.accumulate_steps
            ), "length of data should be %d, but it is %d" % (
                self.accumulate_steps,
                len(inputs),
            )
            return inputs[cache_id].detach()
        else:
            self._check_data_vaild(inputs)
            return inputs[begin:end, :].detach()

    def _load_micro_batch(self, cache_id):
        inputs = self.data
544
        if self.is_pipeline_first_stage():
545
            assert len(inputs) == 2, "length of input should be 2"
546
            return self._load_micro_batch_impl(inputs[0], cache_id)
547
        elif self.is_pipeline_last_stage():
548
            assert len(inputs) == 2, "length of input should be 2"
549
            return self._load_micro_batch_impl(inputs[1], cache_id)
550 551
        else:
            inputs = None
552

553
    def _broadcast_final_loss(self):
554 555 556
        # Since the last backward run in interleave will set the virtual rank to 0,
        # here we need to check last stage ignoring virtual stage.
        if self.is_pipeline_last_stage(ignore_virtual=True):
557 558 559
            assert (
                self.total_loss is not None
            ), "train_batch() in last stage should obtain vaild loss"
560 561 562 563 564
            loss = (
                self.total_loss.detach()
                if not self._delay_scale_loss
                else self.total_loss / self.accumulate_steps
            )
565
            is_fp32 = (
566
                paddle.full([], 1, 'int64')
567
                if loss.dtype == paddle.float32
568
                else paddle.full([], 0, 'int64')
569 570 571 572 573 574 575
            )
            paddle.distributed.broadcast(
                is_fp32, src=self.global_rank, sync_op=True, group=self.pp_group
            )
            paddle.distributed.broadcast(
                loss, src=self.global_rank, sync_op=True, group=self.pp_group
            )
576
        else:
577
            is_fp32 = paddle.full([], 1, 'int64')
578 579 580
            paddle.distributed.broadcast(
                is_fp32,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
581
                sync_op=True,
582 583 584 585
                group=self.pp_group,
            )
            loss = (
                paddle.zeros(shape=[1], dtype="float32")
586
                if is_fp32.item()
587 588
                else paddle.zeros(shape=[1], dtype="float16")
            )
589 590 591
            paddle.distributed.broadcast(
                loss,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
592
                sync_op=True,
593 594
                group=self.pp_group,
            )
595
        return loss
596

597
    def _optimizer_step(self):
598 599 600 601 602 603 604 605
        if self._delay_scale_loss:
            for p in self._layers.parameters():
                if hasattr(p, "main_grad") and p.main_grad is not None:
                    assert p.grad is None
                    p.main_grad = p.main_grad.scale(1.0 / self.accumulate_steps)
                elif p.grad is not None:
                    p.grad = p.grad.scale(1.0 / self.accumulate_steps)

606
        if self.scaler:
607
            self.scaler.step(self.optimizer)
S
ShenLiang 已提交
608
            self.scaler.update()
609 610
        else:
            self.optimizer.step()
611

612 613 614
        self.optimizer.clear_grad()
        if self.lr_scheduler:
            self.lr_scheduler.step()
615

616 617 618 619 620 621 622 623
    def _release_output(self, output):
        if isinstance(output, (tuple, list)):
            for t in output:
                if t is not None and isinstance(t, paddle.Tensor):
                    t._clear_dataptr()
        elif output is not None and isinstance(output, paddle.Tensor):
            output._clear_dataptr()

624 625 626 627 628

class PipelineParallelWithInterleave(PipelineParallel):
    # pipeline parallel with interleave scheduler

    def __init__(self, layers, hcg, strategy):
629
        super().__init__(layers=layers, hcg=hcg, strategy=strategy)
630
        assert layers.get_num_virtual_stages() > 1
631 632 633
        assert (
            self.num_stages > 2
        ), "virtual pipeline must run under pp degree > 2"
634
        assert (
635
            framework.in_dynamic_mode()
636
        ), "virtual pipeline stage with interleave only support eager dygraph mode"
637 638 639
        assert (
            self.accumulate_steps % self.num_stages == 0
        ), "accumulate_steps should be evenly divisible by num_stages for pipeline with interleave"
640 641 642 643 644 645 646 647 648
        # setup for interleave scheduler
        self.num_model_chunks = layers.get_num_virtual_stages()
        self.model_chunks = layers.get_model_chunks()
        assert self.model_chunks is not None
        assert len(self.model_chunks) == self.num_model_chunks
        self._virtual_pp_world_size = self.num_model_chunks
        self._virtual_pp_rank = 0

    def _get_virtual_pp_rank(self, micro_step, forward):
649 650 651
        virtual_pp_stage = micro_step % (
            self.num_stages * self.num_model_chunks
        )
652 653
        virtual_pp_stage = virtual_pp_stage // self.num_stages
        if not forward:
654
            virtual_pp_stage = self.num_model_chunks - virtual_pp_stage - 1
655 656 657 658 659 660 661 662 663 664 665
        return virtual_pp_stage

    def _forward_step_helper(self, micro_step):
        virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=True)
        self.set_virtual_pipeline_rank(virtual_pp_rank)

        # some checkers
        assert hasattr(self, 'input_tensors')
        assert hasattr(self, 'output_tensors')
        if not self._forward_only:
            assert hasattr(self, 'output_tensor_grads')
666 667 668
        assert len(self.input_tensors[virtual_pp_rank]) == (
            len(self.output_tensors[virtual_pp_rank]) + 1
        )
669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688
        input_tensor = self.input_tensors[virtual_pp_rank][-1]
        output_tensor = self._forward_step(input_tensor, virtual_pp_rank)
        self.output_tensors[virtual_pp_rank].append(output_tensor)

        if self._forward_only:
            # no need to store tensor for backward
            self.input_tensors[virtual_pp_rank].pop()
            self.output_tensors[virtual_pp_rank].pop()

        return output_tensor

    def _backward_step_helper(self, micro_step):
        virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=False)
        self.set_virtual_pipeline_rank(virtual_pp_rank)

        # some checkers
        assert hasattr(self, 'input_tensors')
        assert hasattr(self, 'output_tensors')
        assert hasattr(self, 'output_tensor_grads')

689 690 691 692 693 694
        assert (
            len(self.output_tensor_grads[virtual_pp_rank]) == 1
        ), f"output_tensor_grads is empty for virtual_pp_rank {virtual_pp_rank}"

        assert len(self.input_tensors[virtual_pp_rank]) > 0
        assert len(self.output_tensors[virtual_pp_rank]) > 0
695 696 697 698

        input_tensor = self.input_tensors[virtual_pp_rank].pop(0)
        output_tensor = self.output_tensors[virtual_pp_rank].pop(0)
        output_tensor_grad = self.output_tensor_grads[virtual_pp_rank].pop(0)
699 700 701
        input_tensor_grad = self._backward_step(
            input_tensor, output_tensor, output_tensor_grad
        )
702 703 704

        return input_tensor_grad

705
    def forward_backward_pipeline(
706 707
        self, data, scaler, forward_only=False, compute_loss=True
    ):
708 709 710 711
        # use interleave scheduling strategy.
        # this strategy is inspired by:
        # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
        if not compute_loss:
712 713 714
            assert (
                not forward_only
            ), "compute_loss can only be set to False when forward_only is set to True"
715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732

        # init some attributes for this batch run
        self.scaler = scaler
        self.data = data
        self.total_loss = None
        self.micro_batch_id = 0
        self._forward_only = forward_only

        # init some data buffers for interleave scheduler
        self.input_tensors = [[] for _ in range(self.num_model_chunks)]
        self.output_tensors = [[] for _ in range(self.num_model_chunks)]
        self.output_tensor_grads = [[] for _ in range(self.num_model_chunks)]

        num_steps = self.accumulate_steps * self.num_model_chunks
        if forward_only:
            # If only forward, since there is no backward during running, all steps are startup steps
            startup_steps = num_steps
        else:
733 734 735 736 737 738 739
            # actually startup_steps is calculated from two number:
            # first_forward_cross_to_end = (self.num_stages - self.stage_id - 1) + (self.num_model_chunks - 1) * self.num_stages
            # end_to_first_backward_cross = (self.num_stages - self.stage_id - 1)
            # startup_steps = first_forward_cross_to_end + end_to_first_backward_cross
            startup_steps = (self.num_stages - self.stage_id - 1) * 2
            startup_steps += (self.num_model_chunks - 1) * self.num_stages
            startup_steps = min(startup_steps, num_steps)
740 741 742 743 744

        steady_steps = num_steps - startup_steps

        self.set_virtual_pipeline_rank(0)
        self.input_tensors[0].append(
745 746
            p2p.recv_forward(self.is_pipeline_first_stage(), sync_recv=False)
        )
747 748 749 750 751 752

        # run startup steps
        for micro_step in range(startup_steps):
            output_tensor = self._forward_step_helper(micro_step)

            # determine whether recv forward tensor or not
753 754 755
            next_virtual_pp_rank = self._get_virtual_pp_rank(
                micro_step + 1, forward=True
            )
756 757 758 759 760 761 762 763 764 765 766 767 768
            recv_prev = True
            if self.is_pipeline_first_stage(ignore_virtual=True):
                if next_virtual_pp_rank == 0:
                    # next chunk is the first chunk, not need to pre recv an input tensor
                    recv_prev = False
            # last micro step, no next run
            if micro_step == (num_steps - 1):
                recv_prev = False

            # last stage shouldn't send tensor to downstream
            if self.is_pipeline_last_stage():
                output_tensor = None

769
            if micro_step == (startup_steps - 1) and not forward_only:
770 771 772 773 774 775
                input_tensor_grad = None
                recv_next = True
                if self.is_pipeline_last_stage(ignore_virtual=True):
                    recv_next = False

                # the last startup step needs on four direction comm to set up for steady 1f1b
776 777 778 779
                (
                    input_tensor,
                    output_tensor_grad,
                ) = p2p.send_forward_backward_recv_forward_backward(
780 781 782
                    output_tensor,
                    input_tensor_grad,
                    recv_prev=recv_prev,
783 784
                    recv_next=recv_next,
                )
785 786
                # output_tensor_grad is not none if recv_next
                # append output_tensor_grad no matter none or not
787 788 789
                self.output_tensor_grads[self.num_model_chunks - 1].append(
                    output_tensor_grad
                )
790 791
            else:
                input_tensor = p2p.send_forward_recv_forward(
792 793
                    output_tensor, recv_prev=recv_prev
                )
794
            # append input_tensor no matter none or not
795 796
            self.input_tensors[next_virtual_pp_rank].append(input_tensor)

797 798
            self._release_output(output_tensor)

799 800 801 802 803 804 805 806 807
        # run 1f1b steady steps
        for micro_step in range(steady_steps):
            # forward
            forward_micro_step_id = micro_step + startup_steps
            output_tensor = self._forward_step_helper(forward_micro_step_id)

            # backward
            backward_micro_step_id = micro_step
            input_tensor_grad = self._backward_step_helper(
808 809
                backward_micro_step_id
            )
810 811 812 813 814 815 816 817 818

            # four directions comm
            # send output tensor to downstream
            # send input tensor grad to upstream
            # recv input tensor from upstream
            # recv output tensor grad from downstream

            # last stage doesn't send rst to downstream
            forward_virtual_pp_rank = self._get_virtual_pp_rank(
819 820
                forward_micro_step_id, forward=True
            )
821 822 823 824 825 826
            self.set_virtual_pipeline_rank(forward_virtual_pp_rank)
            if self.is_pipeline_last_stage():
                output_tensor = None

            # first stage doesn't send grad to upstream
            backward_virtual_pp_rank = self._get_virtual_pp_rank(
827 828
                backward_micro_step_id, forward=False
            )
829 830 831 832 833 834
            self.set_virtual_pipeline_rank(backward_virtual_pp_rank)
            if self.is_pipeline_first_stage():
                input_tensor_grad = None

            # determine whether to recv input tensor from upstream
            recv_prev = True
835 836 837 838 839 840 841 842
            next_forward_virtual_pp_rank = self._get_virtual_pp_rank(
                forward_micro_step_id + 1, forward=True
            )
            if self.is_pipeline_first_stage(ignore_virtual=True) and (
                next_forward_virtual_pp_rank == 0
            ):
                # first pp stage and first virtual stage
                recv_prev = False
843 844 845 846 847 848 849

            # last iteration doesn't need recv from upstream
            if micro_step == (steady_steps - 1):
                recv_prev = False

            # determine whether to recv grad from downstream
            recv_next = True
850 851 852 853 854 855 856 857
            next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
                backward_micro_step_id + 1, forward=False
            )
            if self.is_pipeline_last_stage(ignore_virtual=True) and (
                next_backward_virtual_pp_rank == (self.num_model_chunks - 1)
            ):
                # last pp stage and last virtual stage
                recv_next = False
858

859 860 861 862
            (
                input_tensor,
                output_tensor_grad,
            ) = p2p.send_forward_backward_recv_forward_backward(
863 864 865
                output_tensor,
                input_tensor_grad,
                recv_prev=recv_prev,
866 867
                recv_next=recv_next,
            )
868 869 870 871 872 873 874 875
            # append input_tensor no matter none or not
            self.input_tensors[next_forward_virtual_pp_rank].append(
                input_tensor
            )
            # append output_tensor_grad no matter none or not
            self.output_tensor_grads[next_backward_virtual_pp_rank].append(
                output_tensor_grad
            )
876 877 878
            self._release_output(output_tensor)

        self._release_output(output_tensor)
879 880 881 882 883 884 885

        # remaining backward steps
        if not forward_only:
            for micro_step in range(steady_steps, num_steps):
                # cooldown loop
                input_tensor_grad = self._backward_step_helper(micro_step)
                next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
886 887
                    micro_step + 1, forward=False
                )
888 889 890

                recv_next = True
                if self.is_pipeline_last_stage(ignore_virtual=True):
891 892 893
                    if next_backward_virtual_pp_rank == (
                        self.num_model_chunks - 1
                    ):
894 895 896 897
                        recv_next = False

                if micro_step == (num_steps - 1):
                    recv_next = False
898
                # append output_tensor_grad no matter none or not
899
                self.output_tensor_grads[next_backward_virtual_pp_rank].append(
900 901 902 903
                    p2p.send_backward_recv_backward(
                        input_tensor_grad, recv_next=recv_next
                    )
                )
904

905 906 907
            if self._comm_overlap:
                assert len(self._comm_buffers) > 0
                for buffer in self._comm_buffers:
908 909
                    buffer.scale_and_split_grads()

Y
Yuang Liu 已提交
910 911
            if self._enable_timer:
                self.timers("allreduce_shared_weight_gradients").start()
912
            self._layers.allreduce_shared_weight_gradients()
Y
Yuang Liu 已提交
913 914
            if self._enable_timer:
                self.timers("allreduce_shared_weight_gradients").stop()
915 916 917

        if compute_loss:
            # return loss if compute loss
Y
Yuang Liu 已提交
918 919
            if self._enable_timer:
                self.timers("broadcast_final_loss").start()
920 921
            with paddle.amp.auto_cast(enable=False):
                train_loss = self._broadcast_final_loss()
Y
Yuang Liu 已提交
922 923
            if self._enable_timer:
                self.timers("broadcast_final_loss").stop()
924 925 926 927
        else:
            # else just return all intermediate output tensor for all micro steps
            train_loss = self.output_tensors

Y
Yuang Liu 已提交
928
        self.timer_printer()
929 930 931 932 933
        return train_loss

    def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
        data = self._prepare_training(data, optimizer, lr_scheduler)
        # interleave scheduler for pipeline parallel
934
        train_loss = self.forward_backward_pipeline(data, scaler)
935 936 937 938 939 940 941 942 943 944 945 946 947 948

        # optimizer
        with paddle.amp.auto_cast(enable=False):
            self._optimizer_step()

        return train_loss

    def eval_batch(self, data, compute_loss=False):
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

        self._layers.eval()
        self._compute_loss = compute_loss

949
        return self.forward_backward_pipeline(data, None, forward_only=True)