pipeline_parallel.py 29.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

import paddle
import paddle.fluid as fluid
from .meta_parallel_base import MetaParallelBase
from .parallel_layers.pp_layers import PipelineLayer
18 19 20

from ..utils.hybrid_parallel_util import broadcast_mp_parameters
from ..utils.hybrid_parallel_util import broadcast_dp_parameters
21
from ..utils.hybrid_parallel_util import broadcast_sharding_parameters
22
from ..utils.log_util import logger
23
from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer
24
import paddle.fluid.framework as framework
S
ShenLiang 已提交
25
from .pp_utils import p2p_communication as p2p
S
ShenLiang 已提交
26
import paddle.fluid.core as core
27

28 29
__all__ = []

30 31 32

class PipelineParallel(MetaParallelBase):
    def __init__(self, layers, hcg, strategy):
33 34
        if not isinstance(layers, PipelineLayer):
            raise TypeError(
35 36
                "The Layer should be a derived class of PipelineLayer."
            )
37 38 39
        super(PipelineParallel, self).__init__(layers, hcg, strategy)
        self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1
        self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1
40 41 42
        self.use_sharding_parallel = (
            self._hcg.get_sharding_parallel_world_size() > 1
        )
43 44 45 46

        self.total_loss = None

        self.micro_batch_size = self._strategy.pipeline_configs[
47 48
            'micro_batch_size'
        ]
49
        self.accumulate_steps = self._strategy.pipeline_configs[
50 51
            'accumulate_steps'
        ]
52 53 54
        # If sent tensor are not the same from different hosts,
        # they shouldn't been sent partially and then concated as a whole tensor.
        self._enable_partial_send_recv = self._strategy.pipeline_configs[
55 56
            'enable_partial_send_recv'
        ]
57 58
        self._using_cache = self._strategy.pipeline_configs['p2p_cache_shape']

59 60
        self.num_stages = self._hcg.get_pipe_parallel_world_size()
        self.stage_id = self._hcg.get_stage_id()
61
        self.pp_group = self._hcg.get_pipe_parallel_group()
62

63 64 65 66 67
        self._virtual_pp_world_size = None
        self._virtual_pp_rank = None
        self._real_pp_world_size = self.num_stages
        self._real_pp_rank = self.stage_id

68 69 70
        p2p.initialize_p2p_groups(
            hcg, self._using_cache, self._enable_partial_send_recv
        )
71 72

        self.global_rank = self._hcg.get_global_rank()
73
        self.micro_batch_id = 0
74

75 76
        self._compute_loss = True

77 78 79 80 81
        logger.info(
            "Pipeline Info -- num_stages: {}, stage_id: {}".format(
                self.num_stages, self.stage_id
            )
        )
82 83 84 85 86

        if self.use_model_parallel:
            logger.info("start broadcast mp parameters")
            broadcast_mp_parameters(self._layers, self._hcg)

87 88 89 90
        if self.use_sharding_parallel:
            logger.info("start broadcast sharding parameters")
            broadcast_sharding_parameters(self._layers, self._hcg)

91
        if self.use_data_parallel:
92
            logger.info("start broadcast dp parameters")
93
            broadcast_dp_parameters(self._layers, self._hcg)
94

95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
    def is_pipeline_first_stage(self, ignore_virtual=False):
        if not ignore_virtual:
            if self._virtual_pp_world_size is not None:
                assert self._virtual_pp_rank is not None
                if self._virtual_pp_rank != 0:
                    return False
        assert self._real_pp_rank is not None
        return self._real_pp_rank == 0

    def is_pipeline_last_stage(self, ignore_virtual=False):
        if not ignore_virtual:
            if self._virtual_pp_world_size is not None:
                assert self._virtual_pp_rank is not None
                if self._virtual_pp_rank != (self._virtual_pp_world_size - 1):
                    return False
        assert self._real_pp_rank is not None
        assert self._real_pp_world_size is not None
        return self._real_pp_rank == (self._real_pp_world_size - 1)

    def set_virtual_pipeline_rank(self, rank):
        self._virtual_pp_rank = rank

