pipeline_parallel.py 49.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
S
ShenLiang 已提交
13
import os
14 15
import time
import warnings
16 17

import paddle
18
from paddle import framework
19

20
from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer
Y
Yuang Liu 已提交
21
from ..utils import timer_helper as timer
22 23 24 25 26 27 28 29
from ..utils.hybrid_parallel_util import (
    broadcast_dp_parameters,
    broadcast_mp_parameters,
    broadcast_sharding_parameters,
)
from ..utils.log_util import logger
from .meta_parallel_base import MetaParallelBase
from .parallel_layers.pp_layers import PipelineLayer
30 31 32 33 34 35 36 37 38

_use_four_directions = os.environ.get(
    'PADDLE_USE_FOUR_DIRECTIONS_P2P', paddle.fluid.core.is_compiled_with_xpu()
)
if _use_four_directions:
    from .pp_utils import four_directions_p2p_communication as p2p
else:
    from .pp_utils import p2p_communication as p2p

39
from .pp_utils.utils import HOOK_ACTION, FusedCommBuffer, assign_group_by_size
40

41 42
__all__ = []

43

44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
# assume only the first stage and last stage need data, and data consumption is ordred
# to be replaced by real micro dataset from reader
class FakeMicroDataset:
    def __init__(
        self, data, is_first_stage, is_last_stage, acc_steps, micro_batch_size
    ):
        self._data = data
        self._index = 0
        self._acc_steps = acc_steps
        self._is_first_stage = is_first_stage
        self._is_last_stage = is_last_stage
        self._micro_batch_size = micro_batch_size

    def __iter__(self):
        return self

    def __next__(self):
        if self._index >= self._acc_steps:
            raise StopIteration
        assert self._is_first_stage or self._is_last_stage
        micro_batch_data = self._load_micro_batch(self._index)
        self._index += 1
        return micro_batch_data

    def _load_micro_batch(self, micro_step):
        inputs = self._data

        if self._is_first_stage or self._is_last_stage:
            assert len(inputs) == 2, "length of input should be 2"
            data = self._load_micro_batch_impl(inputs[0], micro_step)
            label = self._load_micro_batch_impl(inputs[1], micro_step)
            return (data, label)
        else:
            return (None, None)

    def _load_micro_batch_impl(self, inputs, micro_step):
        begin = micro_step * self._micro_batch_size
        end = begin + self._micro_batch_size

        if isinstance(inputs, tuple):
            output = []
            for data in inputs:
                if isinstance(data, list):
                    assert (
                        len(data) == self._acc_steps
                    ), "length of data should be %d, but it is %d" % (
                        self._acc_steps,
                        len(data),
                    )
                    output.append(data[micro_step].detach())
                elif data is not None:
                    self._check_data_vaild(data)
                    output.append(data[begin:end, :].detach())
                else:
                    output.append(None)
            return tuple(output)

        elif isinstance(inputs, list):
            assert (
                len(inputs) == self._acc_steps
            ), "length of data should be %d, but it is %d" % (
                self.accumulate_steps,
                len(inputs),
            )
            return inputs[micro_step].detach()
        elif inputs is not None:
            self._check_data_vaild(inputs)
            return inputs[begin:end, :].detach()
        else:
            return None

    def _check_data_vaild(self, data):
        batch_size = data.shape[0]
        assert self._micro_batch_size * self._acc_steps == batch_size, (
            "batch_size needs to be divisible by micro_batch_size. Currently, "
            "batch_size = %d, micro_batch_size = %d, accumulate_steps = %d."
            % (batch_size, self._micro_batch_size, self._acc_steps)
        )


124 125
class PipelineParallel(MetaParallelBase):
    def __init__(self, layers, hcg, strategy):
126 127
        if not isinstance(layers, PipelineLayer):
            raise TypeError(
128 129
                "The Layer should be a derived class of PipelineLayer."
            )
130
        super().__init__(layers, hcg, strategy)
131 132
        self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1
        self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1
133 134 135
        self.use_sharding_parallel = (
            self._hcg.get_sharding_parallel_world_size() > 1
        )
136 137 138 139

        self.total_loss = None

        self.micro_batch_size = self._strategy.pipeline_configs[
140 141
            'micro_batch_size'
        ]
142
        self.accumulate_steps = self._strategy.pipeline_configs[
143 144
            'accumulate_steps'
        ]
145 146 147
        # If sent tensor are not the same from different hosts,
        # they shouldn't been sent partially and then concated as a whole tensor.
        self._enable_partial_send_recv = self._strategy.pipeline_configs[
148 149
            'enable_partial_send_recv'
        ]
150 151
        self._using_cache = self._strategy.pipeline_configs['p2p_cache_shape']

152 153
        self.num_stages = self._hcg.get_pipe_parallel_world_size()
        self.stage_id = self._hcg.get_stage_id()
154
        self.global_rank = self._hcg.get_global_rank()
155
        self.pp_group = self._hcg.get_pipe_parallel_group()
156
        self.dp_group = self._hcg.get_data_parallel_group()
157
        self.sharding_group = self._hcg.get_sharding_parallel_group()
158

159 160 161 162 163
        self._virtual_pp_world_size = None
        self._virtual_pp_rank = None
        self._real_pp_world_size = self.num_stages
        self._real_pp_rank = self.stage_id

164 165 166
        self._delay_scale_loss = self._strategy.hybrid_configs[
            "pp_configs"
        ].delay_scale_loss
167 168
        # TODO(PP Dev): support dp_comm_overlap without use_main_grad training.
        # This combination will trigger inplace check error during `reshape_` in funct `_split_tensors`.
169 170 171
        self._dp_comm_overlap = self._strategy.hybrid_configs[
            "pp_configs"
        ].dp_comm_overlap
172 173 174
        self._sharding_comm_overlap = self._strategy.hybrid_configs[
            "pp_configs"
        ].sharding_comm_overlap
Y
Yuang Liu 已提交
175 176 177
        self._enable_timer = self._strategy.hybrid_configs[
            "pp_configs"
        ].enable_timer
