pipeline_parallel.py 36.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

import paddle
15
from paddle import framework
16

17
from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer
Y
Yuang Liu 已提交
18
from ..utils import timer_helper as timer
19 20 21 22 23 24 25 26
from ..utils.hybrid_parallel_util import (
    broadcast_dp_parameters,
    broadcast_mp_parameters,
    broadcast_sharding_parameters,
)
from ..utils.log_util import logger
from .meta_parallel_base import MetaParallelBase
from .parallel_layers.pp_layers import PipelineLayer
S
ShenLiang 已提交
27
from .pp_utils import p2p_communication as p2p
28
from .pp_utils.utils import HOOK_ACTION, FusedCommBuffer, assign_group_by_size
29

30 31
__all__ = []

32 33 34

class PipelineParallel(MetaParallelBase):
    def __init__(self, layers, hcg, strategy):
35 36
        if not isinstance(layers, PipelineLayer):
            raise TypeError(
37 38
                "The Layer should be a derived class of PipelineLayer."
            )
39
        super().__init__(layers, hcg, strategy)
40 41
        self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1
        self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1
42 43 44
        self.use_sharding_parallel = (
            self._hcg.get_sharding_parallel_world_size() > 1
        )
45 46 47 48

        self.total_loss = None

        self.micro_batch_size = self._strategy.pipeline_configs[
49 50
            'micro_batch_size'
        ]
51
        self.accumulate_steps = self._strategy.pipeline_configs[
52 53
            'accumulate_steps'
        ]
54 55 56
        # If sent tensor are not the same from different hosts,
        # they shouldn't been sent partially and then concated as a whole tensor.
        self._enable_partial_send_recv = self._strategy.pipeline_configs[
57 58
            'enable_partial_send_recv'
        ]
59 60
        self._using_cache = self._strategy.pipeline_configs['p2p_cache_shape']

61 62
        self.num_stages = self._hcg.get_pipe_parallel_world_size()
        self.stage_id = self._hcg.get_stage_id()
63
        self.pp_group = self._hcg.get_pipe_parallel_group()
64
        self.dp_group = self._hcg.get_data_parallel_group()
65
        self.sharding_group = self._hcg.get_sharding_parallel_group()
66

67 68 69 70 71
        self._virtual_pp_world_size = None
        self._virtual_pp_rank = None
        self._real_pp_world_size = self.num_stages
        self._real_pp_rank = self.stage_id

72 73 74
        self._delay_scale_loss = self._strategy.hybrid_configs[
            "pp_configs"
        ].delay_scale_loss
75 76
        # TODO(PP Dev): support dp_comm_overlap without use_main_grad training.
        # This combination will trigger inplace check error during `reshape_` in funct `_split_tensors`.
77 78 79
        self._dp_comm_overlap = self._strategy.hybrid_configs[
            "pp_configs"
        ].dp_comm_overlap
80 81 82
        self._sharding_comm_overlap = self._strategy.hybrid_configs[
            "pp_configs"
        ].sharding_comm_overlap
Y
Yuang Liu 已提交
83 84 85
        self._enable_timer = self._strategy.hybrid_configs[
            "pp_configs"
        ].enable_timer
86

87 88 89
        if self._dp_comm_overlap:
            assert self.use_data_parallel and self.num_stages > 1

90 91 92 93 94 95 96 97 98 99 100 101
        if self._sharding_comm_overlap:
            assert self.use_sharding_parallel and self.num_stages > 1

        assert not (
            self._dp_comm_overlap and self._sharding_comm_overlap
        ), "Cannot use dp pp overlap and sharding pp overlap at the same time."

        self._comm_buffers = []
        self._comm_overlap = (
            self._dp_comm_overlap or self._sharding_comm_overlap
        )

Y
Yuang Liu 已提交
102 103 104 105 106
        if self._enable_timer:
            if not timer.is_timer_initialized():
                timer.set_timers()
            self.timers = timer.get_timers()

107
        p2p.initialize_p2p_groups(
Y
Yuang Liu 已提交
108 109 110 111
            hcg,
            self._using_cache,
            self._enable_partial_send_recv,
            self._enable_timer,
112
        )
113 114

        self.global_rank = self._hcg.get_global_rank()
115
        self.micro_batch_id = 0
116

117 118
        self._compute_loss = True

119 120 121 122 123
        logger.info(
            "Pipeline Info -- num_stages: {}, stage_id: {}".format(
                self.num_stages, self.stage_id
            )
        )
