engine.py 74.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import copy
16
import json
17
import logging
18
import numbers
19 20
import os
import random
21

22 23
import numpy as np

24
import paddle
25
import paddle.distributed.auto_parallel.static.utils as auto_utils
26
from paddle import static, utils
27
from paddle.distributed import fleet
28 29 30
from paddle.fluid.executor import _to_name_str
from paddle.framework import IrGraph
from paddle.framework import _current_expected_place as _get_device
31
from paddle.framework import core, in_dynamic_mode
32
from paddle.metric import Metric
33
from paddle.static import InputSpec, Operator, Variable, global_scope
34

35 36 37
from ...utils.log_utils import get_logger
from ..interface import CollectionNames, fetch, get_collection
from ..strategy import Strategy
Z
zhaoyingli 已提交
38
from .callbacks import config_callbacks
39
from .cluster import Cluster, get_default_cluster
40 41 42
from .converter import Converter
from .cost.estimate_cost import get_cost_from_engine
from .dist_context import DistributedContext, get_default_distributed_context
43 44
from .dist_loader import (
    DistributedDataLoader,
45
    DistributedDataLoaderFromGenerator,
46
)
47 48 49 50 51 52
from .dist_op import DistributedOperator
from .dist_saver import DistributedSaver
from .helper import ProgramHelper
from .parallelizer_v2 import Parallelizer
from .planner_v2 import Planner
from .process_group import get_all_process_groups, new_process_group
53

54 55

class Engine:
56
    """
57 58
    An Engine object can provide the full power of auto parallel to users.
    With the help of it, users can easily obtain the abilities of the
59 60 61 62 63 64 65
    distributed training and inference. It also support the dynamic graph and
    static graph at the same time.

    Args:
        model (paddle.nn.Layer, optional): The model is an instance of
            paddle.nn.Layer.
        loss (Loss|Callable|None, optional): The loss can be a `paddle.nn.Layer`
66 67
            instance or any callable function taken the predicted values and
            ground truth values as input. It can be None when there is no loss.
68 69 70 71 72 73 74 75 76 77 78 79 80 81
            Default: None.
        optimizer (Optimizer|None, optional): The optimizer need to be set in training
            and should be None in eval and predict mode. Default: None.
        metrics (Metric|list[Metric]|None, optional): If metrics is set, all
            metrics will be calculated and output in train/eval mode. Default: None.
        cluster (Cluster|None, optional): The cluster represents the topology information
            about the used physical devices. Default: None. (Unused for now)
        strategy (Strategy|None, optional): The strategy is used to configure the
        parallelization and optimization behaviors. Default: None.

    Examples:

        .. code-block:: python

82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
            >>> import paddle
            >>> import paddle.vision.transforms as T
            >>> from paddle.distributed.fleet import auto
            >>> from paddle.vision.datasets import MNIST

            >>> transform = T.Compose([
            ...     T.Transpose(),
            ...     T.Normalize([127.5], [127.5])
            >>> ])
            >>> train_dataset = MNIST(mode='train', transform=transform)
            >>> valid_dataset = MNIST(mode='test', transform=transform)

            >>> model = paddle.vision.models.LeNet()
            >>> loss = paddle.nn.CrossEntropyLoss()
            >>> optimizer = paddle.optimizer.Adam(
            ...     learning_rate=0.001, parameters=model.parameters())
            >>> metrics = paddle.metric.Accuracy(topk=(1, 2))

            >>> engine = auto.Engine(model, loss, optimizer, metrics)
            >>> # fit
            >>> engine.fit(train_dataset,
            ...            epochs=2,
            ...            batch_size=64)
            >>> # evaluate
            >>> engine.evaluate(valid_dataset,
            ...                 batch_size=64)
            >>> # predict
            >>> engine.predict(valid_dataset,
            ...                batch_size=64)
            >>> # save
            >>> engine.save("./my_model")
            >>> # load
            >>> engine.load("./my_model")
115 116

    """
117

118 119 120 121 122 123 124 125 126 127 128 129 130 131
    def __init__(
        self,
        model=None,
        loss=None,
        optimizer=None,
        metrics=None,
        cluster=None,
        strategy=None,
    ):
        if (
            model
            and not isinstance(model, paddle.nn.Layer)
            and not callable(model)
        ):
132 133 134 135
            raise TypeError(
                "'model must be sub classes of `paddle.nn.Layer` or any callable function."
            )
        self._model = model
136 137 138
        self._parameter_list = (
            None if not model else [p.name for p in model.parameters()]
        )
139 140 141 142 143 144 145 146 147

        if (
            loss
            and not isinstance(loss, (paddle.nn.Layer, Variable))
            and not callable(loss)
        ):
            raise TypeError(
                "'loss' must be sub classes of `paddle.nn.Layer` or any callable function or a Variable."
            )
148 149 150
        self._loss = loss

        if optimizer and not isinstance(
151
            optimizer,
152
            (paddle.optimizer.Optimizer),
153
        ):
154 155
            raise TypeError(
                "'optimizer' must be object of class `paddle.optimizer.Optimizer`"
156
            )
157
        self._optimizer = auto_utils.validate_opt(optimizer)
158 159

        metrics = metrics or []
160
        for metric in auto_utils.to_list(metrics):
161 162 163 164 165 166
            if metric and not isinstance(metric, Metric):
                raise TypeError(
                    "{} is not sub class of Metric".format(
                        metric.__class__.__name__
                    )
                )
167
        self._metrics = auto_utils.to_list(metrics)
168 169 170 171 172 173 174 175 176 177 178 179 180

        if cluster and not isinstance(cluster, Cluster):
            raise TypeError(
                "'cluster' must be the object or class `paddle.distributed.auto_parallel.Cluster`"
            )
        self._cluster = cluster or get_default_cluster()

        if strategy and not isinstance(strategy, Strategy):
            raise TypeError(
                "'strategy' must be object of class `paddle.distributed.auto_parallel.Strategy`"
            )
        self._strategy = strategy or Strategy()

181
        self._logger = get_logger(logging.INFO)
182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198

        self._json_config = None
        if cluster:
            self._cluster = cluster
        else:
            if os.getenv("PADDLE_AUTO_PARALLEL_CONFIG"):
                try:
                    path = os.getenv("PADDLE_AUTO_PARALLEL_CONFIG")
                    with open(path, "r") as f:
                        self._json_config = json.load(f)
                except Exception as e:
                    self._logger.info(
                        "Load json failed, please check json file, engine will run default config."
                    )
                    self._json_config = None
            self._cluster = get_default_cluster(self._json_config)

199
        if os.getenv("POD_NAME"):
200 201
            self._logger.info(
                "Distribute training by paddle.distributed.launch"
202
            )
203
            fleet.init(is_collective=True)
204

205 206 207 208 209 210
        # for compute cost
        # TODO: remove _fwd_main_progs and _orig_optimizer
        self._fwd_dist_contexts = {}
        self._fwd_main_progs = {}
        self._orig_optimizer = copy.deepcopy(self._optimizer)

211
        self._executor = None
212 213 214
        self._cur_rank = paddle.distributed.get_rank()
        self._nranks = paddle.distributed.get_world_size()
        self._saver = DistributedSaver()
215

216 217
        self._orig_main_prog = static.default_main_program()
        self._orig_startup_prog = static.default_startup_program()
218
        self._orig_dist_context = get_default_distributed_context()
219
        self._dist_contexts = {}
220
        self._planners = {}
221 222
        self._has_prepared = {"train": False, "eval": False, "predict": False}
        self._has_prepared_reader = {
223 224
            "train": False,
            "eval": False,
225
            "predict": False,
226
        }
227 228 229 230
        self._inputs_spec = []
        self._labels_spec = []
        self._inputs = []
        self._labels = []
231
        self._losses = []
232

233
        self._mode = None
234 235
        self._skip_build = False
        self._outside_dataloader = False
236
        self._planned_mode = None
237 238
        self._dygraph_mode = False
        self._tuning = self._strategy.tuning
239 240 241 242 243
        self._acc_steps = 1
        if self._strategy.gradient_merge.enable:
            self._acc_steps = self._strategy.gradient_merge.k_steps
        elif self._strategy.pipeline.enable:
            self._acc_steps = self._strategy.pipeline.accumulate_steps
244

245 246 247 248 249 250 251 252
        if (
            self._strategy.pipeline.enable
            and self._strategy.pipeline.schedule_mode == "1F1B"
        ):
            assert (
                os.getenv("CUDA_MODULE_LOADING") != "LAZY"
            ), "EXP_CUDA_MODULE_LOADING_LAZY not supported in 1F1B pipeline."

Z
zhaoyingli 已提交
253 254
        self.history = None

255
        paddle.framework.set_flags({'FLAGS_new_executor_sequential_run': 1})
256
        paddle.framework.set_flags({'FLAGS_new_executor_static_build': 1})
257

258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275
    def _prepare_data_spec(self, data, split, batch_size):
        inputs_spec = []
        labels_spec = []
        if isinstance(data, paddle.io.IterableDataset):
            if split is None:
                inputs, labels = next(iter(data))
            else:
                sample = next(iter(data))
                inputs = sample[:split]
                labels = sample[split:]
        elif isinstance(data, paddle.io.Dataset):
            if split is None:
                inputs, labels = data[0]
            else:
                sample = data[0]
                inputs = sample[:split]
                labels = sample[split:]
        else:
276
            raise TypeError(
C
chenxujun 已提交
277
                "Data should be a Dataset or IterableDataset, but received {}.".format(
278 279 280
                    type(data).__name__
                )
            )
281 282
        inputs = auto_utils.to_list(inputs)
        labels = auto_utils.to_list(labels)
283 284

        num_shards = self._strategy.dataset.num_shards
285

286 287 288 289 290 291 292 293 294 295 296 297
        def _adjust_item_spec(num_shards, spec):
            if num_shards > 1 and len(spec.shape) > 1:
                spec.shape[0] = spec.shape[0] * num_shards

        def _infer_item_spec(item, name, batch_size, specs):
            if isinstance(item, np.ndarray):
                spec = InputSpec.from_numpy(item, name)
                if batch_size is None:
                    _adjust_item_spec(num_shards, spec)
                    specs.append(spec)
                else:
                    specs.append(spec.batch(batch_size))
W
wanghuancoder 已提交
298
            elif isinstance(item, (Variable, core.eager.Tensor)):
299
                spec = InputSpec.from_tensor(item, name)
300
                _adjust_item_spec(num_shards, spec)
301 302 303 304
                if batch_size is None:
                    specs.append(spec)
                else:
                    specs.append(spec.batch(batch_size))
305
            elif isinstance(item, numbers.Number):
306
                specs.append(InputSpec([batch_size], type(item), name))
307 308 309 310 311 312
            else:
                raise TypeError(
                    "The sample's dtype returned of dataset should be number, np.ndarray or Tensor, but got {}".format(
                        type(item).__name__
                    )
                )
313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328

        if inputs is not None:
            for i, item in enumerate(inputs):
                assert item is not None, "Receive None input."
                name = "input" + str(i)
                _infer_item_spec(item, name, batch_size, inputs_spec)
        if labels is not None:
            for i, item in enumerate(labels):
                assert item is not None, "Receive None input."
                name = "label" + str(i)
                _infer_item_spec(item, name, batch_size, labels_spec)

        inputs_spec = self._validate_spec(inputs_spec)
        labels_spec = self._validate_spec(labels_spec)
        return inputs_spec, labels_spec

329
    def _prepare_data_tensor(self, inputs_spec, labels_spec, inputs, labels):
330
        if in_dynamic_mode() or self._dygraph_mode:
331 332
            raise ValueError("Only support static graph mode.")

333
        if inputs_spec:
334 335 336 337 338
            assert isinstance(
                inputs_spec, list
            ), "inputs should be list, but received {}".format(
                type(inputs_spec)
            )
339 340
            assert isinstance(
                inputs, list
341
            ), f"inputs should be list, but received {type(inputs)}"
342 343 344 345 346 347
            assert len(inputs_spec) == len(
                inputs
            ), "the number of `inputs_spec` should be equal to `inputs`'s."
            for input_spec, input in zip(inputs_spec, inputs):
                if input_spec.shape != input.shape:
                    input.desc.set_shape(input_spec.shape)
348
        if labels_spec:
349 350 351 352 353
            assert isinstance(
                labels_spec, list
            ), "labels should be list, but received {}".format(
                type(labels_spec)
            )
354 355
            assert isinstance(
                labels, list
356
            ), f"labels should be list, but received {type(labels)}"
357 358 359 360 361 362 363
            assert len(labels_spec) == len(
                labels
            ), "the number of `labels_spec` should be equal to `labels`'s."
            for label_spec, label in zip(labels_spec, labels):
                if label_spec.shape != label.shape:
                    label.desc.set_shape(label_spec.shape)

364 365
        return inputs, labels

366
    def _prepare_reader(self, feed_list=[]):
367
        dist_context = self._dist_contexts[self._mode]
368
        dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
369 370 371 372
        dist_main_block = dist_main_prog.global_block()

        # NOTE: this list may be changed if Paddle changes the existing rules.
        related_reader_ops = [
373 374 375
            "create_py_reader",
            "create_double_buffer_reader",
            "read",
376 377 378 379 380 381 382 383 384 385 386 387 388
        ]
        # remove the first three ops if multiple run fit/evaluate/predict
        if dist_main_block.ops[0].type == 'create_py_reader':
            for i in range(len(related_reader_ops)):
                if dist_main_block.ops[0].type in related_reader_ops:
                    dist_main_block._remove_op(0, sync=False)
        dist_main_block._sync_with_cpp()
        # Step 1: find the reader ops
        reader_op_indices = []
        for idx, op in enumerate(dist_main_block.ops):
            if op.type in related_reader_ops:
                reader_op_indices.append(idx)
        # Step 2: insert the new reader ops to cpp
389 390
        # record the read ops' desc to insert to program of forward task_node
        read_ops_desc = []
391 392 393 394
        new_reader_ops = []
        for idx in reversed(reader_op_indices):
            new_op_desc = dist_main_block.desc._prepend_op()
            new_op_desc.copy_from(dist_main_block.ops[idx].desc)
395
            read_ops_desc.append(new_op_desc)
396 397 398
            new_op = Operator(
                dist_main_block, new_op_desc, type=new_op_desc.type()
            )
399 400 401 402 403 404 405 406 407 408 409 410 411 412 413
            new_reader_ops.append(new_op)
            dist_op = DistributedOperator(new_op)
            dist_context.add_dist_op_for_program(dist_op)
        # Step 3: insert the new reader ops to python
        for new_op in new_reader_ops:
            dist_main_block.ops.insert(0, new_op)
        for i in range(len(reader_op_indices)):
            reader_op_indices[i] += len(reader_op_indices)
        # Step 4: remove the old reader ops from python and cpp
        for idx in reversed(reader_op_indices):
            op = dist_main_block.ops.pop(idx)
            dist_main_block.desc._remove_op(idx, idx + 1)
        dist_main_block._sync_with_cpp()
        self._has_prepared_reader[self._mode] = True

414 415 416 417 418
        # Insert read op to forward TaskNode for fleet executor if 1F1B pass is setted
        if (
            self.main_program._pipeline_opt
            and not auto_utils.use_new_executor()
        ):
419 420
            assert "tasks" in self.main_program._pipeline_opt["fleet_opt"]
            fleet_opt = self.main_program._pipeline_opt["fleet_opt"]
421 422 423 424 425 426
            fwd_task = None
            if self._strategy.pipeline.schedule_mode == "1F1B":
                fwd_task = fleet_opt["tasks"][1]
            elif self._strategy.pipeline.schedule_mode == "stream":
                fwd_task = fleet_opt["tasks"][0]
            assert fwd_task is not None
427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444
            fwd_prog = fwd_task.get_program()
            fwd_block = fwd_prog.global_block()

            for var in feed_list:
                if var.name not in fwd_block.vars:
                    fwd_block._clone_variable(var)

            for op_desc in read_ops_desc:
                new_op_desc = fwd_block.desc._prepend_op()
                new_op_desc.copy_from(op_desc)
                new_op = Operator(
                    fwd_block, new_op_desc, type=new_op_desc.type()
                )
                fwd_block.ops.insert(0, new_op)

            fwd_block._sync_with_cpp()
            fwd_task.set_program(fwd_prog)

445 446 447 448 449
    def _prepare_feed(self, data, user_feeds, mode):
        feeds = {}
        if data is not None:
            if isinstance(data, (list, tuple)):
                if len(data) == 1 and isinstance(data[0], dict):
450 451
                    for name, value in data[0].items():
                        feeds[name] = value
452
                else:
453
                    raise ValueError(f"Unsupported data {data}")
454
            elif isinstance(data, dict):
455 456
                for name, value in data.items():
                    feeds[name] = value
457
            else:
458
                raise ValueError(f"Unsupported data {data}")
459
        if user_feeds is not None:
460 461 462 463 464
            assert isinstance(
                user_feeds, dict
            ), "user_feeds must be a dict, but receive {}".format(
                type(user_feeds).__name__
            )
465 466
            for name, data in user_feeds.items():
                feeds[name] = data
467 468
        return feeds

469
    def _prepare_fetch(self, user_fetches, mode):
470
        if user_fetches is not None:
471 472 473 474 475
            assert isinstance(
                user_fetches, list
            ), "user_fetches must be a list, but receive {}".format(
                type(user_fetches).__name__
            )
476
        fetch_names = []
477
        fetch_indices = []
478

479 480
        def _process_fetch_group(group_name, var_list):
            group_indices = []
481
            for var in var_list:
482 483 484 485 486 487 488 489
                # Remove duplicate var_names
                if self._is_local_var(var):
                    var_name = _to_name_str(var)
                    if var_name not in fetch_names:
                        fetch_names.append(var_name)
                    group_indices.append(fetch_names.index(var_name))
            fetch_indices.append(group_indices)

490 491
        dist_context = self._dist_contexts[mode]
        fetch_vars = dist_context.serial_fetch_vars
492
        if mode != "predict":
493
            _process_fetch_group("loss", fetch_vars["loss"])
