engine.py 54.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import os
16
import logging
17 18
import random
import numpy as np
19 20 21
from collections import defaultdict

import paddle
22
import paddle.utils as utils
23

24
from paddle import fluid, profiler, static
25
from paddle.metric import Metric
26
from paddle.static import InputSpec
27
from paddle.fluid import core
28
from paddle.fluid import Variable
29
from paddle.fluid.layers.utils import flatten
30
from paddle.fluid.executor import global_scope, _to_name_str
31
from paddle.fluid.framework import Operator, _non_static_mode
32 33
from paddle.fluid.framework import _current_expected_place as _get_device
from paddle.fluid.dygraph.parallel import ParallelEnv
34
from paddle.distributed import fleet
35

36
from .converter import Converter
37
from .helper import ProgramHelper
38
from .cluster import Cluster, get_default_cluster
39 40
from .planner_v2 import Planner
from .parallelizer_v2 import Parallelizer
41 42 43
from .dist_op import DistributedOperator
from .dist_saver import DistributedSaver
from .dist_loader import NonIterableGeneratorLoader
44
from .utils import to_list
45 46
from .utils import get_logger, get_dist_attr
from .process_group import new_process_group, get_all_process_groups
47
from .dist_context import DistributedContext, get_default_distributed_context
48
from .strategy import Strategy
49
from .interface import CollectionNames, get_collection
50 51 52


class Engine:
53
    """
54 55
    An Engine object can provide the full power of auto parallel to users.
    With the help of it, users can easily obtain the abilities of the
56 57 58 59 60 61 62
    distributed training and inference. It also support the dynamic graph and
    static graph at the same time.

    Args:
        model (paddle.nn.Layer, optional): The model is an instance of
            paddle.nn.Layer.
        loss (Loss|Callable|None, optional): The loss can be a `paddle.nn.Layer`
63 64
            instance or any callable function taken the predicted values and
            ground truth values as input. It can be None when there is no loss.
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
            Default: None.
        optimizer (Optimizer|None, optional): The optimizer need to be set in training
            and should be None in eval and predict mode. Default: None.
        metrics (Metric|list[Metric]|None, optional): If metrics is set, all
            metrics will be calculated and output in train/eval mode. Default: None.
        cluster (Cluster|None, optional): The cluster represents the topology information
            about the used physical devices. Default: None. (Unused for now)
        strategy (Strategy|None, optional): The strategy is used to configure the
        parallelization and optimization behaviors. Default: None.

    Examples:

        .. code-block:: python

            import paddle
            import paddle.vision.transforms as T
81
            from paddle.distributed.fleet import auto
82 83 84 85 86 87 88 89 90 91
            from paddle.vision.datasets import MNIST

            transform = T.Compose([
                T.Transpose(),
                T.Normalize([127.5], [127.5])
            ])
            train_dataset = MNIST(mode='train', transform=transform)
            valid_dataset = MNIST(mode='test', transform=transform)

            model = paddle.vision.models.LeNet()
92
            loss = paddle.nn.CrossEntropyLoss()
93 94 95 96
            optimizer = paddle.optimizer.Adam(
                learning_rate=0.001, parameters=model.parameters())
            metrics = paddle.metric.Accuracy(topk=(1, 2))

97 98
            engine = auto.Engine(model, loss, optimizer, metrics)
            # fit
99 100 101
            engine.fit(train_dataset,
                       epochs=2,
                       batch_size=64)
102
            # evaluate
103 104 105 106 107 108 109
            engine.evaluate(valid_dataset,
                            batch_size=64)
            # predict
            engine.predict(valid_dataset,
                           batch_size=64)
            # save
            engine.save("./my_model")
110
            # load
111 112 113
            engine.load("./my_model")

    """
114

115 116
    def __init__(self,
                 model=None,
117 118 119
                 loss=None,
                 optimizer=None,
                 metrics=None,
120
                 cluster=None,
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
                 strategy=None):

        if model and not isinstance(model,
                                    paddle.nn.Layer) and not callable(model):
            raise TypeError(
                "'model must be sub classes of `paddle.nn.Layer` or any callable function."
            )
        self._model = model

        if loss and not isinstance(loss,
                                   paddle.nn.Layer) and not callable(loss):
            raise TypeError(
                "'loss' must be sub classes of `paddle.nn.Layer` or any callable function."
            )
        self._loss = loss

        if optimizer and not isinstance(
                optimizer,
            (paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer)):
            raise TypeError(
                "'optimizer' must be object of class `paddle.optimizer.Optimizer`"
                " or `paddle.fluid.optimizer.Optimizer`.")
        self._optimizer = self._validate_opt(optimizer)

        metrics = metrics or []
        for metric in to_list(metrics):
            assert isinstance(metric, Metric), \
                "{} is not sub class of Metric".format(
                    metric.__class__.__name__)
        self._metrics = to_list(metrics)

        if cluster and not isinstance(cluster, Cluster):
            raise TypeError(
                "'cluster' must be the object or class `paddle.distributed.auto_parallel.Cluster`"
            )
        self._cluster = cluster or get_default_cluster()

        if strategy and not isinstance(strategy, Strategy):
            raise TypeError(
                "'strategy' must be object of class `paddle.distributed.auto_parallel.Strategy`"
            )
        self._strategy = strategy or Strategy()

        if os.getenv("POD_NAME"):
            print("Distribute training by paddle.distributed.launch",
                  flush=True)
            fleet.init(is_collective=True)
168

169
        self._executor = None
170 171 172
        self._cur_rank = paddle.distributed.get_rank()
        self._nranks = paddle.distributed.get_world_size()
        self._saver = DistributedSaver()
173

174
        self._logger = get_logger(logging.INFO)
175

176 177
        self._orig_main_prog = static.default_main_program()
        self._orig_startup_prog = static.default_startup_program()
178
        self._orig_dist_context = get_default_distributed_context()
179
        self._dist_contexts = {}
180 181
        self._serial_main_progs = {}
        self._serial_startup_progs = {}
182 183 184 185
        self._dist_main_progs = defaultdict(dict)  # dist main programs
        self._dist_startup_progs = defaultdict(dict)  # dist startup programs
        self._feed_vars = {}
        self._fetch_vars = {}