124 125 126 127 128

        if self.use_model_parallel:
            logger.info("start broadcast mp parameters")
            broadcast_mp_parameters(self._layers, self._hcg)

129 130 131 132
        if self.use_sharding_parallel:
            logger.info("start broadcast sharding parameters")
            broadcast_sharding_parameters(self._layers, self._hcg)

133
        if self.use_data_parallel:
134
            logger.info("start broadcast dp parameters")
135
            broadcast_dp_parameters(self._layers, self._hcg)
136

137 138
        if self._dp_comm_overlap:
            self.register_allreduce_overlap_hook(
139
                self._layers, self.dp_group, self.accumulate_steps, True
140 141
            )

142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
    def is_pipeline_first_stage(self, ignore_virtual=False):
        if not ignore_virtual:
            if self._virtual_pp_world_size is not None:
                assert self._virtual_pp_rank is not None
                if self._virtual_pp_rank != 0:
                    return False
        assert self._real_pp_rank is not None
        return self._real_pp_rank == 0

    def is_pipeline_last_stage(self, ignore_virtual=False):
        if not ignore_virtual:
            if self._virtual_pp_world_size is not None:
                assert self._virtual_pp_rank is not None
                if self._virtual_pp_rank != (self._virtual_pp_world_size - 1):
                    return False
        assert self._real_pp_rank is not None
        assert self._real_pp_world_size is not None
        return self._real_pp_rank == (self._real_pp_world_size - 1)

    def set_virtual_pipeline_rank(self, rank):
        self._virtual_pp_rank = rank

164 165 166 167 168 169 170
    def bw_hook_func(self, buffer, param):
        @paddle.autograd.no_grad()
        def fused_allreduce(*_):
            buffer.add_grad(param)

        return fused_allreduce

171
    def register_allreduce_overlap_hook(self, model, comm_group, acc_steps, dp):
172 173 174 175 176
        if model.get_num_virtual_stages() > 1:
            models = model.get_model_chunks()
        else:
            models = [model]

177 178 179 180 181 182 183 184 185
        if not dp:
            assert hasattr(self, "optimizer")
            assert hasattr(self.optimizer, "_param2rank")
            _param2rank = self.optimizer._param2rank

        act = HOOK_ACTION.ALL_REDUCE if dp else HOOK_ACTION.REDUCE

        fused_parameter_group = {}

186 187 188 189 190 191 192 193 194
        for model in models:
            # For virtual pipeline. Will separate parameters in different chunk into
            # different groups to get the best performance.
            parameter_list = [
                p for p in model.parameters() if not p.stop_gradient
            ]
            if len(parameter_list) < 1:
                return

195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
            if dp:
                fused_parameter_group[-1] = parameter_list
            else:
                # Sort parameters for sharding, since they have different dst rank
                for p in parameter_list:
                    assert p.name in _param2rank
                    dst_rank = _param2rank[p.name]
                    if dst_rank in fused_parameter_group:
                        fused_parameter_group[dst_rank].append(p)
                    else:
                        fused_parameter_group[dst_rank] = [p]

            for dst in fused_parameter_group:
                parameter_list = fused_parameter_group[dst]
                if not dp:
                    # parse the relative dst rank to absolute dst rank for sharding
                    dst = comm_group.ranks[dst]
                var_groups = assign_group_by_size(parameter_list)
                for group_idx, parameters in var_groups.items():
                    buffer = FusedCommBuffer(
                        group_idx, parameters, comm_group, acc_steps, act, dst
216
                    )
217 218 219 220 221
                    self._comm_buffers.append(buffer)
                    for param in parameters:
                        param._register_backward_hook(
                            self.bw_hook_func(buffer, param)
                        )
222

Y
Yuang Liu 已提交
223 224 225 226 227 228
    def timer_printer(self):
        if not self._enable_timer:
            return
        all_flag_names = self.timers.timers.keys()
        self.timers.log(all_flag_names)

229 230 231 232
    def forward_backward_pipeline(self, data, scaler=None):
        # use the 1f1b scheduling strategy.
        # this strategy is inspired by:
        # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
233

234 235
        self.scaler = scaler

236 237
        # store data for train
        self.data = data
238

239 240 241
        # store total loss of entire batch
        self.total_loss = None

242 243
        # store data id for micro_batch
        self.micro_batch_id = 0
244

245
        startup_steps = self.num_stages - self.stage_id - 1
