pipeline_parallel.py 35.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
Y
Yuang Liu 已提交
13
import warnings
14 15

import paddle
16
from paddle import framework
17

18
from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer
Y
Yuang Liu 已提交
19
from ..utils import timer_helper as timer
20 21 22 23 24 25 26 27
from ..utils.hybrid_parallel_util import (
    broadcast_dp_parameters,
    broadcast_mp_parameters,
    broadcast_sharding_parameters,
)
from ..utils.log_util import logger
from .meta_parallel_base import MetaParallelBase
from .parallel_layers.pp_layers import PipelineLayer
S
ShenLiang 已提交
28
from .pp_utils import p2p_communication as p2p
29
from .pp_utils.utils import HOOK_ACTION, FusedCommBuffer, assign_group_by_size
30

31 32
__all__ = []

33 34 35

class PipelineParallel(MetaParallelBase):
    def __init__(self, layers, hcg, strategy):
36 37
        if not isinstance(layers, PipelineLayer):
            raise TypeError(
38 39
                "The Layer should be a derived class of PipelineLayer."
            )
40
        super().__init__(layers, hcg, strategy)
41 42
        self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1
        self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1
43 44 45
        self.use_sharding_parallel = (
            self._hcg.get_sharding_parallel_world_size() > 1
        )
46 47 48 49

        self.total_loss = None

        self.micro_batch_size = self._strategy.pipeline_configs[
50 51
            'micro_batch_size'
        ]
52
        self.accumulate_steps = self._strategy.pipeline_configs[
53 54
            'accumulate_steps'
        ]
55 56 57
        # If sent tensor are not the same from different hosts,
        # they shouldn't been sent partially and then concated as a whole tensor.
        self._enable_partial_send_recv = self._strategy.pipeline_configs[
58 59
            'enable_partial_send_recv'
        ]
60 61
        self._using_cache = self._strategy.pipeline_configs['p2p_cache_shape']

62 63
        self.num_stages = self._hcg.get_pipe_parallel_world_size()
        self.stage_id = self._hcg.get_stage_id()
64
        self.pp_group = self._hcg.get_pipe_parallel_group()
65
        self.dp_group = self._hcg.get_data_parallel_group()
66
        self.sharding_group = self._hcg.get_sharding_parallel_group()
67

68 69 70 71 72
        self._virtual_pp_world_size = None
        self._virtual_pp_rank = None
        self._real_pp_world_size = self.num_stages
        self._real_pp_rank = self.stage_id

73 74 75
        self._delay_scale_loss = self._strategy.hybrid_configs[
            "pp_configs"
        ].delay_scale_loss
76 77
        # TODO(PP Dev): support dp_comm_overlap without use_main_grad training.
        # This combination will trigger inplace check error during `reshape_` in funct `_split_tensors`.
78 79 80
        self._dp_comm_overlap = self._strategy.hybrid_configs[
            "pp_configs"
        ].dp_comm_overlap
81 82 83
        self._sharding_comm_overlap = self._strategy.hybrid_configs[
            "pp_configs"
        ].sharding_comm_overlap
Y
Yuang Liu 已提交
84 85 86
        self._enable_timer = self._strategy.hybrid_configs[
            "pp_configs"
        ].enable_timer
87

88 89 90
        if self._dp_comm_overlap:
            assert self.use_data_parallel and self.num_stages > 1

91 92 93 94 95 96 97 98 99 100 101 102
        if self._sharding_comm_overlap:
            assert self.use_sharding_parallel and self.num_stages > 1

        assert not (
            self._dp_comm_overlap and self._sharding_comm_overlap
        ), "Cannot use dp pp overlap and sharding pp overlap at the same time."

        self._comm_buffers = []
        self._comm_overlap = (
            self._dp_comm_overlap or self._sharding_comm_overlap
        )

Y
Yuang Liu 已提交
103 104 105 106 107
        if self._enable_timer:
            if not timer.is_timer_initialized():
                timer.set_timers()
            self.timers = timer.get_timers()

108
        p2p.initialize_p2p_groups(
Y
Yuang Liu 已提交
109 110 111 112
            hcg,
            self._using_cache,
            self._enable_partial_send_recv,
            self._enable_timer,
113
        )
114 115

        self.global_rank = self._hcg.get_global_rank()
116
        self.micro_batch_id = 0
117

118 119
        self._compute_loss = True