494
        if mode != "predict":
495
            metrics = fetch_vars["metrics"]
496 497 498
            for i, var_list in enumerate(metrics):
                _process_fetch_group("metrics_" + str(i), var_list)
        if mode == "predict":
499
            _process_fetch_group("outputs", fetch_vars["outputs"])
500
        for usr_fetch in user_fetches or []:
501 502
            var_name = _to_name_str(usr_fetch)
            fetch(var_name)
503 504 505
        user_fetches_collection = [
            item[1] for item in get_collection(CollectionNames.FETCHES)
        ]
506
        var_list = user_fetches_collection or []
507 508 509
        _process_fetch_group("fetches", var_list)
        return fetch_names, fetch_indices

510 511 512 513 514 515 516 517 518 519
    def _prepare_logger(
        self,
        outs,
        epoch=None,
        step=None,
        lr=None,
        fetch_names=None,
        fetch_indices=None,
        mode=None,
    ):
Z
zhaoyingli 已提交
520
        logs = {}
521
        if epoch is not None:
Z
zhaoyingli 已提交
522
            logs["epoch"] = epoch
523
        if step is not None:
Z
zhaoyingli 已提交
524
            logs["step"] = step + 1
525
        if lr is not None:
Z
zhaoyingli 已提交
526
            logs["lr"] = lr
527 528
        group_idx = 0
        if mode != "predict":
Z
zhaoyingli 已提交
529
            # logging loss
530
            loss_indices = fetch_indices[group_idx]
Z
zhaoyingli 已提交
531
            assert len(loss_indices) <= 1
532
            for idx in loss_indices:
533
                logs["loss"] = outs[idx]
534
            group_idx += 1
Z
zhaoyingli 已提交
535
            # logging metrics
536 537
            dist_context = self._dist_contexts[mode]
            metric_vars = dist_context.serial_fetch_vars["metrics"]
538 539 540 541 542 543 544 545 546
            if metric_vars:
                for metric in self._metrics:
                    metrics_indices = fetch_indices[group_idx]
                    metric_out = []
                    for idx in metrics_indices:
                        metric_out.append(outs[idx])
                    if metric_out:
                        metric.update(*metric_out)
                        results = metric.accumulate()
547
                        for i, res in enumerate(auto_utils.to_list(results)):
Z
zhaoyingli 已提交
548
                            logs[metric.name()[i]] = res
549
                    group_idx += 1
Z
zhaoyingli 已提交
550 551 552 553 554 555 556
        # logging outputs
        elif mode == "predict":
            outputs_indices = fetch_indices[group_idx]
            logs_out = {}
            for idx in outputs_indices:
                logs_out["out%d" % (idx)] = outs[idx]
            logs["outputs"] = logs_out
557 558
            group_idx += 1
        # logging user fetches
Z
zhaoyingli 已提交
559 560
        collect_fetches = get_collection(CollectionNames.FETCHES)
        logs_fetch = {}
561 562 563 564
        for name, var_name in collect_fetches:
            if var_name in fetch_names:
                idx = fetch_names.index(var_name)
                logs_fetch[name or var_name] = outs[idx]
Z
zhaoyingli 已提交
565 566
        logs["fetches"] = logs_fetch
        return logs
567

568
    def _prepare_program(self, mode, init_parameters=True):
569 570 571 572 573 574
        # Do the build process
        self._build(mode)
        # Do the planning process
        self._plan(mode)
        # Do the parallel process
        self._parallel(mode)
575 576 577 578 579
        # Init comm
        self._init_comm()
        if init_parameters:
            # startup program
            self._initialize(mode)
580 581
        self._has_prepared[mode] = True

582
    def _build(self, mode):
583
        if in_dynamic_mode() or self._dygraph_mode:
584
            paddle.disable_static()
585 586 587
            self._dygraph_mode = True
            self._logger.info("Building model with 'to_static' method.")

588
            self.program_helper = ProgramHelper(
589 590 591 592 593
                self._model,
                self._loss,
                self._metrics,
                self._inputs_spec,
                self._labels_spec,
594
            )
595
            # build forward main program
596 597
            with utils.unique_name.guard():
                self.program_helper.build_program(mode)
598

599 600 601
            self.concrete_program = self.program_helper.concrete_program
            serial_main_prog = self.program_helper.main_program
            serial_startup_prog = self.program_helper.startup_program
602

603 604
            self._inputs = self.program_helper.input_vars
            self._labels = self.program_helper.label_vars
605
            outputs = self.program_helper.output_vars
606
            self._losses = self.program_helper.loss_vars
607
            metrics = self.program_helper.metric_vars
608

609
            paddle.enable_static()
610
        else:
611 612 613
            # build program in static mode
            dist_context = self._dist_contexts.get(mode, None)
            if dist_context is not None:
614 615
                return

616
            outputs = []
617
            metrics = []
618
            self._losses = []
619 620
            serial_main_prog = self._orig_main_prog.clone()
            serial_startup_prog = self._orig_startup_prog.clone()
621
            if not self._skip_build:
622 623 624
                with static.program_guard(
                    serial_main_prog, serial_startup_prog
                ), utils.unique_name.guard():
625 626 627 628 629 630 631
                    self._inputs = [
                        s._create_feed_layer() for s in self._inputs_spec
                    ]
                    self._labels = [
                        s._create_feed_layer() for s in self._labels_spec
                    ]

632
                    outputs = auto_utils.to_list(self._model(*self._inputs))
633

634
                    if mode != "predict" and self._loss:
635 636 637 638 639
                        assert isinstance(
                            self._loss, paddle.nn.Layer
                        ) or callable(
                            self._loss
                        ), "the type of `loss` of the Engine arguments should be sub classes of `paddle.nn.Layer` or any callable function."
640
                        self._losses = auto_utils.to_list(
641 642
                            self._loss(*(outputs + self._labels))
                        )
643

644
                    if mode != "predict" and (outputs or self._labels):
645 646
                        for metric in self._metrics:
                            metrics.append(
647
                                auto_utils.to_list(
648 649
                                    metric.compute(*(outputs + self._labels))
                                )
650
                            )
Z
zhaoyingli 已提交
651
            elif mode == "train":
652 653 654
                assert isinstance(
                    self._loss, Variable
                ), "the type of `loss` of the Engine arguments should be Variable."
655
                self._losses = auto_utils.to_list(self._loss)
656 657 658 659 660 661 662

        default_ctx = get_default_distributed_context()
        if not default_ctx.has_annotation:
            # We build the world process group because the data parallel
            # needs all ranks by default.
            new_process_group(list(range(self._nranks)))
            default_ctx.data_parallel = True
663 664 665 666 667 668
            self._inputs = [
                auto_utils.set_data_parallel(var) for var in self._inputs
            ]
            self._labels = [
                auto_utils.set_data_parallel(var) for var in self._labels
            ]
669

670
        feed_vars = {"inputs": self._inputs, "labels": self._labels}
671 672

        fetch_vars = {
673
            "outputs": paddle.utils.flatten(outputs),
674
            "loss": self._losses,
675
            "metrics": metrics,
676 677
        }

678 679 680
        if mode != "train":
            serial_main_prog = serial_main_prog.clone(for_test=True)

681 682 683
        auto_utils.set_recompute_segments(
            self._model, self._losses, self._strategy, serial_main_prog
        )
684
        self._dist_contexts[mode] = DistributedContext(
685 686 687
            serial_main_prog,
            serial_startup_prog,
            self._optimizer,
688 689 690 691 692
            self._losses,
            feed_vars,
            fetch_vars,
            self._cluster,
            self._strategy,
693
            self._json_config,
694 695 696 697 698 699
        )
        self._fwd_dist_contexts[mode] = DistributedContext(
            serial_main_prog,
            serial_startup_prog,
            self._optimizer,
            self._losses,
700 701 702 703
            feed_vars,
            fetch_vars,
            self._cluster,
            self._strategy,
704
            self._json_config,
705
        )
706
        self._dist_contexts[mode].gradient_scale = self._strategy.gradient_scale
707
        self._fwd_main_progs[mode] = serial_main_prog.clone()
708

709 710 711
    def _optimization_tuning(self, mode, dataset, batch_size):
        if not self._tuning.enable:
            raise ValueError("Please set `tuning.enable=True`.")
712

713 714 715 716 717 718 719 720
        assert mode == "train"
        # Do the build process
        self._build(mode)
        # Do the planning process
        self._plan(mode)

        dataset.dp_world_size = self._dp_world_sizes
        dataset.dp_rank = self._dp_ranks
721 722

        from .tuner.optimization_tuner import OptimizationTuner
723 724 725 726 727 728 729 730 731

        self._optimization_tuner = OptimizationTuner(
            self._dist_contexts[mode],
            dataset,
            self._inputs_spec,
            self._labels_spec,
            batch_size=batch_size,
            rank=self._cur_rank,
        )
732 733 734

        self._optimization_tuner.tune()

735
        if self._tuning.run_after_tuning:
736 737
            # update the strategy
            self._dist_contexts[
738 739
                mode
            ]._strategy = self._optimization_tuner.get_best_config()
740

