pipeline_parallel.py 29.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

import paddle
import paddle.fluid as fluid
from .meta_parallel_base import MetaParallelBase
from .parallel_layers.pp_layers import PipelineLayer
18 19 20

from ..utils.hybrid_parallel_util import broadcast_mp_parameters
from ..utils.hybrid_parallel_util import broadcast_dp_parameters
21
from ..utils.hybrid_parallel_util import broadcast_sharding_parameters
22
from ..utils.log_util import logger
23 24 25 26
from ..meta_optimizers.dygraph_optimizer import (
    HybridParallelOptimizer,
    HybridParallelGradScaler,
)
27
import paddle.fluid.framework as framework
S
ShenLiang 已提交
28
from .pp_utils import p2p_communication as p2p
S
ShenLiang 已提交
29
import paddle.fluid.core as core
30

31 32
__all__ = []

33 34 35

class PipelineParallel(MetaParallelBase):
    def __init__(self, layers, hcg, strategy):
36 37
        if not isinstance(layers, PipelineLayer):
            raise TypeError(
38 39
                "The Layer should be a derived class of PipelineLayer."
            )
40 41 42
        super(PipelineParallel, self).__init__(layers, hcg, strategy)
        self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1
        self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1
43 44 45
        self.use_sharding_parallel = (
            self._hcg.get_sharding_parallel_world_size() > 1
        )
46 47 48 49

        self.total_loss = None

        self.micro_batch_size = self._strategy.pipeline_configs[
50 51
            'micro_batch_size'
        ]
52
        self.accumulate_steps = self._strategy.pipeline_configs[
53 54
            'accumulate_steps'
        ]
55 56 57
        # If sent tensor are not the same from different hosts,
        # they shouldn't been sent partially and then concated as a whole tensor.
        self._enable_partial_send_recv = self._strategy.pipeline_configs[
58 59
            'enable_partial_send_recv'
        ]
60 61
        self._using_cache = self._strategy.pipeline_configs['p2p_cache_shape']

62 63
        self.num_stages = self._hcg.get_pipe_parallel_world_size()
        self.stage_id = self._hcg.get_stage_id()
64
        self.pp_group = self._hcg.get_pipe_parallel_group()
65

66 67 68 69 70
        self._virtual_pp_world_size = None
        self._virtual_pp_rank = None
        self._real_pp_world_size = self.num_stages
        self._real_pp_rank = self.stage_id

71 72 73
        p2p.initialize_p2p_groups(
            hcg, self._using_cache, self._enable_partial_send_recv
        )
74 75

        self.global_rank = self._hcg.get_global_rank()
76
        self.micro_batch_id = 0
77

78 79
        self._compute_loss = True

80 81 82 83 84
        logger.info(
            "Pipeline Info -- num_stages: {}, stage_id: {}".format(
                self.num_stages, self.stage_id
            )
        )
85 86 87 88 89

        if self.use_model_parallel:
            logger.info("start broadcast mp parameters")
            broadcast_mp_parameters(self._layers, self._hcg)

90 91 92 93
        if self.use_sharding_parallel:
            logger.info("start broadcast sharding parameters")
            broadcast_sharding_parameters(self._layers, self._hcg)

94
        if self.use_data_parallel:
95
            logger.info("start broadcast dp parameters")
96
            broadcast_dp_parameters(self._layers, self._hcg)
97

98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
    def is_pipeline_first_stage(self, ignore_virtual=False):
        if not ignore_virtual:
            if self._virtual_pp_world_size is not None:
                assert self._virtual_pp_rank is not None
                if self._virtual_pp_rank != 0:
                    return False
        assert self._real_pp_rank is not None
        return self._real_pp_rank == 0

    def is_pipeline_last_stage(self, ignore_virtual=False):
        if not ignore_virtual:
            if self._virtual_pp_world_size is not None:
                assert self._virtual_pp_rank is not None
                if self._virtual_pp_rank != (self._virtual_pp_world_size - 1):
                    return False
        assert self._real_pp_rank is not None
        assert self._real_pp_world_size is not None
        return self._real_pp_rank == (self._real_pp_world_size - 1)

    def set_virtual_pipeline_rank(self, rank):
        self._virtual_pp_rank = rank