246 247
        startup_steps = min(startup_steps, self.accumulate_steps)
        steady_steps = self.accumulate_steps - startup_steps
248

249 250
        input_buffers = []
        output_buffers = []
251

252
        for step_id in range(startup_steps):
253
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
254

255
            output_tensor = self._forward_step(input_tensor)
256
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
257

258 259
            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)
260

261 262 263
            if not self.is_pipeline_last_stage():
                self._release_output(output_tensor)

264
        if steady_steps > 0:
265
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
266

267
        for i in range(steady_steps):
268
            last_iter = i == (steady_steps - 1)
269

270
            output_tensor = self._forward_step(input_tensor)
271

272
            output_tensor_grad = p2p.send_forward_recv_backward(
273 274
                output_tensor, self.is_pipeline_last_stage()
            )
275

276 277
            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)
278

279 280 281
            if not self.is_pipeline_last_stage():
                self._release_output(output_tensor)

282
            input_tensor, output_tensor = input_buffers.pop(
283 284
                0
            ), output_buffers.pop(0)
285

286 287 288
            input_tensor_grad = self._backward_step(
                input_tensor, output_tensor, output_tensor_grad
            )
289 290 291

            if last_iter:
                input_tensor = None
292 293 294
                p2p.send_backward(
                    input_tensor_grad, self.is_pipeline_first_stage()
                )
295
            else:
296
                input_tensor = p2p.send_backward_recv_forward(
297 298
                    input_tensor_grad, self.is_pipeline_first_stage()
                )
299

300 301 302
        for i in range(startup_steps):
            input_tensor = input_buffers.pop(0)
            output_tensor = output_buffers.pop(0)
303

304
            output_tensor_grad = p2p.recv_backward(
305 306
                self.is_pipeline_last_stage()
            )
307

308 309 310
            input_tensor_grad = self._backward_step(
                input_tensor, output_tensor, output_tensor_grad
            )
311
            p2p.send_backward(input_tensor_grad, self.is_pipeline_first_stage())
312

313 314 315
        if self._comm_overlap:
            assert len(self._comm_buffers) > 0
            for buffer in self._comm_buffers:
316 317
                buffer.scale_and_split_grads()

Y
Yuang Liu 已提交
318 319
        if self._enable_timer:
            self.timers("allreduce_shared_weight_gradients").start()
320
        self._layers.allreduce_shared_weight_gradients()
Y
Yuang Liu 已提交
321 322 323
        if self._enable_timer:
            self.timers("allreduce_shared_weight_gradients").stop()
            self.timers("broadcast_final_loss").start()
324 325
        with paddle.amp.auto_cast(enable=False):
            train_loss = self._broadcast_final_loss()
Y
Yuang Liu 已提交
326 327 328
        if self._enable_timer:
            self.timers("broadcast_final_loss").stop()
        self.timer_printer()
329 330
        return train_loss

331 332 333 334
    def _prepare_training(self, data, optimizer, lr_scheduler):
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

335 336 337
        assert isinstance(
            optimizer, HybridParallelOptimizer
        ), 'optimizer should be HybridParallelOptimizer subclass.'
338

339
        assert (
340
            framework._dygraph_tracer()._has_grad
341
        ), 'Please enable the generation of gradients.'
342

343
        if self.is_pipeline_first_stage(
344 345 346 347 348
            ignore_virtual=True
        ) or self.is_pipeline_last_stage(ignore_virtual=True):
            assert (
                data is not None
            ), "For the first and the last stage, the data must be set."
349 350 351 352 353 354 355 356
        else:
            data = None

        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler

        self._layers.train()

357 358 359 360 361
        if self._sharding_comm_overlap and len(self._comm_buffers) == 0:
            self.register_allreduce_overlap_hook(
                self._layers, self.sharding_group, self.accumulate_steps, False
            )

362 363 364 365 366
        return data

    def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
        data = self._prepare_training(data, optimizer, lr_scheduler)
        # 1f1b scheduler for pipeline parallel
367
        train_loss = self.forward_backward_pipeline(data, scaler)
368 369

        # optimizer
370 371
        with paddle.amp.auto_cast(enable=False):
            self._optimizer_step()
372 373

        return train_loss
374

375
    def eval_batch(self, data, compute_loss=False):
376 377 378
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

