engine.py 52.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import os
16
import time
17 18
import copy
import logging
19 20
import random
import numpy as np
21 22 23
from collections import defaultdict

import paddle
24
import paddle.utils as utils
25

26
from paddle import fluid, profiler, static
27
from paddle.jit import to_static
28
from paddle.metric import Metric
29
from paddle.static import InputSpec
30
from paddle.fluid import core
31
from paddle.fluid import Variable
32
from paddle.fluid.layers.utils import flatten
33
from paddle.fluid.executor import global_scope, _to_name_str
34
from paddle.fluid.framework import Operator, Parameter, _non_static_mode
35 36
from paddle.fluid.framework import _current_expected_place as _get_device
from paddle.fluid.dygraph.parallel import ParallelEnv
37
from paddle.distributed import fleet
38

39
from .converter import Converter
40
from .helper import ProgramHelper
41
from .cluster import Cluster, get_default_cluster
42 43
from .planner_v2 import Planner
from .parallelizer_v2 import Parallelizer
44 45 46 47
from .dist_op import DistributedOperator
from .dist_saver import DistributedSaver
from .dist_loader import NonIterableGeneratorLoader
from .utils import print_program_with_dist_attr, to_list
48 49
from .utils import get_logger, get_dist_attr
from .process_group import new_process_group, get_all_process_groups
50
from .dist_context import DistributedContext, get_default_distributed_context
51
from .strategy import Strategy
52
from .interface import CollectionNames, get_collection
53 54 55


class Engine:
56
    """
57 58
    An Engine object can provide the full power of auto parallel to users.
    With the help of it, users can easily obtain the abilities of the
59 60 61 62 63 64 65
    distributed training and inference. It also support the dynamic graph and
    static graph at the same time.

    Args:
        model (paddle.nn.Layer, optional): The model is an instance of
            paddle.nn.Layer.
        loss (Loss|Callable|None, optional): The loss can be a `paddle.nn.Layer`
66 67
            instance or any callable function taken the predicted values and
            ground truth values as input. It can be None when there is no loss.
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
            Default: None.
        optimizer (Optimizer|None, optional): The optimizer need to be set in training
            and should be None in eval and predict mode. Default: None.
        metrics (Metric|list[Metric]|None, optional): If metrics is set, all
            metrics will be calculated and output in train/eval mode. Default: None.
        cluster (Cluster|None, optional): The cluster represents the topology information
            about the used physical devices. Default: None. (Unused for now)
        strategy (Strategy|None, optional): The strategy is used to configure the
        parallelization and optimization behaviors. Default: None.

    Examples:

        .. code-block:: python

            import paddle
            import paddle.vision.transforms as T
84
            from paddle.distributed.fleet import auto
85 86 87 88 89 90 91 92 93 94
            from paddle.vision.datasets import MNIST

            transform = T.Compose([
                T.Transpose(),
                T.Normalize([127.5], [127.5])
            ])
            train_dataset = MNIST(mode='train', transform=transform)
            valid_dataset = MNIST(mode='test', transform=transform)

            model = paddle.vision.models.LeNet()
95
            loss = paddle.nn.CrossEntropyLoss()
96 97 98 99
            optimizer = paddle.optimizer.Adam(
                learning_rate=0.001, parameters=model.parameters())
            metrics = paddle.metric.Accuracy(topk=(1, 2))

100 101
            engine = auto.Engine(model, loss, optimizer, metrics)
            # fit
102 103 104
            engine.fit(train_dataset,
                       epochs=2,
                       batch_size=64)
105
            # evaluate
106 107 108 109 110 111 112
            engine.evaluate(valid_dataset,
                            batch_size=64)
            # predict
            engine.predict(valid_dataset,
                           batch_size=64)
            # save
            engine.save("./my_model")
113
            # load
114 115 116
            engine.load("./my_model")

    """
117

118 119
    def __init__(self,
                 model=None,
120 121 122
                 loss=None,
                 optimizer=None,
                 metrics=None,
123
                 cluster=None,
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
                 strategy=None):

        if model and not isinstance(model,
                                    paddle.nn.Layer) and not callable(model):
            raise TypeError(
                "'model must be sub classes of `paddle.nn.Layer` or any callable function."
            )
        self._model = model

        if loss and not isinstance(loss,
                                   paddle.nn.Layer) and not callable(loss):
            raise TypeError(
                "'loss' must be sub classes of `paddle.nn.Layer` or any callable function."
            )
        self._loss = loss

        if optimizer and not isinstance(
                optimizer,
            (paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer)):
            raise TypeError(
                "'optimizer' must be object of class `paddle.optimizer.Optimizer`"
                " or `paddle.fluid.optimizer.Optimizer`.")
        self._optimizer = self._validate_opt(optimizer)

        metrics = metrics or []
        for metric in to_list(metrics):
            assert isinstance(metric, Metric), \
                "{} is not sub class of Metric".format(
                    metric.__class__.__name__)
        self._metrics = to_list(metrics)

        if cluster and not isinstance(cluster, Cluster):
            raise TypeError(
                "'cluster' must be the object or class `paddle.distributed.auto_parallel.Cluster`"
            )
        self._cluster = cluster or get_default_cluster()

        if strategy and not isinstance(strategy, Strategy):
            raise TypeError(
                "'strategy' must be object of class `paddle.distributed.auto_parallel.Strategy`"
            )
        self._strategy = strategy or Strategy()

        if os.getenv("POD_NAME"):
            print("Distribute training by paddle.distributed.launch",
                  flush=True)
            fleet.init(is_collective=True)
171

172
        self._executor = None
173 174 175
        self._cur_rank = paddle.distributed.get_rank()
        self._nranks = paddle.distributed.get_world_size()
        self._saver = DistributedSaver()
176

177
        self._logger = get_logger(logging.INFO)
178

179 180
        self._orig_main_prog = static.default_main_program()
        self._orig_startup_prog = static.default_startup_program()
181
        self._orig_dist_context = get_default_distributed_context()
182
        self._dist_contexts = {}