120 121 122 123 124
        logger.info(
            "Pipeline Info -- num_stages: {}, stage_id: {}".format(
                self.num_stages, self.stage_id
            )
        )
125 126 127 128 129

        if self.use_model_parallel:
            logger.info("start broadcast mp parameters")
            broadcast_mp_parameters(self._layers, self._hcg)

130 131 132 133
        if self.use_sharding_parallel:
            logger.info("start broadcast sharding parameters")
            broadcast_sharding_parameters(self._layers, self._hcg)

134
        if self.use_data_parallel:
135
            logger.info("start broadcast dp parameters")
136
            broadcast_dp_parameters(self._layers, self._hcg)
137

138 139
        if self._dp_comm_overlap:
            self.register_allreduce_overlap_hook(
140
                self._layers, self.dp_group, self.accumulate_steps, True
141 142
            )

143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
    def is_pipeline_first_stage(self, ignore_virtual=False):
        if not ignore_virtual:
            if self._virtual_pp_world_size is not None:
                assert self._virtual_pp_rank is not None
                if self._virtual_pp_rank != 0:
                    return False
        assert self._real_pp_rank is not None
        return self._real_pp_rank == 0

    def is_pipeline_last_stage(self, ignore_virtual=False):
        if not ignore_virtual:
            if self._virtual_pp_world_size is not None:
                assert self._virtual_pp_rank is not None
                if self._virtual_pp_rank != (self._virtual_pp_world_size - 1):
                    return False
        assert self._real_pp_rank is not None
        assert self._real_pp_world_size is not None
        return self._real_pp_rank == (self._real_pp_world_size - 1)

    def set_virtual_pipeline_rank(self, rank):
        self._virtual_pp_rank = rank

165 166 167 168 169 170 171
    def bw_hook_func(self, buffer, param):
        @paddle.autograd.no_grad()
        def fused_allreduce(*_):
            buffer.add_grad(param)

        return fused_allreduce

172
    def register_allreduce_overlap_hook(self, model, comm_group, acc_steps, dp):
173 174 175 176 177
        if model.get_num_virtual_stages() > 1:
            models = model.get_model_chunks()
        else:
            models = [model]

178 179 180 181 182 183 184 185 186
        if not dp:
            assert hasattr(self, "optimizer")
            assert hasattr(self.optimizer, "_param2rank")
            _param2rank = self.optimizer._param2rank

        act = HOOK_ACTION.ALL_REDUCE if dp else HOOK_ACTION.REDUCE

        fused_parameter_group = {}

187 188 189 190 191 192 193 194 195
        for model in models:
            # For virtual pipeline. Will separate parameters in different chunk into
            # different groups to get the best performance.
            parameter_list = [
                p for p in model.parameters() if not p.stop_gradient
            ]
            if len(parameter_list) < 1:
                return

196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216
            if dp:
                fused_parameter_group[-1] = parameter_list
            else:
                # Sort parameters for sharding, since they have different dst rank
                for p in parameter_list:
                    assert p.name in _param2rank
                    dst_rank = _param2rank[p.name]
                    if dst_rank in fused_parameter_group:
                        fused_parameter_group[dst_rank].append(p)
                    else:
                        fused_parameter_group[dst_rank] = [p]

            for dst in fused_parameter_group:
                parameter_list = fused_parameter_group[dst]
                if not dp:
                    # parse the relative dst rank to absolute dst rank for sharding
                    dst = comm_group.ranks[dst]
                var_groups = assign_group_by_size(parameter_list)
                for group_idx, parameters in var_groups.items():
                    buffer = FusedCommBuffer(
                        group_idx, parameters, comm_group, acc_steps, act, dst
217
                    )
218 219 220 221 222
                    self._comm_buffers.append(buffer)
                    for param in parameters:
                        param._register_backward_hook(
                            self.bw_hook_func(buffer, param)
                        )
223

Y
Yuang Liu 已提交
224 225 226 227 228 229
    def timer_printer(self):
        if not self._enable_timer:
            return
        all_flag_names = self.timers.timers.keys()
        self.timers.log(all_flag_names)

230 231 232 233
    def forward_backward_pipeline(self, data, scaler=None):
        # use the 1f1b scheduling strategy.
        # this strategy is inspired by:
        # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
234

235 236
        self.scaler = scaler

237 238
        # store data for train
        self.data = data
239

