engine.py 32.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import time
16 17 18 19 20
import copy
import logging
from collections import defaultdict

import paddle
21
import paddle.utils as utils
22

23
from paddle import fluid, static
24
from paddle.io import Dataset
25
from paddle.jit import to_static
26
from paddle.metric import Metric
27
from paddle.static import InputSpec
28
from paddle.fluid import core
29
from paddle.fluid import program_guard
30
from paddle.fluid.layers.utils import flatten
31
from paddle.fluid.executor import global_scope, _to_name_str
32
from paddle.fluid.backward import append_backward
33
from paddle.fluid.framework import Operator, Parameter, _non_static_mode
34 35
from paddle.fluid.framework import _current_expected_place as _get_device
from paddle.fluid.dygraph.parallel import ParallelEnv
36
from paddle.distributed import fleet
37
from paddle.distributed.passes import new_pass, PassContext
38

39
from .hepler import ProgramHelper
40 41
from ..collective import _get_global_env
from .cluster import Cluster, get_default_cluster
42 43
from .planner_v2 import Planner
from .parallelizer_v2 import Parallelizer
44 45 46 47 48
from .dist_op import DistributedOperator
from .dist_saver import DistributedSaver
from .dist_loader import NonIterableGeneratorLoader
from .utils import make_data_unshard, set_grad_var_shape
from .utils import print_program_with_dist_attr, to_list
49
from .process_group import new_process_group, get_all_process_groups, get_world_process_group
50
from .dist_context import DistributedContext, get_default_distributed_context
51 52 53


class Engine:
54

55 56 57 58 59
    def __init__(self,
                 model=None,
                 inputs_spec=None,
                 labels_spec=None,
                 cluster=None,
60 61
                 strategy=None,
                 user_tuning_config=None):
62
        self.model = model
63 64
        self.inputs_spec = self._validate_spec(inputs_spec)
        self.labels_spec = self._validate_spec(labels_spec)
65
        self.cluster = cluster
66 67
        if self.cluster is None:
            self.cluster = get_default_cluster()
68
        self.strategy = strategy
69 70
        if self.strategy is None:
            self.strategy = fleet.DistributedStrategy()
71
        self._user_tuning_config = user_tuning_config
72

73
        self._executor = None
74 75 76
        self._cur_rank = paddle.distributed.get_rank()
        self._nranks = paddle.distributed.get_world_size()
        self._saver = DistributedSaver()
77 78 79 80 81 82 83 84 85 86 87 88

        # TODO: add logger module
        self._logger = logging.getLogger()
        self._logger.propagate = False
        if not self._logger.handlers:
            self._logger.setLevel(logging.INFO)
            log_handler = logging.StreamHandler()
            log_format = logging.Formatter(
                '[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s'
            )
            log_handler.setFormatter(log_format)
            self._logger.addHandler(log_handler)
89

90 91
        self._orig_main_prog = static.default_main_program()
        self._orig_startup_prog = static.default_startup_program()
92
        self._orig_dist_context = get_default_distributed_context()
93
        self._dist_contexts = {}
94 95
        self._serial_main_progs = {}
        self._serial_startup_progs = {}
96 97 98 99
        self._dist_main_progs = defaultdict(dict)  # dist main programs
        self._dist_startup_progs = defaultdict(dict)  # dist startup programs
        self._feed_vars = {}
        self._fetch_vars = {}
100
        self._planners = {}
101 102 103 104 105
        self._mode_init_states = {
            "train": False,
            "eval": False,
            "predict": False
        }
106
        self._dygraph_mode = False
107 108 109 110

    def prepare(self,
                optimizer=None,
                loss=None,
111
                gradient_scale=True,
112 113
                metrics=None,
                all_ranks=False):
114 115 116
        if optimizer and not isinstance(
                optimizer,
            (paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer)):
