pipeline_parallel.py 28.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

import paddle
import paddle.fluid as fluid
from .meta_parallel_base import MetaParallelBase
17
from .pp_utils.utils import is_float_tensor, _initialize_recompute_hcg
18
from .parallel_layers.pp_layers import PipelineLayer
19 20 21

from ..utils.hybrid_parallel_util import broadcast_mp_parameters
from ..utils.hybrid_parallel_util import broadcast_dp_parameters
22
from ..utils.hybrid_parallel_util import broadcast_sharding_parameters
23
from ..utils.log_util import logger
24
from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer, HybridParallelGradScaler
25
import paddle.fluid.framework as framework
S
ShenLiang 已提交
26
from .pp_utils import p2p_communication as p2p
S
ShenLiang 已提交
27
import paddle.fluid.core as core
28

29 30
__all__ = []

31 32

class PipelineParallel(MetaParallelBase):
33

34
    def __init__(self, layers, hcg, strategy):
35 36 37
        if not isinstance(layers, PipelineLayer):
            raise TypeError(
                "The Layer should be a derived class of PipelineLayer.")
38 39 40
        super(PipelineParallel, self).__init__(layers, hcg, strategy)
        self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1
        self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1
41 42
        self.use_sharding_parallel = self._hcg.get_sharding_parallel_world_size(
        ) > 1
43 44 45 46 47 48 49 50

        self.total_loss = None

        self.micro_batch_size = self._strategy.pipeline_configs[
            'micro_batch_size']
        self.accumulate_steps = self._strategy.pipeline_configs[
            'accumulate_steps']

51 52
        self._using_cache = self._strategy.pipeline_configs['p2p_cache_shape']

53 54
        self.num_stages = self._hcg.get_pipe_parallel_world_size()
        self.stage_id = self._hcg.get_stage_id()
55
        self.pp_group = self._hcg.get_pipe_parallel_group()
56

57 58 59 60 61
        self._virtual_pp_world_size = None
        self._virtual_pp_rank = None
        self._real_pp_world_size = self.num_stages
        self._real_pp_rank = self.stage_id

62
        p2p.initialize_p2p_groups(hcg, self._using_cache)
63

64 65
        _initialize_recompute_hcg(hcg)

66
        self.global_rank = self._hcg.get_global_rank()
67
        self.micro_batch_id = 0
68

69 70
        self._compute_loss = True

71 72 73 74 75 76 77
        logger.info("Pipeline Info -- num_stages: {}, stage_id: {}".format(
            self.num_stages, self.stage_id))

        if self.use_model_parallel:
            logger.info("start broadcast mp parameters")
            broadcast_mp_parameters(self._layers, self._hcg)

78 79 80 81
        if self.use_sharding_parallel:
            logger.info("start broadcast sharding parameters")
            broadcast_sharding_parameters(self._layers, self._hcg)

82
        if self.use_data_parallel:
83
            logger.info("start broadcast dp parameters")
84
            broadcast_dp_parameters(self._layers, self._hcg)
85

86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
    def is_pipeline_first_stage(self, ignore_virtual=False):
        if not ignore_virtual:
            if self._virtual_pp_world_size is not None:
                assert self._virtual_pp_rank is not None
                if self._virtual_pp_rank != 0:
                    return False
        assert self._real_pp_rank is not None
        return self._real_pp_rank == 0

    def is_pipeline_last_stage(self, ignore_virtual=False):
        if not ignore_virtual:
            if self._virtual_pp_world_size is not None:
                assert self._virtual_pp_rank is not None
                if self._virtual_pp_rank != (self._virtual_pp_world_size - 1):
                    return False
        assert self._real_pp_rank is not None
        assert self._real_pp_world_size is not None
        return self._real_pp_rank == (self._real_pp_world_size - 1)

    def set_virtual_pipeline_rank(self, rank):
        self._virtual_pp_rank = rank

108 109 110 111
    def forward_backward_pipeline(self, data, scaler=None):
        # use the 1f1b scheduling strategy.
        # this strategy is inspired by:
        # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
112

113 114
        self.scaler = scaler

115 116
        # store data for train
        self.data = data
117

118 119 120
        # store total loss of entire batch
        self.total_loss = None

