pipeline_parallel.py 33.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

import paddle
15
from paddle import framework
16

17
from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer
Y
Yuang Liu 已提交
18
from ..utils import timer_helper as timer
19 20 21 22 23 24 25 26
from ..utils.hybrid_parallel_util import (
    broadcast_dp_parameters,
    broadcast_mp_parameters,
    broadcast_sharding_parameters,
)
from ..utils.log_util import logger
from .meta_parallel_base import MetaParallelBase
from .parallel_layers.pp_layers import PipelineLayer
S
ShenLiang 已提交
27
from .pp_utils import p2p_communication as p2p
28
from .pp_utils.utils import FusedAllReduceBuffer, assign_group_by_size
29

30 31
__all__ = []

32 33 34

class PipelineParallel(MetaParallelBase):
    def __init__(self, layers, hcg, strategy):
35 36
        if not isinstance(layers, PipelineLayer):
            raise TypeError(
37 38
                "The Layer should be a derived class of PipelineLayer."
            )
39
        super().__init__(layers, hcg, strategy)
40 41
        self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1
        self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1
42 43 44
        self.use_sharding_parallel = (
            self._hcg.get_sharding_parallel_world_size() > 1
        )
45 46 47 48

        self.total_loss = None

        self.micro_batch_size = self._strategy.pipeline_configs[
49 50
            'micro_batch_size'
        ]
51
        self.accumulate_steps = self._strategy.pipeline_configs[
52 53
            'accumulate_steps'
        ]
54 55 56
        # If sent tensor are not the same from different hosts,
        # they shouldn't been sent partially and then concated as a whole tensor.
        self._enable_partial_send_recv = self._strategy.pipeline_configs[
57 58
            'enable_partial_send_recv'
        ]
59 60
        self._using_cache = self._strategy.pipeline_configs['p2p_cache_shape']

61 62
        self.num_stages = self._hcg.get_pipe_parallel_world_size()
        self.stage_id = self._hcg.get_stage_id()
63
        self.pp_group = self._hcg.get_pipe_parallel_group()
64
        self.dp_group = self._hcg.get_data_parallel_group()
65

66 67 68 69 70
        self._virtual_pp_world_size = None
        self._virtual_pp_rank = None
        self._real_pp_world_size = self.num_stages
        self._real_pp_rank = self.stage_id

71 72 73 74 75 76
        self._delay_scale_loss = self._strategy.hybrid_configs[
            "pp_configs"
        ].delay_scale_loss
        self._dp_comm_overlap = self._strategy.hybrid_configs[
            "pp_configs"
        ].dp_comm_overlap
Y
Yuang Liu 已提交
77 78 79
        self._enable_timer = self._strategy.hybrid_configs[
            "pp_configs"
        ].enable_timer
80 81
        self._dp_comm_buffers = []

82 83 84
        if self._dp_comm_overlap:
            assert self.use_data_parallel and self.num_stages > 1

Y
Yuang Liu 已提交
85 86 87 88 89
        if self._enable_timer:
            if not timer.is_timer_initialized():
                timer.set_timers()
            self.timers = timer.get_timers()

90
        p2p.initialize_p2p_groups(
Y
Yuang Liu 已提交
91 92 93 94
            hcg,
            self._using_cache,
            self._enable_partial_send_recv,
            self._enable_timer,
95
        )
96 97

        self.global_rank = self._hcg.get_global_rank()
98
        self.micro_batch_id = 0
99

100 101
        self._compute_loss = True

102 103 104 105 106
        logger.info(
            "Pipeline Info -- num_stages: {}, stage_id: {}".format(
                self.num_stages, self.stage_id
            )
        )
107 108 109 110 111

        if self.use_model_parallel:
            logger.info("start broadcast mp parameters")
            broadcast_mp_parameters(self._layers, self._hcg)

112 113 114 115
        if self.use_sharding_parallel:
            logger.info("start broadcast sharding parameters")
            broadcast_sharding_parameters(self._layers, self._hcg)

116
        if self.use_data_parallel:
117
            logger.info("start broadcast dp parameters")
118
            broadcast_dp_parameters(self._layers, self._hcg)
119

120 121 122 123 124
        if self._dp_comm_overlap:
            self.register_allreduce_overlap_hook(
                self._layers, self.dp_group, self.accumulate_steps
            )

125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
    def is_pipeline_first_stage(self, ignore_virtual=False):
        if not ignore_virtual:
            if self._virtual_pp_world_size is not None:
                assert self._virtual_pp_rank is not None
                if self._virtual_pp_rank != 0:
                    return False
        assert self._real_pp_rank is not None
        return self._real_pp_rank == 0

    def is_pipeline_last_stage(self, ignore_virtual=False):
        if not ignore_virtual:
            if self._virtual_pp_world_size is not None:
                assert self._virtual_pp_rank is not None
                if self._virtual_pp_rank != (self._virtual_pp_world_size - 1):
                    return False
        assert self._real_pp_rank is not None
        assert self._real_pp_world_size is not None
        return self._real_pp_rank == (self._real_pp_world_size - 1)

    def set_virtual_pipeline_rank(self, rank):
        self._virtual_pp_rank = rank

147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
    def bw_hook_func(self, buffer, param):
        @paddle.autograd.no_grad()
        def fused_allreduce(*_):
            buffer.add_grad(param)

        return fused_allreduce

    def register_allreduce_overlap_hook(self, model, comm_group, acc_steps):
        parameter_list = [p for p in model.parameters() if not p.stop_gradient]
        if len(parameter_list) < 1:
            return

        var_groups = assign_group_by_size(parameter_list)
        for group_idx, parameters in var_groups.items():
            buffer = FusedAllReduceBuffer(
                group_idx, parameters, comm_group, acc_steps
            )
            self._dp_comm_buffers.append(buffer)
            for param in parameters:
                param._register_backward_hook(self.bw_hook_func(buffer, param))

Y
Yuang Liu 已提交
168 169 170 171 172 173
    def timer_printer(self):
        if not self._enable_timer:
            return
        all_flag_names = self.timers.timers.keys()
        self.timers.log(all_flag_names)

174 175 176 177
    def forward_backward_pipeline(self, data, scaler=None):
        # use the 1f1b scheduling strategy.
        # this strategy is inspired by:
        # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
178

179 180
        self.scaler = scaler

181 182
        # store data for train
        self.data = data
183

184 185 186
        # store total loss of entire batch
        self.total_loss = None

187 188
        # store data id for micro_batch
        self.micro_batch_id = 0
189

190
        startup_steps = self.num_stages - self.stage_id - 1
191 192
        startup_steps = min(startup_steps, self.accumulate_steps)
        steady_steps = self.accumulate_steps - startup_steps
193

194 195
        input_buffers = []
        output_buffers = []
196

197
        for step_id in range(startup_steps):
198
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
199

200
            output_tensor = self._forward_step(input_tensor)
201
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
202

203 204
            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)
205

206
        if steady_steps > 0:
207
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
208

209
        for i in range(steady_steps):
210
            last_iter = i == (steady_steps - 1)
211

212
            output_tensor = self._forward_step(input_tensor)
213

214
            output_tensor_grad = p2p.send_forward_recv_backward(
215 216
                output_tensor, self.is_pipeline_last_stage()
            )
217

218 219
            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)
220

221
            input_tensor, output_tensor = input_buffers.pop(
222 223
                0
            ), output_buffers.pop(0)
224

225 226 227
            input_tensor_grad = self._backward_step(
                input_tensor, output_tensor, output_tensor_grad
            )
228 229 230

            if last_iter:
                input_tensor = None
231 232 233
                p2p.send_backward(
                    input_tensor_grad, self.is_pipeline_first_stage()
                )
234
            else:
235
                input_tensor = p2p.send_backward_recv_forward(
236 237
                    input_tensor_grad, self.is_pipeline_first_stage()
                )
238

239 240 241
        for i in range(startup_steps):
            input_tensor = input_buffers.pop(0)
            output_tensor = output_buffers.pop(0)