120 121 122 123
    def forward_backward_pipeline(self, data, scaler=None):
        # use the 1f1b scheduling strategy.
        # this strategy is inspired by:
        # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
124

125 126
        self.scaler = scaler

127 128
        # store data for train
        self.data = data
129

130 131 132
        # store total loss of entire batch
        self.total_loss = None

133 134
        # store data id for micro_batch
        self.micro_batch_id = 0
135

136
        startup_steps = self.num_stages - self.stage_id - 1
137 138
        startup_steps = min(startup_steps, self.accumulate_steps)
        steady_steps = self.accumulate_steps - startup_steps
139

140 141
        input_buffers = []
        output_buffers = []
142

143
        for step_id in range(startup_steps):
144
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
145

146
            output_tensor = self._forward_step(input_tensor)
147
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
148

149 150
            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)
151

152
        if steady_steps > 0:
153
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
154

155
        for i in range(steady_steps):
156
            last_iter = i == (steady_steps - 1)
157

158
            output_tensor = self._forward_step(input_tensor)
159

160
            output_tensor_grad = p2p.send_forward_recv_backward(
161 162
                output_tensor, self.is_pipeline_last_stage()
            )
163

164 165
            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)
166

167
            input_tensor, output_tensor = input_buffers.pop(
168 169
                0
            ), output_buffers.pop(0)
170

171 172 173
            input_tensor_grad = self._backward_step(
                input_tensor, output_tensor, output_tensor_grad
            )
174 175 176

            if last_iter:
                input_tensor = None
177 178 179
                p2p.send_backward(
                    input_tensor_grad, self.is_pipeline_first_stage()
                )
180
            else:
181
                input_tensor = p2p.send_backward_recv_forward(
182 183
                    input_tensor_grad, self.is_pipeline_first_stage()
                )
184

185 186 187
        for i in range(startup_steps):
            input_tensor = input_buffers.pop(0)
            output_tensor = output_buffers.pop(0)
188

189
            output_tensor_grad = p2p.recv_backward(
190 191
                self.is_pipeline_last_stage()
            )
192

193 194 195
            input_tensor_grad = self._backward_step(
                input_tensor, output_tensor, output_tensor_grad
            )
196
            p2p.send_backward(input_tensor_grad, self.is_pipeline_first_stage())
197

198
        self._layers.allreduce_shared_weight_gradients()
199 200
        with paddle.amp.auto_cast(enable=False):
            train_loss = self._broadcast_final_loss()
201 202
        return train_loss

203 204 205 206
    def _prepare_training(self, data, optimizer, lr_scheduler):
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

207 208 209
        assert isinstance(
            optimizer, HybridParallelOptimizer
        ), 'optimizer should be HybridParallelOptimizer subclass.'
210

211 212 213
        assert (
            fluid.framework._dygraph_tracer()._has_grad
        ), 'Please enable the generation of gradients.'
214

215
        if self.is_pipeline_first_stage(
216 217 218 219 220
            ignore_virtual=True
        ) or self.is_pipeline_last_stage(ignore_virtual=True):
            assert (
                data is not None
            ), "For the first and the last stage, the data must be set."
221 222 223 224 225 226 227 228
        else:
            data = None

        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler

        self._layers.train()

229 230 231 232 233
        return data

    def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
        data = self._prepare_training(data, optimizer, lr_scheduler)
        # 1f1b scheduler for pipeline parallel
234
        train_loss = self.forward_backward_pipeline(data, scaler)
235 236

        # optimizer
237 238
        with paddle.amp.auto_cast(enable=False):
            self._optimizer_step()
239 240

        return train_loss
241

242
    def eval_batch(self, data, compute_loss=False):
243 244 245
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

