engine.py 75.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import copy
16
import json
17
import logging
18
import numbers
19 20
import os
import random
21

22 23
import numpy as np

24
import paddle
25
import paddle.distributed.auto_parallel.static.utils as auto_utils
26
from paddle import static, utils
27
from paddle.distributed import fleet
28 29 30
from paddle.fluid.executor import _to_name_str
from paddle.framework import IrGraph
from paddle.framework import _current_expected_place as _get_device
31
from paddle.framework import core, in_dynamic_mode
32
from paddle.metric import Metric
33
from paddle.static import InputSpec, Operator, Variable, global_scope
34

35 36 37
from ...utils.log_utils import get_logger
from ..interface import CollectionNames, fetch, get_collection
from ..strategy import Strategy
Z
zhaoyingli 已提交
38
from .callbacks import config_callbacks
39
from .cluster import Cluster, get_default_cluster
40 41 42
from .converter import Converter
from .cost.estimate_cost import get_cost_from_engine
from .dist_context import DistributedContext, get_default_distributed_context
43 44
from .dist_loader import (
    DistributedDataLoader,
45
    DistributedDataLoaderFromGenerator,
46
)
47 48 49 50 51 52
from .dist_op import DistributedOperator
from .dist_saver import DistributedSaver
from .helper import ProgramHelper
from .parallelizer_v2 import Parallelizer
from .planner_v2 import Planner
from .process_group import get_all_process_groups, new_process_group
53

54 55

class Engine:
56
    """
J
JZ-LIANG 已提交
57 58
    An High-Level API for auto parallel, which could be used for distributed Training (engine.fit) and Inferenced (engine.predict).
    Static graph mode is supported natively, Dynamic graph mode is also supported under `@to_static <https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/jit/to_static_cn.html#to-static>`_ .
59 60 61 62 63

    Args:
        model (paddle.nn.Layer, optional): The model is an instance of
            paddle.nn.Layer.
        loss (Loss|Callable|None, optional): The loss can be a `paddle.nn.Layer`
64 65
            instance or any callable function taken the predicted values and
            ground truth values as input. It can be None when there is no loss.
66 67 68 69 70 71 72 73 74 75 76 77 78 79
            Default: None.
        optimizer (Optimizer|None, optional): The optimizer need to be set in training
            and should be None in eval and predict mode. Default: None.
        metrics (Metric|list[Metric]|None, optional): If metrics is set, all
            metrics will be calculated and output in train/eval mode. Default: None.
        cluster (Cluster|None, optional): The cluster represents the topology information
            about the used physical devices. Default: None. (Unused for now)
        strategy (Strategy|None, optional): The strategy is used to configure the
        parallelization and optimization behaviors. Default: None.

    Examples:

        .. code-block:: python

80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
            >>> import paddle
            >>> import paddle.vision.transforms as T
            >>> from paddle.distributed.fleet import auto
            >>> from paddle.vision.datasets import MNIST

            >>> transform = T.Compose([
            ...     T.Transpose(),
            ...     T.Normalize([127.5], [127.5])
            >>> ])
            >>> train_dataset = MNIST(mode='train', transform=transform)
            >>> valid_dataset = MNIST(mode='test', transform=transform)

            >>> model = paddle.vision.models.LeNet()
            >>> loss = paddle.nn.CrossEntropyLoss()
            >>> optimizer = paddle.optimizer.Adam(
            ...     learning_rate=0.001, parameters=model.parameters())
            >>> metrics = paddle.metric.Accuracy(topk=(1, 2))

            >>> engine = auto.Engine(model, loss, optimizer, metrics)
            >>> # fit
            >>> engine.fit(train_dataset,
            ...            epochs=2,
            ...            batch_size=64)
            >>> # evaluate
            >>> engine.evaluate(valid_dataset,
            ...                 batch_size=64)
            >>> # predict
            >>> engine.predict(valid_dataset,
            ...                batch_size=64)
            >>> # save
            >>> engine.save("./my_model")
            >>> # load
            >>> engine.load("./my_model")
113 114

    """
115

116 117 118 119 120 121 122 123 124 125 126 127 128 129
    def __init__(
        self,
        model=None,
        loss=None,
        optimizer=None,
        metrics=None,
        cluster=None,
        strategy=None,
    ):
        if (
            model
            and not isinstance(model, paddle.nn.Layer)
            and not callable(model)
        ):
130 131 132 133
            raise TypeError(
                "'model must be sub classes of `paddle.nn.Layer` or any callable function."
            )
        self._model = model
134 135 136
        self._parameter_list = (
            None if not model else [p.name for p in model.parameters()]
        )
137 138 139 140 141 142 143 144 145

        if (
            loss
            and not isinstance(loss, (paddle.nn.Layer, Variable))
            and not callable(loss)
        ):
            raise TypeError(
                "'loss' must be sub classes of `paddle.nn.Layer` or any callable function or a Variable."
            )
146 147 148
        self._loss = loss

        if optimizer and not isinstance(
149
            optimizer,
150
            (paddle.optimizer.Optimizer),
151
        ):
152 153
            raise TypeError(
                "'optimizer' must be object of class `paddle.optimizer.Optimizer`"
154
            )
155
        self._optimizer = auto_utils.validate_opt(optimizer)
156 157

        metrics = metrics or []
158
        for metric in auto_utils.to_list(metrics):
159 160 161 162 163 164
            if metric and not isinstance(metric, Metric):
                raise TypeError(
                    "{} is not sub class of Metric".format(
                        metric.__class__.__name__
                    )
                )
165
        self._metrics = auto_utils.to_list(metrics)
166 167 168 169 170 171 172 173 174 175 176 177 178

        if cluster and not isinstance(cluster, Cluster):
            raise TypeError(
                "'cluster' must be the object or class `paddle.distributed.auto_parallel.Cluster`"
            )
        self._cluster = cluster or get_default_cluster()

        if strategy and not isinstance(strategy, Strategy):
            raise TypeError(
                "'strategy' must be object of class `paddle.distributed.auto_parallel.Strategy`"
            )
        self._strategy = strategy or Strategy()

179
        self._logger = get_logger(logging.INFO)
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196

        self._json_config = None
        if cluster:
            self._cluster = cluster
        else:
            if os.getenv("PADDLE_AUTO_PARALLEL_CONFIG"):
                try:
                    path = os.getenv("PADDLE_AUTO_PARALLEL_CONFIG")
                    with open(path, "r") as f:
                        self._json_config = json.load(f)
                except Exception as e:
                    self._logger.info(
                        "Load json failed, please check json file, engine will run default config."
                    )
                    self._json_config = None
            self._cluster = get_default_cluster(self._json_config)

197
        if os.getenv("POD_NAME"):
198 199
            self._logger.info(
                "Distribute training by paddle.distributed.launch"
200
            )
201
            fleet.init(is_collective=True)
202

203 204 205 206 207 208
        # for compute cost
        # TODO: remove _fwd_main_progs and _orig_optimizer
        self._fwd_dist_contexts = {}
        self._fwd_main_progs = {}
        self._orig_optimizer = copy.deepcopy(self._optimizer)

209
        self._executor = None
210 211 212
        self._cur_rank = paddle.distributed.get_rank()
        self._nranks = paddle.distributed.get_world_size()
        self._saver = DistributedSaver()
213

214 215
        self._orig_main_prog = static.default_main_program()
        self._orig_startup_prog = static.default_startup_program()
216
        self._orig_dist_context = get_default_distributed_context()
217
        self._dist_contexts = {}
218
        self._planners = {}
219 220
        self._has_prepared = {"train": False, "eval": False, "predict": False}
        self._has_prepared_reader = {
221 222
            "train": False,
            "eval": False,
223
            "predict": False,
224
        }
225 226 227 228
        self._inputs_spec = []
        self._labels_spec = []
        self._inputs = []
        self._labels = []
229
        self._losses = []
230

231
        self._mode = None
232 233
        self._skip_build = False
        self._outside_dataloader = False
234
        self._planned_mode = None
235 236
        self._dygraph_mode = False
        self._tuning = self._strategy.tuning
237 238 239 240 241
        self._acc_steps = 1
        if self._strategy.gradient_merge.enable:
            self._acc_steps = self._strategy.gradient_merge.k_steps
        elif self._strategy.pipeline.enable:
            self._acc_steps = self._strategy.pipeline.accumulate_steps
242

243 244 245 246 247 248 249 250
        if (
            self._strategy.pipeline.enable
            and self._strategy.pipeline.schedule_mode == "1F1B"
        ):
            assert (
                os.getenv("CUDA_MODULE_LOADING") != "LAZY"
            ), "EXP_CUDA_MODULE_LOADING_LAZY not supported in 1F1B pipeline."

Z
zhaoyingli 已提交
251 252
        self.history = None

253
        paddle.framework.set_flags({'FLAGS_new_executor_sequential_run': 1})
254
        paddle.framework.set_flags({'FLAGS_new_executor_static_build': 1})
255

256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273
    def _prepare_data_spec(self, data, split, batch_size):
        inputs_spec = []
        labels_spec = []
        if isinstance(data, paddle.io.IterableDataset):
            if split is None:
                inputs, labels = next(iter(data))
            else:
                sample = next(iter(data))
                inputs = sample[:split]
                labels = sample[split:]
        elif isinstance(data, paddle.io.Dataset):
            if split is None:
                inputs, labels = data[0]
            else:
                sample = data[0]
                inputs = sample[:split]
                labels = sample[split:]
        else:
274
            raise TypeError(
C
chenxujun 已提交
275
                "Data should be a Dataset or IterableDataset, but received {}.".format(
276 277 278
                    type(data).__name__
                )
            )
279 280
        inputs = auto_utils.to_list(inputs)
        labels = auto_utils.to_list(labels)
281 282

        num_shards = self._strategy.dataset.num_shards
283

284 285 286 287 288 289 290 291 292 293 294 295
        def _adjust_item_spec(num_shards, spec):
            if num_shards > 1 and len(spec.shape) > 1:
                spec.shape[0] = spec.shape[0] * num_shards

        def _infer_item_spec(item, name, batch_size, specs):
            if isinstance(item, np.ndarray):
                spec = InputSpec.from_numpy(item, name)
                if batch_size is None:
                    _adjust_item_spec(num_shards, spec)
                    specs.append(spec)
                else:
                    specs.append(spec.batch(batch_size))
W
wanghuancoder 已提交
296
            elif isinstance(item, (Variable, core.eager.Tensor)):
297
                spec = InputSpec.from_tensor(item, name)
298
                _adjust_item_spec(num_shards, spec)
299 300 301 302
                if batch_size is None:
                    specs.append(spec)
                else:
                    specs.append(spec.batch(batch_size))
303
            elif isinstance(item, numbers.Number):
304
                specs.append(InputSpec([batch_size], type(item), name))
305 306 307 308 309 310
            else:
                raise TypeError(
                    "The sample's dtype returned of dataset should be number, np.ndarray or Tensor, but got {}".format(
                        type(item).__name__
                    )
                )
311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326

        if inputs is not None:
            for i, item in enumerate(inputs):
                assert item is not None, "Receive None input."
                name = "input" + str(i)
                _infer_item_spec(item, name, batch_size, inputs_spec)
        if labels is not None:
            for i, item in enumerate(labels):
                assert item is not None, "Receive None input."
                name = "label" + str(i)
                _infer_item_spec(item, name, batch_size, labels_spec)

        inputs_spec = self._validate_spec(inputs_spec)
        labels_spec = self._validate_spec(labels_spec)
        return inputs_spec, labels_spec

327
    def _prepare_data_tensor(self, inputs_spec, labels_spec, inputs, labels):
328
        if in_dynamic_mode() or self._dygraph_mode:
329 330
            raise ValueError("Only support static graph mode.")

331
        if inputs_spec:
332 333 334 335 336
            assert isinstance(
                inputs_spec, list
            ), "inputs should be list, but received {}".format(
                type(inputs_spec)
            )
337 338
            assert isinstance(
                inputs, list
339
            ), f"inputs should be list, but received {type(inputs)}"
340 341 342 343 344 345
            assert len(inputs_spec) == len(
                inputs
            ), "the number of `inputs_spec` should be equal to `inputs`'s."
            for input_spec, input in zip(inputs_spec, inputs):
                if input_spec.shape != input.shape:
                    input.desc.set_shape(input_spec.shape)
346
        if labels_spec:
347 348 349 350 351
            assert isinstance(
                labels_spec, list
            ), "labels should be list, but received {}".format(
                type(labels_spec)
            )
352 353
            assert isinstance(
                labels, list
354
            ), f"labels should be list, but received {type(labels)}"
355 356 357 358 359 360 361
            assert len(labels_spec) == len(
                labels
            ), "the number of `labels_spec` should be equal to `labels`'s."
            for label_spec, label in zip(labels_spec, labels):
                if label_spec.shape != label.shape:
                    label.desc.set_shape(label_spec.shape)

362 363
        return inputs, labels

364
    def _prepare_reader(self, feed_list=[]):
365
        dist_context = self._dist_contexts[self._mode]
366
        dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
367 368 369 370
        dist_main_block = dist_main_prog.global_block()

        # NOTE: this list may be changed if Paddle changes the existing rules.
        related_reader_ops = [
371 372 373
            "create_py_reader",
            "create_double_buffer_reader",
            "read",
374 375 376 377 378 379 380 381 382 383 384 385 386
        ]
        # remove the first three ops if multiple run fit/evaluate/predict
        if dist_main_block.ops[0].type == 'create_py_reader':
            for i in range(len(related_reader_ops)):
                if dist_main_block.ops[0].type in related_reader_ops:
                    dist_main_block._remove_op(0, sync=False)
        dist_main_block._sync_with_cpp()
        # Step 1: find the reader ops
        reader_op_indices = []
        for idx, op in enumerate(dist_main_block.ops):
            if op.type in related_reader_ops:
                reader_op_indices.append(idx)
        # Step 2: insert the new reader ops to cpp
387 388
        # record the read ops' desc to insert to program of forward task_node
        read_ops_desc = []
389 390 391 392
        new_reader_ops = []
        for idx in reversed(reader_op_indices):
            new_op_desc = dist_main_block.desc._prepend_op()
            new_op_desc.copy_from(dist_main_block.ops[idx].desc)
393
            read_ops_desc.append(new_op_desc)
394 395 396
            new_op = Operator(
                dist_main_block, new_op_desc, type=new_op_desc.type()
            )
397 398 399 400 401 402 403 404 405 406 407 408 409 410 411
            new_reader_ops.append(new_op)
            dist_op = DistributedOperator(new_op)
            dist_context.add_dist_op_for_program(dist_op)
        # Step 3: insert the new reader ops to python
        for new_op in new_reader_ops:
            dist_main_block.ops.insert(0, new_op)
        for i in range(len(reader_op_indices)):
            reader_op_indices[i] += len(reader_op_indices)
        # Step 4: remove the old reader ops from python and cpp
        for idx in reversed(reader_op_indices):
            op = dist_main_block.ops.pop(idx)
            dist_main_block.desc._remove_op(idx, idx + 1)
        dist_main_block._sync_with_cpp()
        self._has_prepared_reader[self._mode] = True

412 413 414 415 416
        # Insert read op to forward TaskNode for fleet executor if 1F1B pass is setted
        if (
            self.main_program._pipeline_opt
            and not auto_utils.use_new_executor()
        ):
417 418
            assert "tasks" in self.main_program._pipeline_opt["fleet_opt"]
            fleet_opt = self.main_program._pipeline_opt["fleet_opt"]
419 420 421 422 423 424
            fwd_task = None
            if self._strategy.pipeline.schedule_mode == "1F1B":
                fwd_task = fleet_opt["tasks"][1]
            elif self._strategy.pipeline.schedule_mode == "stream":
                fwd_task = fleet_opt["tasks"][0]
            assert fwd_task is not None
425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442
            fwd_prog = fwd_task.get_program()
            fwd_block = fwd_prog.global_block()

            for var in feed_list:
                if var.name not in fwd_block.vars:
                    fwd_block._clone_variable(var)

            for op_desc in read_ops_desc:
                new_op_desc = fwd_block.desc._prepend_op()
                new_op_desc.copy_from(op_desc)
                new_op = Operator(
                    fwd_block, new_op_desc, type=new_op_desc.type()
                )
                fwd_block.ops.insert(0, new_op)

            fwd_block._sync_with_cpp()
            fwd_task.set_program(fwd_prog)

443 444 445 446 447
    def _prepare_feed(self, data, user_feeds, mode):
        feeds = {}
        if data is not None:
            if isinstance(data, (list, tuple)):
                if len(data) == 1 and isinstance(data[0], dict):
448 449
                    for name, value in data[0].items():
                        feeds[name] = value
450
                else:
451
                    raise ValueError(f"Unsupported data {data}")
452
            elif isinstance(data, dict):
453 454
                for name, value in data.items():
                    feeds[name] = value
455
            else:
456
                raise ValueError(f"Unsupported data {data}")
457
        if user_feeds is not None:
458 459 460 461 462
            assert isinstance(
                user_feeds, dict
            ), "user_feeds must be a dict, but receive {}".format(
                type(user_feeds).__name__
            )
463 464
            for name, data in user_feeds.items():
                feeds[name] = data
465 466
        return feeds

467
    def _prepare_fetch(self, user_fetches, mode):
468
        if user_fetches is not None:
469 470 471 472 473
            assert isinstance(
                user_fetches, list
            ), "user_fetches must be a list, but receive {}".format(
                type(user_fetches).__name__
            )
474
        fetch_names = []
475
        fetch_indices = []
476

477 478
        def _process_fetch_group(group_name, var_list):
            group_indices = []
479
            for var in var_list:
480 481 482 483 484 485 486 487
                # Remove duplicate var_names
                if self._is_local_var(var):
                    var_name = _to_name_str(var)
                    if var_name not in fetch_names:
                        fetch_names.append(var_name)
                    group_indices.append(fetch_names.index(var_name))
            fetch_indices.append(group_indices)

488 489
        dist_context = self._dist_contexts[mode]
        fetch_vars = dist_context.serial_fetch_vars
490
        if mode != "predict":
491
            _process_fetch_group("loss", fetch_vars["loss"])
492
        if mode != "predict":
493
            metrics = fetch_vars["metrics"]
