pipeline_parallel.py 28.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

import paddle
import paddle.fluid as fluid
from .meta_parallel_base import MetaParallelBase
17
from .pp_utils.utils import is_float_tensor, _initialize_recompute_hcg
18
from .parallel_layers.pp_layers import PipelineLayer
19 20 21

from ..utils.hybrid_parallel_util import broadcast_mp_parameters
from ..utils.hybrid_parallel_util import broadcast_dp_parameters
22
from ..utils.hybrid_parallel_util import broadcast_sharding_parameters
23
from ..utils.log_util import logger
24
from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer, HybridParallelGradScaler
25
import paddle.fluid.framework as framework
S
ShenLiang 已提交
26
from .pp_utils import p2p_communication as p2p
S
ShenLiang 已提交
27
import paddle.fluid.core as core
28

29 30
__all__ = []

31 32

class PipelineParallel(MetaParallelBase):
33

34
    def __init__(self, layers, hcg, strategy):
35 36 37
        if not isinstance(layers, PipelineLayer):
            raise TypeError(
                "The Layer should be a derived class of PipelineLayer.")
38 39 40
        super(PipelineParallel, self).__init__(layers, hcg, strategy)
        self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1
        self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1
41 42
        self.use_sharding_parallel = self._hcg.get_sharding_parallel_world_size(
        ) > 1
43 44 45 46 47 48 49 50

        self.total_loss = None

        self.micro_batch_size = self._strategy.pipeline_configs[
            'micro_batch_size']
        self.accumulate_steps = self._strategy.pipeline_configs[
            'accumulate_steps']

51 52
        self._using_cache = self._strategy.pipeline_configs['p2p_cache_shape']

53 54
        self.num_stages = self._hcg.get_pipe_parallel_world_size()
        self.stage_id = self._hcg.get_stage_id()
55
        self.pp_group = self._hcg.get_pipe_parallel_group()
56

57 58 59 60 61
        self._virtual_pp_world_size = None
        self._virtual_pp_rank = None
        self._real_pp_world_size = self.num_stages
        self._real_pp_rank = self.stage_id

62
        p2p.initialize_p2p_groups(hcg, self._using_cache)
63

64 65
        _initialize_recompute_hcg(hcg)

66
        self.global_rank = self._hcg.get_global_rank()
67
        self.micro_batch_id = 0
68

69 70
        self._compute_loss = True

71 72 73 74 75 76 77
        logger.info("Pipeline Info -- num_stages: {}, stage_id: {}".format(
            self.num_stages, self.stage_id))

        if self.use_model_parallel:
            logger.info("start broadcast mp parameters")
            broadcast_mp_parameters(self._layers, self._hcg)

78 79 80 81
        if self.use_sharding_parallel:
            logger.info("start broadcast sharding parameters")
            broadcast_sharding_parameters(self._layers, self._hcg)

82
        if self.use_data_parallel:
83
            logger.info("start broadcast dp parameters")
84
            broadcast_dp_parameters(self._layers, self._hcg)
85

86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
    def is_pipeline_first_stage(self, ignore_virtual=False):
        if not ignore_virtual:
            if self._virtual_pp_world_size is not None:
                assert self._virtual_pp_rank is not None
                if self._virtual_pp_rank != 0:
                    return False
        assert self._real_pp_rank is not None
        return self._real_pp_rank == 0

    def is_pipeline_last_stage(self, ignore_virtual=False):
        if not ignore_virtual:
            if self._virtual_pp_world_size is not None:
                assert self._virtual_pp_rank is not None
                if self._virtual_pp_rank != (self._virtual_pp_world_size - 1):
                    return False
        assert self._real_pp_rank is not None
        assert self._real_pp_world_size is not None
        return self._real_pp_rank == (self._real_pp_world_size - 1)

    def set_virtual_pipeline_rank(self, rank):
        self._virtual_pp_rank = rank

108 109 110 111
    def forward_backward_pipeline(self, data, scaler=None):
        # use the 1f1b scheduling strategy.
        # this strategy is inspired by:
        # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
112

113 114
        self.scaler = scaler

115 116
        # store data for train
        self.data = data
117

118 119 120
        # store total loss of entire batch
        self.total_loss = None