379 380 381 382 383 384 385 386 387 388 389
        self._layers.eval()
        self._compute_loss = compute_loss

        # save data for eval
        self.data = data
        # store data id for micro_batch
        self.micro_batch_id = 0

        # store total loss of entire batch
        self.total_loss = None

390
        startup_steps = self.num_stages - self.stage_id - 1
391 392 393 394 395 396 397
        startup_steps = min(startup_steps, self.accumulate_steps)
        steady_steps = self.accumulate_steps - startup_steps

        input_buffers = []
        output_buffers = []

        for step_id in range(startup_steps):
398
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
399 400

            output_tensor = self._forward_step(input_tensor)
401
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
402 403 404 405 406

            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)

        if steady_steps > 0:
407
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
408 409

        for i in range(steady_steps):
410
            last_iter = i == (steady_steps - 1)
411 412

            output_tensor = self._forward_step(input_tensor)
413
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
414 415 416 417 418

            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)

            if not last_iter:
419
                input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
420

421 422 423 424 425 426
        if self._compute_loss:
            self.train_loss = self._broadcast_final_loss()
        else:
            self.train_loss = output_buffers

        return self.train_loss
427

428
    def _forward_step(self, input_tensor, chunk_id=None):
Y
Yuang Liu 已提交
429 430
        if self._enable_timer:
            self.timers("forward_step").start()
431
        if self.is_pipeline_first_stage():
432 433
            input_tensor = self._load_micro_batch(self.micro_batch_id)

434 435 436
        assert chunk_id is None or isinstance(chunk_id, int)

        output_tensor = self._layers.forward(input_tensor, chunk_id=chunk_id)
437

438
        if self.is_pipeline_last_stage():
439 440
            # train calculate loss for train
            if self._compute_loss:
441 442 443
                assert (
                    self._layers._loss_fn is not None
                ), "loss function should exist to compute loss"
444 445
                labels = self._load_micro_batch(self.micro_batch_id)
                output_tensor = self._layers._loss_fn(output_tensor, labels)
446
                assert isinstance(
447
                    output_tensor, (paddle.Tensor, framework.core.eager.Tensor)
448
                ), "Currently, loss_fn should obtain Paddle.Tensor dtype"
449

450
                with paddle.amp.auto_cast(enable=False):
451
                    if self.accumulate_steps > 1 and not self._delay_scale_loss:
452
                        output_tensor = output_tensor / self.accumulate_steps
453

454 455 456
                    if self.total_loss is None:
                        self.total_loss = paddle.zeros_like(output_tensor)
                    self.total_loss += output_tensor.detach()
457

458 459 460 461
        if self.is_pipeline_first_stage() or self.is_pipeline_last_stage():
            # Only increase micro batch id at virtual first/last pp stage.
            # The micro batch id is used to load data, therefore, only increase it when load data.
            self.micro_batch_id += 1
Y
Yuang Liu 已提交
462 463
        if self._enable_timer:
            self.timers("forward_step").stop()
464 465 466
        return output_tensor

    def _backward_step(self, input_tensor, output_tensor, output_tensor_grad):
Y
Yuang Liu 已提交
467 468
        if self._enable_timer:
            self.timers("backward_step").start()
S
ShenLiang 已提交
469
        with paddle.amp.auto_cast(enable=False):
470
            if self.is_pipeline_last_stage():
S
ShenLiang 已提交
471 472 473 474 475
                assert output_tensor_grad is None
                if self.scaler:
                    paddle.autograd.backward(self.scaler.scale(output_tensor))
                else:
                    paddle.autograd.backward(output_tensor)
476
            else:
S
ShenLiang 已提交
477 478 479 480 481
                if isinstance(output_tensor, tuple):
                    outputs = [t for t in output_tensor if not t.stop_gradient]
                    assert len(outputs) == len(output_tensor_grad)
                    paddle.autograd.backward(
                        tensors=outputs,
482
                        grad_tensors=list(output_tensor_grad),
483
                    )
S
ShenLiang 已提交
484
                else:
485 486 487 488
                    paddle.autograd.backward(
                        tensors=[output_tensor],
                        grad_tensors=[output_tensor_grad],
                    )
S
ShenLiang 已提交
489 490 491 492 493

            input_tensor_grad = None
            if input_tensor is not None:
                if isinstance(input_tensor, tuple):
                    input_tensor_grad = tuple(
494 495
                        [t.grad for t in input_tensor if not t.stop_gradient]
                    )
S
ShenLiang 已提交
496 497
                else:
                    input_tensor_grad = input_tensor.grad
