pipeline_parallel.py 43.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

zhenhailiu's avatar
zhenhailiu 已提交
14
import os
15 16
import sys
from collections import defaultdict
zhenhailiu's avatar
zhenhailiu 已提交
17

18
import paddle
19
from paddle import framework
20

21
from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer
22
from ..utils import timer_helper as timer
23 24 25 26 27 28 29 30
from ..utils.hybrid_parallel_util import (
    broadcast_dp_parameters,
    broadcast_mp_parameters,
    broadcast_sharding_parameters,
)
from ..utils.log_util import logger
from .meta_parallel_base import MetaParallelBase
from .parallel_layers.pp_layers import PipelineLayer
S
ShenLiang 已提交
31
from .pp_utils import p2p_communication as p2p
32
from .pp_utils.utils import HOOK_ACTION, FusedCommBuffer, assign_group_by_size
33

34 35
__all__ = []

36
g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 1))
zhenhailiu's avatar
zhenhailiu 已提交
37

38

39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
# assume only the first stage and last stage need data, and data consumption are ordred;
# to be replaced by real micro dataset from reader
class FakeMicroDataset:
    def __init__(
        self, data, is_first_stage, is_last_stage, acc_steps, micro_batch_size
    ):
        self._data = data
        self._index = 0
        self._acc_steps = acc_steps
        self._is_first_stage = is_first_stage
        self._is_last_stage = is_last_stage
        self._micro_batch_size = micro_batch_size

    def __iter__(self):
        return self

    def __next__(self):
        assert self._index < self._acc_steps
        assert self._is_first_stage or self._is_last_stage
        micro_batch_data = self._load_micro_batch(self._index)
        self._index += 1
        return micro_batch_data

    def _load_micro_batch(self, micro_step):
        inputs = self._data

S
ShenLiang 已提交
65 66 67
        data = None
        label = None
        if self._is_first_stage:
68 69
            assert len(inputs) == 2, "length of input should be 2"
            data = self._load_micro_batch_impl(inputs[0], micro_step)
S
ShenLiang 已提交
70 71 72

        if self._is_last_stage:
            assert len(inputs) == 2, "length of input should be 2"
73
            label = self._load_micro_batch_impl(inputs[1], micro_step)
S
ShenLiang 已提交
74 75

        return (data, label)
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90

    def _load_micro_batch_impl(self, inputs, micro_step):
        begin = micro_step * self._micro_batch_size
        end = begin + self._micro_batch_size

        if isinstance(inputs, tuple):
            output = []
            for data in inputs:
                if isinstance(data, list):
                    assert (
                        len(data) == self._acc_steps
                    ), "length of data should be %d, but it is %d" % (
                        self._acc_steps,
                        len(data),
                    )
91 92 93 94 95
                    output.append(
                        data[micro_step].detach()
                        if data[micro_step] is not None
                        else None
                    )
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
                elif data is not None:
                    self._check_data_vaild(data)
                    output.append(data[begin:end, :].detach())
                else:
                    output.append(None)
            return tuple(output)

        elif isinstance(inputs, list):
            assert (
                len(inputs) == self._acc_steps
            ), "length of data should be %d, but it is %d" % (
                self.accumulate_steps,
                len(inputs),
            )
            return inputs[micro_step].detach()
        elif inputs is not None:
            self._check_data_vaild(inputs)
            return inputs[begin:end, :].detach()
        else:
            return None

    def _check_data_vaild(self, data):
        batch_size = data.shape[0]
        assert self._micro_batch_size * self._acc_steps == batch_size, (
            "batch_size needs to be divisible by micro_batch_size. Currently, "
            "batch_size = %d, micro_batch_size = %d, accumulate_steps = %d."
            % (batch_size, self._micro_batch_size, self._acc_steps)
        )


126 127
class PipelineParallel(MetaParallelBase):
    def __init__(self, layers, hcg, strategy):
128 129
        if not isinstance(layers, PipelineLayer):
            raise TypeError(
130 131
                "The Layer should be a derived class of PipelineLayer."
            )
132
        super().__init__(layers, hcg, strategy)
133 134
        self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1
        self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1
135 136 137
        self.use_sharding_parallel = (
            self._hcg.get_sharding_parallel_world_size() > 1
        )
138 139 140 141

        self.total_loss = None

        self.micro_batch_size = self._strategy.pipeline_configs[
142 143
            'micro_batch_size'
        ]
144
        self.accumulate_steps = self._strategy.pipeline_configs[
145 146
            'accumulate_steps'
        ]
147 148 149
        # If sent tensor are not the same from different hosts,
        # they shouldn't been sent partially and then concated as a whole tensor.
        self._enable_partial_send_recv = self._strategy.pipeline_configs[
150 151
            'enable_partial_send_recv'
        ]
152 153
        self._using_cache = self._strategy.pipeline_configs['p2p_cache_shape']

154 155
        self.num_stages = self._hcg.get_pipe_parallel_world_size()
        self.stage_id = self._hcg.get_stage_id()
156
        self.pp_group = self._hcg.get_pipe_parallel_group()
157
        self.dp_group = self._hcg.get_data_parallel_group()
158
        self.sharding_group = self._hcg.get_sharding_parallel_group()
159

160 161 162 163 164
        self._virtual_pp_world_size = None
        self._virtual_pp_rank = None
        self._real_pp_world_size = self.num_stages
        self._real_pp_rank = self.stage_id

