pipeline_parallel.py 35.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

import paddle
15
from paddle import framework
16

17
from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer
Y
Yuang Liu 已提交
18
from ..utils import timer_helper as timer
19 20 21 22 23 24 25 26
from ..utils.hybrid_parallel_util import (
    broadcast_dp_parameters,
    broadcast_mp_parameters,
    broadcast_sharding_parameters,
)
from ..utils.log_util import logger
from .meta_parallel_base import MetaParallelBase
from .parallel_layers.pp_layers import PipelineLayer
S
ShenLiang 已提交
27
from .pp_utils import p2p_communication as p2p
28
from .pp_utils.utils import HOOK_ACTION, FusedCommBuffer, assign_group_by_size
29

30 31
__all__ = []

32 33 34

class PipelineParallel(MetaParallelBase):
    def __init__(self, layers, hcg, strategy):
35 36
        if not isinstance(layers, PipelineLayer):
            raise TypeError(
37 38
                "The Layer should be a derived class of PipelineLayer."
            )
39
        super().__init__(layers, hcg, strategy)
40 41
        self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1
        self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1
42 43 44
        self.use_sharding_parallel = (
            self._hcg.get_sharding_parallel_world_size() > 1
        )
45 46 47 48

        self.total_loss = None

        self.micro_batch_size = self._strategy.pipeline_configs[
49 50
            'micro_batch_size'
        ]
51
        self.accumulate_steps = self._strategy.pipeline_configs[
52 53
            'accumulate_steps'
        ]
54 55 56
        # If sent tensor are not the same from different hosts,
        # they shouldn't been sent partially and then concated as a whole tensor.
        self._enable_partial_send_recv = self._strategy.pipeline_configs[
57 58
            'enable_partial_send_recv'
        ]
59 60
        self._using_cache = self._strategy.pipeline_configs['p2p_cache_shape']

61 62
        self.num_stages = self._hcg.get_pipe_parallel_world_size()
        self.stage_id = self._hcg.get_stage_id()
63
        self.pp_group = self._hcg.get_pipe_parallel_group()
64
        self.dp_group = self._hcg.get_data_parallel_group()
65
        self.sharding_group = self._hcg.get_sharding_parallel_group()
66

67 68 69 70 71
        self._virtual_pp_world_size = None
        self._virtual_pp_rank = None
        self._real_pp_world_size = self.num_stages
        self._real_pp_rank = self.stage_id

72 73 74
        self._delay_scale_loss = self._strategy.hybrid_configs[
            "pp_configs"
        ].delay_scale_loss
75 76
        # TODO(PP Dev): support dp_comm_overlap without use_main_grad training.
        # This combination will trigger inplace check error during `reshape_` in funct `_split_tensors`.
77 78 79
        self._dp_comm_overlap = self._strategy.hybrid_configs[
            "pp_configs"
        ].dp_comm_overlap
80 81 82
        self._sharding_comm_overlap = self._strategy.hybrid_configs[
            "pp_configs"
        ].sharding_comm_overlap
Y
Yuang Liu 已提交
83 84 85
        self._enable_timer = self._strategy.hybrid_configs[
            "pp_configs"
        ].enable_timer
86

87 88 89
        if self._dp_comm_overlap:
            assert self.use_data_parallel and self.num_stages > 1

90 91 92 93 94 95 96 97 98 99 100 101
        if self._sharding_comm_overlap:
            assert self.use_sharding_parallel and self.num_stages > 1

        assert not (
            self._dp_comm_overlap and self._sharding_comm_overlap
        ), "Cannot use dp pp overlap and sharding pp overlap at the same time."

        self._comm_buffers = []
        self._comm_overlap = (
            self._dp_comm_overlap or self._sharding_comm_overlap
        )

Y
Yuang Liu 已提交
102 103 104 105 106
        if self._enable_timer:
            if not timer.is_timer_initialized():
                timer.set_timers()
            self.timers = timer.get_timers()

107
        p2p.initialize_p2p_groups(
Y
Yuang Liu 已提交
108 109 110 111
            hcg,
            self._using_cache,
            self._enable_partial_send_recv,
            self._enable_timer,
112
        )
113 114

        self.global_rank = self._hcg.get_global_rank()
115
        self.micro_batch_id = 0
116

117 118
        self._compute_loss = True