246 247 248 249 250 251 252 253 254 255 256
        self._layers.eval()
        self._compute_loss = compute_loss

        # save data for eval
        self.data = data
        # store data id for micro_batch
        self.micro_batch_id = 0

        # store total loss of entire batch
        self.total_loss = None

257
        startup_steps = self.num_stages - self.stage_id - 1
258 259 260 261 262 263 264
        startup_steps = min(startup_steps, self.accumulate_steps)
        steady_steps = self.accumulate_steps - startup_steps

        input_buffers = []
        output_buffers = []

        for step_id in range(startup_steps):
265
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
266 267

            output_tensor = self._forward_step(input_tensor)
268
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
269 270 271 272 273

            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)

        if steady_steps > 0:
274
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
275 276

        for i in range(steady_steps):
277
            last_iter = i == (steady_steps - 1)
278 279

            output_tensor = self._forward_step(input_tensor)
280
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
281 282 283 284 285

            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)

            if not last_iter:
286
                input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
287

288 289 290 291 292 293
        if self._compute_loss:
            self.train_loss = self._broadcast_final_loss()
        else:
            self.train_loss = output_buffers

        return self.train_loss
294

295 296
    def _forward_step(self, input_tensor, chunk_id=None):
        if self.is_pipeline_first_stage():
297 298
            input_tensor = self._load_micro_batch(self.micro_batch_id)

299 300 301
        assert chunk_id is None or isinstance(chunk_id, int)

        output_tensor = self._layers.forward(input_tensor, chunk_id=chunk_id)
302

303
        if self.is_pipeline_last_stage():
304 305
            # train calculate loss for train
            if self._compute_loss:
306 307 308
                assert (
                    self._layers._loss_fn is not None
                ), "loss function should exist to compute loss"
309 310
                labels = self._load_micro_batch(self.micro_batch_id)
                output_tensor = self._layers._loss_fn(output_tensor, labels)
311
                assert isinstance(
312 313
                    output_tensor, (paddle.Tensor, core.eager.Tensor)
                ), "Currently, loss_fn should obtain Paddle.Tensor dtype"
314

315 316 317
                with paddle.amp.auto_cast(enable=False):
                    if self.accumulate_steps > 1:
                        output_tensor = output_tensor / self.accumulate_steps
318

319 320 321
                    if self.total_loss is None:
                        self.total_loss = paddle.zeros_like(output_tensor)
                    self.total_loss += output_tensor.detach()
322

323 324 325 326
        if self.is_pipeline_first_stage() or self.is_pipeline_last_stage():
            # Only increase micro batch id at virtual first/last pp stage.
            # The micro batch id is used to load data, therefore, only increase it when load data.
            self.micro_batch_id += 1
327 328 329
        return output_tensor

    def _backward_step(self, input_tensor, output_tensor, output_tensor_grad):
S
ShenLiang 已提交
330
        with paddle.amp.auto_cast(enable=False):
331
            if self.is_pipeline_last_stage():
S
ShenLiang 已提交
332 333 334 335 336
                assert output_tensor_grad is None
                if self.scaler:
                    paddle.autograd.backward(self.scaler.scale(output_tensor))
                else:
                    paddle.autograd.backward(output_tensor)
337
            else:
S
ShenLiang 已提交
338 339 340 341 342
                if isinstance(output_tensor, tuple):
                    outputs = [t for t in output_tensor if not t.stop_gradient]
                    assert len(outputs) == len(output_tensor_grad)
                    paddle.autograd.backward(
                        tensors=outputs,
343 344
                        grad_tensors=[t for t in output_tensor_grad],
                    )
S
ShenLiang 已提交
345
                else:
346 347 348 349
                    paddle.autograd.backward(
                        tensors=[output_tensor],
                        grad_tensors=[output_tensor_grad],
                    )
S
ShenLiang 已提交
350 351 352 353 354

            input_tensor_grad = None
            if input_tensor is not None:
                if isinstance(input_tensor, tuple):
                    input_tensor_grad = tuple(
355 356
                        [t.grad for t in input_tensor if not t.stop_gradient]
                    )