165 166 167
        self._delay_scale_loss = self._strategy.hybrid_configs[
            "pp_configs"
        ].delay_scale_loss
168 169
        # TODO(PP Dev): support dp_comm_overlap without use_main_grad training.
        # This combination will trigger inplace check error during `reshape_` in funct `_split_tensors`.
170 171 172
        self._dp_comm_overlap = self._strategy.hybrid_configs[
            "pp_configs"
        ].dp_comm_overlap
173 174 175 176 177 178
        self._sharding_comm_overlap = self._strategy.hybrid_configs[
            "pp_configs"
        ].sharding_comm_overlap
        self._enable_timer = self._strategy.hybrid_configs[
            "pp_configs"
        ].enable_timer
179

H
Haohongxiang 已提交
180 181 182
        if self._dp_comm_overlap:
            assert self.use_data_parallel and self.num_stages > 1

183 184 185 186 187 188 189
        if self._sharding_comm_overlap:
            assert self.use_sharding_parallel and self.num_stages > 1

        assert not (
            self._dp_comm_overlap and self._sharding_comm_overlap
        ), "Cannot use dp pp overlap and sharding pp overlap at the same time."

190
        self._chunk_2_comm_buffers = defaultdict(list)
191 192 193 194 195 196 197 198 199
        self._comm_overlap = (
            self._dp_comm_overlap or self._sharding_comm_overlap
        )

        if self._enable_timer:
            if not timer.is_timer_initialized():
                timer.set_timers()
            self.timers = timer.get_timers()

200
        p2p.initialize_p2p_groups(
201 202 203
            hcg,
            self._enable_partial_send_recv,
            self._enable_timer,
204
        )
205

206 207 208
        # construct pipeline meta info
        self._p2p_helper = p2p.P2pHelper(self._using_cache)

209
        self.global_rank = self._hcg.get_global_rank()
210
        self.micro_batch_id = 0
211

212 213
        self._compute_loss = True

214 215 216 217 218
        logger.info(
            "Pipeline Info -- num_stages: {}, stage_id: {}".format(
                self.num_stages, self.stage_id
            )
        )
219 220 221 222 223

        if self.use_model_parallel:
            logger.info("start broadcast mp parameters")
            broadcast_mp_parameters(self._layers, self._hcg)

224 225 226 227
        if self.use_sharding_parallel:
            logger.info("start broadcast sharding parameters")
            broadcast_sharding_parameters(self._layers, self._hcg)

228
        if self.use_data_parallel:
229
            logger.info("start broadcast dp parameters")
230
            broadcast_dp_parameters(self._layers, self._hcg)
231

232 233
        if self._dp_comm_overlap:
            self.register_allreduce_overlap_hook(
234
                self._layers, self.dp_group, self.accumulate_steps, True
235 236
            )

237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258
    def is_pipeline_first_stage(self, ignore_virtual=False):
        if not ignore_virtual:
            if self._virtual_pp_world_size is not None:
                assert self._virtual_pp_rank is not None
                if self._virtual_pp_rank != 0:
                    return False
        assert self._real_pp_rank is not None
        return self._real_pp_rank == 0

    def is_pipeline_last_stage(self, ignore_virtual=False):
        if not ignore_virtual:
            if self._virtual_pp_world_size is not None:
                assert self._virtual_pp_rank is not None
                if self._virtual_pp_rank != (self._virtual_pp_world_size - 1):
                    return False
        assert self._real_pp_rank is not None
        assert self._real_pp_world_size is not None
        return self._real_pp_rank == (self._real_pp_world_size - 1)

    def set_virtual_pipeline_rank(self, rank):
        self._virtual_pp_rank = rank

259 260 261 262 263 264 265
    def bw_hook_func(self, buffer, param):
        @paddle.autograd.no_grad()
        def fused_allreduce(*_):
            buffer.add_grad(param)

        return fused_allreduce

266 267 268
    def register_allreduce_overlap_hook(
        self, model, comm_group, acc_steps, dp, group_size=128 * 1024 * 1024
    ):
269 270 271 272 273
        if model.get_num_virtual_stages() > 1:
            models = model.get_model_chunks()
        else:
            models = [model]

274 275 276 277
        if not dp:
            assert hasattr(self, "optimizer")
            assert hasattr(self.optimizer, "_param2rank")
            _param2rank = self.optimizer._param2rank
zhenhailiu's avatar
zhenhailiu 已提交
278 279 280 281 282 283
        # Note: after sharding change to reduce operation, here need to be cleared
        act = (
            HOOK_ACTION.ALL_REDUCE
            if (dp or not g_shard_use_reduce)
            else HOOK_ACTION.REDUCE
        )
284

285
        for chunk_idx, model in enumerate(models):
286 287
            # For virtual pipeline. Will separate parameters in different chunk into
            # different groups to get the best performance.
288

289 290
            fused_parameter_group = {}

291 292 293 294 295 296
            parameter_list = [
                p for p in model.parameters() if not p.stop_gradient
            ]
            if len(parameter_list) < 1:
                return