119 120 121 122 123
        logger.info(
            "Pipeline Info -- num_stages: {}, stage_id: {}".format(
                self.num_stages, self.stage_id
            )
        )
124 125 126 127 128

        if self.use_model_parallel:
            logger.info("start broadcast mp parameters")
            broadcast_mp_parameters(self._layers, self._hcg)

129 130 131 132
        if self.use_sharding_parallel:
            logger.info("start broadcast sharding parameters")
            broadcast_sharding_parameters(self._layers, self._hcg)

133
        if self.use_data_parallel:
134
            logger.info("start broadcast dp parameters")
135
            broadcast_dp_parameters(self._layers, self._hcg)
136

137 138
        if self._dp_comm_overlap:
            self.register_allreduce_overlap_hook(
139
                self._layers, self.dp_group, self.accumulate_steps, True
140 141
            )

142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
    def is_pipeline_first_stage(self, ignore_virtual=False):
        if not ignore_virtual:
            if self._virtual_pp_world_size is not None:
                assert self._virtual_pp_rank is not None
                if self._virtual_pp_rank != 0:
                    return False
        assert self._real_pp_rank is not None
        return self._real_pp_rank == 0

    def is_pipeline_last_stage(self, ignore_virtual=False):
        if not ignore_virtual:
            if self._virtual_pp_world_size is not None:
                assert self._virtual_pp_rank is not None
                if self._virtual_pp_rank != (self._virtual_pp_world_size - 1):
                    return False
        assert self._real_pp_rank is not None
        assert self._real_pp_world_size is not None
        return self._real_pp_rank == (self._real_pp_world_size - 1)

    def set_virtual_pipeline_rank(self, rank):
        self._virtual_pp_rank = rank

164 165 166 167 168 169 170
    def bw_hook_func(self, buffer, param):
        @paddle.autograd.no_grad()
        def fused_allreduce(*_):
            buffer.add_grad(param)

        return fused_allreduce

171
    def register_allreduce_overlap_hook(self, model, comm_group, acc_steps, dp):
172 173 174 175 176
        if model.get_num_virtual_stages() > 1:
            models = model.get_model_chunks()
        else:
            models = [model]

177 178 179 180 181 182 183 184 185
        if not dp:
            assert hasattr(self, "optimizer")
            assert hasattr(self.optimizer, "_param2rank")
            _param2rank = self.optimizer._param2rank

        act = HOOK_ACTION.ALL_REDUCE if dp else HOOK_ACTION.REDUCE

        fused_parameter_group = {}

186 187 188 189 190 191 192 193 194
        for model in models:
            # For virtual pipeline. Will separate parameters in different chunk into
            # different groups to get the best performance.
            parameter_list = [
                p for p in model.parameters() if not p.stop_gradient
            ]
            if len(parameter_list) < 1:
                return

195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
            if dp:
                fused_parameter_group[-1] = parameter_list
            else:
                # Sort parameters for sharding, since they have different dst rank
                for p in parameter_list:
                    assert p.name in _param2rank
                    dst_rank = _param2rank[p.name]
                    if dst_rank in fused_parameter_group:
                        fused_parameter_group[dst_rank].append(p)
                    else:
                        fused_parameter_group[dst_rank] = [p]

            for dst in fused_parameter_group:
                parameter_list = fused_parameter_group[dst]
                if not dp:
                    # parse the relative dst rank to absolute dst rank for sharding
                    dst = comm_group.ranks[dst]
                var_groups = assign_group_by_size(parameter_list)
                for group_idx, parameters in var_groups.items():
                    buffer = FusedCommBuffer(
                        group_idx, parameters, comm_group, acc_steps, act, dst
216
                    )
217 218 219 220 221
                    self._comm_buffers.append(buffer)
                    for param in parameters:
                        param._register_backward_hook(
                            self.bw_hook_func(buffer, param)
                        )
222

Y
Yuang Liu 已提交
223 224 225 226 227 228
    def timer_printer(self):
        if not self._enable_timer:
            return
        all_flag_names = self.timers.timers.keys()
        self.timers.log(all_flag_names)

229 230 231 232
    def forward_backward_pipeline(self, data, scaler=None):
        # use the 1f1b scheduling strategy.
        # this strategy is inspired by:
        # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
233

234 235
        self.scaler = scaler

236 237
        # store data for train
        self.data = data
238