240 241 242
        # store total loss of entire batch
        self.total_loss = None

243 244
        # store data id for micro_batch
        self.micro_batch_id = 0
245

246
        startup_steps = self.num_stages - self.stage_id - 1
247 248
        startup_steps = min(startup_steps, self.accumulate_steps)
        steady_steps = self.accumulate_steps - startup_steps
249

250 251
        input_buffers = []
        output_buffers = []
252

253
        for step_id in range(startup_steps):
254
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
255

256
            output_tensor = self._forward_step(input_tensor)
257
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
258

259 260
            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)
261

262
        if steady_steps > 0:
263
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
264

265
        for i in range(steady_steps):
266
            last_iter = i == (steady_steps - 1)
267

268
            output_tensor = self._forward_step(input_tensor)
269

270
            output_tensor_grad = p2p.send_forward_recv_backward(
271 272
                output_tensor, self.is_pipeline_last_stage()
            )
273

274 275
            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)
276

277
            input_tensor, output_tensor = input_buffers.pop(
278 279
                0
            ), output_buffers.pop(0)
280

281 282 283
            input_tensor_grad = self._backward_step(
                input_tensor, output_tensor, output_tensor_grad
            )
284 285 286

            if last_iter:
                input_tensor = None
287 288 289
                p2p.send_backward(
                    input_tensor_grad, self.is_pipeline_first_stage()
                )
290
            else:
291
                input_tensor = p2p.send_backward_recv_forward(
292 293
                    input_tensor_grad, self.is_pipeline_first_stage()
                )
294

295 296 297
        for i in range(startup_steps):
            input_tensor = input_buffers.pop(0)
            output_tensor = output_buffers.pop(0)
298

299
            output_tensor_grad = p2p.recv_backward(
300 301
                self.is_pipeline_last_stage()
            )
302

303 304 305
            input_tensor_grad = self._backward_step(
                input_tensor, output_tensor, output_tensor_grad
            )
306
            p2p.send_backward(input_tensor_grad, self.is_pipeline_first_stage())
307

308 309 310
        if self._comm_overlap:
            assert len(self._comm_buffers) > 0
            for buffer in self._comm_buffers:
311 312
                buffer.scale_and_split_grads()

Y
Yuang Liu 已提交
313 314
        if self._enable_timer:
            self.timers("allreduce_shared_weight_gradients").start()
315
        self._layers.allreduce_shared_weight_gradients()
Y
Yuang Liu 已提交
316 317 318
        if self._enable_timer:
            self.timers("allreduce_shared_weight_gradients").stop()
            self.timers("broadcast_final_loss").start()
319 320
        with paddle.amp.auto_cast(enable=False):
            train_loss = self._broadcast_final_loss()
Y
Yuang Liu 已提交
321 322 323
        if self._enable_timer:
            self.timers("broadcast_final_loss").stop()
        self.timer_printer()
324 325
        return train_loss

326 327 328 329
    def _prepare_training(self, data, optimizer, lr_scheduler):
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

330 331 332
        assert isinstance(
            optimizer, HybridParallelOptimizer
        ), 'optimizer should be HybridParallelOptimizer subclass.'
333

334
        assert (
335
            framework._dygraph_tracer()._has_grad
336
        ), 'Please enable the generation of gradients.'
337

338
        if self.is_pipeline_first_stage(
339 340 341 342 343
            ignore_virtual=True
        ) or self.is_pipeline_last_stage(ignore_virtual=True):
            assert (
                data is not None
            ), "For the first and the last stage, the data must be set."
344 345 346 347 348 349 350 351
        else:
            data = None

        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler

        self._layers.train()

352 353 354 355 356
        if self._sharding_comm_overlap and len(self._comm_buffers) == 0:
            self.register_allreduce_overlap_hook(
                self._layers, self.sharding_group, self.accumulate_steps, False
            )

357 358 359 360 361
        return data

    def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
        data = self._prepare_training(data, optimizer, lr_scheduler)
        # 1f1b scheduler for pipeline parallel
362
        train_loss = self.forward_backward_pipeline(data, scaler)
363 364

        # optimizer
365 366
        with paddle.amp.auto_cast(enable=False):
            self._optimizer_step()
367 368

        return train_loss
369

370
    def eval_batch(self, data, compute_loss=False):
371 372 373
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