186
        self._planners = {}
187 188 189 190 191
        self._mode_init_states = {
            "train": False,
            "eval": False,
            "predict": False
        }
192 193

        self._planned_mode = None
194 195
        self._dygraph_mode = False
        self._tuning = self._strategy.tuning
196

197
    def _prepare_program(self, mode):
198
        # Do the build process
199 200 201 202
        self._build(mode)
        # Do the planning process
        self._plan(mode)
        # Do the parallel process
203
        self._parallel(mode)
204 205 206
        # Init comm and startup program
        self._initialize(mode)
        self._mode_init_states[mode] = True
207

208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
    def _prepare_feed(self, user_feeds=None, mode="train"):
        if user_feeds is not None:
            assert isinstance(user_feeds, dict), \
                "user_feeds must be a dict, but receive {}".format(type(user_feeds).__name__)
        feeds = {}
        # TODO: add inputs and labels feed dict
        if user_feeds is not None:
            for name, var in user_feeds.items():
                feeds[name] = var
        return feeds

    def _prepare_fetch(self, user_fetches=None, mode="train"):
        if user_fetches is not None:
            assert isinstance(user_fetches, list), \
                "user_fetches must be a list, but receive {}".format(type(user_fetches).__name__)
        fetch_names = []
224
        fetch_indices = []
225

226 227
        def _process_fetch_group(group_name, var_list):
            group_indices = []
228
            for var in var_list:
229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337
                # Remove duplicate var_names
                if self._is_local_var(var):
                    var_name = _to_name_str(var)
                    if var_name not in fetch_names:
                        fetch_names.append(var_name)
                    group_indices.append(fetch_names.index(var_name))
            fetch_indices.append(group_indices)

        if mode != "predict":
            _process_fetch_group("loss", self._fetch_vars[mode]["loss"])
        if mode != "predict":
            metrics = self._fetch_vars[mode]["metrics"]
            for i, var_list in enumerate(metrics):
                _process_fetch_group("metrics_" + str(i), var_list)
        if mode == "predict":
            _process_fetch_group("outputs", self._fetch_vars[mode]["outputs"])
        user_fetches_collection = [
            item[1] for item in get_collection(CollectionNames.FETCHES)
        ]
        var_list = (user_fetches_collection or []) + (user_fetches or [])
        _process_fetch_group("fetches", var_list)
        return fetch_names, fetch_indices

    def _prepare_logger(self,
                        outs,
                        mode="train",
                        epoch=None,
                        step=None,
                        lr=None,
                        fetch_names=None,
                        fetch_indices=None,
                        profiler_log=""):
        logs = "[{}] ".format(mode)
        if epoch is not None:
            logs += "epoch: {:d} ".format(epoch)
        if step is not None:
            logs += "step: {:d} ".format(step)
        if lr is not None:
            logs += "lr: {:5e} ".format(lr)
        group_idx = 0
        # logging loss
        if mode != "predict":
            loss_indices = fetch_indices[group_idx]
            for idx in loss_indices:
                logs += "loss: {:8f} ".format(outs[idx][0])
            group_idx += 1
        # logging metrics
        if mode != "predict":
            for metric in self._metrics:
                metrics_indices = fetch_indices[group_idx]
                metric_out = []
                for idx in metrics_indices:
                    metric_out.append(outs[idx])
                if metric_out:
                    metric.update(*metric_out)
                    results = metric.accumulate()
                    for i, res in enumerate(to_list(results)):
                        logs += "{}: {:8f} ".format(metric.name()[i], res)
                group_idx += 1
        # Skip logging outputs
        if mode == "predict":
            group_idx += 1
        # logging user fetches
        fetches_logging = get_collection(CollectionNames.LOGGING)
        for name, var in fetches_logging:
            if var.name in fetch_names:
                idx = fetch_names.index(var.name)
                # Use the user defined name for logging
                logs += "{}: {} ".format(name, outs[idx])
        self._logger.info(logs)

    def _prepare_history(self, outs, mode="train", fetch_indices=None):
        history = {}
        group_idx = 0
        # store loss
        if mode != "predict":
            loss_indices = fetch_indices[group_idx]
            loss_values = []
            for idx in loss_indices:
                loss_values.append(outs[idx][0])
            history["loss"] = loss_values
            group_idx += 1
        # store metrics
        if mode != "predict":
            for metric in self._metrics:
                metrics_indices = fetch_indices[group_idx]
                metric_out = []
                for idx in metrics_indices:
                    metric_out.append(outs[idx])
                if metric_out:
                    metric.update(*metric_out)
                    results = metric.accumulate()
                    history[tuple(metric.name())] = to_list(results)
                group_idx += 1
        # store outputs
        if mode == "predict":
            outputs_indices = fetch_indices[group_idx]
            outputs_values = []
            for idx in outputs_indices:
                outputs_values.append(outs[idx])
            history["outputs"] = outputs_values
            group_idx += 1
        # store user fetches
        fetches_indices = fetch_indices[group_idx]
        fetches_values = []
        for idx in fetches_indices:
            fetches_values.append(outs[idx])
        history["fetches"] = fetches_values
        return history
338

339
    def _build(self, mode):
340
        if _non_static_mode() or self._dygraph_mode:
341
            paddle.disable_static()
342 343 344
            self._dygraph_mode = True
            self._logger.info("Building model with 'to_static' method.")

345 346
            inputs_spec = self.inputs_spec
            labels_spec = self.labels_spec if self.labels_spec else []
347
            self.program_helper = ProgramHelper(self._model, self._loss,
348 349
                                                self._metrics, inputs_spec,
                                                labels_spec)
350
            # build forward main program
351
            self.program_helper.build_program(mode)
352

353 354 355
            self.concrete_program = self.program_helper.concrete_program
            serial_main_prog = self.program_helper.main_program
            serial_startup_prog = self.program_helper.startup_program
356

357 358 359 360 361
            inputs = self.program_helper.input_vars
            outputs = self.program_helper.output_vars
            labels = self.program_helper.label_vars
            losses = self.program_helper.loss_vars
            metrics = self.program_helper.metric_vars
362

363
            paddle.enable_static()
