pipeline_parallel.py 38.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

import paddle
15
from paddle import framework
16

17
from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer
Y
Yuang Liu 已提交
18
from ..utils import timer_helper as timer
19 20 21 22 23 24 25 26
from ..utils.hybrid_parallel_util import (
    broadcast_dp_parameters,
    broadcast_mp_parameters,
    broadcast_sharding_parameters,
)
from ..utils.log_util import logger
from .meta_parallel_base import MetaParallelBase
from .parallel_layers.pp_layers import PipelineLayer
S
ShenLiang 已提交
27
from .pp_utils import p2p_communication as p2p
28
from .pp_utils.utils import HOOK_ACTION, FusedCommBuffer, assign_group_by_size
29

30 31
__all__ = []

32

33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
# assume only the first stage and last stage need data, and data consumption is ordred
# to be replaced by real micro dataset from reader
class FakeMicroDataset:
    def __init__(
        self, data, is_first_stage, is_last_stage, acc_steps, micro_batch_size
    ):
        self._data = data
        self._index = 0
        self._acc_steps = acc_steps
        self._is_first_stage = is_first_stage
        self._is_last_stage = is_last_stage
        self._micro_batch_size = micro_batch_size

    def __iter__(self):
        return self

    def __next__(self):
        if self._index >= self._acc_steps:
            raise StopIteration
        assert self._is_first_stage or self._is_last_stage
        micro_batch_data = self._load_micro_batch(self._index)
        self._index += 1
        return micro_batch_data

    def _load_micro_batch(self, micro_step):
        inputs = self._data

        if self._is_first_stage or self._is_last_stage:
            assert len(inputs) == 2, "length of input should be 2"
            data = self._load_micro_batch_impl(inputs[0], micro_step)
            label = self._load_micro_batch_impl(inputs[1], micro_step)
            return (data, label)
        else:
            return (None, None)

    def _load_micro_batch_impl(self, inputs, micro_step):
        begin = micro_step * self._micro_batch_size
        end = begin + self._micro_batch_size

        if isinstance(inputs, tuple):
            output = []
            for data in inputs:
                if isinstance(data, list):
                    assert (
                        len(data) == self._acc_steps
                    ), "length of data should be %d, but it is %d" % (
                        self._acc_steps,
                        len(data),
                    )
                    output.append(data[micro_step].detach())
                elif data is not None:
                    self._check_data_vaild(data)
                    output.append(data[begin:end, :].detach())
                else:
                    output.append(None)
            return tuple(output)

        elif isinstance(inputs, list):
            assert (
                len(inputs) == self._acc_steps
            ), "length of data should be %d, but it is %d" % (
                self.accumulate_steps,
                len(inputs),
            )
            return inputs[micro_step].detach()
        elif inputs is not None:
            self._check_data_vaild(inputs)
            return inputs[begin:end, :].detach()
        else:
            return None

    def _check_data_vaild(self, data):
        batch_size = data.shape[0]
        assert self._micro_batch_size * self._acc_steps == batch_size, (
            "batch_size needs to be divisible by micro_batch_size. Currently, "
            "batch_size = %d, micro_batch_size = %d, accumulate_steps = %d."
            % (batch_size, self._micro_batch_size, self._acc_steps)
        )


113 114
class PipelineParallel(MetaParallelBase):
    def __init__(self, layers, hcg, strategy):
115 116
        if not isinstance(layers, PipelineLayer):
            raise TypeError(
117 118
                "The Layer should be a derived class of PipelineLayer."
            )
119
        super().__init__(layers, hcg, strategy)
120 121
        self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1
        self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1
122 123 124
        self.use_sharding_parallel = (
            self._hcg.get_sharding_parallel_world_size() > 1
        )
125 126 127 128

        self.total_loss = None

        self.micro_batch_size = self._strategy.pipeline_configs[
129 130
            'micro_batch_size'
        ]
131
        self.accumulate_steps = self._strategy.pipeline_configs[
132 133
            'accumulate_steps'
        ]
134 135 136
        # If sent tensor are not the same from different hosts,
        # they shouldn't been sent partially and then concated as a whole tensor.
        self._enable_partial_send_recv = self._strategy.pipeline_configs[
137 138
            'enable_partial_send_recv'
        ]
139 140
        self._using_cache = self._strategy.pipeline_configs['p2p_cache_shape']

141 142
        self.num_stages = self._hcg.get_pipe_parallel_world_size()
        self.stage_id = self._hcg.get_stage_id()
143
        self.pp_group = self._hcg.get_pipe_parallel_group()
144
        self.dp_group = self._hcg.get_data_parallel_group()
145
        self.sharding_group = self._hcg.get_sharding_parallel_group()