297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313
            if dp:
                fused_parameter_group[-1] = parameter_list
            else:
                # Sort parameters for sharding, since they have different dst rank
                for p in parameter_list:
                    assert p.name in _param2rank
                    dst_rank = _param2rank[p.name]
                    if dst_rank in fused_parameter_group:
                        fused_parameter_group[dst_rank].append(p)
                    else:
                        fused_parameter_group[dst_rank] = [p]

            for dst in fused_parameter_group:
                parameter_list = fused_parameter_group[dst]
                if not dp:
                    # parse the relative dst rank to absolute dst rank for sharding
                    dst = comm_group.ranks[dst]
314
                var_groups = assign_group_by_size(parameter_list, group_size)
315 316 317
                for group_idx, parameters in var_groups.items():
                    buffer = FusedCommBuffer(
                        group_idx, parameters, comm_group, acc_steps, act, dst
318
                    )
319
                    self._chunk_2_comm_buffers[chunk_idx].append(buffer)
320 321 322 323 324 325 326 327 328 329
                    for param in parameters:
                        param._register_backward_hook(
                            self.bw_hook_func(buffer, param)
                        )

    def timer_printer(self):
        if not self._enable_timer:
            return
        all_flag_names = self.timers.timers.keys()
        self.timers.log(all_flag_names)
330

331 332 333 334
    def forward_backward_pipeline(self, data, scaler=None):
        # use the 1f1b scheduling strategy.
        # this strategy is inspired by:
        # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
335

336 337
        self.scaler = scaler

338 339 340
        # store total loss of entire batch
        self.total_loss = None

341 342
        # store data id for micro_batch
        self.micro_batch_id = 0
343

344
        startup_steps = self.num_stages - self.stage_id - 1
345 346
        startup_steps = min(startup_steps, self.accumulate_steps)
        steady_steps = self.accumulate_steps - startup_steps
347

348 349
        input_buffers = []
        output_buffers = []
350

351 352
        micro_dataset = self._wrap_data(data)

353
        for step_id in range(startup_steps):
354 355 356
            input_tensor = self._p2p_helper.recv_forward(
                self.is_pipeline_first_stage()
            )
357

358
            output_tensor = self._forward_step(input_tensor, micro_dataset)
359 360 361
            self._p2p_helper.send_forward(
                output_tensor, self.is_pipeline_last_stage()
            )
362

363 364
            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)
365

366 367 368
            if not self.is_pipeline_last_stage():
                self._release_output(output_tensor)

369
        if steady_steps > 0:
370 371 372
            input_tensor = self._p2p_helper.recv_forward(
                self.is_pipeline_first_stage()
            )
373

374
        for i in range(steady_steps):
375
            last_iter = i == (steady_steps - 1)
376

377
            output_tensor = self._forward_step(input_tensor, micro_dataset)
378

379
            output_tensor_grad = self._p2p_helper.send_forward_recv_backward(
380 381
                output_tensor, self.is_pipeline_last_stage()
            )
382

383 384
            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)
385

386 387 388
            if not self.is_pipeline_last_stage():
                self._release_output(output_tensor)

389
            input_tensor, output_tensor = input_buffers.pop(
390 391
                0
            ), output_buffers.pop(0)
392

393 394 395
            input_tensor_grad = self._backward_step(
                input_tensor, output_tensor, output_tensor_grad
            )
396 397 398

            if last_iter:
                input_tensor = None
399
                self._p2p_helper.send_backward(
400 401
                    input_tensor_grad, self.is_pipeline_first_stage()
                )
402
            else:
403
                input_tensor = self._p2p_helper.send_backward_recv_forward(
404 405
                    input_tensor_grad, self.is_pipeline_first_stage()
                )
406

407 408 409
        for i in range(startup_steps):
            input_tensor = input_buffers.pop(0)
            output_tensor = output_buffers.pop(0)
410

411
            output_tensor_grad = self._p2p_helper.recv_backward(
412 413
                self.is_pipeline_last_stage()
            )
414

415 416 417
            input_tensor_grad = self._backward_step(
                input_tensor, output_tensor, output_tensor_grad
            )
418 419 420
            self._p2p_helper.send_backward(
                input_tensor_grad, self.is_pipeline_first_stage()
            )
421

422
        if self._comm_overlap:
423 424 425 426 427 428
            assert (
                len(self._chunk_2_comm_buffers) > 0
            ), "comm buffers should be created"
            for _, buffers in self._chunk_2_comm_buffers.items():
                for buffer in buffers:
                    buffer.scale_and_split_grads()
429

430 431
        if self._enable_timer:
            self.timers("allreduce_shared_weight_gradients").start()
432
        self._layers.allreduce_shared_weight_gradients()
433 434 435
        if self._enable_timer:
            self.timers("allreduce_shared_weight_gradients").stop()
            self.timers("broadcast_final_loss").start()
436 437
        with paddle.amp.auto_cast(enable=False):
            train_loss = self._broadcast_final_loss()
438 439 440
        if self._enable_timer:
            self.timers("broadcast_final_loss").stop()
        self.timer_printer()
441 442
        return train_loss

443 444 445 446
    def _prepare_training(self, data, optimizer, lr_scheduler):
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

447 448 449
        assert isinstance(
            optimizer, HybridParallelOptimizer
        ), 'optimizer should be HybridParallelOptimizer subclass.'
450

451
        assert (
452
            framework._dygraph_tracer()._has_grad
453
        ), 'Please enable the generation of gradients.'
454