S
ShenLiang 已提交
357 358 359
                else:
                    input_tensor_grad = input_tensor.grad
            return input_tensor_grad
360

361 362 363 364 365 366 367 368 369
    def _check_data_vaild(self, data):
        batch_size = data.shape[0]
        assert self.micro_batch_size * self.accumulate_steps == batch_size, (
            "batch_size needs to be divisible by micro_batch_size. Currently, "
            "batch_size = %d, micro_batch_size = %d, accumulate_steps = %d."
            % (batch_size, self.micro_batch_size, self.accumulate_steps)
        )

    def _load_micro_batch_impl(self, inputs, cache_id):
370 371 372
        begin = cache_id * self.micro_batch_size
        end = begin + self.micro_batch_size

373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402
        if isinstance(inputs, tuple):
            output = []
            for data in inputs:
                if isinstance(data, list):
                    assert (
                        len(data) == self.accumulate_steps
                    ), "length of data should be %d, but it is %d" % (
                        self.accumulate_steps,
                        len(data),
                    )
                    output.append(data[cache_id].detach())
                else:
                    self._check_data_vaild(data)
                    output.append(data[begin:end, :].detach())
            return tuple(output)

        elif isinstance(inputs, list):
            assert (
                len(inputs) == self.accumulate_steps
            ), "length of data should be %d, but it is %d" % (
                self.accumulate_steps,
                len(inputs),
            )
            return inputs[cache_id].detach()
        else:
            self._check_data_vaild(inputs)
            return inputs[begin:end, :].detach()

    def _load_micro_batch(self, cache_id):
        inputs = self.data
403
        if self.is_pipeline_first_stage():
404
            assert len(inputs) == 2, "length of input should be 2"
405
            return self._load_micro_batch_impl(inputs[0], cache_id)
406
        elif self.is_pipeline_last_stage():
407
            assert len(inputs) == 2, "length of input should be 2"
408
            return self._load_micro_batch_impl(inputs[1], cache_id)
409 410
        else:
            inputs = None
411

412
    def _broadcast_final_loss(self):
413 414 415
        # Since the last backward run in interleave will set the virtual rank to 0,
        # here we need to check last stage ignoring virtual stage.
        if self.is_pipeline_last_stage(ignore_virtual=True):
416 417 418
            assert (
                self.total_loss is not None
            ), "train_batch() in last stage should obtain vaild loss"
419
            loss = self.total_loss.detach()
420 421 422 423 424 425 426 427 428 429 430
            is_fp32 = (
                paddle.to_tensor(1)
                if loss.dtype == paddle.float32
                else paddle.to_tensor(0)
            )
            paddle.distributed.broadcast(
                is_fp32, src=self.global_rank, sync_op=True, group=self.pp_group
            )
            paddle.distributed.broadcast(
                loss, src=self.global_rank, sync_op=True, group=self.pp_group
            )
431
        else:
432 433 434 435
            is_fp32 = paddle.to_tensor(1)
            paddle.distributed.broadcast(
                is_fp32,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
436
                sync_op=True,
437 438 439 440 441 442 443
                group=self.pp_group,
            )
            loss = (
                paddle.zeros(shape=[1], dtype="float32")
                if is_fp32.numpy()[0]
                else paddle.zeros(shape=[1], dtype="float16")
            )
444 445 446
            paddle.distributed.broadcast(
                loss,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
447
                sync_op=True,
448 449
                group=self.pp_group,
            )
450
        return loss
451

452
    def _optimizer_step(self):
453
        if self.scaler:
454
            self.scaler.step(self.optimizer)
S
ShenLiang 已提交
455
            self.scaler.update()
456 457
        else:
            self.optimizer.step()
458

459 460 461
        self.optimizer.clear_grad()
        if self.lr_scheduler:
            self.lr_scheduler.step()
462 463 464 465 466 467