117 118 119 120
            raise TypeError(
                    "'optimizer' must be object of class `paddle.optimizer.Optimizer`" \
                        " or `paddle.fluid.optimizer.Optimizer`."
                )
121
        self._optimizer = optimizer
122
        self._all_ranks = all_ranks
123 124 125 126 127 128

        if loss and not isinstance(loss,
                                   paddle.nn.Layer) and not callable(loss):
            raise TypeError(
                "'loss' must be sub classes of `paddle.nn.Layer` or any callable function."
            )
129
        self._loss = loss
130 131 132 133 134 135

        metrics = metrics or []
        for metric in to_list(metrics):
            assert isinstance(metric, Metric), \
                "{} is not sub class of Metric".format(
                    metric.__class__.__name__)
136
        self._metrics = to_list(metrics)
137
        self._gradient_scale = gradient_scale
138
        self._planned_mode = None
139
        self._prepare_single_mode("train")
140

141
    def _prepare_single_mode(self, mode):
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156

        self._build(mode)
        # Do the planning process
        self._plan(mode)

        # Do the Optimization tuning
        if self._user_tuning_config and mode == "train":
            self._optimization_tuning(mode)

        # Do the parallel process
        self._parallel(mode, self._all_ranks)

        # Init comm and startup program
        self._initialize(mode)
        self._mode_init_states[mode] = True
157

158
    def _build(self, mode):
159
        if _non_static_mode() or self._dygraph_mode:
160
            paddle.disable_static()
161 162 163
            self._dygraph_mode = True
            self._logger.info("Building model with 'to_static' method.")

164 165 166
            program_helper = ProgramHelper(self.model, self._loss,
                                           self._metrics, self.inputs_spec,
                                           self.labels_spec)
167
            # build forward main program
168
            program_helper.build_program(mode)
169

170 171 172
            self.concrete_program = program_helper.concrete_program
            serial_main_prog = program_helper.main_program
            serial_startup_prog = program_helper.startup_program
173

174 175 176 177 178
            inputs = program_helper.input_vars
            outputs = program_helper.output_vars
            labels = program_helper.label_vars
            losses = program_helper.loss_vars
            metrics = program_helper.metric_vars
179

180
            paddle.enable_static()
181 182 183 184 185 186 187 188 189 190
        else:
            # build program in static mode
            serial_main_prog = self._serial_main_progs.get(mode, None)
            if serial_main_prog is not None:
                return

            losses = []
            metrics = []
            serial_main_prog = self._orig_main_prog.clone()
            serial_startup_prog = self._orig_startup_prog.clone()
191
            # FIXME to support grad clip
J
JZ-LIANG 已提交
192 193
            with static.program_guard(serial_main_prog, serial_startup_prog), \
                utils.unique_name.guard():
194 195 196 197 198 199 200 201 202 203 204 205
                inputs_spec = self.inputs_spec
                labels_spec = self.labels_spec if self.labels_spec else []
                inputs = [s._create_feed_layer() for s in inputs_spec]
                labels = [s._create_feed_layer() for s in labels_spec]
                outputs = to_list(self.model(*inputs))
                if mode != "predict" and self._loss:
                    losses = to_list(self._loss(*(outputs + labels)))

                if mode != "predict":
                    for metric in self._metrics:
                        metrics.extend(
                            to_list(metric.compute(*(outputs + labels))))
206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221

        default_ctx = get_default_distributed_context()
        if not default_ctx.has_annotation:
            # We build the world process group because the data parallel
            # needs all ranks by default.
            new_process_group(list(range(self._nranks)))
            default_ctx.data_parallel = True

        feed_vars = {"inputs": inputs, "labels": labels}

        fetch_vars = {
            "outputs": flatten(outputs),
            "loss": losses,
            "metrics": metrics
        }

222
        self._set_recompute_ckpts()
