pipeline_parallel.py 31.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

import paddle
15
from paddle import framework
16

17
from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer
18 19 20 21 22 23 24 25
from ..utils.hybrid_parallel_util import (
    broadcast_dp_parameters,
    broadcast_mp_parameters,
    broadcast_sharding_parameters,
)
from ..utils.log_util import logger
from .meta_parallel_base import MetaParallelBase
from .parallel_layers.pp_layers import PipelineLayer
S
ShenLiang 已提交
26
from .pp_utils import p2p_communication as p2p
27
from .pp_utils.utils import FusedAllReduceBuffer, assign_group_by_size
28

29 30
__all__ = []

31 32 33

class PipelineParallel(MetaParallelBase):
    def __init__(self, layers, hcg, strategy):
34 35
        if not isinstance(layers, PipelineLayer):
            raise TypeError(
36 37
                "The Layer should be a derived class of PipelineLayer."
            )
38
        super().__init__(layers, hcg, strategy)
39 40
        self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1
        self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1
41 42 43
        self.use_sharding_parallel = (
            self._hcg.get_sharding_parallel_world_size() > 1
        )
44 45 46 47

        self.total_loss = None

        self.micro_batch_size = self._strategy.pipeline_configs[
48 49
            'micro_batch_size'
        ]
50
        self.accumulate_steps = self._strategy.pipeline_configs[
51 52
            'accumulate_steps'
        ]
53 54 55
        # If sent tensor are not the same from different hosts,
        # they shouldn't been sent partially and then concated as a whole tensor.
        self._enable_partial_send_recv = self._strategy.pipeline_configs[
56 57
            'enable_partial_send_recv'
        ]
58 59
        self._using_cache = self._strategy.pipeline_configs['p2p_cache_shape']

60 61
        self.num_stages = self._hcg.get_pipe_parallel_world_size()
        self.stage_id = self._hcg.get_stage_id()
62
        self.pp_group = self._hcg.get_pipe_parallel_group()
63
        self.dp_group = self._hcg.get_data_parallel_group()
64

65 66 67 68 69
        self._virtual_pp_world_size = None
        self._virtual_pp_rank = None
        self._real_pp_world_size = self.num_stages
        self._real_pp_rank = self.stage_id

70 71 72 73 74 75 76 77
        self._delay_scale_loss = self._strategy.hybrid_configs[
            "pp_configs"
        ].delay_scale_loss
        self._dp_comm_overlap = self._strategy.hybrid_configs[
            "pp_configs"
        ].dp_comm_overlap
        self._dp_comm_buffers = []

78 79 80
        if self._dp_comm_overlap:
            assert self.use_data_parallel and self.num_stages > 1

81 82 83
        p2p.initialize_p2p_groups(
            hcg, self._using_cache, self._enable_partial_send_recv
        )
84 85

        self.global_rank = self._hcg.get_global_rank()
86
        self.micro_batch_id = 0
87

88 89
        self._compute_loss = True

90 91 92 93 94
        logger.info(
            "Pipeline Info -- num_stages: {}, stage_id: {}".format(
                self.num_stages, self.stage_id
            )
        )
95 96 97 98 99

        if self.use_model_parallel:
            logger.info("start broadcast mp parameters")
            broadcast_mp_parameters(self._layers, self._hcg)

100 101 102 103
        if self.use_sharding_parallel:
            logger.info("start broadcast sharding parameters")
            broadcast_sharding_parameters(self._layers, self._hcg)

104
        if self.use_data_parallel:
105
            logger.info("start broadcast dp parameters")
106
            broadcast_dp_parameters(self._layers, self._hcg)
107

108 109 110 111 112
        if self._dp_comm_overlap:
            self.register_allreduce_overlap_hook(
                self._layers, self.dp_group, self.accumulate_steps
            )

113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
    def is_pipeline_first_stage(self, ignore_virtual=False):
        if not ignore_virtual:
            if self._virtual_pp_world_size is not None:
                assert self._virtual_pp_rank is not None
                if self._virtual_pp_rank != 0:
                    return False
        assert self._real_pp_rank is not None
        return self._real_pp_rank == 0

    def is_pipeline_last_stage(self, ignore_virtual=False):
        if not ignore_virtual:
            if self._virtual_pp_world_size is not None:
                assert self._virtual_pp_rank is not None
                if self._virtual_pp_rank != (self._virtual_pp_world_size - 1):
                    return False
        assert self._real_pp_rank is not None
        assert self._real_pp_world_size is not None
        return self._real_pp_rank == (self._real_pp_world_size - 1)

    def set_virtual_pipeline_rank(self, rank):
        self._virtual_pp_rank = rank