S
ShenLiang 已提交
178

179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202
        self._profiling = self._strategy.hybrid_configs["pp_configs"].profiling
        self._records = []
        self._record_format = (
            '"name": "{}{}", "cat": "pipeline timeline", "ph": {}, "pid": 0, "tid": '
            + str(self.stage_id + 1)
            + ', "ts": {}, "cname": "{}"'
        )
        self._forward_color = "thread_state_running"  # RGB: 126, 200, 148
        self._backward_color = "rail_idle"  # RGB: 238, 142, 0
        if self._profiling:
            logger.info(
                "If enable pp profiling, the max training steps should be restricted "
                "to a reasonable value (such as 5) to avoid generating large profile files. "
                "The profiler will generate a profile file 'profile_record_tmp_file_for_rank_*' "
                "for each rank. Users should gather all profile files for one entire pipeline "
                "to one node (rank 0 is recommended) to get the full view of the pipeline profile. "
                "[DONT CHANGE THE NAME OF THE PROFILE FILES!]. "
                "Then get the profile parser from this url: "
                "https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/distributed/fleet/meta_parallel/pp_utils/profiler_helper.py "
                "and save the script to the same directory of all profile files."
                "Parse those files by this command: `python profiler_helper.py`. "
                "After parsing, a new file 'pipeline_profile.json' will be generated. "
                "Users can inspect this file by chrome://tracing website."
            )
203

204 205 206
        if self._dp_comm_overlap:
            assert self.use_data_parallel and self.num_stages > 1

207 208 209 210 211 212 213 214 215 216 217 218
        if self._sharding_comm_overlap:
            assert self.use_sharding_parallel and self.num_stages > 1

        assert not (
            self._dp_comm_overlap and self._sharding_comm_overlap
        ), "Cannot use dp pp overlap and sharding pp overlap at the same time."

        self._comm_buffers = []
        self._comm_overlap = (
            self._dp_comm_overlap or self._sharding_comm_overlap
        )

Y
Yuang Liu 已提交
219 220 221 222 223
        if self._enable_timer:
            if not timer.is_timer_initialized():
                timer.set_timers()
            self.timers = timer.get_timers()

224
        p2p.initialize_p2p_groups(
Y
Yuang Liu 已提交
225 226 227 228
            hcg,
            self._using_cache,
            self._enable_partial_send_recv,
            self._enable_timer,
229
        )
230 231

        self.global_rank = self._hcg.get_global_rank()
232
        self.micro_batch_id = 0
233

234 235
        self._compute_loss = True

236 237 238 239 240
        logger.info(
            "Pipeline Info -- num_stages: {}, stage_id: {}".format(
                self.num_stages, self.stage_id
            )
        )
241 242 243 244 245

        if self.use_model_parallel:
            logger.info("start broadcast mp parameters")
            broadcast_mp_parameters(self._layers, self._hcg)

246 247 248 249
        if self.use_sharding_parallel:
            logger.info("start broadcast sharding parameters")
            broadcast_sharding_parameters(self._layers, self._hcg)

250
        if self.use_data_parallel:
251
            logger.info("start broadcast dp parameters")
252
            broadcast_dp_parameters(self._layers, self._hcg)
253

254 255
        if self._dp_comm_overlap:
            self.register_allreduce_overlap_hook(
256
                self._layers, self.dp_group, self.accumulate_steps, True
257 258
            )

259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280
    def is_pipeline_first_stage(self, ignore_virtual=False):
        if not ignore_virtual:
            if self._virtual_pp_world_size is not None:
                assert self._virtual_pp_rank is not None
                if self._virtual_pp_rank != 0:
                    return False
        assert self._real_pp_rank is not None
        return self._real_pp_rank == 0

    def is_pipeline_last_stage(self, ignore_virtual=False):
        if not ignore_virtual:
            if self._virtual_pp_world_size is not None:
                assert self._virtual_pp_rank is not None
                if self._virtual_pp_rank != (self._virtual_pp_world_size - 1):
                    return False
        assert self._real_pp_rank is not None
        assert self._real_pp_world_size is not None
        return self._real_pp_rank == (self._real_pp_world_size - 1)

    def set_virtual_pipeline_rank(self, rank):
        self._virtual_pp_rank = rank

281 282 283 284 285 286 287
    def bw_hook_func(self, buffer, param):
        @paddle.autograd.no_grad()
        def fused_allreduce(*_):
            buffer.add_grad(param)

        return fused_allreduce

288
    def register_allreduce_overlap_hook(self, model, comm_group, acc_steps, dp):
289 290 291 292 293
        if model.get_num_virtual_stages() > 1:
            models = model.get_model_chunks()
        else:
            models = [model]

294 295 296 297 298 299 300 301 302
        if not dp:
            assert hasattr(self, "optimizer")
            assert hasattr(self.optimizer, "_param2rank")
            _param2rank = self.optimizer._param2rank

        act = HOOK_ACTION.ALL_REDUCE if dp else HOOK_ACTION.REDUCE

        fused_parameter_group = {}

303 304 305
        for model in models:
            # For virtual pipeline. Will separate parameters in different chunk into
            # different groups to get the best performance.
S
ShenLiang 已提交
306

307 308 309 310 311 312
            parameter_list = [
                p for p in model.parameters() if not p.stop_gradient
            ]
            if len(parameter_list) < 1:
                return

313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333
            if dp:
                fused_parameter_group[-1] = parameter_list
            else:
                # Sort parameters for sharding, since they have different dst rank
                for p in parameter_list:
                    assert p.name in _param2rank
                    dst_rank = _param2rank[p.name]
                    if dst_rank in fused_parameter_group:
                        fused_parameter_group[dst_rank].append(p)
                    else:
                        fused_parameter_group[dst_rank] = [p]

            for dst in fused_parameter_group:
                parameter_list = fused_parameter_group[dst]
                if not dp:
                    # parse the relative dst rank to absolute dst rank for sharding
                    dst = comm_group.ranks[dst]
                var_groups = assign_group_by_size(parameter_list)
                for group_idx, parameters in var_groups.items():
                    buffer = FusedCommBuffer(
                        group_idx, parameters, comm_group, acc_steps, act, dst
334
                    )