Y
Yuang Liu 已提交
498 499
            if self._enable_timer:
                self.timers("backward_step").stop()
S
ShenLiang 已提交
500
            return input_tensor_grad
501

502 503 504 505 506 507 508 509 510
    def _check_data_vaild(self, data):
        batch_size = data.shape[0]
        assert self.micro_batch_size * self.accumulate_steps == batch_size, (
            "batch_size needs to be divisible by micro_batch_size. Currently, "
            "batch_size = %d, micro_batch_size = %d, accumulate_steps = %d."
            % (batch_size, self.micro_batch_size, self.accumulate_steps)
        )

    def _load_micro_batch_impl(self, inputs, cache_id):
511 512 513
        begin = cache_id * self.micro_batch_size
        end = begin + self.micro_batch_size

514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543
        if isinstance(inputs, tuple):
            output = []
            for data in inputs:
                if isinstance(data, list):
                    assert (
                        len(data) == self.accumulate_steps
                    ), "length of data should be %d, but it is %d" % (
                        self.accumulate_steps,
                        len(data),
                    )
                    output.append(data[cache_id].detach())
                else:
                    self._check_data_vaild(data)
                    output.append(data[begin:end, :].detach())
            return tuple(output)

        elif isinstance(inputs, list):
            assert (
                len(inputs) == self.accumulate_steps
            ), "length of data should be %d, but it is %d" % (
                self.accumulate_steps,
                len(inputs),
            )
            return inputs[cache_id].detach()
        else:
            self._check_data_vaild(inputs)
            return inputs[begin:end, :].detach()

    def _load_micro_batch(self, cache_id):
        inputs = self.data
544
        if self.is_pipeline_first_stage():
545
            assert len(inputs) == 2, "length of input should be 2"
546
            return self._load_micro_batch_impl(inputs[0], cache_id)
547
        elif self.is_pipeline_last_stage():
548
            assert len(inputs) == 2, "length of input should be 2"
549
            return self._load_micro_batch_impl(inputs[1], cache_id)
550 551
        else:
            inputs = None
552

553
    def _broadcast_final_loss(self):
554 555 556
        # Since the last backward run in interleave will set the virtual rank to 0,
        # here we need to check last stage ignoring virtual stage.
        if self.is_pipeline_last_stage(ignore_virtual=True):
557 558 559
            assert (
                self.total_loss is not None
            ), "train_batch() in last stage should obtain vaild loss"
560 561 562 563 564
            loss = (
                self.total_loss.detach()
                if not self._delay_scale_loss
                else self.total_loss / self.accumulate_steps
            )
565
            is_fp32 = (
566
                paddle.full([], 1, 'int64')
567
                if loss.dtype == paddle.float32
568
                else paddle.full([], 0, 'int64')
569 570 571 572 573 574 575
            )
            paddle.distributed.broadcast(
                is_fp32, src=self.global_rank, sync_op=True, group=self.pp_group
            )
            paddle.distributed.broadcast(
                loss, src=self.global_rank, sync_op=True, group=self.pp_group
            )
576
        else:
577
            is_fp32 = paddle.full([], 1, 'int64')
578 579 580
            paddle.distributed.broadcast(
                is_fp32,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
581
                sync_op=True,
582 583 584 585
                group=self.pp_group,
            )
            loss = (
                paddle.zeros(shape=[1], dtype="float32")
586
                if is_fp32.item()
587 588
                else paddle.zeros(shape=[1], dtype="float16")
            )
589 590 591
            paddle.distributed.broadcast(
                loss,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
592
                sync_op=True,
593 594
                group=self.pp_group,
            )
595
        return loss
596

597
    def _optimizer_step(self):
598 599 600 601 602 603 604 605
        if self._delay_scale_loss:
            for p in self._layers.parameters():
                if hasattr(p, "main_grad") and p.main_grad is not None:
                    assert p.grad is None
                    p.main_grad = p.main_grad.scale(1.0 / self.accumulate_steps)
                elif p.grad is not None:
                    p.grad = p.grad.scale(1.0 / self.accumulate_steps)

606
        if self.scaler:
607
            self.scaler.step(self.optimizer)
S
ShenLiang 已提交
608
            self.scaler.update()
609 610
        else:
            self.optimizer.step()
611

612 613 614
        self.optimizer.clear_grad()
        if self.lr_scheduler:
            self.lr_scheduler.step()
615