135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
    def bw_hook_func(self, buffer, param):
        @paddle.autograd.no_grad()
        def fused_allreduce(*_):
            buffer.add_grad(param)

        return fused_allreduce

    def register_allreduce_overlap_hook(self, model, comm_group, acc_steps):
        parameter_list = [p for p in model.parameters() if not p.stop_gradient]
        if len(parameter_list) < 1:
            return

        var_groups = assign_group_by_size(parameter_list)
        for group_idx, parameters in var_groups.items():
            buffer = FusedAllReduceBuffer(
                group_idx, parameters, comm_group, acc_steps
            )
            self._dp_comm_buffers.append(buffer)
            for param in parameters:
                param._register_backward_hook(self.bw_hook_func(buffer, param))

156 157 158 159
    def forward_backward_pipeline(self, data, scaler=None):
        # use the 1f1b scheduling strategy.
        # this strategy is inspired by:
        # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
160

161 162
        self.scaler = scaler

163 164
        # store data for train
        self.data = data
165

166 167 168
        # store total loss of entire batch
        self.total_loss = None

169 170
        # store data id for micro_batch
        self.micro_batch_id = 0
171

172
        startup_steps = self.num_stages - self.stage_id - 1
173 174
        startup_steps = min(startup_steps, self.accumulate_steps)
        steady_steps = self.accumulate_steps - startup_steps
175

176 177
        input_buffers = []
        output_buffers = []
178

179
        for step_id in range(startup_steps):
180
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
181

182
            output_tensor = self._forward_step(input_tensor)
183
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
184

185 186
            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)
187

188
        if steady_steps > 0:
189
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
190

191
        for i in range(steady_steps):
192
            last_iter = i == (steady_steps - 1)
193

194
            output_tensor = self._forward_step(input_tensor)
195

196
            output_tensor_grad = p2p.send_forward_recv_backward(
197 198
                output_tensor, self.is_pipeline_last_stage()
            )
199

200 201
            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)
202

203
            input_tensor, output_tensor = input_buffers.pop(
204 205
                0
            ), output_buffers.pop(0)
206

207 208 209
            input_tensor_grad = self._backward_step(
                input_tensor, output_tensor, output_tensor_grad
            )
210 211 212

            if last_iter:
                input_tensor = None
213 214 215
                p2p.send_backward(
                    input_tensor_grad, self.is_pipeline_first_stage()
                )
216
            else:
217
                input_tensor = p2p.send_backward_recv_forward(
218 219
                    input_tensor_grad, self.is_pipeline_first_stage()
                )
220

221 222 223
        for i in range(startup_steps):
            input_tensor = input_buffers.pop(0)
            output_tensor = output_buffers.pop(0)
224

225
            output_tensor_grad = p2p.recv_backward(
226 227
                self.is_pipeline_last_stage()
            )
228

229 230 231
            input_tensor_grad = self._backward_step(
                input_tensor, output_tensor, output_tensor_grad
            )
232
            p2p.send_backward(input_tensor_grad, self.is_pipeline_first_stage())
233

234 235 236 237 238
        if self._dp_comm_overlap:
            assert len(self._dp_comm_buffers) > 0
            for buffer in self._dp_comm_buffers:
                buffer.scale_and_split_grads()

239
        self._layers.allreduce_shared_weight_gradients()
240 241
        with paddle.amp.auto_cast(enable=False):
            train_loss = self._broadcast_final_loss()
242 243
        return train_loss

244 245 246 247
    def _prepare_training(self, data, optimizer, lr_scheduler):
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

248 249 250
        assert isinstance(
            optimizer, HybridParallelOptimizer
        ), 'optimizer should be HybridParallelOptimizer subclass.'
251

252
        assert (
253
            framework._dygraph_tracer()._has_grad
254
        ), 'Please enable the generation of gradients.'
255

