pipeline_parallel.py 29.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

import paddle
15
from paddle import framework
16

17
from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer
18 19 20 21 22 23 24 25
from ..utils.hybrid_parallel_util import (
    broadcast_dp_parameters,
    broadcast_mp_parameters,
    broadcast_sharding_parameters,
)
from ..utils.log_util import logger
from .meta_parallel_base import MetaParallelBase
from .parallel_layers.pp_layers import PipelineLayer
S
ShenLiang 已提交
26
from .pp_utils import p2p_communication as p2p
27

28 29
__all__ = []

30 31 32

class PipelineParallel(MetaParallelBase):
    def __init__(self, layers, hcg, strategy):
33 34
        if not isinstance(layers, PipelineLayer):
            raise TypeError(
35 36
                "The Layer should be a derived class of PipelineLayer."
            )
37
        super().__init__(layers, hcg, strategy)
38 39
        self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1
        self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1
40 41 42
        self.use_sharding_parallel = (
            self._hcg.get_sharding_parallel_world_size() > 1
        )
43 44 45 46

        self.total_loss = None

        self.micro_batch_size = self._strategy.pipeline_configs[
47 48
            'micro_batch_size'
        ]
49
        self.accumulate_steps = self._strategy.pipeline_configs[
50 51
            'accumulate_steps'
        ]
52 53 54
        # If sent tensor are not the same from different hosts,
        # they shouldn't been sent partially and then concated as a whole tensor.
        self._enable_partial_send_recv = self._strategy.pipeline_configs[
55 56
            'enable_partial_send_recv'
        ]
57 58
        self._using_cache = self._strategy.pipeline_configs['p2p_cache_shape']

59 60
        self.num_stages = self._hcg.get_pipe_parallel_world_size()
        self.stage_id = self._hcg.get_stage_id()
61
        self.pp_group = self._hcg.get_pipe_parallel_group()
62

63 64 65 66 67
        self._virtual_pp_world_size = None
        self._virtual_pp_rank = None
        self._real_pp_world_size = self.num_stages
        self._real_pp_rank = self.stage_id

68 69 70
        p2p.initialize_p2p_groups(
            hcg, self._using_cache, self._enable_partial_send_recv
        )
71 72

        self.global_rank = self._hcg.get_global_rank()
73
        self.micro_batch_id = 0
74

75 76
        self._compute_loss = True

77 78 79 80 81
        logger.info(
            "Pipeline Info -- num_stages: {}, stage_id: {}".format(
                self.num_stages, self.stage_id
            )
        )
82 83 84 85 86

        if self.use_model_parallel:
            logger.info("start broadcast mp parameters")
            broadcast_mp_parameters(self._layers, self._hcg)

87 88 89 90
        if self.use_sharding_parallel:
            logger.info("start broadcast sharding parameters")
            broadcast_sharding_parameters(self._layers, self._hcg)

91
        if self.use_data_parallel:
92
            logger.info("start broadcast dp parameters")
93
            broadcast_dp_parameters(self._layers, self._hcg)
94

95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
    def is_pipeline_first_stage(self, ignore_virtual=False):
        if not ignore_virtual:
            if self._virtual_pp_world_size is not None:
                assert self._virtual_pp_rank is not None
                if self._virtual_pp_rank != 0:
                    return False
        assert self._real_pp_rank is not None
        return self._real_pp_rank == 0

    def is_pipeline_last_stage(self, ignore_virtual=False):
        if not ignore_virtual:
            if self._virtual_pp_world_size is not None:
                assert self._virtual_pp_rank is not None
                if self._virtual_pp_rank != (self._virtual_pp_world_size - 1):
                    return False
        assert self._real_pp_rank is not None
        assert self._real_pp_world_size is not None
        return self._real_pp_rank == (self._real_pp_world_size - 1)

    def set_virtual_pipeline_rank(self, rank):
        self._virtual_pp_rank = rank