121 122
        # store data id for micro_batch
        self.micro_batch_id = 0
123

124 125 126
        startup_steps = (self.num_stages - self.stage_id - 1)
        startup_steps = min(startup_steps, self.accumulate_steps)
        steady_steps = self.accumulate_steps - startup_steps
127

128 129
        input_buffers = []
        output_buffers = []
130

131
        for step_id in range(startup_steps):
132
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
133

134
            output_tensor = self._forward_step(input_tensor)
135
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
136

137 138
            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)
139

140
        if steady_steps > 0:
141
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
142

143 144
        for i in range(steady_steps):
            last_iter = (i == (steady_steps - 1))
145

146
            output_tensor = self._forward_step(input_tensor)
147

148 149
            output_tensor_grad = p2p.send_forward_recv_backward(
                output_tensor, self.is_pipeline_last_stage())
150

151 152
            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)
153

154 155
            input_tensor, output_tensor = input_buffers.pop(
                0), output_buffers.pop(0)
156

157 158 159 160 161
            input_tensor_grad = self._backward_step(input_tensor, output_tensor,
                                                    output_tensor_grad)

            if last_iter:
                input_tensor = None
162 163
                p2p.send_backward(input_tensor_grad,
                                  self.is_pipeline_first_stage())
164
            else:
165 166
                input_tensor = p2p.send_backward_recv_forward(
                    input_tensor_grad, self.is_pipeline_first_stage())
167

168 169 170
        for i in range(startup_steps):
            input_tensor = input_buffers.pop(0)
            output_tensor = output_buffers.pop(0)
171

172 173
            output_tensor_grad = p2p.recv_backward(
                self.is_pipeline_last_stage())
174

175 176
            input_tensor_grad = self._backward_step(input_tensor, output_tensor,
                                                    output_tensor_grad)
177
            p2p.send_backward(input_tensor_grad, self.is_pipeline_first_stage())
178

179
        self._layers.allreduce_shared_weight_gradients()
180 181
        with paddle.amp.auto_cast(enable=False):
            train_loss = self._broadcast_final_loss()
182 183
        return train_loss

184 185 186 187
    def _prepare_training(self, data, optimizer, lr_scheduler):
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

188 189 190 191 192 193
        assert isinstance(optimizer, HybridParallelOptimizer), (
            'optimizer should be HybridParallelOptimizer subclass.')

        assert fluid.framework._dygraph_tracer()._has_grad, (
            'Please enable the generation of gradients.')

194 195 196
        if self.is_pipeline_first_stage(
                ignore_virtual=True) or self.is_pipeline_last_stage(
                    ignore_virtual=True):
197 198 199 200 201 202 203 204 205 206
            assert data is not None, (
                "For the first and the last stage, the data must be set.")
        else:
            data = None

        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler

        self._layers.train()

207 208 209 210 211
        return data

    def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
        data = self._prepare_training(data, optimizer, lr_scheduler)
        # 1f1b scheduler for pipeline parallel
212
        train_loss = self.forward_backward_pipeline(data, scaler)
213 214

        # optimizer
215 216
        with paddle.amp.auto_cast(enable=False):
            self._optimizer_step()
217 218

        return train_loss
219

220
    def eval_batch(self, data, compute_loss=False):
221 222 223
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242
        self._layers.eval()
        self._compute_loss = compute_loss

        # save data for eval
        self.data = data
        # store data id for micro_batch
        self.micro_batch_id = 0

        # store total loss of entire batch
        self.total_loss = None

        startup_steps = (self.num_stages - self.stage_id - 1)
        startup_steps = min(startup_steps, self.accumulate_steps)
        steady_steps = self.accumulate_steps - startup_steps

        input_buffers = []
        output_buffers = []

        for step_id in range(startup_steps):
243
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
244 245

            output_tensor = self._forward_step(input_tensor)
246
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
247 248 249 250 251

            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)

        if steady_steps > 0:
252
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
253 254 255 256 257

        for i in range(steady_steps):
            last_iter = (i == (steady_steps - 1))

            output_tensor = self._forward_step(input_tensor)
258
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
259 260 261 262 263

            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)

            if not last_iter:
264
                input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
265