121 122
        # store data id for micro_batch
        self.micro_batch_id = 0
123

124 125 126
        startup_steps = (self.num_stages - self.stage_id - 1)
        startup_steps = min(startup_steps, self.accumulate_steps)
        steady_steps = self.accumulate_steps - startup_steps
127

128 129
        input_buffers = []
        output_buffers = []
130

131
        for step_id in range(startup_steps):
132
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
133

134
            output_tensor = self._forward_step(input_tensor)
135
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
136

137 138
            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)
139

140
        if steady_steps > 0:
141
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
142

143 144
        for i in range(steady_steps):
            last_iter = (i == (steady_steps - 1))
145

146
            output_tensor = self._forward_step(input_tensor)
147

148 149
            output_tensor_grad = p2p.send_forward_recv_backward(
                output_tensor, self.is_pipeline_last_stage())
150

151 152
            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)
153

154 155
            input_tensor, output_tensor = input_buffers.pop(
                0), output_buffers.pop(0)
156

157 158 159 160 161
            input_tensor_grad = self._backward_step(input_tensor, output_tensor,
                                                    output_tensor_grad)

            if last_iter:
                input_tensor = None
162 163
                p2p.send_backward(input_tensor_grad,
                                  self.is_pipeline_first_stage())
164
            else:
165 166
                input_tensor = p2p.send_backward_recv_forward(
                    input_tensor_grad, self.is_pipeline_first_stage())
167

168 169 170
        for i in range(startup_steps):
            input_tensor = input_buffers.pop(0)
            output_tensor = output_buffers.pop(0)
171

172 173
            output_tensor_grad = p2p.recv_backward(
                self.is_pipeline_last_stage())
174

175 176
            input_tensor_grad = self._backward_step(input_tensor, output_tensor,
                                                    output_tensor_grad)
177
            p2p.send_backward(input_tensor_grad, self.is_pipeline_first_stage())
178

179
        self._layers.allreduce_shared_weight_gradients()
180 181
        with paddle.amp.auto_cast(enable=False):
            train_loss = self._broadcast_final_loss()
182 183
        return train_loss

184 185 186 187
    def _prepare_training(self, data, optimizer, lr_scheduler):
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

188 189 190 191 192 193
        assert isinstance(optimizer, HybridParallelOptimizer), (
            'optimizer should be HybridParallelOptimizer subclass.')

        assert fluid.framework._dygraph_tracer()._has_grad, (
            'Please enable the generation of gradients.')

194 195 196
        if self.is_pipeline_first_stage(
                ignore_virtual=True) or self.is_pipeline_last_stage(
                    ignore_virtual=True):
197 198 199 200 201 202 203 204 205 206
            assert data is not None, (
                "For the first and the last stage, the data must be set.")
        else:
            data = None

        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler

        self._layers.train()

207 208 209 210 211
        return data

    def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
        data = self._prepare_training(data, optimizer, lr_scheduler)
        # 1f1b scheduler for pipeline parallel
212
        train_loss = self.forward_backward_pipeline(data, scaler)
213 214

        # optimizer
215 216
        with paddle.amp.auto_cast(enable=False):
            self._optimizer_step()
217 218

        return train_loss
219

220
    def eval_batch(self, data, compute_loss=False):
221 222 223
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242
        self._layers.eval()
        self._compute_loss = compute_loss

        # save data for eval
        self.data = data
        # store data id for micro_batch
        self.micro_batch_id = 0

        # store total loss of entire batch
        self.total_loss = None

        startup_steps = (self.num_stages - self.stage_id - 1)
        startup_steps = min(startup_steps, self.accumulate_steps)
        steady_steps = self.accumulate_steps - startup_steps

        input_buffers = []
        output_buffers = []

        for step_id in range(startup_steps):
243
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
244 245

            output_tensor = self._forward_step(input_tensor)
246
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
247 248 249 250 251

            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)

        if steady_steps > 0:
252
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
253 254 255 256 257

        for i in range(steady_steps):
            last_iter = (i == (steady_steps - 1))

            output_tensor = self._forward_step(input_tensor)
258
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
259 260 261 262 263

            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)

            if not last_iter:
264
                input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
265