223 224 225 226
        self._dist_contexts[mode] = DistributedContext(
            serial_main_prog, serial_startup_prog, self._optimizer, losses,
            feed_vars, fetch_vars, self.cluster, self.strategy)
        self._dist_contexts[mode].gradient_scale = self._gradient_scale
227
        self._dist_contexts[mode]._dygraph_mode = self._dygraph_mode
228

229 230 231 232 233 234 235
    def _optimization_tuning(self, mode):

        self.mode = mode
        assert "batch_size" in self._user_tuning_config, "Optimization Tuning should provide with batch size."
        assert "dataset" in self._user_tuning_config, "Optimization Tuning should provide with dataset."
        batch_size = self._user_tuning_config["batch_size"]
        dataset = self._user_tuning_config["dataset"]
236 237
        dataset.dp_world_size = self._input_split_size
        dataset.dp_rank = self._input_split_rank
238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256

        from .tuner.optimization_tuner import OptimizationTuner
        self._optimization_tuner = OptimizationTuner(self._user_tuning_config,
                                                     self._dist_contexts[mode],
                                                     dataset,
                                                     self.inputs_spec,
                                                     self.labels_spec,
                                                     batch_size=batch_size,
                                                     rank=self._cur_rank)

        self._optimization_tuner.tune()

        if self._user_tuning_config["run_after_tuning"]:
            # update the strategy
            self._dist_contexts[
                mode]._strategy = self._optimization_tuner.get_best_config()
        else:
            return

257 258 259 260 261 262
    def _plan(self, mode):
        if self._planned_mode is None:
            self._planned_mode = mode
        else:
            self._init_dist_context(mode)

263 264
        self._planners[mode] = Planner(mode, self._dist_contexts[mode])
        self._planners[mode].plan()
265

266 267 268 269 270 271 272 273 274
        # infer data parallel info
        inputs_var = self._dist_contexts[mode].serial_feed_vars["inputs"]
        labels_var = self._dist_contexts[mode].serial_feed_vars["labels"]
        block = self._dist_contexts[mode].serial_main_program.global_block()
        feed_list = []
        for var in inputs_var + labels_var:
            if var.name in block.vars:
                feed_list.append(block.vars[var.name])

275
        self._input_split_size, self._input_split_rank = self._get_input_split_info(
276 277
            feed_list[0], self._dist_contexts[mode])

278
    def _parallel(self, mode, all_ranks):
279 280 281
        # Parallelize program based on the planner's results
        # For now, the completer has to be passed to the planner,
        # because we may use it to complete the annotation of the backwarkward and update.
282
        parallelizer = Parallelizer(mode, self._planners[mode].completer,
283 284 285 286 287
                                    self._dist_contexts[mode])
        if not all_ranks:
            parallelizer.parallel(self._cur_rank)
        else:
            parallelizer.parallel_all()
288 289

    def _init_dist_context(self, mode):
290
        # Init dist_context['mode'] with the first planned dist_context
291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307
        # to guarantee that train/eval/predict mode have same parallel strategy
        dist_context = self._dist_contexts[mode]
        origin_main_prog = dist_context._original_serial_main_program
        ref_mode = self._planned_mode
        ref_dist_context = self._dist_contexts[ref_mode]
        ref_origin_main_prog = ref_dist_context._original_serial_main_program
        ref_blocks = ref_origin_main_prog.blocks
        for ib, block in enumerate(origin_main_prog.blocks):
            for iop, op in enumerate(block.ops):
                ref_op = ref_blocks[ib].ops[iop]
                assert op.type == ref_op.type, \
                    "'{}' mode op '{}' is different with '{}' op '{}'. ".format(mode, op.type, ref_mode, ref_op.type)
                ref_op_dist_attr = ref_dist_context.get_op_dist_attr_for_program(
                    ref_op)
                dist_context.set_op_dist_attr_for_program(op, ref_op_dist_attr)

    def _initialize(self, mode):
308
        # Get the current content from the distributed context