256
        if self.is_pipeline_first_stage(
257 258 259 260 261
            ignore_virtual=True
        ) or self.is_pipeline_last_stage(ignore_virtual=True):
            assert (
                data is not None
            ), "For the first and the last stage, the data must be set."
262 263 264 265 266 267 268 269
        else:
            data = None

        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler

        self._layers.train()

270 271 272 273 274
        return data

    def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
        data = self._prepare_training(data, optimizer, lr_scheduler)
        # 1f1b scheduler for pipeline parallel
275
        train_loss = self.forward_backward_pipeline(data, scaler)
276 277

        # optimizer
278 279
        with paddle.amp.auto_cast(enable=False):
            self._optimizer_step()
280 281

        return train_loss
282

283
    def eval_batch(self, data, compute_loss=False):
284 285 286
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

287 288 289 290 291 292 293 294 295 296 297
        self._layers.eval()
        self._compute_loss = compute_loss

        # save data for eval
        self.data = data
        # store data id for micro_batch
        self.micro_batch_id = 0

        # store total loss of entire batch
        self.total_loss = None

298
        startup_steps = self.num_stages - self.stage_id - 1
299 300 301 302 303 304 305
        startup_steps = min(startup_steps, self.accumulate_steps)
        steady_steps = self.accumulate_steps - startup_steps

        input_buffers = []
        output_buffers = []

        for step_id in range(startup_steps):
306
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
307 308

            output_tensor = self._forward_step(input_tensor)
309
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
310 311 312 313 314

            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)

        if steady_steps > 0:
315
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
316 317

        for i in range(steady_steps):
318
            last_iter = i == (steady_steps - 1)
319 320

            output_tensor = self._forward_step(input_tensor)
321
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
322 323 324 325 326

            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)

            if not last_iter:
327
                input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
328

329 330 331 332 333 334
        if self._compute_loss:
            self.train_loss = self._broadcast_final_loss()
        else:
            self.train_loss = output_buffers

        return self.train_loss
335

336 337
    def _forward_step(self, input_tensor, chunk_id=None):
        if self.is_pipeline_first_stage():
338 339
            input_tensor = self._load_micro_batch(self.micro_batch_id)

340 341 342
        assert chunk_id is None or isinstance(chunk_id, int)

        output_tensor = self._layers.forward(input_tensor, chunk_id=chunk_id)
343

344
        if self.is_pipeline_last_stage():
345 346
            # train calculate loss for train
            if self._compute_loss:
347 348 349
                assert (
                    self._layers._loss_fn is not None
                ), "loss function should exist to compute loss"
350 351
                labels = self._load_micro_batch(self.micro_batch_id)
                output_tensor = self._layers._loss_fn(output_tensor, labels)
352
                assert isinstance(
353
                    output_tensor, (paddle.Tensor, framework.core.eager.Tensor)
354
                ), "Currently, loss_fn should obtain Paddle.Tensor dtype"
355

356
                with paddle.amp.auto_cast(enable=False):
357
                    if self.accumulate_steps > 1 and not self._delay_scale_loss:
358
                        output_tensor = output_tensor / self.accumulate_steps
359

360 361 362
                    if self.total_loss is None:
                        self.total_loss = paddle.zeros_like(output_tensor)
                    self.total_loss += output_tensor.detach()
363

364 365 366 367
        if self.is_pipeline_first_stage() or self.is_pipeline_last_stage():
            # Only increase micro batch id at virtual first/last pp stage.
            # The micro batch id is used to load data, therefore, only increase it when load data.
            self.micro_batch_id += 1
368 369 370
        return output_tensor

    def _backward_step(self, input_tensor, output_tensor, output_tensor_grad):
S
ShenLiang 已提交
371
        with paddle.amp.auto_cast(enable=False):
372
            if self.is_pipeline_last_stage():
S
ShenLiang 已提交
373 374 375 376 377
                assert output_tensor_grad is None
                if self.scaler:
                    paddle.autograd.backward(self.scaler.scale(output_tensor))
                else:
                    paddle.autograd.backward(output_tensor)
378
            else:
S
ShenLiang 已提交
379 380 381 382 383
                if isinstance(output_tensor, tuple):
                    outputs = [t for t in output_tensor if not t.stop_gradient]
                    assert len(outputs) == len(output_tensor_grad)
                    paddle.autograd.backward(
                        tensors=outputs,
384
                        grad_tensors=list(output_tensor_grad),
385
                    )