364 365 366 367 368 369 370 371 372 373
        else:
            # build program in static mode
            serial_main_prog = self._serial_main_progs.get(mode, None)
            if serial_main_prog is not None:
                return

            losses = []
            metrics = []
            serial_main_prog = self._orig_main_prog.clone()
            serial_startup_prog = self._orig_startup_prog.clone()
J
JZ-LIANG 已提交
374 375
            with static.program_guard(serial_main_prog, serial_startup_prog), \
                utils.unique_name.guard():
376 377 378 379
                inputs_spec = self.inputs_spec
                labels_spec = self.labels_spec if self.labels_spec else []
                inputs = [s._create_feed_layer() for s in inputs_spec]
                labels = [s._create_feed_layer() for s in labels_spec]
380
                outputs = to_list(self._model(*inputs))
381 382 383 384 385
                if mode != "predict" and self._loss:
                    losses = to_list(self._loss(*(outputs + labels)))

                if mode != "predict":
                    for metric in self._metrics:
386
                        metrics.append(
387
                            to_list(metric.compute(*(outputs + labels))))
388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403

        default_ctx = get_default_distributed_context()
        if not default_ctx.has_annotation:
            # We build the world process group because the data parallel
            # needs all ranks by default.
            new_process_group(list(range(self._nranks)))
            default_ctx.data_parallel = True

        feed_vars = {"inputs": inputs, "labels": labels}

        fetch_vars = {
            "outputs": flatten(outputs),
            "loss": losses,
            "metrics": metrics
        }

404 405 406
        if mode != "train":
            serial_main_prog = serial_main_prog.clone(for_test=True)

407
        self._set_recompute_ckpts()
408 409
        self._dist_contexts[mode] = DistributedContext(
            serial_main_prog, serial_startup_prog, self._optimizer, losses,
410 411
            feed_vars, fetch_vars, self._cluster, self._strategy)
        self._dist_contexts[mode].gradient_scale = self._strategy.gradient_scale
412

413 414 415
    def _optimization_tuning(self, mode, dataset, batch_size):
        if not self._tuning.enable:
            raise ValueError("Please set `tuning.enable=True`.")
416

417 418 419 420 421 422 423 424
        assert mode == "train"
        # Do the build process
        self._build(mode)
        # Do the planning process
        self._plan(mode)

        dataset.dp_world_size = self._dp_world_sizes
        dataset.dp_rank = self._dp_ranks
425 426

        from .tuner.optimization_tuner import OptimizationTuner
427
        self._optimization_tuner = OptimizationTuner(self._tuning.to_dict(),
428 429 430 431 432 433 434 435 436
                                                     self._dist_contexts[mode],
                                                     dataset,
                                                     self.inputs_spec,
                                                     self.labels_spec,
                                                     batch_size=batch_size,
                                                     rank=self._cur_rank)

        self._optimization_tuner.tune()

437
        if self._tuning.run_after_tuning:
438 439 440 441
            # update the strategy
            self._dist_contexts[
                mode]._strategy = self._optimization_tuner.get_best_config()

442 443 444 445 446 447
    def _plan(self, mode):
        if self._planned_mode is None:
            self._planned_mode = mode
        else:
            self._init_dist_context(mode)

448 449
        self._planners[mode] = Planner(mode, self._dist_contexts[mode])
        self._planners[mode].plan()
450

451 452 453 454 455 456 457 458 459
        # infer data parallel info
        inputs_var = self._dist_contexts[mode].serial_feed_vars["inputs"]
        labels_var = self._dist_contexts[mode].serial_feed_vars["labels"]
        block = self._dist_contexts[mode].serial_main_program.global_block()
        feed_list = []
        for var in inputs_var + labels_var:
            if var.name in block.vars:
                feed_list.append(block.vars[var.name])

460 461
        self._dp_world_sizes = []
        self._dp_ranks = []
462 463 464
        for feed_var in feed_list:
            dp_world_size, dp_rank = self._get_input_split_info(
                feed_var, self._dist_contexts[mode])
465 466
            self._dp_world_sizes.append(dp_world_size)
            self._dp_ranks.append(dp_rank)
467

468
    def _parallel(self, mode, all_ranks=False):
469 470 471
        # Parallelize program based on the planner's results
        # For now, the completer has to be passed to the planner,
        # because we may use it to complete the annotation of the backwarkward and update.
472
        parallelizer = Parallelizer(mode, self._planners[mode].completer,
473 474 475 476 477
                                    self._dist_contexts[mode])
        if not all_ranks:
            parallelizer.parallel(self._cur_rank)
        else:
            parallelizer.parallel_all()
478 479

    def _init_dist_context(self, mode):
480
        # Init dist_context['mode'] with the first planned dist_context
481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497
        # to guarantee that train/eval/predict mode have same parallel strategy
        dist_context = self._dist_contexts[mode]
        origin_main_prog = dist_context._original_serial_main_program
        ref_mode = self._planned_mode
        ref_dist_context = self._dist_contexts[ref_mode]
        ref_origin_main_prog = ref_dist_context._original_serial_main_program
        ref_blocks = ref_origin_main_prog.blocks
        for ib, block in enumerate(origin_main_prog.blocks):
            for iop, op in enumerate(block.ops):
                ref_op = ref_blocks[ib].ops[iop]
                assert op.type == ref_op.type, \
                    "'{}' mode op '{}' is different with '{}' op '{}'. ".format(mode, op.type, ref_mode, ref_op.type)
                ref_op_dist_attr = ref_dist_context.get_op_dist_attr_for_program(
                    ref_op)
                dist_context.set_op_dist_attr_for_program(op, ref_op_dist_attr)

    def _initialize(self, mode):
498
        # Get the current content from the distributed context
499 500 501 502
        self._serial_main_progs[mode] = self._dist_contexts[
            mode].serial_main_program
        self._serial_startup_progs[mode] = self._dist_contexts[
            mode].serial_startup_program
503 504 505 506
        self._dist_main_progs[mode] = self._dist_contexts[
            mode].dist_main_programs
        self._dist_startup_progs[mode] = self._dist_contexts[
            mode].dist_startup_programs
507 508
        self._feed_vars[mode] = self._dist_contexts[mode].serial_feed_vars
        self._fetch_vars[mode] = self._dist_contexts[mode].serial_fetch_vars