309 310 311 312
        self._serial_main_progs[mode] = self._dist_contexts[
            mode].serial_main_program
        self._serial_startup_progs[mode] = self._dist_contexts[
            mode].serial_startup_program
313 314 315 316
        self._dist_main_progs[mode] = self._dist_contexts[
            mode].dist_main_programs
        self._dist_startup_progs[mode] = self._dist_contexts[
            mode].dist_startup_programs
317 318
        self._feed_vars[mode] = self._dist_contexts[mode].serial_feed_vars
        self._fetch_vars[mode] = self._dist_contexts[mode].serial_fetch_vars
319
        self._lr_optimizer = self._dist_contexts[mode]._lr_optimizer
320

321 322 323 324
        if self._nranks > 1:
            # Traverse different rank programs and traverse each op of them,
            # instantiate communication by process_mapping.
            all_process_groups = get_all_process_groups()
325

326
            # NOTE: add the comm init control in the future for auto search
327 328 329 330
            for process_group in all_process_groups:
                if self._cur_rank not in process_group.ranks:
                    continue
                process_group.instantiate()
331 332 333 334

        self._place = _get_device()
        if isinstance(self._place, fluid.CUDAPlace):
            self._place = fluid.CUDAPlace(ParallelEnv().dev_id)
335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363

        if self._dygraph_mode:
            paddle.disable_static()
            main_program = self._dist_main_progs[mode][self._cur_rank]
            for param in self.concrete_program.parameters:
                # create var in scope and share parameters to scope
                if param.name not in main_program.global_block().vars:
                    continue
                # get param_var's dist_attr
                var = main_program.global_block().vars[param.name]
                var_dist_attr = self._dist_contexts[
                    mode].get_tensor_dist_attr_for_program(var)
                dist_attr = {
                    "dims_mapping": var_dist_attr.dims_mapping,
                    "process_shape": var_dist_attr.process_mesh.topology,
                    "process_group": var_dist_attr.process_mesh.processes
                }
                # slice param_value with dist_attr
                # share sliced_param_value with param_tensor in global_scope
                from .converter import Converter
                param_tensor = global_scope().var(param.name).get_tensor()
                sliced_param = Converter.slice_with_dist_attr(
                    param.numpy(), dist_attr)
                shared_tensor = paddle.to_tensor(sliced_param,
                                                 place=self._place)
                param_tensor._share_data_with(
                    shared_tensor.value().get_tensor())
            paddle.enable_static()

364 365
        if self._executor is None:
            self._executor = paddle.static.Executor(self._place)
366 367 368 369 370 371 372 373 374 375
            uninitialized = []
            dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank]
            for var in dist_startup_prog.list_vars():
                scope_var = global_scope().find_var(var.name)
                if scope_var and scope_var.get_tensor()._is_initialized():
                    continue
                uninitialized.append(var)
            if uninitialized:
                prune_startup_prog = dist_startup_prog._prune(uninitialized)
                self._executor.run(prune_startup_prog)
376

377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410
            if self.strategy.amp and self.strategy.amp_configs['use_pure_fp16']:
                # from paddle.fluid.contrib.mixed_precision.fp16_utils import cast_parameters_to_fp16
                def cast_parameters_to_fp16(place,
                                            program,
                                            scope=None,
                                            to_fp16_var_names=None):
                    """
                    Traverse all parameters in the whole model and set them to the FP16 data type.
                    Whereas, this function will keep parameters of batchnorms in FP32.
                    Args:
                        place(fluid.CPUPlace|fluid.CUDAPlace): `place` is used to restore the FP16 weight tensors.
                        program (Program): The used program.
                        scope(fluid.Scope, optional): `scope` is used to get the FP32 weight tensor values.
                                                    Default is None.
                        to_fp16_var_names(set|list, optional): The data types of vars in `to_fp16_var_names`
                                                            will be set to FP16. Usually, it is the returned
                                                            value of `cast_model_to_fp16` API.
                    """
                    from paddle.framework import core
                    import numpy as np
                    all_parameters = []
                    for block in program.blocks:
                        all_parameters.extend(block.all_parameters())

                    var_scope = scope if scope else paddle.static.global_scope()
                    for param in all_parameters:
                        if param.dtype == core.VarDesc.VarType.FP16:
                            param_t = var_scope.find_var(
                                param.name).get_tensor()
                            data = np.array(param_t)
                            param_t.set(np.float16(data), place)

                cast_parameters_to_fp16(self._place, prune_startup_prog)