146

147 148 149 150 151
        self._virtual_pp_world_size = None
        self._virtual_pp_rank = None
        self._real_pp_world_size = self.num_stages
        self._real_pp_rank = self.stage_id

152 153 154
        self._delay_scale_loss = self._strategy.hybrid_configs[
            "pp_configs"
        ].delay_scale_loss
155 156
        # TODO(PP Dev): support dp_comm_overlap without use_main_grad training.
        # This combination will trigger inplace check error during `reshape_` in funct `_split_tensors`.
157 158 159
        self._dp_comm_overlap = self._strategy.hybrid_configs[
            "pp_configs"
        ].dp_comm_overlap
160 161 162
        self._sharding_comm_overlap = self._strategy.hybrid_configs[
            "pp_configs"
        ].sharding_comm_overlap
Y
Yuang Liu 已提交
163 164 165
        self._enable_timer = self._strategy.hybrid_configs[
            "pp_configs"
        ].enable_timer
166

167 168 169
        if self._dp_comm_overlap:
            assert self.use_data_parallel and self.num_stages > 1

170 171 172 173 174 175 176 177 178 179 180 181
        if self._sharding_comm_overlap:
            assert self.use_sharding_parallel and self.num_stages > 1

        assert not (
            self._dp_comm_overlap and self._sharding_comm_overlap
        ), "Cannot use dp pp overlap and sharding pp overlap at the same time."

        self._comm_buffers = []
        self._comm_overlap = (
            self._dp_comm_overlap or self._sharding_comm_overlap
        )

Y
Yuang Liu 已提交
182 183 184 185 186
        if self._enable_timer:
            if not timer.is_timer_initialized():
                timer.set_timers()
            self.timers = timer.get_timers()

187
        p2p.initialize_p2p_groups(
Y
Yuang Liu 已提交
188 189 190 191
            hcg,
            self._using_cache,
            self._enable_partial_send_recv,
            self._enable_timer,
192
        )
193 194

        self.global_rank = self._hcg.get_global_rank()
195
        self.micro_batch_id = 0
196

197 198
        self._compute_loss = True

199 200 201 202 203
        logger.info(
            "Pipeline Info -- num_stages: {}, stage_id: {}".format(
                self.num_stages, self.stage_id
            )
        )
204 205 206 207 208

        if self.use_model_parallel:
            logger.info("start broadcast mp parameters")
            broadcast_mp_parameters(self._layers, self._hcg)

209 210 211 212
        if self.use_sharding_parallel:
            logger.info("start broadcast sharding parameters")
            broadcast_sharding_parameters(self._layers, self._hcg)

213
        if self.use_data_parallel:
214
            logger.info("start broadcast dp parameters")
215
            broadcast_dp_parameters(self._layers, self._hcg)
216

217 218
        if self._dp_comm_overlap:
            self.register_allreduce_overlap_hook(
219
                self._layers, self.dp_group, self.accumulate_steps, True
220 221
            )

222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243
    def is_pipeline_first_stage(self, ignore_virtual=False):
        if not ignore_virtual:
            if self._virtual_pp_world_size is not None:
                assert self._virtual_pp_rank is not None
                if self._virtual_pp_rank != 0:
                    return False
        assert self._real_pp_rank is not None
        return self._real_pp_rank == 0

    def is_pipeline_last_stage(self, ignore_virtual=False):
        if not ignore_virtual:
            if self._virtual_pp_world_size is not None:
                assert self._virtual_pp_rank is not None
                if self._virtual_pp_rank != (self._virtual_pp_world_size - 1):
                    return False
        assert self._real_pp_rank is not None
        assert self._real_pp_world_size is not None
        return self._real_pp_rank == (self._real_pp_world_size - 1)

    def set_virtual_pipeline_rank(self, rank):
        self._virtual_pp_rank = rank

244 245 246 247 248 249 250
    def bw_hook_func(self, buffer, param):
        @paddle.autograd.no_grad()
        def fused_allreduce(*_):
            buffer.add_grad(param)

        return fused_allreduce

251
    def register_allreduce_overlap_hook(self, model, comm_group, acc_steps, dp):
252 253 254 255 256
        if model.get_num_virtual_stages() > 1:
            models = model.get_model_chunks()
        else:
            models = [model]

257 258 259 260 261 262 263 264 265
        if not dp:
            assert hasattr(self, "optimizer")
            assert hasattr(self.optimizer, "_param2rank")
            _param2rank = self.optimizer._param2rank

        act = HOOK_ACTION.ALL_REDUCE if dp else HOOK_ACTION.REDUCE

        fused_parameter_group = {}