117 118 119 120
    def forward_backward_pipeline(self, data, scaler=None):
        # use the 1f1b scheduling strategy.
        # this strategy is inspired by:
        # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
121

122 123
        self.scaler = scaler

124 125
        # store data for train
        self.data = data
126

127 128 129
        # store total loss of entire batch
        self.total_loss = None

130 131
        # store data id for micro_batch
        self.micro_batch_id = 0
132

133
        startup_steps = self.num_stages - self.stage_id - 1
134 135
        startup_steps = min(startup_steps, self.accumulate_steps)
        steady_steps = self.accumulate_steps - startup_steps
136

137 138
        input_buffers = []
        output_buffers = []
139

140
        for step_id in range(startup_steps):
141
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
142

143
            output_tensor = self._forward_step(input_tensor)
144
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
145

146 147
            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)
148

149
        if steady_steps > 0:
150
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
151

152
        for i in range(steady_steps):
153
            last_iter = i == (steady_steps - 1)
154

155
            output_tensor = self._forward_step(input_tensor)
156

157
            output_tensor_grad = p2p.send_forward_recv_backward(
158 159
                output_tensor, self.is_pipeline_last_stage()
            )
160

161 162
            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)
163

164
            input_tensor, output_tensor = input_buffers.pop(
165 166
                0
            ), output_buffers.pop(0)
167

168 169 170
            input_tensor_grad = self._backward_step(
                input_tensor, output_tensor, output_tensor_grad
            )
171 172 173

            if last_iter:
                input_tensor = None
174 175 176
                p2p.send_backward(
                    input_tensor_grad, self.is_pipeline_first_stage()
                )
177
            else:
178
                input_tensor = p2p.send_backward_recv_forward(
179 180
                    input_tensor_grad, self.is_pipeline_first_stage()
                )
181

182 183 184
        for i in range(startup_steps):
            input_tensor = input_buffers.pop(0)
            output_tensor = output_buffers.pop(0)
185

186
            output_tensor_grad = p2p.recv_backward(
187 188
                self.is_pipeline_last_stage()
            )
189

190 191 192
            input_tensor_grad = self._backward_step(
                input_tensor, output_tensor, output_tensor_grad
            )
193
            p2p.send_backward(input_tensor_grad, self.is_pipeline_first_stage())
194

195
        self._layers.allreduce_shared_weight_gradients()
196 197
        with paddle.amp.auto_cast(enable=False):
            train_loss = self._broadcast_final_loss()
198 199
        return train_loss

200 201 202 203
    def _prepare_training(self, data, optimizer, lr_scheduler):
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

204 205 206
        assert isinstance(
            optimizer, HybridParallelOptimizer
        ), 'optimizer should be HybridParallelOptimizer subclass.'
207

208 209 210
        assert (
            fluid.framework._dygraph_tracer()._has_grad
        ), 'Please enable the generation of gradients.'
211

212
        if self.is_pipeline_first_stage(
213 214 215 216 217
            ignore_virtual=True
        ) or self.is_pipeline_last_stage(ignore_virtual=True):
            assert (
                data is not None
            ), "For the first and the last stage, the data must be set."
218 219 220 221 222 223 224 225
        else:
            data = None

        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler

        self._layers.train()

226 227 228 229 230
        return data

    def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
        data = self._prepare_training(data, optimizer, lr_scheduler)
        # 1f1b scheduler for pipeline parallel
231
        train_loss = self.forward_backward_pipeline(data, scaler)
232 233

        # optimizer
234 235
        with paddle.amp.auto_cast(enable=False):
            self._optimizer_step()
236 237

        return train_loss
238

239
    def eval_batch(self, data, compute_loss=False):
240 241 242
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

243 244 245 246 247 248 249 250 251 252 253
        self._layers.eval()
        self._compute_loss = compute_loss

        # save data for eval
        self.data = data
        # store data id for micro_batch
        self.micro_batch_id = 0

        # store total loss of entire batch
        self.total_loss = None