242

243
            output_tensor_grad = p2p.recv_backward(
244 245
                self.is_pipeline_last_stage()
            )
246

247 248 249
            input_tensor_grad = self._backward_step(
                input_tensor, output_tensor, output_tensor_grad
            )
250
            p2p.send_backward(input_tensor_grad, self.is_pipeline_first_stage())
251

252 253 254 255 256
        if self._dp_comm_overlap:
            assert len(self._dp_comm_buffers) > 0
            for buffer in self._dp_comm_buffers:
                buffer.scale_and_split_grads()

Y
Yuang Liu 已提交
257 258
        if self._enable_timer:
            self.timers("allreduce_shared_weight_gradients").start()
259
        self._layers.allreduce_shared_weight_gradients()
Y
Yuang Liu 已提交
260 261 262
        if self._enable_timer:
            self.timers("allreduce_shared_weight_gradients").stop()
            self.timers("broadcast_final_loss").start()
263 264
        with paddle.amp.auto_cast(enable=False):
            train_loss = self._broadcast_final_loss()
Y
Yuang Liu 已提交
265 266 267
        if self._enable_timer:
            self.timers("broadcast_final_loss").stop()
        self.timer_printer()
268 269
        return train_loss

270 271 272 273
    def _prepare_training(self, data, optimizer, lr_scheduler):
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

274 275 276
        assert isinstance(
            optimizer, HybridParallelOptimizer
        ), 'optimizer should be HybridParallelOptimizer subclass.'
277

278
        assert (
279
            framework._dygraph_tracer()._has_grad
280
        ), 'Please enable the generation of gradients.'
281

282
        if self.is_pipeline_first_stage(
283 284 285 286 287
            ignore_virtual=True
        ) or self.is_pipeline_last_stage(ignore_virtual=True):
            assert (
                data is not None
            ), "For the first and the last stage, the data must be set."
288 289 290 291 292 293 294 295
        else:
            data = None

        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler

        self._layers.train()

296 297 298 299 300
        return data

    def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
        data = self._prepare_training(data, optimizer, lr_scheduler)
        # 1f1b scheduler for pipeline parallel
301
        train_loss = self.forward_backward_pipeline(data, scaler)
302 303

        # optimizer
304 305
        with paddle.amp.auto_cast(enable=False):
            self._optimizer_step()
306 307

        return train_loss
308

309
    def eval_batch(self, data, compute_loss=False):
310 311 312
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

313 314 315 316 317 318 319 320 321 322 323
        self._layers.eval()
        self._compute_loss = compute_loss

        # save data for eval
        self.data = data
        # store data id for micro_batch
        self.micro_batch_id = 0

        # store total loss of entire batch
        self.total_loss = None

324
        startup_steps = self.num_stages - self.stage_id - 1
325 326 327 328 329 330 331
        startup_steps = min(startup_steps, self.accumulate_steps)
        steady_steps = self.accumulate_steps - startup_steps

        input_buffers = []
        output_buffers = []

        for step_id in range(startup_steps):
332
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
333 334

            output_tensor = self._forward_step(input_tensor)
335
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
336 337 338 339 340

            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)

        if steady_steps > 0:
341
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
342 343

        for i in range(steady_steps):
344
            last_iter = i == (steady_steps - 1)
345 346

            output_tensor = self._forward_step(input_tensor)
347
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
348 349 350 351 352

            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)

            if not last_iter:
353
                input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
354

355 356 357 358 359 360
        if self._compute_loss:
            self.train_loss = self._broadcast_final_loss()
        else:
            self.train_loss = output_buffers

        return self.train_loss
361

362
    def _forward_step(self, input_tensor, chunk_id=None):
Y
Yuang Liu 已提交
363 364
        if self._enable_timer:
            self.timers("forward_step").start()
365
        if self.is_pipeline_first_stage():
366 367
            input_tensor = self._load_micro_batch(self.micro_batch_id)

368 369 370
        assert chunk_id is None or isinstance(chunk_id, int)

        output_tensor = self._layers.forward(input_tensor, chunk_id=chunk_id)
