engine.py 75.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import copy
16
import json
17
import logging
18
import numbers
19 20
import os
import random
21

22 23
import numpy as np

24
import paddle
25
import paddle.distributed.auto_parallel.static.utils as auto_utils
26
from paddle import static, utils
27
from paddle.distributed import fleet
28 29 30
from paddle.fluid.executor import _to_name_str
from paddle.framework import IrGraph
from paddle.framework import _current_expected_place as _get_device
31
from paddle.framework import core, in_dynamic_mode
32
from paddle.metric import Metric
33
from paddle.static import InputSpec, Operator, Variable, global_scope
34

35 36 37
from ...utils.log_utils import get_logger
from ..interface import CollectionNames, fetch, get_collection
from ..strategy import Strategy
Z
zhaoyingli 已提交
38
from .callbacks import config_callbacks
39
from .cluster import Cluster, get_default_cluster
40 41 42
from .converter import Converter
from .cost.estimate_cost import get_cost_from_engine
from .dist_context import DistributedContext, get_default_distributed_context
43 44
from .dist_loader import (
    DistributedDataLoader,
45
    DistributedDataLoaderFromGenerator,
46
)
47 48 49 50 51 52
from .dist_op import DistributedOperator
from .dist_saver import DistributedSaver
from .helper import ProgramHelper
from .parallelizer_v2 import Parallelizer
from .planner_v2 import Planner
from .process_group import get_all_process_groups, new_process_group
53

54 55

class Engine:
56
    """
J
JZ-LIANG 已提交
57 58
    An High-Level API for auto parallel, which could be used for distributed Training (engine.fit) and Inferenced (engine.predict).
    Static graph mode is supported natively, Dynamic graph mode is also supported under `@to_static <https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/jit/to_static_cn.html#to-static>`_ .
59 60 61 62 63

    Args:
        model (paddle.nn.Layer, optional): The model is an instance of
            paddle.nn.Layer.
        loss (Loss|Callable|None, optional): The loss can be a `paddle.nn.Layer`
64 65
            instance or any callable function taken the predicted values and
            ground truth values as input. It can be None when there is no loss.
66 67 68 69 70 71 72 73 74 75 76 77 78 79
            Default: None.
        optimizer (Optimizer|None, optional): The optimizer need to be set in training
            and should be None in eval and predict mode. Default: None.
        metrics (Metric|list[Metric]|None, optional): If metrics is set, all
            metrics will be calculated and output in train/eval mode. Default: None.
        cluster (Cluster|None, optional): The cluster represents the topology information
            about the used physical devices. Default: None. (Unused for now)
        strategy (Strategy|None, optional): The strategy is used to configure the
        parallelization and optimization behaviors. Default: None.

    Examples:

        .. code-block:: python

80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
            >>> import paddle
            >>> import paddle.vision.transforms as T
            >>> from paddle.distributed.fleet import auto
            >>> from paddle.vision.datasets import MNIST

            >>> transform = T.Compose([
            ...     T.Transpose(),
            ...     T.Normalize([127.5], [127.5])
            >>> ])
            >>> train_dataset = MNIST(mode='train', transform=transform)
            >>> valid_dataset = MNIST(mode='test', transform=transform)

            >>> model = paddle.vision.models.LeNet()
            >>> loss = paddle.nn.CrossEntropyLoss()
            >>> optimizer = paddle.optimizer.Adam(
            ...     learning_rate=0.001, parameters=model.parameters())
            >>> metrics = paddle.metric.Accuracy(topk=(1, 2))

            >>> engine = auto.Engine(model, loss, optimizer, metrics)
            >>> # fit
            >>> engine.fit(train_dataset,
            ...            epochs=2,
            ...            batch_size=64)
            >>> # evaluate
            >>> engine.evaluate(valid_dataset,
            ...                 batch_size=64)
            >>> # predict
            >>> engine.predict(valid_dataset,
            ...                batch_size=64)
            >>> # save
            >>> engine.save("./my_model")
            >>> # load
            >>> engine.load("./my_model")
113 114

    """
115

116 117 118 119 120 121 122 123 124 125 126 127 128 129
    def __init__(
        self,
        model=None,
        loss=None,
        optimizer=None,
        metrics=None,
        cluster=None,
        strategy=None,
    ):
        if (
            model
            and not isinstance(model, paddle.nn.Layer)
            and not callable(model)
        ):
130 131 132 133
            raise TypeError(
                "'model must be sub classes of `paddle.nn.Layer` or any callable function."
            )
        self._model = model
134 135 136
        self._parameter_list = (
            None if not model else [p.name for p in model.parameters()]
        )
137 138 139 140 141 142 143 144 145

        if (
            loss
            and not isinstance(loss, (paddle.nn.Layer, Variable))
            and not callable(loss)
        ):
            raise TypeError(
                "'loss' must be sub classes of `paddle.nn.Layer` or any callable function or a Variable."
            )
146 147 148
        self._loss = loss

        if optimizer and not isinstance(
149
            optimizer,
150
            (paddle.optimizer.Optimizer),
151
        ):
152 153
            raise TypeError(
                "'optimizer' must be object of class `paddle.optimizer.Optimizer`"
154
            )
155
        self._optimizer = auto_utils.validate_opt(optimizer)
156 157

        metrics = metrics or []
158
        for metric in auto_utils.to_list(metrics):
159 160 161 162 163 164
            if metric and not isinstance(metric, Metric):
                raise TypeError(
                    "{} is not sub class of Metric".format(
                        metric.__class__.__name__
                    )
                )
165
        self._metrics = auto_utils.to_list(metrics)
166 167 168 169 170 171 172 173 174 175 176 177 178

        if cluster and not isinstance(cluster, Cluster):
            raise TypeError(
                "'cluster' must be the object or class `paddle.distributed.auto_parallel.Cluster`"
            )
        self._cluster = cluster or get_default_cluster()

        if strategy and not isinstance(strategy, Strategy):
            raise TypeError(
                "'strategy' must be object of class `paddle.distributed.auto_parallel.Strategy`"
            )
        self._strategy = strategy or Strategy()

179
        self._logger = get_logger(logging.INFO)
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196

        self._json_config = None
        if cluster:
            self._cluster = cluster
        else:
            if os.getenv("PADDLE_AUTO_PARALLEL_CONFIG"):
                try:
                    path = os.getenv("PADDLE_AUTO_PARALLEL_CONFIG")
                    with open(path, "r") as f:
                        self._json_config = json.load(f)
                except Exception as e:
                    self._logger.info(
                        "Load json failed, please check json file, engine will run default config."
                    )
                    self._json_config = None
            self._cluster = get_default_cluster(self._json_config)

197
        if os.getenv("POD_NAME"):
198 199
            self._logger.info(
                "Distribute training by paddle.distributed.launch"
200
            )
201
            fleet.init(is_collective=True)
202

203 204 205 206 207 208
        # for compute cost
        # TODO: remove _fwd_main_progs and _orig_optimizer
        self._fwd_dist_contexts = {}
        self._fwd_main_progs = {}
        self._orig_optimizer = copy.deepcopy(self._optimizer)

209
        self._executor = None
210 211 212
        self._cur_rank = paddle.distributed.get_rank()
        self._nranks = paddle.distributed.get_world_size()
        self._saver = DistributedSaver()
213

214 215
        self._orig_main_prog = static.default_main_program()
        self._orig_startup_prog = static.default_startup_program()
216
        self._orig_dist_context = get_default_distributed_context()
217
        self._dist_contexts = {}
218
        self._planners = {}
219 220
        self._has_prepared = {"train": False, "eval": False, "predict": False}
        self._has_prepared_reader = {
221 222
            "train": False,
            "eval": False,
223
            "predict": False,
224
        }
225 226 227 228
        self._inputs_spec = []
        self._labels_spec = []
        self._inputs = []
        self._labels = []
229
        self._losses = []
230

231
        self._mode = None
232 233
        self._skip_build = False
        self._outside_dataloader = False
234
        self._planned_mode = None
235 236
        self._dygraph_mode = False
        self._tuning = self._strategy.tuning
237 238 239 240 241
        self._acc_steps = 1
        if self._strategy.gradient_merge.enable:
            self._acc_steps = self._strategy.gradient_merge.k_steps
        elif self._strategy.pipeline.enable:
            self._acc_steps = self._strategy.pipeline.accumulate_steps
242

243 244 245 246 247 248 249 250
        if (
            self._strategy.pipeline.enable
            and self._strategy.pipeline.schedule_mode == "1F1B"
        ):
            assert (
                os.getenv("CUDA_MODULE_LOADING") != "LAZY"
            ), "EXP_CUDA_MODULE_LOADING_LAZY not supported in 1F1B pipeline."

Z
zhaoyingli 已提交
251 252
        self.history = None

253
        paddle.framework.set_flags({'FLAGS_new_executor_sequential_run': 1})
254
        paddle.framework.set_flags({'FLAGS_new_executor_static_build': 1})
255

256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273
    def _prepare_data_spec(self, data, split, batch_size):
        inputs_spec = []
        labels_spec = []
        if isinstance(data, paddle.io.IterableDataset):
            if split is None:
                inputs, labels = next(iter(data))
            else:
                sample = next(iter(data))
                inputs = sample[:split]
                labels = sample[split:]
        elif isinstance(data, paddle.io.Dataset):
            if split is None:
                inputs, labels = data[0]
            else:
                sample = data[0]
                inputs = sample[:split]
                labels = sample[split:]
        else:
274
            raise TypeError(
C
chenxujun 已提交
275
                "Data should be a Dataset or IterableDataset, but received {}.".format(
276 277 278
                    type(data).__name__
                )
            )
279 280
        inputs = auto_utils.to_list(inputs)
        labels = auto_utils.to_list(labels)
281 282

        num_shards = self._strategy.dataset.num_shards
283

284 285 286 287 288 289 290 291 292 293 294 295
        def _adjust_item_spec(num_shards, spec):
            if num_shards > 1 and len(spec.shape) > 1:
                spec.shape[0] = spec.shape[0] * num_shards

        def _infer_item_spec(item, name, batch_size, specs):
            if isinstance(item, np.ndarray):
                spec = InputSpec.from_numpy(item, name)
                if batch_size is None:
                    _adjust_item_spec(num_shards, spec)
                    specs.append(spec)
                else:
                    specs.append(spec.batch(batch_size))
W
wanghuancoder 已提交
296
            elif isinstance(item, (Variable, core.eager.Tensor)):
297
                spec = InputSpec.from_tensor(item, name)
298
                _adjust_item_spec(num_shards, spec)
299 300 301 302
                if batch_size is None:
                    specs.append(spec)
                else:
                    specs.append(spec.batch(batch_size))
303
            elif isinstance(item, numbers.Number):
304
                specs.append(InputSpec([batch_size], type(item), name))
305 306 307 308 309 310
            else:
                raise TypeError(
                    "The sample's dtype returned of dataset should be number, np.ndarray or Tensor, but got {}".format(
                        type(item).__name__
                    )
                )
311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326

        if inputs is not None:
            for i, item in enumerate(inputs):
                assert item is not None, "Receive None input."
                name = "input" + str(i)
                _infer_item_spec(item, name, batch_size, inputs_spec)
        if labels is not None:
            for i, item in enumerate(labels):
                assert item is not None, "Receive None input."
                name = "label" + str(i)
                _infer_item_spec(item, name, batch_size, labels_spec)

        inputs_spec = self._validate_spec(inputs_spec)
        labels_spec = self._validate_spec(labels_spec)
        return inputs_spec, labels_spec

327
    def _prepare_data_tensor(self, inputs_spec, labels_spec, inputs, labels):
328
        if in_dynamic_mode() or self._dygraph_mode:
329 330
            raise ValueError("Only support static graph mode.")

331
        if inputs_spec:
332 333 334 335 336
            assert isinstance(
                inputs_spec, list
            ), "inputs should be list, but received {}".format(
                type(inputs_spec)
            )
337 338
            assert isinstance(
                inputs, list
339
            ), f"inputs should be list, but received {type(inputs)}"
340 341 342 343 344 345
            assert len(inputs_spec) == len(
                inputs
            ), "the number of `inputs_spec` should be equal to `inputs`'s."
            for input_spec, input in zip(inputs_spec, inputs):
                if input_spec.shape != input.shape:
                    input.desc.set_shape(input_spec.shape)
346
        if labels_spec:
347 348 349 350 351
            assert isinstance(
                labels_spec, list
            ), "labels should be list, but received {}".format(
                type(labels_spec)
            )
352 353
            assert isinstance(
                labels, list
354
            ), f"labels should be list, but received {type(labels)}"
355 356 357 358 359 360 361
            assert len(labels_spec) == len(
                labels
            ), "the number of `labels_spec` should be equal to `labels`'s."
            for label_spec, label in zip(labels_spec, labels):
                if label_spec.shape != label.shape:
                    label.desc.set_shape(label_spec.shape)

362 363
        return inputs, labels

364
    def _prepare_reader(self, feed_list=[]):
365
        dist_context = self._dist_contexts[self._mode]
366
        dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
367 368 369 370
        dist_main_block = dist_main_prog.global_block()

        # NOTE: this list may be changed if Paddle changes the existing rules.
        related_reader_ops = [
371 372 373
            "create_py_reader",
            "create_double_buffer_reader",
            "read",
374 375 376 377 378 379 380 381 382 383 384 385 386
        ]
        # remove the first three ops if multiple run fit/evaluate/predict
        if dist_main_block.ops[0].type == 'create_py_reader':
            for i in range(len(related_reader_ops)):
                if dist_main_block.ops[0].type in related_reader_ops:
                    dist_main_block._remove_op(0, sync=False)
        dist_main_block._sync_with_cpp()
        # Step 1: find the reader ops
        reader_op_indices = []
        for idx, op in enumerate(dist_main_block.ops):
            if op.type in related_reader_ops:
                reader_op_indices.append(idx)
        # Step 2: insert the new reader ops to cpp
387 388
        # record the read ops' desc to insert to program of forward task_node
        read_ops_desc = []
389 390 391 392
        new_reader_ops = []
        for idx in reversed(reader_op_indices):
            new_op_desc = dist_main_block.desc._prepend_op()
            new_op_desc.copy_from(dist_main_block.ops[idx].desc)
393
            read_ops_desc.append(new_op_desc)
394 395 396
            new_op = Operator(
                dist_main_block, new_op_desc, type=new_op_desc.type()
            )
397 398 399 400 401 402 403 404 405 406 407 408 409 410 411
            new_reader_ops.append(new_op)
            dist_op = DistributedOperator(new_op)
            dist_context.add_dist_op_for_program(dist_op)
        # Step 3: insert the new reader ops to python
        for new_op in new_reader_ops:
            dist_main_block.ops.insert(0, new_op)
        for i in range(len(reader_op_indices)):
            reader_op_indices[i] += len(reader_op_indices)
        # Step 4: remove the old reader ops from python and cpp
        for idx in reversed(reader_op_indices):
            op = dist_main_block.ops.pop(idx)
            dist_main_block.desc._remove_op(idx, idx + 1)
        dist_main_block._sync_with_cpp()
        self._has_prepared_reader[self._mode] = True

412 413 414 415 416
        # Insert read op to forward TaskNode for fleet executor if 1F1B pass is setted
        if (
            self.main_program._pipeline_opt
            and not auto_utils.use_new_executor()
        ):
417 418
            assert "tasks" in self.main_program._pipeline_opt["fleet_opt"]
            fleet_opt = self.main_program._pipeline_opt["fleet_opt"]
419 420 421 422 423 424
            fwd_task = None
            if self._strategy.pipeline.schedule_mode == "1F1B":
                fwd_task = fleet_opt["tasks"][1]
            elif self._strategy.pipeline.schedule_mode == "stream":
                fwd_task = fleet_opt["tasks"][0]
            assert fwd_task is not None
425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442
            fwd_prog = fwd_task.get_program()
            fwd_block = fwd_prog.global_block()

            for var in feed_list:
                if var.name not in fwd_block.vars:
                    fwd_block._clone_variable(var)

            for op_desc in read_ops_desc:
                new_op_desc = fwd_block.desc._prepend_op()
                new_op_desc.copy_from(op_desc)
                new_op = Operator(
                    fwd_block, new_op_desc, type=new_op_desc.type()
                )
                fwd_block.ops.insert(0, new_op)

            fwd_block._sync_with_cpp()
            fwd_task.set_program(fwd_prog)

443 444 445 446 447
    def _prepare_feed(self, data, user_feeds, mode):
        feeds = {}
        if data is not None:
            if isinstance(data, (list, tuple)):
                if len(data) == 1 and isinstance(data[0], dict):
448 449
                    for name, value in data[0].items():
                        feeds[name] = value
450
                else:
451
                    raise ValueError(f"Unsupported data {data}")
452
            elif isinstance(data, dict):
453 454
                for name, value in data.items():
                    feeds[name] = value
455
            else:
456
                raise ValueError(f"Unsupported data {data}")
457
        if user_feeds is not None:
458 459 460 461 462
            assert isinstance(
                user_feeds, dict
            ), "user_feeds must be a dict, but receive {}".format(
                type(user_feeds).__name__
            )
463 464
            for name, data in user_feeds.items():
                feeds[name] = data
465 466
        return feeds

467
    def _prepare_fetch(self, user_fetches, mode):
468
        if user_fetches is not None:
469 470 471 472 473
            assert isinstance(
                user_fetches, list
            ), "user_fetches must be a list, but receive {}".format(
                type(user_fetches).__name__
            )
474
        fetch_names = []
475
        fetch_indices = []
476

477 478
        def _process_fetch_group(group_name, var_list):
            group_indices = []
479
            for var in var_list:
480 481 482 483 484 485 486 487
                # Remove duplicate var_names
                if self._is_local_var(var):
                    var_name = _to_name_str(var)
                    if var_name not in fetch_names:
                        fetch_names.append(var_name)
                    group_indices.append(fetch_names.index(var_name))
            fetch_indices.append(group_indices)

488 489
        dist_context = self._dist_contexts[mode]
        fetch_vars = dist_context.serial_fetch_vars
490
        if mode != "predict":
491
            _process_fetch_group("loss", fetch_vars["loss"])
492
        if mode != "predict":
493
            metrics = fetch_vars["metrics"]
494 495 496
            for i, var_list in enumerate(metrics):
                _process_fetch_group("metrics_" + str(i), var_list)
        if mode == "predict":