117 118 119 120
    def forward_backward_pipeline(self, data, scaler=None):
        # use the 1f1b scheduling strategy.
        # this strategy is inspired by:
        # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
121

122 123
        self.scaler = scaler

124 125
        # store data for train
        self.data = data
126

127 128 129
        # store total loss of entire batch
        self.total_loss = None

130 131
        # store data id for micro_batch
        self.micro_batch_id = 0
132

133
        startup_steps = self.num_stages - self.stage_id - 1
134 135
        startup_steps = min(startup_steps, self.accumulate_steps)
        steady_steps = self.accumulate_steps - startup_steps
136

137 138
        input_buffers = []
        output_buffers = []
139

140
        for step_id in range(startup_steps):
141
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
142

143
            output_tensor = self._forward_step(input_tensor)
144
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
145

146 147
            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)
148

149
        if steady_steps > 0:
150
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
151

152
        for i in range(steady_steps):
153
            last_iter = i == (steady_steps - 1)
154

155
            output_tensor = self._forward_step(input_tensor)
156

157
            output_tensor_grad = p2p.send_forward_recv_backward(
158 159
                output_tensor, self.is_pipeline_last_stage()
            )
160

161 162
            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)
163

164
            input_tensor, output_tensor = input_buffers.pop(
165 166
                0
            ), output_buffers.pop(0)
167

168 169 170
            input_tensor_grad = self._backward_step(
                input_tensor, output_tensor, output_tensor_grad
            )
171 172 173

            if last_iter:
                input_tensor = None
174 175 176
                p2p.send_backward(
                    input_tensor_grad, self.is_pipeline_first_stage()
                )
177
            else:
178
                input_tensor = p2p.send_backward_recv_forward(
179 180
                    input_tensor_grad, self.is_pipeline_first_stage()
                )
181

182 183 184
        for i in range(startup_steps):
            input_tensor = input_buffers.pop(0)
            output_tensor = output_buffers.pop(0)
185

186
            output_tensor_grad = p2p.recv_backward(
187 188
                self.is_pipeline_last_stage()
            )
189

190 191 192
            input_tensor_grad = self._backward_step(
                input_tensor, output_tensor, output_tensor_grad
            )
193
            p2p.send_backward(input_tensor_grad, self.is_pipeline_first_stage())
194

195
        self._layers.allreduce_shared_weight_gradients()
196 197
        with paddle.amp.auto_cast(enable=False):
            train_loss = self._broadcast_final_loss()
198 199
        return train_loss

200 201 202 203
    def _prepare_training(self, data, optimizer, lr_scheduler):
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

204 205 206
        assert isinstance(
            optimizer, HybridParallelOptimizer
        ), 'optimizer should be HybridParallelOptimizer subclass.'
207

208
        assert (
209
            framework._dygraph_tracer()._has_grad
210
        ), 'Please enable the generation of gradients.'
211

212
        if self.is_pipeline_first_stage(
213 214 215 216 217
            ignore_virtual=True
        ) or self.is_pipeline_last_stage(ignore_virtual=True):
            assert (
                data is not None
            ), "For the first and the last stage, the data must be set."
218 219 220 221 222 223 224 225
        else:
            data = None

        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler

        self._layers.train()

226 227 228 229 230
        return data

    def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
        data = self._prepare_training(data, optimizer, lr_scheduler)
        # 1f1b scheduler for pipeline parallel
231
        train_loss = self.forward_backward_pipeline(data, scaler)
232 233

        # optimizer
234 235
        with paddle.amp.auto_cast(enable=False):
            self._optimizer_step()
236 237

        return train_loss
238

239
    def eval_batch(self, data, compute_loss=False):
240 241 242
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

243 244 245 246 247 248 249 250 251 252 253
        self._layers.eval()
        self._compute_loss = compute_loss

        # save data for eval
        self.data = data
        # store data id for micro_batch
        self.micro_batch_id = 0

        # store total loss of entire batch
        self.total_loss = None