183 184
        self._serial_main_progs = {}
        self._serial_startup_progs = {}
185 186 187 188
        self._dist_main_progs = defaultdict(dict)  # dist main programs
        self._dist_startup_progs = defaultdict(dict)  # dist startup programs
        self._feed_vars = {}
        self._fetch_vars = {}
189
        self._planners = {}
190 191 192 193 194
        self._mode_init_states = {
            "train": False,
            "eval": False,
            "predict": False
        }
195 196

        self._planned_mode = None
197 198
        self._dygraph_mode = False
        self._tuning = self._strategy.tuning
199

200
    def _prepare_program(self, mode):
201
        # Do the build process
202 203 204 205
        self._build(mode)
        # Do the planning process
        self._plan(mode)
        # Do the parallel process
206
        self._parallel(mode)
207 208 209
        # Init comm and startup program
        self._initialize(mode)
        self._mode_init_states[mode] = True
210

211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266
    def _prepare_feed(self, user_feeds=None, mode="train"):
        if user_feeds is not None:
            assert isinstance(user_feeds, dict), \
                "user_feeds must be a dict, but receive {}".format(type(user_feeds).__name__)
        feeds = {}
        # TODO: add inputs and labels feed dict
        for name, var in get_collection(CollectionNames.FEEDS):
            assert name is not None, "No name defined for feed var"
            feeds[name] = var
        if user_feeds is not None:
            for name, var in user_feeds.items():
                feeds[name] = var
        return feeds

    def _prepare_fetch(self, user_fetches=None, mode="train"):
        if user_fetches is not None:
            assert isinstance(user_fetches, list), \
                "user_fetches must be a list, but receive {}".format(type(user_fetches).__name__)
        fetch_names = []
        fetch_new_names = []
        fetch_sections = {}
        cnt = 0

        def _process_section(section_name, var_list):
            nonlocal cnt
            section_start = cnt
            for var in var_list:
                new_name = None
                # Rename the loss
                if section_name == "loss":
                    new_name = "loss"
                if isinstance(var, tuple):
                    assert len(var) == 2, "Length of tuple {} must be 2".format(
                        var)
                    new_name, var = var
                if self._is_local_var(var) and var.name not in fetch_names:
                    fetch_names.append(var.name)
                    fetch_new_names.append(var.name)
                    cnt += 1
                if self._is_local_var(var) and new_name is not None:
                    fetch_new_names[fetch_names.index(var.name)] = new_name
            section_end = cnt
            fetch_sections[section_name] = (section_start, section_end)

        for name, var_list in self._fetch_vars[mode].items():
            if name == "loss" and mode != "predict":
                _process_section("loss", var_list)
            if name == "metrics" and mode != "predict":
                _process_section("metrics", var_list)
            if name == "outputs" and mode == "predict":
                _process_section("metrics", var_list)
        var_list = (get_collection(CollectionNames.FETCHES)
                    or []) + (user_fetches or [])
        _process_section("user_fetches", var_list)
        return fetch_names, fetch_new_names, fetch_sections

267
    def _build(self, mode):
268
        if _non_static_mode() or self._dygraph_mode:
269
            paddle.disable_static()
270 271 272
            self._dygraph_mode = True
            self._logger.info("Building model with 'to_static' method.")

273 274
            inputs_spec = self.inputs_spec
            labels_spec = self.labels_spec if self.labels_spec else []
275
            self.program_helper = ProgramHelper(self._model, self._loss,
276 277
                                                self._metrics, inputs_spec,
                                                labels_spec)
278
            # build forward main program
279
            self.program_helper.build_program(mode)
280

281 282 283
            self.concrete_program = self.program_helper.concrete_program
            serial_main_prog = self.program_helper.main_program
            serial_startup_prog = self.program_helper.startup_program
284

285 286 287 288 289
            inputs = self.program_helper.input_vars
            outputs = self.program_helper.output_vars
            labels = self.program_helper.label_vars
            losses = self.program_helper.loss_vars
            metrics = self.program_helper.metric_vars
290

291
            paddle.enable_static()
292 293 294 295 296 297 298 299 300 301
        else:
            # build program in static mode
            serial_main_prog = self._serial_main_progs.get(mode, None)
            if serial_main_prog is not None:
                return

            losses = []
            metrics = []
            serial_main_prog = self._orig_main_prog.clone()
            serial_startup_prog = self._orig_startup_prog.clone()
J
JZ-LIANG 已提交
302 303
            with static.program_guard(serial_main_prog, serial_startup_prog), \
                utils.unique_name.guard():
304 305 306 307
                inputs_spec = self.inputs_spec
                labels_spec = self.labels_spec if self.labels_spec else []
                inputs = [s._create_feed_layer() for s in inputs_spec]
                labels = [s._create_feed_layer() for s in labels_spec]
308
                outputs = to_list(self._model(*inputs))
309 310 311 312 313 314 315
                if mode != "predict" and self._loss:
                    losses = to_list(self._loss(*(outputs + labels)))

                if mode != "predict":
                    for metric in self._metrics:
                        metrics.extend(
                            to_list(metric.compute(*(outputs + labels))))
316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331

        default_ctx = get_default_distributed_context()
        if not default_ctx.has_annotation:
            # We build the world process group because the data parallel
            # needs all ranks by default.
            new_process_group(list(range(self._nranks)))
            default_ctx.data_parallel = True

        feed_vars = {"inputs": inputs, "labels": labels}

        fetch_vars = {
            "outputs": flatten(outputs),
            "loss": losses,
            "metrics": metrics
        }

332 333 334
        if mode != "train":
            serial_main_prog = serial_main_prog.clone(for_test=True)

335
        self._set_recompute_ckpts()
336 337
        self._dist_contexts[mode] = DistributedContext(
            serial_main_prog, serial_startup_prog, self._optimizer, losses,
338 339
            feed_vars, fetch_vars, self._cluster, self._strategy)
        self._dist_contexts[mode].gradient_scale = self._strategy.gradient_scale
340

