pipeline_parallel.py 49.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
S
ShenLiang 已提交
13
import os
14 15
import time
import warnings
16 17

import paddle
18
from paddle import framework
19

20
from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer
Y
Yuang Liu 已提交
21
from ..utils import timer_helper as timer
22 23 24 25 26 27 28 29
from ..utils.hybrid_parallel_util import (
    broadcast_dp_parameters,
    broadcast_mp_parameters,
    broadcast_sharding_parameters,
)
from ..utils.log_util import logger
from .meta_parallel_base import MetaParallelBase
from .parallel_layers.pp_layers import PipelineLayer
30 31 32 33 34 35 36 37 38

_use_four_directions = os.environ.get(
    'PADDLE_USE_FOUR_DIRECTIONS_P2P', paddle.fluid.core.is_compiled_with_xpu()
)
if _use_four_directions:
    from .pp_utils import four_directions_p2p_communication as p2p
else:
    from .pp_utils import p2p_communication as p2p

39
from .pp_utils.utils import HOOK_ACTION, FusedCommBuffer, assign_group_by_size
40

41 42
__all__ = []

43

44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
# assume only the first stage and last stage need data, and data consumption is ordred
# to be replaced by real micro dataset from reader
class FakeMicroDataset:
    def __init__(
        self, data, is_first_stage, is_last_stage, acc_steps, micro_batch_size
    ):
        self._data = data
        self._index = 0
        self._acc_steps = acc_steps
        self._is_first_stage = is_first_stage
        self._is_last_stage = is_last_stage
        self._micro_batch_size = micro_batch_size

    def __iter__(self):
        return self

    def __next__(self):
        if self._index >= self._acc_steps:
            raise StopIteration
        assert self._is_first_stage or self._is_last_stage
        micro_batch_data = self._load_micro_batch(self._index)
        self._index += 1
        return micro_batch_data

    def _load_micro_batch(self, micro_step):
        inputs = self._data

        if self._is_first_stage or self._is_last_stage:
            assert len(inputs) == 2, "length of input should be 2"
            data = self._load_micro_batch_impl(inputs[0], micro_step)
            label = self._load_micro_batch_impl(inputs[1], micro_step)
            return (data, label)
        else:
            return (None, None)

    def _load_micro_batch_impl(self, inputs, micro_step):
        begin = micro_step * self._micro_batch_size
        end = begin + self._micro_batch_size

        if isinstance(inputs, tuple):
            output = []
            for data in inputs:
                if isinstance(data, list):
                    assert (
                        len(data) == self._acc_steps
                    ), "length of data should be %d, but it is %d" % (
                        self._acc_steps,
                        len(data),
                    )
                    output.append(data[micro_step].detach())
                elif data is not None:
                    self._check_data_vaild(data)
                    output.append(data[begin:end, :].detach())
                else:
                    output.append(None)
            return tuple(output)

        elif isinstance(inputs, list):
            assert (
                len(inputs) == self._acc_steps
            ), "length of data should be %d, but it is %d" % (
                self.accumulate_steps,
                len(inputs),
            )
            return inputs[micro_step].detach()
        elif inputs is not None:
            self._check_data_vaild(inputs)
            return inputs[begin:end, :].detach()
        else:
            return None

    def _check_data_vaild(self, data):
        batch_size = data.shape[0]
        assert self._micro_batch_size * self._acc_steps == batch_size, (
            "batch_size needs to be divisible by micro_batch_size. Currently, "
            "batch_size = %d, micro_batch_size = %d, accumulate_steps = %d."
            % (batch_size, self._micro_batch_size, self._acc_steps)
        )


124 125
class PipelineParallel(MetaParallelBase):
    def __init__(self, layers, hcg, strategy):
126 127
        if not isinstance(layers, PipelineLayer):
            raise TypeError(
128 129
                "The Layer should be a derived class of PipelineLayer."
            )
130
        super().__init__(layers, hcg, strategy)
131 132
        self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1
        self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1
133 134 135
        self.use_sharding_parallel = (
            self._hcg.get_sharding_parallel_world_size() > 1
        )
136 137 138 139

        self.total_loss = None

        self.micro_batch_size = self._strategy.pipeline_configs[
140 141
            'micro_batch_size'
        ]
142
        self.accumulate_steps = self._strategy.pipeline_configs[
143 144
            'accumulate_steps'
        ]
145 146 147
        # If sent tensor are not the same from different hosts,
        # they shouldn't been sent partially and then concated as a whole tensor.
        self._enable_partial_send_recv = self._strategy.pipeline_configs[
148 149
            'enable_partial_send_recv'
        ]
150 151
        self._using_cache = self._strategy.pipeline_configs['p2p_cache_shape']

152 153
        self.num_stages = self._hcg.get_pipe_parallel_world_size()
        self.stage_id = self._hcg.get_stage_id()
154
        self.global_rank = self._hcg.get_global_rank()
155
        self.pp_group = self._hcg.get_pipe_parallel_group()
156
        self.dp_group = self._hcg.get_data_parallel_group()
157
        self.sharding_group = self._hcg.get_sharding_parallel_group()
158

159 160 161 162 163
        self._virtual_pp_world_size = None
        self._virtual_pp_rank = None
        self._real_pp_world_size = self.num_stages
        self._real_pp_rank = self.stage_id

164 165 166
        self._delay_scale_loss = self._strategy.hybrid_configs[
            "pp_configs"
        ].delay_scale_loss
167 168
        # TODO(PP Dev): support dp_comm_overlap without use_main_grad training.
        # This combination will trigger inplace check error during `reshape_` in funct `_split_tensors`.
169 170 171
        self._dp_comm_overlap = self._strategy.hybrid_configs[
            "pp_configs"
        ].dp_comm_overlap
172 173 174
        self._sharding_comm_overlap = self._strategy.hybrid_configs[
            "pp_configs"
        ].sharding_comm_overlap
Y
Yuang Liu 已提交
175 176 177
        self._enable_timer = self._strategy.hybrid_configs[
            "pp_configs"
        ].enable_timer