254
        startup_steps = self.num_stages - self.stage_id - 1
255 256 257 258 259 260 261
        startup_steps = min(startup_steps, self.accumulate_steps)
        steady_steps = self.accumulate_steps - startup_steps

        input_buffers = []
        output_buffers = []

        for step_id in range(startup_steps):
262
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
263 264

            output_tensor = self._forward_step(input_tensor)
265
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
266 267 268 269 270

            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)

        if steady_steps > 0:
271
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
272 273

        for i in range(steady_steps):
274
            last_iter = i == (steady_steps - 1)
275 276

            output_tensor = self._forward_step(input_tensor)
277
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
278 279 280 281 282

            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)

            if not last_iter:
283
                input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
284

285 286 287 288 289 290
        if self._compute_loss:
            self.train_loss = self._broadcast_final_loss()
        else:
            self.train_loss = output_buffers

        return self.train_loss
291

292 293
    def _forward_step(self, input_tensor, chunk_id=None):
        if self.is_pipeline_first_stage():
294 295
            input_tensor = self._load_micro_batch(self.micro_batch_id)

296 297 298
        assert chunk_id is None or isinstance(chunk_id, int)

        output_tensor = self._layers.forward(input_tensor, chunk_id=chunk_id)
299

300
        if self.is_pipeline_last_stage():
301 302
            # train calculate loss for train
            if self._compute_loss:
303 304 305
                assert (
                    self._layers._loss_fn is not None
                ), "loss function should exist to compute loss"
306 307
                labels = self._load_micro_batch(self.micro_batch_id)
                output_tensor = self._layers._loss_fn(output_tensor, labels)
308
                assert isinstance(
309 310
                    output_tensor, (paddle.Tensor, core.eager.Tensor)
                ), "Currently, loss_fn should obtain Paddle.Tensor dtype"
311

312 313 314
                with paddle.amp.auto_cast(enable=False):
                    if self.accumulate_steps > 1:
                        output_tensor = output_tensor / self.accumulate_steps
315

316 317 318
                    if self.total_loss is None:
                        self.total_loss = paddle.zeros_like(output_tensor)
                    self.total_loss += output_tensor.detach()
319

320 321 322 323
        if self.is_pipeline_first_stage() or self.is_pipeline_last_stage():
            # Only increase micro batch id at virtual first/last pp stage.
            # The micro batch id is used to load data, therefore, only increase it when load data.
            self.micro_batch_id += 1
324 325 326
        return output_tensor

    def _backward_step(self, input_tensor, output_tensor, output_tensor_grad):
S
ShenLiang 已提交
327
        with paddle.amp.auto_cast(enable=False):
328
            if self.is_pipeline_last_stage():
S
ShenLiang 已提交
329 330 331 332 333
                assert output_tensor_grad is None
                if self.scaler:
                    paddle.autograd.backward(self.scaler.scale(output_tensor))
                else:
                    paddle.autograd.backward(output_tensor)
334
            else:
S
ShenLiang 已提交
335 336 337 338 339
                if isinstance(output_tensor, tuple):
                    outputs = [t for t in output_tensor if not t.stop_gradient]
                    assert len(outputs) == len(output_tensor_grad)
                    paddle.autograd.backward(
                        tensors=outputs,
340 341
                        grad_tensors=[t for t in output_tensor_grad],
                    )
S
ShenLiang 已提交
342
                else:
343 344 345 346
                    paddle.autograd.backward(
                        tensors=[output_tensor],
                        grad_tensors=[output_tensor_grad],
                    )
S
ShenLiang 已提交
347 348 349 350 351

            input_tensor_grad = None
            if input_tensor is not None:
                if isinstance(input_tensor, tuple):
                    input_tensor_grad = tuple(
352 353
                        [t.grad for t in input_tensor if not t.stop_gradient]
                    )