335 336 337 338 339
                    self._comm_buffers.append(buffer)
                    for param in parameters:
                        param._register_backward_hook(
                            self.bw_hook_func(buffer, param)
                        )
340

Y
Yuang Liu 已提交
341 342 343 344 345 346
    def timer_printer(self):
        if not self._enable_timer:
            return
        all_flag_names = self.timers.timers.keys()
        self.timers.log(all_flag_names)

347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374
    def _record_stamp(self, name, step, phase, color):
        if self._profiling:
            paddle.device.synchronize()
            self._records.append(
                '{'
                + self._record_format.format(
                    name,
                    step,
                    phase,
                    int(time.time() * 1000),
                    color,
                )
                + '}'
            )

    def _flush_records(self):
        if self._profiling:
            with open(
                f'./profile_record_tmp_file_for_rank_{self.global_rank}',
                'a+',
            ) as f:
                for record in self._records:
                    f.write(record + '\n')
            self._records = []

    def forward_backward_pipeline(
        self, data, scaler=None, static_scheduler=False
    ):
375 376 377
        # use the 1f1b scheduling strategy.
        # this strategy is inspired by:
        # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
378

379 380 381 382 383 384 385 386 387 388 389 390 391
        if static_scheduler:
            assert (
                not self._profiling
            ), "While _profiling, static scheduler is not available"
            if data is not None:
                warnings.warn(
                    "Static scheduler run won't real run the model, but data has been provided"
                )
            logger.info(
                "enable static_scheduler will return the pp schedule instead of the loss"
            )
            schedule = ""

392 393
        self.scaler = scaler

394 395 396
        # store total loss of entire batch
        self.total_loss = None

397 398
        # store data id for micro_batch
        self.micro_batch_id = 0
399

400
        startup_steps = self.num_stages - self.stage_id - 1
401 402
        startup_steps = min(startup_steps, self.accumulate_steps)
        steady_steps = self.accumulate_steps - startup_steps
403

404 405
        input_buffers = []
        output_buffers = []
406

407 408
        micro_dataset = self._wrap_data(data)

409
        for step_id in range(startup_steps):
410 411 412 413
            if static_scheduler:
                schedule += f"f{step_id};"
                logger.info(f"forward step for micro step {step_id}")
                continue
414
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
415

416
            self._record_stamp("F", step_id, '"B"', self._forward_color)
417
            output_tensor = self._forward_step(input_tensor, micro_dataset)
418
            self._record_stamp("F", step_id, '"E"', self._forward_color)
419
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
420

421 422
            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)
423

424 425 426
            if not self.is_pipeline_last_stage():
                self._release_output(output_tensor)

427
        if steady_steps > 0 and not static_scheduler:
428
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
429

430
        for i in range(steady_steps):
431 432 433 434 435 436
            if static_scheduler:
                schedule += f"f{startup_steps + i};"
                schedule += f"b{i};"
                logger.info(f"forward step for micro step {startup_steps + i}")
                logger.info(f"backward step for micro step {i}")
                continue
437
            last_iter = i == (steady_steps - 1)
438

439 440 441
            self._record_stamp(
                "F", startup_steps + i, '"B"', self._forward_color
            )
442
            output_tensor = self._forward_step(input_tensor, micro_dataset)
443 444 445
            self._record_stamp(
                "F", startup_steps + i, '"E"', self._forward_color
            )
446

447
            output_tensor_grad = p2p.send_forward_recv_backward(
448 449
                output_tensor, self.is_pipeline_last_stage()
            )
450

451 452
            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)
453

454 455 456
            if not self.is_pipeline_last_stage():
                self._release_output(output_tensor)

457
            input_tensor, output_tensor = input_buffers.pop(
458 459
                0
            ), output_buffers.pop(0)
460

461
            self._record_stamp("B", i, '"B"', self._backward_color)
462 463 464
            input_tensor_grad = self._backward_step(
                input_tensor, output_tensor, output_tensor_grad
            )
465
            self._record_stamp("B", i, '"E"', self._backward_color)
466 467 468

            if last_iter:
                input_tensor = None
469 470 471
                p2p.send_backward(
                    input_tensor_grad, self.is_pipeline_first_stage()
                )
472
            else:
473
                input_tensor = p2p.send_backward_recv_forward(
474 475
                    input_tensor_grad, self.is_pipeline_first_stage()
                )
476

477
        for i in range(startup_steps):
478 479 480 481
            if static_scheduler:
                schedule += f"b{steady_steps + i};"
                logger.info(f"backward step for micro step {steady_steps + i}")
                continue
482 483
            input_tensor = input_buffers.pop(0)
            output_tensor = output_buffers.pop(0)
484

485
            output_tensor_grad = p2p.recv_backward(
486 487
                self.is_pipeline_last_stage()
            )
488

489 490 491
            self._record_stamp(
                "B", steady_steps + i, '"B"', self._backward_color
            )
492 493 494
            input_tensor_grad = self._backward_step(
                input_tensor, output_tensor, output_tensor_grad
            )
495 496 497
            self._record_stamp(
                "B", steady_steps + i, '"E"', self._backward_color
            )
498
            p2p.send_backward(input_tensor_grad, self.is_pipeline_first_stage())
499

500 501 502 503 504
        if static_scheduler:
            return schedule

        self._flush_records()

505 506 507
        if self._comm_overlap:
            assert len(self._comm_buffers) > 0
            for buffer in self._comm_buffers:
508 509
                buffer.scale_and_split_grads()

Y
Yuang Liu 已提交
510 511
        if self._enable_timer:
            self.timers("allreduce_shared_weight_gradients").start()
512
        self._layers.allreduce_shared_weight_gradients()
Y
Yuang Liu 已提交
513 514 515
        if self._enable_timer:
            self.timers("allreduce_shared_weight_gradients").stop()
            self.timers("broadcast_final_loss").start()
516 517
        with paddle.amp.auto_cast(enable=False):
            train_loss = self._broadcast_final_loss()