266 267 268 269 270 271
        if self._compute_loss:
            self.train_loss = self._broadcast_final_loss()
        else:
            self.train_loss = output_buffers

        return self.train_loss
272

273 274
    def _forward_step(self, input_tensor, chunk_id=None):
        if self.is_pipeline_first_stage():
275 276
            input_tensor = self._load_micro_batch(self.micro_batch_id)

277 278 279
        assert chunk_id is None or isinstance(chunk_id, int)

        output_tensor = self._layers.forward(input_tensor, chunk_id=chunk_id)
280

281
        if self.is_pipeline_last_stage():
282 283 284 285 286
            # train calculate loss for train
            if self._compute_loss:
                assert self._layers._loss_fn is not None, "loss function should exist to compute loss"
                labels = self._load_micro_batch(self.micro_batch_id)
                output_tensor = self._layers._loss_fn(output_tensor, labels)
287 288 289 290
                assert isinstance(
                    output_tensor,
                    (paddle.Tensor, core.eager.Tensor
                     )), "Currently, loss_fn should obtain Paddle.Tensor dtype"
291

292 293 294
                with paddle.amp.auto_cast(enable=False):
                    if self.accumulate_steps > 1:
                        output_tensor = output_tensor / self.accumulate_steps
295

296 297 298
                    if self.total_loss is None:
                        self.total_loss = paddle.zeros_like(output_tensor)
                    self.total_loss += output_tensor.detach()
299

300 301 302 303
        if self.is_pipeline_first_stage() or self.is_pipeline_last_stage():
            # Only increase micro batch id at virtual first/last pp stage.
            # The micro batch id is used to load data, therefore, only increase it when load data.
            self.micro_batch_id += 1
304 305 306
        return output_tensor

    def _backward_step(self, input_tensor, output_tensor, output_tensor_grad):
S
ShenLiang 已提交
307
        with paddle.amp.auto_cast(enable=False):
308
            if self.is_pipeline_last_stage():
S
ShenLiang 已提交
309 310 311 312 313
                assert output_tensor_grad is None
                if self.scaler:
                    paddle.autograd.backward(self.scaler.scale(output_tensor))
                else:
                    paddle.autograd.backward(output_tensor)
314
            else:
S
ShenLiang 已提交
315 316 317 318 319 320 321
                if isinstance(output_tensor, tuple):
                    outputs = [t for t in output_tensor if not t.stop_gradient]
                    assert len(outputs) == len(output_tensor_grad)
                    paddle.autograd.backward(
                        tensors=outputs,
                        grad_tensors=[t for t in output_tensor_grad])
                else:
322 323
                    paddle.autograd.backward(tensors=[output_tensor],
                                             grad_tensors=[output_tensor_grad])
S
ShenLiang 已提交
324 325 326 327 328 329 330 331 332

            input_tensor_grad = None
            if input_tensor is not None:
                if isinstance(input_tensor, tuple):
                    input_tensor_grad = tuple(
                        [t.grad for t in input_tensor if not t.stop_gradient])
                else:
                    input_tensor_grad = input_tensor.grad
            return input_tensor_grad
333 334

    def _load_micro_batch(self, cache_id):
335 336 337 338
        inputs = self.data
        begin = cache_id * self.micro_batch_size
        end = begin + self.micro_batch_size

339 340
        # The virtual first and last pipeline stage need data, all others don't need.
        if self.is_pipeline_first_stage():
341 342
            assert len(inputs) == 2, "length of input should be 2"
            if isinstance(inputs[0], tuple):
343 344 345
                assert len(
                    inputs[0]
                ) > 1, "If you use tuple for input data, it should have at least two inputs."
346 347 348 349 350 351
                batch_size = inputs[0][0].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size, (
                    "batch_size needs to be divisible by micro_batch_size. Currently, "
                    "batch_size = %d, micro_batch_size = %d, accumulate_steps = %d."
                    %
                    (batch_size, self.micro_batch_size, self.accumulate_steps))
352 353
                data = [input[begin:end, :].detach() for input in inputs[0]]
                return tuple(data)
354 355 356
            else:
                batch_size = inputs[0].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size
357
                return inputs[0][begin:end, :].detach()
358
        elif self.is_pipeline_last_stage():