S
ShenLiang 已提交
386
                else:
387 388 389 390
                    paddle.autograd.backward(
                        tensors=[output_tensor],
                        grad_tensors=[output_tensor_grad],
                    )
S
ShenLiang 已提交
391 392 393 394 395

            input_tensor_grad = None
            if input_tensor is not None:
                if isinstance(input_tensor, tuple):
                    input_tensor_grad = tuple(
396 397
                        [t.grad for t in input_tensor if not t.stop_gradient]
                    )
S
ShenLiang 已提交
398 399 400
                else:
                    input_tensor_grad = input_tensor.grad
            return input_tensor_grad
401

402 403 404 405 406 407 408 409 410
    def _check_data_vaild(self, data):
        batch_size = data.shape[0]
        assert self.micro_batch_size * self.accumulate_steps == batch_size, (
            "batch_size needs to be divisible by micro_batch_size. Currently, "
            "batch_size = %d, micro_batch_size = %d, accumulate_steps = %d."
            % (batch_size, self.micro_batch_size, self.accumulate_steps)
        )

    def _load_micro_batch_impl(self, inputs, cache_id):
411 412 413
        begin = cache_id * self.micro_batch_size
        end = begin + self.micro_batch_size

414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443
        if isinstance(inputs, tuple):
            output = []
            for data in inputs:
                if isinstance(data, list):
                    assert (
                        len(data) == self.accumulate_steps
                    ), "length of data should be %d, but it is %d" % (
                        self.accumulate_steps,
                        len(data),
                    )
                    output.append(data[cache_id].detach())
                else:
                    self._check_data_vaild(data)
                    output.append(data[begin:end, :].detach())
            return tuple(output)

        elif isinstance(inputs, list):
            assert (
                len(inputs) == self.accumulate_steps
            ), "length of data should be %d, but it is %d" % (
                self.accumulate_steps,
                len(inputs),
            )
            return inputs[cache_id].detach()
        else:
            self._check_data_vaild(inputs)
            return inputs[begin:end, :].detach()

    def _load_micro_batch(self, cache_id):
        inputs = self.data
444
        if self.is_pipeline_first_stage():
445
            assert len(inputs) == 2, "length of input should be 2"
446
            return self._load_micro_batch_impl(inputs[0], cache_id)
447
        elif self.is_pipeline_last_stage():
448
            assert len(inputs) == 2, "length of input should be 2"
449
            return self._load_micro_batch_impl(inputs[1], cache_id)
450 451
        else:
            inputs = None
452

453
    def _broadcast_final_loss(self):
454 455 456
        # Since the last backward run in interleave will set the virtual rank to 0,
        # here we need to check last stage ignoring virtual stage.
        if self.is_pipeline_last_stage(ignore_virtual=True):
457 458 459
            assert (
                self.total_loss is not None
            ), "train_batch() in last stage should obtain vaild loss"
460 461 462 463 464
            loss = (
                self.total_loss.detach()
                if not self._delay_scale_loss
                else self.total_loss / self.accumulate_steps
            )
465
            is_fp32 = (
466
                paddle.full([], 1, 'int64')
467
                if loss.dtype == paddle.float32
468
                else paddle.full([], 0, 'int64')
469 470 471 472 473 474 475
            )
            paddle.distributed.broadcast(
                is_fp32, src=self.global_rank, sync_op=True, group=self.pp_group
            )
            paddle.distributed.broadcast(
                loss, src=self.global_rank, sync_op=True, group=self.pp_group
            )
476
        else:
477
            is_fp32 = paddle.full([], 1, 'int64')
478 479 480
            paddle.distributed.broadcast(
                is_fp32,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
481
                sync_op=True,
482 483 484 485
                group=self.pp_group,
            )
            loss = (
                paddle.zeros(shape=[1], dtype="float32")
486
                if is_fp32.item()
487 488
                else paddle.zeros(shape=[1], dtype="float16")
            )
489 490 491
            paddle.distributed.broadcast(
                loss,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
492
                sync_op=True,
493 494
                group=self.pp_group,
            )
495
        return loss
496

497
    def _optimizer_step(self):