494 495 496
            for i, var_list in enumerate(metrics):
                _process_fetch_group("metrics_" + str(i), var_list)
        if mode == "predict":
497
            _process_fetch_group("outputs", fetch_vars["outputs"])
498
        for usr_fetch in user_fetches or []:
499 500
            var_name = _to_name_str(usr_fetch)
            fetch(var_name)
501 502 503
        user_fetches_collection = [
            item[1] for item in get_collection(CollectionNames.FETCHES)
        ]
504
        var_list = user_fetches_collection or []
505 506 507
        _process_fetch_group("fetches", var_list)
        return fetch_names, fetch_indices

508 509 510 511 512 513 514 515 516 517
    def _prepare_logger(
        self,
        outs,
        epoch=None,
        step=None,
        lr=None,
        fetch_names=None,
        fetch_indices=None,
        mode=None,
    ):
Z
zhaoyingli 已提交
518
        logs = {}
519
        if epoch is not None:
Z
zhaoyingli 已提交
520
            logs["epoch"] = epoch
521
        if step is not None:
Z
zhaoyingli 已提交
522
            logs["step"] = step + 1
523
        if lr is not None:
Z
zhaoyingli 已提交
524
            logs["lr"] = lr
525 526
        group_idx = 0
        if mode != "predict":
Z
zhaoyingli 已提交
527
            # logging loss
528
            loss_indices = fetch_indices[group_idx]
Z
zhaoyingli 已提交
529
            assert len(loss_indices) <= 1
530
            for idx in loss_indices:
531
                logs["loss"] = outs[idx]
532
            group_idx += 1
Z
zhaoyingli 已提交
533
            # logging metrics
534 535
            dist_context = self._dist_contexts[mode]
            metric_vars = dist_context.serial_fetch_vars["metrics"]
536 537 538 539 540 541 542 543 544
            if metric_vars:
                for metric in self._metrics:
                    metrics_indices = fetch_indices[group_idx]
                    metric_out = []
                    for idx in metrics_indices:
                        metric_out.append(outs[idx])
                    if metric_out:
                        metric.update(*metric_out)
                        results = metric.accumulate()
545
                        for i, res in enumerate(auto_utils.to_list(results)):
Z
zhaoyingli 已提交
546
                            logs[metric.name()[i]] = res
547
                    group_idx += 1
Z
zhaoyingli 已提交
548 549 550 551 552 553 554
        # logging outputs
        elif mode == "predict":
            outputs_indices = fetch_indices[group_idx]
            logs_out = {}
            for idx in outputs_indices:
                logs_out["out%d" % (idx)] = outs[idx]
            logs["outputs"] = logs_out
555 556
            group_idx += 1
        # logging user fetches
Z
zhaoyingli 已提交
557 558
        collect_fetches = get_collection(CollectionNames.FETCHES)
        logs_fetch = {}
559 560 561 562
        for name, var_name in collect_fetches:
            if var_name in fetch_names:
                idx = fetch_names.index(var_name)
                logs_fetch[name or var_name] = outs[idx]
Z
zhaoyingli 已提交
563 564
        logs["fetches"] = logs_fetch
        return logs
565

566
    def _prepare_program(self, mode, init_parameters=True):
567 568 569 570 571 572
        # Do the build process
        self._build(mode)
        # Do the planning process
        self._plan(mode)
        # Do the parallel process
        self._parallel(mode)
573 574 575 576 577
        # Init comm
        self._init_comm()
        if init_parameters:
            # startup program
            self._initialize(mode)
578 579
        self._has_prepared[mode] = True

580
    def _build(self, mode):
581
        if in_dynamic_mode() or self._dygraph_mode:
582
            paddle.disable_static()
583 584 585
            self._dygraph_mode = True
            self._logger.info("Building model with 'to_static' method.")

586
            self.program_helper = ProgramHelper(
587 588 589 590 591
                self._model,
                self._loss,
                self._metrics,
                self._inputs_spec,
                self._labels_spec,
592
            )
593
            # build forward main program
594 595
            with utils.unique_name.guard():
                self.program_helper.build_program(mode)
596

597 598 599
            self.concrete_program = self.program_helper.concrete_program
            serial_main_prog = self.program_helper.main_program
            serial_startup_prog = self.program_helper.startup_program
600

601 602
            self._inputs = self.program_helper.input_vars
            self._labels = self.program_helper.label_vars
603
            outputs = self.program_helper.output_vars
604
            self._losses = self.program_helper.loss_vars
605
            metrics = self.program_helper.metric_vars
606

607
            paddle.enable_static()
608
        else:
609 610 611
            # build program in static mode
            dist_context = self._dist_contexts.get(mode, None)
            if dist_context is not None:
612 613
                return

614
            outputs = []
615
            metrics = []
616
            self._losses = []
617 618
            serial_main_prog = self._orig_main_prog.clone()
            serial_startup_prog = self._orig_startup_prog.clone()
619
            if not self._skip_build:
620 621 622
                with static.program_guard(
                    serial_main_prog, serial_startup_prog
                ), utils.unique_name.guard():
623 624 625 626 627 628 629
                    self._inputs = [
                        s._create_feed_layer() for s in self._inputs_spec
                    ]
                    self._labels = [
                        s._create_feed_layer() for s in self._labels_spec
                    ]

630
                    outputs = auto_utils.to_list(self._model(*self._inputs))
631

632
                    if mode != "predict" and self._loss:
633 634 635 636 637
                        assert isinstance(
                            self._loss, paddle.nn.Layer
                        ) or callable(
                            self._loss
                        ), "the type of `loss` of the Engine arguments should be sub classes of `paddle.nn.Layer` or any callable function."
638
                        self._losses = auto_utils.to_list(
639 640
                            self._loss(*(outputs + self._labels))
                        )
641

642
                    if mode != "predict" and (outputs or self._labels):
643 644
                        for metric in self._metrics:
                            metrics.append(
645
                                auto_utils.to_list(
646 647
                                    metric.compute(*(outputs + self._labels))
                                )
648
                            )
Z
zhaoyingli 已提交
649
            elif mode == "train":
650 651 652
                assert isinstance(
                    self._loss, Variable
                ), "the type of `loss` of the Engine arguments should be Variable."
653
                self._losses = auto_utils.to_list(self._loss)
654 655 656 657 658 659 660

        default_ctx = get_default_distributed_context()
        if not default_ctx.has_annotation:
            # We build the world process group because the data parallel
            # needs all ranks by default.
            new_process_group(list(range(self._nranks)))
            default_ctx.data_parallel = True
661 662 663 664 665 666
            self._inputs = [
                auto_utils.set_data_parallel(var) for var in self._inputs
            ]
            self._labels = [
                auto_utils.set_data_parallel(var) for var in self._labels
            ]
667

668
        feed_vars = {"inputs": self._inputs, "labels": self._labels}
669 670

        fetch_vars = {
671
            "outputs": paddle.utils.flatten(outputs),
672
            "loss": self._losses,
673
            "metrics": metrics,
674 675
        }

676 677 678
        if mode != "train":
            serial_main_prog = serial_main_prog.clone(for_test=True)

679 680 681
        auto_utils.set_recompute_segments(
            self._model, self._losses, self._strategy, serial_main_prog
        )
682
        self._dist_contexts[mode] = DistributedContext(
683 684 685
            serial_main_prog,
            serial_startup_prog,
            self._optimizer,
686 687 688 689 690
            self._losses,
            feed_vars,
            fetch_vars,
            self._cluster,
            self._strategy,
691
            self._json_config,
692 693 694 695 696 697
        )
        self._fwd_dist_contexts[mode] = DistributedContext(
            serial_main_prog,
            serial_startup_prog,
            self._optimizer,
            self._losses,
698 699 700 701
            feed_vars,
            fetch_vars,
            self._cluster,
            self._strategy,
702
            self._json_config,
703
        )
704
        self._dist_contexts[mode].gradient_scale = self._strategy.gradient_scale
705
        self._fwd_main_progs[mode] = serial_main_prog.clone()
706

707 708 709
    def _optimization_tuning(self, mode, dataset, batch_size):
        if not self._tuning.enable:
            raise ValueError("Please set `tuning.enable=True`.")
710

711 712 713 714 715 716 717 718
        assert mode == "train"
        # Do the build process
        self._build(mode)
        # Do the planning process
        self._plan(mode)

        dataset.dp_world_size = self._dp_world_sizes
        dataset.dp_rank = self._dp_ranks
719 720

        from .tuner.optimization_tuner import OptimizationTuner
721 722 723 724 725 726 727 728 729

        self._optimization_tuner = OptimizationTuner(
            self._dist_contexts[mode],
            dataset,
            self._inputs_spec,
            self._labels_spec,
            batch_size=batch_size,
            rank=self._cur_rank,
        )
730 731 732

        self._optimization_tuner.tune()

733
        if self._tuning.run_after_tuning:
734 735
            # update the strategy
            self._dist_contexts[
736 737
                mode
            ]._strategy = self._optimization_tuner.get_best_config()
738

739 740 741 742 743 744
    def _plan(self, mode):
        if self._planned_mode is None:
            self._planned_mode = mode
        else:
            self._init_dist_context(mode)