239 240 241
        # store total loss of entire batch
        self.total_loss = None

242 243
        # store data id for micro_batch
        self.micro_batch_id = 0
244

245
        startup_steps = self.num_stages - self.stage_id - 1
246 247
        startup_steps = min(startup_steps, self.accumulate_steps)
        steady_steps = self.accumulate_steps - startup_steps
248

249 250
        input_buffers = []
        output_buffers = []
251

252
        for step_id in range(startup_steps):
253
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
254

255
            output_tensor = self._forward_step(input_tensor)
256
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
257

258 259
            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)
260

261
        if steady_steps > 0:
262
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
263

264
        for i in range(steady_steps):
265
            last_iter = i == (steady_steps - 1)
266

267
            output_tensor = self._forward_step(input_tensor)
268

269
            output_tensor_grad = p2p.send_forward_recv_backward(
270 271
                output_tensor, self.is_pipeline_last_stage()
            )
272

273 274
            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)
275

276
            input_tensor, output_tensor = input_buffers.pop(
277 278
                0
            ), output_buffers.pop(0)
279

280 281 282
            input_tensor_grad = self._backward_step(
                input_tensor, output_tensor, output_tensor_grad
            )
283 284 285

            if last_iter:
                input_tensor = None
286 287 288
                p2p.send_backward(
                    input_tensor_grad, self.is_pipeline_first_stage()
                )
289
            else:
290
                input_tensor = p2p.send_backward_recv_forward(
291 292
                    input_tensor_grad, self.is_pipeline_first_stage()
                )
293

294 295 296
        for i in range(startup_steps):
            input_tensor = input_buffers.pop(0)
            output_tensor = output_buffers.pop(0)
297

298
            output_tensor_grad = p2p.recv_backward(
299 300
                self.is_pipeline_last_stage()
            )
301

302 303 304
            input_tensor_grad = self._backward_step(
                input_tensor, output_tensor, output_tensor_grad
            )
305
            p2p.send_backward(input_tensor_grad, self.is_pipeline_first_stage())
306

307 308 309
        if self._comm_overlap:
            assert len(self._comm_buffers) > 0
            for buffer in self._comm_buffers:
310 311
                buffer.scale_and_split_grads()

Y
Yuang Liu 已提交
312 313
        if self._enable_timer:
            self.timers("allreduce_shared_weight_gradients").start()
314
        self._layers.allreduce_shared_weight_gradients()
Y
Yuang Liu 已提交
315 316 317
        if self._enable_timer:
            self.timers("allreduce_shared_weight_gradients").stop()
            self.timers("broadcast_final_loss").start()
318 319
        with paddle.amp.auto_cast(enable=False):
            train_loss = self._broadcast_final_loss()
Y
Yuang Liu 已提交
320 321 322
        if self._enable_timer:
            self.timers("broadcast_final_loss").stop()
        self.timer_printer()
323 324
        return train_loss

325 326 327 328
    def _prepare_training(self, data, optimizer, lr_scheduler):
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

329 330 331
        assert isinstance(
            optimizer, HybridParallelOptimizer
        ), 'optimizer should be HybridParallelOptimizer subclass.'
332

333
        assert (
334
            framework._dygraph_tracer()._has_grad
335
        ), 'Please enable the generation of gradients.'
336

337
        if self.is_pipeline_first_stage(
338 339 340 341 342
            ignore_virtual=True
        ) or self.is_pipeline_last_stage(ignore_virtual=True):
            assert (
                data is not None
            ), "For the first and the last stage, the data must be set."
343 344 345 346 347 348 349 350
        else:
            data = None

        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler

        self._layers.train()

351 352 353 354 355
        if self._sharding_comm_overlap and len(self._comm_buffers) == 0:
            self.register_allreduce_overlap_hook(
                self._layers, self.sharding_group, self.accumulate_steps, False
            )

356 357 358 359 360
        return data

    def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
        data = self._prepare_training(data, optimizer, lr_scheduler)
        # 1f1b scheduler for pipeline parallel
361
        train_loss = self.forward_backward_pipeline(data, scaler)
362 363

        # optimizer
364 365
        with paddle.amp.auto_cast(enable=False):
            self._optimizer_step()
366 367

        return train_loss
368

369
    def eval_batch(self, data, compute_loss=False):
370 371 372
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