411 412 413 414
    def fit(self,
            train_data,
            batch_size=1,
            epochs=1,
415
            fetches=None,
416
            steps_per_epoch=None,
417 418
            collate_fn=None,
            use_cache=False,
419
            return_numpy=True):
420 421
        # TODO: callbacks
        # TODO: evaluate after training
422 423 424 425 426 427

        if not self._mode_init_states['train']:
            raise Exception(
                "train program is not initialized yet, please call engine.prepare() before calling fit() funtion."
            )

428
        self.mode = 'train'
429
        assert self.mode in self._dist_main_progs, \
430
            "train model is not ready, please call `engine.prepare()` first."
431
        train_dataloader = self._create_dataloader(train_data, batch_size,
432 433
                                                   epochs, steps_per_epoch,
                                                   collate_fn)
434

435 436
        usr_fetch = self._validate_fetches(fetches)
        fetch_loss = self._validate_fetches(self.fetch_vars["loss"])
437
        fetch_list, fetch_map = self._fetch_map(fetch_loss, usr_fetch)
438 439
        lr_scheduler = self.get_lr_scheduler(self.main_program)

440
        for epoch in range(epochs):
441
            train_logs = {"epoch: {:d} ": epoch}
442
            for step, _ in enumerate(train_dataloader):
443

444 445
                outs = self._executor.run(self.main_program,
                                          fetch_list=fetch_list,
446
                                          use_program_cache=use_cache,
447
                                          return_numpy=return_numpy)
448
                train_logs["step: {:d} "] = step
449 450
                if lr_scheduler is not None:
                    lr_scheduler.step()
451
                    train_logs["lr: {:5e} "] = self._lr_optimizer.get_lr()
452 453
                # inner fetches
                if fetch_loss:
454
                    train_logs["loss: {:9f} "] = outs[0][0]
455 456 457 458
                # user fetches
                user_outs = outs[len(fetch_loss):]
                user_fetch_list = fetch_list[len(fetch_loss):]
                for i, out in enumerate(user_outs):
459 460 461 462
                    train_logs[fetch_map[user_fetch_list[i]] + ": {}"] = out
                # logger
                string = '[train] ' + ''.join(list(train_logs.keys()))
                self._logger.info(string.format(*list(train_logs.values())))
463

464 465 466
    def evaluate(self,
                 eval_data,
                 batch_size=1,
467
                 fetches=None,
468 469
                 collate_fn=None,
                 use_cache=False,
470
                 return_numpy=True):
471
        self.mode = 'eval'
472 473 474
        if not self._mode_init_states[self.mode]:
            self._prepare_single_mode(self.mode)

475
        assert self.mode in self._dist_main_progs, \
476
            "eval model is not ready, please call `engine.prepare()` first."
477 478 479
        eval_dataloader = self._create_dataloader(eval_data,
                                                  batch_size,
                                                  collate_fn=collate_fn)
480

481 482 483
        usr_fetch = self._validate_fetches(fetches)
        fetch_loss = self._validate_fetches(self.fetch_vars["loss"])
        fetch_metrics = self._validate_fetches(self.fetch_vars["metrics"])
484 485 486 487
        inner_fetch = dict(fetch_loss, **fetch_metrics)
        fetch_list, fetch_map = self._fetch_map(inner_fetch, usr_fetch)

        for step, _ in enumerate(eval_dataloader):