374 375 376 377 378 379 380 381 382 383 384
        self._layers.eval()
        self._compute_loss = compute_loss

        # save data for eval
        self.data = data
        # store data id for micro_batch
        self.micro_batch_id = 0

        # store total loss of entire batch
        self.total_loss = None

385
        startup_steps = self.num_stages - self.stage_id - 1
386 387 388 389 390 391 392
        startup_steps = min(startup_steps, self.accumulate_steps)
        steady_steps = self.accumulate_steps - startup_steps

        input_buffers = []
        output_buffers = []

        for step_id in range(startup_steps):
393
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
394 395

            output_tensor = self._forward_step(input_tensor)
396
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
397 398 399 400 401

            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)

        if steady_steps > 0:
402
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
403 404

        for i in range(steady_steps):
405
            last_iter = i == (steady_steps - 1)
406 407

            output_tensor = self._forward_step(input_tensor)
408
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
409 410 411 412 413

            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)

            if not last_iter:
414
                input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
415

416 417 418 419 420 421
        if self._compute_loss:
            self.train_loss = self._broadcast_final_loss()
        else:
            self.train_loss = output_buffers

        return self.train_loss
422

423
    def _forward_step(self, input_tensor, chunk_id=None):
Y
Yuang Liu 已提交
424 425
        if self._enable_timer:
            self.timers("forward_step").start()
426
        if self.is_pipeline_first_stage():
427 428
            input_tensor = self._load_micro_batch(self.micro_batch_id)

429 430 431
        assert chunk_id is None or isinstance(chunk_id, int)

        output_tensor = self._layers.forward(input_tensor, chunk_id=chunk_id)
432

433
        if self.is_pipeline_last_stage():
434 435
            # train calculate loss for train
            if self._compute_loss:
436 437 438
                assert (
                    self._layers._loss_fn is not None
                ), "loss function should exist to compute loss"
439 440
                labels = self._load_micro_batch(self.micro_batch_id)
                output_tensor = self._layers._loss_fn(output_tensor, labels)
441
                assert isinstance(
442
                    output_tensor, (paddle.Tensor, framework.core.eager.Tensor)
443
                ), "Currently, loss_fn should obtain Paddle.Tensor dtype"
444

445
                with paddle.amp.auto_cast(enable=False):
446
                    if self.accumulate_steps > 1 and not self._delay_scale_loss:
447
                        output_tensor = output_tensor / self.accumulate_steps
448

449 450 451
                    if self.total_loss is None:
                        self.total_loss = paddle.zeros_like(output_tensor)
                    self.total_loss += output_tensor.detach()
452

453 454 455 456
        if self.is_pipeline_first_stage() or self.is_pipeline_last_stage():
            # Only increase micro batch id at virtual first/last pp stage.
            # The micro batch id is used to load data, therefore, only increase it when load data.
            self.micro_batch_id += 1
Y
Yuang Liu 已提交
457 458
        if self._enable_timer:
            self.timers("forward_step").stop()
459 460 461
        return output_tensor

    def _backward_step(self, input_tensor, output_tensor, output_tensor_grad):
Y
Yuang Liu 已提交
462 463
        if self._enable_timer:
            self.timers("backward_step").start()
S
ShenLiang 已提交
464
        with paddle.amp.auto_cast(enable=False):
465
            if self.is_pipeline_last_stage():
S
ShenLiang 已提交
466 467 468 469 470
                assert output_tensor_grad is None
                if self.scaler:
                    paddle.autograd.backward(self.scaler.scale(output_tensor))
                else:
                    paddle.autograd.backward(output_tensor)
471
            else:
S
ShenLiang 已提交
472 473 474 475 476
                if isinstance(output_tensor, tuple):
                    outputs = [t for t in output_tensor if not t.stop_gradient]
                    assert len(outputs) == len(output_tensor_grad)
                    paddle.autograd.backward(
                        tensors=outputs,
477
                        grad_tensors=list(output_tensor_grad),
478
                    )
S
ShenLiang 已提交
479
                else:
480 481 482 483
                    paddle.autograd.backward(
                        tensors=[output_tensor],
                        grad_tensors=[output_tensor_grad],
                    )
S
ShenLiang 已提交
484 485 486 487 488

            input_tensor_grad = None
            if input_tensor is not None:
                if isinstance(input_tensor, tuple):
                    input_tensor_grad = tuple(
489 490
                        [t.grad for t in input_tensor if not t.stop_gradient]
                    )