373 374 375 376 377 378 379 380 381 382 383
        self._layers.eval()
        self._compute_loss = compute_loss

        # save data for eval
        self.data = data
        # store data id for micro_batch
        self.micro_batch_id = 0

        # store total loss of entire batch
        self.total_loss = None

384
        startup_steps = self.num_stages - self.stage_id - 1
385 386 387 388 389 390 391
        startup_steps = min(startup_steps, self.accumulate_steps)
        steady_steps = self.accumulate_steps - startup_steps

        input_buffers = []
        output_buffers = []

        for step_id in range(startup_steps):
392
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
393 394

            output_tensor = self._forward_step(input_tensor)
395
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
396 397 398 399 400

            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)

        if steady_steps > 0:
401
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
402 403

        for i in range(steady_steps):
404
            last_iter = i == (steady_steps - 1)
405 406

            output_tensor = self._forward_step(input_tensor)
407
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
408 409 410 411 412

            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)

            if not last_iter:
413
                input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
414

415 416 417 418 419 420
        if self._compute_loss:
            self.train_loss = self._broadcast_final_loss()
        else:
            self.train_loss = output_buffers

        return self.train_loss
421

422
    def _forward_step(self, input_tensor, chunk_id=None):
Y
Yuang Liu 已提交
423 424
        if self._enable_timer:
            self.timers("forward_step").start()
425
        if self.is_pipeline_first_stage():
426 427
            input_tensor = self._load_micro_batch(self.micro_batch_id)

428 429 430
        assert chunk_id is None or isinstance(chunk_id, int)

        output_tensor = self._layers.forward(input_tensor, chunk_id=chunk_id)
431

432
        if self.is_pipeline_last_stage():
433 434
            # train calculate loss for train
            if self._compute_loss:
435 436 437
                assert (
                    self._layers._loss_fn is not None
                ), "loss function should exist to compute loss"
438 439
                labels = self._load_micro_batch(self.micro_batch_id)
                output_tensor = self._layers._loss_fn(output_tensor, labels)
440
                assert isinstance(
441
                    output_tensor, (paddle.Tensor, framework.core.eager.Tensor)
442
                ), "Currently, loss_fn should obtain Paddle.Tensor dtype"
443

444
                with paddle.amp.auto_cast(enable=False):
445
                    if self.accumulate_steps > 1 and not self._delay_scale_loss:
446
                        output_tensor = output_tensor / self.accumulate_steps
447

448 449 450
                    if self.total_loss is None:
                        self.total_loss = paddle.zeros_like(output_tensor)
                    self.total_loss += output_tensor.detach()
451

452 453 454 455
        if self.is_pipeline_first_stage() or self.is_pipeline_last_stage():
            # Only increase micro batch id at virtual first/last pp stage.
            # The micro batch id is used to load data, therefore, only increase it when load data.
            self.micro_batch_id += 1
Y
Yuang Liu 已提交
456 457
        if self._enable_timer:
            self.timers("forward_step").stop()
458 459 460
        return output_tensor

    def _backward_step(self, input_tensor, output_tensor, output_tensor_grad):
Y
Yuang Liu 已提交
461 462
        if self._enable_timer:
            self.timers("backward_step").start()
S
ShenLiang 已提交
463
        with paddle.amp.auto_cast(enable=False):
464
            if self.is_pipeline_last_stage():
S
ShenLiang 已提交
465 466 467 468 469
                assert output_tensor_grad is None
                if self.scaler:
                    paddle.autograd.backward(self.scaler.scale(output_tensor))
                else:
                    paddle.autograd.backward(output_tensor)
470
            else:
S
ShenLiang 已提交
471 472 473 474 475
                if isinstance(output_tensor, tuple):
                    outputs = [t for t in output_tensor if not t.stop_gradient]
                    assert len(outputs) == len(output_tensor_grad)
                    paddle.autograd.backward(
                        tensors=outputs,
476
                        grad_tensors=list(output_tensor_grad),
477
                    )
S
ShenLiang 已提交
478
                else:
479 480 481 482
                    paddle.autograd.backward(
                        tensors=[output_tensor],
                        grad_tensors=[output_tensor_grad],
                    )
S
ShenLiang 已提交
483 484 485 486 487

            input_tensor_grad = None
            if input_tensor is not None:
                if isinstance(input_tensor, tuple):
                    input_tensor_grad = tuple(
488 489
                        [t.grad for t in input_tensor if not t.stop_gradient]
                    )