S
ShenLiang 已提交
178

179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202
        self._profiling = self._strategy.hybrid_configs["pp_configs"].profiling
        self._records = []
        self._record_format = (
            '"name": "{}{}", "cat": "pipeline timeline", "ph": {}, "pid": 0, "tid": '
            + str(self.stage_id + 1)
            + ', "ts": {}, "cname": "{}"'
        )
        self._forward_color = "thread_state_running"  # RGB: 126, 200, 148
        self._backward_color = "rail_idle"  # RGB: 238, 142, 0
        if self._profiling:
            logger.info(
                "If enable pp profiling, the max training steps should be restricted "
                "to a reasonable value (such as 5) to avoid generating large profile files. "
                "The profiler will generate a profile file 'profile_record_tmp_file_for_rank_*' "
                "for each rank. Users should gather all profile files for one entire pipeline "
                "to one node (rank 0 is recommended) to get the full view of the pipeline profile. "
                "[DONT CHANGE THE NAME OF THE PROFILE FILES!]. "
                "Then get the profile parser from this url: "
                "https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/distributed/fleet/meta_parallel/pp_utils/profiler_helper.py "
                "and save the script to the same directory of all profile files."
                "Parse those files by this command: `python profiler_helper.py`. "
                "After parsing, a new file 'pipeline_profile.json' will be generated. "
                "Users can inspect this file by chrome://tracing website."
            )
203

204 205 206
        if self._dp_comm_overlap:
            assert self.use_data_parallel and self.num_stages > 1

207 208 209 210 211 212 213 214 215 216 217 218
        if self._sharding_comm_overlap:
            assert self.use_sharding_parallel and self.num_stages > 1

        assert not (
            self._dp_comm_overlap and self._sharding_comm_overlap
        ), "Cannot use dp pp overlap and sharding pp overlap at the same time."

        self._comm_buffers = []
        self._comm_overlap = (
            self._dp_comm_overlap or self._sharding_comm_overlap
        )

Y
Yuang Liu 已提交
219 220 221 222 223
        if self._enable_timer:
            if not timer.is_timer_initialized():
                timer.set_timers()
            self.timers = timer.get_timers()

224
        p2p.initialize_p2p_groups(
Y
Yuang Liu 已提交
225 226 227 228
            hcg,
            self._using_cache,
            self._enable_partial_send_recv,
            self._enable_timer,
229
        )
230 231

        self.global_rank = self._hcg.get_global_rank()
232
        self.micro_batch_id = 0
233

234 235
        self._compute_loss = True

236 237 238 239 240
        logger.info(
            "Pipeline Info -- num_stages: {}, stage_id: {}".format(
                self.num_stages, self.stage_id
            )
        )
241 242 243 244 245

        if self.use_model_parallel:
            logger.info("start broadcast mp parameters")
            broadcast_mp_parameters(self._layers, self._hcg)

246 247 248 249
        if self.use_sharding_parallel:
            logger.info("start broadcast sharding parameters")
            broadcast_sharding_parameters(self._layers, self._hcg)

250
        if self.use_data_parallel:
251
            logger.info("start broadcast dp parameters")
252
            broadcast_dp_parameters(self._layers, self._hcg)
253

254 255
        if self._dp_comm_overlap:
            self.register_allreduce_overlap_hook(
256
                self._layers, self.dp_group, self.accumulate_steps, True
257 258
            )

259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280
    def is_pipeline_first_stage(self, ignore_virtual=False):
        if not ignore_virtual:
            if self._virtual_pp_world_size is not None:
                assert self._virtual_pp_rank is not None
                if self._virtual_pp_rank != 0:
                    return False
        assert self._real_pp_rank is not None
        return self._real_pp_rank == 0

    def is_pipeline_last_stage(self, ignore_virtual=False):
        if not ignore_virtual:
            if self._virtual_pp_world_size is not None:
                assert self._virtual_pp_rank is not None
                if self._virtual_pp_rank != (self._virtual_pp_world_size - 1):
                    return False
        assert self._real_pp_rank is not None
        assert self._real_pp_world_size is not None
        return self._real_pp_rank == (self._real_pp_world_size - 1)

    def set_virtual_pipeline_rank(self, rank):
        self._virtual_pp_rank = rank

281 282 283 284 285 286 287
    def bw_hook_func(self, buffer, param):
        @paddle.autograd.no_grad()
        def fused_allreduce(*_):
            buffer.add_grad(param)

        return fused_allreduce

288
    def register_allreduce_overlap_hook(self, model, comm_group, acc_steps, dp):
289 290 291 292 293
        if model.get_num_virtual_stages() > 1:
            models = model.get_model_chunks()
        else:
            models = [model]

294 295 296 297 298 299 300 301 302
        if not dp:
            assert hasattr(self, "optimizer")
            assert hasattr(self.optimizer, "_param2rank")
            _param2rank = self.optimizer._param2rank

        act = HOOK_ACTION.ALL_REDUCE if dp else HOOK_ACTION.REDUCE

        fused_parameter_group = {}

303 304 305
        for model in models:
            # For virtual pipeline. Will separate parameters in different chunk into
            # different groups to get the best performance.
S
ShenLiang 已提交
306

307 308 309 310 311 312
            parameter_list = [
                p for p in model.parameters() if not p.stop_gradient
            ]
            if len(parameter_list) < 1:
                return

313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333
            if dp:
                fused_parameter_group[-1] = parameter_list
            else:
                # Sort parameters for sharding, since they have different dst rank
                for p in parameter_list:
                    assert p.name in _param2rank
                    dst_rank = _param2rank[p.name]
                    if dst_rank in fused_parameter_group:
                        fused_parameter_group[dst_rank].append(p)
                    else:
                        fused_parameter_group[dst_rank] = [p]

            for dst in fused_parameter_group:
                parameter_list = fused_parameter_group[dst]
                if not dp:
                    # parse the relative dst rank to absolute dst rank for sharding
                    dst = comm_group.ranks[dst]
                var_groups = assign_group_by_size(parameter_list)
                for group_idx, parameters in var_groups.items():
                    buffer = FusedCommBuffer(
                        group_idx, parameters, comm_group, acc_steps, act, dst
334
                    )