341 342 343
    def _optimization_tuning(self, mode, dataset, batch_size):
        if not self._tuning.enable:
            raise ValueError("Please set `tuning.enable=True`.")
344

345 346 347 348 349 350 351 352
        assert mode == "train"
        # Do the build process
        self._build(mode)
        # Do the planning process
        self._plan(mode)

        dataset.dp_world_size = self._dp_world_sizes
        dataset.dp_rank = self._dp_ranks
353 354

        from .tuner.optimization_tuner import OptimizationTuner
355
        self._optimization_tuner = OptimizationTuner(self._tuning.to_dict(),
356 357 358 359 360 361 362 363 364
                                                     self._dist_contexts[mode],
                                                     dataset,
                                                     self.inputs_spec,
                                                     self.labels_spec,
                                                     batch_size=batch_size,
                                                     rank=self._cur_rank)

        self._optimization_tuner.tune()

365
        if self._tuning.run_after_tuning:
366 367 368 369
            # update the strategy
            self._dist_contexts[
                mode]._strategy = self._optimization_tuner.get_best_config()

370 371 372 373 374 375
    def _plan(self, mode):
        if self._planned_mode is None:
            self._planned_mode = mode
        else:
            self._init_dist_context(mode)

376 377
        self._planners[mode] = Planner(mode, self._dist_contexts[mode])
        self._planners[mode].plan()
378

379 380 381 382 383 384 385 386 387
        # infer data parallel info
        inputs_var = self._dist_contexts[mode].serial_feed_vars["inputs"]
        labels_var = self._dist_contexts[mode].serial_feed_vars["labels"]
        block = self._dist_contexts[mode].serial_main_program.global_block()
        feed_list = []
        for var in inputs_var + labels_var:
            if var.name in block.vars:
                feed_list.append(block.vars[var.name])

388 389
        self._dp_world_sizes = []
        self._dp_ranks = []
390 391 392
        for feed_var in feed_list:
            dp_world_size, dp_rank = self._get_input_split_info(
                feed_var, self._dist_contexts[mode])
393 394
            self._dp_world_sizes.append(dp_world_size)
            self._dp_ranks.append(dp_rank)
395

396
    def _parallel(self, mode, all_ranks=False):
397 398 399
        # Parallelize program based on the planner's results
        # For now, the completer has to be passed to the planner,
        # because we may use it to complete the annotation of the backwarkward and update.
400
        parallelizer = Parallelizer(mode, self._planners[mode].completer,
401 402 403 404 405
                                    self._dist_contexts[mode])
        if not all_ranks:
            parallelizer.parallel(self._cur_rank)
        else:
            parallelizer.parallel_all()
406 407

    def _init_dist_context(self, mode):
408
        # Init dist_context['mode'] with the first planned dist_context
409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425
        # to guarantee that train/eval/predict mode have same parallel strategy
        dist_context = self._dist_contexts[mode]
        origin_main_prog = dist_context._original_serial_main_program
        ref_mode = self._planned_mode
        ref_dist_context = self._dist_contexts[ref_mode]
        ref_origin_main_prog = ref_dist_context._original_serial_main_program
        ref_blocks = ref_origin_main_prog.blocks
        for ib, block in enumerate(origin_main_prog.blocks):
            for iop, op in enumerate(block.ops):
                ref_op = ref_blocks[ib].ops[iop]
                assert op.type == ref_op.type, \
                    "'{}' mode op '{}' is different with '{}' op '{}'. ".format(mode, op.type, ref_mode, ref_op.type)
                ref_op_dist_attr = ref_dist_context.get_op_dist_attr_for_program(
                    ref_op)
                dist_context.set_op_dist_attr_for_program(op, ref_op_dist_attr)

    def _initialize(self, mode):
426
        # Get the current content from the distributed context
427 428 429 430
        self._serial_main_progs[mode] = self._dist_contexts[
            mode].serial_main_program
        self._serial_startup_progs[mode] = self._dist_contexts[
            mode].serial_startup_program
431 432 433 434
        self._dist_main_progs[mode] = self._dist_contexts[
            mode].dist_main_programs
        self._dist_startup_progs[mode] = self._dist_contexts[
            mode].dist_startup_programs
435 436
        self._feed_vars[mode] = self._dist_contexts[mode].serial_feed_vars
        self._fetch_vars[mode] = self._dist_contexts[mode].serial_fetch_vars
437
        self._lr_optimizer = self._dist_contexts[mode]._lr_optimizer
438

439 440 441 442
        if self._nranks > 1:
            # Traverse different rank programs and traverse each op of them,
            # instantiate communication by process_mapping.
            all_process_groups = get_all_process_groups()
443

444
            # NOTE: add the comm init control in the future for auto search
445 446 447 448
            for process_group in all_process_groups:
                if self._cur_rank not in process_group.ranks:
                    continue
                process_group.instantiate()
449

450 451 452
        place = _get_device()
        if isinstance(place, fluid.CUDAPlace):
            place = fluid.CUDAPlace(ParallelEnv().dev_id)
453

454 455 456 457 458
        if self._strategy.seed:
            paddle.seed(self._strategy.seed + self._dp_ranks[0])
            np.random.seed(self._strategy.seed + self._dp_ranks[0])
            random.seed(self._strategy.seed + self._dp_ranks[0])

459
        if self._dygraph_mode:
460 461 462
            dist_context = self._dist_contexts[mode]
            dist_main_program = self._dist_main_progs[mode][self._cur_rank]
            self.program_helper.init(dist_main_program, place, dist_context)
463

464
        if self._executor is None:
465
            self._executor = paddle.static.Executor(place)
466 467 468 469 470 471 472 473 474 475
            uninitialized = []
            dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank]
            for var in dist_startup_prog.list_vars():
                scope_var = global_scope().find_var(var.name)
                if scope_var and scope_var.get_tensor()._is_initialized():
                    continue
                uninitialized.append(var)
            if uninitialized:
                prune_startup_prog = dist_startup_prog._prune(uninitialized)
                self._executor.run(prune_startup_prog)
476

