pipeline_parallel.py 28.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

import paddle
import paddle.fluid as fluid
from .meta_parallel_base import MetaParallelBase
from .parallel_layers.pp_layers import PipelineLayer
18 19 20

from ..utils.hybrid_parallel_util import broadcast_mp_parameters
from ..utils.hybrid_parallel_util import broadcast_dp_parameters
21
from ..utils.hybrid_parallel_util import broadcast_sharding_parameters
22
from ..utils.log_util import logger
23
from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer, HybridParallelGradScaler
24
import paddle.fluid.framework as framework
S
ShenLiang 已提交
25
from .pp_utils import p2p_communication as p2p
S
ShenLiang 已提交
26
import paddle.fluid.core as core
27

28 29
__all__ = []

30 31

class PipelineParallel(MetaParallelBase):
32

33
    def __init__(self, layers, hcg, strategy):
34 35 36
        if not isinstance(layers, PipelineLayer):
            raise TypeError(
                "The Layer should be a derived class of PipelineLayer.")
37 38 39
        super(PipelineParallel, self).__init__(layers, hcg, strategy)
        self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1
        self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1
40 41
        self.use_sharding_parallel = self._hcg.get_sharding_parallel_world_size(
        ) > 1
42 43 44 45 46 47 48 49

        self.total_loss = None

        self.micro_batch_size = self._strategy.pipeline_configs[
            'micro_batch_size']
        self.accumulate_steps = self._strategy.pipeline_configs[
            'accumulate_steps']

50 51
        self._using_cache = self._strategy.pipeline_configs['p2p_cache_shape']

52 53
        self.num_stages = self._hcg.get_pipe_parallel_world_size()
        self.stage_id = self._hcg.get_stage_id()
54
        self.pp_group = self._hcg.get_pipe_parallel_group()
55

56 57 58 59 60
        self._virtual_pp_world_size = None
        self._virtual_pp_rank = None
        self._real_pp_world_size = self.num_stages
        self._real_pp_rank = self.stage_id

61
        p2p.initialize_p2p_groups(hcg, self._using_cache)
62 63

        self.global_rank = self._hcg.get_global_rank()
64
        self.micro_batch_id = 0
65

66 67
        self._compute_loss = True

68 69 70 71 72 73 74
        logger.info("Pipeline Info -- num_stages: {}, stage_id: {}".format(
            self.num_stages, self.stage_id))

        if self.use_model_parallel:
            logger.info("start broadcast mp parameters")
            broadcast_mp_parameters(self._layers, self._hcg)

75 76 77 78
        if self.use_sharding_parallel:
            logger.info("start broadcast sharding parameters")
            broadcast_sharding_parameters(self._layers, self._hcg)

79
        if self.use_data_parallel:
80
            logger.info("start broadcast dp parameters")
81
            broadcast_dp_parameters(self._layers, self._hcg)
82

83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
    def is_pipeline_first_stage(self, ignore_virtual=False):
        if not ignore_virtual:
            if self._virtual_pp_world_size is not None:
                assert self._virtual_pp_rank is not None
                if self._virtual_pp_rank != 0:
                    return False
        assert self._real_pp_rank is not None
        return self._real_pp_rank == 0

    def is_pipeline_last_stage(self, ignore_virtual=False):
        if not ignore_virtual:
            if self._virtual_pp_world_size is not None:
                assert self._virtual_pp_rank is not None
                if self._virtual_pp_rank != (self._virtual_pp_world_size - 1):
                    return False
        assert self._real_pp_rank is not None
        assert self._real_pp_world_size is not None
        return self._real_pp_rank == (self._real_pp_world_size - 1)

    def set_virtual_pipeline_rank(self, rank):
        self._virtual_pp_rank = rank

105 106 107 108
    def forward_backward_pipeline(self, data, scaler=None):
        # use the 1f1b scheduling strategy.
        # this strategy is inspired by:
        # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
109

110 111
        self.scaler = scaler

112 113
        # store data for train
        self.data = data
114

115 116 117
        # store total loss of entire batch
        self.total_loss = None

118 119
        # store data id for micro_batch
        self.micro_batch_id = 0
120