S
ShenLiang 已提交
490 491
                else:
                    input_tensor_grad = input_tensor.grad
Y
Yuang Liu 已提交
492 493
            if self._enable_timer:
                self.timers("backward_step").stop()
S
ShenLiang 已提交
494
            return input_tensor_grad
495

496 497 498 499 500 501 502 503 504
    def _check_data_vaild(self, data):
        batch_size = data.shape[0]
        assert self.micro_batch_size * self.accumulate_steps == batch_size, (
            "batch_size needs to be divisible by micro_batch_size. Currently, "
            "batch_size = %d, micro_batch_size = %d, accumulate_steps = %d."
            % (batch_size, self.micro_batch_size, self.accumulate_steps)
        )

    def _load_micro_batch_impl(self, inputs, cache_id):
505 506 507
        begin = cache_id * self.micro_batch_size
        end = begin + self.micro_batch_size

508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537
        if isinstance(inputs, tuple):
            output = []
            for data in inputs:
                if isinstance(data, list):
                    assert (
                        len(data) == self.accumulate_steps
                    ), "length of data should be %d, but it is %d" % (
                        self.accumulate_steps,
                        len(data),
                    )
                    output.append(data[cache_id].detach())
                else:
                    self._check_data_vaild(data)
                    output.append(data[begin:end, :].detach())
            return tuple(output)

        elif isinstance(inputs, list):
            assert (
                len(inputs) == self.accumulate_steps
            ), "length of data should be %d, but it is %d" % (
                self.accumulate_steps,
                len(inputs),
            )
            return inputs[cache_id].detach()
        else:
            self._check_data_vaild(inputs)
            return inputs[begin:end, :].detach()

    def _load_micro_batch(self, cache_id):
        inputs = self.data
538
        if self.is_pipeline_first_stage():
539
            assert len(inputs) == 2, "length of input should be 2"
540
            return self._load_micro_batch_impl(inputs[0], cache_id)
541
        elif self.is_pipeline_last_stage():
542
            assert len(inputs) == 2, "length of input should be 2"
543
            return self._load_micro_batch_impl(inputs[1], cache_id)
544 545
        else:
            inputs = None
546

547
    def _broadcast_final_loss(self):
548 549 550
        # Since the last backward run in interleave will set the virtual rank to 0,
        # here we need to check last stage ignoring virtual stage.
        if self.is_pipeline_last_stage(ignore_virtual=True):
551 552 553
            assert (
                self.total_loss is not None
            ), "train_batch() in last stage should obtain vaild loss"
554 555 556 557 558
            loss = (
                self.total_loss.detach()
                if not self._delay_scale_loss
                else self.total_loss / self.accumulate_steps
            )
559
            is_fp32 = (
560
                paddle.full([], 1, 'int64')
561
                if loss.dtype == paddle.float32
562
                else paddle.full([], 0, 'int64')
563 564 565 566 567 568 569
            )
            paddle.distributed.broadcast(
                is_fp32, src=self.global_rank, sync_op=True, group=self.pp_group
            )
            paddle.distributed.broadcast(
                loss, src=self.global_rank, sync_op=True, group=self.pp_group
            )
570
        else:
571
            is_fp32 = paddle.full([], 1, 'int64')
572 573 574
            paddle.distributed.broadcast(
                is_fp32,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
575
                sync_op=True,
576 577 578 579
                group=self.pp_group,
            )
            loss = (
                paddle.zeros(shape=[1], dtype="float32")
580
                if is_fp32.item()
581 582
                else paddle.zeros(shape=[1], dtype="float16")
            )
583 584 585
            paddle.distributed.broadcast(
                loss,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
586
                sync_op=True,
587 588
                group=self.pp_group,
            )
589
        return loss
590

591
    def _optimizer_step(self):
592 593 594 595 596 597 598 599
        if self._delay_scale_loss:
            for p in self._layers.parameters():
                if hasattr(p, "main_grad") and p.main_grad is not None:
                    assert p.grad is None
                    p.main_grad = p.main_grad.scale(1.0 / self.accumulate_steps)
                elif p.grad is not None:
                    p.grad = p.grad.scale(1.0 / self.accumulate_steps)

600
        if self.scaler:
601
            self.scaler.step(self.optimizer)