509
        self._lr_optimizer = self._dist_contexts[mode]._lr_optimizer
510

511 512 513 514
        if self._nranks > 1:
            # Traverse different rank programs and traverse each op of them,
            # instantiate communication by process_mapping.
            all_process_groups = get_all_process_groups()
515

516
            # NOTE: add the comm init control in the future for auto search
517 518 519 520
            for process_group in all_process_groups:
                if self._cur_rank not in process_group.ranks:
                    continue
                process_group.instantiate()
521

522 523 524
        place = _get_device()
        if isinstance(place, fluid.CUDAPlace):
            place = fluid.CUDAPlace(ParallelEnv().dev_id)
525

526 527 528 529 530
        if self._strategy.seed:
            paddle.seed(self._strategy.seed + self._dp_ranks[0])
            np.random.seed(self._strategy.seed + self._dp_ranks[0])
            random.seed(self._strategy.seed + self._dp_ranks[0])

531
        if self._dygraph_mode:
532 533 534
            dist_context = self._dist_contexts[mode]
            dist_main_program = self._dist_main_progs[mode][self._cur_rank]
            self.program_helper.init(dist_main_program, place, dist_context)
535

536
        if self._executor is None:
537
            self._executor = paddle.static.Executor(place)
538 539 540 541 542 543 544 545 546 547
            uninitialized = []
            dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank]
            for var in dist_startup_prog.list_vars():
                scope_var = global_scope().find_var(var.name)
                if scope_var and scope_var.get_tensor()._is_initialized():
                    continue
                uninitialized.append(var)
            if uninitialized:
                prune_startup_prog = dist_startup_prog._prune(uninitialized)
                self._executor.run(prune_startup_prog)
548

549 550 551 552 553 554 555 556 557
            if hasattr(self, "_state_dict") and hasattr(self, "_dist_attr"):
                self._set_state_dict(mode, self._strict, self._state_dict,
                                     self._dist_attr)

        if self._strategy.reinit:
            self._logger.info("NOTE: parameters wiil be re-initialized.")
            dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank]
            self._executor.run(dist_startup_prog)

558
    def _split_sample_item(self, data, split):
559 560
        if isinstance(data, paddle.io.IterableDataset):
            if split is None:
561
                inputs, labels = next(iter(data))
562 563
            else:
                sample = next(iter(data))
564 565
                inputs = sample[:split]
                labels = sample[split:]
566 567
        elif isinstance(data, paddle.io.Dataset):
            if split is None:
568
                inputs, labels = data[0]
569 570
            else:
                sample = data[0]
571 572
                inputs = sample[:split]
                labels = sample[split:]
573 574 575 576
        else:
            raise ValueError(
                "Data should be a Dataset or IterableDatset, but received {}.".
                format(type(data).__name__))
577 578 579
        inputs = to_list(inputs)
        labels = to_list(labels)
        return inputs, labels
580

581
    def _infer_sample_spec(self, inputs, labels, batch_size):
582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600
        self.inputs_spec = []
        self.labels_spec = []

        def _infer_item_spec(item, name, batch_size, specs):
            if isinstance(item, np.ndarray):
                spec = InputSpec.from_numpy(item, name)
                if batch_size is None:
                    specs.append(spec)
                else:
                    specs.append(spec.batch(batch_size))
            elif isinstance(item, (Variable, core.VarBase, core.eager.Tensor)):
                spec = InputSpec.from_tensor(item, name)
                if batch_size is None:
                    specs.append(spec)
                else:
                    specs.append(spec.batch(batch_size))
            else:
                specs.append(InputSpec([batch_size], type(item), name))

601 602
        if inputs is not None:
            for i, item in enumerate(inputs):
603 604 605
                assert item is not None, "Receive None input."
                name = "input" + str(i)
                _infer_item_spec(item, name, batch_size, self.inputs_spec)
606 607
        if labels is not None:
            for i, item in enumerate(labels):
608 609 610 611 612 613 614
                assert item is not None, "Receive None input."
                name = "label" + str(i)
                _infer_item_spec(item, name, batch_size, self.labels_spec)

        self.inputs_spec = self._validate_spec(self.inputs_spec)
        self.labels_spec = self._validate_spec(self.labels_spec)

615 616 617 618 619 620 621
    def __call__(self,
                 inputs=None,
                 labels=None,
                 feeds=None,
                 fetches=None,
                 mode="train"):
        feed_dict = self._prepare_feed(feeds, mode)
622
        fetch_names, fetch_indices = self._prepare_fetch(fetches, mode)
623 624 625 626
        try:
            outs = self._executor.run(
                self.main_program,
                feed=feed_dict,
627
                fetch_list=fetch_names,
628 629 630 631
                use_program_cache=self._strategy.use_cache,
                return_numpy=self._strategy.return_numpy)
        except core.EOFException:
            pass
632 633 634 635
        self._prepare_logger(outs, self.mode, None, None, None, fetch_names,
                             fetch_indices)
        history = self._prepare_history(outs, self.mode, fetch_indices)
        return history
636