477 478 479 480 481 482 483 484 485
            if hasattr(self, "_state_dict") and hasattr(self, "_dist_attr"):
                self._set_state_dict(mode, self._strict, self._state_dict,
                                     self._dist_attr)

        if self._strategy.reinit:
            self._logger.info("NOTE: parameters wiil be re-initialized.")
            dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank]
            self._executor.run(dist_startup_prog)

486
    def _split_sample_item(self, data, split):
487 488
        if isinstance(data, paddle.io.IterableDataset):
            if split is None:
489
                inputs, labels = next(iter(data))
490 491
            else:
                sample = next(iter(data))
492 493
                inputs = sample[:split]
                labels = sample[split:]
494 495
        elif isinstance(data, paddle.io.Dataset):
            if split is None:
496
                inputs, labels = data[0]
497 498
            else:
                sample = data[0]
499 500
                inputs = sample[:split]
                labels = sample[split:]
501 502 503 504
        else:
            raise ValueError(
                "Data should be a Dataset or IterableDatset, but received {}.".
                format(type(data).__name__))
505 506 507
        inputs = to_list(inputs)
        labels = to_list(labels)
        return inputs, labels
508

509
    def _infer_sample_spec(self, inputs, labels, batch_size):
510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528
        self.inputs_spec = []
        self.labels_spec = []

        def _infer_item_spec(item, name, batch_size, specs):
            if isinstance(item, np.ndarray):
                spec = InputSpec.from_numpy(item, name)
                if batch_size is None:
                    specs.append(spec)
                else:
                    specs.append(spec.batch(batch_size))
            elif isinstance(item, (Variable, core.VarBase, core.eager.Tensor)):
                spec = InputSpec.from_tensor(item, name)
                if batch_size is None:
                    specs.append(spec)
                else:
                    specs.append(spec.batch(batch_size))
            else:
                specs.append(InputSpec([batch_size], type(item), name))

529 530
        if inputs is not None:
            for i, item in enumerate(inputs):
531 532 533
                assert item is not None, "Receive None input."
                name = "input" + str(i)
                _infer_item_spec(item, name, batch_size, self.inputs_spec)
534 535
        if labels is not None:
            for i, item in enumerate(labels):
536 537 538 539 540 541 542
                assert item is not None, "Receive None input."
                name = "label" + str(i)
                _infer_item_spec(item, name, batch_size, self.labels_spec)

        self.inputs_spec = self._validate_spec(self.inputs_spec)
        self.labels_spec = self._validate_spec(self.labels_spec)

543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572
    def __call__(self,
                 inputs=None,
                 labels=None,
                 feeds=None,
                 fetches=None,
                 mode="train"):
        feed_dict = self._prepare_feed(feeds, mode)
        fetch_list, fetch_new_names, fetch_sections = self._prepare_fetch(
            fetches, mode)
        try:
            outs = self._executor.run(
                self.main_program,
                feed=feed_dict,
                fetch_list=fetch_list,
                use_program_cache=self._strategy.use_cache,
                return_numpy=self._strategy.return_numpy)
        except core.EOFException:
            pass
        self._print_log(outs, self.mode, None, None, None, fetch_new_names,
                        fetch_sections)
        return outs

    # TODO: need a better to print the log
    def _print_log(self,
                   outs,
                   mode="train",
                   epoch=None,
                   step=None,
                   lr=None,
                   fetch_new_names=None,
573 574
                   fetch_sections=None,
                   profiler_log=""):
575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599
        prefix = "[{}] ".format(mode)
        logs = {}
        if epoch is not None:
            logs["epoch: {:d} "] = epoch
        if step is not None:
            logs["step: {:d} "] = step
        if lr is not None:
            logs["lr: {:5e} "] = lr
        if fetch_sections is not None:
            assert fetch_new_names is not None
            for section_name, section in fetch_sections.items():
                section_start, section_end = section
                if section_name == "metrics" and section_start < section_end:
                    metric_out = outs[section_start:section_end]
                    for metric in self._metrics:
                        metric.update(*metric_out)
                        results = metric.accumulate()
                        for i, res in enumerate(to_list(results)):
                            logs[metric.name()[i] + ": {:8f} "] = res
                elif section_name == "loss" and section_start < section_end:
                    for i in range(section_start, section_end):
                        logs[fetch_new_names[i] + ": {:8f} "] = outs[i][0]
                else:
                    for i in range(section_start, section_end):
                        logs[fetch_new_names[i] + ": {} "] = outs[i]
600
        string = prefix + ''.join(list(logs.keys())) + profiler_log
601 602
        self._logger.info(string.format(*list(logs.values())))