S
ShenLiang 已提交
602
            self.scaler.update()
603 604
        else:
            self.optimizer.step()
605

606 607 608
        self.optimizer.clear_grad()
        if self.lr_scheduler:
            self.lr_scheduler.step()
609 610 611 612 613 614


class PipelineParallelWithInterleave(PipelineParallel):
    # pipeline parallel with interleave scheduler

    def __init__(self, layers, hcg, strategy):
615
        super().__init__(layers=layers, hcg=hcg, strategy=strategy)
616
        assert layers.get_num_virtual_stages() > 1
617
        assert (
618
            framework.in_dynamic_mode()
619
        ), "virtual pipeline stage with interleave only support eager dygraph mode"
620 621 622
        assert (
            self.accumulate_steps % self.num_stages == 0
        ), "accumulate_steps should be evenly divisible by num_stages for pipeline with interleave"
623 624 625 626 627 628 629 630 631
        # setup for interleave scheduler
        self.num_model_chunks = layers.get_num_virtual_stages()
        self.model_chunks = layers.get_model_chunks()
        assert self.model_chunks is not None
        assert len(self.model_chunks) == self.num_model_chunks
        self._virtual_pp_world_size = self.num_model_chunks
        self._virtual_pp_rank = 0

    def _get_virtual_pp_rank(self, micro_step, forward):
632 633 634
        virtual_pp_stage = micro_step % (
            self.num_stages * self.num_model_chunks
        )
635 636
        virtual_pp_stage = virtual_pp_stage // self.num_stages
        if not forward:
637
            virtual_pp_stage = self.num_model_chunks - virtual_pp_stage - 1
638 639 640 641 642 643 644 645 646 647 648
        return virtual_pp_stage

    def _forward_step_helper(self, micro_step):
        virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=True)
        self.set_virtual_pipeline_rank(virtual_pp_rank)

        # some checkers
        assert hasattr(self, 'input_tensors')
        assert hasattr(self, 'output_tensors')
        if not self._forward_only:
            assert hasattr(self, 'output_tensor_grads')
649 650 651
        assert len(self.input_tensors[virtual_pp_rank]) == (
            len(self.output_tensors[virtual_pp_rank]) + 1
        )
652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671
        input_tensor = self.input_tensors[virtual_pp_rank][-1]
        output_tensor = self._forward_step(input_tensor, virtual_pp_rank)
        self.output_tensors[virtual_pp_rank].append(output_tensor)

        if self._forward_only:
            # no need to store tensor for backward
            self.input_tensors[virtual_pp_rank].pop()
            self.output_tensors[virtual_pp_rank].pop()

        return output_tensor

    def _backward_step_helper(self, micro_step):
        virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=False)
        self.set_virtual_pipeline_rank(virtual_pp_rank)

        # some checkers
        assert hasattr(self, 'input_tensors')
        assert hasattr(self, 'output_tensors')
        assert hasattr(self, 'output_tensor_grads')

672 673 674 675 676 677
        assert (
            len(self.output_tensor_grads[virtual_pp_rank]) == 1
        ), f"output_tensor_grads is empty for virtual_pp_rank {virtual_pp_rank}"

        assert len(self.input_tensors[virtual_pp_rank]) > 0
        assert len(self.output_tensors[virtual_pp_rank]) > 0
678 679 680 681

        input_tensor = self.input_tensors[virtual_pp_rank].pop(0)
        output_tensor = self.output_tensors[virtual_pp_rank].pop(0)
        output_tensor_grad = self.output_tensor_grads[virtual_pp_rank].pop(0)
682 683 684
        input_tensor_grad = self._backward_step(
            input_tensor, output_tensor, output_tensor_grad
        )
685 686 687

        return input_tensor_grad

688
    def forward_backward_pipeline(
689 690
        self, data, scaler, forward_only=False, compute_loss=True
    ):
691 692 693 694
        # use interleave scheduling strategy.
        # this strategy is inspired by:
        # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
        if not compute_loss:
695 696 697
            assert (
                not forward_only
            ), "compute_loss can only be set to False when forward_only is set to True"