S
ShenLiang 已提交
491 492
                else:
                    input_tensor_grad = input_tensor.grad
Y
Yuang Liu 已提交
493 494
            if self._enable_timer:
                self.timers("backward_step").stop()
S
ShenLiang 已提交
495
            return input_tensor_grad
496

497 498 499 500 501 502 503 504 505
    def _check_data_vaild(self, data):
        batch_size = data.shape[0]
        assert self.micro_batch_size * self.accumulate_steps == batch_size, (
            "batch_size needs to be divisible by micro_batch_size. Currently, "
            "batch_size = %d, micro_batch_size = %d, accumulate_steps = %d."
            % (batch_size, self.micro_batch_size, self.accumulate_steps)
        )

    def _load_micro_batch_impl(self, inputs, cache_id):
506 507 508
        begin = cache_id * self.micro_batch_size
        end = begin + self.micro_batch_size

509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538
        if isinstance(inputs, tuple):
            output = []
            for data in inputs:
                if isinstance(data, list):
                    assert (
                        len(data) == self.accumulate_steps
                    ), "length of data should be %d, but it is %d" % (
                        self.accumulate_steps,
                        len(data),
                    )
                    output.append(data[cache_id].detach())
                else:
                    self._check_data_vaild(data)
                    output.append(data[begin:end, :].detach())
            return tuple(output)

        elif isinstance(inputs, list):
            assert (
                len(inputs) == self.accumulate_steps
            ), "length of data should be %d, but it is %d" % (
                self.accumulate_steps,
                len(inputs),
            )
            return inputs[cache_id].detach()
        else:
            self._check_data_vaild(inputs)
            return inputs[begin:end, :].detach()

    def _load_micro_batch(self, cache_id):
        inputs = self.data
539
        if self.is_pipeline_first_stage():
540
            assert len(inputs) == 2, "length of input should be 2"
541
            return self._load_micro_batch_impl(inputs[0], cache_id)
542
        elif self.is_pipeline_last_stage():
543
            assert len(inputs) == 2, "length of input should be 2"
544
            return self._load_micro_batch_impl(inputs[1], cache_id)
545 546
        else:
            inputs = None
547

548
    def _broadcast_final_loss(self):
549 550 551
        # Since the last backward run in interleave will set the virtual rank to 0,
        # here we need to check last stage ignoring virtual stage.
        if self.is_pipeline_last_stage(ignore_virtual=True):
552 553 554
            assert (
                self.total_loss is not None
            ), "train_batch() in last stage should obtain vaild loss"
555 556 557 558 559
            loss = (
                self.total_loss.detach()
                if not self._delay_scale_loss
                else self.total_loss / self.accumulate_steps
            )
560
            is_fp32 = (
561
                paddle.full([], 1, 'int64')
562
                if loss.dtype == paddle.float32
563
                else paddle.full([], 0, 'int64')
564 565 566 567 568 569 570
            )
            paddle.distributed.broadcast(
                is_fp32, src=self.global_rank, sync_op=True, group=self.pp_group
            )
            paddle.distributed.broadcast(
                loss, src=self.global_rank, sync_op=True, group=self.pp_group
            )
571
        else:
572
            is_fp32 = paddle.full([], 1, 'int64')
573 574 575
            paddle.distributed.broadcast(
                is_fp32,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
576
                sync_op=True,
577 578 579 580
                group=self.pp_group,
            )
            loss = (
                paddle.zeros(shape=[1], dtype="float32")
581
                if is_fp32.item()
582 583
                else paddle.zeros(shape=[1], dtype="float16")
            )
584 585 586
            paddle.distributed.broadcast(
                loss,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
587
                sync_op=True,
588 589
                group=self.pp_group,
            )
590
        return loss
591

592
    def _optimizer_step(self):
593 594 595 596 597 598 599 600
        if self._delay_scale_loss:
            for p in self._layers.parameters():
                if hasattr(p, "main_grad") and p.main_grad is not None:
                    assert p.grad is None
                    p.main_grad = p.main_grad.scale(1.0 / self.accumulate_steps)
                elif p.grad is not None:
                    p.grad = p.grad.scale(1.0 / self.accumulate_steps)

601
        if self.scaler:
602
            self.scaler.step(self.optimizer)
S
ShenLiang 已提交
603
            self.scaler.update()
604 605
        else:
            self.optimizer.step()
606