S
ShenLiang 已提交
354 355 356
                else:
                    input_tensor_grad = input_tensor.grad
            return input_tensor_grad
357 358

    def _load_micro_batch(self, cache_id):
359 360 361 362
        inputs = self.data
        begin = cache_id * self.micro_batch_size
        end = begin + self.micro_batch_size

363 364
        # The virtual first and last pipeline stage need data, all others don't need.
        if self.is_pipeline_first_stage():
365 366
            assert len(inputs) == 2, "length of input should be 2"
            if isinstance(inputs[0], tuple):
367 368 369
                assert (
                    len(inputs[0]) > 1
                ), "If you use tuple for input data, it should have at least two inputs."
370
                batch_size = inputs[0][0].shape[0]
371 372 373
                assert (
                    self.micro_batch_size * self.accumulate_steps == batch_size
                ), (
374 375
                    "batch_size needs to be divisible by micro_batch_size. Currently, "
                    "batch_size = %d, micro_batch_size = %d, accumulate_steps = %d."
376 377
                    % (batch_size, self.micro_batch_size, self.accumulate_steps)
                )
378 379
                data = [input[begin:end, :].detach() for input in inputs[0]]
                return tuple(data)
380 381
            else:
                batch_size = inputs[0].shape[0]
382 383 384
                assert (
                    self.micro_batch_size * self.accumulate_steps == batch_size
                )
385
                return inputs[0][begin:end, :].detach()
386
        elif self.is_pipeline_last_stage():
387 388 389
            assert len(inputs) == 2, "length of input should be 2"
            if isinstance(inputs[1], tuple):
                batch_size = inputs[1][0].shape[0]
390 391 392
                assert (
                    self.micro_batch_size * self.accumulate_steps == batch_size
                )
393 394
                data = [input[begin:end, :].detach() for input in inputs[1]]
                return tuple(data)
395
            else:
396
                batch_size = inputs[1].shape[0]
397 398 399
                assert (
                    self.micro_batch_size * self.accumulate_steps == batch_size
                )
400
                return inputs[1][begin:end, :].detach()
401 402 403
        else:
            # No data input is required for other stages
            inputs = None
404

405
    def _broadcast_final_loss(self):
406 407 408
        # Since the last backward run in interleave will set the virtual rank to 0,
        # here we need to check last stage ignoring virtual stage.
        if self.is_pipeline_last_stage(ignore_virtual=True):
409 410 411
            assert (
                self.total_loss is not None
            ), "train_batch() in last stage should obtain vaild loss"
412
            loss = self.total_loss.detach()
413 414 415 416 417 418 419 420 421 422 423
            is_fp32 = (
                paddle.to_tensor(1)
                if loss.dtype == paddle.float32
                else paddle.to_tensor(0)
            )
            paddle.distributed.broadcast(
                is_fp32, src=self.global_rank, sync_op=True, group=self.pp_group
            )
            paddle.distributed.broadcast(
                loss, src=self.global_rank, sync_op=True, group=self.pp_group
            )
424
        else:
425 426 427 428
            is_fp32 = paddle.to_tensor(1)
            paddle.distributed.broadcast(
                is_fp32,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
429
                sync_op=True,
430 431 432 433 434 435 436
                group=self.pp_group,
            )
            loss = (
                paddle.zeros(shape=[1], dtype="float32")
                if is_fp32.numpy()[0]
                else paddle.zeros(shape=[1], dtype="float16")
            )
437 438 439
            paddle.distributed.broadcast(
                loss,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
440
                sync_op=True,
441 442
                group=self.pp_group,
            )
443
        return loss
444

445
    def _optimizer_step(self):
446
        if self.scaler:
447
            self.scaler.step(self.optimizer)
S
ShenLiang 已提交
448
            self.scaler.update()
449 450
        else:
            self.optimizer.step()
451

452 453 454
        self.optimizer.clear_grad()
        if self.lr_scheduler:
            self.lr_scheduler.step()
455 456 457 458 459 460