254
        startup_steps = self.num_stages - self.stage_id - 1
255 256 257 258 259 260 261
        startup_steps = min(startup_steps, self.accumulate_steps)
        steady_steps = self.accumulate_steps - startup_steps

        input_buffers = []
        output_buffers = []

        for step_id in range(startup_steps):
262
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
263 264

            output_tensor = self._forward_step(input_tensor)
265
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
266 267 268 269 270

            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)

        if steady_steps > 0:
271
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
272 273

        for i in range(steady_steps):
274
            last_iter = i == (steady_steps - 1)
275 276

            output_tensor = self._forward_step(input_tensor)
277
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
278 279 280 281 282

            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)

            if not last_iter:
283
                input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
284

285 286 287 288 289 290
        if self._compute_loss:
            self.train_loss = self._broadcast_final_loss()
        else:
            self.train_loss = output_buffers

        return self.train_loss
291

292 293
    def _forward_step(self, input_tensor, chunk_id=None):
        if self.is_pipeline_first_stage():
294 295
            input_tensor = self._load_micro_batch(self.micro_batch_id)

296 297 298
        assert chunk_id is None or isinstance(chunk_id, int)

        output_tensor = self._layers.forward(input_tensor, chunk_id=chunk_id)
299

300
        if self.is_pipeline_last_stage():
301 302
            # train calculate loss for train
            if self._compute_loss:
303 304 305
                assert (
                    self._layers._loss_fn is not None
                ), "loss function should exist to compute loss"
306 307
                labels = self._load_micro_batch(self.micro_batch_id)
                output_tensor = self._layers._loss_fn(output_tensor, labels)
308
                assert isinstance(
309
                    output_tensor, (paddle.Tensor, framework.core.eager.Tensor)
310
                ), "Currently, loss_fn should obtain Paddle.Tensor dtype"
311

312 313 314
                with paddle.amp.auto_cast(enable=False):
                    if self.accumulate_steps > 1:
                        output_tensor = output_tensor / self.accumulate_steps
315

316 317 318
                    if self.total_loss is None:
                        self.total_loss = paddle.zeros_like(output_tensor)
                    self.total_loss += output_tensor.detach()
319

320 321 322 323
        if self.is_pipeline_first_stage() or self.is_pipeline_last_stage():
            # Only increase micro batch id at virtual first/last pp stage.
            # The micro batch id is used to load data, therefore, only increase it when load data.
            self.micro_batch_id += 1
324 325 326
        return output_tensor

    def _backward_step(self, input_tensor, output_tensor, output_tensor_grad):
S
ShenLiang 已提交
327
        with paddle.amp.auto_cast(enable=False):
328
            if self.is_pipeline_last_stage():
S
ShenLiang 已提交
329 330 331 332 333
                assert output_tensor_grad is None
                if self.scaler:
                    paddle.autograd.backward(self.scaler.scale(output_tensor))
                else:
                    paddle.autograd.backward(output_tensor)
334
            else:
S
ShenLiang 已提交
335 336 337 338 339
                if isinstance(output_tensor, tuple):
                    outputs = [t for t in output_tensor if not t.stop_gradient]
                    assert len(outputs) == len(output_tensor_grad)
                    paddle.autograd.backward(
                        tensors=outputs,
340
                        grad_tensors=list(output_tensor_grad),
341
                    )
S
ShenLiang 已提交
342
                else:
343 344 345 346
                    paddle.autograd.backward(
                        tensors=[output_tensor],
                        grad_tensors=[output_tensor_grad],
                    )
S
ShenLiang 已提交
347 348 349 350 351

            input_tensor_grad = None
            if input_tensor is not None:
                if isinstance(input_tensor, tuple):
                    input_tensor_grad = tuple(
352 353
                        [t.grad for t in input_tensor if not t.stop_gradient]
                    )
S
ShenLiang 已提交
354 355 356
                else:
                    input_tensor_grad = input_tensor.grad
            return input_tensor_grad