Y
Yuang Liu 已提交
518 519 520
        if self._enable_timer:
            self.timers("broadcast_final_loss").stop()
        self.timer_printer()
521 522
        return train_loss

523 524 525 526
    def _prepare_training(self, data, optimizer, lr_scheduler):
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

527 528 529
        assert isinstance(
            optimizer, HybridParallelOptimizer
        ), 'optimizer should be HybridParallelOptimizer subclass.'
530

531
        assert (
532
            framework._dygraph_tracer()._has_grad
533
        ), 'Please enable the generation of gradients.'
534

535
        if self.is_pipeline_first_stage(
536 537 538 539 540
            ignore_virtual=True
        ) or self.is_pipeline_last_stage(ignore_virtual=True):
            assert (
                data is not None
            ), "For the first and the last stage, the data must be set."
541 542 543 544 545 546 547 548
        else:
            data = None

        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler

        self._layers.train()

549 550 551 552 553
        if self._sharding_comm_overlap and len(self._comm_buffers) == 0:
            self.register_allreduce_overlap_hook(
                self._layers, self.sharding_group, self.accumulate_steps, False
            )

554 555
        return data

556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571
    def _wrap_data(self, data):
        """
        for backward compatibilty, wrap data to Fake FakeMicroDataset if it is of type list or tuple
        """
        if (not isinstance(data, tuple)) and (not isinstance(data, list)):
            return data

        micro_dataset = FakeMicroDataset(
            data,
            self.is_pipeline_first_stage(ignore_virtual=True),
            self.is_pipeline_last_stage(ignore_virtual=True),
            self.accumulate_steps,
            self.micro_batch_size,
        )
        return micro_dataset

572 573 574
    def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
        data = self._prepare_training(data, optimizer, lr_scheduler)
        # 1f1b scheduler for pipeline parallel
575
        train_loss = self.forward_backward_pipeline(data, scaler)
576 577

        # optimizer
578 579
        with paddle.amp.auto_cast(enable=False):
            self._optimizer_step()
580 581

        return train_loss
582

583
    def eval_batch(self, data, compute_loss=False):
584 585 586
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

587 588 589 590 591 592 593 594 595
        self._layers.eval()
        self._compute_loss = compute_loss

        # store data id for micro_batch
        self.micro_batch_id = 0

        # store total loss of entire batch
        self.total_loss = None

596
        startup_steps = self.num_stages - self.stage_id - 1
597 598 599 600 601 602
        startup_steps = min(startup_steps, self.accumulate_steps)
        steady_steps = self.accumulate_steps - startup_steps

        input_buffers = []
        output_buffers = []

603 604 605
        # convert to micro dataset
        micro_dataset = self._wrap_data(data)

606
        for step_id in range(startup_steps):
607
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
608

609
            output_tensor = self._forward_step(input_tensor, micro_dataset)
610
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
611 612 613 614 615

            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)

        if steady_steps > 0:
616
            input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
617 618

        for i in range(steady_steps):
619
            last_iter = i == (steady_steps - 1)
620

621
            output_tensor = self._forward_step(input_tensor, micro_dataset)
622
            p2p.send_forward(output_tensor, self.is_pipeline_last_stage())
623 624 625 626 627

            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)

            if not last_iter:
628
                input_tensor = p2p.recv_forward(self.is_pipeline_first_stage())
629

630 631 632 633 634 635
        if self._compute_loss:
            self.train_loss = self._broadcast_final_loss()
        else:
            self.train_loss = output_buffers

        return self.train_loss
636

637
    def _forward_step(self, input_tensor, micro_dataset, chunk_id=None):
Y
Yuang Liu 已提交
638 639
        if self._enable_timer:
            self.timers("forward_step").start()
640
        if self.is_pipeline_first_stage():
641 642
            input_tensor = next(micro_dataset)[0]
            self._check_micro_batch_data_valid(input_tensor)
643

644 645 646
        assert chunk_id is None or isinstance(chunk_id, int)

        output_tensor = self._layers.forward(input_tensor, chunk_id=chunk_id)
647

648
        if self.is_pipeline_last_stage():
649 650
            # train calculate loss for train
            if self._compute_loss:
651 652 653
                assert (
                    self._layers._loss_fn is not None
                ), "loss function should exist to compute loss"
654 655
                labels = next(micro_dataset)[1]
                self._check_micro_batch_data_valid(labels)
656
                output_tensor = self._layers._loss_fn(output_tensor, labels)
657
                assert isinstance(
658
                    output_tensor, (paddle.Tensor, framework.core.eager.Tensor)
659
                ), "Currently, loss_fn should obtain Paddle.Tensor dtype"
660

661
                with paddle.amp.auto_cast(enable=False):
662
                    if self.accumulate_steps > 1 and not self._delay_scale_loss:
663
                        output_tensor = output_tensor / self.accumulate_steps
664

665 666 667
                    if self.total_loss is None:
                        self.total_loss = paddle.zeros_like(output_tensor)
                    self.total_loss += output_tensor.detach()
668

669 670 671 672
        if self.is_pipeline_first_stage() or self.is_pipeline_last_stage():
            # Only increase micro batch id at virtual first/last pp stage.
            # The micro batch id is used to load data, therefore, only increase it when load data.
            self.micro_batch_id += 1
Y
Yuang Liu 已提交
673 674
        if self._enable_timer:
            self.timers("forward_step").stop()
675 676 677
        return output_tensor

    def _backward_step(self, input_tensor, output_tensor, output_tensor_grad):
Y
Yuang Liu 已提交
678 679
        if self._enable_timer:
            self.timers("backward_step").start()
S
ShenLiang 已提交
680
        with paddle.amp.auto_cast(enable=False):
681
            if self.is_pipeline_last_stage():
S
ShenLiang 已提交
682 683 684 685 686
                assert output_tensor_grad is None
                if self.scaler:
                    paddle.autograd.backward(self.scaler.scale(output_tensor))
                else:
                    paddle.autograd.backward(output_tensor)
687
            else:
S
ShenLiang 已提交
688 689 690 691 692
                if isinstance(output_tensor, tuple):
                    outputs = [t for t in output_tensor if not t.stop_gradient]
                    assert len(outputs) == len(output_tensor_grad)
                    paddle.autograd.backward(
                        tensors=outputs,
693
                        grad_tensors=list(output_tensor_grad),
694
                    )
S
ShenLiang 已提交
695
                else:
696 697 698 699
                    paddle.autograd.backward(
                        tensors=[output_tensor],
                        grad_tensors=[output_tensor_grad],
                    )
S
ShenLiang 已提交
700 701 702 703 704

            input_tensor_grad = None
            if input_tensor is not None:
                if isinstance(input_tensor, tuple):
                    input_tensor_grad = tuple(
705 706
                        [t.grad for t in input_tensor if not t.stop_gradient]
                    )
S
ShenLiang 已提交
707 708
                else:
                    input_tensor_grad = input_tensor.grad
Y
Yuang Liu 已提交
709 710
            if self._enable_timer:
                self.timers("backward_step").stop()
S
ShenLiang 已提交
711
            return input_tensor_grad
712

713 714 715 716 717 718
    def _check_micro_batch_data_valid(self, micro_batch_data):
        if isinstance(micro_batch_data, (tuple, list)):
            for data in micro_batch_data:
                self._check_micro_batch_data_valid(data)
        elif micro_batch_data is not None:
            micro_batch_size = micro_batch_data.shape[0]
719
            assert (
720 721
                micro_batch_size == self.micro_batch_size
            ), f"expected micro_batch_size {self.micro_batch_size} but get {micro_batch_size}"
722

723
    def _broadcast_final_loss(self):
724 725 726
        # Since the last backward run in interleave will set the virtual rank to 0,
        # here we need to check last stage ignoring virtual stage.
        if self.is_pipeline_last_stage(ignore_virtual=True):
727 728 729
            assert (
                self.total_loss is not None
            ), "train_batch() in last stage should obtain vaild loss"
730 731 732 733 734
            loss = (
                self.total_loss.detach()
                if not self._delay_scale_loss
                else self.total_loss / self.accumulate_steps
            )
735
            is_fp32 = (
736
                paddle.full([], 1, 'int64')
737
                if loss.dtype == paddle.float32
738
                else paddle.full([], 0, 'int64')
739 740 741 742 743 744 745
            )
            paddle.distributed.broadcast(
                is_fp32, src=self.global_rank, sync_op=True, group=self.pp_group
            )
            paddle.distributed.broadcast(
                loss, src=self.global_rank, sync_op=True, group=self.pp_group
            )
746
        else:
747
            is_fp32 = paddle.full([], 1, 'int64')
748 749 750
            paddle.distributed.broadcast(
                is_fp32,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
751
                sync_op=True,
752 753 754 755
                group=self.pp_group,
            )
            loss = (
                paddle.zeros(shape=[1], dtype="float32")
756
                if is_fp32.item()
757 758
                else paddle.zeros(shape=[1], dtype="float16")
            )
759 760 761
            paddle.distributed.broadcast(
                loss,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
762
                sync_op=True,
763 764
                group=self.pp_group,
            )
765
        return loss
766

767
    def _optimizer_step(self):
768 769 770 771 772 773 774 775
        if self._delay_scale_loss:
            for p in self._layers.parameters():
                if hasattr(p, "main_grad") and p.main_grad is not None:
                    assert p.grad is None
                    p.main_grad = p.main_grad.scale(1.0 / self.accumulate_steps)
                elif p.grad is not None:
                    p.grad = p.grad.scale(1.0 / self.accumulate_steps)

776
        if self.scaler:
777
            self.scaler.step(self.optimizer)
S
ShenLiang 已提交
778
            self.scaler.update()
779 780
        else:
            self.optimizer.step()
781

782 783 784
        self.optimizer.clear_grad()
        if self.lr_scheduler:
            self.lr_scheduler.step()
785

786
    def _release_output(self, output):
787 788 789 790
        def can_free(t):
            return (
                t is not None
                and isinstance(t, paddle.Tensor)
H
Haohongxiang 已提交
791
                and t._is_initialized()
792 793 794
                and t.inplace_version == 0
            )

795 796
        if isinstance(output, (tuple, list)):
            for t in output:
797
                if can_free(t):
798
                    t._clear_dataptr()
799 800

        elif can_free(output):
801 802
            output._clear_dataptr()

803 804 805
    def get_static_scheduler(self):
        return self.forward_backward_pipeline(data=None, static_scheduler=True)

806 807 808 809 810

class PipelineParallelWithInterleave(PipelineParallel):
    # pipeline parallel with interleave scheduler

    def __init__(self, layers, hcg, strategy):
811
        super().__init__(layers=layers, hcg=hcg, strategy=strategy)
812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828
        self._record_format = (
            '"name": "{}{}_VP{}", "cat": "virtual pipeline timeline", "ph": {}, "pid": 0, "tid": '
            + str(self.stage_id + 1)
            + ', "ts": {}, "cname": "{}"'
        )
        self._forward_colors = [
            "thread_state_running",  # RGB: 126, 200, 148
            "thread_state_unknown",  # RGB: 199, 155, 125
        ]
        self._backward_colors = [
            "rail_load",  # RGB: 13, 168, 97
            "rail_idle",  # RGB: 238, 142, 0
        ]
        # Structures to record the micro step for each layer chunk
        self._forward_micro_step_counter = {}
        self._backward_micro_step_counter = {}

829
        assert layers.get_num_virtual_stages() > 1
830 831 832
        assert (
            self.num_stages > 2
        ), "virtual pipeline must run under pp degree > 2"
833
        assert (
834
            framework.in_dynamic_mode()
835
        ), "virtual pipeline stage with interleave only support eager dygraph mode"
836 837 838
        assert (
            self.accumulate_steps % self.num_stages == 0
        ), "accumulate_steps should be evenly divisible by num_stages for pipeline with interleave"