741 742 743 744 745 746
    def _plan(self, mode):
        if self._planned_mode is None:
            self._planned_mode = mode
        else:
            self._init_dist_context(mode)

747 748
        self._planners[mode] = Planner(mode, self._dist_contexts[mode])
        self._planners[mode].plan()
749

750 751 752 753
        # infer data parallel info
        inputs_var = self._dist_contexts[mode].serial_feed_vars["inputs"]
        labels_var = self._dist_contexts[mode].serial_feed_vars["labels"]
        block = self._dist_contexts[mode].serial_main_program.global_block()
754
        # TODO: check this feed_list
755 756 757 758 759
        feed_list = []
        for var in inputs_var + labels_var:
            if var.name in block.vars:
                feed_list.append(block.vars[var.name])

760 761
        self._dp_world_sizes = []
        self._dp_ranks = []
762
        for feed_var in feed_list:
763
            dp_world_size, dp_rank = auto_utils.get_input_split_info(
764
                self._cur_rank, feed_var, self._dist_contexts[mode]
765
            )
766 767
            self._dp_world_sizes.append(dp_world_size)
            self._dp_ranks.append(dp_rank)
768

769
    def _parallel(self, mode, all_ranks=False):
770
        # Parallelize program based on the planner's results
L
Leo Chen 已提交
771
        # For now, the completer has to be passed to the Parallelizer,
C
chenxujun 已提交
772
        # because we may use it to complete the annotation of the backward and update.
773
        parallelizer = Parallelizer(
Y
yuehuayingxueluo 已提交
774 775 776
            mode,
            self._planners[mode].completer,
            self._dist_contexts[mode],
777
        )
778
        if not all_ranks:
779
            parallelizer.parallel(self._cur_rank, self._parameter_list)
780
        else:
781
            parallelizer.parallel_all(self._parameter_list)
782 783

    def _init_dist_context(self, mode):
784
        # Init dist_context['mode'] with the first planned dist_context
785 786 787 788 789 790 791 792 793 794
        # to guarantee that train/eval/predict mode have same parallel strategy
        dist_context = self._dist_contexts[mode]
        origin_main_prog = dist_context._original_serial_main_program
        ref_mode = self._planned_mode
        ref_dist_context = self._dist_contexts[ref_mode]
        ref_origin_main_prog = ref_dist_context._original_serial_main_program
        ref_blocks = ref_origin_main_prog.blocks
        for ib, block in enumerate(origin_main_prog.blocks):
            for iop, op in enumerate(block.ops):
                ref_op = ref_blocks[ib].ops[iop]
795 796 797 798 799 800 801 802
                assert (
                    op.type == ref_op.type
                ), "'{}' mode op '{}' is different with '{}' op '{}'. ".format(
                    mode, op.type, ref_mode, ref_op.type
                )
                ref_op_dist_attr = (
                    ref_dist_context.get_op_dist_attr_for_program(ref_op)
                )
803 804
                dist_context.set_op_dist_attr_for_program(op, ref_op_dist_attr)

805
    def _init_comm(self):
806 807 808 809
        if self._nranks > 1:
            # Traverse different rank programs and traverse each op of them,
            # instantiate communication by process_mapping.
            all_process_groups = get_all_process_groups()
810

811
            if self._strategy.auto_mode == "full_random":
812
                auto_utils.initialize_pg_in_full_mode(
813
                    all_process_groups, self._cur_rank
814
                )
815 816 817
            else:
                for process_group in all_process_groups:
                    process_group.instantiate()
818

819
    def _initialize(self, mode):
820
        self._place = _get_device()
821
        if isinstance(self._place, paddle.framework.CUDAPlace):
822 823 824
            self._place = paddle.framework.CUDAPlace(
                paddle.distributed.ParallelEnv().dev_id
            )
825

826 827 828 829 830
        if self._strategy.seed:
            paddle.seed(self._strategy.seed + self._dp_ranks[0])
            np.random.seed(self._strategy.seed + self._dp_ranks[0])
            random.seed(self._strategy.seed + self._dp_ranks[0])

831
        dist_context = self._dist_contexts[mode]
832
        if self._dygraph_mode:
833
            dist_main_program = dist_context.dist_main_programs[self._cur_rank]
834 835 836
            self.program_helper.init(
                dist_main_program, self._place, dist_context
            )
837

838
        if self._executor is None:
839
            self._executor = paddle.static.Executor(self._place)
840
            uninitialized = []
841 842 843
            dist_startup_prog = dist_context.dist_startup_programs[
                self._cur_rank
            ]
844 845 846 847 848 849 850 851
            for var in dist_startup_prog.list_vars():
                scope_var = global_scope().find_var(var.name)
                if scope_var and scope_var.get_tensor()._is_initialized():
                    continue
                uninitialized.append(var)
            if uninitialized:
                prune_startup_prog = dist_startup_prog._prune(uninitialized)
                self._executor.run(prune_startup_prog)
852

853
            if hasattr(self, "_state_dict") and hasattr(self, "_dist_attr"):
854 855 856
                self._set_state_dict(
                    mode, self._strict, self._state_dict, self._dist_attr
                )
857 858

        if self._strategy.reinit:
Z
zhaoyingli 已提交
859
            self._logger.info("NOTE: parameters will be re-initialized.")
860 861 862
            dist_startup_prog = dist_context.dist_startup_programs[
                self._cur_rank
            ]
863 864
            self._executor.run(dist_startup_prog)

865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881
    def fit(
        self,
        train_data,
        train_sample_split=None,
        batch_size=1,
        epochs=1,
        steps_per_epoch=None,
        log_freq=10,
        save_dir=None,
        save_freq=1,
        valid_data=None,
        valid_sample_split=None,
        valid_freq=1,
        valid_steps=None,
        collate_fn=None,
        callbacks=None,
        verbose=2,
882
        nvprof_range=[-1, -1],
883
    ):
884 885 886 887 888 889 890 891
        """
        Trains the model for a fixed number of epochs. If `valid_data` is set,
        evaluation will be done at the end of each epoch.

        Args:
            train_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None.
            train_sample_split (int, optional): Each sample of the train dataset is assumed
                to be a (input, label) pair by default and has two items. If each sample has
892
                more than two items, train_sample_split specifies how to split these items into
893
                input and label. The items before it are input and the left are label. Default: None.
894
            batch_size (int, optional): The batch size of train_data and valid_data if provided.
895 896 897
                The user's data will be used directly without batching if set to None. Default: 1.
            epochs (int, optional): The number of epochs to train the model. Default: 1.
            steps_per_epoch (int, optional): The total number of steps (batches of samples)
898
                is executed in one epoch before stating the next one. If None, it is equal to
899 900
                the number samples in your dataset divided by the batch size. Default: None.
            valid_data (Dataset, optional): An instance of paddle paddle.io.Dataset used for
901
                evaluation at the end of epoch. No evaluation will be done if set to None.
902
                Default: None. (Unsupported for now)
903
            valid_freq (int, optional): Only relevant if valid_data is provided. This specifies
904 905
                how many training epochs before a new evaluation is performed. Default: 1.
            valid_sample_split (int, optional): Only relevant if valid_data is provided.
906 907
                Each sample of the valid dataset is assumed to be a (input, label) pair
                by default and has two items. If each sample has more than two items,
908 909 910
                valid_sample_split specifies how to split these items into input and label.
                The items before it are input and the left are label. Default: None.
            valid_steps (int, optional): Only relevant if valid_data is provided.
911 912
                It is the total number of steps (batches of samples) to draw before
                stopping validation at the end of every epoch. If None, validation will run until the
913 914 915 916
                `valid_data` dataset is exhausted. The validation will start from the
                beginning of the dataset at each epoch. Default: None.
            collate_fn(callable, optional): function to generate mini-batch data by merging
                the sample list, None for only stack each fields of sample in axis
917
                0. Default None.
918 919
            callbacks (Callback|None, optional): A list of `Callback` instances to apply
                during training. Default: None. (Unused for now)
920
            nvprof_range(list, optional): A list of integers indicating nvprof ranges in form of [start_step, end_step]. Note that if start_step >= end_step, the nvprof will not apply.
921 922 923 924 925 926 927 928

        Returns:
            None

        Examples:

            .. code-block:: python

929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949
                >>> import paddle
                >>> import paddle.vision.transforms as T
                >>> from paddle.distributed.fleet import auto
                >>> from paddle.vision.datasets import MNIST

                >>> transform = T.Compose([
                ...     T.Transpose(),
                ...     T.Normalize([127.5], [127.5])
                >>> ])
                >>> train_dataset = MNIST(mode='train', transform=transform)

                >>> model = paddle.vision.models.LeNet()
                >>> loss = paddle.nn.CrossEntropyLoss()
                >>> optimizer = paddle.optimizer.Adam(
                ...     learning_rate=0.001, parameters=model.parameters())
                >>> metrics = paddle.metric.Accuracy(topk=(1, 2))

                >>> engine = auto.Engine(model, loss, optimizer, metrics)
                >>> engine.fit(train_dataset,
                ...             epochs=2,
                ...             batch_size=64)
950
        """
951 952
        self._mode = 'train'
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
953 954
            train_data, train_sample_split, batch_size
        )