455
        if self.is_pipeline_first_stage(
456 457 458 459 460
            ignore_virtual=True
        ) or self.is_pipeline_last_stage(ignore_virtual=True):
            assert (
                data is not None
            ), "For the first and the last stage, the data must be set."
461 462 463 464 465 466 467 468
        else:
            data = None

        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler

        self._layers.train()

469
        if self._sharding_comm_overlap and len(self._chunk_2_comm_buffers) == 0:
470 471 472 473
            self.register_allreduce_overlap_hook(
                self._layers, self.sharding_group, self.accumulate_steps, False
            )

474 475
        return data

476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491
    def _wrap_data(self, data):
        """
        for backward compatibilty, wrap data to Fake FakeMicroDataset if it is of type list or tuple
        """
        if (not isinstance(data, tuple)) and (not isinstance(data, list)):
            return data

        micro_dataset = FakeMicroDataset(
            data,
            self.is_pipeline_first_stage(ignore_virtual=True),
            self.is_pipeline_last_stage(ignore_virtual=True),
            self.accumulate_steps,
            self.micro_batch_size,
        )
        return micro_dataset

492 493 494
    def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
        data = self._prepare_training(data, optimizer, lr_scheduler)
        # 1f1b scheduler for pipeline parallel
495
        train_loss = self.forward_backward_pipeline(data, scaler)
496 497

        # optimizer
498 499
        with paddle.amp.auto_cast(enable=False):
            self._optimizer_step()
500 501

        return train_loss
502

503
    def eval_batch(self, data, compute_loss=False):
504 505 506
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

507 508 509 510 511 512 513 514 515
        self._layers.eval()
        self._compute_loss = compute_loss

        # store data id for micro_batch
        self.micro_batch_id = 0

        # store total loss of entire batch
        self.total_loss = None

516
        startup_steps = self.num_stages - self.stage_id - 1
517 518 519 520 521 522
        startup_steps = min(startup_steps, self.accumulate_steps)
        steady_steps = self.accumulate_steps - startup_steps

        input_buffers = []
        output_buffers = []

523 524
        micro_dataset = self._wrap_data(data)

525
        for step_id in range(startup_steps):
526 527 528
            input_tensor = self._p2p_helper.recv_forward(
                self.is_pipeline_first_stage()
            )
529

530
            output_tensor = self._forward_step(input_tensor, micro_dataset)
531 532 533
            self._p2p_helper.send_forward(
                output_tensor, self.is_pipeline_last_stage()
            )
534 535 536 537 538

            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)

        if steady_steps > 0:
539 540 541
            input_tensor = self._p2p_helper.recv_forward(
                self.is_pipeline_first_stage()
            )
542 543

        for i in range(steady_steps):
544
            last_iter = i == (steady_steps - 1)
545

546
            output_tensor = self._forward_step(input_tensor, micro_dataset)
547 548 549
            self._p2p_helper.send_forward(
                output_tensor, self.is_pipeline_last_stage()
            )
550 551 552 553 554

            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)

            if not last_iter:
555 556 557
                input_tensor = self._p2p_helper.recv_forward(
                    self.is_pipeline_first_stage()
                )
558

559 560 561 562 563 564
        if self._compute_loss:
            self.train_loss = self._broadcast_final_loss()
        else:
            self.train_loss = output_buffers

        return self.train_loss
565

566
    def _forward_step(self, input_tensor, micro_dataset, chunk_id=None):
567 568
        if self._enable_timer:
            self.timers("forward_step").start()
569
        if self.is_pipeline_first_stage():
570 571
            input_tensor = next(micro_dataset)[0]
            self._check_micro_batch_data_valid(input_tensor)
572

573 574 575
        assert chunk_id is None or isinstance(chunk_id, int)

        output_tensor = self._layers.forward(input_tensor, chunk_id=chunk_id)
576

577
        if self.is_pipeline_last_stage():
578 579
            # train calculate loss for train
            if self._compute_loss:
580 581 582
                assert (
                    self._layers._loss_fn is not None
                ), "loss function should exist to compute loss"
583 584
                labels = next(micro_dataset)[1]
                self._check_micro_batch_data_valid(labels)
585
                output_tensor = self._layers._loss_fn(output_tensor, labels)
586
                assert isinstance(
587
                    output_tensor, (paddle.Tensor, framework.core.eager.Tensor)
588
                ), "Currently, loss_fn should obtain Paddle.Tensor dtype"
589

590
                with paddle.amp.auto_cast(enable=False):
591
                    if self.accumulate_steps > 1 and not self._delay_scale_loss:
592
                        output_tensor = output_tensor / self.accumulate_steps
593

594 595 596
                    if self.total_loss is None:
                        self.total_loss = paddle.zeros_like(output_tensor)
                    self.total_loss += output_tensor.detach()
597

598 599 600 601
        if self.is_pipeline_first_stage() or self.is_pipeline_last_stage():
            # Only increase micro batch id at virtual first/last pp stage.
            # The micro batch id is used to load data, therefore, only increase it when load data.
            self.micro_batch_id += 1
602 603
        if self._enable_timer:
            self.timers("forward_step").stop()
604 605
        return output_tensor