497
            _process_fetch_group("outputs", fetch_vars["outputs"])
498
        for usr_fetch in user_fetches or []:
499 500
            var_name = _to_name_str(usr_fetch)
            fetch(var_name)
501 502 503
        user_fetches_collection = [
            item[1] for item in get_collection(CollectionNames.FETCHES)
        ]
504
        var_list = user_fetches_collection or []
505 506 507
        _process_fetch_group("fetches", var_list)
        return fetch_names, fetch_indices

508 509 510 511 512 513 514 515 516 517
    def _prepare_logger(
        self,
        outs,
        epoch=None,
        step=None,
        lr=None,
        fetch_names=None,
        fetch_indices=None,
        mode=None,
    ):
Z
zhaoyingli 已提交
518
        logs = {}
519
        if epoch is not None:
Z
zhaoyingli 已提交
520
            logs["epoch"] = epoch
521
        if step is not None:
Z
zhaoyingli 已提交
522
            logs["step"] = step + 1
523
        if lr is not None:
Z
zhaoyingli 已提交
524
            logs["lr"] = lr
525 526
        group_idx = 0
        if mode != "predict":
Z
zhaoyingli 已提交
527
            # logging loss
528
            loss_indices = fetch_indices[group_idx]
Z
zhaoyingli 已提交
529
            assert len(loss_indices) <= 1
530
            for idx in loss_indices:
531
                logs["loss"] = outs[idx]
532
            group_idx += 1
Z
zhaoyingli 已提交
533
            # logging metrics
534 535
            dist_context = self._dist_contexts[mode]
            metric_vars = dist_context.serial_fetch_vars["metrics"]
536 537 538 539 540 541 542 543 544
            if metric_vars:
                for metric in self._metrics:
                    metrics_indices = fetch_indices[group_idx]
                    metric_out = []
                    for idx in metrics_indices:
                        metric_out.append(outs[idx])
                    if metric_out:
                        metric.update(*metric_out)
                        results = metric.accumulate()
545
                        for i, res in enumerate(auto_utils.to_list(results)):
Z
zhaoyingli 已提交
546
                            logs[metric.name()[i]] = res
547
                    group_idx += 1
Z
zhaoyingli 已提交
548 549 550 551 552 553 554
        # logging outputs
        elif mode == "predict":
            outputs_indices = fetch_indices[group_idx]
            logs_out = {}
            for idx in outputs_indices:
                logs_out["out%d" % (idx)] = outs[idx]
            logs["outputs"] = logs_out
555 556
            group_idx += 1
        # logging user fetches
Z
zhaoyingli 已提交
557 558
        collect_fetches = get_collection(CollectionNames.FETCHES)
        logs_fetch = {}
559 560 561 562
        for name, var_name in collect_fetches:
            if var_name in fetch_names:
                idx = fetch_names.index(var_name)
                logs_fetch[name or var_name] = outs[idx]
Z
zhaoyingli 已提交
563 564
        logs["fetches"] = logs_fetch
        return logs
565

566
    def _prepare_program(self, mode, init_parameters=True):
567 568 569 570 571 572
        # Do the build process
        self._build(mode)
        # Do the planning process
        self._plan(mode)
        # Do the parallel process
        self._parallel(mode)
573 574 575 576 577
        # Init comm
        self._init_comm()
        if init_parameters:
            # startup program
            self._initialize(mode)
578 579
        self._has_prepared[mode] = True

580
    def _build(self, mode):
581
        if in_dynamic_mode() or self._dygraph_mode:
582
            paddle.disable_static()
583 584 585
            self._dygraph_mode = True
            self._logger.info("Building model with 'to_static' method.")

586
            self.program_helper = ProgramHelper(
587 588 589 590 591
                self._model,
                self._loss,
                self._metrics,
                self._inputs_spec,
                self._labels_spec,
592
            )
593
            # build forward main program
594 595
            with utils.unique_name.guard():
                self.program_helper.build_program(mode)
596

597 598 599
            self.concrete_program = self.program_helper.concrete_program
            serial_main_prog = self.program_helper.main_program
            serial_startup_prog = self.program_helper.startup_program
600

601 602
            self._inputs = self.program_helper.input_vars
            self._labels = self.program_helper.label_vars
603
            outputs = self.program_helper.output_vars
604
            self._losses = self.program_helper.loss_vars
605
            metrics = self.program_helper.metric_vars
606

607
            paddle.enable_static()
608
        else:
609 610 611
            # build program in static mode
            dist_context = self._dist_contexts.get(mode, None)
            if dist_context is not None:
612 613
                return

614
            outputs = []
615
            metrics = []
616
            self._losses = []
617 618
            serial_main_prog = self._orig_main_prog.clone()
            serial_startup_prog = self._orig_startup_prog.clone()
619
            if not self._skip_build:
620 621 622
                with static.program_guard(
                    serial_main_prog, serial_startup_prog
                ), utils.unique_name.guard():
623 624 625 626 627 628 629
                    self._inputs = [
                        s._create_feed_layer() for s in self._inputs_spec
                    ]
                    self._labels = [
                        s._create_feed_layer() for s in self._labels_spec
                    ]

630
                    outputs = auto_utils.to_list(self._model(*self._inputs))
631

632
                    if mode != "predict" and self._loss:
633 634 635 636 637
                        assert isinstance(
                            self._loss, paddle.nn.Layer
                        ) or callable(
                            self._loss
                        ), "the type of `loss` of the Engine arguments should be sub classes of `paddle.nn.Layer` or any callable function."
638
                        self._losses = auto_utils.to_list(
639 640
                            self._loss(*(outputs + self._labels))
                        )
641

642
                    if mode != "predict" and (outputs or self._labels):
643 644
                        for metric in self._metrics:
                            metrics.append(
645
                                auto_utils.to_list(
646 647
                                    metric.compute(*(outputs + self._labels))
                                )
648
                            )
Z
zhaoyingli 已提交
649
            elif mode == "train":
650 651 652
                assert isinstance(
                    self._loss, Variable
                ), "the type of `loss` of the Engine arguments should be Variable."
653
                self._losses = auto_utils.to_list(self._loss)
654 655 656 657 658 659 660

        default_ctx = get_default_distributed_context()
        if not default_ctx.has_annotation:
            # We build the world process group because the data parallel
            # needs all ranks by default.
            new_process_group(list(range(self._nranks)))
            default_ctx.data_parallel = True
661 662 663 664 665 666
            self._inputs = [
                auto_utils.set_data_parallel(var) for var in self._inputs
            ]
            self._labels = [
                auto_utils.set_data_parallel(var) for var in self._labels
            ]
667

668
        feed_vars = {"inputs": self._inputs, "labels": self._labels}
669 670

        fetch_vars = {
671
            "outputs": paddle.utils.flatten(outputs),
672
            "loss": self._losses,
673
            "metrics": metrics,
674 675
        }

676 677 678
        if mode != "train":
            serial_main_prog = serial_main_prog.clone(for_test=True)

679 680 681
        auto_utils.set_recompute_segments(
            self._model, self._losses, self._strategy, serial_main_prog
        )
682
        self._dist_contexts[mode] = DistributedContext(
683 684 685
            serial_main_prog,
            serial_startup_prog,
            self._optimizer,
686 687 688 689 690
            self._losses,
            feed_vars,
            fetch_vars,
            self._cluster,
            self._strategy,
691
            self._json_config,
692 693 694 695 696 697
        )
        self._fwd_dist_contexts[mode] = DistributedContext(
            serial_main_prog,
            serial_startup_prog,
            self._optimizer,
            self._losses,
698 699 700 701
            feed_vars,
            fetch_vars,
            self._cluster,
            self._strategy,
702
            self._json_config,
703
        )
704
        self._dist_contexts[mode].gradient_scale = self._strategy.gradient_scale
705
        self._fwd_main_progs[mode] = serial_main_prog.clone()
706

707 708 709
    def _optimization_tuning(self, mode, dataset, batch_size):
        if not self._tuning.enable:
            raise ValueError("Please set `tuning.enable=True`.")
710

711 712 713 714 715 716 717 718
        assert mode == "train"
        # Do the build process
        self._build(mode)
        # Do the planning process
        self._plan(mode)

        dataset.dp_world_size = self._dp_world_sizes
        dataset.dp_rank = self._dp_ranks
719 720

        from .tuner.optimization_tuner import OptimizationTuner
721 722 723 724 725 726 727 728 729

        self._optimization_tuner = OptimizationTuner(
            self._dist_contexts[mode],
            dataset,
            self._inputs_spec,
            self._labels_spec,
            batch_size=batch_size,
            rank=self._cur_rank,
        )
730 731 732

        self._optimization_tuner.tune()

733
        if self._tuning.run_after_tuning:
734 735
            # update the strategy
            self._dist_contexts[
736 737
                mode
            ]._strategy = self._optimization_tuner.get_best_config()
738

739 740 741 742 743 744
    def _plan(self, mode):
        if self._planned_mode is None:
            self._planned_mode = mode
        else:
            self._init_dist_context(mode)

745 746
        self._planners[mode] = Planner(mode, self._dist_contexts[mode])
        self._planners[mode].plan()
747