955

956 957
        if not self._has_prepared[self._mode]:
            self._prepare_program(self._mode)
Z
zhaoyingli 已提交
958
        else:
959
            self._switch_mode(self._mode)
Z
zhaoyingli 已提交
960

961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989
        if auto_utils.use_new_executor():
            local_batch_size = self._validate_batch_size(batch_size)
            train_dataloader = self._prepare_dataloader(
                train_data,
                return_list=False,
                batch_size=local_batch_size,
                epochs=epochs,
                collate_fn=collate_fn,
            )
            steps_per_epoch = (
                len(train_dataloader)
                if steps_per_epoch is None
                else steps_per_epoch
            )
        else:
            micro_batch_size = self._validate_batch_size(batch_size)
            train_dataloader = self._prepare_dataloader_from_generator(
                dataset=train_data,
                capacity=70,
                iterable=False,
                batch_size=micro_batch_size,
                epochs=epochs,
                steps_per_epoch=steps_per_epoch,
                collate_fn=collate_fn,
            )
            steps_per_epoch = train_dataloader._steps
            local_batch_size = micro_batch_size
            if self._strategy.pipeline.enable:
                local_batch_size = micro_batch_size * self._acc_steps
Z
zhaoyingli 已提交
990

991
        fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode)
Z
zhaoyingli 已提交
992 993 994 995

        cbks = config_callbacks(
            callbacks,
            engine=self,
996
            batch_size=local_batch_size,
Z
zhaoyingli 已提交
997
            epochs=epochs,
998
            steps=steps_per_epoch,
Z
zhaoyingli 已提交
999 1000 1001 1002 1003
            log_freq=log_freq,
            save_freq=save_freq,
            save_dir=save_dir,
            verbose=verbose,
            metrics=self._metrics_name(),
1004 1005 1006
            acc_step=1
            if self._strategy.pipeline.enable
            else self._acc_steps,  # lr update once every local batch
Z
zhaoyingli 已提交
1007 1008 1009 1010 1011 1012
        )

        cbks.on_begin('train')
        for epoch in range(epochs):
            logs = {}
            cbks.on_epoch_begin(epoch)
1013

1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052
            for step, data in enumerate(train_dataloader):
                if auto_utils.use_new_executor():
                    feeds = self._validate_feed(data)
                else:
                    feeds = [{}]

                try:
                    for micro_feed in feeds:
                        with paddle.profiler.utils._nvprof_range(
                            iter_id=step,
                            start=nvprof_range[0],
                            end=nvprof_range[1],
                        ):
                            cbks.on_batch_begin('train', step, logs)
                            outs = self._executor.run(
                                self.main_program,
                                feed=micro_feed,
                                fetch_list=fetch_names,
                                use_program_cache=self._strategy.use_cache,
                                return_numpy=self._strategy.return_numpy,
                            )
                            lr = auto_utils.get_lr(self.optimizer)
                            logs = self._prepare_logger(
                                outs,
                                epoch,
                                step,
                                lr,
                                fetch_names,
                                fetch_indices,
                                self._mode,
                            )
                            cbks.on_batch_end('train', step, logs)
                except core.EOFException:
                    break

                if steps_per_epoch and step >= steps_per_epoch:
                    if not auto_utils.use_new_executor():
                        train_dataloader._reset()
                    break
Z
zhaoyingli 已提交
1053 1054

            if valid_data and (epoch + 1) % valid_freq == 0:
1055 1056 1057 1058 1059 1060 1061 1062 1063 1064
                val_logs = self.evaluate(
                    valid_data,
                    valid_sample_split,
                    batch_size,
                    valid_steps,
                    log_freq,
                    collate_fn,
                    callbacks,
                    verbose,
                )
Z
zhaoyingli 已提交
1065
                val_logs = {
1066
                    "val_" + name: val for name, val in val_logs.items()
Z
zhaoyingli 已提交
1067 1068 1069 1070 1071 1072 1073 1074 1075 1076
                }
                logs.update(val_logs)
                self._switch_mode("train")
            else:
                self._reset_metrics()

            cbks.on_epoch_end(epoch, logs)

        cbks.on_end('train', logs)
        return self.history
1077

1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088
    def evaluate(
        self,
        valid_data,
        valid_sample_split=None,
        batch_size=1,
        steps=None,
        log_freq=10,
        collate_fn=None,
        callbacks=None,
        verbose=2,
    ):
1089 1090 1091 1092
        """
        Evaluate the loss and metrics of the model on evaluation data.

        Args:
1093 1094
            valid_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None.
            valid_sample_split (int, optional): Each sample of the eval dataset is assumed
1095
                to be a (input, label) pair by default and has two items. If each sample has
1096
                more than two items, valid_sample_split specifies how to split these items into
1097
                input and label. The items before it are input and the left are label. Default: None.
1098
            batch_size (int, optional): The batch size of valid_data. The user's data will
1099
                be used directly without batching if set to None. Default: 1.
1100 1101
            steps (int, optional): It is the total number of steps (batches of samples) to draw before
                stopping evaluation. If None, evaluation will run until the `valid_data` dataset is exhausted.
1102 1103 1104 1105 1106
                The evaluation will start from the beginning of the dataset in each run. Default: None.
            collate_fn(callable, optional): function to generate mini-batch data by merging
                the sample list, None for only stack each fields of sample in axis
                0. Default None.
            callbacks (Callback|None, optional): A list of `Callback` instances to apply
1107
                during evaluating. Default: None. (Unused for now)
1108 1109 1110 1111 1112 1113 1114 1115

        Returns:
            None

        Examples:

            .. code-block:: python

1116 1117 1118 1119
                >>> import paddle
                >>> import paddle.vision.transforms as T
                >>> from paddle.distributed.fleet import auto
                >>> from paddle.vision.datasets import MNIST
1120

1121 1122 1123 1124 1125
                >>> transform = T.Compose([
                ...     T.Transpose(),
                ...     T.Normalize([127.5], [127.5])
                >>> ])
                >>> valid_dataset = MNIST(mode='test', transform=transform)
1126

1127 1128 1129
                >>> model = paddle.vision.models.LeNet()
                >>> loss = paddle.nn.CrossEntropyLoss()
                >>> metrics = paddle.metric.Accuracy(topk=(1, 2))
1130

1131 1132
                >>> engine = auto.Engine(model, loss, metrics=metrics)
                >>> engine.evaluate(valid_dataset, batch_size=64)
1133 1134

        """
1135 1136
        self._mode = 'eval'
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
1137 1138
            valid_data, valid_sample_split, batch_size
        )
1139
        micro_batch_size = self._validate_batch_size(batch_size)
1140 1141
        if not self._has_prepared[self._mode]:
            self._prepare_program(self._mode)
Z
zhaoyingli 已提交
1142
        else:
1143
            self._switch_mode(self._mode)
Z
zhaoyingli 已提交
1144

1145 1146 1147 1148
        valid_dataloader = self._prepare_dataloader_from_generator(
            dataset=valid_data,
            capacity=70,
            iterable=False,
1149
            batch_size=micro_batch_size,
1150
            steps_per_epoch=steps,
1151 1152
            collate_fn=collate_fn,
        )
Z
zhaoyingli 已提交
1153

1154
        fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode)
1155

Z
zhaoyingli 已提交
1156 1157 1158
        cbks = config_callbacks(
            callbacks,
            engine=self,
1159
            batch_size=micro_batch_size,
Z
zhaoyingli 已提交
1160 1161 1162 1163 1164 1165
            log_freq=log_freq,
            verbose=verbose,
            metrics=self._metrics_name(),
        )

        eval_steps = valid_dataloader._steps
1166 1167 1168
        cbks.on_begin(
            'eval', {'steps': eval_steps, 'metrics': self._metrics_name()}
        )
Z
zhaoyingli 已提交
1169
        logs = {}
1170
        for step, _ in enumerate(valid_dataloader):
Z
zhaoyingli 已提交
1171
            cbks.on_batch_begin('eval', step, logs)
1172
            try:
1173 1174
                outs = self._executor.run(
                    self.main_program,
1175
                    fetch_list=fetch_names,
1176
                    use_program_cache=self._strategy.use_cache,
1177 1178
                    return_numpy=self._strategy.return_numpy,
                )
1179
            except core.EOFException:
1180
                break
1181 1182 1183
            logs = self._prepare_logger(
                outs, None, step, None, fetch_names, fetch_indices, self._mode
            )
Z
zhaoyingli 已提交
1184 1185
            cbks.on_batch_end('eval', step, logs)
        cbks.on_end('eval', logs)
1186
        self._reset_metrics()
Z
zhaoyingli 已提交
1187
        return logs
1188

1189 1190 1191 1192 1193 1194 1195 1196 1197 1198
    def predict(
        self,
        test_data,
        test_sample_split=None,
        batch_size=1,
        steps=None,
        collate_fn=None,
        callbacks=None,
        verbose=2,
    ):