335 336 337 338 339
                    self._comm_buffers.append(buffer)
                    for param in parameters:
                        param._register_backward_hook(
                            self.bw_hook_func(buffer, param)
                        )
340

Y
Yuang Liu 已提交
341 342 343 344 345 346
    def timer_printer(self):
        if not self._enable_timer:
            return
        all_flag_names = self.timers.timers.keys()
        self.timers.log(all_flag_names)

347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374
    def _record_stamp(self, name, step, phase, color):
        if self._profiling:
            paddle.device.synchronize()
            self._records.append(
                '{'
                + self._record_format.format(
                    name,
                    step,
                    phase,
                    int(time.time() * 1000),
                    color,
                )
                + '}'
            )

    def _flush_records(self):
        if self._profiling:
            with open(
                f'./profile_record_tmp_file_for_rank_{self.global_rank}',
                'a+',
            ) as f:
                for record in self._records:
                    f.write(record + '\n')
            self._records = []

    def forward_backward_pipeline(
        self, data, scaler=None, static_scheduler=False
    ):
375 376 377
        # use the 1f1b scheduling strategy.
        # this strategy is inspired by:
        # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
378

379 380 381 382 383 384 385 386 387 388 389 390 391
        if static_scheduler:
            assert (
                not self._profiling
            ), "While _profiling, static scheduler is not available"
            if data is not None:
                warnings.warn(
                    "Static scheduler run won't real run the model, but data has been provided"
                )
            logger.info(
                "enable static_scheduler will return the pp schedule instead of the loss"
            )
            schedule = ""

392 393
        self.scaler = scaler

394 395 396
        # store total loss of entire batch
        self.total_loss = None

397 398
        # store data id for micro_batch
        self.micro_batch_id = 0
399

400
        startup_steps = self.num_stages - self.stage_id - 1
401 402
        startup_steps = min(startup_steps, self.accumulate_steps)
        steady_steps = self.accumulate_steps - startup_steps
403

404 405
        input_buffers = []
        output_buffers = []
406

407 408
        micro_dataset = self._wrap_data(data)

409
        for step_id in range(startup_steps):
410 411 412 413
            if static_scheduler:
                schedule += f"f{step_id};"
                logger.info(f"forward step for micro step {step_id}")
                continue
414
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
415

416
            self._record_stamp("F", step_id, '"B"', self._forward_color)
417
            output_tensor = self._forward_step(input_tensor, micro_dataset)
418
            self._record_stamp("F", step_id, '"E"', self._forward_color)
419
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
420

421 422
            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)
423

424 425 426
            if not self.is_pipeline_last_stage():
                self._release_output(output_tensor)

427
        if steady_steps > 0 and not static_scheduler:
428
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
429

430
        for i in range(steady_steps):
431 432 433 434 435 436
            if static_scheduler:
                schedule += f"f{startup_steps + i};"
                schedule += f"b{i};"
                logger.info(f"forward step for micro step {startup_steps + i}")
                logger.info(f"backward step for micro step {i}")
                continue
437
            last_iter = i == (steady_steps - 1)
438

439 440 441
            self._record_stamp(
                "F", startup_steps + i, '"B"', self._forward_color
            )
442
            output_tensor = self._forward_step(input_tensor, micro_dataset)
443 444 445
            self._record_stamp(
                "F", startup_steps + i, '"E"', self._forward_color
            )
446

447
            output_tensor_grad = p2p.send_forward_recv_backward(
448 449
                output_tensor, self.is_pipeline_last_stage()
            )
450

451 452
            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)
453

454 455 456
            if not self.is_pipeline_last_stage():
                self._release_output(output_tensor)

457
            input_tensor, output_tensor = input_buffers.pop(
458 459
                0
            ), output_buffers.pop(0)
460

461
            self._record_stamp("B", i, '"B"', self._backward_color)
462 463 464
            input_tensor_grad = self._backward_step(
                input_tensor, output_tensor, output_tensor_grad
            )
465
            self._record_stamp("B", i, '"E"', self._backward_color)
466 467 468

            if last_iter:
                input_tensor = None
469 470 471
                p2p.send_backward(
                    input_tensor_grad, self.is_pipeline_first_stage()
                )
472
            else:
473
                input_tensor = p2p.send_backward_recv_forward(
474 475
                    input_tensor_grad, self.is_pipeline_first_stage()
                )
476

477
        for i in range(startup_steps):
478 479 480 481
            if static_scheduler:
                schedule += f"b{steady_steps + i};"
                logger.info(f"backward step for micro step {steady_steps + i}")
                continue
482 483
            input_tensor = input_buffers.pop(0)
            output_tensor = output_buffers.pop(0)
484

485
            output_tensor_grad = p2p.recv_backward(
486 487
                self.is_pipeline_last_stage()
            )
488

489 490 491
            self._record_stamp(
                "B", steady_steps + i, '"B"', self._backward_color
            )
492 493 494
            input_tensor_grad = self._backward_step(
                input_tensor, output_tensor, output_tensor_grad
            )
495 496 497
            self._record_stamp(
                "B", steady_steps + i, '"E"', self._backward_color
            )
498
            p2p.send_backward(input_tensor_grad, self.is_pipeline_first_stage())
499

500 501 502 503 504
        if static_scheduler:
            return schedule

        self._flush_records()

505 506 507
        if self._comm_overlap:
            assert len(self._comm_buffers) > 0
            for buffer in self._comm_buffers:
508 509
                buffer.scale_and_split_grads()

Y
Yuang Liu 已提交
510 511
        if self._enable_timer:
            self.timers("allreduce_shared_weight_gradients").start()
512
        self._layers.allreduce_shared_weight_gradients()
Y
Yuang Liu 已提交
513 514 515
        if self._enable_timer:
            self.timers("allreduce_shared_weight_gradients").stop()
            self.timers("broadcast_final_loss").start()