class PipelineParallelWithInterleave(PipelineParallel):
    # pipeline parallel with interleave scheduler

    def __init__(self, layers, hcg, strategy):
461 462 463
        super(PipelineParallelWithInterleave, self).__init__(
            layers=layers, hcg=hcg, strategy=strategy
        )
464
        assert layers.get_num_virtual_stages() > 1
465 466
        assert (
            framework.in_dygraph_mode()
467 468 469 470 471 472 473 474 475 476
        ), "virtual pipeline stage with interleave only support eager dygraph mode"
        # setup for interleave scheduler
        self.num_model_chunks = layers.get_num_virtual_stages()
        self.model_chunks = layers.get_model_chunks()
        assert self.model_chunks is not None
        assert len(self.model_chunks) == self.num_model_chunks
        self._virtual_pp_world_size = self.num_model_chunks
        self._virtual_pp_rank = 0

    def _get_virtual_pp_rank(self, micro_step, forward):
477 478 479
        virtual_pp_stage = micro_step % (
            self.num_stages * self.num_model_chunks
        )
480 481
        virtual_pp_stage = virtual_pp_stage // self.num_stages
        if not forward:
482
            virtual_pp_stage = self.num_model_chunks - virtual_pp_stage - 1
483 484 485 486 487 488 489 490 491 492 493 494 495 496
        return virtual_pp_stage

    def _forward_step_helper(self, micro_step):
        virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=True)
        self.set_virtual_pipeline_rank(virtual_pp_rank)

        # some checkers
        assert hasattr(self, 'input_tensors')
        assert hasattr(self, 'output_tensors')
        if not self._forward_only:
            assert hasattr(self, 'output_tensor_grads')

        if self.is_pipeline_first_stage():
            if len(self.input_tensors[virtual_pp_rank]) == len(
497 498
                self.output_tensors[virtual_pp_rank]
            ):
499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526
                self.input_tensors[virtual_pp_rank].append(None)
        input_tensor = self.input_tensors[virtual_pp_rank][-1]
        output_tensor = self._forward_step(input_tensor, virtual_pp_rank)
        self.output_tensors[virtual_pp_rank].append(output_tensor)

        if self._forward_only:
            # no need to store tensor for backward
            self.input_tensors[virtual_pp_rank].pop()
            self.output_tensors[virtual_pp_rank].pop()

        return output_tensor

    def _backward_step_helper(self, micro_step):
        virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=False)
        self.set_virtual_pipeline_rank(virtual_pp_rank)

        # some checkers
        assert hasattr(self, 'input_tensors')
        assert hasattr(self, 'output_tensors')
        assert hasattr(self, 'output_tensor_grads')

        if self.is_pipeline_last_stage():
            if len(self.output_tensor_grads[virtual_pp_rank]) == 0:
                self.output_tensor_grads[virtual_pp_rank].append(None)

        input_tensor = self.input_tensors[virtual_pp_rank].pop(0)
        output_tensor = self.output_tensors[virtual_pp_rank].pop(0)
        output_tensor_grad = self.output_tensor_grads[virtual_pp_rank].pop(0)
527 528 529
        input_tensor_grad = self._backward_step(
            input_tensor, output_tensor, output_tensor_grad
        )
530 531 532

        return input_tensor_grad

533 534 535
    def interleave_pipeline(
        self, data, scaler, forward_only=False, compute_loss=True
    ):
536 537 538 539
        # use interleave scheduling strategy.
        # this strategy is inspired by:
        # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
        if not compute_loss:
540 541 542
            assert (
                not forward_only
            ), "compute_loss can only be set to False when forward_only is set to True"