266 267 268 269 270 271
        if self._compute_loss:
            self.train_loss = self._broadcast_final_loss()
        else:
            self.train_loss = output_buffers

        return self.train_loss
272

273 274
    def _forward_step(self, input_tensor, chunk_id=None):
        if self.is_pipeline_first_stage():
275 276
            input_tensor = self._load_micro_batch(self.micro_batch_id)

277 278 279
        assert chunk_id is None or isinstance(chunk_id, int)

        output_tensor = self._layers.forward(input_tensor, chunk_id=chunk_id)
280

281
        if self.is_pipeline_last_stage():
282 283 284 285 286
            # train calculate loss for train
            if self._compute_loss:
                assert self._layers._loss_fn is not None, "loss function should exist to compute loss"
                labels = self._load_micro_batch(self.micro_batch_id)
                output_tensor = self._layers._loss_fn(output_tensor, labels)
287 288 289 290
                assert isinstance(
                    output_tensor,
                    (paddle.Tensor, core.eager.Tensor
                     )), "Currently, loss_fn should obtain Paddle.Tensor dtype"
291

292 293 294
                with paddle.amp.auto_cast(enable=False):
                    if self.accumulate_steps > 1:
                        output_tensor = output_tensor / self.accumulate_steps
295

296 297 298
                    if self.total_loss is None:
                        self.total_loss = paddle.zeros_like(output_tensor)
                    self.total_loss += output_tensor.detach()
299

300 301 302 303
        if self.is_pipeline_first_stage() or self.is_pipeline_last_stage():
            # Only increase micro batch id at virtual first/last pp stage.
            # The micro batch id is used to load data, therefore, only increase it when load data.
            self.micro_batch_id += 1
304 305 306
        return output_tensor

    def _backward_step(self, input_tensor, output_tensor, output_tensor_grad):
S
ShenLiang 已提交
307
        with paddle.amp.auto_cast(enable=False):
308
            if self.is_pipeline_last_stage():
S
ShenLiang 已提交
309 310 311 312 313
                assert output_tensor_grad is None
                if self.scaler:
                    paddle.autograd.backward(self.scaler.scale(output_tensor))
                else:
                    paddle.autograd.backward(output_tensor)
314
            else:
S
ShenLiang 已提交
315 316 317 318 319 320 321
                if isinstance(output_tensor, tuple):
                    outputs = [t for t in output_tensor if not t.stop_gradient]
                    assert len(outputs) == len(output_tensor_grad)
                    paddle.autograd.backward(
                        tensors=outputs,
                        grad_tensors=[t for t in output_tensor_grad])
                else:
322 323
                    paddle.autograd.backward(tensors=[output_tensor],
                                             grad_tensors=[output_tensor_grad])
S
ShenLiang 已提交
324 325 326 327 328 329 330 331 332

            input_tensor_grad = None
            if input_tensor is not None:
                if isinstance(input_tensor, tuple):
                    input_tensor_grad = tuple(
                        [t.grad for t in input_tensor if not t.stop_gradient])
                else:
                    input_tensor_grad = input_tensor.grad
            return input_tensor_grad
333 334

    def _load_micro_batch(self, cache_id):
335 336 337 338
        inputs = self.data
        begin = cache_id * self.micro_batch_size
        end = begin + self.micro_batch_size

339 340
        # The virtual first and last pipeline stage need data, all others don't need.
        if self.is_pipeline_first_stage():
341 342
            assert len(inputs) == 2, "length of input should be 2"
            if isinstance(inputs[0], tuple):
343 344 345
                assert len(
                    inputs[0]
                ) > 1, "If you use tuple for input data, it should have at least two inputs."
346 347 348 349 350 351
                batch_size = inputs[0][0].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size, (
                    "batch_size needs to be divisible by micro_batch_size. Currently, "
                    "batch_size = %d, micro_batch_size = %d, accumulate_steps = %d."
                    %
                    (batch_size, self.micro_batch_size, self.accumulate_steps))
352 353
                data = [input[begin:end, :].detach() for input in inputs[0]]
                return tuple(data)
354 355 356
            else:
                batch_size = inputs[0].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size
357
                return inputs[0][begin:end, :].detach()
358
        elif self.is_pipeline_last_stage():