371

372
        if self.is_pipeline_last_stage():
373 374
            # train calculate loss for train
            if self._compute_loss:
375 376 377
                assert (
                    self._layers._loss_fn is not None
                ), "loss function should exist to compute loss"
378 379
                labels = self._load_micro_batch(self.micro_batch_id)
                output_tensor = self._layers._loss_fn(output_tensor, labels)
380
                assert isinstance(
381
                    output_tensor, (paddle.Tensor, framework.core.eager.Tensor)
382
                ), "Currently, loss_fn should obtain Paddle.Tensor dtype"
383

384
                with paddle.amp.auto_cast(enable=False):
385
                    if self.accumulate_steps > 1 and not self._delay_scale_loss:
386
                        output_tensor = output_tensor / self.accumulate_steps
387

388 389 390
                    if self.total_loss is None:
                        self.total_loss = paddle.zeros_like(output_tensor)
                    self.total_loss += output_tensor.detach()
391

392 393 394 395
        if self.is_pipeline_first_stage() or self.is_pipeline_last_stage():
            # Only increase micro batch id at virtual first/last pp stage.
            # The micro batch id is used to load data, therefore, only increase it when load data.
            self.micro_batch_id += 1
Y
Yuang Liu 已提交
396 397
        if self._enable_timer:
            self.timers("forward_step").stop()
398 399 400
        return output_tensor

    def _backward_step(self, input_tensor, output_tensor, output_tensor_grad):
Y
Yuang Liu 已提交
401 402
        if self._enable_timer:
            self.timers("backward_step").start()
S
ShenLiang 已提交
403
        with paddle.amp.auto_cast(enable=False):
404
            if self.is_pipeline_last_stage():
S
ShenLiang 已提交
405 406 407 408 409
                assert output_tensor_grad is None
                if self.scaler:
                    paddle.autograd.backward(self.scaler.scale(output_tensor))
                else:
                    paddle.autograd.backward(output_tensor)
410
            else:
S
ShenLiang 已提交
411 412 413 414 415
                if isinstance(output_tensor, tuple):
                    outputs = [t for t in output_tensor if not t.stop_gradient]
                    assert len(outputs) == len(output_tensor_grad)
                    paddle.autograd.backward(
                        tensors=outputs,
416
                        grad_tensors=list(output_tensor_grad),
417
                    )
S
ShenLiang 已提交
418
                else:
419 420 421 422
                    paddle.autograd.backward(
                        tensors=[output_tensor],
                        grad_tensors=[output_tensor_grad],
                    )
S
ShenLiang 已提交
423 424 425 426 427

            input_tensor_grad = None
            if input_tensor is not None:
                if isinstance(input_tensor, tuple):
                    input_tensor_grad = tuple(
428 429
                        [t.grad for t in input_tensor if not t.stop_gradient]
                    )
S
ShenLiang 已提交
430 431
                else:
                    input_tensor_grad = input_tensor.grad
Y
Yuang Liu 已提交
432 433
            if self._enable_timer:
                self.timers("backward_step").stop()
S
ShenLiang 已提交
434
            return input_tensor_grad
435

436 437 438 439 440 441 442 443 444
    def _check_data_vaild(self, data):
        batch_size = data.shape[0]
        assert self.micro_batch_size * self.accumulate_steps == batch_size, (
            "batch_size needs to be divisible by micro_batch_size. Currently, "
            "batch_size = %d, micro_batch_size = %d, accumulate_steps = %d."
            % (batch_size, self.micro_batch_size, self.accumulate_steps)
        )

    def _load_micro_batch_impl(self, inputs, cache_id):
445 446 447
        begin = cache_id * self.micro_batch_size
        end = begin + self.micro_batch_size