606 607 608 609 610 611 612 613 614 615
    def _check_micro_batch_data_valid(self, micro_batch_data):
        if isinstance(micro_batch_data, (tuple, list)):
            for data in micro_batch_data:
                self._check_micro_batch_data_valid(data)
        elif micro_batch_data is not None:
            micro_batch_size = micro_batch_data.shape[0]
            assert (
                micro_batch_size == self.micro_batch_size
            ), f"expected micro_batch_size {self.micro_batch_size} but get {micro_batch_size}"

616
    def _backward_step(self, input_tensor, output_tensor, output_tensor_grad):
617 618
        if self._enable_timer:
            self.timers("backward_step").start()
S
ShenLiang 已提交
619
        with paddle.amp.auto_cast(enable=False):
620
            if self.is_pipeline_last_stage():
S
ShenLiang 已提交
621 622 623 624 625
                assert output_tensor_grad is None
                if self.scaler:
                    paddle.autograd.backward(self.scaler.scale(output_tensor))
                else:
                    paddle.autograd.backward(output_tensor)
626
            else:
S
ShenLiang 已提交
627 628 629 630 631
                if isinstance(output_tensor, tuple):
                    outputs = [t for t in output_tensor if not t.stop_gradient]
                    assert len(outputs) == len(output_tensor_grad)
                    paddle.autograd.backward(
                        tensors=outputs,
632
                        grad_tensors=list(output_tensor_grad),
633
                    )
S
ShenLiang 已提交
634
                else:
635 636 637 638
                    paddle.autograd.backward(
                        tensors=[output_tensor],
                        grad_tensors=[output_tensor_grad],
                    )
S
ShenLiang 已提交
639 640 641 642 643

            input_tensor_grad = None
            if input_tensor is not None:
                if isinstance(input_tensor, tuple):
                    input_tensor_grad = tuple(
644 645
                        [t.grad for t in input_tensor if not t.stop_gradient]
                    )
S
ShenLiang 已提交
646 647
                else:
                    input_tensor_grad = input_tensor.grad
648 649
            if self._enable_timer:
                self.timers("backward_step").stop()
S
ShenLiang 已提交
650
            return input_tensor_grad
651

652
    def _broadcast_final_loss(self):
653 654 655
        # Since the last backward run in interleave will set the virtual rank to 0,
        # here we need to check last stage ignoring virtual stage.
        if self.is_pipeline_last_stage(ignore_virtual=True):
656 657 658
            assert (
                self.total_loss is not None
            ), "train_batch() in last stage should obtain vaild loss"
659 660 661 662 663
            loss = (
                self.total_loss.detach()
                if not self._delay_scale_loss
                else self.total_loss / self.accumulate_steps
            )
664
            is_fp32 = (
665
                paddle.full([], 1, 'int64')
666
                if loss.dtype == paddle.float32
667
                else paddle.full([], 0, 'int64')
668 669 670 671 672 673 674
            )
            paddle.distributed.broadcast(
                is_fp32, src=self.global_rank, sync_op=True, group=self.pp_group
            )
            paddle.distributed.broadcast(
                loss, src=self.global_rank, sync_op=True, group=self.pp_group
            )
675
        else:
676
            is_fp32 = paddle.full([], 1, 'int64')
677 678 679
            paddle.distributed.broadcast(
                is_fp32,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
680
                sync_op=True,
681 682 683 684
                group=self.pp_group,
            )
            loss = (
                paddle.zeros(shape=[1], dtype="float32")
685
                if is_fp32.item()
686 687
                else paddle.zeros(shape=[1], dtype="float16")
            )
688 689 690
            paddle.distributed.broadcast(
                loss,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
691
                sync_op=True,
692 693
                group=self.pp_group,
            )
694
        return loss
695

696
    def _optimizer_step(self):
697 698 699 700 701 702 703 704
        if self._delay_scale_loss:
            for p in self._layers.parameters():
                if hasattr(p, "main_grad") and p.main_grad is not None:
                    assert p.grad is None
                    p.main_grad = p.main_grad.scale(1.0 / self.accumulate_steps)
                elif p.grad is not None:
                    p.grad = p.grad.scale(1.0 / self.accumulate_steps)

705
        if self.scaler:
706
            self.scaler.step(self.optimizer)
S
ShenLiang 已提交
707
            self.scaler.update()
708 709
        else:
            self.optimizer.step()
710

711 712 713
        self.optimizer.clear_grad()
        if self.lr_scheduler:
            self.lr_scheduler.step()
714

715
    def _release_output(self, output):
716 717 718 719 720 721 722 723
        def can_free(t):
            return (
                t is not None
                and isinstance(t, paddle.Tensor)
                and t._is_initialized()
                and t.inplace_version == 0
            )

724 725
        if isinstance(output, (tuple, list)):
            for t in output:
726
                if can_free(t):
727
                    t._clear_dataptr()
728 729

        elif can_free(output):
730 731
            output._clear_dataptr()

732 733 734 735 736

class PipelineParallelWithInterleave(PipelineParallel):
    # pipeline parallel with interleave scheduler

    def __init__(self, layers, hcg, strategy):
737
        super().__init__(layers=layers, hcg=hcg, strategy=strategy)
738
        assert layers.get_num_virtual_stages() > 1
739 740
        assert (
            framework.in_dygraph_mode()
741
        ), "virtual pipeline stage with interleave only support eager dygraph mode"