543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573

        # init some attributes for this batch run
        self.scaler = scaler
        self.data = data
        self.total_loss = None
        self.micro_batch_id = 0
        self._forward_only = forward_only

        # init some data buffers for interleave scheduler
        self.input_tensors = [[] for _ in range(self.num_model_chunks)]
        self.output_tensors = [[] for _ in range(self.num_model_chunks)]
        self.output_tensor_grads = [[] for _ in range(self.num_model_chunks)]

        num_steps = self.accumulate_steps * self.num_model_chunks
        all_startup_steps = False
        if forward_only:
            # If only forward, since there is no backward during running, all steps are startup steps
            startup_steps = num_steps
        else:
            if self.accumulate_steps == self.num_stages:
                startup_steps = num_steps
                all_startup_steps = True
            else:
                startup_steps = (self.num_stages - self.stage_id - 1) * 2
                startup_steps += (self.num_model_chunks - 1) * self.num_stages
                startup_steps = min(startup_steps, num_steps)

        steady_steps = num_steps - startup_steps

        self.set_virtual_pipeline_rank(0)
        self.input_tensors[0].append(
574 575
            p2p.recv_forward(self.is_pipeline_first_stage(), sync_recv=False)
        )
576 577 578 579 580 581

        # run startup steps
        for micro_step in range(startup_steps):
            output_tensor = self._forward_step_helper(micro_step)

            # determine whether recv forward tensor or not
582 583 584
            next_virtual_pp_rank = self._get_virtual_pp_rank(
                micro_step + 1, forward=True
            )
585 586 587 588 589 590 591 592 593 594 595 596 597
            recv_prev = True
            if self.is_pipeline_first_stage(ignore_virtual=True):
                if next_virtual_pp_rank == 0:
                    # next chunk is the first chunk, not need to pre recv an input tensor
                    recv_prev = False
            # last micro step, no next run
            if micro_step == (num_steps - 1):
                recv_prev = False

            # last stage shouldn't send tensor to downstream
            if self.is_pipeline_last_stage():
                output_tensor = None

598 599 600 601 602
            if (
                micro_step == (startup_steps - 1)
                and not forward_only
                and not all_startup_steps
            ):
603 604 605 606 607 608
                input_tensor_grad = None
                recv_next = True
                if self.is_pipeline_last_stage(ignore_virtual=True):
                    recv_next = False

                # the last startup step needs on four direction comm to set up for steady 1f1b
609 610 611 612
                (
                    input_tensor,
                    output_tensor_grad,
                ) = p2p.send_forward_backward_recv_forward_backward(
613 614 615
                    output_tensor,
                    input_tensor_grad,
                    recv_prev=recv_prev,
616 617 618 619 620
                    recv_next=recv_next,
                )
                self.output_tensor_grads[self.num_model_chunks - 1].append(
                    output_tensor_grad
                )
621 622
            else:
                input_tensor = p2p.send_forward_recv_forward(
623 624
                    output_tensor, recv_prev=recv_prev
                )
625 626 627 628 629 630 631 632 633 634 635
            self.input_tensors[next_virtual_pp_rank].append(input_tensor)

        # run 1f1b steady steps
        for micro_step in range(steady_steps):
            # forward
            forward_micro_step_id = micro_step + startup_steps
            output_tensor = self._forward_step_helper(forward_micro_step_id)

            # backward
            backward_micro_step_id = micro_step
            input_tensor_grad = self._backward_step_helper(
636 637
                backward_micro_step_id
            )
638 639 640 641 642 643 644 645 646

            # four directions comm
            # send output tensor to downstream
            # send input tensor grad to upstream
            # recv input tensor from upstream
            # recv output tensor grad from downstream

            # last stage doesn't send rst to downstream
            forward_virtual_pp_rank = self._get_virtual_pp_rank(
647 648
                forward_micro_step_id, forward=True
            )
649 650 651 652 653 654
            self.set_virtual_pipeline_rank(forward_virtual_pp_rank)
            if self.is_pipeline_last_stage():
                output_tensor = None

            # first stage doesn't send grad to upstream
            backward_virtual_pp_rank = self._get_virtual_pp_rank(
655 656
                backward_micro_step_id, forward=False
            )