616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631
    def _release_output(self, output):
        def can_free(t):
            return (
                t is not None
                and isinstance(t, paddle.Tensor)
                and t.inplace_version == 0
            )

        if isinstance(output, (tuple, list)):
            for t in output:
                if can_free(t):
                    t._clear_dataptr()

        elif can_free(output):
            output._clear_dataptr()

632 633 634 635 636

class PipelineParallelWithInterleave(PipelineParallel):
    # pipeline parallel with interleave scheduler

    def __init__(self, layers, hcg, strategy):
637
        super().__init__(layers=layers, hcg=hcg, strategy=strategy)
638
        assert layers.get_num_virtual_stages() > 1
639
        assert (
640
            framework.in_dynamic_mode()
641
        ), "virtual pipeline stage with interleave only support eager dygraph mode"
642 643 644
        assert (
            self.accumulate_steps % self.num_stages == 0
        ), "accumulate_steps should be evenly divisible by num_stages for pipeline with interleave"
645 646 647 648 649 650 651 652 653
        # setup for interleave scheduler
        self.num_model_chunks = layers.get_num_virtual_stages()
        self.model_chunks = layers.get_model_chunks()
        assert self.model_chunks is not None
        assert len(self.model_chunks) == self.num_model_chunks
        self._virtual_pp_world_size = self.num_model_chunks
        self._virtual_pp_rank = 0

    def _get_virtual_pp_rank(self, micro_step, forward):
654 655 656
        virtual_pp_stage = micro_step % (
            self.num_stages * self.num_model_chunks
        )
657 658
        virtual_pp_stage = virtual_pp_stage // self.num_stages
        if not forward:
659
            virtual_pp_stage = self.num_model_chunks - virtual_pp_stage - 1
660 661 662 663 664 665 666 667 668 669 670
        return virtual_pp_stage

    def _forward_step_helper(self, micro_step):
        virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=True)
        self.set_virtual_pipeline_rank(virtual_pp_rank)

        # some checkers
        assert hasattr(self, 'input_tensors')
        assert hasattr(self, 'output_tensors')
        if not self._forward_only:
            assert hasattr(self, 'output_tensor_grads')
671 672 673
        assert len(self.input_tensors[virtual_pp_rank]) == (
            len(self.output_tensors[virtual_pp_rank]) + 1
        )
674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693
        input_tensor = self.input_tensors[virtual_pp_rank][-1]
        output_tensor = self._forward_step(input_tensor, virtual_pp_rank)
        self.output_tensors[virtual_pp_rank].append(output_tensor)

        if self._forward_only:
            # no need to store tensor for backward
            self.input_tensors[virtual_pp_rank].pop()
            self.output_tensors[virtual_pp_rank].pop()

        return output_tensor

    def _backward_step_helper(self, micro_step):
        virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=False)
        self.set_virtual_pipeline_rank(virtual_pp_rank)

        # some checkers
        assert hasattr(self, 'input_tensors')
        assert hasattr(self, 'output_tensors')
        assert hasattr(self, 'output_tensor_grads')

694 695 696 697 698 699
        assert (
            len(self.output_tensor_grads[virtual_pp_rank]) == 1
        ), f"output_tensor_grads is empty for virtual_pp_rank {virtual_pp_rank}"

        assert len(self.input_tensors[virtual_pp_rank]) > 0
        assert len(self.output_tensors[virtual_pp_rank]) > 0
700 701 702 703

        input_tensor = self.input_tensors[virtual_pp_rank].pop(0)
        output_tensor = self.output_tensors[virtual_pp_rank].pop(0)
        output_tensor_grad = self.output_tensor_grads[virtual_pp_rank].pop(0)
704 705 706
        input_tensor_grad = self._backward_step(
            input_tensor, output_tensor, output_tensor_grad
        )
707 708 709

        return input_tensor_grad

710
    def forward_backward_pipeline(
711 712
        self, data, scaler, forward_only=False, compute_loss=True
    ):
713 714 715 716
        # use interleave scheduling strategy.
        # this strategy is inspired by:
        # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
        if not compute_loss:
717 718 719
            assert (
                not forward_only
            ), "compute_loss can only be set to False when forward_only is set to True"