603 604
    def fit(self,
            train_data,
605
            train_sample_split=None,
606 607 608
            batch_size=1,
            epochs=1,
            steps_per_epoch=None,
609 610 611 612
            valid_data=None,
            valid_sample_split=None,
            valid_freq=1,
            valid_steps=None,
613
            collate_fn=None,
614 615 616 617 618 619 620 621 622
            callbacks=None):
        """
        Trains the model for a fixed number of epochs. If `valid_data` is set,
        evaluation will be done at the end of each epoch.

        Args:
            train_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None.
            train_sample_split (int, optional): Each sample of the train dataset is assumed
                to be a (input, label) pair by default and has two items. If each sample has
623
                more than two items, train_sample_split specifies how to split these items into
624
                input and label. The items before it are input and the left are label. Default: None.
625
            batch_size (int, optional): The batch size of train_data and valid_data if provided.
626 627 628
                The user's data will be used directly without batching if set to None. Default: 1.
            epochs (int, optional): The number of epochs to train the model. Default: 1.
            steps_per_epoch (int, optional): The total number of steps (batches of samples)
629
                is executed in one epoch before stating the next one. If None, it is equal to
630 631
                the number samples in your dataset divided by the batch size. Default: None.
            valid_data (Dataset, optional): An instance of paddle paddle.io.Dataset used for
632
                evaluation at the end of epoch. No evaluation will be done if set to None.
633
                Default: None. (Unsupported for now)
634
            valid_freq (int, optional): Only relevant if valid_data is provided. This specifies
635 636
                how many training epochs before a new evaluation is performed. Default: 1.
            valid_sample_split (int, optional): Only relevant if valid_data is provided.
637 638
                Each sample of the valid dataset is assumed to be a (input, label) pair
                by default and has two items. If each sample has more than two items,
639 640 641
                valid_sample_split specifies how to split these items into input and label.
                The items before it are input and the left are label. Default: None.
            valid_steps (int, optional): Only relevant if valid_data is provided.
642 643
                It is the total number of steps (batches of samples) to draw before
                stopping validation at the end of every epoch. If None, validation will run until the
644 645 646 647
                `valid_data` dataset is exhausted. The validation will start from the
                beginning of the dataset at each epoch. Default: None.
            collate_fn(callable, optional): function to generate mini-batch data by merging
                the sample list, None for only stack each fields of sample in axis
648
                0. Default None.
649 650 651 652 653 654 655 656 657 658 659 660
            callbacks (Callback|None, optional): A list of `Callback` instances to apply
                during training. Default: None. (Unused for now)

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle
                import paddle.vision.transforms as T
661
                from paddle.distributed.fleet import auto
662 663 664 665 666 667 668 669 670
                from paddle.vision.datasets import MNIST

                transform = T.Compose([
                    T.Transpose(),
                    T.Normalize([127.5], [127.5])
                ])
                train_dataset = MNIST(mode='train', transform=transform)

                model = paddle.vision.models.LeNet()
671
                loss = paddle.nn.CrossEntropyLoss()
672 673 674 675
                optimizer = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=model.parameters())
                metrics = paddle.metric.Accuracy(topk=(1, 2))

676
                engine = auto.Engine(model, loss, optimizer, metrics)
677 678 679 680
                engine.fit(train_dataset,
                           epochs=2,
                           batch_size=64)
        """
681
        self.mode = 'train'
682 683
        inputs, labels = self._split_sample_item(train_data, train_sample_split)
        self._infer_sample_spec(inputs, labels, batch_size)
684
        if not self._mode_init_states[self.mode]:
685
            self._prepare_program(self.mode)
Z
zhaoyingli 已提交
686 687
        else:
            self._switch_mode("train")
688

689
        assert self.mode in self._dist_main_progs, \
690 691 692 693 694 695 696
            "train model is not ready, please call `engine._prepare_program('train')` first."
        train_dataloader = self._prepare_dataloader(train_data, batch_size,
                                                    epochs, steps_per_epoch,
                                                    collate_fn)

        fetch_list, fetch_new_names, fetch_sections = self._prepare_fetch(
            mode=self.mode)
697
        lr_scheduler = self._get_lr_scheduler(self.main_program)
698

699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726
        with profiler.Profiler(timer_only=True) as prof:
            for epoch in range(epochs):
                for step, _ in enumerate(train_dataloader):
                    try:
                        outs = self._executor.run(
                            self.main_program,
                            fetch_list=fetch_list,
                            use_program_cache=self._strategy.use_cache,
                            return_numpy=self._strategy.return_numpy)
                    except core.EOFException:
                        break
                    if lr_scheduler and step % self._k_steps == 0:
                        lr_scheduler.step()
                    lr = self._get_lr(self._lr_optimizer)

                    prof.step()

                    self._print_log(outs, self.mode, epoch, step, lr,
                                    fetch_new_names, fetch_sections,
                                    prof.step_info())

                if valid_data and epoch % valid_freq == 0:
                    self.evaluate(valid_data, valid_sample_split, batch_size,
                                  valid_steps, collate_fn, callbacks)
                    self._switch_mode("train")
                else:
                    self._reset_metrics()
            return outs
727

728
    def evaluate(self,
729 730
                 valid_data,
                 valid_sample_split=None,
731
                 batch_size=1,
732
                 steps=None,
733
                 collate_fn=None,
734 735 736 737 738
                 callbacks=None):
        """
        Evaluate the loss and metrics of the model on evaluation data.

        Args:
739 740
            valid_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None.
            valid_sample_split (int, optional): Each sample of the eval dataset is assumed
741
                to be a (input, label) pair by default and has two items. If each sample has
742
                more than two items, valid_sample_split specifies how to split these items into
743
                input and label. The items before it are input and the left are label. Default: None.
744
            batch_size (int, optional): The batch size of valid_data. The user's data will
745
                be used directly without batching if set to None. Default: 1.
746 747
            steps (int, optional): It is the total number of steps (batches of samples) to draw before
                stopping evaluation. If None, evaluation will run until the `valid_data` dataset is exhausted.
748 749 750 751 752
                The evaluation will start from the beginning of the dataset in each run. Default: None.
            collate_fn(callable, optional): function to generate mini-batch data by merging
                the sample list, None for only stack each fields of sample in axis
                0. Default None.
            callbacks (Callback|None, optional): A list of `Callback` instances to apply
753
                during evaluating. Default: None. (Unused for now)
754 755 756 757 758 759 760 761 762 763

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle
                import paddle.vision.transforms as T
764
                from paddle.distributed.fleet import auto
765 766 767 768 769 770 771 772 773
                from paddle.vision.datasets import MNIST

                transform = T.Compose([
                    T.Transpose(),
                    T.Normalize([127.5], [127.5])
                ])
                valid_dataset = MNIST(mode='test', transform=transform)

                model = paddle.vision.models.LeNet()
774
                loss = paddle.nn.CrossEntropyLoss()
775 776
                metrics = paddle.metric.Accuracy(topk=(1, 2))

777
                engine = auto.Engine(model, loss, metrics=metrics)
778 779 780
                engine.evaluate(valid_dataset, batch_size=64)

        """
781
        self.mode = 'eval'
782 783
        inputs, labels = self._split_sample_item(valid_data, valid_sample_split)
        self._infer_sample_spec(inputs, labels, batch_size)
784
        if not self._mode_init_states[self.mode]:
785
            self._prepare_program(self.mode)
Z
zhaoyingli 已提交
786 787
        else:
            self._switch_mode("eval")