488
            eval_logs = {"step: {:d} ": step}
489 490
            outs = self._executor.run(self.main_program,
                                      fetch_list=fetch_list,
491
                                      use_program_cache=use_cache,
492 493 494
                                      return_numpy=return_numpy)
            # inner fetches
            if fetch_loss:
495
                eval_logs["loss: {:9f} "] = outs[0][0]
496 497 498 499 500 501 502
            # Metric
            if fetch_metrics:
                metric_out = outs[len(fetch_loss):len(inner_fetch)]
                for metric in self._metrics:
                    metric.update(*metric_out)
                    results = metric.accumulate()
                    for i, res in enumerate(to_list(results)):
503
                        eval_logs[metric.name()[i] + ": {:9f} "] = res
504
            # usr fetches
505
            usr_outs = outs[len(inner_fetch):]
506
            usr_fetch_list = fetch_list[len(inner_fetch):]
507
            for i, out in enumerate(usr_outs):
508
                eval_logs[fetch_map[usr_fetch_list[i]] + ": {}"] = out
509
            # logger
510 511
            string = '[eval] ' + ''.join(list(eval_logs.keys()))
            self._logger.info(string.format(*list(eval_logs.values())))
512

513 514 515
    def predict(self,
                test_data,
                batch_size=1,
516
                fetches=None,
517 518
                collate_fn=None,
                use_cache=False,
519
                return_numpy=True):
520
        self.mode = 'predict'
521 522 523
        if not self._mode_init_states[self.mode]:
            self._prepare_single_mode(self.mode)

524
        assert self.mode in self._dist_main_progs, \
525
            "predict model is not ready, please call `engine.prepare()` first."
526 527 528
        test_dataloader = self._create_dataloader(test_data,
                                                  batch_size,
                                                  collate_fn=collate_fn)
529

530 531
        usr_fetch = self._validate_fetches(fetches)
        fetch_outputs = self._validate_fetches(self.fetch_vars["outputs"])
532
        fetch_list, fetch_map = self._fetch_map(fetch_outputs, usr_fetch)
533 534

        outputs = []
535
        for step, _ in enumerate(test_dataloader):
536
            predict_logs = {"step: {:d} ": step}
537 538
            outs = self._executor.run(self.main_program,
                                      fetch_list=fetch_list,
539
                                      use_program_cache=use_cache,
540 541 542
                                      return_numpy=return_numpy)
            outputs.append(outs[:len(fetch_outputs)])
            for i, out in enumerate(outs):
543 544 545 546
                predict_logs[fetch_map[fetch_list[i]] + ": {}"] = out
            # logger
            string = '[pred] ' + ''.join(list(predict_logs.keys()))
            self._logger.info(string.format(*list(predict_logs.values())))
547

548
        return outputs
549

550 551 552 553
    def _create_dataloader(self,
                           dataset,
                           batch_size,
                           epochs=1,
554 555
                           steps_per_epoch=None,
                           collate_fn=None):
556 557 558 559
        dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank]
        dist_startup_prog = self._dist_startup_progs[self.mode][self._cur_rank]
        dist_context = self._dist_contexts[self.mode]
        dist_main_block = dist_main_prog.global_block()
560

561
        # NOTE: Get feed_list from dist_program, then insert dataloader op
562 563
        # with sharded var shape. Because predict_program does not contain
        # labels var, so we will filter dataset's value with length of feed_list.
564 565 566 567 568 569 570 571
        inputs_var = self._feed_vars[self.mode]["inputs"]
        labels_var = self._feed_vars[self.mode]["labels"]
        feed_list = []
        for var in inputs_var + labels_var:
            if var.name in dist_main_block.vars:
                feed_list.append(dist_main_block.vars[var.name])

        # remove the first three ops if multi run fit/evaluate/predict
572
        op_size = len(dist_main_block.ops)