266 267 268 269 270 271 272 273 274
        for model in models:
            # For virtual pipeline. Will separate parameters in different chunk into
            # different groups to get the best performance.
            parameter_list = [
                p for p in model.parameters() if not p.stop_gradient
            ]
            if len(parameter_list) < 1:
                return

275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295
            if dp:
                fused_parameter_group[-1] = parameter_list
            else:
                # Sort parameters for sharding, since they have different dst rank
                for p in parameter_list:
                    assert p.name in _param2rank
                    dst_rank = _param2rank[p.name]
                    if dst_rank in fused_parameter_group:
                        fused_parameter_group[dst_rank].append(p)
                    else:
                        fused_parameter_group[dst_rank] = [p]

            for dst in fused_parameter_group:
                parameter_list = fused_parameter_group[dst]
                if not dp:
                    # parse the relative dst rank to absolute dst rank for sharding
                    dst = comm_group.ranks[dst]
                var_groups = assign_group_by_size(parameter_list)
                for group_idx, parameters in var_groups.items():
                    buffer = FusedCommBuffer(
                        group_idx, parameters, comm_group, acc_steps, act, dst
296
                    )
297 298 299 300 301
                    self._comm_buffers.append(buffer)
                    for param in parameters:
                        param._register_backward_hook(
                            self.bw_hook_func(buffer, param)
                        )
302

Y
Yuang Liu 已提交
303 304 305 306 307 308
    def timer_printer(self):
        if not self._enable_timer:
            return
        all_flag_names = self.timers.timers.keys()
        self.timers.log(all_flag_names)

309 310 311 312
    def forward_backward_pipeline(self, data, scaler=None):
        # use the 1f1b scheduling strategy.
        # this strategy is inspired by:
        # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
313

314 315
        self.scaler = scaler

316 317 318
        # store total loss of entire batch
        self.total_loss = None

319 320
        # store data id for micro_batch
        self.micro_batch_id = 0
321

322
        startup_steps = self.num_stages - self.stage_id - 1
323 324
        startup_steps = min(startup_steps, self.accumulate_steps)
        steady_steps = self.accumulate_steps - startup_steps
325

326 327
        input_buffers = []
        output_buffers = []
328

329 330
        micro_dataset = self._wrap_data(data)

331
        for step_id in range(startup_steps):
332
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
333

334
            output_tensor = self._forward_step(input_tensor, micro_dataset)
335
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
336

337 338
            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)
339

340 341 342
            if not self.is_pipeline_last_stage():
                self._release_output(output_tensor)

343
        if steady_steps > 0:
344
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
345

346
        for i in range(steady_steps):
347
            last_iter = i == (steady_steps - 1)
348

349
            output_tensor = self._forward_step(input_tensor, micro_dataset)
350

351
            output_tensor_grad = p2p.send_forward_recv_backward(
352 353
                output_tensor, self.is_pipeline_last_stage()
            )
354

355 356
            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)
357

358 359 360
            if not self.is_pipeline_last_stage():
                self._release_output(output_tensor)

361
            input_tensor, output_tensor = input_buffers.pop(
362 363
                0
            ), output_buffers.pop(0)
364

365 366 367
            input_tensor_grad = self._backward_step(
                input_tensor, output_tensor, output_tensor_grad
            )
368 369 370

            if last_iter:
                input_tensor = None
371 372 373
                p2p.send_backward(
                    input_tensor_grad, self.is_pipeline_first_stage()
                )
374
            else:
375
                input_tensor = p2p.send_backward_recv_forward(
376 377
                    input_tensor_grad, self.is_pipeline_first_stage()
                )
378

379 380 381
        for i in range(startup_steps):
            input_tensor = input_buffers.pop(0)
            output_tensor = output_buffers.pop(0)
382

383
            output_tensor_grad = p2p.recv_backward(
384 385
                self.is_pipeline_last_stage()
            )
386

387 388 389
            input_tensor_grad = self._backward_step(
                input_tensor, output_tensor, output_tensor_grad
            )
390
            p2p.send_backward(input_tensor_grad, self.is_pipeline_first_stage())
391

392 393 394
        if self._comm_overlap:
            assert len(self._comm_buffers) > 0
            for buffer in self._comm_buffers:
395 396
                buffer.scale_and_split_grads()

Y
Yuang Liu 已提交
397 398
        if self._enable_timer:
            self.timers("allreduce_shared_weight_gradients").start()
399
        self._layers.allreduce_shared_weight_gradients()
Y
Yuang Liu 已提交
400 401 402
        if self._enable_timer:
            self.timers("allreduce_shared_weight_gradients").stop()
            self.timers("broadcast_final_loss").start()
403 404
        with paddle.amp.auto_cast(enable=False):
            train_loss = self._broadcast_final_loss()