788

789
        assert self.mode in self._dist_main_progs, \
790 791 792 793 794
            "eval model is not ready, please call `engine._prepare_program('eval')` first."
        valid_dataloader = self._prepare_dataloader(valid_data,
                                                    batch_size,
                                                    steps_per_epoch=steps,
                                                    collate_fn=collate_fn)
795

796 797
        fetch_list, fetch_new_names, fetch_sections = self._prepare_fetch(
            mode=self.mode)
798

799 800
        outputs = defaultdict(list)
        for step, _ in enumerate(valid_dataloader):
801
            try:
802 803 804 805 806 807
                outs = self._executor.run(
                    self.main_program,
                    fetch_list=fetch_list,
                    use_program_cache=self._strategy.use_cache,
                    return_numpy=self._strategy.return_numpy)
            except core.EOFException:
808
                break
809 810
            self._print_log(outs, self.mode, None, step, None, fetch_new_names,
                            fetch_sections)
811 812
        self._reset_metrics()
        return outputs
813

814 815
    def predict(self,
                test_data,
816
                test_sample_split=None,
817
                batch_size=1,
818
                steps=None,
819
                collate_fn=None,
820 821 822 823 824 825 826 827
                callbacks=None):
        """
        Compute the output predictions on testing data.

        Args:
            test_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None.
            test_sample_split (int, optional): Each sample of the test dataset is assumed
                to be a (input, label) pair by default and has two items. If each sample has
828
                more than two items, test_sample_split specifies how to split these items into
829 830 831
                input and label. The items before it are input and the left are label. Default: None.
            batch_size (int, optional): The batch size of test_data. The user's data will
                be used directly without batching if set to None. Default: 1.
832 833
            steps (int, optional): It is the total number of steps (batches of samples) to draw before
                stopping predict. If None, predict will run until the `test_data` dataset is exhausted.
834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849
                The predict will start from the beginning of the dataset in each run. Default: None.
            collate_fn(callable, optional): function to generate mini-batch data by merging
                the sample list, None for only stack each fields of sample in axis
                0. Default None.
            callbacks (Callback|None, optional): A list of `Callback` instances to apply
                during testing. Default: None. (Unused for now)

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle
                import paddle.vision.transforms as T
850
                from paddle.distributed.fleet import auto
851 852 853 854 855 856 857 858 859 860
                from paddle.vision.datasets import MNIST

                transform = T.Compose([
                    T.Transpose(),
                    T.Normalize([127.5], [127.5])
                ])
                valid_dataset = MNIST(mode='test', transform=transform)

                model = paddle.vision.models.LeNet()

861
                engine = auto.Engine(model)
862 863
                engine.predict(valid_dataset, batch_size=64)
        """
864
        self.mode = 'predict'
865 866
        inputs, labels = self._split_sample_item(test_data, test_sample_split)
        self._infer_sample_spec(inputs, labels, batch_size)
867
        if not self._mode_init_states[self.mode]:
868
            self._prepare_program(self.mode)
Z
zhaoyingli 已提交
869 870
        else:
            self._switch_mode("predict")
871

872
        assert self.mode in self._dist_main_progs, \
873 874 875 876 877
            "predict model is not ready, please call `engine._prepare_program('predict')` first."
        test_dataloader = self._prepare_dataloader(test_data,
                                                   batch_size,
                                                   steps_per_epoch=steps,
                                                   collate_fn=collate_fn)
878

879 880
        fetch_list, fetch_new_names, fetch_sections = self._prepare_fetch(
            mode=self.mode)
881

882
        for step, _ in enumerate(test_dataloader):
883
            try:
884 885 886 887 888 889
                outs = self._executor.run(
                    self.main_program,
                    fetch_list=fetch_list,
                    use_program_cache=self._strategy.use_cache,
                    return_numpy=self._strategy.return_numpy)
            except core.EOFException:
890
                break
891 892
            self._print_log(outs, self.mode, None, step, None, fetch_new_names,
                            fetch_sections)
893

894
        return outs
895

896 897
    def _tune(self, tune_data, tune_sample_split=None, batch_size=1):
        self.mode = 'train'
898 899
        inputs, labels = self._split_sample_item(tune_data, tune_sample_split)
        self._infer_sample_spec(inputs, labels, batch_size)
900 901
        self._optimization_tuning(self.mode, tune_data, batch_size)

902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928
    def dataloader(self,
                   dataset,
                   sample_split=1,
                   batch_size=1,
                   epochs=1,
                   steps_per_epoch=None,
                   collate_fn=None,
                   mode="train",
                   from_generator=True):
        assert from_generator, "Only support from_generator for now"
        self.mode = mode
        inputs, labels = self._split_sample_item(dataset, sample_split)
        self._infer_sample_spec(inputs, labels, batch_size)
        if not self._mode_init_states[self.mode]:
            self._prepare_program(self.mode)
        else:
            self._switch_mode("train")
        dataloader = self._prepare_dataloader(dataset, batch_size, epochs,
                                              steps_per_epoch, collate_fn)
        return dataloader

    def _prepare_dataloader(self,
                            dataset,
                            batch_size,
                            epochs=1,
                            steps_per_epoch=None,
                            collate_fn=None):
929

930 931 932 933
        if self._strategy.gradient_merge and batch_size is not None:
            assert batch_size % self._k_steps == 0, \
                "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(batch_size, self._k_steps)
            batch_size //= self._k_steps
934

935 936 937 938
        dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank]
        dist_startup_prog = self._dist_startup_progs[self.mode][self._cur_rank]
        dist_context = self._dist_contexts[self.mode]
        dist_main_block = dist_main_prog.global_block()
939

940 941 942 943
        # NOTE: Get feed_list, then insert dataloader op with sharded var shape.
        # Cause predict_program does not contain labels var,
        # then we will add labels var from serial_program to dist_program,
        # that maintains the length of feed_list equal to the length of dataset's values.