573 574 575 576
        if dist_main_block.ops[0].type == 'create_py_reader':
            op_size -= 3
            for _ in range(3):
                dist_main_block._remove_op(0, sync=False)
577 578

        # insert read op at the end of program
579
        places = paddle.static.cuda_places()
580
        with static.program_guard(dist_main_prog, dist_startup_prog):
581
            dataloader = NonIterableGeneratorLoader(
582 583 584 585 586 587
                dataset,
                feed_list,
                places,
                batch_size,
                epochs,
                steps_per_epoch,
588
                collate_fn,
589 590
                data_parallel_world_size=self._input_split_size,
                data_parallel_rank=self._input_split_rank)
591 592

        # move read op from the end of program to the start of program
593
        new_op_size = len(dist_main_block.ops)
594
        for _ in range(new_op_size - 1, op_size - 1, -1):
595 596 597
            op = dist_main_block.ops[new_op_size - 1]
            new_op_desc = dist_main_block.desc._prepend_op()
            new_op_desc.copy_from(op.desc)
598 599 600
            new_op = Operator(dist_main_block,
                              new_op_desc,
                              type=new_op_desc.type())
601 602 603 604 605 606 607 608
            dist_main_block.ops.insert(0, new_op)
            dist_op = DistributedOperator(new_op)
            dist_context.add_dist_op_for_program(dist_op)
        for _ in range(new_op_size - op_size):
            dist_main_block._remove_op(new_op_size, sync=False)
        dist_main_block._sync_with_cpp()
        return dataloader

609 610 611 612 613 614 615 616 617 618 619
    def _validate_spec(self, specs):
        specs = to_list(specs)
        if specs is not None:
            for i, spec in enumerate(specs):
                assert isinstance(spec, InputSpec)
                if spec.name is None:
                    raise ValueError(
                        "Requires Input[{}].name != None, but receive `None` with {}."
                        .format(i, spec))
        return specs

620 621 622 623 624 625 626 627 628 629 630 631 632 633 634
    def _is_local_var(self, var):
        var_name = _to_name_str(var)
        return var_name in self.main_program.global_block().vars

    def _validate_fetches(self, fetches):
        # 1. Check user-defined fetches type
        # 2. Prepare fetches_dict like {user_defined_name: var_name}
        if not fetches:
            return {}
        if isinstance(fetches, dict):
            fetch_var_names = list(map(_to_name_str, fetches.values()))
            fetches_dict = dict(zip(fetch_var_names, list(fetches.keys())))
        elif isinstance(fetches, list):
            fetch_var_names = list(map(_to_name_str, fetches))
            fetches_dict = dict(zip(fetch_var_names, fetch_var_names))
635
        else:
636 637 638 639 640 641 642 643 644 645 646 647 648
            raise TypeError("'fetches' only support 'dict' and 'list', "
                            "but got '{}'".format(str(type(fetches))))
        return dict(
            filter(lambda x: self._is_local_var(x[0]), fetches_dict.items()))

    def _fetch_map(self, inner_fetch, usr_fetch):
        # replace inner fetch name if usr set for it
        for iname in inner_fetch:
            if iname in usr_fetch:
                inner_fetch[iname] = usr_fetch[iname]
                usr_fetch.pop(iname)
        fetches = dict(inner_fetch, **usr_fetch)
        return list(fetches.keys()), fetches
649

650 651
    def _get_input_split_info(self, var, dist_context):
        # deduce how the input data is split among the cluster
652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672
        from .utils import _get_comm_group, _get_corresponding_rank

        tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
        process_mesh = tensor_dist_attr.process_mesh
        dims_mapping = tensor_dist_attr.dims_mapping

        if self._cur_rank not in process_mesh.processes:
            rank_id = _get_corresponding_rank(dist_context, process_mesh,
                                              self._cur_rank)
        else:
            rank_id = self._cur_rank

        batch_size_axis = dims_mapping[0]
        if batch_size_axis > -1 and process_mesh.topology[batch_size_axis] > 1:
            group_ranks = _get_comm_group(process_mesh.processes,
                                          process_mesh.topology,
                                          batch_size_axis, rank_id)
            return len(group_ranks), group_ranks.index(rank_id)

        return None, None