698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715

        # init some attributes for this batch run
        self.scaler = scaler
        self.data = data
        self.total_loss = None
        self.micro_batch_id = 0
        self._forward_only = forward_only

        # init some data buffers for interleave scheduler
        self.input_tensors = [[] for _ in range(self.num_model_chunks)]
        self.output_tensors = [[] for _ in range(self.num_model_chunks)]
        self.output_tensor_grads = [[] for _ in range(self.num_model_chunks)]

        num_steps = self.accumulate_steps * self.num_model_chunks
        if forward_only:
            # If only forward, since there is no backward during running, all steps are startup steps
            startup_steps = num_steps
        else:
716 717 718 719 720 721 722
            # actually startup_steps is calculated from two number:
            # first_forward_cross_to_end = (self.num_stages - self.stage_id - 1) + (self.num_model_chunks - 1) * self.num_stages
            # end_to_first_backward_cross = (self.num_stages - self.stage_id - 1)
            # startup_steps = first_forward_cross_to_end + end_to_first_backward_cross
            startup_steps = (self.num_stages - self.stage_id - 1) * 2
            startup_steps += (self.num_model_chunks - 1) * self.num_stages
            startup_steps = min(startup_steps, num_steps)
723 724 725 726 727

        steady_steps = num_steps - startup_steps

        self.set_virtual_pipeline_rank(0)
        self.input_tensors[0].append(
728 729
            p2p.recv_forward(self.is_pipeline_first_stage(), sync_recv=False)
        )
730 731 732 733 734 735

        # run startup steps
        for micro_step in range(startup_steps):
            output_tensor = self._forward_step_helper(micro_step)

            # determine whether recv forward tensor or not
736 737 738
            next_virtual_pp_rank = self._get_virtual_pp_rank(
                micro_step + 1, forward=True
            )
739 740 741 742 743 744 745 746 747 748 749 750 751
            recv_prev = True
            if self.is_pipeline_first_stage(ignore_virtual=True):
                if next_virtual_pp_rank == 0:
                    # next chunk is the first chunk, not need to pre recv an input tensor
                    recv_prev = False
            # last micro step, no next run
            if micro_step == (num_steps - 1):
                recv_prev = False

            # last stage shouldn't send tensor to downstream
            if self.is_pipeline_last_stage():
                output_tensor = None

752
            if micro_step == (startup_steps - 1) and not forward_only:
753 754 755 756 757 758
                input_tensor_grad = None
                recv_next = True
                if self.is_pipeline_last_stage(ignore_virtual=True):
                    recv_next = False

                # the last startup step needs on four direction comm to set up for steady 1f1b
759 760 761 762
                (
                    input_tensor,
                    output_tensor_grad,
                ) = p2p.send_forward_backward_recv_forward_backward(
763 764 765
                    output_tensor,
                    input_tensor_grad,
                    recv_prev=recv_prev,
766 767
                    recv_next=recv_next,
                )
768 769
                # output_tensor_grad is not none if recv_next
                # append output_tensor_grad no matter none or not
770 771 772
                self.output_tensor_grads[self.num_model_chunks - 1].append(
                    output_tensor_grad
                )
773 774
            else:
                input_tensor = p2p.send_forward_recv_forward(
775 776
                    output_tensor, recv_prev=recv_prev
                )
777
            # append input_tensor no matter none or not
778 779 780 781 782 783 784 785 786 787 788
            self.input_tensors[next_virtual_pp_rank].append(input_tensor)

        # run 1f1b steady steps
        for micro_step in range(steady_steps):
            # forward
            forward_micro_step_id = micro_step + startup_steps
            output_tensor = self._forward_step_helper(forward_micro_step_id)

            # backward
            backward_micro_step_id = micro_step
            input_tensor_grad = self._backward_step_helper(
789 790
                backward_micro_step_id
            )
791 792 793 794 795 796 797 798 799

            # four directions comm
            # send output tensor to downstream
            # send input tensor grad to upstream
            # recv input tensor from upstream
            # recv output tensor grad from downstream

            # last stage doesn't send rst to downstream
            forward_virtual_pp_rank = self._get_virtual_pp_rank(
800 801
                forward_micro_step_id, forward=True
            )
802 803 804 805 806 807
            self.set_virtual_pipeline_rank(forward_virtual_pp_rank)
            if self.is_pipeline_last_stage():
                output_tensor = None

            # first stage doesn't send grad to upstream
            backward_virtual_pp_rank = self._get_virtual_pp_rank(
808 809
                backward_micro_step_id, forward=False
            )