944 945 946 947 948 949
        inputs_var = self._feed_vars[self.mode]["inputs"]
        labels_var = self._feed_vars[self.mode]["labels"]
        feed_list = []
        for var in inputs_var + labels_var:
            if var.name in dist_main_block.vars:
                feed_list.append(dist_main_block.vars[var.name])
950 951 952 953
            else:
                copy_var = dist_main_block._clone_variable(var, var.persistable)
                copy_var.desc.set_original_id(var.desc.original_id())
                feed_list.append(copy_var)
954 955

        # remove the first three ops if multi run fit/evaluate/predict
956
        op_size = len(dist_main_block.ops)
957 958 959 960
        if dist_main_block.ops[0].type == 'create_py_reader':
            op_size -= 3
            for _ in range(3):
                dist_main_block._remove_op(0, sync=False)
961 962

        # insert read op at the end of program
963
        places = paddle.static.cuda_places()
964
        with static.program_guard(dist_main_prog, dist_startup_prog):
965
            dataloader = NonIterableGeneratorLoader(
966 967 968 969 970 971
                dataset,
                feed_list,
                places,
                batch_size,
                epochs,
                steps_per_epoch,
972
                collate_fn,
973 974 975
                data_parallel_world_size=self._dp_world_sizes,
                data_parallel_rank=self._dp_ranks,
                split_data=self._strategy.split_data)
976 977

        # move read op from the end of program to the start of program
978
        new_op_size = len(dist_main_block.ops)
979
        for _ in range(new_op_size - 1, op_size - 1, -1):
980 981 982
            op = dist_main_block.ops[new_op_size - 1]
            new_op_desc = dist_main_block.desc._prepend_op()
            new_op_desc.copy_from(op.desc)
983 984 985
            new_op = Operator(dist_main_block,
                              new_op_desc,
                              type=new_op_desc.type())
986 987 988 989 990 991 992 993
            dist_main_block.ops.insert(0, new_op)
            dist_op = DistributedOperator(new_op)
            dist_context.add_dist_op_for_program(dist_op)
        for _ in range(new_op_size - op_size):
            dist_main_block._remove_op(new_op_size, sync=False)
        dist_main_block._sync_with_cpp()
        return dataloader

994 995
    def _validate_spec(self, specs):
        specs = to_list(specs)
996
        self._k_steps = self._strategy.gradient_merge.k_steps
997 998 999 1000 1001 1002 1003
        if specs is not None:
            for i, spec in enumerate(specs):
                assert isinstance(spec, InputSpec)
                if spec.name is None:
                    raise ValueError(
                        "Requires Input[{}].name != None, but receive `None` with {}."
                        .format(i, spec))
1004
                if self._k_steps > 1:
1005
                    shape = list(spec.shape)
1006 1007 1008
                    assert shape[0] % self._k_steps == 0, \
                        "Requires batch_size[{}] to be divisible by k_steps[{}].".format(spec.shape[0], self._k_steps)
                    shape[0] //= self._k_steps
1009
                    spec.shape = shape
1010 1011
        return specs

1012 1013 1014 1015
    def _is_local_var(self, var):
        var_name = _to_name_str(var)
        return var_name in self.main_program.global_block().vars

1016 1017
    def _get_input_split_info(self, var, dist_context):
        # deduce how the input data is split among the cluster
1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036
        from .utils import _get_comm_group, _get_corresponding_rank

        tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
        process_mesh = tensor_dist_attr.process_mesh
        dims_mapping = tensor_dist_attr.dims_mapping

        if self._cur_rank not in process_mesh.processes:
            rank_id = _get_corresponding_rank(dist_context, process_mesh,
                                              self._cur_rank)
        else:
            rank_id = self._cur_rank

        batch_size_axis = dims_mapping[0]
        if batch_size_axis > -1 and process_mesh.topology[batch_size_axis] > 1:
            group_ranks = _get_comm_group(process_mesh.processes,
                                          process_mesh.topology,
                                          batch_size_axis, rank_id)
            return len(group_ranks), group_ranks.index(rank_id)

1037
        return 1, 0
1038

1039 1040 1041 1042
    def _set_recompute_ckpts(self):
        # NOTE hack to enable recompute in engine api for GPT-3
        # TODO support more PaddleNLP/CV models here

1043
        recompute = self._strategy.recompute
1044 1045

        # extract ckpts by specific model
1046
        if isinstance(self._model, paddle.nn.Layer):
Z
zhaoyingli 已提交
1047 1048 1049 1050
            if hasattr(self._model,
                       "gpt") and self._model.__class__.__name__ in [
                           'GPTForPretraining', 'GPTForPretrainingAuto'
                       ]:
1051
                exact_ckpts = self._model.gpt.checkpoints
1052
            else:
1053
                exact_ckpts = recompute.checkpoints
1054
        else:
1055
            exact_ckpts = recompute.checkpoints
1056 1057

        # modify strategy
1058 1059
        if recompute.enable:
            recompute.checkpoints = exact_ckpts[:]
1060
            logs = {
1061
                'Model Class': self._model.__class__.__name__,
1062 1063 1064 1065
                'Applied Recompute ckpts': exact_ckpts
            }
            self._logger.info(logs)

1066
    def _validate_opt(self, optimizer):
1067 1068 1069
        if optimizer is not None:
            optimizer._parameter_list = None
            optimizer._param_groups = None
1070 1071
        return optimizer

1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088
    def _reset_metrics(self):
        for metric in self._metrics:
            metric.reset()

    def _switch_mode(self, mode):
        self.mode = mode
        self._initialize(mode)

    def _set_state_dict(self, mode, strict, state_dict, dist_attr):
        program = self._dist_main_progs[mode][self._cur_rank]
        dist_context = self._dist_contexts[mode]
        cur_dist_attr = get_dist_attr(program, dist_context)
        converter = Converter(state_dict, dist_attr, cur_dist_attr)
        state_dict = converter.convert(strict=strict)
        program.set_state_dict(state_dict)

    def save(self, path, training=True):