748 749 750 751
        # infer data parallel info
        inputs_var = self._dist_contexts[mode].serial_feed_vars["inputs"]
        labels_var = self._dist_contexts[mode].serial_feed_vars["labels"]
        block = self._dist_contexts[mode].serial_main_program.global_block()
752
        # TODO: check this feed_list
753 754 755 756 757
        feed_list = []
        for var in inputs_var + labels_var:
            if var.name in block.vars:
                feed_list.append(block.vars[var.name])

758 759
        self._dp_world_sizes = []
        self._dp_ranks = []
760
        for feed_var in feed_list:
761
            dp_world_size, dp_rank = auto_utils.get_input_split_info(
762
                self._cur_rank, feed_var, self._dist_contexts[mode]
763
            )
764 765
            self._dp_world_sizes.append(dp_world_size)
            self._dp_ranks.append(dp_rank)
766

767
    def _parallel(self, mode, all_ranks=False):
768
        # Parallelize program based on the planner's results
L
Leo Chen 已提交
769
        # For now, the completer has to be passed to the Parallelizer,
C
chenxujun 已提交
770
        # because we may use it to complete the annotation of the backward and update.
771
        parallelizer = Parallelizer(
Y
yuehuayingxueluo 已提交
772 773 774
            mode,
            self._planners[mode].completer,
            self._dist_contexts[mode],
775
        )
776
        if not all_ranks:
777
            parallelizer.parallel(self._cur_rank, self._parameter_list)
778
        else:
779
            parallelizer.parallel_all(self._parameter_list)
780 781

    def _init_dist_context(self, mode):
782
        # Init dist_context['mode'] with the first planned dist_context
783 784 785 786 787 788 789 790 791 792
        # to guarantee that train/eval/predict mode have same parallel strategy
        dist_context = self._dist_contexts[mode]
        origin_main_prog = dist_context._original_serial_main_program
        ref_mode = self._planned_mode
        ref_dist_context = self._dist_contexts[ref_mode]
        ref_origin_main_prog = ref_dist_context._original_serial_main_program
        ref_blocks = ref_origin_main_prog.blocks
        for ib, block in enumerate(origin_main_prog.blocks):
            for iop, op in enumerate(block.ops):
                ref_op = ref_blocks[ib].ops[iop]
793 794 795 796 797 798 799 800
                assert (
                    op.type == ref_op.type
                ), "'{}' mode op '{}' is different with '{}' op '{}'. ".format(
                    mode, op.type, ref_mode, ref_op.type
                )
                ref_op_dist_attr = (
                    ref_dist_context.get_op_dist_attr_for_program(ref_op)
                )
801 802
                dist_context.set_op_dist_attr_for_program(op, ref_op_dist_attr)

803
    def _init_comm(self):
804 805 806 807
        if self._nranks > 1:
            # Traverse different rank programs and traverse each op of them,
            # instantiate communication by process_mapping.
            all_process_groups = get_all_process_groups()
808

809
            if self._strategy.auto_mode == "full_random":
810
                auto_utils.initialize_pg_in_full_mode(
811
                    all_process_groups, self._cur_rank
812
                )
813 814 815
            else:
                for process_group in all_process_groups:
                    process_group.instantiate()
816

817
    def _initialize(self, mode):
818
        self._place = _get_device()
819
        if isinstance(self._place, paddle.framework.CUDAPlace):
820 821 822
            self._place = paddle.framework.CUDAPlace(
                paddle.distributed.ParallelEnv().dev_id
            )
823

824 825 826 827 828
        if self._strategy.seed:
            paddle.seed(self._strategy.seed + self._dp_ranks[0])
            np.random.seed(self._strategy.seed + self._dp_ranks[0])
            random.seed(self._strategy.seed + self._dp_ranks[0])

829
        dist_context = self._dist_contexts[mode]
830
        if self._dygraph_mode:
831
            dist_main_program = dist_context.dist_main_programs[self._cur_rank]
832 833 834
            self.program_helper.init(
                dist_main_program, self._place, dist_context
            )
835

836 837 838 839 840 841 842 843
        # NOTE(zhaoyinglia): Skip startup program when use new ir temporarily.
        use_new_ir = False
        if auto_utils.use_new_ir():
            use_new_ir = True
            paddle.framework.set_flags(
                {"FLAGS_enable_new_ir_in_executor": False}
            )

844
        if self._executor is None:
845
            self._executor = paddle.static.Executor(self._place)
846
            uninitialized = []
847 848 849
            dist_startup_prog = dist_context.dist_startup_programs[
                self._cur_rank
            ]
850 851 852 853 854 855 856 857
            for var in dist_startup_prog.list_vars():
                scope_var = global_scope().find_var(var.name)
                if scope_var and scope_var.get_tensor()._is_initialized():
                    continue
                uninitialized.append(var)
            if uninitialized:
                prune_startup_prog = dist_startup_prog._prune(uninitialized)
                self._executor.run(prune_startup_prog)
858

859
            if hasattr(self, "_state_dict") and hasattr(self, "_dist_attr"):
860 861 862
                self._set_state_dict(
                    mode, self._strict, self._state_dict, self._dist_attr
                )
863 864

        if self._strategy.reinit:
Z
zhaoyingli 已提交
865
            self._logger.info("NOTE: parameters will be re-initialized.")
866 867 868
            dist_startup_prog = dist_context.dist_startup_programs[
                self._cur_rank
            ]
869 870
            self._executor.run(dist_startup_prog)

871 872 873 874 875
        if use_new_ir:
            paddle.framework.set_flags(
                {"FLAGS_enable_new_ir_in_executor": True}
            )

876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892
    def fit(
        self,
        train_data,
        train_sample_split=None,
        batch_size=1,
        epochs=1,
        steps_per_epoch=None,
        log_freq=10,
        save_dir=None,
        save_freq=1,
        valid_data=None,
        valid_sample_split=None,
        valid_freq=1,
        valid_steps=None,
        collate_fn=None,
        callbacks=None,
        verbose=2,
893
        nvprof_range=[-1, -1],
894
    ):
895 896 897 898 899 900 901 902
        """
        Trains the model for a fixed number of epochs. If `valid_data` is set,
        evaluation will be done at the end of each epoch.

        Args:
            train_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None.
            train_sample_split (int, optional): Each sample of the train dataset is assumed
                to be a (input, label) pair by default and has two items. If each sample has
903
                more than two items, train_sample_split specifies how to split these items into
904
                input and label. The items before it are input and the left are label. Default: None.
905
            batch_size (int, optional): The batch size of train_data and valid_data if provided.
906 907 908
                The user's data will be used directly without batching if set to None. Default: 1.
            epochs (int, optional): The number of epochs to train the model. Default: 1.
            steps_per_epoch (int, optional): The total number of steps (batches of samples)
909
                is executed in one epoch before stating the next one. If None, it is equal to
910 911
                the number samples in your dataset divided by the batch size. Default: None.
            valid_data (Dataset, optional): An instance of paddle paddle.io.Dataset used for
912
                evaluation at the end of epoch. No evaluation will be done if set to None.
913
                Default: None. (Unsupported for now)
914
            valid_freq (int, optional): Only relevant if valid_data is provided. This specifies
915 916
                how many training epochs before a new evaluation is performed. Default: 1.
            valid_sample_split (int, optional): Only relevant if valid_data is provided.
917 918
                Each sample of the valid dataset is assumed to be a (input, label) pair
                by default and has two items. If each sample has more than two items,
919 920 921
                valid_sample_split specifies how to split these items into input and label.
                The items before it are input and the left are label. Default: None.
            valid_steps (int, optional): Only relevant if valid_data is provided.
922 923
                It is the total number of steps (batches of samples) to draw before
                stopping validation at the end of every epoch. If None, validation will run until the
924 925 926 927
                `valid_data` dataset is exhausted. The validation will start from the
                beginning of the dataset at each epoch. Default: None.
            collate_fn(callable, optional): function to generate mini-batch data by merging
                the sample list, None for only stack each fields of sample in axis
928
                0. Default None.
929 930
            callbacks (Callback|None, optional): A list of `Callback` instances to apply
                during training. Default: None. (Unused for now)
931
            nvprof_range(list, optional): A list of integers indicating nvprof ranges in form of [start_step, end_step]. Note that if start_step >= end_step, the nvprof will not apply.
932 933 934 935 936 937 938 939

        Returns:
            None

        Examples:

            .. code-block:: python

940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960
                >>> import paddle
                >>> import paddle.vision.transforms as T
                >>> from paddle.distributed.fleet import auto
                >>> from paddle.vision.datasets import MNIST

                >>> transform = T.Compose([
                ...     T.Transpose(),
                ...     T.Normalize([127.5], [127.5])
                >>> ])
                >>> train_dataset = MNIST(mode='train', transform=transform)

                >>> model = paddle.vision.models.LeNet()
                >>> loss = paddle.nn.CrossEntropyLoss()
                >>> optimizer = paddle.optimizer.Adam(
                ...     learning_rate=0.001, parameters=model.parameters())
                >>> metrics = paddle.metric.Accuracy(topk=(1, 2))

                >>> engine = auto.Engine(model, loss, optimizer, metrics)
                >>> engine.fit(train_dataset,
                ...             epochs=2,
                ...             batch_size=64)
961
        """
962
        self._mode = 'train'
963

964
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
965 966
            train_data, train_sample_split, batch_size
        )