745 746
        self._planners[mode] = Planner(mode, self._dist_contexts[mode])
        self._planners[mode].plan()
747

748 749 750 751
        # infer data parallel info
        inputs_var = self._dist_contexts[mode].serial_feed_vars["inputs"]
        labels_var = self._dist_contexts[mode].serial_feed_vars["labels"]
        block = self._dist_contexts[mode].serial_main_program.global_block()
752
        # TODO: check this feed_list
753 754 755 756 757
        feed_list = []
        for var in inputs_var + labels_var:
            if var.name in block.vars:
                feed_list.append(block.vars[var.name])

758 759
        self._dp_world_sizes = []
        self._dp_ranks = []
760
        for feed_var in feed_list:
761
            dp_world_size, dp_rank = auto_utils.get_input_split_info(
762
                self._cur_rank, feed_var, self._dist_contexts[mode]
763
            )
764 765
            self._dp_world_sizes.append(dp_world_size)
            self._dp_ranks.append(dp_rank)
766

767
    def _parallel(self, mode, all_ranks=False):
768
        # Parallelize program based on the planner's results
L
Leo Chen 已提交
769
        # For now, the completer has to be passed to the Parallelizer,
C
chenxujun 已提交
770
        # because we may use it to complete the annotation of the backward and update.
771
        parallelizer = Parallelizer(
Y
yuehuayingxueluo 已提交
772 773 774
            mode,
            self._planners[mode].completer,
            self._dist_contexts[mode],
775
        )
776
        if not all_ranks:
777
            parallelizer.parallel(self._cur_rank, self._parameter_list)
778
        else:
779
            parallelizer.parallel_all(self._parameter_list)
780 781

    def _init_dist_context(self, mode):
782
        # Init dist_context['mode'] with the first planned dist_context
783 784 785 786 787 788 789 790 791 792
        # to guarantee that train/eval/predict mode have same parallel strategy
        dist_context = self._dist_contexts[mode]
        origin_main_prog = dist_context._original_serial_main_program
        ref_mode = self._planned_mode
        ref_dist_context = self._dist_contexts[ref_mode]
        ref_origin_main_prog = ref_dist_context._original_serial_main_program
        ref_blocks = ref_origin_main_prog.blocks
        for ib, block in enumerate(origin_main_prog.blocks):
            for iop, op in enumerate(block.ops):
                ref_op = ref_blocks[ib].ops[iop]
793 794 795 796 797 798 799 800
                assert (
                    op.type == ref_op.type
                ), "'{}' mode op '{}' is different with '{}' op '{}'. ".format(
                    mode, op.type, ref_mode, ref_op.type
                )
                ref_op_dist_attr = (
                    ref_dist_context.get_op_dist_attr_for_program(ref_op)
                )
801 802
                dist_context.set_op_dist_attr_for_program(op, ref_op_dist_attr)

803
    def _init_comm(self):
804 805 806 807
        if self._nranks > 1:
            # Traverse different rank programs and traverse each op of them,
            # instantiate communication by process_mapping.
            all_process_groups = get_all_process_groups()
808

809
            if self._strategy.auto_mode == "full_random":
810
                auto_utils.initialize_pg_in_full_mode(
811
                    all_process_groups, self._cur_rank
812
                )
813 814 815
            else:
                for process_group in all_process_groups:
                    process_group.instantiate()
816

817
    def _initialize(self, mode):
818
        self._place = _get_device()
819
        if isinstance(self._place, paddle.framework.CUDAPlace):
820 821 822
            self._place = paddle.framework.CUDAPlace(
                paddle.distributed.ParallelEnv().dev_id
            )
823

824 825 826 827 828
        if self._strategy.seed:
            paddle.seed(self._strategy.seed + self._dp_ranks[0])
            np.random.seed(self._strategy.seed + self._dp_ranks[0])
            random.seed(self._strategy.seed + self._dp_ranks[0])

829
        dist_context = self._dist_contexts[mode]
830
        if self._dygraph_mode:
831
            dist_main_program = dist_context.dist_main_programs[self._cur_rank]
832 833 834
            self.program_helper.init(
                dist_main_program, self._place, dist_context
            )
835

836
        if self._executor is None:
837
            self._executor = paddle.static.Executor(self._place)
838
            uninitialized = []
839 840 841
            dist_startup_prog = dist_context.dist_startup_programs[
                self._cur_rank
            ]
842 843 844 845 846 847 848 849
            for var in dist_startup_prog.list_vars():
                scope_var = global_scope().find_var(var.name)
                if scope_var and scope_var.get_tensor()._is_initialized():
                    continue
                uninitialized.append(var)
            if uninitialized:
                prune_startup_prog = dist_startup_prog._prune(uninitialized)
                self._executor.run(prune_startup_prog)
850

851
            if hasattr(self, "_state_dict") and hasattr(self, "_dist_attr"):
852 853 854
                self._set_state_dict(
                    mode, self._strict, self._state_dict, self._dist_attr
                )
855 856

        if self._strategy.reinit:
Z
zhaoyingli 已提交
857
            self._logger.info("NOTE: parameters will be re-initialized.")
858 859 860
            dist_startup_prog = dist_context.dist_startup_programs[
                self._cur_rank
            ]
861 862
            self._executor.run(dist_startup_prog)

863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879
    def fit(
        self,
        train_data,
        train_sample_split=None,
        batch_size=1,
        epochs=1,
        steps_per_epoch=None,
        log_freq=10,
        save_dir=None,
        save_freq=1,
        valid_data=None,
        valid_sample_split=None,
        valid_freq=1,
        valid_steps=None,
        collate_fn=None,
        callbacks=None,
        verbose=2,
880
        nvprof_range=[-1, -1],
881
    ):
882 883 884 885 886 887 888 889
        """
        Trains the model for a fixed number of epochs. If `valid_data` is set,
        evaluation will be done at the end of each epoch.

        Args:
            train_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None.
            train_sample_split (int, optional): Each sample of the train dataset is assumed
                to be a (input, label) pair by default and has two items. If each sample has
890
                more than two items, train_sample_split specifies how to split these items into
891
                input and label. The items before it are input and the left are label. Default: None.
892
            batch_size (int, optional): The batch size of train_data and valid_data if provided.
893 894 895
                The user's data will be used directly without batching if set to None. Default: 1.
            epochs (int, optional): The number of epochs to train the model. Default: 1.
            steps_per_epoch (int, optional): The total number of steps (batches of samples)
896
                is executed in one epoch before stating the next one. If None, it is equal to
897 898
                the number samples in your dataset divided by the batch size. Default: None.
            valid_data (Dataset, optional): An instance of paddle paddle.io.Dataset used for
899
                evaluation at the end of epoch. No evaluation will be done if set to None.
900
                Default: None. (Unsupported for now)
901
            valid_freq (int, optional): Only relevant if valid_data is provided. This specifies
902 903
                how many training epochs before a new evaluation is performed. Default: 1.
            valid_sample_split (int, optional): Only relevant if valid_data is provided.
904 905
                Each sample of the valid dataset is assumed to be a (input, label) pair
                by default and has two items. If each sample has more than two items,
906 907 908
                valid_sample_split specifies how to split these items into input and label.
                The items before it are input and the left are label. Default: None.
            valid_steps (int, optional): Only relevant if valid_data is provided.
909 910
                It is the total number of steps (batches of samples) to draw before
                stopping validation at the end of every epoch. If None, validation will run until the
911 912 913 914
                `valid_data` dataset is exhausted. The validation will start from the
                beginning of the dataset at each epoch. Default: None.
            collate_fn(callable, optional): function to generate mini-batch data by merging
                the sample list, None for only stack each fields of sample in axis
915
                0. Default None.
916 917
            callbacks (Callback|None, optional): A list of `Callback` instances to apply
                during training. Default: None. (Unused for now)
918
            nvprof_range(list, optional): A list of integers indicating nvprof ranges in form of [start_step, end_step]. Note that if start_step >= end_step, the nvprof will not apply.
919 920 921 922 923 924 925 926

        Returns:
            None

        Examples:

            .. code-block:: python

927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947
                >>> import paddle
                >>> import paddle.vision.transforms as T
                >>> from paddle.distributed.fleet import auto
                >>> from paddle.vision.datasets import MNIST

                >>> transform = T.Compose([
                ...     T.Transpose(),
                ...     T.Normalize([127.5], [127.5])
                >>> ])
                >>> train_dataset = MNIST(mode='train', transform=transform)

                >>> model = paddle.vision.models.LeNet()
                >>> loss = paddle.nn.CrossEntropyLoss()
                >>> optimizer = paddle.optimizer.Adam(
                ...     learning_rate=0.001, parameters=model.parameters())
                >>> metrics = paddle.metric.Accuracy(topk=(1, 2))

                >>> engine = auto.Engine(model, loss, optimizer, metrics)
                >>> engine.fit(train_dataset,
                ...             epochs=2,
                ...             batch_size=64)
948
        """
949
        self._mode = 'train'
950

951
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
952 953
            train_data, train_sample_split, batch_size
        )
954

955 956
        if not self._has_prepared[self._mode]:
            self._prepare_program(self._mode)
Z
zhaoyingli 已提交
957
        else:
958
            self._switch_mode(self._mode)