357

358 359 360 361 362 363 364 365 366
    def _check_data_vaild(self, data):
        batch_size = data.shape[0]
        assert self.micro_batch_size * self.accumulate_steps == batch_size, (
            "batch_size needs to be divisible by micro_batch_size. Currently, "
            "batch_size = %d, micro_batch_size = %d, accumulate_steps = %d."
            % (batch_size, self.micro_batch_size, self.accumulate_steps)
        )

    def _load_micro_batch_impl(self, inputs, cache_id):
367 368 369
        begin = cache_id * self.micro_batch_size
        end = begin + self.micro_batch_size

370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399
        if isinstance(inputs, tuple):
            output = []
            for data in inputs:
                if isinstance(data, list):
                    assert (
                        len(data) == self.accumulate_steps
                    ), "length of data should be %d, but it is %d" % (
                        self.accumulate_steps,
                        len(data),
                    )
                    output.append(data[cache_id].detach())
                else:
                    self._check_data_vaild(data)
                    output.append(data[begin:end, :].detach())
            return tuple(output)

        elif isinstance(inputs, list):
            assert (
                len(inputs) == self.accumulate_steps
            ), "length of data should be %d, but it is %d" % (
                self.accumulate_steps,
                len(inputs),
            )
            return inputs[cache_id].detach()
        else:
            self._check_data_vaild(inputs)
            return inputs[begin:end, :].detach()

    def _load_micro_batch(self, cache_id):
        inputs = self.data
400
        if self.is_pipeline_first_stage():
401
            assert len(inputs) == 2, "length of input should be 2"
402
            return self._load_micro_batch_impl(inputs[0], cache_id)
403
        elif self.is_pipeline_last_stage():
404
            assert len(inputs) == 2, "length of input should be 2"
405
            return self._load_micro_batch_impl(inputs[1], cache_id)
406 407
        else:
            inputs = None
408

409
    def _broadcast_final_loss(self):
410 411 412
        # Since the last backward run in interleave will set the virtual rank to 0,
        # here we need to check last stage ignoring virtual stage.
        if self.is_pipeline_last_stage(ignore_virtual=True):
413 414 415
            assert (
                self.total_loss is not None
            ), "train_batch() in last stage should obtain vaild loss"
416
            loss = self.total_loss.detach()
417
            is_fp32 = (
418
                paddle.full([], 1, 'int64')
419
                if loss.dtype == paddle.float32
420
                else paddle.full([], 0, 'int64')
421 422 423 424 425 426 427
            )
            paddle.distributed.broadcast(
                is_fp32, src=self.global_rank, sync_op=True, group=self.pp_group
            )
            paddle.distributed.broadcast(
                loss, src=self.global_rank, sync_op=True, group=self.pp_group
            )
428
        else:
429
            is_fp32 = paddle.full([], 1, 'int64')
430 431 432
            paddle.distributed.broadcast(
                is_fp32,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
433
                sync_op=True,
434 435 436 437
                group=self.pp_group,
            )
            loss = (
                paddle.zeros(shape=[1], dtype="float32")
438
                if is_fp32.item()
439 440
                else paddle.zeros(shape=[1], dtype="float16")
            )
441 442 443
            paddle.distributed.broadcast(
                loss,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
444
                sync_op=True,
445 446
                group=self.pp_group,
            )
447
        return loss
448

449
    def _optimizer_step(self):
450
        if self.scaler:
451
            self.scaler.step(self.optimizer)
S
ShenLiang 已提交
452
            self.scaler.update()
453 454
        else:
            self.optimizer.step()
455

456 457 458
        self.optimizer.clear_grad()
        if self.lr_scheduler:
            self.lr_scheduler.step()
459 460 461 462 463 464


class PipelineParallelWithInterleave(PipelineParallel):
    # pipeline parallel with interleave scheduler

    def __init__(self, layers, hcg, strategy):