448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477
        if isinstance(inputs, tuple):
            output = []
            for data in inputs:
                if isinstance(data, list):
                    assert (
                        len(data) == self.accumulate_steps
                    ), "length of data should be %d, but it is %d" % (
                        self.accumulate_steps,
                        len(data),
                    )
                    output.append(data[cache_id].detach())
                else:
                    self._check_data_vaild(data)
                    output.append(data[begin:end, :].detach())
            return tuple(output)

        elif isinstance(inputs, list):
            assert (
                len(inputs) == self.accumulate_steps
            ), "length of data should be %d, but it is %d" % (
                self.accumulate_steps,
                len(inputs),
            )
            return inputs[cache_id].detach()
        else:
            self._check_data_vaild(inputs)
            return inputs[begin:end, :].detach()

    def _load_micro_batch(self, cache_id):
        inputs = self.data
478
        if self.is_pipeline_first_stage():
479
            assert len(inputs) == 2, "length of input should be 2"
480
            return self._load_micro_batch_impl(inputs[0], cache_id)
481
        elif self.is_pipeline_last_stage():
482
            assert len(inputs) == 2, "length of input should be 2"
483
            return self._load_micro_batch_impl(inputs[1], cache_id)
484 485
        else:
            inputs = None
486

487
    def _broadcast_final_loss(self):
488 489 490
        # Since the last backward run in interleave will set the virtual rank to 0,
        # here we need to check last stage ignoring virtual stage.
        if self.is_pipeline_last_stage(ignore_virtual=True):
491 492 493
            assert (
                self.total_loss is not None
            ), "train_batch() in last stage should obtain vaild loss"
494 495 496 497 498
            loss = (
                self.total_loss.detach()
                if not self._delay_scale_loss
                else self.total_loss / self.accumulate_steps
            )
499
            is_fp32 = (
500
                paddle.full([], 1, 'int64')
501
                if loss.dtype == paddle.float32
502
                else paddle.full([], 0, 'int64')
503 504 505 506 507 508 509
            )
            paddle.distributed.broadcast(
                is_fp32, src=self.global_rank, sync_op=True, group=self.pp_group
            )
            paddle.distributed.broadcast(
                loss, src=self.global_rank, sync_op=True, group=self.pp_group
            )
510
        else:
511
            is_fp32 = paddle.full([], 1, 'int64')
512 513 514
            paddle.distributed.broadcast(
                is_fp32,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
515
                sync_op=True,
516 517 518 519
                group=self.pp_group,
            )
            loss = (
                paddle.zeros(shape=[1], dtype="float32")
520
                if is_fp32.item()
521 522
                else paddle.zeros(shape=[1], dtype="float16")
            )
523 524 525
            paddle.distributed.broadcast(
                loss,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
526
                sync_op=True,
527 528
                group=self.pp_group,
            )
529
        return loss
530

531
    def _optimizer_step(self):
532 533 534 535 536 537 538 539
        if self._delay_scale_loss:
            for p in self._layers.parameters():
                if hasattr(p, "main_grad") and p.main_grad is not None:
                    assert p.grad is None
                    p.main_grad = p.main_grad.scale(1.0 / self.accumulate_steps)
                elif p.grad is not None:
                    p.grad = p.grad.scale(1.0 / self.accumulate_steps)

540
        if self.scaler:
541
            self.scaler.step(self.optimizer)
S
ShenLiang 已提交
542
            self.scaler.update()
543 544
        else:
            self.optimizer.step()
545

546 547 548
        self.optimizer.clear_grad()
        if self.lr_scheduler:
            self.lr_scheduler.step()
549 550 551 552 553 554


class PipelineParallelWithInterleave(PipelineParallel):
    # pipeline parallel with interleave scheduler

    def __init__(self, layers, hcg, strategy):
555
        super().__init__(layers=layers, hcg=hcg, strategy=strategy)
556
        assert layers.get_num_virtual_stages() > 1
557 558
        assert (
            framework.in_dygraph_mode()
559 560 561 562 563 564 565 566 567 568
        ), "virtual pipeline stage with interleave only support eager dygraph mode"
        # setup for interleave scheduler
        self.num_model_chunks = layers.get_num_virtual_stages()
        self.model_chunks = layers.get_model_chunks()
        assert self.model_chunks is not None
        assert len(self.model_chunks) == self.num_model_chunks
        self._virtual_pp_world_size = self.num_model_chunks
        self._virtual_pp_rank = 0

    def _get_virtual_pp_rank(self, micro_step, forward):