Z
zhaoyingli 已提交
959

960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988
        if auto_utils.use_new_executor():
            local_batch_size = self._validate_batch_size(batch_size)
            train_dataloader = self._prepare_dataloader(
                train_data,
                return_list=False,
                batch_size=local_batch_size,
                epochs=epochs,
                collate_fn=collate_fn,
            )
            steps_per_epoch = (
                len(train_dataloader)
                if steps_per_epoch is None
                else steps_per_epoch
            )
        else:
            micro_batch_size = self._validate_batch_size(batch_size)
            train_dataloader = self._prepare_dataloader_from_generator(
                dataset=train_data,
                capacity=70,
                iterable=False,
                batch_size=micro_batch_size,
                epochs=epochs,
                steps_per_epoch=steps_per_epoch,
                collate_fn=collate_fn,
            )
            steps_per_epoch = train_dataloader._steps
            local_batch_size = micro_batch_size
            if self._strategy.pipeline.enable:
                local_batch_size = micro_batch_size * self._acc_steps
Z
zhaoyingli 已提交
989

990
        fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode)
Z
zhaoyingli 已提交
991 992 993 994

        cbks = config_callbacks(
            callbacks,
            engine=self,
995
            batch_size=local_batch_size,
Z
zhaoyingli 已提交
996
            epochs=epochs,
997
            steps=steps_per_epoch,
Z
zhaoyingli 已提交
998 999 1000 1001 1002
            log_freq=log_freq,
            save_freq=save_freq,
            save_dir=save_dir,
            verbose=verbose,
            metrics=self._metrics_name(),
1003 1004 1005
            acc_step=1
            if self._strategy.pipeline.enable
            else self._acc_steps,  # lr update once every local batch
Z
zhaoyingli 已提交
1006 1007 1008 1009 1010 1011
        )

        cbks.on_begin('train')
        for epoch in range(epochs):
            logs = {}
            cbks.on_epoch_begin(epoch)
1012

1013
            for step, batch in enumerate(train_dataloader):
1014
                if auto_utils.use_new_executor():
1015
                    batches = self._validate_batch(batch)
1016
                else:
1017
                    batches = [{}]
1018 1019

                try:
1020
                    for micro_batch in batches:
1021 1022 1023 1024 1025 1026 1027 1028
                        with paddle.profiler.utils._nvprof_range(
                            iter_id=step,
                            start=nvprof_range[0],
                            end=nvprof_range[1],
                        ):
                            cbks.on_batch_begin('train', step, logs)
                            outs = self._executor.run(
                                self.main_program,
1029
                                feed=micro_batch,
1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051
                                fetch_list=fetch_names,
                                use_program_cache=self._strategy.use_cache,
                                return_numpy=self._strategy.return_numpy,
                            )
                            lr = auto_utils.get_lr(self.optimizer)
                            logs = self._prepare_logger(
                                outs,
                                epoch,
                                step,
                                lr,
                                fetch_names,
                                fetch_indices,
                                self._mode,
                            )
                            cbks.on_batch_end('train', step, logs)
                except core.EOFException:
                    break

                if steps_per_epoch and step >= steps_per_epoch:
                    if not auto_utils.use_new_executor():
                        train_dataloader._reset()
                    break
Z
zhaoyingli 已提交
1052 1053

            if valid_data and (epoch + 1) % valid_freq == 0:
1054 1055 1056 1057 1058 1059 1060 1061 1062 1063
                val_logs = self.evaluate(
                    valid_data,
                    valid_sample_split,
                    batch_size,
                    valid_steps,
                    log_freq,
                    collate_fn,
                    callbacks,
                    verbose,
                )
Z
zhaoyingli 已提交
1064
                val_logs = {
1065
                    "val_" + name: val for name, val in val_logs.items()
Z
zhaoyingli 已提交
1066 1067 1068 1069 1070 1071 1072 1073 1074 1075
                }
                logs.update(val_logs)
                self._switch_mode("train")
            else:
                self._reset_metrics()

            cbks.on_epoch_end(epoch, logs)

        cbks.on_end('train', logs)
        return self.history
1076

1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087
    def evaluate(
        self,
        valid_data,
        valid_sample_split=None,
        batch_size=1,
        steps=None,
        log_freq=10,
        collate_fn=None,
        callbacks=None,
        verbose=2,
    ):
1088 1089 1090 1091
        """
        Evaluate the loss and metrics of the model on evaluation data.

        Args:
1092 1093
            valid_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None.
            valid_sample_split (int, optional): Each sample of the eval dataset is assumed
1094
                to be a (input, label) pair by default and has two items. If each sample has
1095
                more than two items, valid_sample_split specifies how to split these items into
1096
                input and label. The items before it are input and the left are label. Default: None.
1097
            batch_size (int, optional): The batch size of valid_data. The user's data will
1098
                be used directly without batching if set to None. Default: 1.
1099 1100
            steps (int, optional): It is the total number of steps (batches of samples) to draw before
                stopping evaluation. If None, evaluation will run until the `valid_data` dataset is exhausted.
1101 1102 1103 1104 1105
                The evaluation will start from the beginning of the dataset in each run. Default: None.
            collate_fn(callable, optional): function to generate mini-batch data by merging
                the sample list, None for only stack each fields of sample in axis
                0. Default None.
            callbacks (Callback|None, optional): A list of `Callback` instances to apply
1106
                during evaluating. Default: None. (Unused for now)
1107 1108 1109 1110 1111 1112 1113 1114

        Returns:
            None

        Examples:

            .. code-block:: python

1115 1116 1117 1118
                >>> import paddle
                >>> import paddle.vision.transforms as T
                >>> from paddle.distributed.fleet import auto
                >>> from paddle.vision.datasets import MNIST
1119

1120 1121 1122 1123 1124
                >>> transform = T.Compose([
                ...     T.Transpose(),
                ...     T.Normalize([127.5], [127.5])
                >>> ])
                >>> valid_dataset = MNIST(mode='test', transform=transform)
1125

1126 1127 1128
                >>> model = paddle.vision.models.LeNet()
                >>> loss = paddle.nn.CrossEntropyLoss()
                >>> metrics = paddle.metric.Accuracy(topk=(1, 2))
1129

1130 1131
                >>> engine = auto.Engine(model, loss, metrics=metrics)
                >>> engine.evaluate(valid_dataset, batch_size=64)
1132 1133

        """
1134 1135
        self._mode = 'eval'
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
1136 1137
            valid_data, valid_sample_split, batch_size
        )
1138

1139 1140
        if not self._has_prepared[self._mode]:
            self._prepare_program(self._mode)
Z
zhaoyingli 已提交
1141
        else:
1142
            self._switch_mode(self._mode)
Z
zhaoyingli 已提交
1143

1144
        micro_batch_size = self._validate_batch_size(batch_size)
1145 1146 1147 1148
        valid_dataloader = self._prepare_dataloader_from_generator(
            dataset=valid_data,
            capacity=70,
            iterable=False,
1149
            batch_size=micro_batch_size,
1150
            steps_per_epoch=steps,
1151 1152
            collate_fn=collate_fn,
        )
Z
zhaoyingli 已提交
1153

1154
        fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode)
1155

Z
zhaoyingli 已提交
1156 1157 1158
        cbks = config_callbacks(
            callbacks,
            engine=self,
1159
            batch_size=micro_batch_size,
Z
zhaoyingli 已提交
1160 1161 1162 1163 1164 1165
            log_freq=log_freq,
            verbose=verbose,
            metrics=self._metrics_name(),
        )

        eval_steps = valid_dataloader._steps
1166 1167 1168
        cbks.on_begin(
            'eval', {'steps': eval_steps, 'metrics': self._metrics_name()}
        )
Z
zhaoyingli 已提交
1169
        logs = {}
1170
        for step, _ in enumerate(valid_dataloader):
Z
zhaoyingli 已提交
1171
            cbks.on_batch_begin('eval', step, logs)
1172
            try:
1173 1174
                outs = self._executor.run(
                    self.main_program,
1175
                    fetch_list=fetch_names,
1176
                    use_program_cache=self._strategy.use_cache,
1177 1178
                    return_numpy=self._strategy.return_numpy,
                )
1179
            except core.EOFException:
1180
                break
1181 1182 1183
            logs = self._prepare_logger(
                outs, None, step, None, fetch_names, fetch_indices, self._mode
            )
Z
zhaoyingli 已提交
1184 1185
            cbks.on_batch_end('eval', step, logs)
        cbks.on_end('eval', logs)
1186
        self._reset_metrics()
Z
zhaoyingli 已提交
1187
        return logs
1188

1189 1190 1191 1192 1193 1194 1195 1196 1197 1198
    def predict(
        self,
        test_data,
        test_sample_split=None,
        batch_size=1,
        steps=None,
        collate_fn=None,
        callbacks=None,
        verbose=2,
    ):