1199 1200 1201 1202 1203 1204 1205
        """
        Compute the output predictions on testing data.

        Args:
            test_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None.
            test_sample_split (int, optional): Each sample of the test dataset is assumed
                to be a (input, label) pair by default and has two items. If each sample has
1206
                more than two items, test_sample_split specifies how to split these items into
1207 1208 1209
                input and label. The items before it are input and the left are label. Default: None.
            batch_size (int, optional): The batch size of test_data. The user's data will
                be used directly without batching if set to None. Default: 1.
1210 1211
            steps (int, optional): It is the total number of steps (batches of samples) to draw before
                stopping predict. If None, predict will run until the `test_data` dataset is exhausted.
1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225
                The predict will start from the beginning of the dataset in each run. Default: None.
            collate_fn(callable, optional): function to generate mini-batch data by merging
                the sample list, None for only stack each fields of sample in axis
                0. Default None.
            callbacks (Callback|None, optional): A list of `Callback` instances to apply
                during testing. Default: None. (Unused for now)

        Returns:
            None

        Examples:

            .. code-block:: python

1226 1227 1228 1229
                >>> import paddle
                >>> import paddle.vision.transforms as T
                >>> from paddle.distributed.fleet import auto
                >>> from paddle.vision.datasets import MNIST
1230

1231 1232 1233 1234 1235
                >>> transform = T.Compose([
                ...     T.Transpose(),
                ...     T.Normalize([127.5], [127.5])
                >>> ])
                >>> valid_dataset = MNIST(mode='test', transform=transform)
1236

1237
                >>> model = paddle.vision.models.LeNet()
1238

1239 1240
                >>> engine = auto.Engine(model)
                >>> engine.predict(valid_dataset, batch_size=64)
1241
        """
1242 1243
        self._mode = 'predict'
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
1244 1245
            test_data, test_sample_split, batch_size
        )
1246
        micro_batch_size = self._validate_batch_size(batch_size)
1247 1248
        if not self._has_prepared[self._mode]:
            self._prepare_program(self._mode)
Z
zhaoyingli 已提交
1249
        else:
1250
            self._switch_mode(self._mode)
Z
zhaoyingli 已提交
1251

1252 1253 1254 1255
        test_dataloader = self._prepare_dataloader_from_generator(
            dataset=test_data,
            capacity=70,
            iterable=False,
1256
            batch_size=micro_batch_size,
1257
            steps_per_epoch=steps,
1258 1259
            collate_fn=collate_fn,
        )
Z
zhaoyingli 已提交
1260

1261
        fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode)
1262

Z
zhaoyingli 已提交
1263 1264 1265 1266 1267
        outputs = []
        cbks = config_callbacks(callbacks, engine=self, verbose=verbose)
        test_steps = test_dataloader._steps
        cbks.on_begin('predict', {'steps': test_steps})
        logs = {}
1268
        for step, _ in enumerate(test_dataloader):
Z
zhaoyingli 已提交
1269
            cbks.on_batch_begin('predict', step, logs)
1270
            try:
1271 1272
                outs = self._executor.run(
                    self.main_program,
1273
                    fetch_list=fetch_names,
1274
                    use_program_cache=self._strategy.use_cache,
1275 1276
                    return_numpy=self._strategy.return_numpy,
                )
1277
            except core.EOFException:
1278
                break
1279 1280 1281
            logs = self._prepare_logger(
                outs, None, step, None, fetch_names, fetch_indices, self._mode
            )
Z
zhaoyingli 已提交
1282 1283 1284 1285 1286
            cbks.on_batch_end('predict', step, logs)
            outputs.append(list(logs["outputs"].values()))
        cbks.on_end('predict', logs)
        return outputs

1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302
    def dataloader(
        self,
        dataset,
        batch_size=1,
        shuffle=False,
        drop_last=False,
        collate_fn=None,
        num_workers=0,
        use_buffer_reader=True,
        use_shared_memory=True,
        timeout=0,
        worker_init_fn=None,
        epochs=1,
        steps_per_epoch=None,
        sample_split=1,
        mode=None,
1303
        places=None,
1304
    ):
1305 1306 1307
        if mode is not None:
            self.to_mode(mode)
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
1308 1309
            dataset, sample_split, batch_size
        )
1310
        micro_batch_size = self._validate_batch_size(batch_size)
1311 1312
        if not self._has_prepared[self._mode]:
            self._prepare_program(self._mode)
1313
        else:
1314
            self._switch_mode(self._mode)
1315

1316 1317 1318
        dataloader = self._prepare_dataloader(
            dataset,
            return_list=False,
1319
            batch_size=micro_batch_size,
1320 1321 1322 1323 1324 1325 1326 1327 1328
            shuffle=shuffle,
            drop_last=drop_last,
            collate_fn=collate_fn,
            num_workers=num_workers,
            use_buffer_reader=use_buffer_reader,
            use_shared_memory=use_shared_memory,
            timeout=timeout,
            worker_init_fn=worker_init_fn,
            epochs=epochs,
1329
            steps_per_epoch=steps_per_epoch,
1330
            places=places,
1331
        )
1332 1333
        return dataloader

1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348
    def dataloader_from_generator(
        self,
        dataset,
        capacity=70,
        use_double_buffer=True,
        iterable=True,
        use_multiprocess=False,
        drop_last=True,
        batch_size=1,
        epochs=1,
        steps_per_epoch=None,
        collate_fn=None,
        sample_split=1,
        mode=None,
    ):
1349 1350 1351
        if mode is not None:
            self.to_mode(mode)
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
1352 1353
            dataset, sample_split, batch_size
        )
1354
        micro_batch_size = self._validate_batch_size(batch_size)
1355 1356 1357 1358
        if not self._has_prepared[self._mode]:
            self._prepare_program(self._mode)
        else:
            self._switch_mode(self._mode)
1359

1360 1361 1362 1363 1364 1365 1366 1367
        dataloader = self._prepare_dataloader_from_generator(
            dataset=dataset,
            capacity=capacity,
            use_double_buffer=use_double_buffer,
            iterable=iterable,
            return_list=False,
            use_multiprocess=use_multiprocess,
            drop_last=drop_last,
1368
            batch_size=micro_batch_size,
1369 1370
            epochs=epochs,
            steps_per_epoch=steps_per_epoch,
1371 1372
            collate_fn=collate_fn,
        )
1373 1374
        return dataloader

1375 1376 1377 1378 1379 1380 1381 1382 1383
    def prepare(
        self,
        inputs_spec=None,
        labels_spec=None,
        inputs=None,
        labels=None,
        main_program=None,
        startup_program=None,
        mode=None,
1384
        init_parameters=True,
1385
    ):
1386 1387
        if mode is not None:
            self.to_mode(mode)
1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403

        if not self._mode:
            raise ValueError(
                "Please set mode to be prepared with `prepare(mode=...)`"
            )

        if self._has_prepared[self._mode]:
            return

        inputs_spec = self._validate_spec(inputs_spec)
        labels_spec = self._validate_spec(labels_spec)
        inputs = self._validate_vars(inputs)
        labels = self._validate_vars(labels)

        self._orig_main_prog = main_program
        self._orig_startup_prog = startup_program
1404 1405
        if inputs or labels:
            self._skip_build = True
1406 1407
            inputs, labels = self._prepare_data_tensor(
                inputs_spec, labels_spec, inputs, labels
1408
            )
1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419
            if self._orig_main_prog is None:
                self._orig_main_prog = static.default_main_program()
            if self._orig_startup_prog is None:
                self._orig_startup_prog = static.default_startup_program()
        elif inputs_spec or labels_spec:
            self._outside_dataloader = True
            if self._orig_main_prog is None:
                self._orig_main_prog = static.default_main_program()
            if self._orig_startup_prog is None:
                self._orig_startup_prog = static.default_startup_program()
        else:
1420 1421 1422
            assert (
                self._inputs_spec and self._labels_spec
            ), "Please call the dataloader(...) before calling prepare(...)"
1423

1424 1425 1426
        self._inputs_spec, self._labels_spec = inputs_spec, labels_spec
        self._inputs, self._labels = inputs, labels
        if not self._has_prepared[self._mode]:
1427
            self._prepare_program(self._mode, init_parameters)
1428 1429 1430
        else:
            self._switch_mode(self._mode)

1431
    def run(self, data=None, feed=None, fetch_list=None, mode=None):
1432 1433 1434 1435
        if mode is not None:
            self.to_mode(mode)
        feed_dict = self._prepare_feed(data, feed, self._mode)
        fetch_names, fetch_indices = self._prepare_fetch(fetch_list, self._mode)
1436 1437 1438 1439
        if (
            self._outside_dataloader
            and not self._has_prepared_reader[self._mode]
        ):
1440
            self._prepare_reader()
1441 1442 1443 1444 1445 1446 1447 1448 1449 1450
        outs = self._executor.run(
            self.main_program,
            feed=feed_dict,
            fetch_list=fetch_names,
            use_program_cache=self._strategy.use_cache,
            return_numpy=self._strategy.return_numpy,
        )
        logs = self._prepare_logger(
            outs, None, None, None, fetch_names, fetch_indices, self._mode
        )
Z
zhaoyingli 已提交
1451
        return logs
1452