607 608 609
        self.optimizer.clear_grad()
        if self.lr_scheduler:
            self.lr_scheduler.step()
610 611 612 613 614 615


class PipelineParallelWithInterleave(PipelineParallel):
    # pipeline parallel with interleave scheduler

    def __init__(self, layers, hcg, strategy):
616
        super().__init__(layers=layers, hcg=hcg, strategy=strategy)
617
        assert layers.get_num_virtual_stages() > 1
Y
Yuang Liu 已提交
618 619 620 621
        if self.num_stages <= 2:
            warnings.warn(
                "Deprecate warning! In the near future the virtual pp will only available when pp degree > 2."
            )
622
        assert (
623
            framework.in_dynamic_mode()
624
        ), "virtual pipeline stage with interleave only support eager dygraph mode"
625 626 627
        assert (
            self.accumulate_steps % self.num_stages == 0
        ), "accumulate_steps should be evenly divisible by num_stages for pipeline with interleave"
628 629 630 631 632 633 634 635 636
        # setup for interleave scheduler
        self.num_model_chunks = layers.get_num_virtual_stages()
        self.model_chunks = layers.get_model_chunks()
        assert self.model_chunks is not None
        assert len(self.model_chunks) == self.num_model_chunks
        self._virtual_pp_world_size = self.num_model_chunks
        self._virtual_pp_rank = 0

    def _get_virtual_pp_rank(self, micro_step, forward):
637 638 639
        virtual_pp_stage = micro_step % (
            self.num_stages * self.num_model_chunks
        )
640 641
        virtual_pp_stage = virtual_pp_stage // self.num_stages
        if not forward:
642
            virtual_pp_stage = self.num_model_chunks - virtual_pp_stage - 1
643 644 645 646 647 648 649 650 651 652 653
        return virtual_pp_stage

    def _forward_step_helper(self, micro_step):
        virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=True)
        self.set_virtual_pipeline_rank(virtual_pp_rank)

        # some checkers
        assert hasattr(self, 'input_tensors')
        assert hasattr(self, 'output_tensors')
        if not self._forward_only:
            assert hasattr(self, 'output_tensor_grads')
654 655 656
        assert len(self.input_tensors[virtual_pp_rank]) == (
            len(self.output_tensors[virtual_pp_rank]) + 1
        )
657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676
        input_tensor = self.input_tensors[virtual_pp_rank][-1]
        output_tensor = self._forward_step(input_tensor, virtual_pp_rank)
        self.output_tensors[virtual_pp_rank].append(output_tensor)

        if self._forward_only:
            # no need to store tensor for backward
            self.input_tensors[virtual_pp_rank].pop()
            self.output_tensors[virtual_pp_rank].pop()

        return output_tensor

    def _backward_step_helper(self, micro_step):
        virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=False)
        self.set_virtual_pipeline_rank(virtual_pp_rank)

        # some checkers
        assert hasattr(self, 'input_tensors')
        assert hasattr(self, 'output_tensors')
        assert hasattr(self, 'output_tensor_grads')

677 678 679 680 681 682
        assert (
            len(self.output_tensor_grads[virtual_pp_rank]) == 1
        ), f"output_tensor_grads is empty for virtual_pp_rank {virtual_pp_rank}"

        assert len(self.input_tensors[virtual_pp_rank]) > 0
        assert len(self.output_tensors[virtual_pp_rank]) > 0
683 684 685 686

        input_tensor = self.input_tensors[virtual_pp_rank].pop(0)
        output_tensor = self.output_tensors[virtual_pp_rank].pop(0)
        output_tensor_grad = self.output_tensor_grads[virtual_pp_rank].pop(0)
687 688 689
        input_tensor_grad = self._backward_step(
            input_tensor, output_tensor, output_tensor_grad
        )
690 691 692

        return input_tensor_grad

693
    def forward_backward_pipeline(
694 695
        self, data, scaler, forward_only=False, compute_loss=True
    ):
696 697 698 699
        # use interleave scheduling strategy.
        # this strategy is inspired by:
        # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
        if not compute_loss:
700 701 702
            assert (
                not forward_only
            ), "compute_loss can only be set to False when forward_only is set to True"