465
        super().__init__(layers=layers, hcg=hcg, strategy=strategy)
466
        assert layers.get_num_virtual_stages() > 1
467 468
        assert (
            framework.in_dygraph_mode()
469 470 471 472 473 474 475 476 477 478
        ), "virtual pipeline stage with interleave only support eager dygraph mode"
        # setup for interleave scheduler
        self.num_model_chunks = layers.get_num_virtual_stages()
        self.model_chunks = layers.get_model_chunks()
        assert self.model_chunks is not None
        assert len(self.model_chunks) == self.num_model_chunks
        self._virtual_pp_world_size = self.num_model_chunks
        self._virtual_pp_rank = 0

    def _get_virtual_pp_rank(self, micro_step, forward):
479 480 481
        virtual_pp_stage = micro_step % (
            self.num_stages * self.num_model_chunks
        )
482 483
        virtual_pp_stage = virtual_pp_stage // self.num_stages
        if not forward:
484
            virtual_pp_stage = self.num_model_chunks - virtual_pp_stage - 1
485 486 487 488 489 490 491 492 493 494 495 496 497 498
        return virtual_pp_stage

    def _forward_step_helper(self, micro_step):
        virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=True)
        self.set_virtual_pipeline_rank(virtual_pp_rank)

        # some checkers
        assert hasattr(self, 'input_tensors')
        assert hasattr(self, 'output_tensors')
        if not self._forward_only:
            assert hasattr(self, 'output_tensor_grads')

        if self.is_pipeline_first_stage():
            if len(self.input_tensors[virtual_pp_rank]) == len(
499 500
                self.output_tensors[virtual_pp_rank]
            ):
501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528
                self.input_tensors[virtual_pp_rank].append(None)
        input_tensor = self.input_tensors[virtual_pp_rank][-1]
        output_tensor = self._forward_step(input_tensor, virtual_pp_rank)
        self.output_tensors[virtual_pp_rank].append(output_tensor)

        if self._forward_only:
            # no need to store tensor for backward
            self.input_tensors[virtual_pp_rank].pop()
            self.output_tensors[virtual_pp_rank].pop()

        return output_tensor

    def _backward_step_helper(self, micro_step):
        virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=False)
        self.set_virtual_pipeline_rank(virtual_pp_rank)

        # some checkers
        assert hasattr(self, 'input_tensors')
        assert hasattr(self, 'output_tensors')
        assert hasattr(self, 'output_tensor_grads')

        if self.is_pipeline_last_stage():
            if len(self.output_tensor_grads[virtual_pp_rank]) == 0:
                self.output_tensor_grads[virtual_pp_rank].append(None)

        input_tensor = self.input_tensors[virtual_pp_rank].pop(0)
        output_tensor = self.output_tensors[virtual_pp_rank].pop(0)
        output_tensor_grad = self.output_tensor_grads[virtual_pp_rank].pop(0)
529 530 531
        input_tensor_grad = self._backward_step(
            input_tensor, output_tensor, output_tensor_grad
        )
532 533 534

        return input_tensor_grad

535
    def forward_backward_pipeline(
536 537
        self, data, scaler, forward_only=False, compute_loss=True
    ):
538 539 540 541
        # use interleave scheduling strategy.
        # this strategy is inspired by:
        # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
        if not compute_loss:
542 543 544
            assert (
                not forward_only
            ), "compute_loss can only be set to False when forward_only is set to True"