720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737

        # init some attributes for this batch run
        self.scaler = scaler
        self.data = data
        self.total_loss = None
        self.micro_batch_id = 0
        self._forward_only = forward_only

        # init some data buffers for interleave scheduler
        self.input_tensors = [[] for _ in range(self.num_model_chunks)]
        self.output_tensors = [[] for _ in range(self.num_model_chunks)]
        self.output_tensor_grads = [[] for _ in range(self.num_model_chunks)]

        num_steps = self.accumulate_steps * self.num_model_chunks
        if forward_only:
            # If only forward, since there is no backward during running, all steps are startup steps
            startup_steps = num_steps
        else:
738 739 740 741 742 743 744
            # actually startup_steps is calculated from two number:
            # first_forward_cross_to_end = (self.num_stages - self.stage_id - 1) + (self.num_model_chunks - 1) * self.num_stages
            # end_to_first_backward_cross = (self.num_stages - self.stage_id - 1)
            # startup_steps = first_forward_cross_to_end + end_to_first_backward_cross
            startup_steps = (self.num_stages - self.stage_id - 1) * 2
            startup_steps += (self.num_model_chunks - 1) * self.num_stages
            startup_steps = min(startup_steps, num_steps)
745 746 747 748 749

        steady_steps = num_steps - startup_steps

        self.set_virtual_pipeline_rank(0)
        self.input_tensors[0].append(
750 751
            p2p.recv_forward(self.is_pipeline_first_stage(), sync_recv=False)
        )
752 753 754 755 756 757

        # run startup steps
        for micro_step in range(startup_steps):
            output_tensor = self._forward_step_helper(micro_step)

            # determine whether recv forward tensor or not
758 759 760
            next_virtual_pp_rank = self._get_virtual_pp_rank(
                micro_step + 1, forward=True
            )
761 762 763 764 765 766 767 768 769 770 771 772 773
            recv_prev = True
            if self.is_pipeline_first_stage(ignore_virtual=True):
                if next_virtual_pp_rank == 0:
                    # next chunk is the first chunk, not need to pre recv an input tensor
                    recv_prev = False
            # last micro step, no next run
            if micro_step == (num_steps - 1):
                recv_prev = False

            # last stage shouldn't send tensor to downstream
            if self.is_pipeline_last_stage():
                output_tensor = None

774
            if micro_step == (startup_steps - 1) and not forward_only:
775 776 777 778 779 780
                input_tensor_grad = None
                recv_next = True
                if self.is_pipeline_last_stage(ignore_virtual=True):
                    recv_next = False

                # the last startup step needs on four direction comm to set up for steady 1f1b
781 782 783 784
                (
                    input_tensor,
                    output_tensor_grad,
                ) = p2p.send_forward_backward_recv_forward_backward(
785 786 787
                    output_tensor,
                    input_tensor_grad,
                    recv_prev=recv_prev,
788 789
                    recv_next=recv_next,
                )
790 791
                # output_tensor_grad is not none if recv_next
                # append output_tensor_grad no matter none or not
792 793 794
                self.output_tensor_grads[self.num_model_chunks - 1].append(
                    output_tensor_grad
                )
795 796
            else:
                input_tensor = p2p.send_forward_recv_forward(
797 798
                    output_tensor, recv_prev=recv_prev
                )
799
            # append input_tensor no matter none or not
800 801
            self.input_tensors[next_virtual_pp_rank].append(input_tensor)

802 803
            self._release_output(output_tensor)

804 805 806 807 808 809 810 811 812
        # run 1f1b steady steps
        for micro_step in range(steady_steps):
            # forward
            forward_micro_step_id = micro_step + startup_steps
            output_tensor = self._forward_step_helper(forward_micro_step_id)

            # backward
            backward_micro_step_id = micro_step
            input_tensor_grad = self._backward_step_helper(
813 814
                backward_micro_step_id
            )
815 816 817 818 819 820 821 822 823

            # four directions comm
            # send output tensor to downstream
            # send input tensor grad to upstream
            # recv input tensor from upstream
            # recv output tensor grad from downstream

            # last stage doesn't send rst to downstream
            forward_virtual_pp_rank = self._get_virtual_pp_rank(
824 825
                forward_micro_step_id, forward=True
            )
826 827 828 829 830 831
            self.set_virtual_pipeline_rank(forward_virtual_pp_rank)
            if self.is_pipeline_last_stage():
                output_tensor = None

            # first stage doesn't send grad to upstream
            backward_virtual_pp_rank = self._get_virtual_pp_rank(
832 833
                backward_micro_step_id, forward=False
            )