967

968 969
        if not self._has_prepared[self._mode]:
            self._prepare_program(self._mode)
Z
zhaoyingli 已提交
970
        else:
971
            self._switch_mode(self._mode)
Z
zhaoyingli 已提交
972

973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001
        if auto_utils.use_new_executor():
            local_batch_size = self._validate_batch_size(batch_size)
            train_dataloader = self._prepare_dataloader(
                train_data,
                return_list=False,
                batch_size=local_batch_size,
                epochs=epochs,
                collate_fn=collate_fn,
            )
            steps_per_epoch = (
                len(train_dataloader)
                if steps_per_epoch is None
                else steps_per_epoch
            )
        else:
            micro_batch_size = self._validate_batch_size(batch_size)
            train_dataloader = self._prepare_dataloader_from_generator(
                dataset=train_data,
                capacity=70,
                iterable=False,
                batch_size=micro_batch_size,
                epochs=epochs,
                steps_per_epoch=steps_per_epoch,
                collate_fn=collate_fn,
            )
            steps_per_epoch = train_dataloader._steps
            local_batch_size = micro_batch_size
            if self._strategy.pipeline.enable:
                local_batch_size = micro_batch_size * self._acc_steps
Z
zhaoyingli 已提交
1002

1003
        fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode)
Z
zhaoyingli 已提交
1004 1005 1006 1007

        cbks = config_callbacks(
            callbacks,
            engine=self,
1008
            batch_size=local_batch_size,
Z
zhaoyingli 已提交
1009
            epochs=epochs,
1010
            steps=steps_per_epoch,
Z
zhaoyingli 已提交
1011 1012 1013 1014 1015
            log_freq=log_freq,
            save_freq=save_freq,
            save_dir=save_dir,
            verbose=verbose,
            metrics=self._metrics_name(),
1016 1017 1018
            acc_step=1
            if self._strategy.pipeline.enable
            else self._acc_steps,  # lr update once every local batch
Z
zhaoyingli 已提交
1019 1020 1021 1022 1023 1024
        )

        cbks.on_begin('train')
        for epoch in range(epochs):
            logs = {}
            cbks.on_epoch_begin(epoch)
1025

1026
            for step, batch in enumerate(train_dataloader):
1027
                if auto_utils.use_new_executor():
1028
                    batches = self._validate_batch(batch)
1029
                else:
1030
                    batches = [{}]
1031 1032

                try:
1033
                    for micro_batch in batches:
1034 1035 1036 1037 1038 1039 1040 1041
                        with paddle.profiler.utils._nvprof_range(
                            iter_id=step,
                            start=nvprof_range[0],
                            end=nvprof_range[1],
                        ):
                            cbks.on_batch_begin('train', step, logs)
                            outs = self._executor.run(
                                self.main_program,
1042
                                feed=micro_batch,
1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064
                                fetch_list=fetch_names,
                                use_program_cache=self._strategy.use_cache,
                                return_numpy=self._strategy.return_numpy,
                            )
                            lr = auto_utils.get_lr(self.optimizer)
                            logs = self._prepare_logger(
                                outs,
                                epoch,
                                step,
                                lr,
                                fetch_names,
                                fetch_indices,
                                self._mode,
                            )
                            cbks.on_batch_end('train', step, logs)
                except core.EOFException:
                    break

                if steps_per_epoch and step >= steps_per_epoch:
                    if not auto_utils.use_new_executor():
                        train_dataloader._reset()
                    break
Z
zhaoyingli 已提交
1065 1066

            if valid_data and (epoch + 1) % valid_freq == 0:
1067 1068 1069 1070 1071 1072 1073 1074 1075 1076
                val_logs = self.evaluate(
                    valid_data,
                    valid_sample_split,
                    batch_size,
                    valid_steps,
                    log_freq,
                    collate_fn,
                    callbacks,
                    verbose,
                )
Z
zhaoyingli 已提交
1077
                val_logs = {
1078
                    "val_" + name: val for name, val in val_logs.items()
Z
zhaoyingli 已提交
1079 1080 1081 1082 1083 1084 1085 1086 1087 1088
                }
                logs.update(val_logs)
                self._switch_mode("train")
            else:
                self._reset_metrics()

            cbks.on_epoch_end(epoch, logs)

        cbks.on_end('train', logs)
        return self.history
1089

1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100
    def evaluate(
        self,
        valid_data,
        valid_sample_split=None,
        batch_size=1,
        steps=None,
        log_freq=10,
        collate_fn=None,
        callbacks=None,
        verbose=2,
    ):
1101 1102 1103 1104
        """
        Evaluate the loss and metrics of the model on evaluation data.

        Args:
1105 1106
            valid_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None.
            valid_sample_split (int, optional): Each sample of the eval dataset is assumed
1107
                to be a (input, label) pair by default and has two items. If each sample has
1108
                more than two items, valid_sample_split specifies how to split these items into
1109
                input and label. The items before it are input and the left are label. Default: None.
1110
            batch_size (int, optional): The batch size of valid_data. The user's data will
1111
                be used directly without batching if set to None. Default: 1.
1112 1113
            steps (int, optional): It is the total number of steps (batches of samples) to draw before
                stopping evaluation. If None, evaluation will run until the `valid_data` dataset is exhausted.
1114 1115 1116 1117 1118
                The evaluation will start from the beginning of the dataset in each run. Default: None.
            collate_fn(callable, optional): function to generate mini-batch data by merging
                the sample list, None for only stack each fields of sample in axis
                0. Default None.
            callbacks (Callback|None, optional): A list of `Callback` instances to apply
1119
                during evaluating. Default: None. (Unused for now)
1120 1121 1122 1123 1124 1125 1126 1127

        Returns:
            None

        Examples:

            .. code-block:: python

1128 1129 1130 1131
                >>> import paddle
                >>> import paddle.vision.transforms as T
                >>> from paddle.distributed.fleet import auto
                >>> from paddle.vision.datasets import MNIST
1132

1133 1134 1135 1136 1137
                >>> transform = T.Compose([
                ...     T.Transpose(),
                ...     T.Normalize([127.5], [127.5])
                >>> ])
                >>> valid_dataset = MNIST(mode='test', transform=transform)
1138

1139 1140 1141
                >>> model = paddle.vision.models.LeNet()
                >>> loss = paddle.nn.CrossEntropyLoss()
                >>> metrics = paddle.metric.Accuracy(topk=(1, 2))
1142

1143 1144
                >>> engine = auto.Engine(model, loss, metrics=metrics)
                >>> engine.evaluate(valid_dataset, batch_size=64)
1145 1146

        """
1147 1148
        self._mode = 'eval'
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
1149 1150
            valid_data, valid_sample_split, batch_size
        )
1151

1152 1153
        if not self._has_prepared[self._mode]:
            self._prepare_program(self._mode)
Z
zhaoyingli 已提交
1154
        else:
1155
            self._switch_mode(self._mode)
Z
zhaoyingli 已提交
1156

1157
        micro_batch_size = self._validate_batch_size(batch_size)
1158 1159 1160 1161
        valid_dataloader = self._prepare_dataloader_from_generator(
            dataset=valid_data,
            capacity=70,
            iterable=False,
1162
            batch_size=micro_batch_size,
1163
            steps_per_epoch=steps,
1164 1165
            collate_fn=collate_fn,
        )
Z
zhaoyingli 已提交
1166

1167
        fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode)
1168

Z
zhaoyingli 已提交
1169 1170 1171
        cbks = config_callbacks(
            callbacks,
            engine=self,
1172
            batch_size=micro_batch_size,
Z
zhaoyingli 已提交
1173 1174 1175 1176 1177 1178
            log_freq=log_freq,
            verbose=verbose,
            metrics=self._metrics_name(),
        )

        eval_steps = valid_dataloader._steps
1179 1180 1181
        cbks.on_begin(
            'eval', {'steps': eval_steps, 'metrics': self._metrics_name()}
        )
Z
zhaoyingli 已提交
1182
        logs = {}
1183
        for step, _ in enumerate(valid_dataloader):
Z
zhaoyingli 已提交
1184
            cbks.on_batch_begin('eval', step, logs)
1185
            try:
1186 1187
                outs = self._executor.run(
                    self.main_program,
1188
                    fetch_list=fetch_names,
1189
                    use_program_cache=self._strategy.use_cache,
1190 1191
                    return_numpy=self._strategy.return_numpy,
                )
1192
            except core.EOFException:
1193
                break
1194 1195 1196
            logs = self._prepare_logger(
                outs, None, step, None, fetch_names, fetch_indices, self._mode
            )
Z
zhaoyingli 已提交
1197 1198
            cbks.on_batch_end('eval', step, logs)
        cbks.on_end('eval', logs)
1199
        self._reset_metrics()
Z
zhaoyingli 已提交
1200
        return logs
1201

1202 1203 1204 1205 1206 1207 1208 1209 1210 1211
    def predict(
        self,
        test_data,
        test_sample_split=None,
        batch_size=1,
        steps=None,
        collate_fn=None,
        callbacks=None,
        verbose=2,
    ):