Y
Yuang Liu 已提交
405 406 407
        if self._enable_timer:
            self.timers("broadcast_final_loss").stop()
        self.timer_printer()
408 409
        return train_loss

410 411 412 413
    def _prepare_training(self, data, optimizer, lr_scheduler):
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

414 415 416
        assert isinstance(
            optimizer, HybridParallelOptimizer
        ), 'optimizer should be HybridParallelOptimizer subclass.'
417

418
        assert (
419
            framework._dygraph_tracer()._has_grad
420
        ), 'Please enable the generation of gradients.'
421

422
        if self.is_pipeline_first_stage(
423 424 425 426 427
            ignore_virtual=True
        ) or self.is_pipeline_last_stage(ignore_virtual=True):
            assert (
                data is not None
            ), "For the first and the last stage, the data must be set."
428 429 430 431 432 433 434 435
        else:
            data = None

        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler

        self._layers.train()

436 437 438 439 440
        if self._sharding_comm_overlap and len(self._comm_buffers) == 0:
            self.register_allreduce_overlap_hook(
                self._layers, self.sharding_group, self.accumulate_steps, False
            )

441 442
        return data

443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458
    def _wrap_data(self, data):
        """
        for backward compatibilty, wrap data to Fake FakeMicroDataset if it is of type list or tuple
        """
        if (not isinstance(data, tuple)) and (not isinstance(data, list)):
            return data

        micro_dataset = FakeMicroDataset(
            data,
            self.is_pipeline_first_stage(ignore_virtual=True),
            self.is_pipeline_last_stage(ignore_virtual=True),
            self.accumulate_steps,
            self.micro_batch_size,
        )
        return micro_dataset

459 460 461
    def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
        data = self._prepare_training(data, optimizer, lr_scheduler)
        # 1f1b scheduler for pipeline parallel
462
        train_loss = self.forward_backward_pipeline(data, scaler)
463 464

        # optimizer
465 466
        with paddle.amp.auto_cast(enable=False):
            self._optimizer_step()
467 468

        return train_loss
469

470
    def eval_batch(self, data, compute_loss=False):
471 472 473
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

474 475 476 477 478 479 480 481 482
        self._layers.eval()
        self._compute_loss = compute_loss

        # store data id for micro_batch
        self.micro_batch_id = 0

        # store total loss of entire batch
        self.total_loss = None

483
        startup_steps = self.num_stages - self.stage_id - 1
484 485 486 487 488 489
        startup_steps = min(startup_steps, self.accumulate_steps)
        steady_steps = self.accumulate_steps - startup_steps

        input_buffers = []
        output_buffers = []

490 491 492
        # convert to micro dataset
        micro_dataset = self._wrap_data(data)

493
        for step_id in range(startup_steps):
494
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
495

496
            output_tensor = self._forward_step(input_tensor, micro_dataset)
497
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
498 499 500 501 502

            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)

        if steady_steps > 0:
503
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
504 505

        for i in range(steady_steps):
506
            last_iter = i == (steady_steps - 1)
507

508
            output_tensor = self._forward_step(input_tensor, micro_dataset)
509
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
510 511 512 513 514

            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)

            if not last_iter:
515
                input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
516

517 518 519 520 521 522
        if self._compute_loss:
            self.train_loss = self._broadcast_final_loss()
        else:
            self.train_loss = output_buffers

        return self.train_loss
523

524
    def _forward_step(self, input_tensor, micro_dataset, chunk_id=None):
Y
Yuang Liu 已提交
525 526
        if self._enable_timer:
            self.timers("forward_step").start()
527
        if self.is_pipeline_first_stage():
528 529
            input_tensor = next(micro_dataset)[0]
            self._check_micro_batch_data_valid(input_tensor)
530

531 532 533
        assert chunk_id is None or isinstance(chunk_id, int)

        output_tensor = self._layers.forward(input_tensor, chunk_id=chunk_id)
534

535
        if self.is_pipeline_last_stage():
536 537
            # train calculate loss for train
            if self._compute_loss:
538 539 540
                assert (
                    self._layers._loss_fn is not None
                ), "loss function should exist to compute loss"
541 542
                labels = next(micro_dataset)[1]
                self._check_micro_batch_data_valid(labels)
543
                output_tensor = self._layers._loss_fn(output_tensor, labels)
544
                assert isinstance(
545
                    output_tensor, (paddle.Tensor, framework.core.eager.Tensor)
546
                ), "Currently, loss_fn should obtain Paddle.Tensor dtype"
547

548
                with paddle.amp.auto_cast(enable=False):
549
                    if self.accumulate_steps > 1 and not self._delay_scale_loss:
550
                        output_tensor = output_tensor / self.accumulate_steps
551

552 553 554
                    if self.total_loss is None:
                        self.total_loss = paddle.zeros_like(output_tensor)
                    self.total_loss += output_tensor.detach()
555

556 557 558 559
        if self.is_pipeline_first_stage() or self.is_pipeline_last_stage():
            # Only increase micro batch id at virtual first/last pp stage.
            # The micro batch id is used to load data, therefore, only increase it when load data.
            self.micro_batch_id += 1
Y
Yuang Liu 已提交
560 561
        if self._enable_timer:
            self.timers("forward_step").stop()
562 563 564
        return output_tensor

    def _backward_step(self, input_tensor, output_tensor, output_tensor_grad):
Y
Yuang Liu 已提交
565 566
        if self._enable_timer:
            self.timers("backward_step").start()
S
ShenLiang 已提交
567
        with paddle.amp.auto_cast(enable=False):
568
            if self.is_pipeline_last_stage():
S
ShenLiang 已提交
569 570 571 572 573
                assert output_tensor_grad is None
                if self.scaler:
                    paddle.autograd.backward(self.scaler.scale(output_tensor))
                else:
                    paddle.autograd.backward(output_tensor)
574
            else:
S
ShenLiang 已提交
575 576 577 578 579
                if isinstance(output_tensor, tuple):
                    outputs = [t for t in output_tensor if not t.stop_gradient]
                    assert len(outputs) == len(output_tensor_grad)
                    paddle.autograd.backward(
                        tensors=outputs,
580
                        grad_tensors=list(output_tensor_grad),
581
                    )
S
ShenLiang 已提交
582
                else:
583 584 585 586
                    paddle.autograd.backward(
                        tensors=[output_tensor],
                        grad_tensors=[output_tensor_grad],
                    )
S
ShenLiang 已提交
587 588 589 590 591

            input_tensor_grad = None
            if input_tensor is not None:
                if isinstance(input_tensor, tuple):
                    input_tensor_grad = tuple(
592 593
                        [t.grad for t in input_tensor if not t.stop_gradient]
                    )
S
ShenLiang 已提交
594 595
                else:
                    input_tensor_grad = input_tensor.grad
Y
Yuang Liu 已提交
596 597
            if self._enable_timer:
                self.timers("backward_step").stop()
S
ShenLiang 已提交
598
            return input_tensor_grad
599

600 601 602 603 604 605
    def _check_micro_batch_data_valid(self, micro_batch_data):
        if isinstance(micro_batch_data, (tuple, list)):
            for data in micro_batch_data:
                self._check_micro_batch_data_valid(data)
        elif micro_batch_data is not None:
            micro_batch_size = micro_batch_data.shape[0]
606
            assert (
607 608
                micro_batch_size == self.micro_batch_size
            ), f"expected micro_batch_size {self.micro_batch_size} but get {micro_batch_size}"
609

610
    def _broadcast_final_loss(self):
611 612 613
        # Since the last backward run in interleave will set the virtual rank to 0,
        # here we need to check last stage ignoring virtual stage.
        if self.is_pipeline_last_stage(ignore_virtual=True):
614 615 616
            assert (
                self.total_loss is not None
            ), "train_batch() in last stage should obtain vaild loss"
617 618 619 620 621
            loss = (
                self.total_loss.detach()
                if not self._delay_scale_loss
                else self.total_loss / self.accumulate_steps
            )
622
            is_fp32 = (
623
                paddle.full([], 1, 'int64')
624
                if loss.dtype == paddle.float32
625
                else paddle.full([], 0, 'int64')
626 627 628 629 630 631 632
            )
            paddle.distributed.broadcast(
                is_fp32, src=self.global_rank, sync_op=True, group=self.pp_group
            )
            paddle.distributed.broadcast(
                loss, src=self.global_rank, sync_op=True, group=self.pp_group
            )
633
        else:
634
            is_fp32 = paddle.full([], 1, 'int64')
635 636 637
            paddle.distributed.broadcast(
                is_fp32,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
638
                sync_op=True,
639 640 641 642
                group=self.pp_group,
            )
            loss = (
                paddle.zeros(shape=[1], dtype="float32")
643
                if is_fp32.item()
644 645
                else paddle.zeros(shape=[1], dtype="float16")
            )
646 647 648
            paddle.distributed.broadcast(
                loss,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
649
                sync_op=True,
650 651
                group=self.pp_group,
            )
652
        return loss
653

654
    def _optimizer_step(self):