742 743 744 745 746 747
        assert (
            self.num_stages > 2
        ), "virtual pipeline must run under pp degree > 2"
        assert (
            self.accumulate_steps % self.num_stages == 0
        ), "accumulate_steps should be evenly divisible by num_stages for pipeline with interleave"
748 749 750 751 752 753 754
        # setup for interleave scheduler
        self.num_model_chunks = layers.get_num_virtual_stages()
        self.model_chunks = layers.get_model_chunks()
        assert self.model_chunks is not None
        assert len(self.model_chunks) == self.num_model_chunks
        self._virtual_pp_world_size = self.num_model_chunks
        self._virtual_pp_rank = 0
755 756 757 758 759 760 761
        self._assign_vpp_info(self.model_chunks)

    def _assign_vpp_info(self, chunks):
        chunk_num = len(chunks)
        for i, chunk in enumerate(chunks):
            for p in chunk.parameters():
                p._chunk_info = {"chunk_id": i, "chunk_num": chunk_num}
762 763

    def _get_virtual_pp_rank(self, micro_step, forward):
764 765 766
        virtual_pp_stage = micro_step % (
            self.num_stages * self.num_model_chunks
        )
767 768
        virtual_pp_stage = virtual_pp_stage // self.num_stages
        if not forward:
769
            virtual_pp_stage = self.num_model_chunks - virtual_pp_stage - 1
770 771
        return virtual_pp_stage

772
    def _forward_step_helper(self, micro_dataset, micro_step):
773 774 775 776 777 778 779 780 781 782 783
        virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=True)
        self.set_virtual_pipeline_rank(virtual_pp_rank)

        # some checkers
        assert hasattr(self, 'input_tensors')
        assert hasattr(self, 'output_tensors')
        if not self._forward_only:
            assert hasattr(self, 'output_tensor_grads')

        if self.is_pipeline_first_stage():
            if len(self.input_tensors[virtual_pp_rank]) == len(
784 785
                self.output_tensors[virtual_pp_rank]
            ):
786 787
                self.input_tensors[virtual_pp_rank].append(None)
        input_tensor = self.input_tensors[virtual_pp_rank][-1]
788 789 790
        output_tensor = self._forward_step(
            input_tensor, micro_dataset, virtual_pp_rank
        )
791 792 793 794 795 796 797 798 799
        self.output_tensors[virtual_pp_rank].append(output_tensor)

        if self._forward_only:
            # no need to store tensor for backward
            self.input_tensors[virtual_pp_rank].pop()
            self.output_tensors[virtual_pp_rank].pop()

        return output_tensor

800 801 802 803
    def _overlap_comm_grads(self):
        if self._comm_overlap:
            self._backward_step_count += 1
            sync_step = self._backward_step_count - self.stage_id
804
            if sync_step > 0 and sync_step % self.num_stages == 0:
805
                chunk_idx = self._virtual_pp_world_size - (
806
                    sync_step // self.num_stages
807 808 809 810 811 812 813
                )
                for buffer in self._chunk_2_comm_buffers[chunk_idx]:
                    buffer.comm_grads()

            if self.stage_id != 0:
                if (
                    self._backward_step_count
814
                    == self.num_stages * self.num_model_chunks
815 816 817 818 819 820 821 822
                ):
                    for buffer in self._chunk_2_comm_buffers[0]:
                        buffer.comm_grads()

    def _sync_overlap_grads(self):
        if self._comm_overlap:
            assert (
                self._backward_step_count
823 824 825 826
                == self.num_stages * self.num_model_chunks
            ), (
                "backward step count should be equal to accumulate steps * virtual pp world size,"
                f" but get {self._backward_step_count}, excepted result is {self.num_stages * self.num_model_chunks}"
827 828 829 830 831 832
            )

            for _, buffers in self._chunk_2_comm_buffers.items():
                for buffer in buffers:
                    buffer.scale_and_split_grads()

833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848
    def _backward_step_helper(self, micro_step):
        virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=False)
        self.set_virtual_pipeline_rank(virtual_pp_rank)

        # some checkers
        assert hasattr(self, 'input_tensors')
        assert hasattr(self, 'output_tensors')
        assert hasattr(self, 'output_tensor_grads')

        if self.is_pipeline_last_stage():
            if len(self.output_tensor_grads[virtual_pp_rank]) == 0:
                self.output_tensor_grads[virtual_pp_rank].append(None)

        input_tensor = self.input_tensors[virtual_pp_rank].pop(0)
        output_tensor = self.output_tensors[virtual_pp_rank].pop(0)
        output_tensor_grad = self.output_tensor_grads[virtual_pp_rank].pop(0)
849 850 851
        input_tensor_grad = self._backward_step(
            input_tensor, output_tensor, output_tensor_grad
        )
852

853 854
        self._overlap_comm_grads()

855 856
        return input_tensor_grad

857 858 859 860 861 862 863 864 865 866 867 868 869 870
    def bw_hook_func(self, buffer, param):
        # For pipeline with interleave, we need to add grad to buffer without communication.
        # Use communication where appropriate to avoid dp communication and pp scheduling conflicts.
        @paddle.autograd.no_grad()
        def fused_allreduce(*_):
            buffer.add_grad(param, use_comm=False)

        return fused_allreduce

    def register_allreduce_overlap_hook(self, model, comm_group, acc_steps, dp):
        super().register_allreduce_overlap_hook(
            model, comm_group, acc_steps, dp, group_size=sys.maxsize
        )