545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575

        # init some attributes for this batch run
        self.scaler = scaler
        self.data = data
        self.total_loss = None
        self.micro_batch_id = 0
        self._forward_only = forward_only

        # init some data buffers for interleave scheduler
        self.input_tensors = [[] for _ in range(self.num_model_chunks)]
        self.output_tensors = [[] for _ in range(self.num_model_chunks)]
        self.output_tensor_grads = [[] for _ in range(self.num_model_chunks)]

        num_steps = self.accumulate_steps * self.num_model_chunks
        all_startup_steps = False
        if forward_only:
            # If only forward, since there is no backward during running, all steps are startup steps
            startup_steps = num_steps
        else:
            if self.accumulate_steps == self.num_stages:
                startup_steps = num_steps
                all_startup_steps = True
            else:
                startup_steps = (self.num_stages - self.stage_id - 1) * 2
                startup_steps += (self.num_model_chunks - 1) * self.num_stages
                startup_steps = min(startup_steps, num_steps)

        steady_steps = num_steps - startup_steps

        self.set_virtual_pipeline_rank(0)
        self.input_tensors[0].append(
576 577
            p2p.recv_forward(self.is_pipeline_first_stage(), sync_recv=False)
        )
578 579 580 581 582 583

        # run startup steps
        for micro_step in range(startup_steps):
            output_tensor = self._forward_step_helper(micro_step)

            # determine whether recv forward tensor or not
584 585 586
            next_virtual_pp_rank = self._get_virtual_pp_rank(
                micro_step + 1, forward=True
            )
587 588 589 590 591 592 593 594 595 596 597 598 599
            recv_prev = True
            if self.is_pipeline_first_stage(ignore_virtual=True):
                if next_virtual_pp_rank == 0:
                    # next chunk is the first chunk, not need to pre recv an input tensor
                    recv_prev = False
            # last micro step, no next run
            if micro_step == (num_steps - 1):
                recv_prev = False

            # last stage shouldn't send tensor to downstream
            if self.is_pipeline_last_stage():
                output_tensor = None

600 601 602 603 604
            if (
                micro_step == (startup_steps - 1)
                and not forward_only
                and not all_startup_steps
            ):
605 606 607 608 609 610
                input_tensor_grad = None
                recv_next = True
                if self.is_pipeline_last_stage(ignore_virtual=True):
                    recv_next = False

                # the last startup step needs on four direction comm to set up for steady 1f1b
611 612 613 614
                (
                    input_tensor,
                    output_tensor_grad,
                ) = p2p.send_forward_backward_recv_forward_backward(
615 616 617
                    output_tensor,
                    input_tensor_grad,
                    recv_prev=recv_prev,
618 619 620 621 622
                    recv_next=recv_next,
                )
                self.output_tensor_grads[self.num_model_chunks - 1].append(
                    output_tensor_grad
                )
623 624
            else:
                input_tensor = p2p.send_forward_recv_forward(
625 626
                    output_tensor, recv_prev=recv_prev
                )
627 628 629 630 631 632 633 634 635 636 637
            self.input_tensors[next_virtual_pp_rank].append(input_tensor)

        # run 1f1b steady steps
        for micro_step in range(steady_steps):
            # forward
            forward_micro_step_id = micro_step + startup_steps
            output_tensor = self._forward_step_helper(forward_micro_step_id)

            # backward
            backward_micro_step_id = micro_step
            input_tensor_grad = self._backward_step_helper(
638 639
                backward_micro_step_id
            )
640 641 642 643 644 645 646 647 648

            # four directions comm
            # send output tensor to downstream
            # send input tensor grad to upstream
            # recv input tensor from upstream
            # recv output tensor grad from downstream

            # last stage doesn't send rst to downstream
            forward_virtual_pp_rank = self._get_virtual_pp_rank(
649 650
                forward_micro_step_id, forward=True
            )
651 652 653 654 655 656
            self.set_virtual_pipeline_rank(forward_virtual_pp_rank)
            if self.is_pipeline_last_stage():
                output_tensor = None

            # first stage doesn't send grad to upstream
            backward_virtual_pp_rank = self._get_virtual_pp_rank(
657 658
                backward_micro_step_id, forward=False
            )