637 638
    def fit(self,
            train_data,
639
            train_sample_split=None,
640 641 642
            batch_size=1,
            epochs=1,
            steps_per_epoch=None,
643 644 645 646
            valid_data=None,
            valid_sample_split=None,
            valid_freq=1,
            valid_steps=None,
647
            collate_fn=None,
648 649 650 651 652 653 654 655 656
            callbacks=None):
        """
        Trains the model for a fixed number of epochs. If `valid_data` is set,
        evaluation will be done at the end of each epoch.

        Args:
            train_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None.
            train_sample_split (int, optional): Each sample of the train dataset is assumed
                to be a (input, label) pair by default and has two items. If each sample has
657
                more than two items, train_sample_split specifies how to split these items into
658
                input and label. The items before it are input and the left are label. Default: None.
659
            batch_size (int, optional): The batch size of train_data and valid_data if provided.
660 661 662
                The user's data will be used directly without batching if set to None. Default: 1.
            epochs (int, optional): The number of epochs to train the model. Default: 1.
            steps_per_epoch (int, optional): The total number of steps (batches of samples)
663
                is executed in one epoch before stating the next one. If None, it is equal to
664 665
                the number samples in your dataset divided by the batch size. Default: None.
            valid_data (Dataset, optional): An instance of paddle paddle.io.Dataset used for
666
                evaluation at the end of epoch. No evaluation will be done if set to None.
667
                Default: None. (Unsupported for now)
668
            valid_freq (int, optional): Only relevant if valid_data is provided. This specifies
669 670
                how many training epochs before a new evaluation is performed. Default: 1.
            valid_sample_split (int, optional): Only relevant if valid_data is provided.
671 672
                Each sample of the valid dataset is assumed to be a (input, label) pair
                by default and has two items. If each sample has more than two items,
673 674 675
                valid_sample_split specifies how to split these items into input and label.
                The items before it are input and the left are label. Default: None.
            valid_steps (int, optional): Only relevant if valid_data is provided.
676 677
                It is the total number of steps (batches of samples) to draw before
                stopping validation at the end of every epoch. If None, validation will run until the
678 679 680 681
                `valid_data` dataset is exhausted. The validation will start from the
                beginning of the dataset at each epoch. Default: None.
            collate_fn(callable, optional): function to generate mini-batch data by merging
                the sample list, None for only stack each fields of sample in axis
682
                0. Default None.
683 684 685 686 687 688 689 690 691 692 693 694
            callbacks (Callback|None, optional): A list of `Callback` instances to apply
                during training. Default: None. (Unused for now)

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle
                import paddle.vision.transforms as T
695
                from paddle.distributed.fleet import auto
696 697 698 699 700 701 702 703 704
                from paddle.vision.datasets import MNIST

                transform = T.Compose([
                    T.Transpose(),
                    T.Normalize([127.5], [127.5])
                ])
                train_dataset = MNIST(mode='train', transform=transform)

                model = paddle.vision.models.LeNet()
705
                loss = paddle.nn.CrossEntropyLoss()
706 707 708 709
                optimizer = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=model.parameters())
                metrics = paddle.metric.Accuracy(topk=(1, 2))

710
                engine = auto.Engine(model, loss, optimizer, metrics)
711 712 713 714
                engine.fit(train_dataset,
                           epochs=2,
                           batch_size=64)
        """
715
        self.mode = 'train'
716 717
        inputs, labels = self._split_sample_item(train_data, train_sample_split)
        self._infer_sample_spec(inputs, labels, batch_size)
718
        if not self._mode_init_states[self.mode]:
719
            self._prepare_program(self.mode)
Z
zhaoyingli 已提交
720 721
        else:
            self._switch_mode("train")
722

723
        assert self.mode in self._dist_main_progs, \
724 725 726 727 728
            "train model is not ready, please call `engine._prepare_program('train')` first."
        train_dataloader = self._prepare_dataloader(train_data, batch_size,
                                                    epochs, steps_per_epoch,
                                                    collate_fn)

729
        fetch_names, fetch_indices = self._prepare_fetch(mode=self.mode)
730
        lr_scheduler = self._get_lr_scheduler(self.main_program)
731

732 733 734 735 736 737
        with profiler.Profiler(timer_only=True) as prof:
            for epoch in range(epochs):
                for step, _ in enumerate(train_dataloader):
                    try:
                        outs = self._executor.run(
                            self.main_program,
738
                            fetch_list=fetch_names,
739 740 741 742 743 744 745 746 747 748
                            use_program_cache=self._strategy.use_cache,
                            return_numpy=self._strategy.return_numpy)
                    except core.EOFException:
                        break
                    if lr_scheduler and step % self._k_steps == 0:
                        lr_scheduler.step()
                    lr = self._get_lr(self._lr_optimizer)

                    prof.step()

749 750 751 752 753
                    self._prepare_logger(outs, self.mode, epoch, step, lr,
                                         fetch_names, fetch_indices,
                                         prof.step_info())
                    history = self._prepare_history(outs, self.mode,
                                                    fetch_indices)
754 755 756 757 758 759 760

                if valid_data and epoch % valid_freq == 0:
                    self.evaluate(valid_data, valid_sample_split, batch_size,
                                  valid_steps, collate_fn, callbacks)
                    self._switch_mode("train")
                else:
                    self._reset_metrics()
761
            return history
762

763
    def evaluate(self,
764 765
                 valid_data,
                 valid_sample_split=None,
766
                 batch_size=1,
767
                 steps=None,
768
                 collate_fn=None,
769 770 771 772 773
                 callbacks=None):
        """
        Evaluate the loss and metrics of the model on evaluation data.

        Args:
774 775
            valid_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None.
            valid_sample_split (int, optional): Each sample of the eval dataset is assumed
776
                to be a (input, label) pair by default and has two items. If each sample has
777
                more than two items, valid_sample_split specifies how to split these items into
778
                input and label. The items before it are input and the left are label. Default: None.
779
            batch_size (int, optional): The batch size of valid_data. The user's data will
780
                be used directly without batching if set to None. Default: 1.
781 782
            steps (int, optional): It is the total number of steps (batches of samples) to draw before
                stopping evaluation. If None, evaluation will run until the `valid_data` dataset is exhausted.
783 784 785 786 787
                The evaluation will start from the beginning of the dataset in each run. Default: None.
            collate_fn(callable, optional): function to generate mini-batch data by merging
                the sample list, None for only stack each fields of sample in axis
                0. Default None.
            callbacks (Callback|None, optional): A list of `Callback` instances to apply
788
                during evaluating. Default: None. (Unused for now)
789 790 791 792 793 794 795 796 797 798

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle
                import paddle.vision.transforms as T
799
                from paddle.distributed.fleet import auto
800 801 802 803 804 805 806 807 808
                from paddle.vision.datasets import MNIST

                transform = T.Compose([
                    T.Transpose(),
                    T.Normalize([127.5], [127.5])
                ])
                valid_dataset = MNIST(mode='test', transform=transform)

                model = paddle.vision.models.LeNet()
809
                loss = paddle.nn.CrossEntropyLoss()
810 811
                metrics = paddle.metric.Accuracy(topk=(1, 2))

812
                engine = auto.Engine(model, loss, metrics=metrics)
813 814 815
                engine.evaluate(valid_dataset, batch_size=64)

        """