839 840 841 842 843 844 845
        # setup for interleave scheduler
        self.num_model_chunks = layers.get_num_virtual_stages()
        self.model_chunks = layers.get_model_chunks()
        assert self.model_chunks is not None
        assert len(self.model_chunks) == self.num_model_chunks
        self._virtual_pp_world_size = self.num_model_chunks
        self._virtual_pp_rank = 0
846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891
        self._reset_counter()

    def _reset_counter(self):
        for i in range(self.num_model_chunks):
            self._forward_micro_step_counter[i] = 0
            self._backward_micro_step_counter[i] = 0

    def _record_stamp(self, name, step, phase, forward=True):
        if self._profiling:
            paddle.device.synchronize()
            virtual_pp_rank = self._get_virtual_pp_rank(step, forward=forward)
            color_idx = virtual_pp_rank % 2
            # Get the profile color and micro step for current layer chunk
            if forward:
                color = self._forward_colors[color_idx]
                micro_step = self._forward_micro_step_counter[virtual_pp_rank]
                if phase == '"E"':
                    self._forward_micro_step_counter[virtual_pp_rank] += 1
            else:
                color = self._backward_colors[color_idx]
                micro_step = self._backward_micro_step_counter[virtual_pp_rank]
                if phase == '"E"':
                    self._backward_micro_step_counter[virtual_pp_rank] += 1
            self._records.append(
                '{'
                + self._record_format.format(
                    name,
                    micro_step,
                    virtual_pp_rank,
                    phase,
                    int(time.time() * 1000),
                    color,
                )
                + '}'
            )

    def _flush_records(self):
        if self._profiling:
            with open(
                f'./profile_record_tmp_file_for_rank_{self.global_rank}',
                'a+',
            ) as f:
                for record in self._records:
                    f.write(record + '\n')
            self._records = []
            self._reset_counter()
892 893

    def _get_virtual_pp_rank(self, micro_step, forward):
894 895 896
        virtual_pp_stage = micro_step % (
            self.num_stages * self.num_model_chunks
        )
897 898
        virtual_pp_stage = virtual_pp_stage // self.num_stages
        if not forward:
899
            virtual_pp_stage = self.num_model_chunks - virtual_pp_stage - 1
900 901
        return virtual_pp_stage

902
    def _forward_step_helper(self, micro_dataset, micro_step):
903 904 905 906 907 908 909 910
        virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=True)
        self.set_virtual_pipeline_rank(virtual_pp_rank)

        # some checkers
        assert hasattr(self, 'input_tensors')
        assert hasattr(self, 'output_tensors')
        if not self._forward_only:
            assert hasattr(self, 'output_tensor_grads')
911 912 913
        assert len(self.input_tensors[virtual_pp_rank]) == (
            len(self.output_tensors[virtual_pp_rank]) + 1
        )
914
        input_tensor = self.input_tensors[virtual_pp_rank][-1]
915 916 917
        output_tensor = self._forward_step(
            input_tensor, micro_dataset, virtual_pp_rank
        )
918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935
        self.output_tensors[virtual_pp_rank].append(output_tensor)

        if self._forward_only:
            # no need to store tensor for backward
            self.input_tensors[virtual_pp_rank].pop()
            self.output_tensors[virtual_pp_rank].pop()

        return output_tensor

    def _backward_step_helper(self, micro_step):
        virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=False)
        self.set_virtual_pipeline_rank(virtual_pp_rank)

        # some checkers
        assert hasattr(self, 'input_tensors')
        assert hasattr(self, 'output_tensors')
        assert hasattr(self, 'output_tensor_grads')

936 937 938 939 940 941
        assert (
            len(self.output_tensor_grads[virtual_pp_rank]) == 1
        ), f"output_tensor_grads is empty for virtual_pp_rank {virtual_pp_rank}"

        assert len(self.input_tensors[virtual_pp_rank]) > 0
        assert len(self.output_tensors[virtual_pp_rank]) > 0
942 943 944 945

        input_tensor = self.input_tensors[virtual_pp_rank].pop(0)
        output_tensor = self.output_tensors[virtual_pp_rank].pop(0)
        output_tensor_grad = self.output_tensor_grads[virtual_pp_rank].pop(0)
946 947 948
        input_tensor_grad = self._backward_step(
            input_tensor, output_tensor, output_tensor_grad
        )
949 950 951

        return input_tensor_grad

952
    def forward_backward_pipeline(
953 954 955 956 957 958
        self,
        data,
        scaler,
        forward_only=False,
        compute_loss=True,
        static_scheduler=False,
959
    ):
960 961 962 963
        # use interleave scheduling strategy.
        # this strategy is inspired by:
        # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
        if not compute_loss:
964 965 966
            assert (
                not forward_only
            ), "compute_loss can only be set to False when forward_only is set to True"
967

968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983
        if static_scheduler:
            assert (
                not forward_only
            ), "static_scheduler only for training not for eval"
            assert (
                not self._profiling
            ), "While _profiling, static scheduler is not available"
            if data is not None:
                warnings.warn(
                    "Static scheduler run won't real run the model, but data has been provided"
                )
            logger.info(
                "enable static_scheduler will return the pp schedule instead of the loss"
            )
            schedule = ""

984 985 986 987 988 989 990 991 992 993 994
        # init some attributes for this batch run
        self.scaler = scaler
        self.total_loss = None
        self.micro_batch_id = 0
        self._forward_only = forward_only

        # init some data buffers for interleave scheduler
        self.input_tensors = [[] for _ in range(self.num_model_chunks)]
        self.output_tensors = [[] for _ in range(self.num_model_chunks)]
        self.output_tensor_grads = [[] for _ in range(self.num_model_chunks)]

995 996
        micro_dataset = self._wrap_data(data)

997 998 999 1000 1001
        num_steps = self.accumulate_steps * self.num_model_chunks
        if forward_only:
            # If only forward, since there is no backward during running, all steps are startup steps
            startup_steps = num_steps
        else:
1002 1003 1004 1005 1006 1007 1008
            # actually startup_steps is calculated from two number:
            # first_forward_cross_to_end = (self.num_stages - self.stage_id - 1) + (self.num_model_chunks - 1) * self.num_stages
            # end_to_first_backward_cross = (self.num_stages - self.stage_id - 1)
            # startup_steps = first_forward_cross_to_end + end_to_first_backward_cross
            startup_steps = (self.num_stages - self.stage_id - 1) * 2
            startup_steps += (self.num_model_chunks - 1) * self.num_stages
            startup_steps = min(startup_steps, num_steps)
1009 1010 1011 1012

        steady_steps = num_steps - startup_steps

        self.set_virtual_pipeline_rank(0)
1013 1014 1015 1016 1017 1018
        if not static_scheduler:
            self.input_tensors[0].append(
                p2p.recv_forward(
                    self.is_pipeline_first_stage(), sync_recv=False
                )
            )
1019 1020 1021

        # run startup steps
        for micro_step in range(startup_steps):
1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036
            if static_scheduler:
                virtual_pp_rank = self._get_virtual_pp_rank(
                    micro_step, forward=True
                )
                real_micro_step = self._forward_micro_step_counter[
                    virtual_pp_rank
                ]
                self._forward_micro_step_counter[virtual_pp_rank] += 1
                schedule += f"f{real_micro_step}_vp{virtual_pp_rank};"
                logger.info(
                    f"forward step for {real_micro_step} with virtual pp rank {virtual_pp_rank}"
                )
                continue

            self._record_stamp("F", micro_step, '"B"', forward=True)
1037
            output_tensor = self._forward_step_helper(micro_dataset, micro_step)
1038
            self._record_stamp("F", micro_step, '"E"', forward=True)
1039 1040

            # determine whether recv forward tensor or not
1041 1042 1043
            next_virtual_pp_rank = self._get_virtual_pp_rank(
                micro_step + 1, forward=True
            )
1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056
            recv_prev = True
            if self.is_pipeline_first_stage(ignore_virtual=True):
                if next_virtual_pp_rank == 0:
                    # next chunk is the first chunk, not need to pre recv an input tensor
                    recv_prev = False
            # last micro step, no next run
            if micro_step == (num_steps - 1):
                recv_prev = False

            # last stage shouldn't send tensor to downstream
            if self.is_pipeline_last_stage():
                output_tensor = None

1057
            if micro_step == (startup_steps - 1) and not forward_only:
1058 1059 1060 1061 1062 1063
                input_tensor_grad = None
                recv_next = True
                if self.is_pipeline_last_stage(ignore_virtual=True):
                    recv_next = False

                # the last startup step needs on four direction comm to set up for steady 1f1b
1064 1065 1066 1067
                (
                    input_tensor,
                    output_tensor_grad,
                ) = p2p.send_forward_backward_recv_forward_backward(
1068 1069 1070
                    output_tensor,
                    input_tensor_grad,
                    recv_prev=recv_prev,
1071 1072
                    recv_next=recv_next,
                )
1073 1074
                # output_tensor_grad is not none if recv_next
                # append output_tensor_grad no matter none or not
1075 1076 1077
                self.output_tensor_grads[self.num_model_chunks - 1].append(
                    output_tensor_grad
                )
1078 1079
            else:
                input_tensor = p2p.send_forward_recv_forward(
1080 1081
                    output_tensor, recv_prev=recv_prev
                )
1082
            # append input_tensor no matter none or not
1083 1084
            self.input_tensors[next_virtual_pp_rank].append(input_tensor)

1085 1086
            self._release_output(output_tensor)

1087 1088
        # run 1f1b steady steps
        for micro_step in range(steady_steps):
1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118
            if static_scheduler:
                forward_micro_step_id = micro_step + startup_steps
                forward_virtual_pp_rank = self._get_virtual_pp_rank(
                    forward_micro_step_id, forward=True
                )
                backward_micro_step_id = micro_step
                backward_virtual_pp_rank = self._get_virtual_pp_rank(
                    backward_micro_step_id, forward=False
                )
                real_forward_micro_step = self._forward_micro_step_counter[
                    forward_virtual_pp_rank
                ]
                self._forward_micro_step_counter[forward_virtual_pp_rank] += 1
                real_backward_micro_step = self._backward_micro_step_counter[
                    backward_virtual_pp_rank
                ]
                self._backward_micro_step_counter[backward_virtual_pp_rank] += 1
                schedule += (
                    f"f{real_forward_micro_step}_vp{forward_virtual_pp_rank};"
                )
                schedule += (
                    f"b{real_backward_micro_step}_vp{backward_virtual_pp_rank};"
                )
                logger.info(
                    f"forward step for {real_forward_micro_step} with virtual pp rank {forward_virtual_pp_rank}"
                )
                logger.info(
                    f"backward step for {real_backward_micro_step} with virtual pp rank {backward_virtual_pp_rank}"
                )
                continue
1119 1120
            # forward
            forward_micro_step_id = micro_step + startup_steps
1121
            self._record_stamp("F", forward_micro_step_id, '"B"', forward=True)
1122 1123 1124
            output_tensor = self._forward_step_helper(
                micro_dataset, forward_micro_step_id
            )
1125
            self._record_stamp("F", forward_micro_step_id, '"E"', forward=True)
1126 1127 1128

            # backward
            backward_micro_step_id = micro_step
1129 1130 1131
            self._record_stamp(
                "B", backward_micro_step_id, '"B"', forward=False
            )
1132
            input_tensor_grad = self._backward_step_helper(
1133 1134
                backward_micro_step_id
            )
1135 1136 1137
            self._record_stamp(
                "B", backward_micro_step_id, '"E"', forward=False
            )
1138 1139 1140 1141 1142 1143 1144 1145 1146

            # four directions comm
            # send output tensor to downstream
            # send input tensor grad to upstream
            # recv input tensor from upstream
            # recv output tensor grad from downstream

            # last stage doesn't send rst to downstream
            forward_virtual_pp_rank = self._get_virtual_pp_rank(
1147 1148
                forward_micro_step_id, forward=True
            )