class PipelineParallelWithInterleave(PipelineParallel):
    # pipeline parallel with interleave scheduler

    def __init__(self, layers, hcg, strategy):
468 469 470
        super(PipelineParallelWithInterleave, self).__init__(
            layers=layers, hcg=hcg, strategy=strategy
        )
471
        assert layers.get_num_virtual_stages() > 1
472 473
        assert (
            framework.in_dygraph_mode()
474 475 476 477 478 479 480 481 482 483
        ), "virtual pipeline stage with interleave only support eager dygraph mode"
        # setup for interleave scheduler
        self.num_model_chunks = layers.get_num_virtual_stages()
        self.model_chunks = layers.get_model_chunks()
        assert self.model_chunks is not None
        assert len(self.model_chunks) == self.num_model_chunks
        self._virtual_pp_world_size = self.num_model_chunks
        self._virtual_pp_rank = 0

    def _get_virtual_pp_rank(self, micro_step, forward):
484 485 486
        virtual_pp_stage = micro_step % (
            self.num_stages * self.num_model_chunks
        )
487 488
        virtual_pp_stage = virtual_pp_stage // self.num_stages
        if not forward:
489
            virtual_pp_stage = self.num_model_chunks - virtual_pp_stage - 1
490 491 492 493 494 495 496 497 498 499 500 501 502 503
        return virtual_pp_stage

    def _forward_step_helper(self, micro_step):
        virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=True)
        self.set_virtual_pipeline_rank(virtual_pp_rank)

        # some checkers
        assert hasattr(self, 'input_tensors')
        assert hasattr(self, 'output_tensors')
        if not self._forward_only:
            assert hasattr(self, 'output_tensor_grads')

        if self.is_pipeline_first_stage():
            if len(self.input_tensors[virtual_pp_rank]) == len(
504 505
                self.output_tensors[virtual_pp_rank]
            ):
506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533
                self.input_tensors[virtual_pp_rank].append(None)
        input_tensor = self.input_tensors[virtual_pp_rank][-1]
        output_tensor = self._forward_step(input_tensor, virtual_pp_rank)
        self.output_tensors[virtual_pp_rank].append(output_tensor)

        if self._forward_only:
            # no need to store tensor for backward
            self.input_tensors[virtual_pp_rank].pop()
            self.output_tensors[virtual_pp_rank].pop()

        return output_tensor

    def _backward_step_helper(self, micro_step):
        virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=False)
        self.set_virtual_pipeline_rank(virtual_pp_rank)

        # some checkers
        assert hasattr(self, 'input_tensors')
        assert hasattr(self, 'output_tensors')
        assert hasattr(self, 'output_tensor_grads')

        if self.is_pipeline_last_stage():
            if len(self.output_tensor_grads[virtual_pp_rank]) == 0:
                self.output_tensor_grads[virtual_pp_rank].append(None)

        input_tensor = self.input_tensors[virtual_pp_rank].pop(0)
        output_tensor = self.output_tensors[virtual_pp_rank].pop(0)
        output_tensor_grad = self.output_tensor_grads[virtual_pp_rank].pop(0)
534 535 536
        input_tensor_grad = self._backward_step(
            input_tensor, output_tensor, output_tensor_grad
        )
537 538 539

        return input_tensor_grad

540 541 542
    def interleave_pipeline(
        self, data, scaler, forward_only=False, compute_loss=True
    ):
543 544 545 546
        # use interleave scheduling strategy.
        # this strategy is inspired by:
        # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
        if not compute_loss:
547 548 549
            assert (
                not forward_only
            ), "compute_loss can only be set to False when forward_only is set to True"