816
        self.mode = 'eval'
817 818
        inputs, labels = self._split_sample_item(valid_data, valid_sample_split)
        self._infer_sample_spec(inputs, labels, batch_size)
819
        if not self._mode_init_states[self.mode]:
820
            self._prepare_program(self.mode)
Z
zhaoyingli 已提交
821 822
        else:
            self._switch_mode("eval")
823

824
        assert self.mode in self._dist_main_progs, \
825 826 827 828 829
            "eval model is not ready, please call `engine._prepare_program('eval')` first."
        valid_dataloader = self._prepare_dataloader(valid_data,
                                                    batch_size,
                                                    steps_per_epoch=steps,
                                                    collate_fn=collate_fn)
830

831
        fetch_names, fetch_indices = self._prepare_fetch(mode=self.mode)
832

833
        for step, _ in enumerate(valid_dataloader):
834
            try:
835 836
                outs = self._executor.run(
                    self.main_program,
837
                    fetch_list=fetch_names,
838 839 840
                    use_program_cache=self._strategy.use_cache,
                    return_numpy=self._strategy.return_numpy)
            except core.EOFException:
841
                break
842 843 844
            self._prepare_logger(outs, self.mode, None, step, None, fetch_names,
                                 fetch_indices)
            history = self._prepare_history(outs, self.mode, fetch_indices)
845
        self._reset_metrics()
846
        return history
847

848 849
    def predict(self,
                test_data,
850
                test_sample_split=None,
851
                batch_size=1,
852
                steps=None,
853
                collate_fn=None,
854 855 856 857 858 859 860 861
                callbacks=None):
        """
        Compute the output predictions on testing data.

        Args:
            test_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None.
            test_sample_split (int, optional): Each sample of the test dataset is assumed
                to be a (input, label) pair by default and has two items. If each sample has
862
                more than two items, test_sample_split specifies how to split these items into
863 864 865
                input and label. The items before it are input and the left are label. Default: None.
            batch_size (int, optional): The batch size of test_data. The user's data will
                be used directly without batching if set to None. Default: 1.
866 867
            steps (int, optional): It is the total number of steps (batches of samples) to draw before
                stopping predict. If None, predict will run until the `test_data` dataset is exhausted.
868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883
                The predict will start from the beginning of the dataset in each run. Default: None.
            collate_fn(callable, optional): function to generate mini-batch data by merging
                the sample list, None for only stack each fields of sample in axis
                0. Default None.
            callbacks (Callback|None, optional): A list of `Callback` instances to apply
                during testing. Default: None. (Unused for now)

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle
                import paddle.vision.transforms as T
884
                from paddle.distributed.fleet import auto
885 886 887 888 889 890 891 892 893 894
                from paddle.vision.datasets import MNIST

                transform = T.Compose([
                    T.Transpose(),
                    T.Normalize([127.5], [127.5])
                ])
                valid_dataset = MNIST(mode='test', transform=transform)

                model = paddle.vision.models.LeNet()

895
                engine = auto.Engine(model)
896 897
                engine.predict(valid_dataset, batch_size=64)
        """
898
        self.mode = 'predict'
899 900
        inputs, labels = self._split_sample_item(test_data, test_sample_split)
        self._infer_sample_spec(inputs, labels, batch_size)
901
        if not self._mode_init_states[self.mode]:
902
            self._prepare_program(self.mode)
Z
zhaoyingli 已提交
903 904
        else:
            self._switch_mode("predict")
905

906
        assert self.mode in self._dist_main_progs, \
907 908 909 910 911
            "predict model is not ready, please call `engine._prepare_program('predict')` first."
        test_dataloader = self._prepare_dataloader(test_data,
                                                   batch_size,
                                                   steps_per_epoch=steps,
                                                   collate_fn=collate_fn)
912

913
        fetch_names, fetch_indices = self._prepare_fetch(mode=self.mode)
914

915
        for step, _ in enumerate(test_dataloader):
916
            try:
917 918
                outs = self._executor.run(
                    self.main_program,
919
                    fetch_list=fetch_names,
920 921 922
                    use_program_cache=self._strategy.use_cache,
                    return_numpy=self._strategy.return_numpy)
            except core.EOFException:
923
                break
924 925 926
            self._prepare_logger(outs, self.mode, None, step, None, fetch_names,
                                 fetch_indices)
            history = self._prepare_history(outs, self.mode, fetch_indices)
927

928
        return history
929

930 931
    def _tune(self, tune_data, tune_sample_split=None, batch_size=1):
        self.mode = 'train'
932 933
        inputs, labels = self._split_sample_item(tune_data, tune_sample_split)
        self._infer_sample_spec(inputs, labels, batch_size)
934 935
        self._optimization_tuning(self.mode, tune_data, batch_size)

936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962
    def dataloader(self,
                   dataset,
                   sample_split=1,
                   batch_size=1,
                   epochs=1,
                   steps_per_epoch=None,
                   collate_fn=None,
                   mode="train",
                   from_generator=True):
        assert from_generator, "Only support from_generator for now"
        self.mode = mode
        inputs, labels = self._split_sample_item(dataset, sample_split)
        self._infer_sample_spec(inputs, labels, batch_size)
        if not self._mode_init_states[self.mode]:
            self._prepare_program(self.mode)
        else:
            self._switch_mode("train")
        dataloader = self._prepare_dataloader(dataset, batch_size, epochs,
                                              steps_per_epoch, collate_fn)
        return dataloader

    def _prepare_dataloader(self,
                            dataset,
                            batch_size,
                            epochs=1,
                            steps_per_epoch=None,
                            collate_fn=None):
963

964 965 966 967
        if self._strategy.gradient_merge and batch_size is not None:
            assert batch_size % self._k_steps == 0, \
                "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(batch_size, self._k_steps)
            batch_size //= self._k_steps
968

969 970 971 972
        dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank]
        dist_startup_prog = self._dist_startup_progs[self.mode][self._cur_rank]
        dist_context = self._dist_contexts[self.mode]
        dist_main_block = dist_main_prog.global_block()
973