498 499 500 501 502 503 504 505
        if self._delay_scale_loss:
            for p in self._layers.parameters():
                if hasattr(p, "main_grad") and p.main_grad is not None:
                    assert p.grad is None
                    p.main_grad = p.main_grad.scale(1.0 / self.accumulate_steps)
                elif p.grad is not None:
                    p.grad = p.grad.scale(1.0 / self.accumulate_steps)

506
        if self.scaler:
507
            self.scaler.step(self.optimizer)
S
ShenLiang 已提交
508
            self.scaler.update()
509 510
        else:
            self.optimizer.step()
511

512 513 514
        self.optimizer.clear_grad()
        if self.lr_scheduler:
            self.lr_scheduler.step()
515 516 517 518 519 520


class PipelineParallelWithInterleave(PipelineParallel):
    # pipeline parallel with interleave scheduler

    def __init__(self, layers, hcg, strategy):
521
        super().__init__(layers=layers, hcg=hcg, strategy=strategy)
522
        assert layers.get_num_virtual_stages() > 1
523 524
        assert (
            framework.in_dygraph_mode()
525 526 527 528 529 530 531 532 533 534
        ), "virtual pipeline stage with interleave only support eager dygraph mode"
        # setup for interleave scheduler
        self.num_model_chunks = layers.get_num_virtual_stages()
        self.model_chunks = layers.get_model_chunks()
        assert self.model_chunks is not None
        assert len(self.model_chunks) == self.num_model_chunks
        self._virtual_pp_world_size = self.num_model_chunks
        self._virtual_pp_rank = 0

    def _get_virtual_pp_rank(self, micro_step, forward):
535 536 537
        virtual_pp_stage = micro_step % (
            self.num_stages * self.num_model_chunks
        )
538 539
        virtual_pp_stage = virtual_pp_stage // self.num_stages
        if not forward:
540
            virtual_pp_stage = self.num_model_chunks - virtual_pp_stage - 1
541 542 543 544 545 546 547 548 549 550 551 552 553 554
        return virtual_pp_stage

    def _forward_step_helper(self, micro_step):
        virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=True)
        self.set_virtual_pipeline_rank(virtual_pp_rank)

        # some checkers
        assert hasattr(self, 'input_tensors')
        assert hasattr(self, 'output_tensors')
        if not self._forward_only:
            assert hasattr(self, 'output_tensor_grads')

        if self.is_pipeline_first_stage():
            if len(self.input_tensors[virtual_pp_rank]) == len(
555 556
                self.output_tensors[virtual_pp_rank]
            ):
557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584
                self.input_tensors[virtual_pp_rank].append(None)
        input_tensor = self.input_tensors[virtual_pp_rank][-1]
        output_tensor = self._forward_step(input_tensor, virtual_pp_rank)
        self.output_tensors[virtual_pp_rank].append(output_tensor)

        if self._forward_only:
            # no need to store tensor for backward
            self.input_tensors[virtual_pp_rank].pop()
            self.output_tensors[virtual_pp_rank].pop()

        return output_tensor

    def _backward_step_helper(self, micro_step):
        virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=False)
        self.set_virtual_pipeline_rank(virtual_pp_rank)

        # some checkers
        assert hasattr(self, 'input_tensors')
        assert hasattr(self, 'output_tensors')
        assert hasattr(self, 'output_tensor_grads')

        if self.is_pipeline_last_stage():
            if len(self.output_tensor_grads[virtual_pp_rank]) == 0:
                self.output_tensor_grads[virtual_pp_rank].append(None)

        input_tensor = self.input_tensors[virtual_pp_rank].pop(0)
        output_tensor = self.output_tensors[virtual_pp_rank].pop(0)
        output_tensor_grad = self.output_tensor_grads[virtual_pp_rank].pop(0)
585 586 587
        input_tensor_grad = self._backward_step(
            input_tensor, output_tensor, output_tensor_grad
        )
588 589 590

        return input_tensor_grad

591
    def forward_backward_pipeline(
592 593
        self, data, scaler, forward_only=False, compute_loss=True
    ):
594 595 596 597
        # use interleave scheduling strategy.
        # this strategy is inspired by:
        # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
        if not compute_loss:
598 599 600
            assert (
                not forward_only
            ), "compute_loss can only be set to False when forward_only is set to True"