516 517
        with paddle.amp.auto_cast(enable=False):
            train_loss = self._broadcast_final_loss()
Y
Yuang Liu 已提交
518 519 520
        if self._enable_timer:
            self.timers("broadcast_final_loss").stop()
        self.timer_printer()
521 522
        return train_loss

523 524 525 526
    def _prepare_training(self, data, optimizer, lr_scheduler):
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

527 528 529
        assert isinstance(
            optimizer, HybridParallelOptimizer
        ), 'optimizer should be HybridParallelOptimizer subclass.'
530

531
        assert (
532
            framework._dygraph_tracer()._has_grad
533
        ), 'Please enable the generation of gradients.'
534

535
        if self.is_pipeline_first_stage(
536 537 538 539 540
            ignore_virtual=True
        ) or self.is_pipeline_last_stage(ignore_virtual=True):
            assert (
                data is not None
            ), "For the first and the last stage, the data must be set."
541 542 543 544 545 546 547 548
        else:
            data = None

        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler

        self._layers.train()

549 550 551 552 553
        if self._sharding_comm_overlap and len(self._comm_buffers) == 0:
            self.register_allreduce_overlap_hook(
                self._layers, self.sharding_group, self.accumulate_steps, False
            )

554 555
        return data

556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571
    def _wrap_data(self, data):
        """
        for backward compatibilty, wrap data to Fake FakeMicroDataset if it is of type list or tuple
        """
        if (not isinstance(data, tuple)) and (not isinstance(data, list)):
            return data

        micro_dataset = FakeMicroDataset(
            data,
            self.is_pipeline_first_stage(ignore_virtual=True),
            self.is_pipeline_last_stage(ignore_virtual=True),
            self.accumulate_steps,
            self.micro_batch_size,
        )
        return micro_dataset

572 573 574
    def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
        data = self._prepare_training(data, optimizer, lr_scheduler)
        # 1f1b scheduler for pipeline parallel
575
        train_loss = self.forward_backward_pipeline(data, scaler)
576 577

        # optimizer
578 579
        with paddle.amp.auto_cast(enable=False):
            self._optimizer_step()
580 581

        return train_loss
582

583
    def eval_batch(self, data, compute_loss=False):
584 585 586
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

587 588 589 590 591 592 593 594 595
        self._layers.eval()
        self._compute_loss = compute_loss

        # store data id for micro_batch
        self.micro_batch_id = 0

        # store total loss of entire batch
        self.total_loss = None

596
        startup_steps = self.num_stages - self.stage_id - 1
597 598 599 600 601 602
        startup_steps = min(startup_steps, self.accumulate_steps)
        steady_steps = self.accumulate_steps - startup_steps

        input_buffers = []
        output_buffers = []

603 604 605
        # convert to micro dataset
        micro_dataset = self._wrap_data(data)

606
        for step_id in range(startup_steps):
607
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
608

609
            output_tensor = self._forward_step(input_tensor, micro_dataset)
610
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
611 612 613 614 615

            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)

        if steady_steps > 0:
616
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
617 618

        for i in range(steady_steps):
619
            last_iter = i == (steady_steps - 1)
620

621
            output_tensor = self._forward_step(input_tensor, micro_dataset)
622
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
623 624 625 626 627

            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)

            if not last_iter:
628
                input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
629

630 631 632 633 634 635
        if self._compute_loss:
            self.train_loss = self._broadcast_final_loss()
        else:
            self.train_loss = output_buffers

        return self.train_loss
636

637
    def _forward_step(self, input_tensor, micro_dataset, chunk_id=None):
Y
Yuang Liu 已提交
638 639
        if self._enable_timer:
            self.timers("forward_step").start()
640
        if self.is_pipeline_first_stage():
641 642
            input_tensor = next(micro_dataset)[0]
            self._check_micro_batch_data_valid(input_tensor)
643

644 645 646
        assert chunk_id is None or isinstance(chunk_id, int)

        output_tensor = self._layers.forward(input_tensor, chunk_id=chunk_id)
647

648
        if self.is_pipeline_last_stage():
649 650
            # train calculate loss for train
            if self._compute_loss:
651 652 653
                assert (
                    self._layers._loss_fn is not None
                ), "loss function should exist to compute loss"
654 655
                labels = next(micro_dataset)[1]
                self._check_micro_batch_data_valid(labels)
656
                output_tensor = self._layers._loss_fn(output_tensor, labels)
657
                assert isinstance(
658
                    output_tensor, (paddle.Tensor, framework.core.eager.Tensor)
659
                ), "Currently, loss_fn should obtain Paddle.Tensor dtype"
660

661
                with paddle.amp.auto_cast(enable=False):
662
                    if self.accumulate_steps > 1 and not self._delay_scale_loss:
663
                        output_tensor = output_tensor / self.accumulate_steps
664

665 666 667
                    if self.total_loss is None:
                        self.total_loss = paddle.zeros_like(output_tensor)
                    self.total_loss += output_tensor.detach()
668

669 670 671 672
        if self.is_pipeline_first_stage() or self.is_pipeline_last_stage():
            # Only increase micro batch id at virtual first/last pp stage.
            # The micro batch id is used to load data, therefore, only increase it when load data.
            self.micro_batch_id += 1
Y
Yuang Liu 已提交
673 674
        if self._enable_timer:
            self.timers("forward_step").stop()
675 676 677
        return output_tensor

    def _backward_step(self, input_tensor, output_tensor, output_tensor_grad):
Y
Yuang Liu 已提交
678 679
        if self._enable_timer:
            self.timers("backward_step").start()
S
ShenLiang 已提交
680
        with paddle.amp.auto_cast(enable=False):
681
            if self.is_pipeline_last_stage():
S
ShenLiang 已提交
682 683 684 685 686
                assert output_tensor_grad is None
                if self.scaler:
                    paddle.autograd.backward(self.scaler.scale(output_tensor))
                else:
                    paddle.autograd.backward(output_tensor)
687
            else:
S
ShenLiang 已提交
688 689 690 691 692
                if isinstance(output_tensor, tuple):
                    outputs = [t for t in output_tensor if not t.stop_gradient]
                    assert len(outputs) == len(output_tensor_grad)
                    paddle.autograd.backward(
                        tensors=outputs,
693
                        grad_tensors=list(output_tensor_grad),
694
                    )
S
ShenLiang 已提交
695
                else:
696 697 698 699
                    paddle.autograd.backward(
                        tensors=[output_tensor],
                        grad_tensors=[output_tensor_grad],
                    )
S
ShenLiang 已提交
700 701 702 703 704

            input_tensor_grad = None
            if input_tensor is not None:
                if isinstance(input_tensor, tuple):
                    input_tensor_grad = tuple(
705 706
                        [t.grad for t in input_tensor if not t.stop_gradient]
                    )
S
ShenLiang 已提交
707 708
                else:
                    input_tensor_grad = input_tensor.grad
Y
Yuang Liu 已提交
709 710
            if self._enable_timer:
                self.timers("backward_step").stop()
S
ShenLiang 已提交
711
            return input_tensor_grad
712

713 714 715 716 717
    def _check_micro_batch_data_valid(self, micro_batch_data):
        if isinstance(micro_batch_data, (tuple, list)):
            for data in micro_batch_data:
                self._check_micro_batch_data_valid(data)
        elif micro_batch_data is not None:
zhenhailiu's avatar
zhenhailiu 已提交
718
            assert isinstance(micro_batch_data, paddle.Tensor)
719

720
    def _broadcast_final_loss(self):
721 722 723
        # Since the last backward run in interleave will set the virtual rank to 0,
        # here we need to check last stage ignoring virtual stage.
        if self.is_pipeline_last_stage(ignore_virtual=True):
724 725 726
            assert (
                self.total_loss is not None
            ), "train_batch() in last stage should obtain vaild loss"
727 728 729 730 731
            loss = (
                self.total_loss.detach()
                if not self._delay_scale_loss
                else self.total_loss / self.accumulate_steps
            )
732
            is_fp32 = (
733
                paddle.full([], 1, 'int64')
734
                if loss.dtype == paddle.float32
735
                else paddle.full([], 0, 'int64')
736 737 738 739 740 741 742
            )
            paddle.distributed.broadcast(
                is_fp32, src=self.global_rank, sync_op=True, group=self.pp_group
            )
            paddle.distributed.broadcast(
                loss, src=self.global_rank, sync_op=True, group=self.pp_group
            )
743
        else:
744
            is_fp32 = paddle.full([], 1, 'int64')
745 746 747
            paddle.distributed.broadcast(
                is_fp32,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
748
                sync_op=True,
749 750 751 752
                group=self.pp_group,
            )
            loss = (
                paddle.zeros(shape=[1], dtype="float32")
753
                if is_fp32.item()
754 755
                else paddle.zeros(shape=[1], dtype="float16")
            )
756 757 758
            paddle.distributed.broadcast(
                loss,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
759
                sync_op=True,
760 761
                group=self.pp_group,
            )
762
        return loss
763

764
    def _optimizer_step(self):
765 766 767 768 769 770 771 772
        if self._delay_scale_loss:
            for p in self._layers.parameters():
                if hasattr(p, "main_grad") and p.main_grad is not None:
                    assert p.grad is None
                    p.main_grad = p.main_grad.scale(1.0 / self.accumulate_steps)
                elif p.grad is not None:
                    p.grad = p.grad.scale(1.0 / self.accumulate_steps)

773
        if self.scaler:
774
            self.scaler.step(self.optimizer)
S
ShenLiang 已提交
775
            self.scaler.update()
776 777
        else:
            self.optimizer.step()
778

779 780 781
        self.optimizer.clear_grad()
        if self.lr_scheduler:
            self.lr_scheduler.step()
782

783
    def _release_output(self, output):
784 785 786 787
        def can_free(t):
            return (
                t is not None
                and isinstance(t, paddle.Tensor)
H
Haohongxiang 已提交
788
                and t._is_initialized()
789 790 791
                and t.inplace_version == 0
            )

792 793
        if isinstance(output, (tuple, list)):
            for t in output:
794
                if can_free(t):
795
                    t._clear_dataptr()
796 797

        elif can_free(output):
798 799
            output._clear_dataptr()

800 801 802
    def get_static_scheduler(self):
        return self.forward_backward_pipeline(data=None, static_scheduler=True)

803 804 805 806 807

class PipelineParallelWithInterleave(PipelineParallel):
    # pipeline parallel with interleave scheduler

    def __init__(self, layers, hcg, strategy):
808
        super().__init__(layers=layers, hcg=hcg, strategy=strategy)
809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825
        self._record_format = (
            '"name": "{}{}_VP{}", "cat": "virtual pipeline timeline", "ph": {}, "pid": 0, "tid": '
            + str(self.stage_id + 1)
            + ', "ts": {}, "cname": "{}"'
        )
        self._forward_colors = [
            "thread_state_running",  # RGB: 126, 200, 148
            "thread_state_unknown",  # RGB: 199, 155, 125
        ]
        self._backward_colors = [
            "rail_load",  # RGB: 13, 168, 97
            "rail_idle",  # RGB: 238, 142, 0
        ]
        # Structures to record the micro step for each layer chunk
        self._forward_micro_step_counter = {}
        self._backward_micro_step_counter = {}

826
        assert layers.get_num_virtual_stages() > 1
827 828 829
        assert (
            self.num_stages > 2
        ), "virtual pipeline must run under pp degree > 2"
830
        assert (
831
            framework.in_dynamic_mode()
832
        ), "virtual pipeline stage with interleave only support eager dygraph mode"
833 834 835
        assert (
            self.accumulate_steps % self.num_stages == 0
        ), "accumulate_steps should be evenly divisible by num_stages for pipeline with interleave"