655 656 657 658 659 660 661 662
        if self._delay_scale_loss:
            for p in self._layers.parameters():
                if hasattr(p, "main_grad") and p.main_grad is not None:
                    assert p.grad is None
                    p.main_grad = p.main_grad.scale(1.0 / self.accumulate_steps)
                elif p.grad is not None:
                    p.grad = p.grad.scale(1.0 / self.accumulate_steps)

663
        if self.scaler:
664
            self.scaler.step(self.optimizer)
S
ShenLiang 已提交
665
            self.scaler.update()
666 667
        else:
            self.optimizer.step()
668

669 670 671
        self.optimizer.clear_grad()
        if self.lr_scheduler:
            self.lr_scheduler.step()
672

673 674 675 676 677 678 679 680
    def _release_output(self, output):
        if isinstance(output, (tuple, list)):
            for t in output:
                if t is not None and isinstance(t, paddle.Tensor):
                    t._clear_dataptr()
        elif output is not None and isinstance(output, paddle.Tensor):
            output._clear_dataptr()

681 682 683 684 685

class PipelineParallelWithInterleave(PipelineParallel):
    # pipeline parallel with interleave scheduler

    def __init__(self, layers, hcg, strategy):
686
        super().__init__(layers=layers, hcg=hcg, strategy=strategy)
687
        assert layers.get_num_virtual_stages() > 1
688 689 690
        assert (
            self.num_stages > 2
        ), "virtual pipeline must run under pp degree > 2"
691
        assert (
692
            framework.in_dynamic_mode()
693
        ), "virtual pipeline stage with interleave only support eager dygraph mode"
694 695 696
        assert (
            self.accumulate_steps % self.num_stages == 0
        ), "accumulate_steps should be evenly divisible by num_stages for pipeline with interleave"
697 698 699 700 701 702 703 704 705
        # setup for interleave scheduler
        self.num_model_chunks = layers.get_num_virtual_stages()
        self.model_chunks = layers.get_model_chunks()
        assert self.model_chunks is not None
        assert len(self.model_chunks) == self.num_model_chunks
        self._virtual_pp_world_size = self.num_model_chunks
        self._virtual_pp_rank = 0

    def _get_virtual_pp_rank(self, micro_step, forward):
706 707 708
        virtual_pp_stage = micro_step % (
            self.num_stages * self.num_model_chunks
        )
709 710
        virtual_pp_stage = virtual_pp_stage // self.num_stages
        if not forward:
711
            virtual_pp_stage = self.num_model_chunks - virtual_pp_stage - 1
712 713
        return virtual_pp_stage

714
    def _forward_step_helper(self, micro_dataset, micro_step):
715 716 717 718 719 720 721 722
        virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=True)
        self.set_virtual_pipeline_rank(virtual_pp_rank)

        # some checkers
        assert hasattr(self, 'input_tensors')
        assert hasattr(self, 'output_tensors')
        if not self._forward_only:
            assert hasattr(self, 'output_tensor_grads')
723 724 725
        assert len(self.input_tensors[virtual_pp_rank]) == (
            len(self.output_tensors[virtual_pp_rank]) + 1
        )
726
        input_tensor = self.input_tensors[virtual_pp_rank][-1]
727 728 729
        output_tensor = self._forward_step(
            input_tensor, micro_dataset, virtual_pp_rank
        )
730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747
        self.output_tensors[virtual_pp_rank].append(output_tensor)

        if self._forward_only:
            # no need to store tensor for backward
            self.input_tensors[virtual_pp_rank].pop()
            self.output_tensors[virtual_pp_rank].pop()

        return output_tensor

    def _backward_step_helper(self, micro_step):
        virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=False)
        self.set_virtual_pipeline_rank(virtual_pp_rank)

        # some checkers
        assert hasattr(self, 'input_tensors')
        assert hasattr(self, 'output_tensors')
        assert hasattr(self, 'output_tensor_grads')

748 749 750 751 752 753
        assert (
            len(self.output_tensor_grads[virtual_pp_rank]) == 1
        ), f"output_tensor_grads is empty for virtual_pp_rank {virtual_pp_rank}"

        assert len(self.input_tensors[virtual_pp_rank]) > 0
        assert len(self.output_tensors[virtual_pp_rank]) > 0
754 755 756 757

        input_tensor = self.input_tensors[virtual_pp_rank].pop(0)
        output_tensor = self.output_tensors[virtual_pp_rank].pop(0)
        output_tensor_grad = self.output_tensor_grads[virtual_pp_rank].pop(0)
758 759 760
        input_tensor_grad = self._backward_step(
            input_tensor, output_tensor, output_tensor_grad
        )
761 762 763

        return input_tensor_grad

764
    def forward_backward_pipeline(
765 766
        self, data, scaler, forward_only=False, compute_loss=True
    ):