871
    def forward_backward_pipeline(
872 873
        self, data, scaler, forward_only=False, compute_loss=True
    ):
874 875 876 877
        # use interleave scheduling strategy.
        # this strategy is inspired by:
        # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
        if not compute_loss:
878 879 880
            assert (
                not forward_only
            ), "compute_loss can only be set to False when forward_only is set to True"
881

882 883 884 885 886
        # NOTE(shenliang03): Due to ring_exchange for pipeline with interleave, cache should be enabled
        assert (
            self._using_cache
        ), "cache should be enabled for pipeline with interleave"

887 888 889 890 891 892
        # init some attributes for this batch run
        self.scaler = scaler
        self.total_loss = None
        self.micro_batch_id = 0
        self._forward_only = forward_only

893
        # store the number of backward steps
894 895 896 897 898 899 900 901 902 903 904 905

        assert (
            self.accumulate_steps % self.num_stages == 0
        ), "accumulate_steps({}) should be evenly divisible by num_stages({}) for pipeline with interleave".format(
            self.accumulate_steps, self.num_stages
        )
        per_stage_accumulate_steps = self.accumulate_steps // self.num_stages
        self._backward_step_count = (
            -(per_stage_accumulate_steps - 1)
            * self.num_stages
            * self.num_model_chunks
        )
906

907 908 909 910 911
        # init some data buffers for interleave scheduler
        self.input_tensors = [[] for _ in range(self.num_model_chunks)]
        self.output_tensors = [[] for _ in range(self.num_model_chunks)]
        self.output_tensor_grads = [[] for _ in range(self.num_model_chunks)]

912 913
        micro_dataset = self._wrap_data(data)

914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931
        num_steps = self.accumulate_steps * self.num_model_chunks
        all_startup_steps = False
        if forward_only:
            # If only forward, since there is no backward during running, all steps are startup steps
            startup_steps = num_steps
        else:
            if self.accumulate_steps == self.num_stages:
                startup_steps = num_steps
                all_startup_steps = True
            else:
                startup_steps = (self.num_stages - self.stage_id - 1) * 2
                startup_steps += (self.num_model_chunks - 1) * self.num_stages
                startup_steps = min(startup_steps, num_steps)

        steady_steps = num_steps - startup_steps

        self.set_virtual_pipeline_rank(0)
        self.input_tensors[0].append(
932 933 934
            self._p2p_helper.recv_forward(
                self.is_pipeline_first_stage(), sync_recv=False
            )
935
        )
936 937 938

        # run startup steps
        for micro_step in range(startup_steps):
939
            output_tensor = self._forward_step_helper(micro_dataset, micro_step)
940 941

            # determine whether recv forward tensor or not
942 943 944
            next_virtual_pp_rank = self._get_virtual_pp_rank(
                micro_step + 1, forward=True
            )
945 946 947 948 949 950 951 952 953 954 955 956 957
            recv_prev = True
            if self.is_pipeline_first_stage(ignore_virtual=True):
                if next_virtual_pp_rank == 0:
                    # next chunk is the first chunk, not need to pre recv an input tensor
                    recv_prev = False
            # last micro step, no next run
            if micro_step == (num_steps - 1):
                recv_prev = False

            # last stage shouldn't send tensor to downstream
            if self.is_pipeline_last_stage():
                output_tensor = None

958 959 960 961 962
            if (
                micro_step == (startup_steps - 1)
                and not forward_only
                and not all_startup_steps
            ):
963 964 965 966 967 968
                input_tensor_grad = None
                recv_next = True
                if self.is_pipeline_last_stage(ignore_virtual=True):
                    recv_next = False

                # the last startup step needs on four direction comm to set up for steady 1f1b
969 970 971
                (
                    input_tensor,
                    output_tensor_grad,
972
                ) = self._p2p_helper.send_forward_backward_recv_forward_backward(
973 974 975
                    output_tensor,
                    input_tensor_grad,
                    recv_prev=recv_prev,
976 977 978 979 980
                    recv_next=recv_next,
                )
                self.output_tensor_grads[self.num_model_chunks - 1].append(
                    output_tensor_grad
                )
981
            else:
982
                input_tensor = self._p2p_helper.send_forward_recv_forward(
983 984
                    output_tensor, recv_prev=recv_prev
                )
985 986
            self.input_tensors[next_virtual_pp_rank].append(input_tensor)

987 988
            self._release_output(output_tensor)

989 990 991 992
        # run 1f1b steady steps
        for micro_step in range(steady_steps):
            # forward
            forward_micro_step_id = micro_step + startup_steps
993 994 995
            output_tensor = self._forward_step_helper(
                micro_dataset, forward_micro_step_id
            )
996 997 998 999

            # backward
            backward_micro_step_id = micro_step
            input_tensor_grad = self._backward_step_helper(
1000 1001
                backward_micro_step_id
            )
1002 1003 1004 1005 1006 1007 1008 1009 1010

            # four directions comm
            # send output tensor to downstream
            # send input tensor grad to upstream
            # recv input tensor from upstream
            # recv output tensor grad from downstream

            # last stage doesn't send rst to downstream
            forward_virtual_pp_rank = self._get_virtual_pp_rank(
1011 1012
                forward_micro_step_id, forward=True
            )