810 811 812 813 814 815
            self.set_virtual_pipeline_rank(backward_virtual_pp_rank)
            if self.is_pipeline_first_stage():
                input_tensor_grad = None

            # determine whether to recv input tensor from upstream
            recv_prev = True
816 817 818 819 820 821 822 823
            next_forward_virtual_pp_rank = self._get_virtual_pp_rank(
                forward_micro_step_id + 1, forward=True
            )
            if self.is_pipeline_first_stage(ignore_virtual=True) and (
                next_forward_virtual_pp_rank == 0
            ):
                # first pp stage and first virtual stage
                recv_prev = False
824 825 826 827 828 829 830

            # last iteration doesn't need recv from upstream
            if micro_step == (steady_steps - 1):
                recv_prev = False

            # determine whether to recv grad from downstream
            recv_next = True
831 832 833 834 835 836 837 838
            next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
                backward_micro_step_id + 1, forward=False
            )
            if self.is_pipeline_last_stage(ignore_virtual=True) and (
                next_backward_virtual_pp_rank == (self.num_model_chunks - 1)
            ):
                # last pp stage and last virtual stage
                recv_next = False
839

840 841 842 843
            (
                input_tensor,
                output_tensor_grad,
            ) = p2p.send_forward_backward_recv_forward_backward(
844 845 846
                output_tensor,
                input_tensor_grad,
                recv_prev=recv_prev,
847 848
                recv_next=recv_next,
            )
849 850 851 852 853 854 855 856
            # append input_tensor no matter none or not
            self.input_tensors[next_forward_virtual_pp_rank].append(
                input_tensor
            )
            # append output_tensor_grad no matter none or not
            self.output_tensor_grads[next_backward_virtual_pp_rank].append(
                output_tensor_grad
            )
857 858 859 860 861 862 863

        # remaining backward steps
        if not forward_only:
            for micro_step in range(steady_steps, num_steps):
                # cooldown loop
                input_tensor_grad = self._backward_step_helper(micro_step)
                next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
864 865
                    micro_step + 1, forward=False
                )
866 867 868

                recv_next = True
                if self.is_pipeline_last_stage(ignore_virtual=True):
869 870 871
                    if next_backward_virtual_pp_rank == (
                        self.num_model_chunks - 1
                    ):
872 873 874 875
                        recv_next = False

                if micro_step == (num_steps - 1):
                    recv_next = False
876
                # append output_tensor_grad no matter none or not
877
                self.output_tensor_grads[next_backward_virtual_pp_rank].append(
878 879 880 881
                    p2p.send_backward_recv_backward(
                        input_tensor_grad, recv_next=recv_next
                    )
                )
882

883 884 885
            if self._comm_overlap:
                assert len(self._comm_buffers) > 0
                for buffer in self._comm_buffers:
886 887
                    buffer.scale_and_split_grads()

Y
Yuang Liu 已提交
888 889
            if self._enable_timer:
                self.timers("allreduce_shared_weight_gradients").start()
890
            self._layers.allreduce_shared_weight_gradients()
Y
Yuang Liu 已提交
891 892
            if self._enable_timer:
                self.timers("allreduce_shared_weight_gradients").stop()
893 894 895

        if compute_loss:
            # return loss if compute loss
Y
Yuang Liu 已提交
896 897
            if self._enable_timer:
                self.timers("broadcast_final_loss").start()
898 899
            with paddle.amp.auto_cast(enable=False):
                train_loss = self._broadcast_final_loss()
Y
Yuang Liu 已提交
900 901
            if self._enable_timer:
                self.timers("broadcast_final_loss").stop()
902 903 904 905
        else:
            # else just return all intermediate output tensor for all micro steps
            train_loss = self.output_tensors

Y
Yuang Liu 已提交
906
        self.timer_printer()
907 908 909 910 911
        return train_loss

    def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
        data = self._prepare_training(data, optimizer, lr_scheduler)
        # interleave scheduler for pipeline parallel
912
        train_loss = self.forward_backward_pipeline(data, scaler)
913 914 915 916 917 918 919 920 921 922 923 924 925 926

        # optimizer
        with paddle.amp.auto_cast(enable=False):
            self._optimizer_step()

        return train_loss

    def eval_batch(self, data, compute_loss=False):
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

        self._layers.eval()
        self._compute_loss = compute_loss

927
        return self.forward_backward_pipeline(data, None, forward_only=True)