1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467
    def _prepare_dataloader(
        self,
        dataset,
        return_list=True,
        batch_size=1,
        shuffle=False,
        drop_last=False,
        collate_fn=None,
        num_workers=0,
        use_buffer_reader=True,
        use_shared_memory=True,
        timeout=0,
        worker_init_fn=None,
        epochs=1,
        steps_per_epoch=None,
1468
        places=None,
1469
    ):
1470 1471 1472
        dist_context = self._dist_contexts[self._mode]
        dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
        dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank]
1473
        dist_main_block = dist_main_prog.global_block()
1474

1475 1476 1477 1478
        # NOTE: Get feed_list, then insert dataloader op with sharded var shape.
        # Cause predict_program does not contain labels var,
        # then we will add labels var from serial_program to dist_program,
        # that maintains the length of feed_list equal to the length of dataset's values.
1479 1480
        inputs_var = dist_context.serial_feed_vars["inputs"]
        labels_var = dist_context.serial_feed_vars["labels"]
1481 1482 1483 1484
        feed_list = []
        for var in inputs_var + labels_var:
            if var.name in dist_main_block.vars:
                feed_list.append(dist_main_block.vars[var.name])
1485 1486 1487 1488
            else:
                copy_var = dist_main_block._clone_variable(var, var.persistable)
                copy_var.desc.set_original_id(var.desc.original_id())
                feed_list.append(copy_var)
1489 1490

        # insert read op at the end of program
1491
        with static.program_guard(dist_main_prog, dist_startup_prog):
1492
            dataloader = DistributedDataLoader(
1493
                dataset,
1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508
                feed_list=feed_list,
                places=places,
                return_list=return_list,
                batch_size=batch_size,
                shuffle=shuffle,
                drop_last=drop_last,
                collate_fn=collate_fn,
                num_workers=num_workers,
                use_buffer_reader=use_buffer_reader,
                use_shared_memory=use_shared_memory,
                timeout=timeout,
                worker_init_fn=worker_init_fn,
                epochs=epochs,
                steps_per_epoch=steps_per_epoch,
                split_data=self._strategy.split_data,
1509
                data_parallel_world_size=self._dp_world_sizes,
1510 1511
                data_parallel_rank=self._dp_ranks,
            )
1512

1513 1514
        return dataloader

1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528
    def _prepare_dataloader_from_generator(
        self,
        dataset,
        capacity=None,
        use_double_buffer=True,
        iterable=True,
        return_list=False,
        use_multiprocess=False,
        drop_last=True,
        batch_size=1,
        epochs=1,
        steps_per_epoch=None,
        collate_fn=None,
    ):
1529 1530 1531
        dist_context = self._dist_contexts[self._mode]
        dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
        dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank]
1532 1533 1534 1535 1536 1537
        dist_main_block = dist_main_prog.global_block()

        # NOTE: Get feed_list, then insert dataloader op with sharded var shape.
        # Cause predict_program does not contain labels var,
        # then we will add labels var from serial_program to dist_program,
        # that maintains the length of feed_list equal to the length of dataset's values.
1538 1539
        inputs_var = dist_context.serial_feed_vars["inputs"]
        labels_var = dist_context.serial_feed_vars["labels"]
1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566
        feed_list = []
        for var in inputs_var + labels_var:
            if var.name in dist_main_block.vars:
                feed_list.append(dist_main_block.vars[var.name])
            else:
                copy_var = dist_main_block._clone_variable(var, var.persistable)
                copy_var.desc.set_original_id(var.desc.original_id())
                feed_list.append(copy_var)

        places = paddle.static.cuda_places()
        with static.program_guard(dist_main_prog, dist_startup_prog):
            dataloader = DistributedDataLoaderFromGenerator(
                dataset=dataset,
                feed_list=feed_list,
                capacity=capacity,
                use_double_buffer=use_double_buffer,
                iterable=iterable,
                return_list=return_list,
                use_multiprocess=use_multiprocess,
                drop_last=drop_last,
                places=places,
                batch_size=batch_size,
                epochs=epochs,
                steps_per_epoch=steps_per_epoch,
                collate_fn=collate_fn,
                split_data=self._strategy.split_data,
                data_parallel_world_size=self._dp_world_sizes,
1567
                data_parallel_rank=self._dp_ranks,
1568 1569 1570
                acc_steps=1
                if not self._strategy.pipeline.enable
                else self._acc_steps,
1571
            )
1572
        self._prepare_reader(feed_list)
1573 1574 1575 1576 1577
        return dataloader

    def _tune(self, tune_data, tune_sample_split=None, batch_size=1):
        self._mode = 'train'
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
1578 1579
            tune_data, tune_sample_split, batch_size
        )
1580 1581
        self._optimization_tuning(self._mode, tune_data, batch_size)

1582 1583 1584
    def _validate_batch_size(self, batch_size):
        if batch_size is None:
            return None
1585 1586
        if self._strategy.pipeline.enable and auto_utils.use_new_executor():
            return batch_size
1587 1588 1589 1590 1591 1592 1593
        assert (
            batch_size % self._acc_steps == 0
        ), "Requires batch_size:[{}] to be divisible by acc_steps:[{}].".format(
            batch_size, self._acc_steps
        )
        return batch_size // self._acc_steps

1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612
    def _validate_feed(self, feed):
        if feed is None:
            return [None]
        # pp with schedule or navie-pp
        if self._strategy.pipeline.enable or self._acc_steps == 1:
            return feed

        # split feed data with gradient_merge k_steps
        feed_names = []
        split_feeds = []
        for feed_name, cur_feed in feed[0].items():
            feed_names.append(feed_name)
            split_feeds.append(np.split(np.array(cur_feed), self._acc_steps, 0))
        micro_feeds = []
        for i in range(self._acc_steps):
            split_feed = [sf[i] for sf in split_feeds]
            micro_feeds.append(dict(zip(feed_names, split_feed)))
        return micro_feeds

1613
    def _validate_spec(self, specs):
1614
        specs = auto_utils.to_list(specs)
1615 1616
        if specs is not None:
            for i, spec in enumerate(specs):
1617 1618 1619 1620
                if not isinstance(spec, InputSpec):
                    raise TypeError(
                        "'spec' must be object of class `paddle.static.InputSpec`."
                    )
1621 1622
                if spec.name is None:
                    raise ValueError(
1623 1624 1625 1626
                        "Requires Input[{}].name != None, but receive `None` with {}.".format(
                            i, spec
                        )
                    )
1627
                if self._acc_steps > 1:
1628
                    shape = list(spec.shape)
1629
                    assert (
1630
                        shape[0] % self._acc_steps == 0
1631
                    ), "Requires batch_size[{}] to be divisible by k_steps[{}].".format(
1632
                        spec.shape[0], self._acc_steps
1633
                    )
1634
                    shape[0] //= self._acc_steps
1635
                    spec.shape = shape
1636 1637 1638
        return specs or []

    def _validate_vars(self, vars):
1639
        vars = auto_utils.to_list(vars)
1640 1641 1642 1643 1644
        if vars is not None:
            for i, var in enumerate(vars):
                if not isinstance(var, Variable):
                    raise TypeError("'var' must be a `Variable`.")
        return vars or []
1645

1646 1647 1648 1649
    def _is_local_var(self, var):
        var_name = _to_name_str(var)
        return var_name in self.main_program.global_block().vars

1650 1651 1652 1653
    def _reset_metrics(self):
        for metric in self._metrics:
            metric.reset()

Z
zhaoyingli 已提交
1654 1655 1656
    def _metrics_name(self):
        metrics_name = ['loss'] if self._loss else []
        for m in self._metrics:
1657
            metrics_name.extend(auto_utils.to_list(m.name()))
Z
zhaoyingli 已提交
1658 1659
        return metrics_name

1660
    def _switch_mode(self, mode):
1661
        assert (
1662
            mode in self._dist_contexts
1663
        ), f"{mode} model is not ready, please call `prepare()` first."
1664
        self.to_mode(mode)
1665

1666
    def to_mode(self, mode):
1667 1668 1669 1670
        assert mode in [
            "train",
            "eval",
            "predict",
1671
        ], f"mode {mode} should be one of ['train', 'eval', 'predict']"
1672 1673
        self._mode = mode

1674 1675
    def _set_state_dict(self, mode, strict, state_dict, dist_attr):
        dist_context = self._dist_contexts[mode]
1676
        program = dist_context.dist_main_programs[self._cur_rank]
1677
        cur_dist_attr = auto_utils.get_dist_attr(program, dist_context)
1678 1679
        converter = Converter(state_dict, dist_attr, cur_dist_attr)
        state_dict = converter.convert(strict=strict)
1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692
        for name, param in program.state_dict().items():
            param_array = np.array(param)
            if name not in state_dict:
                continue
            if param_array.dtype != state_dict[name].dtype:
                self._logger.info(
                    "cast {}'s dtype from '{}' to '{}'".format(
                        name,
                        str(state_dict[name].dtype),
                        str(param_array.dtype),
                    )
                )
                state_dict[name] = state_dict[name].astype(param_array.dtype)
1693 1694 1695
        program.set_state_dict(state_dict)

    def save(self, path, training=True):