1013 1014 1015 1016 1017 1018
            self.set_virtual_pipeline_rank(forward_virtual_pp_rank)
            if self.is_pipeline_last_stage():
                output_tensor = None

            # first stage doesn't send grad to upstream
            backward_virtual_pp_rank = self._get_virtual_pp_rank(
1019 1020
                backward_micro_step_id, forward=False
            )
1021 1022 1023 1024 1025 1026 1027 1028
            self.set_virtual_pipeline_rank(backward_virtual_pp_rank)
            if self.is_pipeline_first_stage():
                input_tensor_grad = None

            # determine whether to recv input tensor from upstream
            recv_prev = True
            if self.is_pipeline_first_stage(ignore_virtual=True):
                next_forward_virtual_pp_rank = self._get_virtual_pp_rank(
1029 1030
                    forward_micro_step_id - (self.num_stages - 1), forward=True
                )
1031 1032 1033 1034 1035 1036
                if next_forward_virtual_pp_rank == (self.num_model_chunks - 1):
                    # first pp stage and first virtual stage
                    recv_prev = False
                next_forward_virtual_pp_rank += 1
            else:
                next_forward_virtual_pp_rank = self._get_virtual_pp_rank(
1037 1038
                    forward_micro_step_id + 1, forward=True
                )
1039 1040 1041 1042 1043 1044 1045 1046 1047 1048

            # last iteration doesn't need recv from upstream
            if micro_step == (steady_steps - 1):
                recv_prev = False

            # determine whether to recv grad from downstream
            recv_next = True
            if self.is_pipeline_last_stage(ignore_virtual=True):
                next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
                    backward_micro_step_id - (self.num_stages - 1),
1049 1050
                    forward=False,
                )
1051 1052 1053 1054 1055 1056
                if next_backward_virtual_pp_rank == 0:
                    # last pp stage and last virtual stage
                    recv_next = False
                next_backward_virtual_pp_rank -= 1
            else:
                next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
1057 1058
                    backward_micro_step_id + 1, forward=False
                )
1059

1060 1061 1062
            (
                input_tensor,
                output_tensor_grad,
1063
            ) = self._p2p_helper.send_forward_backward_recv_forward_backward(
1064 1065 1066
                output_tensor,
                input_tensor_grad,
                recv_prev=recv_prev,
1067 1068
                recv_next=recv_next,
            )
1069

1070 1071
            self._release_output(output_tensor)

1072 1073
            if recv_prev:
                self.input_tensors[next_forward_virtual_pp_rank].append(
1074 1075
                    input_tensor
                )
1076 1077
            if recv_next:
                self.output_tensor_grads[next_backward_virtual_pp_rank].append(
1078 1079
                    output_tensor_grad
                )
1080

1081 1082
        self._release_output(output_tensor)

1083 1084 1085 1086
        # remaining backward steps
        if not forward_only:
            if all_startup_steps:
                self.output_tensor_grads[self.num_model_chunks - 1].append(
1087
                    self._p2p_helper.recv_backward(
1088 1089 1090
                        self.is_pipeline_last_stage(), sync_recv=False
                    )
                )
1091 1092 1093 1094 1095

            for micro_step in range(steady_steps, num_steps):
                # cooldown loop
                input_tensor_grad = self._backward_step_helper(micro_step)
                next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
1096 1097
                    micro_step + 1, forward=False
                )
1098 1099 1100

                recv_next = True
                if self.is_pipeline_last_stage(ignore_virtual=True):
1101 1102 1103
                    if next_backward_virtual_pp_rank == (
                        self.num_model_chunks - 1
                    ):
1104 1105 1106 1107 1108 1109
                        recv_next = False

                if micro_step == (num_steps - 1):
                    recv_next = False

                self.output_tensor_grads[next_backward_virtual_pp_rank].append(
1110
                    self._p2p_helper.send_backward_recv_backward(
1111 1112 1113
                        input_tensor_grad, recv_next=recv_next
                    )
                )
1114

1115
            self._sync_overlap_grads()
1116

1117 1118
            if self._enable_timer:
                self.timers("allreduce_shared_weight_gradients").start()
1119
            self._layers.allreduce_shared_weight_gradients()
1120 1121
            if self._enable_timer:
                self.timers("allreduce_shared_weight_gradients").stop()
1122 1123 1124

        if compute_loss:
            # return loss if compute loss
1125 1126
            if self._enable_timer:
                self.timers("broadcast_final_loss").start()
1127 1128
            with paddle.amp.auto_cast(enable=False):
                train_loss = self._broadcast_final_loss()
1129 1130
            if self._enable_timer:
                self.timers("broadcast_final_loss").stop()
1131 1132 1133 1134
        else:
            # else just return all intermediate output tensor for all micro steps
            train_loss = self.output_tensors

1135
        self.timer_printer()
1136 1137 1138 1139 1140
        return train_loss

    def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
        data = self._prepare_training(data, optimizer, lr_scheduler)
        # interleave scheduler for pipeline parallel
1141
        train_loss = self.forward_backward_pipeline(data, scaler)
1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155

        # optimizer
        with paddle.amp.auto_cast(enable=False):
            self._optimizer_step()

        return train_loss

    def eval_batch(self, data, compute_loss=False):
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

        self._layers.eval()
        self._compute_loss = compute_loss

1156
        return self.forward_backward_pipeline(data, None, forward_only=True)