767 768 769 770
        # use interleave scheduling strategy.
        # this strategy is inspired by:
        # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
        if not compute_loss:
771 772 773
            assert (
                not forward_only
            ), "compute_loss can only be set to False when forward_only is set to True"
774 775 776 777 778 779 780 781 782 783 784 785

        # init some attributes for this batch run
        self.scaler = scaler
        self.total_loss = None
        self.micro_batch_id = 0
        self._forward_only = forward_only

        # init some data buffers for interleave scheduler
        self.input_tensors = [[] for _ in range(self.num_model_chunks)]
        self.output_tensors = [[] for _ in range(self.num_model_chunks)]
        self.output_tensor_grads = [[] for _ in range(self.num_model_chunks)]

786 787
        micro_dataset = self._wrap_data(data)

788 789 790 791 792
        num_steps = self.accumulate_steps * self.num_model_chunks
        if forward_only:
            # If only forward, since there is no backward during running, all steps are startup steps
            startup_steps = num_steps
        else:
793 794 795 796 797 798 799
            # actually startup_steps is calculated from two number:
            # first_forward_cross_to_end = (self.num_stages - self.stage_id - 1) + (self.num_model_chunks - 1) * self.num_stages
            # end_to_first_backward_cross = (self.num_stages - self.stage_id - 1)
            # startup_steps = first_forward_cross_to_end + end_to_first_backward_cross
            startup_steps = (self.num_stages - self.stage_id - 1) * 2
            startup_steps += (self.num_model_chunks - 1) * self.num_stages
            startup_steps = min(startup_steps, num_steps)
800 801 802 803 804

        steady_steps = num_steps - startup_steps

        self.set_virtual_pipeline_rank(0)
        self.input_tensors[0].append(
805 806
            p2p.recv_forward(self.is_pipeline_first_stage(), sync_recv=False)
        )
807 808 809

        # run startup steps
        for micro_step in range(startup_steps):
810
            output_tensor = self._forward_step_helper(micro_dataset, micro_step)
811 812

            # determine whether recv forward tensor or not
813 814 815
            next_virtual_pp_rank = self._get_virtual_pp_rank(
                micro_step + 1, forward=True
            )
816 817 818 819 820 821 822 823 824 825 826 827 828
            recv_prev = True
            if self.is_pipeline_first_stage(ignore_virtual=True):
                if next_virtual_pp_rank == 0:
                    # next chunk is the first chunk, not need to pre recv an input tensor
                    recv_prev = False
            # last micro step, no next run
            if micro_step == (num_steps - 1):
                recv_prev = False

            # last stage shouldn't send tensor to downstream
            if self.is_pipeline_last_stage():
                output_tensor = None

829
            if micro_step == (startup_steps - 1) and not forward_only:
830 831 832 833 834 835
                input_tensor_grad = None
                recv_next = True
                if self.is_pipeline_last_stage(ignore_virtual=True):
                    recv_next = False

                # the last startup step needs on four direction comm to set up for steady 1f1b
836 837 838 839
                (
                    input_tensor,
                    output_tensor_grad,
                ) = p2p.send_forward_backward_recv_forward_backward(
840 841 842
                    output_tensor,
                    input_tensor_grad,
                    recv_prev=recv_prev,
843 844
                    recv_next=recv_next,
                )
845 846
                # output_tensor_grad is not none if recv_next
                # append output_tensor_grad no matter none or not
847 848 849
                self.output_tensor_grads[self.num_model_chunks - 1].append(
                    output_tensor_grad
                )
850 851
            else:
                input_tensor = p2p.send_forward_recv_forward(
852 853
                    output_tensor, recv_prev=recv_prev
                )
854
            # append input_tensor no matter none or not
855 856
            self.input_tensors[next_virtual_pp_rank].append(input_tensor)

857 858
            self._release_output(output_tensor)

859 860 861 862
        # run 1f1b steady steps
        for micro_step in range(steady_steps):
            # forward
            forward_micro_step_id = micro_step + startup_steps
863 864 865
            output_tensor = self._forward_step_helper(
                micro_dataset, forward_micro_step_id
            )
866 867 868 869

            # backward
            backward_micro_step_id = micro_step
            input_tensor_grad = self._backward_step_helper(
870 871
                backward_micro_step_id
            )
872 873 874 875 876 877 878 879 880

            # four directions comm
            # send output tensor to downstream
            # send input tensor grad to upstream
            # recv input tensor from upstream
            # recv output tensor grad from downstream

            # last stage doesn't send rst to downstream
            forward_virtual_pp_rank = self._get_virtual_pp_rank(
881 882
                forward_micro_step_id, forward=True
            )