1199 1200 1201 1202 1203 1204 1205
        """
        Compute the output predictions on testing data.

        Args:
            test_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None.
            test_sample_split (int, optional): Each sample of the test dataset is assumed
                to be a (input, label) pair by default and has two items. If each sample has
1206
                more than two items, test_sample_split specifies how to split these items into
1207 1208 1209
                input and label. The items before it are input and the left are label. Default: None.
            batch_size (int, optional): The batch size of test_data. The user's data will
                be used directly without batching if set to None. Default: 1.
1210 1211
            steps (int, optional): It is the total number of steps (batches of samples) to draw before
                stopping predict. If None, predict will run until the `test_data` dataset is exhausted.
1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225
                The predict will start from the beginning of the dataset in each run. Default: None.
            collate_fn(callable, optional): function to generate mini-batch data by merging
                the sample list, None for only stack each fields of sample in axis
                0. Default None.
            callbacks (Callback|None, optional): A list of `Callback` instances to apply
                during testing. Default: None. (Unused for now)

        Returns:
            None

        Examples:

            .. code-block:: python

1226 1227 1228 1229
                >>> import paddle
                >>> import paddle.vision.transforms as T
                >>> from paddle.distributed.fleet import auto
                >>> from paddle.vision.datasets import MNIST
1230

1231 1232 1233 1234 1235
                >>> transform = T.Compose([
                ...     T.Transpose(),
                ...     T.Normalize([127.5], [127.5])
                >>> ])
                >>> valid_dataset = MNIST(mode='test', transform=transform)
1236

1237
                >>> model = paddle.vision.models.LeNet()
1238

1239 1240
                >>> engine = auto.Engine(model)
                >>> engine.predict(valid_dataset, batch_size=64)
1241
        """
1242 1243
        self._mode = 'predict'
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
1244 1245
            test_data, test_sample_split, batch_size
        )
1246

1247 1248
        if not self._has_prepared[self._mode]:
            self._prepare_program(self._mode)
Z
zhaoyingli 已提交
1249
        else:
1250
            self._switch_mode(self._mode)
Z
zhaoyingli 已提交
1251

1252
        micro_batch_size = self._validate_batch_size(batch_size)
1253 1254 1255 1256
        test_dataloader = self._prepare_dataloader_from_generator(
            dataset=test_data,
            capacity=70,
            iterable=False,
1257
            batch_size=micro_batch_size,
1258
            steps_per_epoch=steps,
1259 1260
            collate_fn=collate_fn,
        )
Z
zhaoyingli 已提交
1261

1262
        fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode)
1263

Z
zhaoyingli 已提交
1264 1265 1266 1267 1268
        outputs = []
        cbks = config_callbacks(callbacks, engine=self, verbose=verbose)
        test_steps = test_dataloader._steps
        cbks.on_begin('predict', {'steps': test_steps})
        logs = {}
1269
        for step, _ in enumerate(test_dataloader):
Z
zhaoyingli 已提交
1270
            cbks.on_batch_begin('predict', step, logs)
1271
            try:
1272 1273
                outs = self._executor.run(
                    self.main_program,
1274
                    fetch_list=fetch_names,
1275
                    use_program_cache=self._strategy.use_cache,
1276 1277
                    return_numpy=self._strategy.return_numpy,
                )
1278
            except core.EOFException:
1279
                break
1280 1281 1282
            logs = self._prepare_logger(
                outs, None, step, None, fetch_names, fetch_indices, self._mode
            )
Z
zhaoyingli 已提交
1283 1284 1285 1286 1287
            cbks.on_batch_end('predict', step, logs)
            outputs.append(list(logs["outputs"].values()))
        cbks.on_end('predict', logs)
        return outputs

1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303
    def dataloader(
        self,
        dataset,
        batch_size=1,
        shuffle=False,
        drop_last=False,
        collate_fn=None,
        num_workers=0,
        use_buffer_reader=True,
        use_shared_memory=True,
        timeout=0,
        worker_init_fn=None,
        epochs=1,
        steps_per_epoch=None,
        sample_split=1,
        mode=None,
1304
        places=None,
1305
    ):
1306 1307
        if mode is not None:
            self.to_mode(mode)
1308

1309
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
1310 1311
            dataset, sample_split, batch_size
        )
1312

1313 1314
        if not self._has_prepared[self._mode]:
            self._prepare_program(self._mode)
1315
        else:
1316
            self._switch_mode(self._mode)
1317

1318
        batch_size = self._validate_batch_size(batch_size)
1319 1320 1321
        dataloader = self._prepare_dataloader(
            dataset,
            return_list=False,
1322
            batch_size=batch_size,
1323 1324 1325 1326 1327 1328 1329 1330 1331
            shuffle=shuffle,
            drop_last=drop_last,
            collate_fn=collate_fn,
            num_workers=num_workers,
            use_buffer_reader=use_buffer_reader,
            use_shared_memory=use_shared_memory,
            timeout=timeout,
            worker_init_fn=worker_init_fn,
            epochs=epochs,
1332
            steps_per_epoch=steps_per_epoch,
1333
            places=places,
1334
        )
1335 1336
        return dataloader

1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351
    def dataloader_from_generator(
        self,
        dataset,
        capacity=70,
        use_double_buffer=True,
        iterable=True,
        use_multiprocess=False,
        drop_last=True,
        batch_size=1,
        epochs=1,
        steps_per_epoch=None,
        collate_fn=None,
        sample_split=1,
        mode=None,
    ):
1352 1353 1354
        if mode is not None:
            self.to_mode(mode)
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
1355 1356
            dataset, sample_split, batch_size
        )
1357

1358 1359 1360 1361
        if not self._has_prepared[self._mode]:
            self._prepare_program(self._mode)
        else:
            self._switch_mode(self._mode)
1362

1363
        micro_batch_size = self._validate_batch_size(batch_size)
1364 1365 1366 1367 1368 1369 1370 1371
        dataloader = self._prepare_dataloader_from_generator(
            dataset=dataset,
            capacity=capacity,
            use_double_buffer=use_double_buffer,
            iterable=iterable,
            return_list=False,
            use_multiprocess=use_multiprocess,
            drop_last=drop_last,
1372
            batch_size=micro_batch_size,
1373 1374
            epochs=epochs,
            steps_per_epoch=steps_per_epoch,
1375 1376
            collate_fn=collate_fn,
        )
1377 1378
        return dataloader

1379 1380 1381 1382 1383 1384 1385 1386 1387
    def prepare(
        self,
        inputs_spec=None,
        labels_spec=None,
        inputs=None,
        labels=None,
        main_program=None,
        startup_program=None,
        mode=None,
1388
        init_parameters=True,
1389
    ):
1390 1391
        if mode is not None:
            self.to_mode(mode)
1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407

        if not self._mode:
            raise ValueError(
                "Please set mode to be prepared with `prepare(mode=...)`"
            )

        if self._has_prepared[self._mode]:
            return

        inputs_spec = self._validate_spec(inputs_spec)
        labels_spec = self._validate_spec(labels_spec)
        inputs = self._validate_vars(inputs)
        labels = self._validate_vars(labels)

        self._orig_main_prog = main_program
        self._orig_startup_prog = startup_program
1408 1409
        if inputs or labels:
            self._skip_build = True
1410 1411
            inputs, labels = self._prepare_data_tensor(
                inputs_spec, labels_spec, inputs, labels
1412
            )
1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423
            if self._orig_main_prog is None:
                self._orig_main_prog = static.default_main_program()
            if self._orig_startup_prog is None:
                self._orig_startup_prog = static.default_startup_program()
        elif inputs_spec or labels_spec:
            self._outside_dataloader = True
            if self._orig_main_prog is None:
                self._orig_main_prog = static.default_main_program()
            if self._orig_startup_prog is None:
                self._orig_startup_prog = static.default_startup_program()
        else:
1424 1425 1426
            assert (
                self._inputs_spec and self._labels_spec
            ), "Please call the dataloader(...) before calling prepare(...)"
1427

1428 1429 1430
        self._inputs_spec, self._labels_spec = inputs_spec, labels_spec
        self._inputs, self._labels = inputs, labels
        if not self._has_prepared[self._mode]:
1431
            self._prepare_program(self._mode, init_parameters)
1432 1433 1434
        else:
            self._switch_mode(self._mode)

1435
    def run(self, data=None, feed=None, fetch_list=None, mode=None):
1436 1437 1438 1439
        if mode is not None:
            self.to_mode(mode)
        feed_dict = self._prepare_feed(data, feed, self._mode)
        fetch_names, fetch_indices = self._prepare_fetch(fetch_list, self._mode)
1440 1441 1442 1443
        if (
            self._outside_dataloader
            and not self._has_prepared_reader[self._mode]
        ):
1444
            self._prepare_reader()
1445 1446 1447 1448 1449 1450 1451 1452 1453 1454
        outs = self._executor.run(
            self.main_program,
            feed=feed_dict,
            fetch_list=fetch_names,
            use_program_cache=self._strategy.use_cache,
            return_numpy=self._strategy.return_numpy,
        )
        logs = self._prepare_logger(
            outs, None, None, None, fetch_names, fetch_indices, self._mode
        )
Z
zhaoyingli 已提交
1455
        return logs
1456

1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471
    def _prepare_dataloader(
        self,
        dataset,
        return_list=True,
        batch_size=1,
        shuffle=False,
        drop_last=False,
        collate_fn=None,
        num_workers=0,
        use_buffer_reader=True,
        use_shared_memory=True,
        timeout=0,
        worker_init_fn=None,
        epochs=1,
        steps_per_epoch=None,
1472
        places=None,
1473
    ):