601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631

        # init some attributes for this batch run
        self.scaler = scaler
        self.data = data
        self.total_loss = None
        self.micro_batch_id = 0
        self._forward_only = forward_only

        # init some data buffers for interleave scheduler
        self.input_tensors = [[] for _ in range(self.num_model_chunks)]
        self.output_tensors = [[] for _ in range(self.num_model_chunks)]
        self.output_tensor_grads = [[] for _ in range(self.num_model_chunks)]

        num_steps = self.accumulate_steps * self.num_model_chunks
        all_startup_steps = False
        if forward_only:
            # If only forward, since there is no backward during running, all steps are startup steps
            startup_steps = num_steps
        else:
            if self.accumulate_steps == self.num_stages:
                startup_steps = num_steps
                all_startup_steps = True
            else:
                startup_steps = (self.num_stages - self.stage_id - 1) * 2
                startup_steps += (self.num_model_chunks - 1) * self.num_stages
                startup_steps = min(startup_steps, num_steps)

        steady_steps = num_steps - startup_steps

        self.set_virtual_pipeline_rank(0)
        self.input_tensors[0].append(
632 633
            p2p.recv_forward(self.is_pipeline_first_stage(), sync_recv=False)
        )
634 635 636 637 638 639

        # run startup steps
        for micro_step in range(startup_steps):
            output_tensor = self._forward_step_helper(micro_step)

            # determine whether recv forward tensor or not
640 641 642
            next_virtual_pp_rank = self._get_virtual_pp_rank(
                micro_step + 1, forward=True
            )
643 644 645 646 647 648 649 650 651 652 653 654 655
            recv_prev = True
            if self.is_pipeline_first_stage(ignore_virtual=True):
                if next_virtual_pp_rank == 0:
                    # next chunk is the first chunk, not need to pre recv an input tensor
                    recv_prev = False
            # last micro step, no next run
            if micro_step == (num_steps - 1):
                recv_prev = False

            # last stage shouldn't send tensor to downstream
            if self.is_pipeline_last_stage():
                output_tensor = None

656 657 658 659 660
            if (
                micro_step == (startup_steps - 1)
                and not forward_only
                and not all_startup_steps
            ):
661 662 663 664 665 666
                input_tensor_grad = None
                recv_next = True
                if self.is_pipeline_last_stage(ignore_virtual=True):
                    recv_next = False

                # the last startup step needs on four direction comm to set up for steady 1f1b
667 668 669 670
                (
                    input_tensor,
                    output_tensor_grad,
                ) = p2p.send_forward_backward_recv_forward_backward(
671 672 673
                    output_tensor,
                    input_tensor_grad,
                    recv_prev=recv_prev,
674 675 676 677 678
                    recv_next=recv_next,
                )
                self.output_tensor_grads[self.num_model_chunks - 1].append(
                    output_tensor_grad
                )
679 680
            else:
                input_tensor = p2p.send_forward_recv_forward(
681 682
                    output_tensor, recv_prev=recv_prev
                )
683 684 685 686 687 688 689 690 691 692 693
            self.input_tensors[next_virtual_pp_rank].append(input_tensor)

        # run 1f1b steady steps
        for micro_step in range(steady_steps):
            # forward
            forward_micro_step_id = micro_step + startup_steps
            output_tensor = self._forward_step_helper(forward_micro_step_id)

            # backward
            backward_micro_step_id = micro_step
            input_tensor_grad = self._backward_step_helper(
694 695
                backward_micro_step_id
            )
696 697 698 699 700 701 702 703 704

            # four directions comm
            # send output tensor to downstream
            # send input tensor grad to upstream
            # recv input tensor from upstream
            # recv output tensor grad from downstream

            # last stage doesn't send rst to downstream
            forward_virtual_pp_rank = self._get_virtual_pp_rank(
705 706
                forward_micro_step_id, forward=True
            )
707 708 709 710 711 712
            self.set_virtual_pipeline_rank(forward_virtual_pp_rank)
            if self.is_pipeline_last_stage():
                output_tensor = None

            # first stage doesn't send grad to upstream
            backward_virtual_pp_rank = self._get_virtual_pp_rank(
713 714
                backward_micro_step_id, forward=False
            )