673 674 675 676 677 678 679 680 681
    def _set_recompute_ckpts(self):
        # NOTE hack to enable recompute in engine api for GPT-3
        # TODO support more PaddleNLP/CV models here

        config = self.strategy.recompute_configs

        # extract ckpts by specific model
        if isinstance(self.model, paddle.nn.Layer):
            if hasattr(
682 683 684
                    self.model, "gpt"
            ) and self.model.__class__.__name__ == 'GPTForPretraining':
                exact_ckpts = self.model.gpt.checkpoints
685 686 687 688 689 690 691 692
        else:
            exact_ckpts = config["checkpoints"]

        # modify strategy
        if self.strategy.recompute:
            config["checkpoints"] = exact_ckpts[:]
            self.strategy.recompute_configs = config
            logs = {
693
                'Model Class': self.model.__class__.__name__,
694 695 696 697
                'Applied Recompute ckpts': exact_ckpts
            }
            self._logger.info(logs)

698 699 700 701 702
    def save(self, path, training=True, mode=None):
        if not mode:
            mode = self.mode

        if training:
703 704
            assert 'train' in self._serial_main_progs, \
                "training model is not ready, please call `engine.prepare()` first."
705 706 707
            serial_program = self._serial_main_progs["train"]
            dist_main_prog = self._dist_main_progs["train"][self._cur_rank]
            dist_context = self._dist_contexts["train"]
708 709 710 711
            self._saver.save(path,
                             serial_program=serial_program,
                             dist_main_program=dist_main_prog,
                             dist_context=dist_context)
712 713 714 715 716
        else:
            assert mode, "Please set the 'mode' you want to save."
            feed_vars = self._feed_vars[mode]['inputs']
            fetch_vars = self._fetch_vars[mode]['outputs']
            dist_main_prog = self._dist_main_progs[mode][self._cur_rank]
717 718 719 720 721
            self._saver.save_inference_model(path,
                                             feed_vars,
                                             fetch_vars,
                                             self._executor,
                                             program=dist_main_prog)
722

723 724 725 726
    def load(self, path, strict=True, load_optimizer=True, mode=None):
        if not mode:
            mode = self.mode
        assert mode, "Please set the 'mode' you want to load."
727

728 729 730 731
        dist_main_prog = self._dist_main_progs[mode][self._cur_rank]
        dist_context = self._dist_contexts[mode]
        self._saver.load(path, dist_main_prog, dist_context, strict,
                         load_optimizer)
732

733 734 735 736 737 738 739 740 741
    @staticmethod
    def get_lr_scheduler(program):
        lr_sheduler = None
        if hasattr(program, 'lr_sheduler'):
            from paddle.optimizer.lr import LRScheduler
            lr_sheduler = program.lr_sheduler
            assert isinstance(lr_sheduler, LRScheduler), "must be LRScheduler"
        return lr_sheduler

742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768
    @property
    def mode(self):
        return self._mode

    @mode.setter
    def mode(self, mode):
        self._mode = mode

    @property
    def main_program(self):
        return self._dist_main_progs[self.mode][self._cur_rank]

    @property
    def startup_program(self):
        return self._dist_startup_progs[self.mode][self._cur_rank]

    @property
    def dist_context(self):
        return self._dist_contexts[self.mode]

    @property
    def serial_main_program(self):
        return self._serial_main_progs[self.mode]

    @property
    def serial_startup_program(self):
        return self._serial_startup_progs[self.mode]
769 770 771 772

    @property
    def fetch_vars(self):
        return self._fetch_vars[self.mode]