1089 1090
        """
        Saves the model, parameters, optimizer state to path.
1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110
        If `training` is set to False, only inference model will be saved.

        Args:
            path (str): The file prefix to save model. The format
                is 'dirname/file_prefix' or 'file_prefix'. if empty str.
                A exception will be raised.
            training (bool, optional): Whether to save for training. If not, save
                for inference only. If `training` is set to True, the optimzer state
                will be saved. Otherwise, only the model and parameters are saved.
                This function will silently overwrite existing file at the target
                location. Default: True.

        Returns:
            None

        Examples:

            .. code-block:: python
                import paddle
                import paddle.vision.transforms as T
1111
                from paddle.distributed.fleet import auto
1112 1113 1114 1115 1116 1117 1118 1119 1120
                from paddle.vision.datasets import MNIST

                transform = T.Compose([
                    T.Transpose(),
                    T.Normalize([127.5], [127.5])
                ])
                train_dataset = MNIST(mode='train', transform=transform)

                model = paddle.vision.models.LeNet()
1121
                loss = paddle.nn.CrossEntropyLoss()
1122 1123 1124 1125
                optimizer = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=model.parameters())
                metrics = paddle.metric.Accuracy(topk=(1, 2))

1126
                engine = auto.Engine(model, loss, optimizer, metrics)
1127 1128 1129 1130
                engine.fit(train_dataset,
                           epochs=1,
                           batch_size=64)
                engine.save("./my_model")
1131

1132
        """
1133
        if training:
1134
            assert 'train' in self._serial_main_progs, \
1135
                "training model is not ready, please call `engine._prepare_program('train')` first."
1136 1137 1138
            serial_program = self._serial_main_progs["train"]
            dist_main_prog = self._dist_main_progs["train"][self._cur_rank]
            dist_context = self._dist_contexts["train"]
1139 1140 1141 1142
            self._saver.save(path,
                             serial_program=serial_program,
                             dist_main_program=dist_main_prog,
                             dist_context=dist_context)
1143
        else:
1144
            mode = "predict"
1145 1146 1147
            feed_vars = self._feed_vars[mode]['inputs']
            fetch_vars = self._fetch_vars[mode]['outputs']
            dist_main_prog = self._dist_main_progs[mode][self._cur_rank]
1148 1149 1150 1151 1152
            self._saver.save_inference_model(path,
                                             feed_vars,
                                             fetch_vars,
                                             self._executor,
                                             program=dist_main_prog)
1153

1154 1155 1156 1157 1158 1159
    def load(self, path, strict=True, load_optimizer=True):
        """
        Load the stored model, parameters and optimizer states.

        Args:
            path (str): The prefix of files storing the model states and
1160
                optimizer states.
1161 1162 1163 1164 1165
            strict (bool, optional): Whether to skip the loading of mismatch
                parameter or raise an error when mismatch happens (not found
                the parameter in file storing model states of or receives a
                mismatch shape). Default: False.
            load_optimizer (bool, optional): If True, the stored optimizer
1166
                states is restored. Otherwise, the optimizer states is initialized
1167 1168 1169 1170 1171 1172 1173 1174 1175 1176
                from scratch. Default: False.

        Returns:
            None

        Examples:

            .. code-block:: python
                import paddle
                import paddle.vision.transforms as T
1177
                from paddle.distributed.fleet import auto
1178 1179 1180 1181 1182 1183 1184 1185 1186
                from paddle.vision.datasets import MNIST

                transform = T.Compose([
                    T.Transpose(),
                    T.Normalize([127.5], [127.5])
                ])
                train_dataset = MNIST(mode='train', transform=transform)

                model = paddle.vision.models.LeNet()
1187
                loss = paddle.nn.CrossEntropyLoss()
1188 1189 1190 1191
                optimizer = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=model.parameters())
                metrics = paddle.metric.Accuracy(topk=(1, 2))

1192
                engine = auto.Engine(model, loss, optimizer, metrics)
1193 1194 1195 1196 1197
                engine.fit(train_dataset,
                           epochs=1,
                           batch_size=64)
                engine.save("./my_model")
                engine.load("./my_model")
1198

1199 1200 1201 1202 1203
        """
        self._strict = strict
        self._state_dict, self._dist_attr = self._saver.load(
            path, load_optimizer)
        return self._state_dict, self._dist_attr
1204

1205
    @staticmethod
1206
    def _get_lr_scheduler(program):
1207 1208 1209 1210 1211 1212 1213
        lr_sheduler = None
        if hasattr(program, 'lr_sheduler'):
            from paddle.optimizer.lr import LRScheduler
            lr_sheduler = program.lr_sheduler
            assert isinstance(lr_sheduler, LRScheduler), "must be LRScheduler"
        return lr_sheduler

1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227
    def _get_lr(self, optimizer):
        if isinstance(optimizer, paddle.optimizer.Optimizer):
            return optimizer.get_lr()
        elif isinstance(optimizer, paddle.fluid.optimizer.Optimizer):
            if isinstance(optimizer._learning_rate, float):
                return optimizer._learning_rate
            else:
                return optimizer._learning_rate()
        else:
            raise TypeError(
                    "'optimizer' must be object of class `paddle.optimizer.Optimizer`" \
                        " or `paddle.fluid.optimizer.Optimizer`, but got {}.".format(type(optimizer))
                )

1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254
    @property
    def mode(self):
        return self._mode

    @mode.setter
    def mode(self, mode):
        self._mode = mode

    @property
    def main_program(self):
        return self._dist_main_progs[self.mode][self._cur_rank]

    @property
    def startup_program(self):
        return self._dist_startup_progs[self.mode][self._cur_rank]

    @property
    def dist_context(self):
        return self._dist_contexts[self.mode]

    @property
    def serial_main_program(self):
        return self._serial_main_progs[self.mode]

    @property
    def serial_startup_program(self):
        return self._serial_startup_progs[self.mode]
1255 1256 1257 1258

    @property
    def fetch_vars(self):
        return self._fetch_vars[self.mode]
1259 1260 1261 1262 1263 1264 1265 1266

    @property
    def inputs(self):
        return self.inputs_spec

    @property
    def labels(self):
        return self.labels_spec