359 360 361 362
            assert len(inputs) == 2, "length of input should be 2"
            if isinstance(inputs[1], tuple):
                batch_size = inputs[1][0].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size
363 364
                data = [input[begin:end, :].detach() for input in inputs[1]]
                return tuple(data)
365
            else:
366 367
                batch_size = inputs[1].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size
368
                return inputs[1][begin:end, :].detach()
369 370 371
        else:
            # No data input is required for other stages
            inputs = None
372

373
    def _broadcast_final_loss(self):
374 375 376
        # Since the last backward run in interleave will set the virtual rank to 0,
        # here we need to check last stage ignoring virtual stage.
        if self.is_pipeline_last_stage(ignore_virtual=True):
377 378
            assert self.total_loss is not None, "train_batch() in last stage should obtain vaild loss"
            loss = self.total_loss.detach()
379 380
            is_fp32 = paddle.to_tensor(
                1) if loss.dtype == paddle.float32 else paddle.to_tensor(0)
381 382 383 384 385 386 387 388
            paddle.distributed.broadcast(is_fp32,
                                         src=self.global_rank,
                                         use_calc_stream=True,
                                         group=self.pp_group)
            paddle.distributed.broadcast(loss,
                                         src=self.global_rank,
                                         use_calc_stream=True,
                                         group=self.pp_group)
389
        else:
390 391 392 393 394 395
            is_fp32 = paddle.to_tensor(1)
            paddle.distributed.broadcast(
                is_fp32,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
                use_calc_stream=True,
                group=self.pp_group)
396 397 398 399
            loss = paddle.zeros(shape=[
                1
            ], dtype="float32") if is_fp32.numpy()[0] else paddle.zeros(
                shape=[1], dtype="float16")
400 401 402 403 404 405
            paddle.distributed.broadcast(
                loss,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
                use_calc_stream=True,
                group=self.pp_group)
        return loss
406

407
    def _optimizer_step(self):
408
        if self.scaler:
409
            self.scaler.step(self.optimizer)
S
ShenLiang 已提交
410
            self.scaler.update()
411 412
        else:
            self.optimizer.step()
413

414 415 416
        self.optimizer.clear_grad()
        if self.lr_scheduler:
            self.lr_scheduler.step()
417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704