1149 1150 1151 1152 1153 1154
            self.set_virtual_pipeline_rank(forward_virtual_pp_rank)
            if self.is_pipeline_last_stage():
                output_tensor = None

            # first stage doesn't send grad to upstream
            backward_virtual_pp_rank = self._get_virtual_pp_rank(
1155 1156
                backward_micro_step_id, forward=False
            )
1157 1158 1159 1160 1161 1162
            self.set_virtual_pipeline_rank(backward_virtual_pp_rank)
            if self.is_pipeline_first_stage():
                input_tensor_grad = None

            # determine whether to recv input tensor from upstream
            recv_prev = True
1163 1164 1165 1166 1167 1168 1169 1170
            next_forward_virtual_pp_rank = self._get_virtual_pp_rank(
                forward_micro_step_id + 1, forward=True
            )
            if self.is_pipeline_first_stage(ignore_virtual=True) and (
                next_forward_virtual_pp_rank == 0
            ):
                # first pp stage and first virtual stage
                recv_prev = False
1171 1172 1173 1174 1175 1176 1177

            # last iteration doesn't need recv from upstream
            if micro_step == (steady_steps - 1):
                recv_prev = False

            # determine whether to recv grad from downstream
            recv_next = True
1178 1179 1180 1181 1182 1183 1184 1185
            next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
                backward_micro_step_id + 1, forward=False
            )
            if self.is_pipeline_last_stage(ignore_virtual=True) and (
                next_backward_virtual_pp_rank == (self.num_model_chunks - 1)
            ):
                # last pp stage and last virtual stage
                recv_next = False
1186

1187 1188 1189 1190
            (
                input_tensor,
                output_tensor_grad,
            ) = p2p.send_forward_backward_recv_forward_backward(
1191 1192 1193
                output_tensor,
                input_tensor_grad,
                recv_prev=recv_prev,
1194 1195
                recv_next=recv_next,
            )
1196 1197 1198 1199 1200 1201 1202 1203
            # append input_tensor no matter none or not
            self.input_tensors[next_forward_virtual_pp_rank].append(
                input_tensor
            )
            # append output_tensor_grad no matter none or not
            self.output_tensor_grads[next_backward_virtual_pp_rank].append(
                output_tensor_grad
            )
1204 1205
            self._release_output(output_tensor)

1206 1207
        if not static_scheduler:
            self._release_output(output_tensor)
1208 1209 1210 1211

        # remaining backward steps
        if not forward_only:
            for micro_step in range(steady_steps, num_steps):
1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224
                if static_scheduler:
                    virtual_pp_rank = self._get_virtual_pp_rank(
                        micro_step, forward=False
                    )
                    real_micro_step = self._backward_micro_step_counter[
                        virtual_pp_rank
                    ]
                    self._backward_micro_step_counter[virtual_pp_rank] += 1
                    schedule += f"b{real_micro_step}_vp{virtual_pp_rank};"
                    logger.info(
                        f"backward step for {real_micro_step} with virtual pp rank {virtual_pp_rank}"
                    )
                    continue
1225
                # cooldown loop
1226
                self._record_stamp("B", micro_step, '"B"', forward=False)
1227
                input_tensor_grad = self._backward_step_helper(micro_step)
1228
                self._record_stamp("B", micro_step, '"E"', forward=False)
1229
                next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
1230 1231
                    micro_step + 1, forward=False
                )
1232 1233 1234

                recv_next = True
                if self.is_pipeline_last_stage(ignore_virtual=True):
1235 1236 1237
                    if next_backward_virtual_pp_rank == (
                        self.num_model_chunks - 1
                    ):
1238 1239 1240 1241
                        recv_next = False

                if micro_step == (num_steps - 1):
                    recv_next = False
1242
                # append output_tensor_grad no matter none or not
1243
                self.output_tensor_grads[next_backward_virtual_pp_rank].append(
1244 1245 1246 1247
                    p2p.send_backward_recv_backward(
                        input_tensor_grad, recv_next=recv_next
                    )
                )
1248

1249 1250 1251
            if self._comm_overlap:
                assert len(self._comm_buffers) > 0
                for buffer in self._comm_buffers:
1252 1253
                    buffer.scale_and_split_grads()

1254 1255 1256 1257
            if static_scheduler:
                self._reset_counter()
                return schedule

Y
Yuang Liu 已提交
1258 1259
            if self._enable_timer:
                self.timers("allreduce_shared_weight_gradients").start()
1260
            self._layers.allreduce_shared_weight_gradients()
Y
Yuang Liu 已提交
1261 1262
            if self._enable_timer:
                self.timers("allreduce_shared_weight_gradients").stop()
1263

1264 1265
        self._flush_records()

1266 1267
        if compute_loss:
            # return loss if compute loss
Y
Yuang Liu 已提交
1268 1269
            if self._enable_timer:
                self.timers("broadcast_final_loss").start()
1270 1271
            with paddle.amp.auto_cast(enable=False):
                train_loss = self._broadcast_final_loss()
Y
Yuang Liu 已提交
1272 1273
            if self._enable_timer:
                self.timers("broadcast_final_loss").stop()
1274 1275 1276 1277
        else:
            # else just return all intermediate output tensor for all micro steps
            train_loss = self.output_tensors

Y
Yuang Liu 已提交
1278
        self.timer_printer()
1279 1280 1281 1282 1283
        return train_loss

    def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
        data = self._prepare_training(data, optimizer, lr_scheduler)
        # interleave scheduler for pipeline parallel
1284
        train_loss = self.forward_backward_pipeline(data, scaler)
1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298

        # optimizer
        with paddle.amp.auto_cast(enable=False):
            self._optimizer_step()

        return train_loss

    def eval_batch(self, data, compute_loss=False):
        # reset the virtual pp rank for each run
        self.set_virtual_pipeline_rank(0)

        self._layers.eval()
        self._compute_loss = compute_loss

1299
        return self.forward_backward_pipeline(data, None, forward_only=True)
1300 1301 1302 1303 1304

    def get_static_scheduler(self):
        return self.forward_backward_pipeline(
            data=None, scaler=None, static_scheduler=True
        )