569 570 571
        virtual_pp_stage = micro_step % (
            self.num_stages * self.num_model_chunks
        )
572 573
        virtual_pp_stage = virtual_pp_stage // self.num_stages
        if not forward:
574
            virtual_pp_stage = self.num_model_chunks - virtual_pp_stage - 1
575 576 577 578 579 580 581 582 583 584 585 586 587 588
        return virtual_pp_stage

    def _forward_step_helper(self, micro_step):
        virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=True)
        self.set_virtual_pipeline_rank(virtual_pp_rank)

        # some checkers
        assert hasattr(self, 'input_tensors')
        assert hasattr(self, 'output_tensors')
        if not self._forward_only:
            assert hasattr(self, 'output_tensor_grads')

        if self.is_pipeline_first_stage():
            if len(self.input_tensors[virtual_pp_rank]) == len(
589 590
                self.output_tensors[virtual_pp_rank]
            ):
591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618
                self.input_tensors[virtual_pp_rank].append(None)
        input_tensor = self.input_tensors[virtual_pp_rank][-1]
        output_tensor = self._forward_step(input_tensor, virtual_pp_rank)
        self.output_tensors[virtual_pp_rank].append(output_tensor)

        if self._forward_only:
            # no need to store tensor for backward
            self.input_tensors[virtual_pp_rank].pop()
            self.output_tensors[virtual_pp_rank].pop()

        return output_tensor

    def _backward_step_helper(self, micro_step):
        virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=False)
        self.set_virtual_pipeline_rank(virtual_pp_rank)

        # some checkers
        assert hasattr(self, 'input_tensors')
        assert hasattr(self, 'output_tensors')
        assert hasattr(self, 'output_tensor_grads')

        if self.is_pipeline_last_stage():
            if len(self.output_tensor_grads[virtual_pp_rank]) == 0:
                self.output_tensor_grads[virtual_pp_rank].append(None)

        input_tensor = self.input_tensors[virtual_pp_rank].pop(0)
        output_tensor = self.output_tensors[virtual_pp_rank].pop(0)
        output_tensor_grad = self.output_tensor_grads[virtual_pp_rank].pop(0)
619 620 621
        input_tensor_grad = self._backward_step(
            input_tensor, output_tensor, output_tensor_grad
        )
622 623 624

        return input_tensor_grad

625
    def forward_backward_pipeline(
626 627
        self, data, scaler, forward_only=False, compute_loss=True
    ):
628 629 630 631
        # use interleave scheduling strategy.
        # this strategy is inspired by:
        # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
        if not compute_loss:
632 633 634
            assert (
                not forward_only
            ), "compute_loss can only be set to False when forward_only is set to True"
635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665

        # init some attributes for this batch run
        self.scaler = scaler
        self.data = data
        self.total_loss = None
        self.micro_batch_id = 0
        self._forward_only = forward_only

        # init some data buffers for interleave scheduler
        self.input_tensors = [[] for _ in range(self.num_model_chunks)]
        self.output_tensors = [[] for _ in range(self.num_model_chunks)]
        self.output_tensor_grads = [[] for _ in range(self.num_model_chunks)]

        num_steps = self.accumulate_steps * self.num_model_chunks
        all_startup_steps = False
        if forward_only:
            # If only forward, since there is no backward during running, all steps are startup steps
            startup_steps = num_steps
        else:
            if self.accumulate_steps == self.num_stages:
                startup_steps = num_steps
                all_startup_steps = True
            else:
                startup_steps = (self.num_stages - self.stage_id - 1) * 2
                startup_steps += (self.num_model_chunks - 1) * self.num_stages
                startup_steps = min(startup_steps, num_steps)

        steady_steps = num_steps - startup_steps

        self.set_virtual_pipeline_rank(0)
        self.input_tensors[0].append(
666 667
            p2p.recv_forward(self.is_pipeline_first_stage(), sync_recv=False)
        )