1474 1475 1476
        dist_context = self._dist_contexts[self._mode]
        dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
        dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank]
1477
        dist_main_block = dist_main_prog.global_block()
1478

1479 1480 1481 1482
        # NOTE: Get feed_list, then insert dataloader op with sharded var shape.
        # Cause predict_program does not contain labels var,
        # then we will add labels var from serial_program to dist_program,
        # that maintains the length of feed_list equal to the length of dataset's values.
1483 1484
        inputs_var = dist_context.serial_feed_vars["inputs"]
        labels_var = dist_context.serial_feed_vars["labels"]
1485 1486 1487 1488
        feed_list = []
        for var in inputs_var + labels_var:
            if var.name in dist_main_block.vars:
                feed_list.append(dist_main_block.vars[var.name])
1489 1490 1491 1492
            else:
                copy_var = dist_main_block._clone_variable(var, var.persistable)
                copy_var.desc.set_original_id(var.desc.original_id())
                feed_list.append(copy_var)
1493 1494

        # insert read op at the end of program
1495
        with static.program_guard(dist_main_prog, dist_startup_prog):
1496
            dataloader = DistributedDataLoader(
1497
                dataset,
1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512
                feed_list=feed_list,
                places=places,
                return_list=return_list,
                batch_size=batch_size,
                shuffle=shuffle,
                drop_last=drop_last,
                collate_fn=collate_fn,
                num_workers=num_workers,
                use_buffer_reader=use_buffer_reader,
                use_shared_memory=use_shared_memory,
                timeout=timeout,
                worker_init_fn=worker_init_fn,
                epochs=epochs,
                steps_per_epoch=steps_per_epoch,
                split_data=self._strategy.split_data,
1513
                data_parallel_world_size=self._dp_world_sizes,
1514 1515
                data_parallel_rank=self._dp_ranks,
            )
1516

1517 1518
        return dataloader

1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532
    def _prepare_dataloader_from_generator(
        self,
        dataset,
        capacity=None,
        use_double_buffer=True,
        iterable=True,
        return_list=False,
        use_multiprocess=False,
        drop_last=True,
        batch_size=1,
        epochs=1,
        steps_per_epoch=None,
        collate_fn=None,
    ):
1533 1534 1535
        dist_context = self._dist_contexts[self._mode]
        dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
        dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank]
1536 1537 1538 1539 1540 1541
        dist_main_block = dist_main_prog.global_block()

        # NOTE: Get feed_list, then insert dataloader op with sharded var shape.
        # Cause predict_program does not contain labels var,
        # then we will add labels var from serial_program to dist_program,
        # that maintains the length of feed_list equal to the length of dataset's values.
1542 1543
        inputs_var = dist_context.serial_feed_vars["inputs"]
        labels_var = dist_context.serial_feed_vars["labels"]
1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570
        feed_list = []
        for var in inputs_var + labels_var:
            if var.name in dist_main_block.vars:
                feed_list.append(dist_main_block.vars[var.name])
            else:
                copy_var = dist_main_block._clone_variable(var, var.persistable)
                copy_var.desc.set_original_id(var.desc.original_id())
                feed_list.append(copy_var)

        places = paddle.static.cuda_places()
        with static.program_guard(dist_main_prog, dist_startup_prog):
            dataloader = DistributedDataLoaderFromGenerator(
                dataset=dataset,
                feed_list=feed_list,
                capacity=capacity,
                use_double_buffer=use_double_buffer,
                iterable=iterable,
                return_list=return_list,
                use_multiprocess=use_multiprocess,
                drop_last=drop_last,
                places=places,
                batch_size=batch_size,
                epochs=epochs,
                steps_per_epoch=steps_per_epoch,
                collate_fn=collate_fn,
                split_data=self._strategy.split_data,
                data_parallel_world_size=self._dp_world_sizes,
1571
                data_parallel_rank=self._dp_ranks,
1572 1573 1574
                acc_steps=1
                if not self._strategy.pipeline.enable
                else self._acc_steps,
1575
            )
1576
        self._prepare_reader(feed_list)
1577 1578 1579 1580 1581
        return dataloader

    def _tune(self, tune_data, tune_sample_split=None, batch_size=1):
        self._mode = 'train'
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
1582 1583
            tune_data, tune_sample_split, batch_size
        )
1584 1585
        self._optimization_tuning(self._mode, tune_data, batch_size)

1586 1587 1588 1589
    def _validate_batch_size(self, batch_size):
        if batch_size is None:
            return None

1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611
        if auto_utils.use_new_executor():
            assert (
                len(set(self._dp_world_sizes)) == 1
            ), "DistributedBatchSampler only support one data parallel group, but got [{}] different data parallel groups".format(
                len(set(self._dp_world_sizes))
            )
            assert (
                batch_size % self._dp_world_sizes[0] == 0
            ), "batch_size [{}] is not divisible by dp_world_size [{}]".format(
                str(batch_size), str(self._dp_world_sizes[0])
            )
            return batch_size // self._dp_world_sizes[0]
        else:
            assert (
                batch_size % self._acc_steps == 0
            ), "Requires batch_size:[{}] to be divisible by acc_steps:[{}].".format(
                batch_size, self._acc_steps
            )
            return batch_size // self._acc_steps

    def _validate_batch(self, batch):
        if batch is None:
1612
            return [None]
1613

1614
        if self._strategy.pipeline.enable or self._acc_steps == 1:
1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630
            # pp with schedule or navie-pp
            return batch
        else:
            # split feed data with gradient_merge k_steps
            feed_names = []
            split_batches = []
            for feed_name, cur_feed in batch[0].items():
                feed_names.append(feed_name)
                split_batches.append(
                    np.split(np.array(cur_feed), self._acc_steps, 0)
                )
            baches = []
            for i in range(self._acc_steps):
                micro_batch = [split_batch[i] for split_batch in split_batches]
                baches.append(dict(zip(feed_names, micro_batch)))
            return baches
1631

1632
    def _validate_spec(self, specs):
1633
        specs = auto_utils.to_list(specs)
1634 1635
        if specs is not None:
            for i, spec in enumerate(specs):
1636 1637 1638 1639
                if not isinstance(spec, InputSpec):
                    raise TypeError(
                        "'spec' must be object of class `paddle.static.InputSpec`."
                    )
1640 1641
                if spec.name is None:
                    raise ValueError(
1642 1643 1644 1645
                        "Requires Input[{}].name != None, but receive `None` with {}.".format(
                            i, spec
                        )
                    )
1646
                if self._acc_steps > 1:
1647
                    shape = list(spec.shape)
1648
                    assert (
1649
                        shape[0] % self._acc_steps == 0
1650
                    ), "Requires batch_size[{}] to be divisible by k_steps[{}].".format(
1651
                        spec.shape[0], self._acc_steps
1652
                    )
1653
                    shape[0] //= self._acc_steps
1654
                    spec.shape = shape
1655 1656 1657
        return specs or []

    def _validate_vars(self, vars):
1658
        vars = auto_utils.to_list(vars)
1659 1660 1661 1662 1663
        if vars is not None:
            for i, var in enumerate(vars):
                if not isinstance(var, Variable):
                    raise TypeError("'var' must be a `Variable`.")
        return vars or []
1664

1665 1666 1667 1668
    def _is_local_var(self, var):
        var_name = _to_name_str(var)
        return var_name in self.main_program.global_block().vars

1669 1670 1671 1672
    def _reset_metrics(self):
        for metric in self._metrics:
            metric.reset()

Z
zhaoyingli 已提交
1673 1674 1675
    def _metrics_name(self):
        metrics_name = ['loss'] if self._loss else []
        for m in self._metrics:
1676
            metrics_name.extend(auto_utils.to_list(m.name()))
Z
zhaoyingli 已提交
1677 1678
        return metrics_name

1679
    def _switch_mode(self, mode):
1680
        assert (
1681
            mode in self._dist_contexts
1682
        ), f"{mode} model is not ready, please call `prepare()` first."
1683
        self.to_mode(mode)
1684

1685
    def to_mode(self, mode):
1686 1687 1688 1689
        assert mode in [
            "train",
            "eval",
            "predict",
1690
        ], f"mode {mode} should be one of ['train', 'eval', 'predict']"
1691 1692
        self._mode = mode

1693 1694
    def _set_state_dict(self, mode, strict, state_dict, dist_attr):
        dist_context = self._dist_contexts[mode]
1695
        program = dist_context.dist_main_programs[self._cur_rank]
1696
        cur_dist_attr = auto_utils.get_dist_attr(program, dist_context)
1697 1698
        converter = Converter(state_dict, dist_attr, cur_dist_attr)
        state_dict = converter.convert(strict=strict)
1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711
        for name, param in program.state_dict().items():
            param_array = np.array(param)
            if name not in state_dict:
                continue
            if param_array.dtype != state_dict[name].dtype:
                self._logger.info(
                    "cast {}'s dtype from '{}' to '{}'".format(
                        name,
                        str(state_dict[name].dtype),
                        str(param_array.dtype),
                    )
                )
                state_dict[name] = state_dict[name].astype(param_array.dtype)
1712 1713 1714
        program.set_state_dict(state_dict)

    def save(self, path, training=True):