883 884 885 886 887 888
            self.set_virtual_pipeline_rank(forward_virtual_pp_rank)
            if self.is_pipeline_last_stage():
                output_tensor = None

            # first stage doesn't send grad to upstream
            backward_virtual_pp_rank = self._get_virtual_pp_rank(
889 890
                backward_micro_step_id, forward=False
            )
891 892 893 894 895 896
            self.set_virtual_pipeline_rank(backward_virtual_pp_rank)
            if self.is_pipeline_first_stage():
                input_tensor_grad = None

            # determine whether to recv input tensor from upstream
            recv_prev = True
897 898 899 900 901 902 903 904
            next_forward_virtual_pp_rank = self._get_virtual_pp_rank(
                forward_micro_step_id + 1, forward=True
            )
            if self.is_pipeline_first_stage(ignore_virtual=True) and (
                next_forward_virtual_pp_rank == 0
            ):
                # first pp stage and first virtual stage
                recv_prev = False
905 906 907 908 909 910 911

            # last iteration doesn't need recv from upstream
            if micro_step == (steady_steps - 1):
                recv_prev = False

            # determine whether to recv grad from downstream
            recv_next = True
912 913 914 915 916 917 918 919
            next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
                backward_micro_step_id + 1, forward=False
            )
            if self.is_pipeline_last_stage(ignore_virtual=True) and (
                next_backward_virtual_pp_rank == (self.num_model_chunks - 1)
            ):
                # last pp stage and last virtual stage
                recv_next = False
920

921 922 923 924
            (
                input_tensor,
                output_tensor_grad,
            ) = p2p.send_forward_backward_recv_forward_backward(
925 926 927
                output_tensor,
                input_tensor_grad,
                recv_prev=recv_prev,
928 929
                recv_next=recv_next,
            )
930 931 932 933 934 935 936 937
            # append input_tensor no matter none or not
            self.input_tensors[next_forward_virtual_pp_rank].append(
                input_tensor
            )
            # append output_tensor_grad no matter none or not
            self.output_tensor_grads[next_backward_virtual_pp_rank].append(
                output_tensor_grad
            )
938 939 940
            self._release_output(output_tensor)

        self._release_output(output_tensor)
941 942 943 944 945 946 947

        # remaining backward steps
        if not forward_only:
            for micro_step in range(steady_steps, num_steps):
                # cooldown loop
                input_tensor_grad = self._backward_step_helper(micro_step)
                next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
948 949
                    micro_step + 1, forward=False
                )
950 951 952

                recv_next = True
                if self.is_pipeline_last_stage(ignore_virtual=True):
953 954 955
                    if next_backward_virtual_pp_rank == (
                        self.num_model_chunks - 1
                    ):
956 957 958 959
                        recv_next = False

                if micro_step == (num_steps - 1):
                    recv_next = False
960
                # append output_tensor_grad no matter none or not
961
                self.output_tensor_grads[next_backward_virtual_pp_rank].append(
962 963 964 965
                    p2p.send_backward_recv_backward(
                        input_tensor_grad, recv_next=recv_next
                    )
                )
966

967 968 969
            if self._comm_overlap:
                assert len(self._comm_buffers) > 0
                for buffer in self._comm_buffers:
970 971
                    buffer.scale_and_split_grads()

Y
Yuang Liu 已提交
972 973
            if self._enable_timer:
                self.timers("allreduce_shared_weight_gradients").start()
974
            self._layers.allreduce_shared_weight_gradients()
Y
Yuang Liu 已提交
975 976
            if self._enable_timer:
                self.timers("allreduce_shared_weight_gradients").stop()
977 978 979

        if compute_loss:
            # return loss if compute loss
Y
Yuang Liu 已提交
980 981
            if self._enable_timer:
                self.timers("broadcast_final_loss").start()
982 983
            with paddle.amp.auto_cast(enable=False):
                train_loss = self._broadcast_final_loss()
Y
Yuang Liu 已提交
984 985
            if self._enable_timer:
                self.timers("broadcast_final_loss").stop()
986 987 988 989
        else:
            # else just return all intermediate output tensor for all micro steps
            train_loss = self.output_tensors

Y
Yuang Liu 已提交
990
        self.timer_printer()
991 992 993 994 995
        return train_loss

    def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
        data = self._prepare_training(data, optimizer, lr_scheduler)
        # interleave scheduler for pipeline parallel
996
        train_loss = self.forward_backward_pipeline(data, scaler)
997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010

        # optimizer
        with paddle.amp.auto_cast(enable=False):
            self._optimizer_step()

        return train_loss

    def eval_batch(self, data, compute_loss=False):
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

        self._layers.eval()
        self._compute_loss = compute_loss

1011
        return self.forward_backward_pipeline(data, None, forward_only=True)