1212 1213 1214 1215 1216 1217 1218
        """
        Compute the output predictions on testing data.

        Args:
            test_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None.
            test_sample_split (int, optional): Each sample of the test dataset is assumed
                to be a (input, label) pair by default and has two items. If each sample has
1219
                more than two items, test_sample_split specifies how to split these items into
1220 1221 1222
                input and label. The items before it are input and the left are label. Default: None.
            batch_size (int, optional): The batch size of test_data. The user's data will
                be used directly without batching if set to None. Default: 1.
1223 1224
            steps (int, optional): It is the total number of steps (batches of samples) to draw before
                stopping predict. If None, predict will run until the `test_data` dataset is exhausted.
1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238
                The predict will start from the beginning of the dataset in each run. Default: None.
            collate_fn(callable, optional): function to generate mini-batch data by merging
                the sample list, None for only stack each fields of sample in axis
                0. Default None.
            callbacks (Callback|None, optional): A list of `Callback` instances to apply
                during testing. Default: None. (Unused for now)

        Returns:
            None

        Examples:

            .. code-block:: python

1239 1240 1241 1242
                >>> import paddle
                >>> import paddle.vision.transforms as T
                >>> from paddle.distributed.fleet import auto
                >>> from paddle.vision.datasets import MNIST
1243

1244 1245 1246 1247 1248
                >>> transform = T.Compose([
                ...     T.Transpose(),
                ...     T.Normalize([127.5], [127.5])
                >>> ])
                >>> valid_dataset = MNIST(mode='test', transform=transform)
1249

1250
                >>> model = paddle.vision.models.LeNet()
1251

1252 1253
                >>> engine = auto.Engine(model)
                >>> engine.predict(valid_dataset, batch_size=64)
1254
        """
1255 1256
        self._mode = 'predict'
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
1257 1258
            test_data, test_sample_split, batch_size
        )
1259

1260 1261
        if not self._has_prepared[self._mode]:
            self._prepare_program(self._mode)
Z
zhaoyingli 已提交
1262
        else:
1263
            self._switch_mode(self._mode)
Z
zhaoyingli 已提交
1264

1265
        micro_batch_size = self._validate_batch_size(batch_size)
1266 1267 1268 1269
        test_dataloader = self._prepare_dataloader_from_generator(
            dataset=test_data,
            capacity=70,
            iterable=False,
1270
            batch_size=micro_batch_size,
1271
            steps_per_epoch=steps,
1272 1273
            collate_fn=collate_fn,
        )
Z
zhaoyingli 已提交
1274

1275
        fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode)
1276

Z
zhaoyingli 已提交
1277 1278 1279 1280 1281
        outputs = []
        cbks = config_callbacks(callbacks, engine=self, verbose=verbose)
        test_steps = test_dataloader._steps
        cbks.on_begin('predict', {'steps': test_steps})
        logs = {}
1282
        for step, _ in enumerate(test_dataloader):
Z
zhaoyingli 已提交
1283
            cbks.on_batch_begin('predict', step, logs)
1284
            try:
1285 1286
                outs = self._executor.run(
                    self.main_program,
1287
                    fetch_list=fetch_names,
1288
                    use_program_cache=self._strategy.use_cache,
1289 1290
                    return_numpy=self._strategy.return_numpy,
                )
1291
            except core.EOFException:
1292
                break
1293 1294 1295
            logs = self._prepare_logger(
                outs, None, step, None, fetch_names, fetch_indices, self._mode
            )
Z
zhaoyingli 已提交
1296 1297 1298 1299 1300
            cbks.on_batch_end('predict', step, logs)
            outputs.append(list(logs["outputs"].values()))
        cbks.on_end('predict', logs)
        return outputs

1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316
    def dataloader(
        self,
        dataset,
        batch_size=1,
        shuffle=False,
        drop_last=False,
        collate_fn=None,
        num_workers=0,
        use_buffer_reader=True,
        use_shared_memory=True,
        timeout=0,
        worker_init_fn=None,
        epochs=1,
        steps_per_epoch=None,
        sample_split=1,
        mode=None,
1317
        places=None,
1318
    ):
1319 1320
        if mode is not None:
            self.to_mode(mode)
1321

1322
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
1323 1324
            dataset, sample_split, batch_size
        )
1325

1326 1327
        if not self._has_prepared[self._mode]:
            self._prepare_program(self._mode)
1328
        else:
1329
            self._switch_mode(self._mode)
1330

1331
        batch_size = self._validate_batch_size(batch_size)
1332 1333 1334
        dataloader = self._prepare_dataloader(
            dataset,
            return_list=False,
1335
            batch_size=batch_size,
1336 1337 1338 1339 1340 1341 1342 1343 1344
            shuffle=shuffle,
            drop_last=drop_last,
            collate_fn=collate_fn,
            num_workers=num_workers,
            use_buffer_reader=use_buffer_reader,
            use_shared_memory=use_shared_memory,
            timeout=timeout,
            worker_init_fn=worker_init_fn,
            epochs=epochs,
1345
            steps_per_epoch=steps_per_epoch,
1346
            places=places,
1347
        )
1348 1349
        return dataloader

1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364
    def dataloader_from_generator(
        self,
        dataset,
        capacity=70,
        use_double_buffer=True,
        iterable=True,
        use_multiprocess=False,
        drop_last=True,
        batch_size=1,
        epochs=1,
        steps_per_epoch=None,
        collate_fn=None,
        sample_split=1,
        mode=None,
    ):
1365 1366 1367
        if mode is not None:
            self.to_mode(mode)
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
1368 1369
            dataset, sample_split, batch_size
        )
1370

1371 1372 1373 1374
        if not self._has_prepared[self._mode]:
            self._prepare_program(self._mode)
        else:
            self._switch_mode(self._mode)
1375

1376
        micro_batch_size = self._validate_batch_size(batch_size)
1377 1378 1379 1380 1381 1382 1383 1384
        dataloader = self._prepare_dataloader_from_generator(
            dataset=dataset,
            capacity=capacity,
            use_double_buffer=use_double_buffer,
            iterable=iterable,
            return_list=False,
            use_multiprocess=use_multiprocess,
            drop_last=drop_last,
1385
            batch_size=micro_batch_size,
1386 1387
            epochs=epochs,
            steps_per_epoch=steps_per_epoch,
1388 1389
            collate_fn=collate_fn,
        )
1390 1391
        return dataloader

1392 1393 1394 1395 1396 1397 1398 1399 1400
    def prepare(
        self,
        inputs_spec=None,
        labels_spec=None,
        inputs=None,
        labels=None,
        main_program=None,
        startup_program=None,
        mode=None,
1401
        init_parameters=True,
1402
    ):
1403 1404
        if mode is not None:
            self.to_mode(mode)
1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420

        if not self._mode:
            raise ValueError(
                "Please set mode to be prepared with `prepare(mode=...)`"
            )

        if self._has_prepared[self._mode]:
            return

        inputs_spec = self._validate_spec(inputs_spec)
        labels_spec = self._validate_spec(labels_spec)
        inputs = self._validate_vars(inputs)
        labels = self._validate_vars(labels)

        self._orig_main_prog = main_program
        self._orig_startup_prog = startup_program
1421 1422
        if inputs or labels:
            self._skip_build = True
1423 1424
            inputs, labels = self._prepare_data_tensor(
                inputs_spec, labels_spec, inputs, labels
1425
            )
1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436
            if self._orig_main_prog is None:
                self._orig_main_prog = static.default_main_program()
            if self._orig_startup_prog is None:
                self._orig_startup_prog = static.default_startup_program()
        elif inputs_spec or labels_spec:
            self._outside_dataloader = True
            if self._orig_main_prog is None:
                self._orig_main_prog = static.default_main_program()
            if self._orig_startup_prog is None:
                self._orig_startup_prog = static.default_startup_program()
        else:
1437 1438 1439
            assert (
                self._inputs_spec and self._labels_spec
            ), "Please call the dataloader(...) before calling prepare(...)"
1440

1441 1442 1443
        self._inputs_spec, self._labels_spec = inputs_spec, labels_spec
        self._inputs, self._labels = inputs, labels
        if not self._has_prepared[self._mode]:
1444
            self._prepare_program(self._mode, init_parameters)
1445 1446 1447
        else:
            self._switch_mode(self._mode)

1448
    def run(self, data=None, feed=None, fetch_list=None, mode=None):
1449 1450 1451 1452
        if mode is not None:
            self.to_mode(mode)
        feed_dict = self._prepare_feed(data, feed, self._mode)
        fetch_names, fetch_indices = self._prepare_fetch(fetch_list, self._mode)
1453 1454 1455 1456
        if (
            self._outside_dataloader
            and not self._has_prepared_reader[self._mode]
        ):
1457
            self._prepare_reader()
1458 1459 1460 1461 1462 1463 1464 1465 1466 1467
        outs = self._executor.run(
            self.main_program,
            feed=feed_dict,
            fetch_list=fetch_names,
            use_program_cache=self._strategy.use_cache,
            return_numpy=self._strategy.return_numpy,
        )
        logs = self._prepare_logger(
            outs, None, None, None, fetch_names, fetch_indices, self._mode
        )
Z
zhaoyingli 已提交
1468
        return logs
1469

1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484
    def _prepare_dataloader(
        self,
        dataset,
        return_list=True,
        batch_size=1,
        shuffle=False,
        drop_last=False,
        collate_fn=None,
        num_workers=0,
        use_buffer_reader=True,
        use_shared_memory=True,
        timeout=0,
        worker_init_fn=None,
        epochs=1,
        steps_per_epoch=None,
1485
        places=None,
1486
    ):