836 837 838 839 840 841 842
        # setup for interleave scheduler
        self.num_model_chunks = layers.get_num_virtual_stages()
        self.model_chunks = layers.get_model_chunks()
        assert self.model_chunks is not None
        assert len(self.model_chunks) == self.num_model_chunks
        self._virtual_pp_world_size = self.num_model_chunks
        self._virtual_pp_rank = 0
843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888
        self._reset_counter()

    def _reset_counter(self):
        for i in range(self.num_model_chunks):
            self._forward_micro_step_counter[i] = 0
            self._backward_micro_step_counter[i] = 0

    def _record_stamp(self, name, step, phase, forward=True):
        if self._profiling:
            paddle.device.synchronize()
            virtual_pp_rank = self._get_virtual_pp_rank(step, forward=forward)
            color_idx = virtual_pp_rank % 2
            # Get the profile color and micro step for current layer chunk
            if forward:
                color = self._forward_colors[color_idx]
                micro_step = self._forward_micro_step_counter[virtual_pp_rank]
                if phase == '"E"':
                    self._forward_micro_step_counter[virtual_pp_rank] += 1
            else:
                color = self._backward_colors[color_idx]
                micro_step = self._backward_micro_step_counter[virtual_pp_rank]
                if phase == '"E"':
                    self._backward_micro_step_counter[virtual_pp_rank] += 1
            self._records.append(
                '{'
                + self._record_format.format(
                    name,
                    micro_step,
                    virtual_pp_rank,
                    phase,
                    int(time.time() * 1000),
                    color,
                )
                + '}'
            )

    def _flush_records(self):
        if self._profiling:
            with open(
                f'./profile_record_tmp_file_for_rank_{self.global_rank}',
                'a+',
            ) as f:
                for record in self._records:
                    f.write(record + '\n')
            self._records = []
            self._reset_counter()
889 890

    def _get_virtual_pp_rank(self, micro_step, forward):
891 892 893
        virtual_pp_stage = micro_step % (
            self.num_stages * self.num_model_chunks
        )
894 895
        virtual_pp_stage = virtual_pp_stage // self.num_stages
        if not forward:
896
            virtual_pp_stage = self.num_model_chunks - virtual_pp_stage - 1
897 898
        return virtual_pp_stage

899
    def _forward_step_helper(self, micro_dataset, micro_step):
900 901 902 903 904 905 906 907
        virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=True)
        self.set_virtual_pipeline_rank(virtual_pp_rank)

        # some checkers
        assert hasattr(self, 'input_tensors')
        assert hasattr(self, 'output_tensors')
        if not self._forward_only:
            assert hasattr(self, 'output_tensor_grads')
908 909 910
        assert len(self.input_tensors[virtual_pp_rank]) == (
            len(self.output_tensors[virtual_pp_rank]) + 1
        )
911
        input_tensor = self.input_tensors[virtual_pp_rank][-1]
912 913 914
        output_tensor = self._forward_step(
            input_tensor, micro_dataset, virtual_pp_rank
        )
915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932
        self.output_tensors[virtual_pp_rank].append(output_tensor)

        if self._forward_only:
            # no need to store tensor for backward
            self.input_tensors[virtual_pp_rank].pop()
            self.output_tensors[virtual_pp_rank].pop()

        return output_tensor

    def _backward_step_helper(self, micro_step):
        virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=False)
        self.set_virtual_pipeline_rank(virtual_pp_rank)

        # some checkers
        assert hasattr(self, 'input_tensors')
        assert hasattr(self, 'output_tensors')
        assert hasattr(self, 'output_tensor_grads')

933 934 935 936 937 938
        assert (
            len(self.output_tensor_grads[virtual_pp_rank]) == 1
        ), f"output_tensor_grads is empty for virtual_pp_rank {virtual_pp_rank}"

        assert len(self.input_tensors[virtual_pp_rank]) > 0
        assert len(self.output_tensors[virtual_pp_rank]) > 0
939 940 941 942

        input_tensor = self.input_tensors[virtual_pp_rank].pop(0)
        output_tensor = self.output_tensors[virtual_pp_rank].pop(0)
        output_tensor_grad = self.output_tensor_grads[virtual_pp_rank].pop(0)
943 944 945
        input_tensor_grad = self._backward_step(
            input_tensor, output_tensor, output_tensor_grad
        )
946 947 948

        return input_tensor_grad

949
    def forward_backward_pipeline(
950 951 952 953 954 955
        self,
        data,
        scaler,
        forward_only=False,
        compute_loss=True,
        static_scheduler=False,
956
    ):
957 958 959 960
        # use interleave scheduling strategy.
        # this strategy is inspired by:
        # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
        if not compute_loss:
961 962 963
            assert (
                not forward_only
            ), "compute_loss can only be set to False when forward_only is set to True"
964

965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980
        if static_scheduler:
            assert (
                not forward_only
            ), "static_scheduler only for training not for eval"
            assert (
                not self._profiling
            ), "While _profiling, static scheduler is not available"
            if data is not None:
                warnings.warn(
                    "Static scheduler run won't real run the model, but data has been provided"
                )
            logger.info(
                "enable static_scheduler will return the pp schedule instead of the loss"
            )
            schedule = ""

981 982 983 984 985 986 987 988 989 990 991
        # init some attributes for this batch run
        self.scaler = scaler
        self.total_loss = None
        self.micro_batch_id = 0
        self._forward_only = forward_only

        # init some data buffers for interleave scheduler
        self.input_tensors = [[] for _ in range(self.num_model_chunks)]
        self.output_tensors = [[] for _ in range(self.num_model_chunks)]
        self.output_tensor_grads = [[] for _ in range(self.num_model_chunks)]

992 993
        micro_dataset = self._wrap_data(data)

994 995 996 997 998
        num_steps = self.accumulate_steps * self.num_model_chunks
        if forward_only:
            # If only forward, since there is no backward during running, all steps are startup steps
            startup_steps = num_steps
        else:
999 1000 1001 1002 1003 1004 1005
            # actually startup_steps is calculated from two number:
            # first_forward_cross_to_end = (self.num_stages - self.stage_id - 1) + (self.num_model_chunks - 1) * self.num_stages
            # end_to_first_backward_cross = (self.num_stages - self.stage_id - 1)
            # startup_steps = first_forward_cross_to_end + end_to_first_backward_cross
            startup_steps = (self.num_stages - self.stage_id - 1) * 2
            startup_steps += (self.num_model_chunks - 1) * self.num_stages
            startup_steps = min(startup_steps, num_steps)
1006 1007 1008 1009

        steady_steps = num_steps - startup_steps

        self.set_virtual_pipeline_rank(0)
1010 1011 1012 1013 1014 1015
        if not static_scheduler:
            self.input_tensors[0].append(
                p2p.recv_forward(
                    self.is_pipeline_first_stage(), sync_recv=False
                )
            )
1016 1017 1018

        # run startup steps
        for micro_step in range(startup_steps):
1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033
            if static_scheduler:
                virtual_pp_rank = self._get_virtual_pp_rank(
                    micro_step, forward=True
                )
                real_micro_step = self._forward_micro_step_counter[
                    virtual_pp_rank
                ]
                self._forward_micro_step_counter[virtual_pp_rank] += 1
                schedule += f"f{real_micro_step}_vp{virtual_pp_rank};"
                logger.info(
                    f"forward step for {real_micro_step} with virtual pp rank {virtual_pp_rank}"
                )
                continue

            self._record_stamp("F", micro_step, '"B"', forward=True)
1034
            output_tensor = self._forward_step_helper(micro_dataset, micro_step)
1035
            self._record_stamp("F", micro_step, '"E"', forward=True)
1036 1037

            # determine whether recv forward tensor or not
1038 1039 1040
            next_virtual_pp_rank = self._get_virtual_pp_rank(
                micro_step + 1, forward=True
            )
1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053
            recv_prev = True
            if self.is_pipeline_first_stage(ignore_virtual=True):
                if next_virtual_pp_rank == 0:
                    # next chunk is the first chunk, not need to pre recv an input tensor
                    recv_prev = False
            # last micro step, no next run
            if micro_step == (num_steps - 1):
                recv_prev = False

            # last stage shouldn't send tensor to downstream
            if self.is_pipeline_last_stage():
                output_tensor = None

1054
            if micro_step == (startup_steps - 1) and not forward_only:
1055 1056 1057 1058 1059 1060
                input_tensor_grad = None
                recv_next = True
                if self.is_pipeline_last_stage(ignore_virtual=True):
                    recv_next = False

                # the last startup step needs on four direction comm to set up for steady 1f1b
1061 1062 1063 1064
                (
                    input_tensor,
                    output_tensor_grad,
                ) = p2p.send_forward_backward_recv_forward_backward(
1065 1066 1067
                    output_tensor,
                    input_tensor_grad,
                    recv_prev=recv_prev,
1068 1069
                    recv_next=recv_next,
                )
1070 1071
                # output_tensor_grad is not none if recv_next
                # append output_tensor_grad no matter none or not
1072 1073 1074
                self.output_tensor_grads[self.num_model_chunks - 1].append(
                    output_tensor_grad
                )
1075 1076
            else:
                input_tensor = p2p.send_forward_recv_forward(
1077 1078
                    output_tensor, recv_prev=recv_prev
                )
1079
            # append input_tensor no matter none or not
1080 1081
            self.input_tensors[next_virtual_pp_rank].append(input_tensor)

1082 1083
            self._release_output(output_tensor)

1084 1085
        # run 1f1b steady steps
        for micro_step in range(steady_steps):
1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115
            if static_scheduler:
                forward_micro_step_id = micro_step + startup_steps
                forward_virtual_pp_rank = self._get_virtual_pp_rank(
                    forward_micro_step_id, forward=True
                )
                backward_micro_step_id = micro_step
                backward_virtual_pp_rank = self._get_virtual_pp_rank(
                    backward_micro_step_id, forward=False
                )
                real_forward_micro_step = self._forward_micro_step_counter[
                    forward_virtual_pp_rank
                ]
                self._forward_micro_step_counter[forward_virtual_pp_rank] += 1
                real_backward_micro_step = self._backward_micro_step_counter[
                    backward_virtual_pp_rank
                ]
                self._backward_micro_step_counter[backward_virtual_pp_rank] += 1
                schedule += (
                    f"f{real_forward_micro_step}_vp{forward_virtual_pp_rank};"
                )
                schedule += (
                    f"b{real_backward_micro_step}_vp{backward_virtual_pp_rank};"
                )
                logger.info(
                    f"forward step for {real_forward_micro_step} with virtual pp rank {forward_virtual_pp_rank}"
                )
                logger.info(
                    f"backward step for {real_backward_micro_step} with virtual pp rank {backward_virtual_pp_rank}"
                )
                continue
1116 1117
            # forward
            forward_micro_step_id = micro_step + startup_steps
1118
            self._record_stamp("F", forward_micro_step_id, '"B"', forward=True)
1119 1120 1121
            output_tensor = self._forward_step_helper(
                micro_dataset, forward_micro_step_id
            )
1122
            self._record_stamp("F", forward_micro_step_id, '"E"', forward=True)
1123 1124 1125

            # backward
            backward_micro_step_id = micro_step
1126 1127 1128
            self._record_stamp(
                "B", backward_micro_step_id, '"B"', forward=False
            )
1129
            input_tensor_grad = self._backward_step_helper(
1130 1131
                backward_micro_step_id
            )
1132 1133 1134
            self._record_stamp(
                "B", backward_micro_step_id, '"E"', forward=False
            )
1135 1136 1137 1138 1139 1140 1141 1142 1143

            # four directions comm
            # send output tensor to downstream
            # send input tensor grad to upstream
            # recv input tensor from upstream
            # recv output tensor grad from downstream

            # last stage doesn't send rst to downstream
            forward_virtual_pp_rank = self._get_virtual_pp_rank(
1144 1145
                forward_micro_step_id, forward=True
            )