668 669 670 671 672 673

        # run startup steps
        for micro_step in range(startup_steps):
            output_tensor = self._forward_step_helper(micro_step)

            # determine whether recv forward tensor or not
674 675 676
            next_virtual_pp_rank = self._get_virtual_pp_rank(
                micro_step + 1, forward=True
            )
677 678 679 680 681 682 683 684 685 686 687 688 689
            recv_prev = True
            if self.is_pipeline_first_stage(ignore_virtual=True):
                if next_virtual_pp_rank == 0:
                    # next chunk is the first chunk, not need to pre recv an input tensor
                    recv_prev = False
            # last micro step, no next run
            if micro_step == (num_steps - 1):
                recv_prev = False

            # last stage shouldn't send tensor to downstream
            if self.is_pipeline_last_stage():
                output_tensor = None

690 691 692 693 694
            if (
                micro_step == (startup_steps - 1)
                and not forward_only
                and not all_startup_steps
            ):
695 696 697 698 699 700
                input_tensor_grad = None
                recv_next = True
                if self.is_pipeline_last_stage(ignore_virtual=True):
                    recv_next = False

                # the last startup step needs on four direction comm to set up for steady 1f1b
701 702 703 704
                (
                    input_tensor,
                    output_tensor_grad,
                ) = p2p.send_forward_backward_recv_forward_backward(
705 706 707
                    output_tensor,
                    input_tensor_grad,
                    recv_prev=recv_prev,
708 709 710 711 712
                    recv_next=recv_next,
                )
                self.output_tensor_grads[self.num_model_chunks - 1].append(
                    output_tensor_grad
                )
713 714
            else:
                input_tensor = p2p.send_forward_recv_forward(
715 716
                    output_tensor, recv_prev=recv_prev
                )
717 718 719 720 721 722 723 724 725 726 727
            self.input_tensors[next_virtual_pp_rank].append(input_tensor)

        # run 1f1b steady steps
        for micro_step in range(steady_steps):
            # forward
            forward_micro_step_id = micro_step + startup_steps
            output_tensor = self._forward_step_helper(forward_micro_step_id)

            # backward
            backward_micro_step_id = micro_step
            input_tensor_grad = self._backward_step_helper(
728 729
                backward_micro_step_id
            )
730 731 732 733 734 735 736 737 738

            # four directions comm
            # send output tensor to downstream
            # send input tensor grad to upstream
            # recv input tensor from upstream
            # recv output tensor grad from downstream

            # last stage doesn't send rst to downstream
            forward_virtual_pp_rank = self._get_virtual_pp_rank(
739 740
                forward_micro_step_id, forward=True
            )
741 742 743 744 745 746
            self.set_virtual_pipeline_rank(forward_virtual_pp_rank)
            if self.is_pipeline_last_stage():
                output_tensor = None

            # first stage doesn't send grad to upstream
            backward_virtual_pp_rank = self._get_virtual_pp_rank(
747 748
                backward_micro_step_id, forward=False
            )
749 750 751 752 753 754 755 756
            self.set_virtual_pipeline_rank(backward_virtual_pp_rank)
            if self.is_pipeline_first_stage():
                input_tensor_grad = None

            # determine whether to recv input tensor from upstream
            recv_prev = True
            if self.is_pipeline_first_stage(ignore_virtual=True):
                next_forward_virtual_pp_rank = self._get_virtual_pp_rank(
757 758
                    forward_micro_step_id - (self.num_stages - 1), forward=True
                )
759 760 761 762 763 764
                if next_forward_virtual_pp_rank == (self.num_model_chunks - 1):
                    # first pp stage and first virtual stage
                    recv_prev = False
                next_forward_virtual_pp_rank += 1
            else:
                next_forward_virtual_pp_rank = self._get_virtual_pp_rank(
765 766
                    forward_micro_step_id + 1, forward=True
                )