703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720

        # init some attributes for this batch run
        self.scaler = scaler
        self.data = data
        self.total_loss = None
        self.micro_batch_id = 0
        self._forward_only = forward_only

        # init some data buffers for interleave scheduler
        self.input_tensors = [[] for _ in range(self.num_model_chunks)]
        self.output_tensors = [[] for _ in range(self.num_model_chunks)]
        self.output_tensor_grads = [[] for _ in range(self.num_model_chunks)]

        num_steps = self.accumulate_steps * self.num_model_chunks
        if forward_only:
            # If only forward, since there is no backward during running, all steps are startup steps
            startup_steps = num_steps
        else:
721 722 723 724 725 726 727
            # actually startup_steps is calculated from two number:
            # first_forward_cross_to_end = (self.num_stages - self.stage_id - 1) + (self.num_model_chunks - 1) * self.num_stages
            # end_to_first_backward_cross = (self.num_stages - self.stage_id - 1)
            # startup_steps = first_forward_cross_to_end + end_to_first_backward_cross
            startup_steps = (self.num_stages - self.stage_id - 1) * 2
            startup_steps += (self.num_model_chunks - 1) * self.num_stages
            startup_steps = min(startup_steps, num_steps)
728 729 730 731 732

        steady_steps = num_steps - startup_steps

        self.set_virtual_pipeline_rank(0)
        self.input_tensors[0].append(
733 734
            p2p.recv_forward(self.is_pipeline_first_stage(), sync_recv=False)
        )
735 736 737 738 739 740

        # run startup steps
        for micro_step in range(startup_steps):
            output_tensor = self._forward_step_helper(micro_step)

            # determine whether recv forward tensor or not
741 742 743
            next_virtual_pp_rank = self._get_virtual_pp_rank(
                micro_step + 1, forward=True
            )
744 745 746 747 748 749 750 751 752 753 754 755 756
            recv_prev = True
            if self.is_pipeline_first_stage(ignore_virtual=True):
                if next_virtual_pp_rank == 0:
                    # next chunk is the first chunk, not need to pre recv an input tensor
                    recv_prev = False
            # last micro step, no next run
            if micro_step == (num_steps - 1):
                recv_prev = False

            # last stage shouldn't send tensor to downstream
            if self.is_pipeline_last_stage():
                output_tensor = None

757
            if micro_step == (startup_steps - 1) and not forward_only:
758 759 760 761 762 763
                input_tensor_grad = None
                recv_next = True
                if self.is_pipeline_last_stage(ignore_virtual=True):
                    recv_next = False

                # the last startup step needs on four direction comm to set up for steady 1f1b
764 765 766 767
                (
                    input_tensor,
                    output_tensor_grad,
                ) = p2p.send_forward_backward_recv_forward_backward(
768 769 770
                    output_tensor,
                    input_tensor_grad,
                    recv_prev=recv_prev,
771 772
                    recv_next=recv_next,
                )
773 774
                # output_tensor_grad is not none if recv_next
                # append output_tensor_grad no matter none or not
775 776 777
                self.output_tensor_grads[self.num_model_chunks - 1].append(
                    output_tensor_grad
                )
778 779
            else:
                input_tensor = p2p.send_forward_recv_forward(
780 781
                    output_tensor, recv_prev=recv_prev
                )
782
            # append input_tensor no matter none or not
783 784 785 786 787 788 789 790 791 792 793
            self.input_tensors[next_virtual_pp_rank].append(input_tensor)

        # run 1f1b steady steps
        for micro_step in range(steady_steps):
            # forward
            forward_micro_step_id = micro_step + startup_steps
            output_tensor = self._forward_step_helper(forward_micro_step_id)

            # backward
            backward_micro_step_id = micro_step
            input_tensor_grad = self._backward_step_helper(
794 795
                backward_micro_step_id
            )
796 797 798 799 800 801 802 803 804

            # four directions comm
            # send output tensor to downstream
            # send input tensor grad to upstream
            # recv input tensor from upstream
            # recv output tensor grad from downstream

            # last stage doesn't send rst to downstream
            forward_virtual_pp_rank = self._get_virtual_pp_rank(
805 806
                forward_micro_step_id, forward=True
            )
807 808 809 810 811 812
            self.set_virtual_pipeline_rank(forward_virtual_pp_rank)
            if self.is_pipeline_last_stage():
                output_tensor = None

            # first stage doesn't send grad to upstream
            backward_virtual_pp_rank = self._get_virtual_pp_rank(
813 814
                backward_micro_step_id, forward=False
            )