659 660 661 662 663 664 665 666
            self.set_virtual_pipeline_rank(backward_virtual_pp_rank)
            if self.is_pipeline_first_stage():
                input_tensor_grad = None

            # determine whether to recv input tensor from upstream
            recv_prev = True
            if self.is_pipeline_first_stage(ignore_virtual=True):
                next_forward_virtual_pp_rank = self._get_virtual_pp_rank(
667 668
                    forward_micro_step_id - (self.num_stages - 1), forward=True
                )
669 670 671 672 673 674
                if next_forward_virtual_pp_rank == (self.num_model_chunks - 1):
                    # first pp stage and first virtual stage
                    recv_prev = False
                next_forward_virtual_pp_rank += 1
            else:
                next_forward_virtual_pp_rank = self._get_virtual_pp_rank(
675 676
                    forward_micro_step_id + 1, forward=True
                )
677 678 679 680 681 682 683 684 685 686

            # last iteration doesn't need recv from upstream
            if micro_step == (steady_steps - 1):
                recv_prev = False

            # determine whether to recv grad from downstream
            recv_next = True
            if self.is_pipeline_last_stage(ignore_virtual=True):
                next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
                    backward_micro_step_id - (self.num_stages - 1),
687 688
                    forward=False,
                )
689 690 691 692 693 694
                if next_backward_virtual_pp_rank == 0:
                    # last pp stage and last virtual stage
                    recv_next = False
                next_backward_virtual_pp_rank -= 1
            else:
                next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
695 696
                    backward_micro_step_id + 1, forward=False
                )
697

698 699 700 701
            (
                input_tensor,
                output_tensor_grad,
            ) = p2p.send_forward_backward_recv_forward_backward(
702 703 704
                output_tensor,
                input_tensor_grad,
                recv_prev=recv_prev,
705 706
                recv_next=recv_next,
            )
707 708 709

            if recv_prev:
                self.input_tensors[next_forward_virtual_pp_rank].append(
710 711
                    input_tensor
                )
712 713
            if recv_next:
                self.output_tensor_grads[next_backward_virtual_pp_rank].append(
714 715
                    output_tensor_grad
                )
716 717 718 719 720

        # remaining backward steps
        if not forward_only:
            if all_startup_steps:
                self.output_tensor_grads[self.num_model_chunks - 1].append(
721 722 723 724
                    p2p.recv_backward(
                        self.is_pipeline_last_stage(), sync_recv=False
                    )
                )
725 726 727 728 729

            for micro_step in range(steady_steps, num_steps):
                # cooldown loop
                input_tensor_grad = self._backward_step_helper(micro_step)
                next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
730 731
                    micro_step + 1, forward=False
                )
732 733 734

                recv_next = True
                if self.is_pipeline_last_stage(ignore_virtual=True):
735 736 737
                    if next_backward_virtual_pp_rank == (
                        self.num_model_chunks - 1
                    ):
738 739 740 741 742 743
                        recv_next = False

                if micro_step == (num_steps - 1):
                    recv_next = False

                self.output_tensor_grads[next_backward_virtual_pp_rank].append(
744 745 746 747
                    p2p.send_backward_recv_backward(
                        input_tensor_grad, recv_next=recv_next
                    )
                )
748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763

            self._layers.allreduce_shared_weight_gradients()

        if compute_loss:
            # return loss if compute loss
            with paddle.amp.auto_cast(enable=False):
                train_loss = self._broadcast_final_loss()
        else:
            # else just return all intermediate output tensor for all micro steps
            train_loss = self.output_tensors

        return train_loss

    def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
        data = self._prepare_training(data, optimizer, lr_scheduler)
        # interleave scheduler for pipeline parallel
764
        train_loss = self.forward_backward_pipeline(data, scaler)
765 766 767 768 769 770 771 772 773 774 775 776 777 778

        # optimizer
        with paddle.amp.auto_cast(enable=False):
            self._optimizer_step()

        return train_loss

    def eval_batch(self, data, compute_loss=False):
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

        self._layers.eval()
        self._compute_loss = compute_loss

779
        return self.forward_backward_pipeline(data, None, forward_only=True)