974 975 976 977
        # NOTE: Get feed_list, then insert dataloader op with sharded var shape.
        # Cause predict_program does not contain labels var,
        # then we will add labels var from serial_program to dist_program,
        # that maintains the length of feed_list equal to the length of dataset's values.
978 979 980 981 982 983
        inputs_var = self._feed_vars[self.mode]["inputs"]
        labels_var = self._feed_vars[self.mode]["labels"]
        feed_list = []
        for var in inputs_var + labels_var:
            if var.name in dist_main_block.vars:
                feed_list.append(dist_main_block.vars[var.name])
984 985 986 987
            else:
                copy_var = dist_main_block._clone_variable(var, var.persistable)
                copy_var.desc.set_original_id(var.desc.original_id())
                feed_list.append(copy_var)
988 989

        # remove the first three ops if multi run fit/evaluate/predict
990
        op_size = len(dist_main_block.ops)
991 992 993 994
        if dist_main_block.ops[0].type == 'create_py_reader':
            op_size -= 3
            for _ in range(3):
                dist_main_block._remove_op(0, sync=False)
995 996

        # insert read op at the end of program
997
        places = paddle.static.cuda_places()
998
        with static.program_guard(dist_main_prog, dist_startup_prog):
999
            dataloader = NonIterableGeneratorLoader(
1000 1001 1002 1003 1004 1005
                dataset,
                feed_list,
                places,
                batch_size,
                epochs,
                steps_per_epoch,
1006
                collate_fn,
1007 1008 1009
                data_parallel_world_size=self._dp_world_sizes,
                data_parallel_rank=self._dp_ranks,
                split_data=self._strategy.split_data)
1010 1011

        # move read op from the end of program to the start of program
1012
        new_op_size = len(dist_main_block.ops)
1013
        for _ in range(new_op_size - 1, op_size - 1, -1):
1014 1015 1016
            op = dist_main_block.ops[new_op_size - 1]
            new_op_desc = dist_main_block.desc._prepend_op()
            new_op_desc.copy_from(op.desc)
1017 1018 1019
            new_op = Operator(dist_main_block,
                              new_op_desc,
                              type=new_op_desc.type())
1020 1021 1022 1023 1024 1025 1026 1027
            dist_main_block.ops.insert(0, new_op)
            dist_op = DistributedOperator(new_op)
            dist_context.add_dist_op_for_program(dist_op)
        for _ in range(new_op_size - op_size):
            dist_main_block._remove_op(new_op_size, sync=False)
        dist_main_block._sync_with_cpp()
        return dataloader

1028 1029
    def _validate_spec(self, specs):
        specs = to_list(specs)
1030
        self._k_steps = self._strategy.gradient_merge.k_steps
1031 1032 1033 1034 1035 1036 1037
        if specs is not None:
            for i, spec in enumerate(specs):
                assert isinstance(spec, InputSpec)
                if spec.name is None:
                    raise ValueError(
                        "Requires Input[{}].name != None, but receive `None` with {}."
                        .format(i, spec))
1038
                if self._k_steps > 1:
1039
                    shape = list(spec.shape)
1040 1041 1042
                    assert shape[0] % self._k_steps == 0, \
                        "Requires batch_size[{}] to be divisible by k_steps[{}].".format(spec.shape[0], self._k_steps)
                    shape[0] //= self._k_steps
1043
                    spec.shape = shape
1044 1045
        return specs

1046 1047 1048 1049
    def _is_local_var(self, var):
        var_name = _to_name_str(var)
        return var_name in self.main_program.global_block().vars

1050 1051
    def _get_input_split_info(self, var, dist_context):
        # deduce how the input data is split among the cluster
1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070
        from .utils import _get_comm_group, _get_corresponding_rank

        tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
        process_mesh = tensor_dist_attr.process_mesh
        dims_mapping = tensor_dist_attr.dims_mapping

        if self._cur_rank not in process_mesh.processes:
            rank_id = _get_corresponding_rank(dist_context, process_mesh,
                                              self._cur_rank)
        else:
            rank_id = self._cur_rank

        batch_size_axis = dims_mapping[0]
        if batch_size_axis > -1 and process_mesh.topology[batch_size_axis] > 1:
            group_ranks = _get_comm_group(process_mesh.processes,
                                          process_mesh.topology,
                                          batch_size_axis, rank_id)
            return len(group_ranks), group_ranks.index(rank_id)

1071
        return 1, 0
1072

1073 1074 1075 1076
    def _set_recompute_ckpts(self):
        # NOTE hack to enable recompute in engine api for GPT-3
        # TODO support more PaddleNLP/CV models here

1077
        recompute = self._strategy.recompute
1078 1079

        # extract ckpts by specific model
1080
        if isinstance(self._model, paddle.nn.Layer):
Z
zhaoyingli 已提交
1081 1082 1083 1084
            if hasattr(self._model,
                       "gpt") and self._model.__class__.__name__ in [
                           'GPTForPretraining', 'GPTForPretrainingAuto'
                       ]:
1085
                exact_ckpts = self._model.gpt.checkpoints
1086
            else:
1087
                exact_ckpts = recompute.checkpoints
1088
        else:
1089
            exact_ckpts = recompute.checkpoints
1090 1091

        # modify strategy
1092 1093
        if recompute.enable:
            recompute.checkpoints = exact_ckpts[:]
1094
            logs = {
1095
                'Model Class': self._model.__class__.__name__,
1096 1097 1098 1099
                'Applied Recompute ckpts': exact_ckpts
            }
            self._logger.info(logs)

1100
    def _validate_opt(self, optimizer):
1101 1102 1103
        if optimizer is not None:
            optimizer._parameter_list = None
            optimizer._param_groups = None
1104 1105
        return optimizer

1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122
    def _reset_metrics(self):
        for metric in self._metrics:
            metric.reset()

    def _switch_mode(self, mode):
        self.mode = mode
        self._initialize(mode)

    def _set_state_dict(self, mode, strict, state_dict, dist_attr):
        program = self._dist_main_progs[mode][self._cur_rank]
        dist_context = self._dist_contexts[mode]
        cur_dist_attr = get_dist_attr(program, dist_context)
        converter = Converter(state_dict, dist_attr, cur_dist_attr)
        state_dict = converter.convert(strict=strict)
        program.set_state_dict(state_dict)

    def save(self, path, training=True):