1487 1488 1489
        dist_context = self._dist_contexts[self._mode]
        dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
        dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank]
1490
        dist_main_block = dist_main_prog.global_block()
1491

1492 1493 1494 1495
        # NOTE: Get feed_list, then insert dataloader op with sharded var shape.
        # Cause predict_program does not contain labels var,
        # then we will add labels var from serial_program to dist_program,
        # that maintains the length of feed_list equal to the length of dataset's values.
1496 1497
        inputs_var = dist_context.serial_feed_vars["inputs"]
        labels_var = dist_context.serial_feed_vars["labels"]
1498 1499 1500 1501
        feed_list = []
        for var in inputs_var + labels_var:
            if var.name in dist_main_block.vars:
                feed_list.append(dist_main_block.vars[var.name])
1502 1503 1504 1505
            else:
                copy_var = dist_main_block._clone_variable(var, var.persistable)
                copy_var.desc.set_original_id(var.desc.original_id())
                feed_list.append(copy_var)
1506 1507

        # insert read op at the end of program
1508
        with static.program_guard(dist_main_prog, dist_startup_prog):
1509
            dataloader = DistributedDataLoader(
1510
                dataset,
1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525
                feed_list=feed_list,
                places=places,
                return_list=return_list,
                batch_size=batch_size,
                shuffle=shuffle,
                drop_last=drop_last,
                collate_fn=collate_fn,
                num_workers=num_workers,
                use_buffer_reader=use_buffer_reader,
                use_shared_memory=use_shared_memory,
                timeout=timeout,
                worker_init_fn=worker_init_fn,
                epochs=epochs,
                steps_per_epoch=steps_per_epoch,
                split_data=self._strategy.split_data,
1526
                data_parallel_world_size=self._dp_world_sizes,
1527 1528
                data_parallel_rank=self._dp_ranks,
            )
1529

1530 1531
        return dataloader

1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545
    def _prepare_dataloader_from_generator(
        self,
        dataset,
        capacity=None,
        use_double_buffer=True,
        iterable=True,
        return_list=False,
        use_multiprocess=False,
        drop_last=True,
        batch_size=1,
        epochs=1,
        steps_per_epoch=None,
        collate_fn=None,
    ):
1546 1547 1548
        dist_context = self._dist_contexts[self._mode]
        dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
        dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank]
1549 1550 1551 1552 1553 1554
        dist_main_block = dist_main_prog.global_block()

        # NOTE: Get feed_list, then insert dataloader op with sharded var shape.
        # Cause predict_program does not contain labels var,
        # then we will add labels var from serial_program to dist_program,
        # that maintains the length of feed_list equal to the length of dataset's values.
1555 1556
        inputs_var = dist_context.serial_feed_vars["inputs"]
        labels_var = dist_context.serial_feed_vars["labels"]
1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583
        feed_list = []
        for var in inputs_var + labels_var:
            if var.name in dist_main_block.vars:
                feed_list.append(dist_main_block.vars[var.name])
            else:
                copy_var = dist_main_block._clone_variable(var, var.persistable)
                copy_var.desc.set_original_id(var.desc.original_id())
                feed_list.append(copy_var)

        places = paddle.static.cuda_places()
        with static.program_guard(dist_main_prog, dist_startup_prog):
            dataloader = DistributedDataLoaderFromGenerator(
                dataset=dataset,
                feed_list=feed_list,
                capacity=capacity,
                use_double_buffer=use_double_buffer,
                iterable=iterable,
                return_list=return_list,
                use_multiprocess=use_multiprocess,
                drop_last=drop_last,
                places=places,
                batch_size=batch_size,
                epochs=epochs,
                steps_per_epoch=steps_per_epoch,
                collate_fn=collate_fn,
                split_data=self._strategy.split_data,
                data_parallel_world_size=self._dp_world_sizes,
1584
                data_parallel_rank=self._dp_ranks,
1585 1586 1587
                acc_steps=1
                if not self._strategy.pipeline.enable
                else self._acc_steps,
1588
            )
1589
        self._prepare_reader(feed_list)
1590 1591 1592 1593 1594
        return dataloader

    def _tune(self, tune_data, tune_sample_split=None, batch_size=1):
        self._mode = 'train'
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
1595 1596
            tune_data, tune_sample_split, batch_size
        )
1597 1598
        self._optimization_tuning(self._mode, tune_data, batch_size)

1599 1600 1601 1602
    def _validate_batch_size(self, batch_size):
        if batch_size is None:
            return None

1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624
        if auto_utils.use_new_executor():
            assert (
                len(set(self._dp_world_sizes)) == 1
            ), "DistributedBatchSampler only support one data parallel group, but got [{}] different data parallel groups".format(
                len(set(self._dp_world_sizes))
            )
            assert (
                batch_size % self._dp_world_sizes[0] == 0
            ), "batch_size [{}] is not divisible by dp_world_size [{}]".format(
                str(batch_size), str(self._dp_world_sizes[0])
            )
            return batch_size // self._dp_world_sizes[0]
        else:
            assert (
                batch_size % self._acc_steps == 0
            ), "Requires batch_size:[{}] to be divisible by acc_steps:[{}].".format(
                batch_size, self._acc_steps
            )
            return batch_size // self._acc_steps

    def _validate_batch(self, batch):
        if batch is None:
1625
            return [None]
1626

1627
        if self._strategy.pipeline.enable or self._acc_steps == 1:
1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643
            # pp with schedule or navie-pp
            return batch
        else:
            # split feed data with gradient_merge k_steps
            feed_names = []
            split_batches = []
            for feed_name, cur_feed in batch[0].items():
                feed_names.append(feed_name)
                split_batches.append(
                    np.split(np.array(cur_feed), self._acc_steps, 0)
                )
            baches = []
            for i in range(self._acc_steps):
                micro_batch = [split_batch[i] for split_batch in split_batches]
                baches.append(dict(zip(feed_names, micro_batch)))
            return baches
1644

1645
    def _validate_spec(self, specs):
1646
        specs = auto_utils.to_list(specs)
1647 1648
        if specs is not None:
            for i, spec in enumerate(specs):
1649 1650 1651 1652
                if not isinstance(spec, InputSpec):
                    raise TypeError(
                        "'spec' must be object of class `paddle.static.InputSpec`."
                    )
1653 1654
                if spec.name is None:
                    raise ValueError(
1655 1656 1657 1658
                        "Requires Input[{}].name != None, but receive `None` with {}.".format(
                            i, spec
                        )
                    )
1659
                if self._acc_steps > 1:
1660
                    shape = list(spec.shape)
1661
                    assert (
1662
                        shape[0] % self._acc_steps == 0
1663
                    ), "Requires batch_size[{}] to be divisible by k_steps[{}].".format(
1664
                        spec.shape[0], self._acc_steps
1665
                    )
1666
                    shape[0] //= self._acc_steps
1667
                    spec.shape = shape
1668 1669 1670
        return specs or []

    def _validate_vars(self, vars):
1671
        vars = auto_utils.to_list(vars)
1672 1673 1674 1675 1676
        if vars is not None:
            for i, var in enumerate(vars):
                if not isinstance(var, Variable):
                    raise TypeError("'var' must be a `Variable`.")
        return vars or []
1677

1678 1679 1680 1681
    def _is_local_var(self, var):
        var_name = _to_name_str(var)
        return var_name in self.main_program.global_block().vars

1682 1683 1684 1685
    def _reset_metrics(self):
        for metric in self._metrics:
            metric.reset()

Z
zhaoyingli 已提交
1686 1687 1688
    def _metrics_name(self):
        metrics_name = ['loss'] if self._loss else []
        for m in self._metrics:
1689
            metrics_name.extend(auto_utils.to_list(m.name()))
Z
zhaoyingli 已提交
1690 1691
        return metrics_name

1692
    def _switch_mode(self, mode):
1693
        assert (
1694
            mode in self._dist_contexts
1695
        ), f"{mode} model is not ready, please call `prepare()` first."
1696
        self.to_mode(mode)
1697

1698
    def to_mode(self, mode):
1699 1700 1701 1702
        assert mode in [
            "train",
            "eval",
            "predict",
1703
        ], f"mode {mode} should be one of ['train', 'eval', 'predict']"
1704 1705
        self._mode = mode

1706 1707
    def _set_state_dict(self, mode, strict, state_dict, dist_attr):
        dist_context = self._dist_contexts[mode]
1708
        program = dist_context.dist_main_programs[self._cur_rank]
1709
        cur_dist_attr = auto_utils.get_dist_attr(program, dist_context)
1710 1711
        converter = Converter(state_dict, dist_attr, cur_dist_attr)
        state_dict = converter.convert(strict=strict)
1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724
        for name, param in program.state_dict().items():
            param_array = np.array(param)
            if name not in state_dict:
                continue
            if param_array.dtype != state_dict[name].dtype:
                self._logger.info(
                    "cast {}'s dtype from '{}' to '{}'".format(
                        name,
                        str(state_dict[name].dtype),
                        str(param_array.dtype),
                    )
                )
                state_dict[name] = state_dict[name].astype(param_array.dtype)
1725 1726 1727
        program.set_state_dict(state_dict)

    def save(self, path, training=True):