1696 1697
        """
        Saves the model, parameters, optimizer state to path.
1698 1699 1700 1701 1702 1703 1704
        If `training` is set to False, only inference model will be saved.

        Args:
            path (str): The file prefix to save model. The format
                is 'dirname/file_prefix' or 'file_prefix'. if empty str.
                A exception will be raised.
            training (bool, optional): Whether to save for training. If not, save
1705
                for inference only. If `training` is set to True, the optimizer state
1706 1707 1708 1709 1710 1711 1712 1713 1714 1715
                will be saved. Otherwise, only the model and parameters are saved.
                This function will silently overwrite existing file at the target
                location. Default: True.

        Returns:
            None

        Examples:

            .. code-block:: python
1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738

                >>> import paddle
                >>> import paddle.vision.transforms as T
                >>> from paddle.distributed.fleet import auto
                >>> from paddle.vision.datasets import MNIST

                >>> transform = T.Compose([
                ...     T.Transpose(),
                ...     T.Normalize([127.5], [127.5])
                >>> ])
                >>> train_dataset = MNIST(mode='train', transform=transform)

                >>> model = paddle.vision.models.LeNet()
                >>> loss = paddle.nn.CrossEntropyLoss()
                >>> optimizer = paddle.optimizer.Adam(
                ...     learning_rate=0.001, parameters=model.parameters())
                >>> metrics = paddle.metric.Accuracy(topk=(1, 2))

                >>> engine = auto.Engine(model, loss, optimizer, metrics)
                >>> engine.fit(train_dataset,
                ...             epochs=1,
                ...             batch_size=64)
                >>> engine.save("./my_model")
1739

1740
        """
1741
        if training:
1742
            assert self._mode in self._dist_contexts
Z
zhaoyingli 已提交
1743
            dist_context = self._dist_contexts[self._mode]
1744 1745
            serial_program = dist_context.serial_main_program
            dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
1746 1747 1748 1749 1750 1751
            self._saver.save(
                path,
                serial_program=serial_program,
                dist_main_program=dist_main_prog,
                dist_context=dist_context,
            )
1752
        else:
1753 1754 1755 1756 1757
            assert "predict" in self._dist_contexts
            dist_context = self._dist_contexts["predict"]
            feed_vars = dist_context.serial_feed_vars['inputs']
            fetch_vars = dist_context.serial_fetch_vars['outputs']
            dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
1758
            if self._strategy.qat.enable and self._strategy.qat.onnx_format:
1759
                from paddle.static.quantization import QuantWeightPass
1760 1761 1762

                self._logger.info("export quantized model.")
                self._logger.info(
1763
                    f"convert config {self._strategy.qat.to_dict()}"
1764 1765 1766 1767 1768 1769 1770 1771
                )
                test_graph = IrGraph(
                    core.Graph(dist_main_prog.desc), for_test=True
                )
                quant_weight_pass = QuantWeightPass(global_scope(), self._place)
                for sub_graph in test_graph.all_sub_graphs():
                    quant_weight_pass.apply(sub_graph)
                dist_main_prog = test_graph.to_program()
1772 1773 1774 1775 1776 1777 1778
            self._saver.save_inference_model(
                path,
                feed_vars,
                fetch_vars,
                self._executor,
                program=dist_main_prog,
            )
1779

1780 1781 1782 1783 1784 1785
    def load(self, path, strict=True, load_optimizer=True):
        """
        Load the stored model, parameters and optimizer states.

        Args:
            path (str): The prefix of files storing the model states and
1786
                optimizer states.
1787 1788 1789
            strict (bool, optional): Whether to skip the loading of mismatch
                parameter or raise an error when mismatch happens (not found
                the parameter in file storing model states of or receives a
1790
                mismatch shape). Default: True.
1791
            load_optimizer (bool, optional): If True, the stored optimizer
1792
                states is restored. Otherwise, the optimizer states is initialized
1793
                from scratch. Default: True.
1794 1795 1796 1797 1798 1799 1800

        Returns:
            None

        Examples:

            .. code-block:: python
1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824

                >>> import paddle
                >>> import paddle.vision.transforms as T
                >>> from paddle.distributed.fleet import auto
                >>> from paddle.vision.datasets import MNIST

                >>> transform = T.Compose([
                ...     T.Transpose(),
                ...     T.Normalize([127.5], [127.5])
                >>> ])
                >>> train_dataset = MNIST(mode='train', transform=transform)

                >>> model = paddle.vision.models.LeNet()
                >>> loss = paddle.nn.CrossEntropyLoss()
                >>> optimizer = paddle.optimizer.Adam(
                ...     learning_rate=0.001, parameters=model.parameters())
                >>> metrics = paddle.metric.Accuracy(topk=(1, 2))

                >>> engine = auto.Engine(model, loss, optimizer, metrics)
                >>> engine.fit(train_dataset,
                ...             epochs=1,
                ...             batch_size=64)
                >>> engine.save("./my_model")
                >>> engine.load("./my_model")
1825

1826 1827 1828
        """
        self._strict = strict
        self._state_dict, self._dist_attr = self._saver.load(
1829 1830
            path, load_optimizer
        )
1831
        return self._state_dict, self._dist_attr
1832

1833
    def cost(self, inputs_spec=None, labels_spec=None, mode=None):
1834 1835 1836 1837 1838 1839 1840 1841 1842 1843
        """
        Get and Print cost, including memory of every rank,
        max memory among all ranks, and the global cost of one step based on
        communication cost(computation cost is 0 by default).
        In the future, the flops information of every rank and global cost including
        computation cost will be added.

        Args:
            inputs_spec(InputSpec): The specification of inputs. Default: None.
            labels_spec(InputSpec): The specification of labels. Default: None.
1844
            mode (str): The engine mode must be in ["train", "predict", "eval"]. Default: None.
1845 1846 1847 1848 1849 1850 1851

        Returns:
            Return the global execution time (ms) and max memory (B).

        """
        # Check parallel mode
        if self._strategy.auto_mode == "full":
1852
            self._logger.info(
1853 1854 1855 1856 1857
                "The cost will be calcudated in the search process when the auto mode is full."
            )
            return

        # Check mode
1858 1859 1860
        mode = mode if mode is not None else self._mode
        assert mode is not None, "Please set mode."
        if mode not in self._has_prepared:
1861 1862
            raise ValueError(
                "The mode {} is not in accepted modes {}".format(
1863
                    mode, list(self._has_prepared.keys())
1864 1865
                )
            )
1866 1867
        self.to_mode(mode)

1868 1869 1870
        if inputs_spec is not None and not self._has_prepared[mode]:
            self._inputs_spec = self._validate_spec(inputs_spec)
            self._labels_spec = self._validate_spec(labels_spec)
1871 1872 1873
            self._build(mode)
            self._plan(mode)
        else:
1874
            if in_dynamic_mode() or self._dygraph_mode:
1875
                raise ValueError(
1876 1877 1878 1879 1880
                    "Please call `prepare()` or `fit()` or  `evaluate()` or  `predict()` before calling `cost()`."
                )
            else:
                self._logger.info(
                    "The program whose cost to be estimated must be static default program. Otherwise, please call `prepare()`before calling `cost()`."
1881
                )
1882 1883 1884 1885 1886 1887 1888 1889
                program = paddle.static.default_main_program()
                if (
                    not program.global_block().ops
                    or not program.global_block().ops
                ) and not self._has_prepared[mode]:
                    raise ValueError(
                        "Please call `prepare()` or `fit()` or  `evaluate()` or  `predict()` before calling `cost()`."
                    )
1890 1891 1892 1893 1894 1895

        # Estimate the exec cost and max memory
        global_cost, max_memory = get_cost_from_engine(self, mode)

        return global_cost.time, max_memory

1896 1897
    @property
    def main_program(self):
1898 1899
        dist_context = self._dist_contexts[self._mode]
        return dist_context.dist_main_programs[self._cur_rank]
1900 1901 1902

    @property
    def startup_program(self):
1903 1904
        dist_context = self._dist_contexts[self._mode]
        return dist_context.dist_startup_programs[self._cur_rank]
1905 1906 1907

    @property
    def dist_context(self):
1908
        return self._dist_contexts[self._mode]
1909 1910 1911

    @property
    def serial_main_program(self):
1912 1913
        dist_context = self._dist_contexts[self._mode]
        return dist_context.serial_main_program
1914 1915 1916

    @property
    def serial_startup_program(self):
1917 1918 1919 1920 1921 1922 1923
        dist_context = self._dist_contexts[self._mode]
        return dist_context.serial_startup_program

    @property
    def feed_vars(self):
        dist_context = self._dist_contexts[self._mode]
        return dist_context.serial_feed_vars
1924 1925 1926

    @property
    def fetch_vars(self):
1927 1928 1929 1930 1931 1932 1933 1934 1935
        dist_context = self._dist_contexts[self._mode]
        return dist_context.serial_fetch_vars

    @property
    def optimizer(self):
        dist_context = self._dist_contexts[self._mode]
        if dist_context._serial_optimizer:
            return dist_context._serial_optimizer
        return self._optimizer
1936 1937 1938

    @property
    def inputs(self):
1939
        return self._inputs
1940 1941 1942

    @property
    def labels(self):
1943
        return self._labels