1146 1147 1148 1149 1150 1151
            self.set_virtual_pipeline_rank(forward_virtual_pp_rank)
            if self.is_pipeline_last_stage():
                output_tensor = None

            # first stage doesn't send grad to upstream
            backward_virtual_pp_rank = self._get_virtual_pp_rank(
1152 1153
                backward_micro_step_id, forward=False
            )
1154 1155 1156 1157 1158 1159
            self.set_virtual_pipeline_rank(backward_virtual_pp_rank)
            if self.is_pipeline_first_stage():
                input_tensor_grad = None

            # determine whether to recv input tensor from upstream
            recv_prev = True
1160 1161 1162 1163 1164 1165 1166 1167
            next_forward_virtual_pp_rank = self._get_virtual_pp_rank(
                forward_micro_step_id + 1, forward=True
            )
            if self.is_pipeline_first_stage(ignore_virtual=True) and (
                next_forward_virtual_pp_rank == 0
            ):
                # first pp stage and first virtual stage
                recv_prev = False
1168 1169 1170 1171 1172 1173 1174

            # last iteration doesn't need recv from upstream
            if micro_step == (steady_steps - 1):
                recv_prev = False

            # determine whether to recv grad from downstream
            recv_next = True
1175 1176 1177 1178 1179 1180 1181 1182
            next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
                backward_micro_step_id + 1, forward=False
            )
            if self.is_pipeline_last_stage(ignore_virtual=True) and (
                next_backward_virtual_pp_rank == (self.num_model_chunks - 1)
            ):
                # last pp stage and last virtual stage
                recv_next = False
1183

1184 1185 1186 1187
            (
                input_tensor,
                output_tensor_grad,
            ) = p2p.send_forward_backward_recv_forward_backward(
1188 1189 1190
                output_tensor,
                input_tensor_grad,
                recv_prev=recv_prev,
1191 1192
                recv_next=recv_next,
            )
1193 1194 1195 1196 1197 1198 1199 1200
            # append input_tensor no matter none or not
            self.input_tensors[next_forward_virtual_pp_rank].append(
                input_tensor
            )
            # append output_tensor_grad no matter none or not
            self.output_tensor_grads[next_backward_virtual_pp_rank].append(
                output_tensor_grad
            )
1201 1202
            self._release_output(output_tensor)

1203 1204
        if not static_scheduler:
            self._release_output(output_tensor)
1205 1206 1207 1208

        # remaining backward steps
        if not forward_only:
            for micro_step in range(steady_steps, num_steps):
1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221
                if static_scheduler:
                    virtual_pp_rank = self._get_virtual_pp_rank(
                        micro_step, forward=False
                    )
                    real_micro_step = self._backward_micro_step_counter[
                        virtual_pp_rank
                    ]
                    self._backward_micro_step_counter[virtual_pp_rank] += 1
                    schedule += f"b{real_micro_step}_vp{virtual_pp_rank};"
                    logger.info(
                        f"backward step for {real_micro_step} with virtual pp rank {virtual_pp_rank}"
                    )
                    continue
1222
                # cooldown loop
1223
                self._record_stamp("B", micro_step, '"B"', forward=False)
1224
                input_tensor_grad = self._backward_step_helper(micro_step)
1225
                self._record_stamp("B", micro_step, '"E"', forward=False)
1226
                next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
1227 1228
                    micro_step + 1, forward=False
                )
1229 1230 1231

                recv_next = True
                if self.is_pipeline_last_stage(ignore_virtual=True):
1232 1233 1234
                    if next_backward_virtual_pp_rank == (
                        self.num_model_chunks - 1
                    ):
1235 1236 1237 1238
                        recv_next = False

                if micro_step == (num_steps - 1):
                    recv_next = False
1239
                # append output_tensor_grad no matter none or not
1240
                self.output_tensor_grads[next_backward_virtual_pp_rank].append(
1241 1242 1243 1244
                    p2p.send_backward_recv_backward(
                        input_tensor_grad, recv_next=recv_next
                    )
                )
1245

1246 1247 1248
            if self._comm_overlap:
                assert len(self._comm_buffers) > 0
                for buffer in self._comm_buffers:
1249 1250
                    buffer.scale_and_split_grads()

1251 1252 1253 1254
            if static_scheduler:
                self._reset_counter()
                return schedule

Y
Yuang Liu 已提交
1255 1256
            if self._enable_timer:
                self.timers("allreduce_shared_weight_gradients").start()
1257
            self._layers.allreduce_shared_weight_gradients()
Y
Yuang Liu 已提交
1258 1259
            if self._enable_timer:
                self.timers("allreduce_shared_weight_gradients").stop()
1260

1261 1262
        self._flush_records()

1263 1264
        if compute_loss:
            # return loss if compute loss
Y
Yuang Liu 已提交
1265 1266
            if self._enable_timer:
                self.timers("broadcast_final_loss").start()
1267 1268
            with paddle.amp.auto_cast(enable=False):
                train_loss = self._broadcast_final_loss()
Y
Yuang Liu 已提交
1269 1270
            if self._enable_timer:
                self.timers("broadcast_final_loss").stop()
1271 1272 1273 1274
        else:
            # else just return all intermediate output tensor for all micro steps
            train_loss = self.output_tensors

Y
Yuang Liu 已提交
1275
        self.timer_printer()
1276 1277 1278 1279 1280
        return train_loss

    def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
        data = self._prepare_training(data, optimizer, lr_scheduler)
        # interleave scheduler for pipeline parallel
1281
        train_loss = self.forward_backward_pipeline(data, scaler)
1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295

        # optimizer
        with paddle.amp.auto_cast(enable=False):
            self._optimizer_step()

        return train_loss

    def eval_batch(self, data, compute_loss=False):
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

        self._layers.eval()
        self._compute_loss = compute_loss

1296
        return self.forward_backward_pipeline(data, None, forward_only=True)
1297 1298 1299 1300 1301

    def get_static_scheduler(self):
        return self.forward_backward_pipeline(
            data=None, scaler=None, static_scheduler=True
        )