657 658 659 660 661 662 663 664
            self.set_virtual_pipeline_rank(backward_virtual_pp_rank)
            if self.is_pipeline_first_stage():
                input_tensor_grad = None

            # determine whether to recv input tensor from upstream
            recv_prev = True
            if self.is_pipeline_first_stage(ignore_virtual=True):
                next_forward_virtual_pp_rank = self._get_virtual_pp_rank(
665 666
                    forward_micro_step_id - (self.num_stages - 1), forward=True
                )
667 668 669 670 671 672
                if next_forward_virtual_pp_rank == (self.num_model_chunks - 1):
                    # first pp stage and first virtual stage
                    recv_prev = False
                next_forward_virtual_pp_rank += 1
            else:
                next_forward_virtual_pp_rank = self._get_virtual_pp_rank(
673 674
                    forward_micro_step_id + 1, forward=True
                )
675 676 677 678 679 680 681 682 683 684

            # last iteration doesn't need recv from upstream
            if micro_step == (steady_steps - 1):
                recv_prev = False

            # determine whether to recv grad from downstream
            recv_next = True
            if self.is_pipeline_last_stage(ignore_virtual=True):
                next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
                    backward_micro_step_id - (self.num_stages - 1),
685 686
                    forward=False,
                )
687 688 689 690 691 692
                if next_backward_virtual_pp_rank == 0:
                    # last pp stage and last virtual stage
                    recv_next = False
                next_backward_virtual_pp_rank -= 1
            else:
                next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
693 694
                    backward_micro_step_id + 1, forward=False
                )
695

696 697 698 699
            (
                input_tensor,
                output_tensor_grad,
            ) = p2p.send_forward_backward_recv_forward_backward(
700 701 702
                output_tensor,
                input_tensor_grad,
                recv_prev=recv_prev,
703 704
                recv_next=recv_next,
            )
705 706 707

            if recv_prev:
                self.input_tensors[next_forward_virtual_pp_rank].append(
708 709
                    input_tensor
                )
710 711
            if recv_next:
                self.output_tensor_grads[next_backward_virtual_pp_rank].append(
712 713
                    output_tensor_grad
                )
714 715 716 717 718

        # remaining backward steps
        if not forward_only:
            if all_startup_steps:
                self.output_tensor_grads[self.num_model_chunks - 1].append(
719 720 721 722
                    p2p.recv_backward(
                        self.is_pipeline_last_stage(), sync_recv=False
                    )
                )
723 724 725 726 727

            for micro_step in range(steady_steps, num_steps):
                # cooldown loop
                input_tensor_grad = self._backward_step_helper(micro_step)
                next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
728 729
                    micro_step + 1, forward=False
                )
730 731 732

                recv_next = True
                if self.is_pipeline_last_stage(ignore_virtual=True):
733 734 735
                    if next_backward_virtual_pp_rank == (
                        self.num_model_chunks - 1
                    ):
736 737 738 739 740 741
                        recv_next = False

                if micro_step == (num_steps - 1):
                    recv_next = False

                self.output_tensor_grads[next_backward_virtual_pp_rank].append(
742 743 744 745
                    p2p.send_backward_recv_backward(
                        input_tensor_grad, recv_next=recv_next
                    )
                )
746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777

            self._layers.allreduce_shared_weight_gradients()

        if compute_loss:
            # return loss if compute loss
            with paddle.amp.auto_cast(enable=False):
                train_loss = self._broadcast_final_loss()
        else:
            # else just return all intermediate output tensor for all micro steps
            train_loss = self.output_tensors

        return train_loss

    def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
        data = self._prepare_training(data, optimizer, lr_scheduler)
        # interleave scheduler for pipeline parallel
        train_loss = self.interleave_pipeline(data, scaler)

        # optimizer
        with paddle.amp.auto_cast(enable=False):
            self._optimizer_step()

        return train_loss

    def eval_batch(self, data, compute_loss=False):
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

        self._layers.eval()
        self._compute_loss = compute_loss

        return self.interleave_pipeline(data, None, forward_only=True)