815 816 817 818 819 820
            self.set_virtual_pipeline_rank(backward_virtual_pp_rank)
            if self.is_pipeline_first_stage():
                input_tensor_grad = None

            # determine whether to recv input tensor from upstream
            recv_prev = True
821 822 823 824 825 826 827 828
            next_forward_virtual_pp_rank = self._get_virtual_pp_rank(
                forward_micro_step_id + 1, forward=True
            )
            if self.is_pipeline_first_stage(ignore_virtual=True) and (
                next_forward_virtual_pp_rank == 0
            ):
                # first pp stage and first virtual stage
                recv_prev = False
829 830 831 832 833 834 835

            # last iteration doesn't need recv from upstream
            if micro_step == (steady_steps - 1):
                recv_prev = False

            # determine whether to recv grad from downstream
            recv_next = True
836 837 838 839 840 841 842 843
            next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
                backward_micro_step_id + 1, forward=False
            )
            if self.is_pipeline_last_stage(ignore_virtual=True) and (
                next_backward_virtual_pp_rank == (self.num_model_chunks - 1)
            ):
                # last pp stage and last virtual stage
                recv_next = False
844

845 846 847 848
            (
                input_tensor,
                output_tensor_grad,
            ) = p2p.send_forward_backward_recv_forward_backward(
849 850 851
                output_tensor,
                input_tensor_grad,
                recv_prev=recv_prev,
852 853
                recv_next=recv_next,
            )
854 855 856 857 858 859 860 861
            # append input_tensor no matter none or not
            self.input_tensors[next_forward_virtual_pp_rank].append(
                input_tensor
            )
            # append output_tensor_grad no matter none or not
            self.output_tensor_grads[next_backward_virtual_pp_rank].append(
                output_tensor_grad
            )
862 863 864 865 866 867 868

        # remaining backward steps
        if not forward_only:
            for micro_step in range(steady_steps, num_steps):
                # cooldown loop
                input_tensor_grad = self._backward_step_helper(micro_step)
                next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
869 870
                    micro_step + 1, forward=False
                )
871 872 873

                recv_next = True
                if self.is_pipeline_last_stage(ignore_virtual=True):
874 875 876
                    if next_backward_virtual_pp_rank == (
                        self.num_model_chunks - 1
                    ):
877 878 879 880
                        recv_next = False

                if micro_step == (num_steps - 1):
                    recv_next = False
881
                # append output_tensor_grad no matter none or not
882
                self.output_tensor_grads[next_backward_virtual_pp_rank].append(
883 884 885 886
                    p2p.send_backward_recv_backward(
                        input_tensor_grad, recv_next=recv_next
                    )
                )
887

888 889 890
            if self._comm_overlap:
                assert len(self._comm_buffers) > 0
                for buffer in self._comm_buffers:
891 892
                    buffer.scale_and_split_grads()

Y
Yuang Liu 已提交
893 894
            if self._enable_timer:
                self.timers("allreduce_shared_weight_gradients").start()
895
            self._layers.allreduce_shared_weight_gradients()
Y
Yuang Liu 已提交
896 897
            if self._enable_timer:
                self.timers("allreduce_shared_weight_gradients").stop()
898 899 900

        if compute_loss:
            # return loss if compute loss
Y
Yuang Liu 已提交
901 902
            if self._enable_timer:
                self.timers("broadcast_final_loss").start()
903 904
            with paddle.amp.auto_cast(enable=False):
                train_loss = self._broadcast_final_loss()
Y
Yuang Liu 已提交
905 906
            if self._enable_timer:
                self.timers("broadcast_final_loss").stop()
907 908 909 910
        else:
            # else just return all intermediate output tensor for all micro steps
            train_loss = self.output_tensors

Y
Yuang Liu 已提交
911
        self.timer_printer()
912 913 914 915 916
        return train_loss

    def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
        data = self._prepare_training(data, optimizer, lr_scheduler)
        # interleave scheduler for pipeline parallel
917
        train_loss = self.forward_backward_pipeline(data, scaler)
918 919 920 921 922 923 924 925 926 927 928 929 930 931

        # optimizer
        with paddle.amp.auto_cast(enable=False):
            self._optimizer_step()

        return train_loss

    def eval_batch(self, data, compute_loss=False):
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

        self._layers.eval()
        self._compute_loss = compute_loss

932
        return self.forward_backward_pipeline(data, None, forward_only=True)