1728 1729
        """
        Saves the model, parameters, optimizer state to path.
1730 1731 1732 1733 1734 1735 1736
        If `training` is set to False, only inference model will be saved.

        Args:
            path (str): The file prefix to save model. The format
                is 'dirname/file_prefix' or 'file_prefix'. if empty str.
                A exception will be raised.
            training (bool, optional): Whether to save for training. If not, save
1737
                for inference only. If `training` is set to True, the optimizer state
1738 1739 1740 1741 1742 1743 1744 1745 1746 1747
                will be saved. Otherwise, only the model and parameters are saved.
                This function will silently overwrite existing file at the target
                location. Default: True.

        Returns:
            None

        Examples:

            .. code-block:: python
1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770

                >>> import paddle
                >>> import paddle.vision.transforms as T
                >>> from paddle.distributed.fleet import auto
                >>> from paddle.vision.datasets import MNIST

                >>> transform = T.Compose([
                ...     T.Transpose(),
                ...     T.Normalize([127.5], [127.5])
                >>> ])
                >>> train_dataset = MNIST(mode='train', transform=transform)

                >>> model = paddle.vision.models.LeNet()
                >>> loss = paddle.nn.CrossEntropyLoss()
                >>> optimizer = paddle.optimizer.Adam(
                ...     learning_rate=0.001, parameters=model.parameters())
                >>> metrics = paddle.metric.Accuracy(topk=(1, 2))

                >>> engine = auto.Engine(model, loss, optimizer, metrics)
                >>> engine.fit(train_dataset,
                ...             epochs=1,
                ...             batch_size=64)
                >>> engine.save("./my_model")
1771

1772
        """
1773
        if training:
1774
            assert self._mode in self._dist_contexts
Z
zhaoyingli 已提交
1775
            dist_context = self._dist_contexts[self._mode]
1776 1777
            serial_program = dist_context.serial_main_program
            dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
1778 1779 1780 1781 1782 1783
            self._saver.save(
                path,
                serial_program=serial_program,
                dist_main_program=dist_main_prog,
                dist_context=dist_context,
            )
1784
        else:
1785 1786 1787 1788 1789
            assert "predict" in self._dist_contexts
            dist_context = self._dist_contexts["predict"]
            feed_vars = dist_context.serial_feed_vars['inputs']
            fetch_vars = dist_context.serial_fetch_vars['outputs']
            dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
1790
            if self._strategy.qat.enable and self._strategy.qat.onnx_format:
1791
                from paddle.static.quantization import QuantWeightPass
1792 1793 1794

                self._logger.info("export quantized model.")
                self._logger.info(
1795
                    f"convert config {self._strategy.qat.to_dict()}"
1796 1797 1798 1799 1800 1801 1802 1803
                )
                test_graph = IrGraph(
                    core.Graph(dist_main_prog.desc), for_test=True
                )
                quant_weight_pass = QuantWeightPass(global_scope(), self._place)
                for sub_graph in test_graph.all_sub_graphs():
                    quant_weight_pass.apply(sub_graph)
                dist_main_prog = test_graph.to_program()
1804 1805 1806 1807 1808 1809 1810
            self._saver.save_inference_model(
                path,
                feed_vars,
                fetch_vars,
                self._executor,
                program=dist_main_prog,
            )
1811

1812 1813 1814 1815 1816 1817
    def load(self, path, strict=True, load_optimizer=True):
        """
        Load the stored model, parameters and optimizer states.

        Args:
            path (str): The prefix of files storing the model states and
1818
                optimizer states.
1819 1820 1821
            strict (bool, optional): Whether to skip the loading of mismatch
                parameter or raise an error when mismatch happens (not found
                the parameter in file storing model states of or receives a
1822
                mismatch shape). Default: True.
1823
            load_optimizer (bool, optional): If True, the stored optimizer
1824
                states is restored. Otherwise, the optimizer states is initialized
1825
                from scratch. Default: True.
1826 1827 1828 1829 1830 1831 1832

        Returns:
            None

        Examples:

            .. code-block:: python
1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856

                >>> import paddle
                >>> import paddle.vision.transforms as T
                >>> from paddle.distributed.fleet import auto
                >>> from paddle.vision.datasets import MNIST

                >>> transform = T.Compose([
                ...     T.Transpose(),
                ...     T.Normalize([127.5], [127.5])
                >>> ])
                >>> train_dataset = MNIST(mode='train', transform=transform)

                >>> model = paddle.vision.models.LeNet()
                >>> loss = paddle.nn.CrossEntropyLoss()
                >>> optimizer = paddle.optimizer.Adam(
                ...     learning_rate=0.001, parameters=model.parameters())
                >>> metrics = paddle.metric.Accuracy(topk=(1, 2))

                >>> engine = auto.Engine(model, loss, optimizer, metrics)
                >>> engine.fit(train_dataset,
                ...             epochs=1,
                ...             batch_size=64)
                >>> engine.save("./my_model")
                >>> engine.load("./my_model")
1857

1858 1859 1860
        """
        self._strict = strict
        self._state_dict, self._dist_attr = self._saver.load(
1861 1862
            path, load_optimizer
        )
1863
        return self._state_dict, self._dist_attr
1864

1865
    def cost(self, inputs_spec=None, labels_spec=None, mode=None):
1866 1867 1868 1869 1870 1871 1872 1873 1874 1875
        """
        Get and Print cost, including memory of every rank,
        max memory among all ranks, and the global cost of one step based on
        communication cost(computation cost is 0 by default).
        In the future, the flops information of every rank and global cost including
        computation cost will be added.

        Args:
            inputs_spec(InputSpec): The specification of inputs. Default: None.
            labels_spec(InputSpec): The specification of labels. Default: None.
1876
            mode (str): The engine mode must be in ["train", "predict", "eval"]. Default: None.
1877 1878 1879 1880 1881 1882 1883

        Returns:
            Return the global execution time (ms) and max memory (B).

        """
        # Check parallel mode
        if self._strategy.auto_mode == "full":
1884
            self._logger.info(
1885 1886 1887 1888 1889
                "The cost will be calcudated in the search process when the auto mode is full."
            )
            return

        # Check mode
1890 1891 1892
        mode = mode if mode is not None else self._mode
        assert mode is not None, "Please set mode."
        if mode not in self._has_prepared:
1893 1894
            raise ValueError(
                "The mode {} is not in accepted modes {}".format(
1895
                    mode, list(self._has_prepared.keys())
1896 1897
                )
            )
1898 1899
        self.to_mode(mode)

1900 1901 1902
        if inputs_spec is not None and not self._has_prepared[mode]:
            self._inputs_spec = self._validate_spec(inputs_spec)
            self._labels_spec = self._validate_spec(labels_spec)
1903 1904 1905
            self._build(mode)
            self._plan(mode)
        else:
1906
            if in_dynamic_mode() or self._dygraph_mode:
1907
                raise ValueError(
1908 1909 1910 1911 1912
                    "Please call `prepare()` or `fit()` or  `evaluate()` or  `predict()` before calling `cost()`."
                )
            else:
                self._logger.info(
                    "The program whose cost to be estimated must be static default program. Otherwise, please call `prepare()`before calling `cost()`."
1913
                )
1914 1915 1916 1917 1918 1919 1920 1921
                program = paddle.static.default_main_program()
                if (
                    not program.global_block().ops
                    or not program.global_block().ops
                ) and not self._has_prepared[mode]:
                    raise ValueError(
                        "Please call `prepare()` or `fit()` or  `evaluate()` or  `predict()` before calling `cost()`."
                    )
1922 1923 1924 1925 1926 1927

        # Estimate the exec cost and max memory
        global_cost, max_memory = get_cost_from_engine(self, mode)

        return global_cost.time, max_memory

1928 1929
    @property
    def main_program(self):
1930 1931
        dist_context = self._dist_contexts[self._mode]
        return dist_context.dist_main_programs[self._cur_rank]
1932 1933 1934

    @property
    def startup_program(self):
1935 1936
        dist_context = self._dist_contexts[self._mode]
        return dist_context.dist_startup_programs[self._cur_rank]
1937 1938 1939

    @property
    def dist_context(self):
1940
        return self._dist_contexts[self._mode]
1941 1942 1943

    @property
    def serial_main_program(self):
1944 1945
        dist_context = self._dist_contexts[self._mode]
        return dist_context.serial_main_program
1946 1947 1948

    @property
    def serial_startup_program(self):
1949 1950 1951 1952 1953 1954 1955
        dist_context = self._dist_contexts[self._mode]
        return dist_context.serial_startup_program

    @property
    def feed_vars(self):
        dist_context = self._dist_contexts[self._mode]
        return dist_context.serial_feed_vars
1956 1957 1958

    @property
    def fetch_vars(self):
1959 1960 1961 1962 1963 1964 1965 1966 1967
        dist_context = self._dist_contexts[self._mode]
        return dist_context.serial_fetch_vars

    @property
    def optimizer(self):
        dist_context = self._dist_contexts[self._mode]
        if dist_context._serial_optimizer:
            return dist_context._serial_optimizer
        return self._optimizer
1968 1969 1970

    @property
    def inputs(self):
1971
        return self._inputs
1972 1973 1974

    @property
    def labels(self):
1975
        return self._labels