121 122 123
        startup_steps = (self.num_stages - self.stage_id - 1)
        startup_steps = min(startup_steps, self.accumulate_steps)
        steady_steps = self.accumulate_steps - startup_steps
124

125 126
        input_buffers = []
        output_buffers = []
127

128
        for step_id in range(startup_steps):
129
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
130

131
            output_tensor = self._forward_step(input_tensor)
132
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
133

134 135
            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)
136

137
        if steady_steps > 0:
138
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
139

140 141
        for i in range(steady_steps):
            last_iter = (i == (steady_steps - 1))
142

143
            output_tensor = self._forward_step(input_tensor)
144

145 146
            output_tensor_grad = p2p.send_forward_recv_backward(
                output_tensor, self.is_pipeline_last_stage())
147

148 149
            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)
150

151 152
            input_tensor, output_tensor = input_buffers.pop(
                0), output_buffers.pop(0)
153

154 155 156 157 158
            input_tensor_grad = self._backward_step(input_tensor, output_tensor,
                                                    output_tensor_grad)

            if last_iter:
                input_tensor = None
159 160
                p2p.send_backward(input_tensor_grad,
                                  self.is_pipeline_first_stage())
161
            else:
162 163
                input_tensor = p2p.send_backward_recv_forward(
                    input_tensor_grad, self.is_pipeline_first_stage())
164

165 166 167
        for i in range(startup_steps):
            input_tensor = input_buffers.pop(0)
            output_tensor = output_buffers.pop(0)
168

169 170
            output_tensor_grad = p2p.recv_backward(
                self.is_pipeline_last_stage())
171

172 173
            input_tensor_grad = self._backward_step(input_tensor, output_tensor,
                                                    output_tensor_grad)
174
            p2p.send_backward(input_tensor_grad, self.is_pipeline_first_stage())
175

176
        self._layers.allreduce_shared_weight_gradients()
177 178
        with paddle.amp.auto_cast(enable=False):
            train_loss = self._broadcast_final_loss()
179 180
        return train_loss

181 182 183 184
    def _prepare_training(self, data, optimizer, lr_scheduler):
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

185 186 187 188 189 190
        assert isinstance(optimizer, HybridParallelOptimizer), (
            'optimizer should be HybridParallelOptimizer subclass.')

        assert fluid.framework._dygraph_tracer()._has_grad, (
            'Please enable the generation of gradients.')

191 192 193
        if self.is_pipeline_first_stage(
                ignore_virtual=True) or self.is_pipeline_last_stage(
                    ignore_virtual=True):
194 195 196 197 198 199 200 201 202 203
            assert data is not None, (
                "For the first and the last stage, the data must be set.")
        else:
            data = None

        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler

        self._layers.train()

204 205 206 207 208
        return data

    def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
        data = self._prepare_training(data, optimizer, lr_scheduler)
        # 1f1b scheduler for pipeline parallel
209
        train_loss = self.forward_backward_pipeline(data, scaler)
210 211

        # optimizer
212 213
        with paddle.amp.auto_cast(enable=False):
            self._optimizer_step()
214 215

        return train_loss
216

217
    def eval_batch(self, data, compute_loss=False):
218 219 220
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
        self._layers.eval()
        self._compute_loss = compute_loss

        # save data for eval
        self.data = data
        # store data id for micro_batch
        self.micro_batch_id = 0

        # store total loss of entire batch
        self.total_loss = None

        startup_steps = (self.num_stages - self.stage_id - 1)
        startup_steps = min(startup_steps, self.accumulate_steps)
        steady_steps = self.accumulate_steps - startup_steps

        input_buffers = []
        output_buffers = []

        for step_id in range(startup_steps):
240
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
241 242

            output_tensor = self._forward_step(input_tensor)
243
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
244 245 246 247 248

            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)

        if steady_steps > 0:
249
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
250 251 252 253 254

        for i in range(steady_steps):
            last_iter = (i == (steady_steps - 1))

            output_tensor = self._forward_step(input_tensor)
255
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
256 257 258 259 260

            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)

            if not last_iter:
261
                input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
262

263 264 265 266 267 268
        if self._compute_loss:
            self.train_loss = self._broadcast_final_loss()
        else:
            self.train_loss = output_buffers

        return self.train_loss