767 768 769 770 771 772 773 774 775 776

            # last iteration doesn't need recv from upstream
            if micro_step == (steady_steps - 1):
                recv_prev = False

            # determine whether to recv grad from downstream
            recv_next = True
            if self.is_pipeline_last_stage(ignore_virtual=True):
                next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
                    backward_micro_step_id - (self.num_stages - 1),
777 778
                    forward=False,
                )
779 780 781 782 783 784
                if next_backward_virtual_pp_rank == 0:
                    # last pp stage and last virtual stage
                    recv_next = False
                next_backward_virtual_pp_rank -= 1
            else:
                next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
785 786
                    backward_micro_step_id + 1, forward=False
                )
787

788 789 790 791
            (
                input_tensor,
                output_tensor_grad,
            ) = p2p.send_forward_backward_recv_forward_backward(
792 793 794
                output_tensor,
                input_tensor_grad,
                recv_prev=recv_prev,
795 796
                recv_next=recv_next,
            )
797 798 799

            if recv_prev:
                self.input_tensors[next_forward_virtual_pp_rank].append(
800 801
                    input_tensor
                )
802 803
            if recv_next:
                self.output_tensor_grads[next_backward_virtual_pp_rank].append(
804 805
                    output_tensor_grad
                )
806 807 808 809 810

        # remaining backward steps
        if not forward_only:
            if all_startup_steps:
                self.output_tensor_grads[self.num_model_chunks - 1].append(
811 812 813 814
                    p2p.recv_backward(
                        self.is_pipeline_last_stage(), sync_recv=False
                    )
                )
815 816 817 818 819

            for micro_step in range(steady_steps, num_steps):
                # cooldown loop
                input_tensor_grad = self._backward_step_helper(micro_step)
                next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
820 821
                    micro_step + 1, forward=False
                )
822 823 824

                recv_next = True
                if self.is_pipeline_last_stage(ignore_virtual=True):
825 826 827
                    if next_backward_virtual_pp_rank == (
                        self.num_model_chunks - 1
                    ):
828 829 830 831 832 833
                        recv_next = False

                if micro_step == (num_steps - 1):
                    recv_next = False

                self.output_tensor_grads[next_backward_virtual_pp_rank].append(
834 835 836 837
                    p2p.send_backward_recv_backward(
                        input_tensor_grad, recv_next=recv_next
                    )
                )
838

839 840 841 842 843
            if self._dp_comm_overlap:
                assert len(self._dp_comm_buffers) > 0
                for buffer in self._dp_comm_buffers:
                    buffer.scale_and_split_grads()

Y
Yuang Liu 已提交
844 845
            if self._enable_timer:
                self.timers("allreduce_shared_weight_gradients").start()
846
            self._layers.allreduce_shared_weight_gradients()
Y
Yuang Liu 已提交
847 848
            if self._enable_timer:
                self.timers("allreduce_shared_weight_gradients").stop()
849 850 851

        if compute_loss:
            # return loss if compute loss
Y
Yuang Liu 已提交
852 853
            if self._enable_timer:
                self.timers("broadcast_final_loss").start()
854 855
            with paddle.amp.auto_cast(enable=False):
                train_loss = self._broadcast_final_loss()
Y
Yuang Liu 已提交
856 857
            if self._enable_timer:
                self.timers("broadcast_final_loss").stop()
858 859 860 861
        else:
            # else just return all intermediate output tensor for all micro steps
            train_loss = self.output_tensors

Y
Yuang Liu 已提交
862
        self.timer_printer()
863 864 865 866 867
        return train_loss

    def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
        data = self._prepare_training(data, optimizer, lr_scheduler)
        # interleave scheduler for pipeline parallel
868
        train_loss = self.forward_backward_pipeline(data, scaler)
869 870 871 872 873 874 875 876 877 878 879 880 881 882

        # optimizer
        with paddle.amp.auto_cast(enable=False):
            self._optimizer_step()

        return train_loss

    def eval_batch(self, data, compute_loss=False):
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

        self._layers.eval()
        self._compute_loss = compute_loss

883
        return self.forward_backward_pipeline(data, None, forward_only=True)