1123 1124
        """
        Saves the model, parameters, optimizer state to path.
1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144
        If `training` is set to False, only inference model will be saved.

        Args:
            path (str): The file prefix to save model. The format
                is 'dirname/file_prefix' or 'file_prefix'. if empty str.
                A exception will be raised.
            training (bool, optional): Whether to save for training. If not, save
                for inference only. If `training` is set to True, the optimzer state
                will be saved. Otherwise, only the model and parameters are saved.
                This function will silently overwrite existing file at the target
                location. Default: True.

        Returns:
            None

        Examples:

            .. code-block:: python
                import paddle
                import paddle.vision.transforms as T
1145
                from paddle.distributed.fleet import auto
1146 1147 1148 1149 1150 1151 1152 1153 1154
                from paddle.vision.datasets import MNIST

                transform = T.Compose([
                    T.Transpose(),
                    T.Normalize([127.5], [127.5])
                ])
                train_dataset = MNIST(mode='train', transform=transform)

                model = paddle.vision.models.LeNet()
1155
                loss = paddle.nn.CrossEntropyLoss()
1156 1157 1158 1159
                optimizer = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=model.parameters())
                metrics = paddle.metric.Accuracy(topk=(1, 2))

1160
                engine = auto.Engine(model, loss, optimizer, metrics)
1161 1162 1163 1164
                engine.fit(train_dataset,
                           epochs=1,
                           batch_size=64)
                engine.save("./my_model")
1165

1166
        """
1167
        if training:
1168
            assert 'train' in self._serial_main_progs, \
1169
                "training model is not ready, please call `engine._prepare_program('train')` first."
1170 1171 1172
            serial_program = self._serial_main_progs["train"]
            dist_main_prog = self._dist_main_progs["train"][self._cur_rank]
            dist_context = self._dist_contexts["train"]
1173 1174 1175 1176
            self._saver.save(path,
                             serial_program=serial_program,
                             dist_main_program=dist_main_prog,
                             dist_context=dist_context)
1177
        else:
1178
            mode = "predict"
1179 1180 1181
            feed_vars = self._feed_vars[mode]['inputs']
            fetch_vars = self._fetch_vars[mode]['outputs']
            dist_main_prog = self._dist_main_progs[mode][self._cur_rank]
1182 1183 1184 1185 1186
            self._saver.save_inference_model(path,
                                             feed_vars,
                                             fetch_vars,
                                             self._executor,
                                             program=dist_main_prog)
1187

1188 1189 1190 1191 1192 1193
    def load(self, path, strict=True, load_optimizer=True):
        """
        Load the stored model, parameters and optimizer states.

        Args:
            path (str): The prefix of files storing the model states and
1194
                optimizer states.
1195 1196 1197 1198 1199
            strict (bool, optional): Whether to skip the loading of mismatch
                parameter or raise an error when mismatch happens (not found
                the parameter in file storing model states of or receives a
                mismatch shape). Default: False.
            load_optimizer (bool, optional): If True, the stored optimizer
1200
                states is restored. Otherwise, the optimizer states is initialized
1201 1202 1203 1204 1205 1206 1207 1208 1209 1210
                from scratch. Default: False.

        Returns:
            None

        Examples:

            .. code-block:: python
                import paddle
                import paddle.vision.transforms as T
1211
                from paddle.distributed.fleet import auto
1212 1213 1214 1215 1216 1217 1218 1219 1220
                from paddle.vision.datasets import MNIST

                transform = T.Compose([
                    T.Transpose(),
                    T.Normalize([127.5], [127.5])
                ])
                train_dataset = MNIST(mode='train', transform=transform)

                model = paddle.vision.models.LeNet()
1221
                loss = paddle.nn.CrossEntropyLoss()
1222 1223 1224 1225
                optimizer = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=model.parameters())
                metrics = paddle.metric.Accuracy(topk=(1, 2))

1226
                engine = auto.Engine(model, loss, optimizer, metrics)
1227 1228 1229 1230 1231
                engine.fit(train_dataset,
                           epochs=1,
                           batch_size=64)
                engine.save("./my_model")
                engine.load("./my_model")
1232

1233 1234 1235 1236 1237
        """
        self._strict = strict
        self._state_dict, self._dist_attr = self._saver.load(
            path, load_optimizer)
        return self._state_dict, self._dist_attr
1238

1239
    @staticmethod
1240
    def _get_lr_scheduler(program):
1241 1242 1243 1244 1245 1246 1247
        lr_sheduler = None
        if hasattr(program, 'lr_sheduler'):
            from paddle.optimizer.lr import LRScheduler
            lr_sheduler = program.lr_sheduler
            assert isinstance(lr_sheduler, LRScheduler), "must be LRScheduler"
        return lr_sheduler

1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261
    def _get_lr(self, optimizer):
        if isinstance(optimizer, paddle.optimizer.Optimizer):
            return optimizer.get_lr()
        elif isinstance(optimizer, paddle.fluid.optimizer.Optimizer):
            if isinstance(optimizer._learning_rate, float):
                return optimizer._learning_rate
            else:
                return optimizer._learning_rate()
        else:
            raise TypeError(
                    "'optimizer' must be object of class `paddle.optimizer.Optimizer`" \
                        " or `paddle.fluid.optimizer.Optimizer`, but got {}.".format(type(optimizer))
                )

1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288
    @property
    def mode(self):
        return self._mode

    @mode.setter
    def mode(self, mode):
        self._mode = mode

    @property
    def main_program(self):
        return self._dist_main_progs[self.mode][self._cur_rank]

    @property
    def startup_program(self):
        return self._dist_startup_progs[self.mode][self._cur_rank]

    @property
    def dist_context(self):
        return self._dist_contexts[self.mode]

    @property
    def serial_main_program(self):
        return self._serial_main_progs[self.mode]

    @property
    def serial_startup_program(self):
        return self._serial_startup_progs[self.mode]
1289 1290 1291 1292

    @property
    def fetch_vars(self):
        return self._fetch_vars[self.mode]
1293 1294 1295 1296 1297 1298 1299 1300

    @property
    def inputs(self):
        return self.inputs_spec

    @property
    def labels(self):
        return self.labels_spec