269

270 271
    def _forward_step(self, input_tensor, chunk_id=None):
        if self.is_pipeline_first_stage():
272 273
            input_tensor = self._load_micro_batch(self.micro_batch_id)

274 275 276
        assert chunk_id is None or isinstance(chunk_id, int)

        output_tensor = self._layers.forward(input_tensor, chunk_id=chunk_id)
277

278
        if self.is_pipeline_last_stage():
279 280 281 282 283
            # train calculate loss for train
            if self._compute_loss:
                assert self._layers._loss_fn is not None, "loss function should exist to compute loss"
                labels = self._load_micro_batch(self.micro_batch_id)
                output_tensor = self._layers._loss_fn(output_tensor, labels)
284 285 286 287
                assert isinstance(
                    output_tensor,
                    (paddle.Tensor, core.eager.Tensor
                     )), "Currently, loss_fn should obtain Paddle.Tensor dtype"
288

289 290 291
                with paddle.amp.auto_cast(enable=False):
                    if self.accumulate_steps > 1:
                        output_tensor = output_tensor / self.accumulate_steps
292

293 294 295
                    if self.total_loss is None:
                        self.total_loss = paddle.zeros_like(output_tensor)
                    self.total_loss += output_tensor.detach()
296

297 298 299 300
        if self.is_pipeline_first_stage() or self.is_pipeline_last_stage():
            # Only increase micro batch id at virtual first/last pp stage.
            # The micro batch id is used to load data, therefore, only increase it when load data.
            self.micro_batch_id += 1
301 302 303
        return output_tensor

    def _backward_step(self, input_tensor, output_tensor, output_tensor_grad):
S
ShenLiang 已提交
304
        with paddle.amp.auto_cast(enable=False):
305
            if self.is_pipeline_last_stage():
S
ShenLiang 已提交
306 307 308 309 310
                assert output_tensor_grad is None
                if self.scaler:
                    paddle.autograd.backward(self.scaler.scale(output_tensor))
                else:
                    paddle.autograd.backward(output_tensor)
311
            else:
S
ShenLiang 已提交
312 313 314 315 316 317 318
                if isinstance(output_tensor, tuple):
                    outputs = [t for t in output_tensor if not t.stop_gradient]
                    assert len(outputs) == len(output_tensor_grad)
                    paddle.autograd.backward(
                        tensors=outputs,
                        grad_tensors=[t for t in output_tensor_grad])
                else:
319 320
                    paddle.autograd.backward(tensors=[output_tensor],
                                             grad_tensors=[output_tensor_grad])
S
ShenLiang 已提交
321 322 323 324 325 326 327 328 329

            input_tensor_grad = None
            if input_tensor is not None:
                if isinstance(input_tensor, tuple):
                    input_tensor_grad = tuple(
                        [t.grad for t in input_tensor if not t.stop_gradient])
                else:
                    input_tensor_grad = input_tensor.grad
            return input_tensor_grad
330 331

    def _load_micro_batch(self, cache_id):
332 333 334 335
        inputs = self.data
        begin = cache_id * self.micro_batch_size
        end = begin + self.micro_batch_size

336 337
        # The virtual first and last pipeline stage need data, all others don't need.
        if self.is_pipeline_first_stage():
338 339
            assert len(inputs) == 2, "length of input should be 2"
            if isinstance(inputs[0], tuple):
340 341 342
                assert len(
                    inputs[0]
                ) > 1, "If you use tuple for input data, it should have at least two inputs."
343 344 345 346 347 348
                batch_size = inputs[0][0].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size, (
                    "batch_size needs to be divisible by micro_batch_size. Currently, "
                    "batch_size = %d, micro_batch_size = %d, accumulate_steps = %d."
                    %
                    (batch_size, self.micro_batch_size, self.accumulate_steps))
349 350
                data = [input[begin:end, :].detach() for input in inputs[0]]
                return tuple(data)
351 352 353
            else:
                batch_size = inputs[0].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size
354
                return inputs[0][begin:end, :].detach()
355
        elif self.is_pipeline_last_stage():
356 357 358 359
            assert len(inputs) == 2, "length of input should be 2"
            if isinstance(inputs[1], tuple):
                batch_size = inputs[1][0].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size