715 716 717 718 719 720 721 722
            self.set_virtual_pipeline_rank(backward_virtual_pp_rank)
            if self.is_pipeline_first_stage():
                input_tensor_grad = None

            # determine whether to recv input tensor from upstream
            recv_prev = True
            if self.is_pipeline_first_stage(ignore_virtual=True):
                next_forward_virtual_pp_rank = self._get_virtual_pp_rank(
723 724
                    forward_micro_step_id - (self.num_stages - 1), forward=True
                )
725 726 727 728 729 730
                if next_forward_virtual_pp_rank == (self.num_model_chunks - 1):
                    # first pp stage and first virtual stage
                    recv_prev = False
                next_forward_virtual_pp_rank += 1
            else:
                next_forward_virtual_pp_rank = self._get_virtual_pp_rank(
731 732
                    forward_micro_step_id + 1, forward=True
                )
733 734 735 736 737 738 739 740 741 742

            # last iteration doesn't need recv from upstream
            if micro_step == (steady_steps - 1):
                recv_prev = False

            # determine whether to recv grad from downstream
            recv_next = True
            if self.is_pipeline_last_stage(ignore_virtual=True):
                next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
                    backward_micro_step_id - (self.num_stages - 1),
743 744
                    forward=False,
                )
745 746 747 748 749 750
                if next_backward_virtual_pp_rank == 0:
                    # last pp stage and last virtual stage
                    recv_next = False
                next_backward_virtual_pp_rank -= 1
            else:
                next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
751 752
                    backward_micro_step_id + 1, forward=False
                )
753

754 755 756 757
            (
                input_tensor,
                output_tensor_grad,
            ) = p2p.send_forward_backward_recv_forward_backward(
758 759 760
                output_tensor,
                input_tensor_grad,
                recv_prev=recv_prev,
761 762
                recv_next=recv_next,
            )
763 764 765

            if recv_prev:
                self.input_tensors[next_forward_virtual_pp_rank].append(
766 767
                    input_tensor
                )
768 769
            if recv_next:
                self.output_tensor_grads[next_backward_virtual_pp_rank].append(
770 771
                    output_tensor_grad
                )
772 773 774 775 776

        # remaining backward steps
        if not forward_only:
            if all_startup_steps:
                self.output_tensor_grads[self.num_model_chunks - 1].append(
777 778 779 780
                    p2p.recv_backward(
                        self.is_pipeline_last_stage(), sync_recv=False
                    )
                )
781 782 783 784 785

            for micro_step in range(steady_steps, num_steps):
                # cooldown loop
                input_tensor_grad = self._backward_step_helper(micro_step)
                next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
786 787
                    micro_step + 1, forward=False
                )
788 789 790

                recv_next = True
                if self.is_pipeline_last_stage(ignore_virtual=True):
791 792 793
                    if next_backward_virtual_pp_rank == (
                        self.num_model_chunks - 1
                    ):
794 795 796 797 798 799
                        recv_next = False

                if micro_step == (num_steps - 1):
                    recv_next = False

                self.output_tensor_grads[next_backward_virtual_pp_rank].append(
800 801 802 803
                    p2p.send_backward_recv_backward(
                        input_tensor_grad, recv_next=recv_next
                    )
                )
804

805 806 807 808 809
            if self._dp_comm_overlap:
                assert len(self._dp_comm_buffers) > 0
                for buffer in self._dp_comm_buffers:
                    buffer.scale_and_split_grads()

810 811 812 813 814 815 816 817 818 819 820 821 822 823 824
            self._layers.allreduce_shared_weight_gradients()

        if compute_loss:
            # return loss if compute loss
            with paddle.amp.auto_cast(enable=False):
                train_loss = self._broadcast_final_loss()
        else:
            # else just return all intermediate output tensor for all micro steps
            train_loss = self.output_tensors

        return train_loss

    def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
        data = self._prepare_training(data, optimizer, lr_scheduler)
        # interleave scheduler for pipeline parallel
825
        train_loss = self.forward_backward_pipeline(data, scaler)
826 827 828 829 830 831 832 833 834 835 836 837 838 839

        # optimizer
        with paddle.amp.auto_cast(enable=False):
            self._optimizer_step()

        return train_loss

    def eval_batch(self, data, compute_loss=False):
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

        self._layers.eval()
        self._compute_loss = compute_loss

840
        return self.forward_backward_pipeline(data, None, forward_only=True)