550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580

        # init some attributes for this batch run
        self.scaler = scaler
        self.data = data
        self.total_loss = None
        self.micro_batch_id = 0
        self._forward_only = forward_only

        # init some data buffers for interleave scheduler
        self.input_tensors = [[] for _ in range(self.num_model_chunks)]
        self.output_tensors = [[] for _ in range(self.num_model_chunks)]
        self.output_tensor_grads = [[] for _ in range(self.num_model_chunks)]

        num_steps = self.accumulate_steps * self.num_model_chunks
        all_startup_steps = False
        if forward_only:
            # If only forward, since there is no backward during running, all steps are startup steps
            startup_steps = num_steps
        else:
            if self.accumulate_steps == self.num_stages:
                startup_steps = num_steps
                all_startup_steps = True
            else:
                startup_steps = (self.num_stages - self.stage_id - 1) * 2
                startup_steps += (self.num_model_chunks - 1) * self.num_stages
                startup_steps = min(startup_steps, num_steps)

        steady_steps = num_steps - startup_steps

        self.set_virtual_pipeline_rank(0)
        self.input_tensors[0].append(
581 582
            p2p.recv_forward(self.is_pipeline_first_stage(), sync_recv=False)
        )
583 584 585 586 587 588

        # run startup steps
        for micro_step in range(startup_steps):
            output_tensor = self._forward_step_helper(micro_step)

            # determine whether recv forward tensor or not
589 590 591
            next_virtual_pp_rank = self._get_virtual_pp_rank(
                micro_step + 1, forward=True
            )
592 593 594 595 596 597 598 599 600 601 602 603 604
            recv_prev = True
            if self.is_pipeline_first_stage(ignore_virtual=True):
                if next_virtual_pp_rank == 0:
                    # next chunk is the first chunk, not need to pre recv an input tensor
                    recv_prev = False
            # last micro step, no next run
            if micro_step == (num_steps - 1):
                recv_prev = False

            # last stage shouldn't send tensor to downstream
            if self.is_pipeline_last_stage():
                output_tensor = None

605 606 607 608 609
            if (
                micro_step == (startup_steps - 1)
                and not forward_only
                and not all_startup_steps
            ):
610 611 612 613 614 615
                input_tensor_grad = None
                recv_next = True
                if self.is_pipeline_last_stage(ignore_virtual=True):
                    recv_next = False

                # the last startup step needs on four direction comm to set up for steady 1f1b
616 617 618 619
                (
                    input_tensor,
                    output_tensor_grad,
                ) = p2p.send_forward_backward_recv_forward_backward(
620 621 622
                    output_tensor,
                    input_tensor_grad,
                    recv_prev=recv_prev,
623 624 625 626 627
                    recv_next=recv_next,
                )
                self.output_tensor_grads[self.num_model_chunks - 1].append(
                    output_tensor_grad
                )
628 629
            else:
                input_tensor = p2p.send_forward_recv_forward(
630 631
                    output_tensor, recv_prev=recv_prev
                )
632 633 634 635 636 637 638 639 640 641 642
            self.input_tensors[next_virtual_pp_rank].append(input_tensor)

        # run 1f1b steady steps
        for micro_step in range(steady_steps):
            # forward
            forward_micro_step_id = micro_step + startup_steps
            output_tensor = self._forward_step_helper(forward_micro_step_id)

            # backward
            backward_micro_step_id = micro_step
            input_tensor_grad = self._backward_step_helper(
643 644
                backward_micro_step_id
            )
645 646 647 648 649 650 651 652 653

            # four directions comm
            # send output tensor to downstream
            # send input tensor grad to upstream
            # recv input tensor from upstream
            # recv output tensor grad from downstream

            # last stage doesn't send rst to downstream
            forward_virtual_pp_rank = self._get_virtual_pp_rank(
654 655
                forward_micro_step_id, forward=True
            )
656 657 658 659 660 661
            self.set_virtual_pipeline_rank(forward_virtual_pp_rank)
            if self.is_pipeline_last_stage():
                output_tensor = None

            # first stage doesn't send grad to upstream
            backward_virtual_pp_rank = self._get_virtual_pp_rank(
662 663
                backward_micro_step_id, forward=False
            )