360 361
                data = [input[begin:end, :].detach() for input in inputs[1]]
                return tuple(data)
362
            else:
363 364
                batch_size = inputs[1].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size
365
                return inputs[1][begin:end, :].detach()
366 367 368
        else:
            # No data input is required for other stages
            inputs = None
369

370
    def _broadcast_final_loss(self):
371 372 373
        # Since the last backward run in interleave will set the virtual rank to 0,
        # here we need to check last stage ignoring virtual stage.
        if self.is_pipeline_last_stage(ignore_virtual=True):
374 375
            assert self.total_loss is not None, "train_batch() in last stage should obtain vaild loss"
            loss = self.total_loss.detach()
376 377
            is_fp32 = paddle.to_tensor(
                1) if loss.dtype == paddle.float32 else paddle.to_tensor(0)
378 379
            paddle.distributed.broadcast(is_fp32,
                                         src=self.global_rank,
380
                                         sync_op=True,
381 382 383
                                         group=self.pp_group)
            paddle.distributed.broadcast(loss,
                                         src=self.global_rank,
384
                                         sync_op=True,
385
                                         group=self.pp_group)
386
        else:
387 388 389 390
            is_fp32 = paddle.to_tensor(1)
            paddle.distributed.broadcast(
                is_fp32,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
391
                sync_op=True,
392
                group=self.pp_group)
393 394 395 396
            loss = paddle.zeros(shape=[
                1
            ], dtype="float32") if is_fp32.numpy()[0] else paddle.zeros(
                shape=[1], dtype="float16")
397 398 399
            paddle.distributed.broadcast(
                loss,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
400
                sync_op=True,
401 402
                group=self.pp_group)
        return loss
403

404
    def _optimizer_step(self):
405
        if self.scaler:
406
            self.scaler.step(self.optimizer)
S
ShenLiang 已提交
407
            self.scaler.update()
408 409
        else:
            self.optimizer.step()
410

411 412 413
        self.optimizer.clear_grad()
        if self.lr_scheduler:
            self.lr_scheduler.step()
414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701