359 360 361 362
            assert len(inputs) == 2, "length of input should be 2"
            if isinstance(inputs[1], tuple):
                batch_size = inputs[1][0].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size
363 364
                data = [input[begin:end, :].detach() for input in inputs[1]]
                return tuple(data)
365
            else:
366 367
                batch_size = inputs[1].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size
368
                return inputs[1][begin:end, :].detach()
369 370 371
        else:
            # No data input is required for other stages
            inputs = None
372

373
    def _broadcast_final_loss(self):
374 375 376
        # Since the last backward run in interleave will set the virtual rank to 0,
        # here we need to check last stage ignoring virtual stage.
        if self.is_pipeline_last_stage(ignore_virtual=True):
377 378
            assert self.total_loss is not None, "train_batch() in last stage should obtain vaild loss"
            loss = self.total_loss.detach()
379 380
            is_fp32 = paddle.to_tensor(
                1) if loss.dtype == paddle.float32 else paddle.to_tensor(0)
381 382 383 384 385 386 387 388
            paddle.distributed.broadcast(is_fp32,
                                         src=self.global_rank,
                                         use_calc_stream=True,
                                         group=self.pp_group)
            paddle.distributed.broadcast(loss,
                                         src=self.global_rank,
                                         use_calc_stream=True,
                                         group=self.pp_group)
389
        else:
390 391 392 393 394 395
            is_fp32 = paddle.to_tensor(1)
            paddle.distributed.broadcast(
                is_fp32,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
                use_calc_stream=True,
                group=self.pp_group)
396 397 398 399
            loss = paddle.zeros(shape=[
                1
            ], dtype="float32") if is_fp32.numpy()[0] else paddle.zeros(
                shape=[1], dtype="float16")
400 401 402 403 404 405
            paddle.distributed.broadcast(
                loss,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
                use_calc_stream=True,
                group=self.pp_group)
        return loss
406

407
    def _optimizer_step(self):
408
        if self.scaler:
409
            self.scaler.step(self.optimizer)
S
ShenLiang 已提交
410
            self.scaler.update()
411 412
        else:
            self.optimizer.step()
413

414 415 416
        self.optimizer.clear_grad()
        if self.lr_scheduler:
            self.lr_scheduler.step()