1715 1716
        """
        Saves the model, parameters, optimizer state to path.
1717 1718 1719 1720 1721 1722 1723
        If `training` is set to False, only inference model will be saved.

        Args:
            path (str): The file prefix to save model. The format
                is 'dirname/file_prefix' or 'file_prefix'. if empty str.
                A exception will be raised.
            training (bool, optional): Whether to save for training. If not, save
1724
                for inference only. If `training` is set to True, the optimizer state
1725 1726 1727 1728 1729 1730 1731 1732 1733 1734
                will be saved. Otherwise, only the model and parameters are saved.
                This function will silently overwrite existing file at the target
                location. Default: True.

        Returns:
            None

        Examples:

            .. code-block:: python
1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757

                >>> import paddle
                >>> import paddle.vision.transforms as T
                >>> from paddle.distributed.fleet import auto
                >>> from paddle.vision.datasets import MNIST

                >>> transform = T.Compose([
                ...     T.Transpose(),
                ...     T.Normalize([127.5], [127.5])
                >>> ])
                >>> train_dataset = MNIST(mode='train', transform=transform)

                >>> model = paddle.vision.models.LeNet()
                >>> loss = paddle.nn.CrossEntropyLoss()
                >>> optimizer = paddle.optimizer.Adam(
                ...     learning_rate=0.001, parameters=model.parameters())
                >>> metrics = paddle.metric.Accuracy(topk=(1, 2))

                >>> engine = auto.Engine(model, loss, optimizer, metrics)
                >>> engine.fit(train_dataset,
                ...             epochs=1,
                ...             batch_size=64)
                >>> engine.save("./my_model")
1758

1759
        """
1760
        if training:
1761
            assert self._mode in self._dist_contexts
Z
zhaoyingli 已提交
1762
            dist_context = self._dist_contexts[self._mode]
1763 1764
            serial_program = dist_context.serial_main_program
            dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
1765 1766 1767 1768 1769 1770
            self._saver.save(
                path,
                serial_program=serial_program,
                dist_main_program=dist_main_prog,
                dist_context=dist_context,
            )
1771
        else:
1772 1773 1774 1775 1776
            assert "predict" in self._dist_contexts
            dist_context = self._dist_contexts["predict"]
            feed_vars = dist_context.serial_feed_vars['inputs']
            fetch_vars = dist_context.serial_fetch_vars['outputs']
            dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
1777
            if self._strategy.qat.enable and self._strategy.qat.onnx_format:
1778
                from paddle.static.quantization import QuantWeightPass
1779 1780 1781

                self._logger.info("export quantized model.")
                self._logger.info(
1782
                    f"convert config {self._strategy.qat.to_dict()}"
1783 1784 1785 1786 1787 1788 1789 1790
                )
                test_graph = IrGraph(
                    core.Graph(dist_main_prog.desc), for_test=True
                )
                quant_weight_pass = QuantWeightPass(global_scope(), self._place)
                for sub_graph in test_graph.all_sub_graphs():
                    quant_weight_pass.apply(sub_graph)
                dist_main_prog = test_graph.to_program()
1791 1792 1793 1794 1795 1796 1797
            self._saver.save_inference_model(
                path,
                feed_vars,
                fetch_vars,
                self._executor,
                program=dist_main_prog,
            )
1798

1799 1800 1801 1802 1803 1804
    def load(self, path, strict=True, load_optimizer=True):
        """
        Load the stored model, parameters and optimizer states.

        Args:
            path (str): The prefix of files storing the model states and
1805
                optimizer states.
1806 1807 1808
            strict (bool, optional): Whether to skip the loading of mismatch
                parameter or raise an error when mismatch happens (not found
                the parameter in file storing model states of or receives a
1809
                mismatch shape). Default: True.
1810
            load_optimizer (bool, optional): If True, the stored optimizer
1811
                states is restored. Otherwise, the optimizer states is initialized
1812
                from scratch. Default: True.
1813 1814 1815 1816 1817 1818 1819

        Returns:
            None

        Examples:

            .. code-block:: python
1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843

                >>> import paddle
                >>> import paddle.vision.transforms as T
                >>> from paddle.distributed.fleet import auto
                >>> from paddle.vision.datasets import MNIST

                >>> transform = T.Compose([
                ...     T.Transpose(),
                ...     T.Normalize([127.5], [127.5])
                >>> ])
                >>> train_dataset = MNIST(mode='train', transform=transform)

                >>> model = paddle.vision.models.LeNet()
                >>> loss = paddle.nn.CrossEntropyLoss()
                >>> optimizer = paddle.optimizer.Adam(
                ...     learning_rate=0.001, parameters=model.parameters())
                >>> metrics = paddle.metric.Accuracy(topk=(1, 2))

                >>> engine = auto.Engine(model, loss, optimizer, metrics)
                >>> engine.fit(train_dataset,
                ...             epochs=1,
                ...             batch_size=64)
                >>> engine.save("./my_model")
                >>> engine.load("./my_model")
1844

1845 1846 1847
        """
        self._strict = strict
        self._state_dict, self._dist_attr = self._saver.load(
1848 1849
            path, load_optimizer
        )
1850
        return self._state_dict, self._dist_attr
1851

1852
    def cost(self, inputs_spec=None, labels_spec=None, mode=None):
1853 1854 1855 1856 1857 1858 1859 1860 1861 1862
        """
        Get and Print cost, including memory of every rank,
        max memory among all ranks, and the global cost of one step based on
        communication cost(computation cost is 0 by default).
        In the future, the flops information of every rank and global cost including
        computation cost will be added.

        Args:
            inputs_spec(InputSpec): The specification of inputs. Default: None.
            labels_spec(InputSpec): The specification of labels. Default: None.
1863
            mode (str): The engine mode must be in ["train", "predict", "eval"]. Default: None.
1864 1865 1866 1867 1868 1869 1870

        Returns:
            Return the global execution time (ms) and max memory (B).

        """
        # Check parallel mode
        if self._strategy.auto_mode == "full":
1871
            self._logger.info(
1872 1873 1874 1875 1876
                "The cost will be calcudated in the search process when the auto mode is full."
            )
            return

        # Check mode
1877 1878 1879
        mode = mode if mode is not None else self._mode
        assert mode is not None, "Please set mode."
        if mode not in self._has_prepared:
1880 1881
            raise ValueError(
                "The mode {} is not in accepted modes {}".format(
1882
                    mode, list(self._has_prepared.keys())
1883 1884
                )
            )
1885 1886
        self.to_mode(mode)

1887 1888 1889
        if inputs_spec is not None and not self._has_prepared[mode]:
            self._inputs_spec = self._validate_spec(inputs_spec)
            self._labels_spec = self._validate_spec(labels_spec)
1890 1891 1892
            self._build(mode)
            self._plan(mode)
        else:
1893
            if in_dynamic_mode() or self._dygraph_mode:
1894
                raise ValueError(
1895 1896 1897 1898 1899
                    "Please call `prepare()` or `fit()` or  `evaluate()` or  `predict()` before calling `cost()`."
                )
            else:
                self._logger.info(
                    "The program whose cost to be estimated must be static default program. Otherwise, please call `prepare()`before calling `cost()`."
1900
                )
1901 1902 1903 1904 1905 1906 1907 1908
                program = paddle.static.default_main_program()
                if (
                    not program.global_block().ops
                    or not program.global_block().ops
                ) and not self._has_prepared[mode]:
                    raise ValueError(
                        "Please call `prepare()` or `fit()` or  `evaluate()` or  `predict()` before calling `cost()`."
                    )
1909 1910 1911 1912 1913 1914

        # Estimate the exec cost and max memory
        global_cost, max_memory = get_cost_from_engine(self, mode)

        return global_cost.time, max_memory

1915 1916
    @property
    def main_program(self):
1917 1918
        dist_context = self._dist_contexts[self._mode]
        return dist_context.dist_main_programs[self._cur_rank]
1919 1920 1921

    @property
    def startup_program(self):
1922 1923
        dist_context = self._dist_contexts[self._mode]
        return dist_context.dist_startup_programs[self._cur_rank]
1924 1925 1926

    @property
    def dist_context(self):
1927
        return self._dist_contexts[self._mode]
1928 1929 1930

    @property
    def serial_main_program(self):
1931 1932
        dist_context = self._dist_contexts[self._mode]
        return dist_context.serial_main_program
1933 1934 1935

    @property
    def serial_startup_program(self):
1936 1937 1938 1939 1940 1941 1942
        dist_context = self._dist_contexts[self._mode]
        return dist_context.serial_startup_program

    @property
    def feed_vars(self):
        dist_context = self._dist_contexts[self._mode]
        return dist_context.serial_feed_vars
1943 1944 1945

    @property
    def fetch_vars(self):
1946 1947 1948 1949 1950 1951 1952 1953 1954
        dist_context = self._dist_contexts[self._mode]
        return dist_context.serial_fetch_vars

    @property
    def optimizer(self):
        dist_context = self._dist_contexts[self._mode]
        if dist_context._serial_optimizer:
            return dist_context._serial_optimizer
        return self._optimizer
1955 1956 1957

    @property
    def inputs(self):
1958
        return self._inputs
1959 1960 1961

    @property
    def labels(self):
1962
        return self._labels