class PipelineParallelWithInterleave(PipelineParallel):
    # pipeline parallel with interleave scheduler

    def __init__(self, layers, hcg, strategy):
        super(PipelineParallelWithInterleave, self).__init__(layers=layers,
                                                             hcg=hcg,
                                                             strategy=strategy)
        assert layers.get_num_virtual_stages() > 1
        assert framework.in_dygraph_mode(
        ), "virtual pipeline stage with interleave only support eager dygraph mode"
        # setup for interleave scheduler
        self.num_model_chunks = layers.get_num_virtual_stages()
        self.model_chunks = layers.get_model_chunks()
        assert self.model_chunks is not None
        assert len(self.model_chunks) == self.num_model_chunks
        self._virtual_pp_world_size = self.num_model_chunks
        self._virtual_pp_rank = 0

    def _get_virtual_pp_rank(self, micro_step, forward):
        virtual_pp_stage = micro_step % (self.num_stages *
                                         self.num_model_chunks)
        virtual_pp_stage = virtual_pp_stage // self.num_stages
        if not forward:
            virtual_pp_stage = (self.num_model_chunks - virtual_pp_stage - 1)
        return virtual_pp_stage

    def _forward_step_helper(self, micro_step):
        virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=True)
        self.set_virtual_pipeline_rank(virtual_pp_rank)

        # some checkers
        assert hasattr(self, 'input_tensors')
        assert hasattr(self, 'output_tensors')
        if not self._forward_only:
            assert hasattr(self, 'output_tensor_grads')

        if self.is_pipeline_first_stage():
            if len(self.input_tensors[virtual_pp_rank]) == len(
                    self.output_tensors[virtual_pp_rank]):
                self.input_tensors[virtual_pp_rank].append(None)
        input_tensor = self.input_tensors[virtual_pp_rank][-1]
        output_tensor = self._forward_step(input_tensor, virtual_pp_rank)
        self.output_tensors[virtual_pp_rank].append(output_tensor)

        if self._forward_only:
            # no need to store tensor for backward
            self.input_tensors[virtual_pp_rank].pop()
            self.output_tensors[virtual_pp_rank].pop()

        return output_tensor

    def _backward_step_helper(self, micro_step):
        virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=False)
        self.set_virtual_pipeline_rank(virtual_pp_rank)

        # some checkers
        assert hasattr(self, 'input_tensors')
        assert hasattr(self, 'output_tensors')
        assert hasattr(self, 'output_tensor_grads')

        if self.is_pipeline_last_stage():
            if len(self.output_tensor_grads[virtual_pp_rank]) == 0:
                self.output_tensor_grads[virtual_pp_rank].append(None)

        input_tensor = self.input_tensors[virtual_pp_rank].pop(0)
        output_tensor = self.output_tensors[virtual_pp_rank].pop(0)
        output_tensor_grad = self.output_tensor_grads[virtual_pp_rank].pop(0)
        input_tensor_grad = self._backward_step(input_tensor, output_tensor,
                                                output_tensor_grad)

        return input_tensor_grad

    def interleave_pipeline(self,
                            data,
                            scaler,
                            forward_only=False,
                            compute_loss=True):
        # use interleave scheduling strategy.
        # this strategy is inspired by:
        # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
        if not compute_loss:
            assert not forward_only, "compute_loss can only be set to False when forward_only is set to True"

        # init some attributes for this batch run
        self.scaler = scaler
        self.data = data
        self.total_loss = None
        self.micro_batch_id = 0
        self._forward_only = forward_only

        # init some data buffers for interleave scheduler
        self.input_tensors = [[] for _ in range(self.num_model_chunks)]
        self.output_tensors = [[] for _ in range(self.num_model_chunks)]
        self.output_tensor_grads = [[] for _ in range(self.num_model_chunks)]

        num_steps = self.accumulate_steps * self.num_model_chunks
        all_startup_steps = False
        if forward_only:
            # If only forward, since there is no backward during running, all steps are startup steps
            startup_steps = num_steps
        else:
            if self.accumulate_steps == self.num_stages:
                startup_steps = num_steps
                all_startup_steps = True
            else:
                startup_steps = (self.num_stages - self.stage_id - 1) * 2
                startup_steps += (self.num_model_chunks - 1) * self.num_stages
                startup_steps = min(startup_steps, num_steps)

        steady_steps = num_steps - startup_steps

        self.set_virtual_pipeline_rank(0)
        self.input_tensors[0].append(
            p2p.recv_forward(self.is_pipeline_first_stage()))

        # run startup steps
        for micro_step in range(startup_steps):
            output_tensor = self._forward_step_helper(micro_step)

            # determine whether recv forward tensor or not
            next_virtual_pp_rank = self._get_virtual_pp_rank(micro_step + 1,
                                                             forward=True)
            recv_prev = True
            if self.is_pipeline_first_stage(ignore_virtual=True):
                if next_virtual_pp_rank == 0:
                    # next chunk is the first chunk, not need to pre recv an input tensor
                    recv_prev = False
            # last micro step, no next run
            if micro_step == (num_steps - 1):
                recv_prev = False

            # last stage shouldn't send tensor to downstream
            if self.is_pipeline_last_stage():
                output_tensor = None

            if micro_step == (startup_steps -
                              1) and not forward_only and not all_startup_steps:
                input_tensor_grad = None
                recv_next = True
                if self.is_pipeline_last_stage(ignore_virtual=True):
                    recv_next = False

                # the last startup step needs on four direction comm to set up for steady 1f1b
                input_tensor, output_tensor_grad = p2p.send_forward_backward_recv_forward_backward(
                    output_tensor,
                    input_tensor_grad,
                    recv_prev=recv_prev,
                    recv_next=recv_next)
                self.output_tensor_grads[self.num_model_chunks -
                                         1].append(output_tensor_grad)
            else:
                input_tensor = p2p.send_forward_recv_forward(
                    output_tensor, recv_prev=recv_prev)
            self.input_tensors[next_virtual_pp_rank].append(input_tensor)

        # run 1f1b steady steps
        for micro_step in range(steady_steps):
            # forward
            forward_micro_step_id = micro_step + startup_steps
            output_tensor = self._forward_step_helper(forward_micro_step_id)

            # backward
            backward_micro_step_id = micro_step
            input_tensor_grad = self._backward_step_helper(
                backward_micro_step_id)

            # four directions comm
            # send output tensor to downstream
            # send input tensor grad to upstream
            # recv input tensor from upstream
            # recv output tensor grad from downstream

            # last stage doesn't send rst to downstream
            forward_virtual_pp_rank = self._get_virtual_pp_rank(
                forward_micro_step_id, forward=True)
            self.set_virtual_pipeline_rank(forward_virtual_pp_rank)
            if self.is_pipeline_last_stage():
                output_tensor = None

            # first stage doesn't send grad to upstream
            backward_virtual_pp_rank = self._get_virtual_pp_rank(
                backward_micro_step_id, forward=False)
            self.set_virtual_pipeline_rank(backward_virtual_pp_rank)
            if self.is_pipeline_first_stage():
                input_tensor_grad = None

            # determine whether to recv input tensor from upstream
            recv_prev = True
            if self.is_pipeline_first_stage(ignore_virtual=True):
                next_forward_virtual_pp_rank = self._get_virtual_pp_rank(
                    forward_micro_step_id - (self.num_stages - 1), forward=True)
                if next_forward_virtual_pp_rank == (self.num_model_chunks - 1):
                    # first pp stage and first virtual stage
                    recv_prev = False
                next_forward_virtual_pp_rank += 1
            else:
                next_forward_virtual_pp_rank = self._get_virtual_pp_rank(
                    forward_micro_step_id + 1, forward=True)

            # last iteration doesn't need recv from upstream
            if micro_step == (steady_steps - 1):
                recv_prev = False

            # determine whether to recv grad from downstream
            recv_next = True
            if self.is_pipeline_last_stage(ignore_virtual=True):
                next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
                    backward_micro_step_id - (self.num_stages - 1),
                    forward=False)
                if next_backward_virtual_pp_rank == 0:
                    # last pp stage and last virtual stage
                    recv_next = False
                next_backward_virtual_pp_rank -= 1
            else:
                next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
                    backward_micro_step_id + 1, forward=False)

            input_tensor, output_tensor_grad = p2p.send_forward_backward_recv_forward_backward(
                output_tensor,
                input_tensor_grad,
                recv_prev=recv_prev,
                recv_next=recv_next)

            if recv_prev:
                self.input_tensors[next_forward_virtual_pp_rank].append(
                    input_tensor)
            if recv_next:
                self.output_tensor_grads[next_backward_virtual_pp_rank].append(
                    output_tensor_grad)

        # remaining backward steps
        if not forward_only:
            if all_startup_steps:
                self.output_tensor_grads[self.num_model_chunks - 1].append(
                    p2p.recv_backward(self.is_pipeline_last_stage()))

            for micro_step in range(steady_steps, num_steps):
                # cooldown loop
                input_tensor_grad = self._backward_step_helper(micro_step)
                next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
                    micro_step + 1, forward=False)

                recv_next = True
                if self.is_pipeline_last_stage(ignore_virtual=True):
                    if next_backward_virtual_pp_rank == (self.num_model_chunks -
                                                         1):
                        recv_next = False

                if micro_step == (num_steps - 1):
                    recv_next = False

                self.output_tensor_grads[next_backward_virtual_pp_rank].append(
                    p2p.send_backward_recv_backward(input_tensor_grad,
                                                    recv_next=recv_next))

            self._layers.allreduce_shared_weight_gradients()

        if compute_loss:
            # return loss if compute loss
            with paddle.amp.auto_cast(enable=False):
                train_loss = self._broadcast_final_loss()
        else:
            # else just return all intermediate output tensor for all micro steps
            train_loss = self.output_tensors

        return train_loss

    def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
        data = self._prepare_training(data, optimizer, lr_scheduler)
        # interleave scheduler for pipeline parallel
        train_loss = self.interleave_pipeline(data, scaler)

        # optimizer
        with paddle.amp.auto_cast(enable=False):
            self._optimizer_step()

        return train_loss

    def eval_batch(self, data, compute_loss=False):
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

        self._layers.eval()
        self._compute_loss = compute_loss

        return self.interleave_pipeline(data, None, forward_only=True)