834 835 836 837 838 839
            self.set_virtual_pipeline_rank(backward_virtual_pp_rank)
            if self.is_pipeline_first_stage():
                input_tensor_grad = None

            # determine whether to recv input tensor from upstream
            recv_prev = True
840 841 842 843 844 845 846 847
            next_forward_virtual_pp_rank = self._get_virtual_pp_rank(
                forward_micro_step_id + 1, forward=True
            )
            if self.is_pipeline_first_stage(ignore_virtual=True) and (
                next_forward_virtual_pp_rank == 0
            ):
                # first pp stage and first virtual stage
                recv_prev = False
848 849 850 851 852 853 854

            # last iteration doesn't need recv from upstream
            if micro_step == (steady_steps - 1):
                recv_prev = False

            # determine whether to recv grad from downstream
            recv_next = True
855 856 857 858 859 860 861 862
            next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
                backward_micro_step_id + 1, forward=False
            )
            if self.is_pipeline_last_stage(ignore_virtual=True) and (
                next_backward_virtual_pp_rank == (self.num_model_chunks - 1)
            ):
                # last pp stage and last virtual stage
                recv_next = False
863

864 865 866 867
            (
                input_tensor,
                output_tensor_grad,
            ) = p2p.send_forward_backward_recv_forward_backward(
868 869 870
                output_tensor,
                input_tensor_grad,
                recv_prev=recv_prev,
871 872
                recv_next=recv_next,
            )
873 874 875 876 877 878 879 880
            # append input_tensor no matter none or not
            self.input_tensors[next_forward_virtual_pp_rank].append(
                input_tensor
            )
            # append output_tensor_grad no matter none or not
            self.output_tensor_grads[next_backward_virtual_pp_rank].append(
                output_tensor_grad
            )
881 882 883
            self._release_output(output_tensor)

        self._release_output(output_tensor)
884 885 886 887 888 889 890

        # remaining backward steps
        if not forward_only:
            for micro_step in range(steady_steps, num_steps):
                # cooldown loop
                input_tensor_grad = self._backward_step_helper(micro_step)
                next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
891 892
                    micro_step + 1, forward=False
                )
893 894 895

                recv_next = True
                if self.is_pipeline_last_stage(ignore_virtual=True):
896 897 898
                    if next_backward_virtual_pp_rank == (
                        self.num_model_chunks - 1
                    ):
899 900 901 902
                        recv_next = False

                if micro_step == (num_steps - 1):
                    recv_next = False
903
                # append output_tensor_grad no matter none or not
904
                self.output_tensor_grads[next_backward_virtual_pp_rank].append(
905 906 907 908
                    p2p.send_backward_recv_backward(
                        input_tensor_grad, recv_next=recv_next
                    )
                )
909

910 911 912
            if self._comm_overlap:
                assert len(self._comm_buffers) > 0
                for buffer in self._comm_buffers:
913 914
                    buffer.scale_and_split_grads()

Y
Yuang Liu 已提交
915 916
            if self._enable_timer:
                self.timers("allreduce_shared_weight_gradients").start()
917
            self._layers.allreduce_shared_weight_gradients()
Y
Yuang Liu 已提交
918 919
            if self._enable_timer:
                self.timers("allreduce_shared_weight_gradients").stop()
920 921 922

        if compute_loss:
            # return loss if compute loss
Y
Yuang Liu 已提交
923 924
            if self._enable_timer:
                self.timers("broadcast_final_loss").start()
925 926
            with paddle.amp.auto_cast(enable=False):
                train_loss = self._broadcast_final_loss()
Y
Yuang Liu 已提交
927 928
            if self._enable_timer:
                self.timers("broadcast_final_loss").stop()
929 930 931 932
        else:
            # else just return all intermediate output tensor for all micro steps
            train_loss = self.output_tensors

Y
Yuang Liu 已提交
933
        self.timer_printer()
934 935 936 937 938
        return train_loss

    def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
        data = self._prepare_training(data, optimizer, lr_scheduler)
        # interleave scheduler for pipeline parallel
939
        train_loss = self.forward_backward_pipeline(data, scaler)
940 941 942 943 944 945 946 947 948 949 950 951 952 953

        # optimizer
        with paddle.amp.auto_cast(enable=False):
            self._optimizer_step()

        return train_loss

    def eval_batch(self, data, compute_loss=False):
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

        self._layers.eval()
        self._compute_loss = compute_loss

954
        return self.forward_backward_pipeline(data, None, forward_only=True)