664 665 666 667 668 669 670 671
            self.set_virtual_pipeline_rank(backward_virtual_pp_rank)
            if self.is_pipeline_first_stage():
                input_tensor_grad = None

            # determine whether to recv input tensor from upstream
            recv_prev = True
            if self.is_pipeline_first_stage(ignore_virtual=True):
                next_forward_virtual_pp_rank = self._get_virtual_pp_rank(
672 673
                    forward_micro_step_id - (self.num_stages - 1), forward=True
                )
674 675 676 677 678 679
                if next_forward_virtual_pp_rank == (self.num_model_chunks - 1):
                    # first pp stage and first virtual stage
                    recv_prev = False
                next_forward_virtual_pp_rank += 1
            else:
                next_forward_virtual_pp_rank = self._get_virtual_pp_rank(
680 681
                    forward_micro_step_id + 1, forward=True
                )
682 683 684 685 686 687 688 689 690 691

            # last iteration doesn't need recv from upstream
            if micro_step == (steady_steps - 1):
                recv_prev = False

            # determine whether to recv grad from downstream
            recv_next = True
            if self.is_pipeline_last_stage(ignore_virtual=True):
                next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
                    backward_micro_step_id - (self.num_stages - 1),
692 693
                    forward=False,
                )
694 695 696 697 698 699
                if next_backward_virtual_pp_rank == 0:
                    # last pp stage and last virtual stage
                    recv_next = False
                next_backward_virtual_pp_rank -= 1
            else:
                next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
700 701
                    backward_micro_step_id + 1, forward=False
                )
702

703 704 705 706
            (
                input_tensor,
                output_tensor_grad,
            ) = p2p.send_forward_backward_recv_forward_backward(
707 708 709
                output_tensor,
                input_tensor_grad,
                recv_prev=recv_prev,
710 711
                recv_next=recv_next,
            )
712 713 714

            if recv_prev:
                self.input_tensors[next_forward_virtual_pp_rank].append(
715 716
                    input_tensor
                )
717 718
            if recv_next:
                self.output_tensor_grads[next_backward_virtual_pp_rank].append(
719 720
                    output_tensor_grad
                )
721 722 723 724 725

        # remaining backward steps
        if not forward_only:
            if all_startup_steps:
                self.output_tensor_grads[self.num_model_chunks - 1].append(
726 727 728 729
                    p2p.recv_backward(
                        self.is_pipeline_last_stage(), sync_recv=False
                    )
                )
730 731 732 733 734

            for micro_step in range(steady_steps, num_steps):
                # cooldown loop
                input_tensor_grad = self._backward_step_helper(micro_step)
                next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
735 736
                    micro_step + 1, forward=False
                )
737 738 739

                recv_next = True
                if self.is_pipeline_last_stage(ignore_virtual=True):
740 741 742
                    if next_backward_virtual_pp_rank == (
                        self.num_model_chunks - 1
                    ):
743 744 745 746 747 748
                        recv_next = False

                if micro_step == (num_steps - 1):
                    recv_next = False

                self.output_tensor_grads[next_backward_virtual_pp_rank].append(
749 750 751 752
                    p2p.send_backward_recv_backward(
                        input_tensor_grad, recv_next=recv_next
                    )
                )
753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784

            self._layers.allreduce_shared_weight_gradients()

        if compute_loss:
            # return loss if compute loss
            with paddle.amp.auto_cast(enable=False):
                train_loss = self._broadcast_final_loss()
        else:
            # else just return all intermediate output tensor for all micro steps
            train_loss = self.output_tensors

        return train_loss

    def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
        data = self._prepare_training(data, optimizer, lr_scheduler)
        # interleave scheduler for pipeline parallel
        train_loss = self.interleave_pipeline(data, scaler)

        # optimizer
        with paddle.amp.auto_cast(enable=False):
            self._optimizer_step()

        return train_loss

    def eval_batch(self, data, compute_loss=False):
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

        self._layers.eval()
        self._compute_loss = compute_loss

        return self.interleave_pipeline(data, None, forward_only=True)