class PipelineParallelWithInterleave(PipelineParallel):
    # pipeline parallel with interleave scheduler

    def __init__(self, layers, hcg, strategy):
        super(PipelineParallelWithInterleave, self).__init__(layers=layers,
                                                             hcg=hcg,
                                                             strategy=strategy)
        assert layers.get_num_virtual_stages() > 1
        assert framework.in_dygraph_mode(
        ), "virtual pipeline stage with interleave only support eager dygraph mode"
        # setup for interleave scheduler
        self.num_model_chunks = layers.get_num_virtual_stages()
        self.model_chunks = layers.get_model_chunks()
        assert self.model_chunks is not None
        assert len(self.model_chunks) == self.num_model_chunks
        self._virtual_pp_world_size = self.num_model_chunks
        self._virtual_pp_rank = 0

    def _get_virtual_pp_rank(self, micro_step, forward):
        virtual_pp_stage = micro_step % (self.num_stages *
                                         self.num_model_chunks)
        virtual_pp_stage = virtual_pp_stage // self.num_stages
        if not forward:
            virtual_pp_stage = (self.num_model_chunks - virtual_pp_stage - 1)
        return virtual_pp_stage

    def _forward_step_helper(self, micro_step):
        virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=True)
        self.set_virtual_pipeline_rank(virtual_pp_rank)

        # some checkers
        assert hasattr(self, 'input_tensors')
        assert hasattr(self, 'output_tensors')
        if not self._forward_only:
            assert hasattr(self, 'output_tensor_grads')

        if self.is_pipeline_first_stage():
            if len(self.input_tensors[virtual_pp_rank]) == len(
                    self.output_tensors[virtual_pp_rank]):
                self.input_tensors[virtual_pp_rank].append(None)
        input_tensor = self.input_tensors[virtual_pp_rank][-1]
        output_tensor = self._forward_step(input_tensor, virtual_pp_rank)
        self.output_tensors[virtual_pp_rank].append(output_tensor)

        if self._forward_only:
            # no need to store tensor for backward
            self.input_tensors[virtual_pp_rank].pop()
            self.output_tensors[virtual_pp_rank].pop()

        return output_tensor

    def _backward_step_helper(self, micro_step):
        virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=False)
        self.set_virtual_pipeline_rank(virtual_pp_rank)

        # some checkers
        assert hasattr(self, 'input_tensors')
        assert hasattr(self, 'output_tensors')
        assert hasattr(self, 'output_tensor_grads')

        if self.is_pipeline_last_stage():
            if len(self.output_tensor_grads[virtual_pp_rank]) == 0:
                self.output_tensor_grads[virtual_pp_rank].append(None)

        input_tensor = self.input_tensors[virtual_pp_rank].pop(0)
        output_tensor = self.output_tensors[virtual_pp_rank].pop(0)
        output_tensor_grad = self.output_tensor_grads[virtual_pp_rank].pop(0)
        input_tensor_grad = self._backward_step(input_tensor, output_tensor,
                                                output_tensor_grad)

        return input_tensor_grad

    def interleave_pipeline(self,
                            data,
                            scaler,
                            forward_only=False,
                            compute_loss=True):
        # use interleave scheduling strategy.
        # this strategy is inspired by:
        # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
        if not compute_loss:
            assert not forward_only, "compute_loss can only be set to False when forward_only is set to True"

        # init some attributes for this batch run
        self.scaler = scaler
        self.data = data
        self.total_loss = None
        self.micro_batch_id = 0
        self._forward_only = forward_only

        # init some data buffers for interleave scheduler
        self.input_tensors = [[] for _ in range(self.num_model_chunks)]
        self.output_tensors = [[] for _ in range(self.num_model_chunks)]
        self.output_tensor_grads = [[] for _ in range(self.num_model_chunks)]

        num_steps = self.accumulate_steps * self.num_model_chunks
        all_startup_steps = False
        if forward_only:
            # If only forward, since there is no backward during running, all steps are startup steps
            startup_steps = num_steps
        else:
            if self.accumulate_steps == self.num_stages:
                startup_steps = num_steps
                all_startup_steps = True
            else:
                startup_steps = (self.num_stages - self.stage_id - 1) * 2
                startup_steps += (self.num_model_chunks - 1) * self.num_stages
                startup_steps = min(startup_steps, num_steps)

        steady_steps = num_steps - startup_steps

        self.set_virtual_pipeline_rank(0)
        self.input_tensors[0].append(
            p2p.recv_forward(self.is_pipeline_first_stage()))

        # run startup steps
        for micro_step in range(startup_steps):
            output_tensor = self._forward_step_helper(micro_step)

            # determine whether recv forward tensor or not
            next_virtual_pp_rank = self._get_virtual_pp_rank(micro_step + 1,
                                                             forward=True)
            recv_prev = True
            if self.is_pipeline_first_stage(ignore_virtual=True):
                if next_virtual_pp_rank == 0:
                    # next chunk is the first chunk, not need to pre recv an input tensor
                    recv_prev = False
            # last micro step, no next run
            if micro_step == (num_steps - 1):
                recv_prev = False

            # last stage shouldn't send tensor to downstream
            if self.is_pipeline_last_stage():
                output_tensor = None

            if micro_step == (startup_steps -
                              1) and not forward_only and not all_startup_steps:
                input_tensor_grad = None
                recv_next = True
                if self.is_pipeline_last_stage(ignore_virtual=True):
                    recv_next = False

                # the last startup step needs on four direction comm to set up for steady 1f1b
                input_tensor, output_tensor_grad = p2p.send_forward_backward_recv_forward_backward(
                    output_tensor,
                    input_tensor_grad,
                    recv_prev=recv_prev,
                    recv_next=recv_next)
                self.output_tensor_grads[self.num_model_chunks -
                                         1].append(output_tensor_grad)
            else:
                input_tensor = p2p.send_forward_recv_forward(
                    output_tensor, recv_prev=recv_prev)
            self.input_tensors[next_virtual_pp_rank].append(input_tensor)

        # run 1f1b steady steps
        for micro_step in range(steady_steps):
            # forward
            forward_micro_step_id = micro_step + startup_steps
            output_tensor = self._forward_step_helper(forward_micro_step_id)

            # backward
            backward_micro_step_id = micro_step
            input_tensor_grad = self._backward_step_helper(
                backward_micro_step_id)

            # four directions comm
            # send output tensor to downstream
            # send input tensor grad to upstream
            # recv input tensor from upstream
            # recv output tensor grad from downstream

            # last stage doesn't send rst to downstream
            forward_virtual_pp_rank = self._get_virtual_pp_rank(
                forward_micro_step_id, forward=True)
            self.set_virtual_pipeline_rank(forward_virtual_pp_rank)
            if self.is_pipeline_last_stage():
                output_tensor = None

            # first stage doesn't send grad to upstream
            backward_virtual_pp_rank = self._get_virtual_pp_rank(
                backward_micro_step_id, forward=False)
            self.set_virtual_pipeline_rank(backward_virtual_pp_rank)
            if self.is_pipeline_first_stage():
                input_tensor_grad = None

            # determine whether to recv input tensor from upstream
            recv_prev = True
            if self.is_pipeline_first_stage(ignore_virtual=True):
                next_forward_virtual_pp_rank = self._get_virtual_pp_rank(
                    forward_micro_step_id - (self.num_stages - 1), forward=True)
                if next_forward_virtual_pp_rank == (self.num_model_chunks - 1):
                    # first pp stage and first virtual stage
                    recv_prev = False
                next_forward_virtual_pp_rank += 1
            else:
                next_forward_virtual_pp_rank = self._get_virtual_pp_rank(
                    forward_micro_step_id + 1, forward=True)

            # last iteration doesn't need recv from upstream
            if micro_step == (steady_steps - 1):
                recv_prev = False

            # determine whether to recv grad from downstream
            recv_next = True
            if self.is_pipeline_last_stage(ignore_virtual=True):
                next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
                    backward_micro_step_id - (self.num_stages - 1),
                    forward=False)
                if next_backward_virtual_pp_rank == 0:
                    # last pp stage and last virtual stage
                    recv_next = False
                next_backward_virtual_pp_rank -= 1
            else:
                next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
                    backward_micro_step_id + 1, forward=False)

            input_tensor, output_tensor_grad = p2p.send_forward_backward_recv_forward_backward(
                output_tensor,
                input_tensor_grad,
                recv_prev=recv_prev,
                recv_next=recv_next)

            if recv_prev:
                self.input_tensors[next_forward_virtual_pp_rank].append(
                    input_tensor)
            if recv_next:
                self.output_tensor_grads[next_backward_virtual_pp_rank].append(
                    output_tensor_grad)

        # remaining backward steps
        if not forward_only:
            if all_startup_steps:
                self.output_tensor_grads[self.num_model_chunks - 1].append(
                    p2p.recv_backward(self.is_pipeline_last_stage()))

            for micro_step in range(steady_steps, num_steps):
                # cooldown loop
                input_tensor_grad = self._backward_step_helper(micro_step)
                next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
                    micro_step + 1, forward=False)

                recv_next = True
                if self.is_pipeline_last_stage(ignore_virtual=True):
                    if next_backward_virtual_pp_rank == (self.num_model_chunks -
                                                         1):
                        recv_next = False

                if micro_step == (num_steps - 1):
                    recv_next = False

                self.output_tensor_grads[next_backward_virtual_pp_rank].append(
                    p2p.send_backward_recv_backward(input_tensor_grad,
                                                    recv_next=recv_next))

            self._layers.allreduce_shared_weight_gradients()

        if compute_loss:
            # return loss if compute loss
            with paddle.amp.auto_cast(enable=False):
                train_loss = self._broadcast_final_loss()
        else:
            # else just return all intermediate output tensor for all micro steps
            train_loss = self.output_tensors

        return train_loss

    def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
        data = self._prepare_training(data, optimizer, lr_scheduler)
        # interleave scheduler for pipeline parallel
        train_loss = self.interleave_pipeline(data, scaler)

        # optimizer
        with paddle.amp.auto_cast(enable=False):
            self._optimizer_step()

        return train_loss

    def eval_batch(self, data, compute_loss=False):
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

        self._layers.eval()
        self._compute_loss = compute_loss

        return self.interleave_pipeline(data, None, forward_only=True)