class PipelineParallelWithInterleave(PipelineParallel):
    # pipeline parallel with interleave scheduler

    def __init__(self, layers, hcg, strategy):
        super(PipelineParallelWithInterleave, self).__init__(layers=layers,
                                                             hcg=hcg,
                                                             strategy=strategy)
        assert layers.get_num_virtual_stages() > 1
        assert framework.in_dygraph_mode(
        ), "virtual pipeline stage with interleave only support eager dygraph mode"
        # setup for interleave scheduler
        self.num_model_chunks = layers.get_num_virtual_stages()
        self.model_chunks = layers.get_model_chunks()
        assert self.model_chunks is not None
        assert len(self.model_chunks) == self.num_model_chunks
        self._virtual_pp_world_size = self.num_model_chunks
        self._virtual_pp_rank = 0

    def _get_virtual_pp_rank(self, micro_step, forward):
        virtual_pp_stage = micro_step % (self.num_stages *
                                         self.num_model_chunks)
        virtual_pp_stage = virtual_pp_stage // self.num_stages
        if not forward:
            virtual_pp_stage = (self.num_model_chunks - virtual_pp_stage - 1)
        return virtual_pp_stage

    def _forward_step_helper(self, micro_step):
        virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=True)
        self.set_virtual_pipeline_rank(virtual_pp_rank)

        # some checkers
        assert hasattr(self, 'input_tensors')
        assert hasattr(self, 'output_tensors')
        if not self._forward_only:
            assert hasattr(self, 'output_tensor_grads')

        if self.is_pipeline_first_stage():
            if len(self.input_tensors[virtual_pp_rank]) == len(
                    self.output_tensors[virtual_pp_rank]):
                self.input_tensors[virtual_pp_rank].append(None)
        input_tensor = self.input_tensors[virtual_pp_rank][-1]
        output_tensor = self._forward_step(input_tensor, virtual_pp_rank)
        self.output_tensors[virtual_pp_rank].append(output_tensor)

        if self._forward_only:
            # no need to store tensor for backward
            self.input_tensors[virtual_pp_rank].pop()
            self.output_tensors[virtual_pp_rank].pop()

        return output_tensor

    def _backward_step_helper(self, micro_step):
        virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=False)
        self.set_virtual_pipeline_rank(virtual_pp_rank)

        # some checkers
        assert hasattr(self, 'input_tensors')
        assert hasattr(self, 'output_tensors')
        assert hasattr(self, 'output_tensor_grads')

        if self.is_pipeline_last_stage():
            if len(self.output_tensor_grads[virtual_pp_rank]) == 0:
                self.output_tensor_grads[virtual_pp_rank].append(None)

        input_tensor = self.input_tensors[virtual_pp_rank].pop(0)
        output_tensor = self.output_tensors[virtual_pp_rank].pop(0)
        output_tensor_grad = self.output_tensor_grads[virtual_pp_rank].pop(0)
        input_tensor_grad = self._backward_step(input_tensor, output_tensor,
                                                output_tensor_grad)

        return input_tensor_grad

    def interleave_pipeline(self,
                            data,
                            scaler,
                            forward_only=False,
                            compute_loss=True):
        # use interleave scheduling strategy.
        # this strategy is inspired by:
        # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
        if not compute_loss:
            assert not forward_only, "compute_loss can only be set to False when forward_only is set to True"

        # init some attributes for this batch run
        self.scaler = scaler
        self.data = data
        self.total_loss = None
        self.micro_batch_id = 0
        self._forward_only = forward_only

        # init some data buffers for interleave scheduler
        self.input_tensors = [[] for _ in range(self.num_model_chunks)]
        self.output_tensors = [[] for _ in range(self.num_model_chunks)]
        self.output_tensor_grads = [[] for _ in range(self.num_model_chunks)]

        num_steps = self.accumulate_steps * self.num_model_chunks
        all_startup_steps = False
        if forward_only:
            # If only forward, since there is no backward during running, all steps are startup steps
            startup_steps = num_steps
        else:
            if self.accumulate_steps == self.num_stages:
                startup_steps = num_steps
                all_startup_steps = True
            else:
                startup_steps = (self.num_stages - self.stage_id - 1) * 2
                startup_steps += (self.num_model_chunks - 1) * self.num_stages
                startup_steps = min(startup_steps, num_steps)

        steady_steps = num_steps - startup_steps

        self.set_virtual_pipeline_rank(0)
        self.input_tensors[0].append(
            p2p.recv_forward(self.is_pipeline_first_stage()))

        # run startup steps
        for micro_step in range(startup_steps):
            output_tensor = self._forward_step_helper(micro_step)

            # determine whether recv forward tensor or not
            next_virtual_pp_rank = self._get_virtual_pp_rank(micro_step + 1,
                                                             forward=True)
            recv_prev = True
            if self.is_pipeline_first_stage(ignore_virtual=True):
                if next_virtual_pp_rank == 0:
                    # next chunk is the first chunk, not need to pre recv an input tensor
                    recv_prev = False
            # last micro step, no next run
            if micro_step == (num_steps - 1):
                recv_prev = False

            # last stage shouldn't send tensor to downstream
            if self.is_pipeline_last_stage():
                output_tensor = None

            if micro_step == (startup_steps -
                              1) and not forward_only and not all_startup_steps:
                input_tensor_grad = None
                recv_next = True
                if self.is_pipeline_last_stage(ignore_virtual=True):
                    recv_next = False

                # the last startup step needs on four direction comm to set up for steady 1f1b
                input_tensor, output_tensor_grad = p2p.send_forward_backward_recv_forward_backward(
                    output_tensor,
                    input_tensor_grad,
                    recv_prev=recv_prev,
                    recv_next=recv_next)
                self.output_tensor_grads[self.num_model_chunks -
                                         1].append(output_tensor_grad)
            else:
                input_tensor = p2p.send_forward_recv_forward(
                    output_tensor, recv_prev=recv_prev)
            self.input_tensors[next_virtual_pp_rank].append(input_tensor)

        # run 1f1b steady steps
        for micro_step in range(steady_steps):
            # forward
            forward_micro_step_id = micro_step + startup_steps
            output_tensor = self._forward_step_helper(forward_micro_step_id)

            # backward
            backward_micro_step_id = micro_step
            input_tensor_grad = self._backward_step_helper(
                backward_micro_step_id)

            # four directions comm
            # send output tensor to downstream
            # send input tensor grad to upstream
            # recv input tensor from upstream
            # recv output tensor grad from downstream

            # last stage doesn't send rst to downstream
            forward_virtual_pp_rank = self._get_virtual_pp_rank(
                forward_micro_step_id, forward=True)
            self.set_virtual_pipeline_rank(forward_virtual_pp_rank)
            if self.is_pipeline_last_stage():
                output_tensor = None

            # first stage doesn't send grad to upstream
            backward_virtual_pp_rank = self._get_virtual_pp_rank(
                backward_micro_step_id, forward=False)
            self.set_virtual_pipeline_rank(backward_virtual_pp_rank)
            if self.is_pipeline_first_stage():
                input_tensor_grad = None

            # determine whether to recv input tensor from upstream
            recv_prev = True
            if self.is_pipeline_first_stage(ignore_virtual=True):
                next_forward_virtual_pp_rank = self._get_virtual_pp_rank(
                    forward_micro_step_id - (self.num_stages - 1), forward=True)
                if next_forward_virtual_pp_rank == (self.num_model_chunks - 1):
                    # first pp stage and first virtual stage
                    recv_prev = False
                next_forward_virtual_pp_rank += 1
            else:
                next_forward_virtual_pp_rank = self._get_virtual_pp_rank(
                    forward_micro_step_id + 1, forward=True)

            # last iteration doesn't need recv from upstream
            if micro_step == (steady_steps - 1):
                recv_prev = False

            # determine whether to recv grad from downstream
            recv_next = True
            if self.is_pipeline_last_stage(ignore_virtual=True):
                next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
                    backward_micro_step_id - (self.num_stages - 1),
                    forward=False)
                if next_backward_virtual_pp_rank == 0:
                    # last pp stage and last virtual stage
                    recv_next = False
                next_backward_virtual_pp_rank -= 1
            else:
                next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
                    backward_micro_step_id + 1, forward=False)

            input_tensor, output_tensor_grad = p2p.send_forward_backward_recv_forward_backward(
                output_tensor,
                input_tensor_grad,
                recv_prev=recv_prev,
                recv_next=recv_next)

            if recv_prev:
                self.input_tensors[next_forward_virtual_pp_rank].append(
                    input_tensor)
            if recv_next:
                self.output_tensor_grads[next_backward_virtual_pp_rank].append(
                    output_tensor_grad)

        # remaining backward steps
        if not forward_only:
            if all_startup_steps:
                self.output_tensor_grads[self.num_model_chunks - 1].append(
                    p2p.recv_backward(self.is_pipeline_last_stage()))

            for micro_step in range(steady_steps, num_steps):
                # cooldown loop
                input_tensor_grad = self._backward_step_helper(micro_step)
                next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
                    micro_step + 1, forward=False)

                recv_next = True
                if self.is_pipeline_last_stage(ignore_virtual=True):
                    if next_backward_virtual_pp_rank == (self.num_model_chunks -
                                                         1):
                        recv_next = False

                if micro_step == (num_steps - 1):
                    recv_next = False

                self.output_tensor_grads[next_backward_virtual_pp_rank].append(
                    p2p.send_backward_recv_backward(input_tensor_grad,
                                                    recv_next=recv_next))

            self._layers.allreduce_shared_weight_gradients()

        if compute_loss:
            # return loss if compute loss
            with paddle.amp.auto_cast(enable=False):
                train_loss = self._broadcast_final_loss()
        else:
            # else just return all intermediate output tensor for all micro steps
            train_loss = self.output_tensors

        return train_loss

    def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
        data = self._prepare_training(data, optimizer, lr_scheduler)
        # interleave scheduler for pipeline parallel
        train_loss = self.interleave_pipeline(data, scaler)

        # optimizer
        with paddle.amp.auto_cast(enable=False):
            self._optimizer_step()

        return train_loss

    def eval_batch(self, data, compute_loss=False):
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

        self._layers.eval()
        self._compute_loss = compute_loss

        return self.interleave_pipeline(data, None, forward_only=True)