engine.py 68.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import copy
16
import logging
17
import numbers
18 19
import os
import random
20 21
from collections import defaultdict

22 23
import numpy as np

24
import paddle
25
import paddle.distributed.auto_parallel.utils as auto_utils
26
import paddle.utils as utils
Z
zhaoyingli 已提交
27
from paddle import fluid, static
28 29 30 31
from paddle.distributed import fleet
from paddle.fluid import Variable, core
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.executor import _to_name_str, global_scope
32
from paddle.fluid.framework import IrGraph, Operator
33
from paddle.fluid.framework import _current_expected_place as _get_device
34
from paddle.fluid.framework import in_dygraph_mode
35
from paddle.fluid.layers.utils import flatten
36
from paddle.metric import Metric
37 38
from paddle.static import InputSpec

39
from ..utils.log_utils import get_logger
Z
zhaoyingli 已提交
40
from .callbacks import config_callbacks
41
from .cluster import Cluster, get_default_cluster
42 43 44
from .converter import Converter
from .cost.estimate_cost import get_cost_from_engine
from .dist_context import DistributedContext, get_default_distributed_context
45 46
from .dist_loader import (
    DistributedDataLoader,
47
    DistributedDataLoaderFromGenerator,
48
)
49 50 51
from .dist_op import DistributedOperator
from .dist_saver import DistributedSaver
from .helper import ProgramHelper
52
from .interface import CollectionNames, get_collection
53 54 55 56
from .parallelizer_v2 import Parallelizer
from .planner_v2 import Planner
from .process_group import get_all_process_groups, new_process_group
from .strategy import Strategy
57

58 59

class Engine:
60
    """
61 62
    An Engine object can provide the full power of auto parallel to users.
    With the help of it, users can easily obtain the abilities of the
63 64 65 66 67 68 69
    distributed training and inference. It also support the dynamic graph and
    static graph at the same time.

    Args:
        model (paddle.nn.Layer, optional): The model is an instance of
            paddle.nn.Layer.
        loss (Loss|Callable|None, optional): The loss can be a `paddle.nn.Layer`
70 71
            instance or any callable function taken the predicted values and
            ground truth values as input. It can be None when there is no loss.
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
            Default: None.
        optimizer (Optimizer|None, optional): The optimizer need to be set in training
            and should be None in eval and predict mode. Default: None.
        metrics (Metric|list[Metric]|None, optional): If metrics is set, all
            metrics will be calculated and output in train/eval mode. Default: None.
        cluster (Cluster|None, optional): The cluster represents the topology information
            about the used physical devices. Default: None. (Unused for now)
        strategy (Strategy|None, optional): The strategy is used to configure the
        parallelization and optimization behaviors. Default: None.

    Examples:

        .. code-block:: python

            import paddle
            import paddle.vision.transforms as T
88
            from paddle.distributed.fleet import auto
89 90 91 92 93 94 95 96 97 98
            from paddle.vision.datasets import MNIST

            transform = T.Compose([
                T.Transpose(),
                T.Normalize([127.5], [127.5])
            ])
            train_dataset = MNIST(mode='train', transform=transform)
            valid_dataset = MNIST(mode='test', transform=transform)

            model = paddle.vision.models.LeNet()
99
            loss = paddle.nn.CrossEntropyLoss()
100 101 102 103
            optimizer = paddle.optimizer.Adam(
                learning_rate=0.001, parameters=model.parameters())
            metrics = paddle.metric.Accuracy(topk=(1, 2))

104 105
            engine = auto.Engine(model, loss, optimizer, metrics)
            # fit
106 107 108
            engine.fit(train_dataset,
                       epochs=2,
                       batch_size=64)
109
            # evaluate
110 111 112 113 114 115 116
            engine.evaluate(valid_dataset,
                            batch_size=64)
            # predict
            engine.predict(valid_dataset,
                           batch_size=64)
            # save
            engine.save("./my_model")
117
            # load
118 119 120
            engine.load("./my_model")

    """
121

122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
    def __init__(
        self,
        model=None,
        loss=None,
        optimizer=None,
        metrics=None,
        cluster=None,
        strategy=None,
    ):

        if (
            model
            and not isinstance(model, paddle.nn.Layer)
            and not callable(model)
        ):
137 138 139 140
            raise TypeError(
                "'model must be sub classes of `paddle.nn.Layer` or any callable function."
            )
        self._model = model
141 142 143 144 145 146 147 148 149

        if (
            loss
            and not isinstance(loss, (paddle.nn.Layer, Variable))
            and not callable(loss)
        ):
            raise TypeError(
                "'loss' must be sub classes of `paddle.nn.Layer` or any callable function or a Variable."
            )
150 151 152
        self._loss = loss

        if optimizer and not isinstance(
153 154 155
            optimizer,
            (paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer),
        ):
156 157
            raise TypeError(
                "'optimizer' must be object of class `paddle.optimizer.Optimizer`"
158 159
                " or `paddle.fluid.optimizer.Optimizer`."
            )
160
        self._optimizer = auto_utils.validate_opt(optimizer)
161
        self._orig_optimizer = copy.deepcopy(self._optimizer)
162 163

        metrics = metrics or []
164
        for metric in auto_utils.to_list(metrics):
165 166 167 168 169 170
            if metric and not isinstance(metric, Metric):
                raise TypeError(
                    "{} is not sub class of Metric".format(
                        metric.__class__.__name__
                    )
                )
171
        self._metrics = auto_utils.to_list(metrics)
172 173 174 175 176 177 178 179 180 181 182 183 184

        if cluster and not isinstance(cluster, Cluster):
            raise TypeError(
                "'cluster' must be the object or class `paddle.distributed.auto_parallel.Cluster`"
            )
        self._cluster = cluster or get_default_cluster()

        if strategy and not isinstance(strategy, Strategy):
            raise TypeError(
                "'strategy' must be object of class `paddle.distributed.auto_parallel.Strategy`"
            )
        self._strategy = strategy or Strategy()

185
        self._logger = get_logger(logging.INFO)
186
        if os.getenv("POD_NAME"):
187 188
            self._logger.info(
                "Distribute training by paddle.distributed.launch"
189
            )
190
            fleet.init(is_collective=True)
191

192
        self._executor = None
193 194 195
        self._cur_rank = paddle.distributed.get_rank()
        self._nranks = paddle.distributed.get_world_size()
        self._saver = DistributedSaver()
196

197 198
        self._orig_main_prog = static.default_main_program()
        self._orig_startup_prog = static.default_startup_program()
199
        self._orig_dist_context = get_default_distributed_context()
200
        self._dist_contexts = {}
201 202
        self._fwd_main_progs = {}
        self._fwd_dist_contexts = {}
203 204
        self._serial_main_progs = {}
        self._serial_startup_progs = {}
205 206 207 208
        self._dist_main_progs = defaultdict(dict)  # dist main programs
        self._dist_startup_progs = defaultdict(dict)  # dist startup programs
        self._feed_vars = {}
        self._fetch_vars = {}
209
        self._planners = {}
210 211
        self._has_prepared = {"train": False, "eval": False, "predict": False}
        self._has_prepared_reader = {
212 213
            "train": False,
            "eval": False,
214
            "predict": False,
215
        }
216 217 218 219
        self._inputs_spec = []
        self._labels_spec = []
        self._inputs = []
        self._labels = []
220
        self._losses = []
221

222
        self._mode = None
223 224
        self._skip_build = False
        self._outside_dataloader = False
225
        self._planned_mode = None
226 227
        self._dygraph_mode = False
        self._tuning = self._strategy.tuning
228

Z
zhaoyingli 已提交
229 230
        self.history = None

231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248
    def _prepare_data_spec(self, data, split, batch_size):
        inputs_spec = []
        labels_spec = []
        if isinstance(data, paddle.io.IterableDataset):
            if split is None:
                inputs, labels = next(iter(data))
            else:
                sample = next(iter(data))
                inputs = sample[:split]
                labels = sample[split:]
        elif isinstance(data, paddle.io.Dataset):
            if split is None:
                inputs, labels = data[0]
            else:
                sample = data[0]
                inputs = sample[:split]
                labels = sample[split:]
        else:
249
            raise TypeError(
250 251 252 253
                "Data should be a Dataset or IterableDatset, but received {}.".format(
                    type(data).__name__
                )
            )
254 255
        inputs = auto_utils.to_list(inputs)
        labels = auto_utils.to_list(labels)
256 257

        num_shards = self._strategy.dataset.num_shards
258

259 260 261 262 263 264 265 266 267 268 269 270 271 272
        def _adjust_item_spec(num_shards, spec):
            if num_shards > 1 and len(spec.shape) > 1:
                spec.shape[0] = spec.shape[0] * num_shards

        def _infer_item_spec(item, name, batch_size, specs):
            if isinstance(item, np.ndarray):
                spec = InputSpec.from_numpy(item, name)
                if batch_size is None:
                    _adjust_item_spec(num_shards, spec)
                    specs.append(spec)
                else:
                    specs.append(spec.batch(batch_size))
            elif isinstance(item, (Variable, core.VarBase, core.eager.Tensor)):
                spec = InputSpec.from_tensor(item, name)
273
                _adjust_item_spec(num_shards, spec)
274 275 276 277
                if batch_size is None:
                    specs.append(spec)
                else:
                    specs.append(spec.batch(batch_size))
278
            elif isinstance(item, numbers.Number):
279
                specs.append(InputSpec([batch_size], type(item), name))
280 281 282 283 284 285
            else:
                raise TypeError(
                    "The sample's dtype returned of dataset should be number, np.ndarray or Tensor, but got {}".format(
                        type(item).__name__
                    )
                )
286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301

        if inputs is not None:
            for i, item in enumerate(inputs):
                assert item is not None, "Receive None input."
                name = "input" + str(i)
                _infer_item_spec(item, name, batch_size, inputs_spec)
        if labels is not None:
            for i, item in enumerate(labels):
                assert item is not None, "Receive None input."
                name = "label" + str(i)
                _infer_item_spec(item, name, batch_size, labels_spec)

        inputs_spec = self._validate_spec(inputs_spec)
        labels_spec = self._validate_spec(labels_spec)
        return inputs_spec, labels_spec

302
    def _prepare_data_tensor(self, inputs_spec, labels_spec, inputs, labels):
303
        if in_dygraph_mode() or self._dygraph_mode:
304 305
            raise ValueError("Only support static graph mode.")

306
        if inputs_spec:
307 308 309 310 311
            assert isinstance(
                inputs_spec, list
            ), "inputs should be list, but received {}".format(
                type(inputs_spec)
            )
312 313 314 315 316 317 318 319 320
            assert isinstance(
                inputs, list
            ), "inputs should be list, but received {}".format(type(inputs))
            assert len(inputs_spec) == len(
                inputs
            ), "the number of `inputs_spec` should be equal to `inputs`'s."
            for input_spec, input in zip(inputs_spec, inputs):
                if input_spec.shape != input.shape:
                    input.desc.set_shape(input_spec.shape)
321
        if labels_spec:
322 323 324 325 326
            assert isinstance(
                labels_spec, list
            ), "labels should be list, but received {}".format(
                type(labels_spec)
            )
327 328 329 330 331 332 333 334 335 336
            assert isinstance(
                labels, list
            ), "labels should be list, but received {}".format(type(labels))
            assert len(labels_spec) == len(
                labels
            ), "the number of `labels_spec` should be equal to `labels`'s."
            for label_spec, label in zip(labels_spec, labels):
                if label_spec.shape != label.shape:
                    label.desc.set_shape(label_spec.shape)

337 338 339 340 341 342 343 344 345
        return inputs, labels

    def _prepare_reader(self):
        dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank]
        dist_context = self._dist_contexts[self._mode]
        dist_main_block = dist_main_prog.global_block()

        # NOTE: this list may be changed if Paddle changes the existing rules.
        related_reader_ops = [
346 347 348
            "create_py_reader",
            "create_double_buffer_reader",
            "read",
349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365
        ]
        # remove the first three ops if multiple run fit/evaluate/predict
        if dist_main_block.ops[0].type == 'create_py_reader':
            for i in range(len(related_reader_ops)):
                if dist_main_block.ops[0].type in related_reader_ops:
                    dist_main_block._remove_op(0, sync=False)
        dist_main_block._sync_with_cpp()
        # Step 1: find the reader ops
        reader_op_indices = []
        for idx, op in enumerate(dist_main_block.ops):
            if op.type in related_reader_ops:
                reader_op_indices.append(idx)
        # Step 2: insert the new reader ops to cpp
        new_reader_ops = []
        for idx in reversed(reader_op_indices):
            new_op_desc = dist_main_block.desc._prepend_op()
            new_op_desc.copy_from(dist_main_block.ops[idx].desc)
366 367 368
            new_op = Operator(
                dist_main_block, new_op_desc, type=new_op_desc.type()
            )
369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397
            new_reader_ops.append(new_op)
            dist_op = DistributedOperator(new_op)
            dist_context.add_dist_op_for_program(dist_op)
        # Step 3: insert the new reader ops to python
        for new_op in new_reader_ops:
            dist_main_block.ops.insert(0, new_op)
        for i in range(len(reader_op_indices)):
            reader_op_indices[i] += len(reader_op_indices)
        # Step 4: remove the old reader ops from python and cpp
        for idx in reversed(reader_op_indices):
            op = dist_main_block.ops.pop(idx)
            dist_main_block.desc._remove_op(idx, idx + 1)
        dist_main_block._sync_with_cpp()
        self._has_prepared_reader[self._mode] = True

    def _prepare_feed(self, data, user_feeds, mode):
        feeds = {}
        if data is not None:
            if isinstance(data, (list, tuple)):
                if len(data) == 1 and isinstance(data[0], dict):
                    for name, data in data[0].items():
                        feeds[name] = data
                else:
                    raise ValueError("Unsupported data {}".format(data))
            elif isinstance(data, dict):
                for name, data in data.items():
                    feeds[name] = data
            else:
                raise ValueError("Unsupported data {}".format(data))
398
        if user_feeds is not None:
399 400 401 402 403
            assert isinstance(
                user_feeds, dict
            ), "user_feeds must be a dict, but receive {}".format(
                type(user_feeds).__name__
            )
404 405
            for name, data in user_feeds.items():
                feeds[name] = data
406 407
        return feeds

408
    def _prepare_fetch(self, user_fetches, mode):
409
        if user_fetches is not None:
410 411 412 413 414
            assert isinstance(
                user_fetches, list
            ), "user_fetches must be a list, but receive {}".format(
                type(user_fetches).__name__
            )
415
        fetch_names = []
416
        fetch_indices = []
417

418 419
        def _process_fetch_group(group_name, var_list):
            group_indices = []
420
            for var in var_list:
421 422 423 424 425 426
                # Remove duplicate var_names
                if self._is_local_var(var):
                    var_name = _to_name_str(var)
                    if var_name not in fetch_names:
                        fetch_names.append(var_name)
                    group_indices.append(fetch_names.index(var_name))
427 428
            if not group_indices:
                fetch_names.append([])
429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445
            fetch_indices.append(group_indices)

        if mode != "predict":
            _process_fetch_group("loss", self._fetch_vars[mode]["loss"])
        if mode != "predict":
            metrics = self._fetch_vars[mode]["metrics"]
            for i, var_list in enumerate(metrics):
                _process_fetch_group("metrics_" + str(i), var_list)
        if mode == "predict":
            _process_fetch_group("outputs", self._fetch_vars[mode]["outputs"])
        user_fetches_collection = [
            item[1] for item in get_collection(CollectionNames.FETCHES)
        ]
        var_list = (user_fetches_collection or []) + (user_fetches or [])
        _process_fetch_group("fetches", var_list)
        return fetch_names, fetch_indices

446 447 448 449 450 451 452 453 454 455
    def _prepare_logger(
        self,
        outs,
        epoch=None,
        step=None,
        lr=None,
        fetch_names=None,
        fetch_indices=None,
        mode=None,
    ):
Z
zhaoyingli 已提交
456
        logs = {}
457
        if epoch is not None:
Z
zhaoyingli 已提交
458
            logs["epoch"] = epoch
459
        if step is not None:
Z
zhaoyingli 已提交
460
            logs["step"] = step + 1
461
        if lr is not None:
Z
zhaoyingli 已提交
462
            logs["lr"] = lr
463 464
        group_idx = 0
        if mode != "predict":
Z
zhaoyingli 已提交
465
            # logging loss
466
            loss_indices = fetch_indices[group_idx]
Z
zhaoyingli 已提交
467
            assert len(loss_indices) <= 1
468
            for idx in loss_indices:
Z
zhaoyingli 已提交
469
                logs["loss"] = outs[idx][0]
470
            group_idx += 1
Z
zhaoyingli 已提交
471
            # logging metrics
472 473 474 475 476 477 478 479 480 481
            metric_vars = self._fetch_vars[mode]["metrics"]
            if metric_vars:
                for metric in self._metrics:
                    metrics_indices = fetch_indices[group_idx]
                    metric_out = []
                    for idx in metrics_indices:
                        metric_out.append(outs[idx])
                    if metric_out:
                        metric.update(*metric_out)
                        results = metric.accumulate()
482
                        for i, res in enumerate(auto_utils.to_list(results)):
Z
zhaoyingli 已提交
483
                            logs[metric.name()[i]] = res
484
                    group_idx += 1
Z
zhaoyingli 已提交
485 486 487 488 489 490 491
        # logging outputs
        elif mode == "predict":
            outputs_indices = fetch_indices[group_idx]
            logs_out = {}
            for idx in outputs_indices:
                logs_out["out%d" % (idx)] = outs[idx]
            logs["outputs"] = logs_out
492 493
            group_idx += 1
        # logging user fetches
Z
zhaoyingli 已提交
494 495
        collect_fetches = get_collection(CollectionNames.FETCHES)
        logs_fetch = {}
496 497 498 499
        for name, var_name in collect_fetches:
            if var_name in fetch_names:
                idx = fetch_names.index(var_name)
                logs_fetch[name or var_name] = outs[idx]
Z
zhaoyingli 已提交
500 501
        logs["fetches"] = logs_fetch
        return logs
502

503 504 505 506 507 508 509 510 511 512 513
    def _prepare_program(self, mode):
        # Do the build process
        self._build(mode)
        # Do the planning process
        self._plan(mode)
        # Do the parallel process
        self._parallel(mode)
        # Init comm and startup program
        self._initialize(mode)
        self._has_prepared[mode] = True

514
    def _build(self, mode):
515
        if in_dygraph_mode() or self._dygraph_mode:
516
            paddle.disable_static()
517 518 519
            self._dygraph_mode = True
            self._logger.info("Building model with 'to_static' method.")

520
            self.program_helper = ProgramHelper(
521 522 523 524 525
                self._model,
                self._loss,
                self._metrics,
                self._inputs_spec,
                self._labels_spec,
526
            )
527
            # build forward main program
528
            self.program_helper.build_program(mode)
529

530 531 532
            self.concrete_program = self.program_helper.concrete_program
            serial_main_prog = self.program_helper.main_program
            serial_startup_prog = self.program_helper.startup_program
533

534 535
            self._inputs = self.program_helper.input_vars
            self._labels = self.program_helper.label_vars
536
            outputs = self.program_helper.output_vars
537
            self._losses = self.program_helper.loss_vars
538
            metrics = self.program_helper.metric_vars
539

540
            paddle.enable_static()
541
        else:
542
            # build program in static graph mode
543 544 545 546
            serial_main_prog = self._serial_main_progs.get(mode, None)
            if serial_main_prog is not None:
                return

547
            outputs = []
548
            metrics = []
549
            self._losses = []
550 551
            serial_main_prog = self._orig_main_prog.clone()
            serial_startup_prog = self._orig_startup_prog.clone()
552
            if not self._skip_build:
553 554 555
                with static.program_guard(
                    serial_main_prog, serial_startup_prog
                ), utils.unique_name.guard():
556 557 558 559 560 561 562
                    self._inputs = [
                        s._create_feed_layer() for s in self._inputs_spec
                    ]
                    self._labels = [
                        s._create_feed_layer() for s in self._labels_spec
                    ]

563
                    outputs = auto_utils.to_list(self._model(*self._inputs))
564

565
                    if mode != "predict" and self._loss:
566 567 568 569 570
                        assert isinstance(
                            self._loss, paddle.nn.Layer
                        ) or callable(
                            self._loss
                        ), "the type of `loss` of the Engine arguments should be sub classes of `paddle.nn.Layer` or any callable function."
571
                        self._losses = auto_utils.to_list(
572 573
                            self._loss(*(outputs + self._labels))
                        )
574

575
                    if mode != "predict" and (outputs or self._labels):
576 577
                        for metric in self._metrics:
                            metrics.append(
578
                                auto_utils.to_list(
579 580
                                    metric.compute(*(outputs + self._labels))
                                )
581
                            )
Z
zhaoyingli 已提交
582
            elif mode == "train":
583 584 585
                assert isinstance(
                    self._loss, Variable
                ), "the type of `loss` of the Engine arguments should be Variable."
586
                self._losses = auto_utils.to_list(self._loss)
587 588 589 590 591 592 593

        default_ctx = get_default_distributed_context()
        if not default_ctx.has_annotation:
            # We build the world process group because the data parallel
            # needs all ranks by default.
            new_process_group(list(range(self._nranks)))
            default_ctx.data_parallel = True
594 595 596 597 598 599
            self._inputs = [
                auto_utils.set_data_parallel(var) for var in self._inputs
            ]
            self._labels = [
                auto_utils.set_data_parallel(var) for var in self._labels
            ]
600

601
        feed_vars = {"inputs": self._inputs, "labels": self._labels}
602 603 604

        fetch_vars = {
            "outputs": flatten(outputs),
605
            "loss": self._losses,
606
            "metrics": metrics,
607 608
        }

609 610 611
        if mode != "train":
            serial_main_prog = serial_main_prog.clone(for_test=True)

612 613 614
        auto_utils.set_recompute_segments(
            self._model, self._losses, self._strategy, serial_main_prog
        )
615
        self._dist_contexts[mode] = DistributedContext(
616 617 618
            serial_main_prog,
            serial_startup_prog,
            self._optimizer,
619 620 621 622 623 624 625 626 627 628 629
            self._losses,
            feed_vars,
            fetch_vars,
            self._cluster,
            self._strategy,
        )
        self._fwd_dist_contexts[mode] = DistributedContext(
            serial_main_prog,
            serial_startup_prog,
            self._optimizer,
            self._losses,
630 631 632 633 634
            feed_vars,
            fetch_vars,
            self._cluster,
            self._strategy,
        )
635
        self._dist_contexts[mode].gradient_scale = self._strategy.gradient_scale
636
        self._fwd_main_progs[mode] = serial_main_prog.clone()
637

638 639 640
    def _optimization_tuning(self, mode, dataset, batch_size):
        if not self._tuning.enable:
            raise ValueError("Please set `tuning.enable=True`.")
641

642 643 644 645 646 647 648 649
        assert mode == "train"
        # Do the build process
        self._build(mode)
        # Do the planning process
        self._plan(mode)

        dataset.dp_world_size = self._dp_world_sizes
        dataset.dp_rank = self._dp_ranks
650 651

        from .tuner.optimization_tuner import OptimizationTuner
652 653 654 655 656 657 658 659 660

        self._optimization_tuner = OptimizationTuner(
            self._dist_contexts[mode],
            dataset,
            self._inputs_spec,
            self._labels_spec,
            batch_size=batch_size,
            rank=self._cur_rank,
        )
661 662 663

        self._optimization_tuner.tune()

664
        if self._tuning.run_after_tuning:
665 666
            # update the strategy
            self._dist_contexts[
667 668
                mode
            ]._strategy = self._optimization_tuner.get_best_config()
669

670 671 672 673 674 675
    def _plan(self, mode):
        if self._planned_mode is None:
            self._planned_mode = mode
        else:
            self._init_dist_context(mode)

676 677
        self._planners[mode] = Planner(mode, self._dist_contexts[mode])
        self._planners[mode].plan()
678

679 680 681 682
        # infer data parallel info
        inputs_var = self._dist_contexts[mode].serial_feed_vars["inputs"]
        labels_var = self._dist_contexts[mode].serial_feed_vars["labels"]
        block = self._dist_contexts[mode].serial_main_program.global_block()
683
        # TODO: check this feed_list
684 685 686 687 688
        feed_list = []
        for var in inputs_var + labels_var:
            if var.name in block.vars:
                feed_list.append(block.vars[var.name])

689 690
        self._dp_world_sizes = []
        self._dp_ranks = []
691
        for feed_var in feed_list:
692
            dp_world_size, dp_rank = auto_utils.get_input_split_info(
693
                self._cur_rank, feed_var, self._dist_contexts[mode]
694
            )
695 696
            self._dp_world_sizes.append(dp_world_size)
            self._dp_ranks.append(dp_rank)
697

698
    def _parallel(self, mode, all_ranks=False):
699 700 701
        # Parallelize program based on the planner's results
        # For now, the completer has to be passed to the planner,
        # because we may use it to complete the annotation of the backwarkward and update.
702 703 704
        parallelizer = Parallelizer(
            mode, self._planners[mode].completer, self._dist_contexts[mode]
        )
705 706 707 708
        if not all_ranks:
            parallelizer.parallel(self._cur_rank)
        else:
            parallelizer.parallel_all()
709 710

    def _init_dist_context(self, mode):
711
        # Init dist_context['mode'] with the first planned dist_context
712 713 714 715 716 717 718 719 720 721
        # to guarantee that train/eval/predict mode have same parallel strategy
        dist_context = self._dist_contexts[mode]
        origin_main_prog = dist_context._original_serial_main_program
        ref_mode = self._planned_mode
        ref_dist_context = self._dist_contexts[ref_mode]
        ref_origin_main_prog = ref_dist_context._original_serial_main_program
        ref_blocks = ref_origin_main_prog.blocks
        for ib, block in enumerate(origin_main_prog.blocks):
            for iop, op in enumerate(block.ops):
                ref_op = ref_blocks[ib].ops[iop]
722 723 724 725 726 727 728 729
                assert (
                    op.type == ref_op.type
                ), "'{}' mode op '{}' is different with '{}' op '{}'. ".format(
                    mode, op.type, ref_mode, ref_op.type
                )
                ref_op_dist_attr = (
                    ref_dist_context.get_op_dist_attr_for_program(ref_op)
                )
730 731 732
                dist_context.set_op_dist_attr_for_program(op, ref_op_dist_attr)

    def _initialize(self, mode):
733
        # Get the current content from the distributed context
734
        self._serial_main_progs[mode] = self._dist_contexts[
735 736
            mode
        ].serial_main_program
737
        self._serial_startup_progs[mode] = self._dist_contexts[
738 739
            mode
        ].serial_startup_program
740
        self._dist_main_progs[mode] = self._dist_contexts[
741 742
            mode
        ].dist_main_programs
743
        self._dist_startup_progs[mode] = self._dist_contexts[
744 745
            mode
        ].dist_startup_programs
746 747
        self._feed_vars[mode] = self._dist_contexts[mode].serial_feed_vars
        self._fetch_vars[mode] = self._dist_contexts[mode].serial_fetch_vars
Z
zhaoyingli 已提交
748
        self._optimizer = self._dist_contexts[mode]._serial_optimizer
749

750 751 752 753
        if self._nranks > 1:
            # Traverse different rank programs and traverse each op of them,
            # instantiate communication by process_mapping.
            all_process_groups = get_all_process_groups()
C
caozhou 已提交
754
            cur_rank = self._cur_rank
755 756 757
            # NOTE: After the implementation of the unified dynamic and static communication group
            # initialization mode in the future, the initialization logic of full mode
            # will be removed because port occupation error may occur.
758
            if self._strategy.auto_mode == "full":
759 760 761
                auto_utils.initialize_pg_in_full_mode(
                    all_process_groups, cur_rank
                )
762 763
            else:
                for process_group in all_process_groups:
C
caozhou 已提交
764
                    if cur_rank not in process_group.ranks:
765 766
                        continue
                    process_group.instantiate()
767

768 769 770
        self._place = _get_device()
        if isinstance(self._place, fluid.CUDAPlace):
            self._place = fluid.CUDAPlace(ParallelEnv().dev_id)
771

772 773 774 775 776
        if self._strategy.seed:
            paddle.seed(self._strategy.seed + self._dp_ranks[0])
            np.random.seed(self._strategy.seed + self._dp_ranks[0])
            random.seed(self._strategy.seed + self._dp_ranks[0])

777
        if self._dygraph_mode:
778 779
            dist_context = self._dist_contexts[mode]
            dist_main_program = self._dist_main_progs[mode][self._cur_rank]
780 781 782
            self.program_helper.init(
                dist_main_program, self._place, dist_context
            )
783

784
        if self._executor is None:
785
            self._executor = paddle.static.Executor(self._place)
786 787 788 789 790 791 792 793 794 795
            uninitialized = []
            dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank]
            for var in dist_startup_prog.list_vars():
                scope_var = global_scope().find_var(var.name)
                if scope_var and scope_var.get_tensor()._is_initialized():
                    continue
                uninitialized.append(var)
            if uninitialized:
                prune_startup_prog = dist_startup_prog._prune(uninitialized)
                self._executor.run(prune_startup_prog)
796

797
            if hasattr(self, "_state_dict") and hasattr(self, "_dist_attr"):
798 799 800
                self._set_state_dict(
                    mode, self._strict, self._state_dict, self._dist_attr
                )
801 802

        if self._strategy.reinit:
Z
zhaoyingli 已提交
803
            self._logger.info("NOTE: parameters will be re-initialized.")
804 805 806
            dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank]
            self._executor.run(dist_startup_prog)

807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824
    def fit(
        self,
        train_data,
        train_sample_split=None,
        batch_size=1,
        epochs=1,
        steps_per_epoch=None,
        log_freq=10,
        save_dir=None,
        save_freq=1,
        valid_data=None,
        valid_sample_split=None,
        valid_freq=1,
        valid_steps=None,
        collate_fn=None,
        callbacks=None,
        verbose=2,
    ):
825 826 827 828 829 830 831 832
        """
        Trains the model for a fixed number of epochs. If `valid_data` is set,
        evaluation will be done at the end of each epoch.

        Args:
            train_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None.
            train_sample_split (int, optional): Each sample of the train dataset is assumed
                to be a (input, label) pair by default and has two items. If each sample has
833
                more than two items, train_sample_split specifies how to split these items into
834
                input and label. The items before it are input and the left are label. Default: None.
835
            batch_size (int, optional): The batch size of train_data and valid_data if provided.
836 837 838
                The user's data will be used directly without batching if set to None. Default: 1.
            epochs (int, optional): The number of epochs to train the model. Default: 1.
            steps_per_epoch (int, optional): The total number of steps (batches of samples)
839
                is executed in one epoch before stating the next one. If None, it is equal to
840 841
                the number samples in your dataset divided by the batch size. Default: None.
            valid_data (Dataset, optional): An instance of paddle paddle.io.Dataset used for
842
                evaluation at the end of epoch. No evaluation will be done if set to None.
843
                Default: None. (Unsupported for now)
844
            valid_freq (int, optional): Only relevant if valid_data is provided. This specifies
845 846
                how many training epochs before a new evaluation is performed. Default: 1.
            valid_sample_split (int, optional): Only relevant if valid_data is provided.
847 848
                Each sample of the valid dataset is assumed to be a (input, label) pair
                by default and has two items. If each sample has more than two items,
849 850 851
                valid_sample_split specifies how to split these items into input and label.
                The items before it are input and the left are label. Default: None.
            valid_steps (int, optional): Only relevant if valid_data is provided.
852 853
                It is the total number of steps (batches of samples) to draw before
                stopping validation at the end of every epoch. If None, validation will run until the
854 855 856 857
                `valid_data` dataset is exhausted. The validation will start from the
                beginning of the dataset at each epoch. Default: None.
            collate_fn(callable, optional): function to generate mini-batch data by merging
                the sample list, None for only stack each fields of sample in axis
858
                0. Default None.
859 860 861 862 863 864 865 866 867 868 869 870
            callbacks (Callback|None, optional): A list of `Callback` instances to apply
                during training. Default: None. (Unused for now)

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle
                import paddle.vision.transforms as T
871
                from paddle.distributed.fleet import auto
872 873 874 875 876 877 878 879 880
                from paddle.vision.datasets import MNIST

                transform = T.Compose([
                    T.Transpose(),
                    T.Normalize([127.5], [127.5])
                ])
                train_dataset = MNIST(mode='train', transform=transform)

                model = paddle.vision.models.LeNet()
881
                loss = paddle.nn.CrossEntropyLoss()
882 883 884 885
                optimizer = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=model.parameters())
                metrics = paddle.metric.Accuracy(topk=(1, 2))

886
                engine = auto.Engine(model, loss, optimizer, metrics)
887 888 889 890
                engine.fit(train_dataset,
                           epochs=2,
                           batch_size=64)
        """
891 892
        self._mode = 'train'
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
893 894
            train_data, train_sample_split, batch_size
        )
895 896
        if not self._has_prepared[self._mode]:
            self._prepare_program(self._mode)
Z
zhaoyingli 已提交
897
        else:
898
            self._switch_mode(self._mode)
Z
zhaoyingli 已提交
899

900 901 902 903 904 905 906
        train_dataloader = self._prepare_dataloader_from_generator(
            dataset=train_data,
            capacity=70,
            iterable=False,
            batch_size=batch_size,
            epochs=epochs,
            steps_per_epoch=steps_per_epoch,
907 908
            collate_fn=collate_fn,
        )
Z
zhaoyingli 已提交
909

910
        fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode)
Z
zhaoyingli 已提交
911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936

        cbks = config_callbacks(
            callbacks,
            engine=self,
            batch_size=batch_size,
            epochs=epochs,
            steps=train_dataloader._steps,
            log_freq=log_freq,
            save_freq=save_freq,
            save_dir=save_dir,
            verbose=verbose,
            metrics=self._metrics_name(),
            acc_step=self._k_steps,
        )

        cbks.on_begin('train')
        for epoch in range(epochs):
            logs = {}
            cbks.on_epoch_begin(epoch)
            for step, _ in enumerate(train_dataloader):
                cbks.on_batch_begin('train', step, logs)
                try:
                    outs = self._executor.run(
                        self.main_program,
                        fetch_list=fetch_names,
                        use_program_cache=self._strategy.use_cache,
937 938
                        return_numpy=self._strategy.return_numpy,
                    )
Z
zhaoyingli 已提交
939 940
                except core.EOFException:
                    break
941
                lr = auto_utils.get_lr(self._optimizer)
942 943 944 945 946 947 948 949 950
                logs = self._prepare_logger(
                    outs,
                    epoch,
                    step,
                    lr,
                    fetch_names,
                    fetch_indices,
                    self._mode,
                )
Z
zhaoyingli 已提交
951 952 953
                cbks.on_batch_end('train', step, logs)

            if valid_data and (epoch + 1) % valid_freq == 0:
954 955 956 957 958 959 960 961 962 963
                val_logs = self.evaluate(
                    valid_data,
                    valid_sample_split,
                    batch_size,
                    valid_steps,
                    log_freq,
                    collate_fn,
                    callbacks,
                    verbose,
                )
Z
zhaoyingli 已提交
964
                val_logs = {
965
                    "val_" + name: val for name, val in val_logs.items()
Z
zhaoyingli 已提交
966 967 968 969 970 971 972 973 974 975
                }
                logs.update(val_logs)
                self._switch_mode("train")
            else:
                self._reset_metrics()

            cbks.on_epoch_end(epoch, logs)

        cbks.on_end('train', logs)
        return self.history
976

977 978 979 980 981 982 983 984 985 986 987
    def evaluate(
        self,
        valid_data,
        valid_sample_split=None,
        batch_size=1,
        steps=None,
        log_freq=10,
        collate_fn=None,
        callbacks=None,
        verbose=2,
    ):
988 989 990 991
        """
        Evaluate the loss and metrics of the model on evaluation data.

        Args:
992 993
            valid_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None.
            valid_sample_split (int, optional): Each sample of the eval dataset is assumed
994
                to be a (input, label) pair by default and has two items. If each sample has
995
                more than two items, valid_sample_split specifies how to split these items into
996
                input and label. The items before it are input and the left are label. Default: None.
997
            batch_size (int, optional): The batch size of valid_data. The user's data will
998
                be used directly without batching if set to None. Default: 1.
999 1000
            steps (int, optional): It is the total number of steps (batches of samples) to draw before
                stopping evaluation. If None, evaluation will run until the `valid_data` dataset is exhausted.
1001 1002 1003 1004 1005
                The evaluation will start from the beginning of the dataset in each run. Default: None.
            collate_fn(callable, optional): function to generate mini-batch data by merging
                the sample list, None for only stack each fields of sample in axis
                0. Default None.
            callbacks (Callback|None, optional): A list of `Callback` instances to apply
1006
                during evaluating. Default: None. (Unused for now)
1007 1008 1009 1010 1011 1012 1013 1014 1015 1016

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle
                import paddle.vision.transforms as T
1017
                from paddle.distributed.fleet import auto
1018 1019 1020 1021 1022 1023 1024 1025 1026
                from paddle.vision.datasets import MNIST

                transform = T.Compose([
                    T.Transpose(),
                    T.Normalize([127.5], [127.5])
                ])
                valid_dataset = MNIST(mode='test', transform=transform)

                model = paddle.vision.models.LeNet()
1027
                loss = paddle.nn.CrossEntropyLoss()
1028 1029
                metrics = paddle.metric.Accuracy(topk=(1, 2))

1030
                engine = auto.Engine(model, loss, metrics=metrics)
1031 1032 1033
                engine.evaluate(valid_dataset, batch_size=64)

        """
1034 1035
        self._mode = 'eval'
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
1036 1037
            valid_data, valid_sample_split, batch_size
        )
1038 1039
        if not self._has_prepared[self._mode]:
            self._prepare_program(self._mode)
Z
zhaoyingli 已提交
1040
        else:
1041
            self._switch_mode(self._mode)
Z
zhaoyingli 已提交
1042

1043 1044 1045 1046 1047 1048
        valid_dataloader = self._prepare_dataloader_from_generator(
            dataset=valid_data,
            capacity=70,
            iterable=False,
            batch_size=batch_size,
            steps_per_epoch=steps,
1049 1050
            collate_fn=collate_fn,
        )
Z
zhaoyingli 已提交
1051

1052
        fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode)
1053

Z
zhaoyingli 已提交
1054 1055 1056 1057 1058 1059 1060 1061 1062 1063
        cbks = config_callbacks(
            callbacks,
            engine=self,
            batch_size=batch_size,
            log_freq=log_freq,
            verbose=verbose,
            metrics=self._metrics_name(),
        )

        eval_steps = valid_dataloader._steps
1064 1065 1066
        cbks.on_begin(
            'eval', {'steps': eval_steps, 'metrics': self._metrics_name()}
        )
Z
zhaoyingli 已提交
1067
        logs = {}
1068
        for step, _ in enumerate(valid_dataloader):
Z
zhaoyingli 已提交
1069
            cbks.on_batch_begin('eval', step, logs)
1070
            try:
1071 1072
                outs = self._executor.run(
                    self.main_program,
1073
                    fetch_list=fetch_names,
1074
                    use_program_cache=self._strategy.use_cache,
1075 1076
                    return_numpy=self._strategy.return_numpy,
                )
1077
            except core.EOFException:
1078
                break
1079 1080 1081
            logs = self._prepare_logger(
                outs, None, step, None, fetch_names, fetch_indices, self._mode
            )
Z
zhaoyingli 已提交
1082 1083
            cbks.on_batch_end('eval', step, logs)
        cbks.on_end('eval', logs)
1084
        self._reset_metrics()
Z
zhaoyingli 已提交
1085
        return logs
1086

1087 1088 1089 1090 1091 1092 1093 1094 1095 1096
    def predict(
        self,
        test_data,
        test_sample_split=None,
        batch_size=1,
        steps=None,
        collate_fn=None,
        callbacks=None,
        verbose=2,
    ):
1097 1098 1099 1100 1101 1102 1103
        """
        Compute the output predictions on testing data.

        Args:
            test_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None.
            test_sample_split (int, optional): Each sample of the test dataset is assumed
                to be a (input, label) pair by default and has two items. If each sample has
1104
                more than two items, test_sample_split specifies how to split these items into
1105 1106 1107
                input and label. The items before it are input and the left are label. Default: None.
            batch_size (int, optional): The batch size of test_data. The user's data will
                be used directly without batching if set to None. Default: 1.
1108 1109
            steps (int, optional): It is the total number of steps (batches of samples) to draw before
                stopping predict. If None, predict will run until the `test_data` dataset is exhausted.
1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125
                The predict will start from the beginning of the dataset in each run. Default: None.
            collate_fn(callable, optional): function to generate mini-batch data by merging
                the sample list, None for only stack each fields of sample in axis
                0. Default None.
            callbacks (Callback|None, optional): A list of `Callback` instances to apply
                during testing. Default: None. (Unused for now)

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle
                import paddle.vision.transforms as T
1126
                from paddle.distributed.fleet import auto
1127 1128 1129 1130 1131 1132 1133 1134 1135 1136
                from paddle.vision.datasets import MNIST

                transform = T.Compose([
                    T.Transpose(),
                    T.Normalize([127.5], [127.5])
                ])
                valid_dataset = MNIST(mode='test', transform=transform)

                model = paddle.vision.models.LeNet()

1137
                engine = auto.Engine(model)
1138 1139
                engine.predict(valid_dataset, batch_size=64)
        """
1140 1141
        self._mode = 'predict'
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
1142 1143
            test_data, test_sample_split, batch_size
        )
1144 1145
        if not self._has_prepared[self._mode]:
            self._prepare_program(self._mode)
Z
zhaoyingli 已提交
1146
        else:
1147
            self._switch_mode(self._mode)
Z
zhaoyingli 已提交
1148

1149 1150 1151 1152 1153 1154
        test_dataloader = self._prepare_dataloader_from_generator(
            dataset=test_data,
            capacity=70,
            iterable=False,
            batch_size=batch_size,
            steps_per_epoch=steps,
1155 1156
            collate_fn=collate_fn,
        )
Z
zhaoyingli 已提交
1157

1158
        fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode)
1159

Z
zhaoyingli 已提交
1160 1161 1162 1163 1164
        outputs = []
        cbks = config_callbacks(callbacks, engine=self, verbose=verbose)
        test_steps = test_dataloader._steps
        cbks.on_begin('predict', {'steps': test_steps})
        logs = {}
1165
        for step, _ in enumerate(test_dataloader):
Z
zhaoyingli 已提交
1166
            cbks.on_batch_begin('predict', step, logs)
1167
            try:
1168 1169
                outs = self._executor.run(
                    self.main_program,
1170
                    fetch_list=fetch_names,
1171
                    use_program_cache=self._strategy.use_cache,
1172 1173
                    return_numpy=self._strategy.return_numpy,
                )
1174
            except core.EOFException:
1175
                break
1176 1177 1178
            logs = self._prepare_logger(
                outs, None, step, None, fetch_names, fetch_indices, self._mode
            )
Z
zhaoyingli 已提交
1179 1180 1181 1182 1183
            cbks.on_batch_end('predict', step, logs)
            outputs.append(list(logs["outputs"].values()))
        cbks.on_end('predict', logs)
        return outputs

1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200
    def dataloader(
        self,
        dataset,
        batch_size=1,
        shuffle=False,
        drop_last=False,
        collate_fn=None,
        num_workers=0,
        use_buffer_reader=True,
        use_shared_memory=True,
        timeout=0,
        worker_init_fn=None,
        epochs=1,
        steps_per_epoch=None,
        sample_split=1,
        mode=None,
    ):
1201 1202 1203
        if mode is not None:
            self.to_mode(mode)
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
1204 1205
            dataset, sample_split, batch_size
        )
1206 1207
        if not self._has_prepared[self._mode]:
            self._prepare_program(self._mode)
1208
        else:
1209
            self._switch_mode(self._mode)
1210

1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223
        dataloader = self._prepare_dataloader(
            dataset,
            return_list=False,
            batch_size=batch_size,
            shuffle=shuffle,
            drop_last=drop_last,
            collate_fn=collate_fn,
            num_workers=num_workers,
            use_buffer_reader=use_buffer_reader,
            use_shared_memory=use_shared_memory,
            timeout=timeout,
            worker_init_fn=worker_init_fn,
            epochs=epochs,
1224 1225
            steps_per_epoch=steps_per_epoch,
        )
1226 1227
        return dataloader

1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242
    def dataloader_from_generator(
        self,
        dataset,
        capacity=70,
        use_double_buffer=True,
        iterable=True,
        use_multiprocess=False,
        drop_last=True,
        batch_size=1,
        epochs=1,
        steps_per_epoch=None,
        collate_fn=None,
        sample_split=1,
        mode=None,
    ):
1243 1244 1245
        if mode is not None:
            self.to_mode(mode)
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
1246 1247
            dataset, sample_split, batch_size
        )
1248 1249 1250 1251
        if not self._has_prepared[self._mode]:
            self._prepare_program(self._mode)
        else:
            self._switch_mode(self._mode)
1252

1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263
        dataloader = self._prepare_dataloader_from_generator(
            dataset=dataset,
            capacity=capacity,
            use_double_buffer=use_double_buffer,
            iterable=iterable,
            return_list=False,
            use_multiprocess=use_multiprocess,
            drop_last=drop_last,
            batch_size=batch_size,
            epochs=epochs,
            steps_per_epoch=steps_per_epoch,
1264 1265
            collate_fn=collate_fn,
        )
1266 1267
        return dataloader

1268 1269 1270 1271 1272 1273 1274 1275 1276 1277
    def prepare(
        self,
        inputs_spec=None,
        labels_spec=None,
        inputs=None,
        labels=None,
        main_program=None,
        startup_program=None,
        mode=None,
    ):
1278 1279
        if mode is not None:
            self.to_mode(mode)
1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295

        if not self._mode:
            raise ValueError(
                "Please set mode to be prepared with `prepare(mode=...)`"
            )

        if self._has_prepared[self._mode]:
            return

        inputs_spec = self._validate_spec(inputs_spec)
        labels_spec = self._validate_spec(labels_spec)
        inputs = self._validate_vars(inputs)
        labels = self._validate_vars(labels)

        self._orig_main_prog = main_program
        self._orig_startup_prog = startup_program
1296 1297
        if inputs or labels:
            self._skip_build = True
1298 1299
            inputs, labels = self._prepare_data_tensor(
                inputs_spec, labels_spec, inputs, labels
1300
            )
1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311
            if self._orig_main_prog is None:
                self._orig_main_prog = static.default_main_program()
            if self._orig_startup_prog is None:
                self._orig_startup_prog = static.default_startup_program()
        elif inputs_spec or labels_spec:
            self._outside_dataloader = True
            if self._orig_main_prog is None:
                self._orig_main_prog = static.default_main_program()
            if self._orig_startup_prog is None:
                self._orig_startup_prog = static.default_startup_program()
        else:
1312 1313 1314
            assert (
                self._inputs_spec and self._labels_spec
            ), "Please call the dataloader(...) before calling prepare(...)"
1315

1316 1317 1318 1319 1320 1321 1322
        self._inputs_spec, self._labels_spec = inputs_spec, labels_spec
        self._inputs, self._labels = inputs, labels
        if not self._has_prepared[self._mode]:
            self._prepare_program(self._mode)
        else:
            self._switch_mode(self._mode)

1323
    def run(self, data=None, feed=None, fetch_list=None, mode=None):
1324 1325 1326 1327
        if mode is not None:
            self.to_mode(mode)
        feed_dict = self._prepare_feed(data, feed, self._mode)
        fetch_names, fetch_indices = self._prepare_fetch(fetch_list, self._mode)
1328 1329 1330 1331
        if (
            self._outside_dataloader
            and not self._has_prepared_reader[self._mode]
        ):
1332
            self._prepare_reader()
1333 1334 1335 1336 1337 1338 1339 1340 1341 1342
        outs = self._executor.run(
            self.main_program,
            feed=feed_dict,
            fetch_list=fetch_names,
            use_program_cache=self._strategy.use_cache,
            return_numpy=self._strategy.return_numpy,
        )
        logs = self._prepare_logger(
            outs, None, None, None, fetch_names, fetch_indices, self._mode
        )
Z
zhaoyingli 已提交
1343
        return logs
1344

1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360
    def _prepare_dataloader(
        self,
        dataset,
        return_list=True,
        batch_size=1,
        shuffle=False,
        drop_last=False,
        collate_fn=None,
        num_workers=0,
        use_buffer_reader=True,
        use_shared_memory=True,
        timeout=0,
        worker_init_fn=None,
        epochs=1,
        steps_per_epoch=None,
    ):
1361

1362
        if self._strategy.gradient_merge and batch_size is not None:
1363 1364 1365 1366 1367
            assert (
                batch_size % self._k_steps == 0
            ), "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(
                batch_size, self._k_steps
            )
1368
            batch_size //= self._k_steps
1369

1370 1371
        dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank]
        dist_startup_prog = self._dist_startup_progs[self._mode][self._cur_rank]
1372
        dist_main_block = dist_main_prog.global_block()
1373

1374 1375 1376 1377
        # NOTE: Get feed_list, then insert dataloader op with sharded var shape.
        # Cause predict_program does not contain labels var,
        # then we will add labels var from serial_program to dist_program,
        # that maintains the length of feed_list equal to the length of dataset's values.
1378 1379
        inputs_var = self._feed_vars[self._mode]["inputs"]
        labels_var = self._feed_vars[self._mode]["labels"]
1380 1381 1382 1383
        feed_list = []
        for var in inputs_var + labels_var:
            if var.name in dist_main_block.vars:
                feed_list.append(dist_main_block.vars[var.name])
1384 1385 1386 1387
            else:
                copy_var = dist_main_block._clone_variable(var, var.persistable)
                copy_var.desc.set_original_id(var.desc.original_id())
                feed_list.append(copy_var)
1388 1389

        # insert read op at the end of program
1390
        places = paddle.static.cuda_places()
1391
        with static.program_guard(dist_main_prog, dist_startup_prog):
1392
            dataloader = DistributedDataLoader(
1393
                dataset,
1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408
                feed_list=feed_list,
                places=places,
                return_list=return_list,
                batch_size=batch_size,
                shuffle=shuffle,
                drop_last=drop_last,
                collate_fn=collate_fn,
                num_workers=num_workers,
                use_buffer_reader=use_buffer_reader,
                use_shared_memory=use_shared_memory,
                timeout=timeout,
                worker_init_fn=worker_init_fn,
                epochs=epochs,
                steps_per_epoch=steps_per_epoch,
                split_data=self._strategy.split_data,
1409
                data_parallel_world_size=self._dp_world_sizes,
1410 1411
                data_parallel_rank=self._dp_ranks,
            )
1412

1413 1414
        return dataloader

1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428
    def _prepare_dataloader_from_generator(
        self,
        dataset,
        capacity=None,
        use_double_buffer=True,
        iterable=True,
        return_list=False,
        use_multiprocess=False,
        drop_last=True,
        batch_size=1,
        epochs=1,
        steps_per_epoch=None,
        collate_fn=None,
    ):
1429 1430

        if self._strategy.gradient_merge and batch_size is not None:
1431 1432 1433 1434 1435
            assert (
                batch_size % self._k_steps == 0
            ), "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(
                batch_size, self._k_steps
            )
1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474
            batch_size //= self._k_steps

        dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank]
        dist_startup_prog = self._dist_startup_progs[self._mode][self._cur_rank]
        dist_main_block = dist_main_prog.global_block()

        # NOTE: Get feed_list, then insert dataloader op with sharded var shape.
        # Cause predict_program does not contain labels var,
        # then we will add labels var from serial_program to dist_program,
        # that maintains the length of feed_list equal to the length of dataset's values.
        inputs_var = self._feed_vars[self._mode]["inputs"]
        labels_var = self._feed_vars[self._mode]["labels"]
        feed_list = []
        for var in inputs_var + labels_var:
            if var.name in dist_main_block.vars:
                feed_list.append(dist_main_block.vars[var.name])
            else:
                copy_var = dist_main_block._clone_variable(var, var.persistable)
                copy_var.desc.set_original_id(var.desc.original_id())
                feed_list.append(copy_var)

        places = paddle.static.cuda_places()
        with static.program_guard(dist_main_prog, dist_startup_prog):
            dataloader = DistributedDataLoaderFromGenerator(
                dataset=dataset,
                feed_list=feed_list,
                capacity=capacity,
                use_double_buffer=use_double_buffer,
                iterable=iterable,
                return_list=return_list,
                use_multiprocess=use_multiprocess,
                drop_last=drop_last,
                places=places,
                batch_size=batch_size,
                epochs=epochs,
                steps_per_epoch=steps_per_epoch,
                collate_fn=collate_fn,
                split_data=self._strategy.split_data,
                data_parallel_world_size=self._dp_world_sizes,
1475 1476
                data_parallel_rank=self._dp_ranks,
            )
1477 1478 1479 1480 1481 1482
        self._prepare_reader()
        return dataloader

    def _tune(self, tune_data, tune_sample_split=None, batch_size=1):
        self._mode = 'train'
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
1483 1484
            tune_data, tune_sample_split, batch_size
        )
1485 1486
        self._optimization_tuning(self._mode, tune_data, batch_size)

1487
    def _validate_spec(self, specs):
1488
        specs = auto_utils.to_list(specs)
1489
        self._k_steps = self._strategy.gradient_merge.k_steps
1490 1491
        if specs is not None:
            for i, spec in enumerate(specs):
1492 1493 1494 1495
                if not isinstance(spec, InputSpec):
                    raise TypeError(
                        "'spec' must be object of class `paddle.static.InputSpec`."
                    )
1496 1497
                if spec.name is None:
                    raise ValueError(
1498 1499 1500 1501
                        "Requires Input[{}].name != None, but receive `None` with {}.".format(
                            i, spec
                        )
                    )
1502
                if self._k_steps > 1:
1503
                    shape = list(spec.shape)
1504 1505 1506 1507 1508
                    assert (
                        shape[0] % self._k_steps == 0
                    ), "Requires batch_size[{}] to be divisible by k_steps[{}].".format(
                        spec.shape[0], self._k_steps
                    )
1509
                    shape[0] //= self._k_steps
1510
                    spec.shape = shape
1511 1512 1513
        return specs or []

    def _validate_vars(self, vars):
1514
        vars = auto_utils.to_list(vars)
1515 1516 1517 1518 1519
        if vars is not None:
            for i, var in enumerate(vars):
                if not isinstance(var, Variable):
                    raise TypeError("'var' must be a `Variable`.")
        return vars or []
1520

1521 1522 1523 1524
    def _is_local_var(self, var):
        var_name = _to_name_str(var)
        return var_name in self.main_program.global_block().vars

1525 1526 1527 1528
    def _reset_metrics(self):
        for metric in self._metrics:
            metric.reset()

Z
zhaoyingli 已提交
1529 1530 1531
    def _metrics_name(self):
        metrics_name = ['loss'] if self._loss else []
        for m in self._metrics:
1532
            metrics_name.extend(auto_utils.to_list(m.name()))
Z
zhaoyingli 已提交
1533 1534
        return metrics_name

1535
    def _switch_mode(self, mode):
1536 1537 1538
        assert (
            mode in self._dist_main_progs
        ), "{} model is not ready, please call `prepare()` first.".format(mode)
1539
        self.to_mode(mode)
Z
zhaoyingli 已提交
1540
        self._optimizer = self._dist_contexts[mode]._serial_optimizer
1541

1542
    def to_mode(self, mode):
1543 1544 1545 1546 1547
        assert mode in [
            "train",
            "eval",
            "predict",
        ], "mode {} should be one of ['train', 'eval', 'predict']".format(mode)
1548 1549
        self._mode = mode

1550 1551 1552
    def _set_state_dict(self, mode, strict, state_dict, dist_attr):
        program = self._dist_main_progs[mode][self._cur_rank]
        dist_context = self._dist_contexts[mode]
1553
        cur_dist_attr = auto_utils.get_dist_attr(program, dist_context)
1554 1555 1556 1557 1558
        converter = Converter(state_dict, dist_attr, cur_dist_attr)
        state_dict = converter.convert(strict=strict)
        program.set_state_dict(state_dict)

    def save(self, path, training=True):
1559 1560
        """
        Saves the model, parameters, optimizer state to path.
1561 1562 1563 1564 1565 1566 1567
        If `training` is set to False, only inference model will be saved.

        Args:
            path (str): The file prefix to save model. The format
                is 'dirname/file_prefix' or 'file_prefix'. if empty str.
                A exception will be raised.
            training (bool, optional): Whether to save for training. If not, save
1568
                for inference only. If `training` is set to True, the optimizer state
1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580
                will be saved. Otherwise, only the model and parameters are saved.
                This function will silently overwrite existing file at the target
                location. Default: True.

        Returns:
            None

        Examples:

            .. code-block:: python
                import paddle
                import paddle.vision.transforms as T
1581
                from paddle.distributed.fleet import auto
1582 1583 1584 1585 1586 1587 1588 1589 1590
                from paddle.vision.datasets import MNIST

                transform = T.Compose([
                    T.Transpose(),
                    T.Normalize([127.5], [127.5])
                ])
                train_dataset = MNIST(mode='train', transform=transform)

                model = paddle.vision.models.LeNet()
1591
                loss = paddle.nn.CrossEntropyLoss()
1592 1593 1594 1595
                optimizer = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=model.parameters())
                metrics = paddle.metric.Accuracy(topk=(1, 2))

1596
                engine = auto.Engine(model, loss, optimizer, metrics)
1597 1598 1599 1600
                engine.fit(train_dataset,
                           epochs=1,
                           batch_size=64)
                engine.save("./my_model")
1601

1602
        """
1603
        if training:
Z
zhaoyingli 已提交
1604 1605 1606 1607
            assert self._mode in self._serial_main_progs
            serial_program = self._serial_main_progs[self._mode]
            dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank]
            dist_context = self._dist_contexts[self._mode]
1608 1609 1610 1611 1612 1613
            self._saver.save(
                path,
                serial_program=serial_program,
                dist_main_program=dist_main_prog,
                dist_context=dist_context,
            )
1614
        else:
Z
zhaoyingli 已提交
1615 1616 1617 1618
            assert "predict" in self._dist_main_progs
            feed_vars = self._feed_vars["predict"]['inputs']
            fetch_vars = self._fetch_vars["predict"]['outputs']
            dist_main_prog = self._dist_main_progs["predict"][self._cur_rank]
1619
            if self._strategy.qat.enable and self._strategy.qat.onnx_format:
1620
                from paddle.static.quantization import QuantWeightPass
1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632

                self._logger.info("export quantized model.")
                self._logger.info(
                    "convert config {}".format(self._strategy.qat.to_dict())
                )
                test_graph = IrGraph(
                    core.Graph(dist_main_prog.desc), for_test=True
                )
                quant_weight_pass = QuantWeightPass(global_scope(), self._place)
                for sub_graph in test_graph.all_sub_graphs():
                    quant_weight_pass.apply(sub_graph)
                dist_main_prog = test_graph.to_program()
1633 1634 1635 1636 1637 1638 1639
            self._saver.save_inference_model(
                path,
                feed_vars,
                fetch_vars,
                self._executor,
                program=dist_main_prog,
            )
1640

1641 1642 1643 1644 1645 1646
    def load(self, path, strict=True, load_optimizer=True):
        """
        Load the stored model, parameters and optimizer states.

        Args:
            path (str): The prefix of files storing the model states and
1647
                optimizer states.
1648 1649 1650
            strict (bool, optional): Whether to skip the loading of mismatch
                parameter or raise an error when mismatch happens (not found
                the parameter in file storing model states of or receives a
1651
                mismatch shape). Default: True.
1652
            load_optimizer (bool, optional): If True, the stored optimizer
1653
                states is restored. Otherwise, the optimizer states is initialized
1654
                from scratch. Default: True.
1655 1656 1657 1658 1659 1660 1661 1662 1663

        Returns:
            None

        Examples:

            .. code-block:: python
                import paddle
                import paddle.vision.transforms as T
1664
                from paddle.distributed.fleet import auto
1665 1666 1667 1668 1669 1670 1671 1672 1673
                from paddle.vision.datasets import MNIST

                transform = T.Compose([
                    T.Transpose(),
                    T.Normalize([127.5], [127.5])
                ])
                train_dataset = MNIST(mode='train', transform=transform)

                model = paddle.vision.models.LeNet()
1674
                loss = paddle.nn.CrossEntropyLoss()
1675 1676 1677 1678
                optimizer = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=model.parameters())
                metrics = paddle.metric.Accuracy(topk=(1, 2))

1679
                engine = auto.Engine(model, loss, optimizer, metrics)
1680 1681 1682 1683 1684
                engine.fit(train_dataset,
                           epochs=1,
                           batch_size=64)
                engine.save("./my_model")
                engine.load("./my_model")
1685

1686 1687 1688
        """
        self._strict = strict
        self._state_dict, self._dist_attr = self._saver.load(
1689 1690
            path, load_optimizer
        )
1691
        return self._state_dict, self._dist_attr
1692

1693
    def cost(self, inputs_spec=None, labels_spec=None, mode=None):
1694 1695 1696 1697 1698 1699 1700 1701 1702 1703
        """
        Get and Print cost, including memory of every rank,
        max memory among all ranks, and the global cost of one step based on
        communication cost(computation cost is 0 by default).
        In the future, the flops information of every rank and global cost including
        computation cost will be added.

        Args:
            inputs_spec(InputSpec): The specification of inputs. Default: None.
            labels_spec(InputSpec): The specification of labels. Default: None.
1704
            mode (str): The engine mode must be in ["train", "predict", "eval"]. Default: None.
1705 1706 1707 1708 1709 1710 1711

        Returns:
            Return the global execution time (ms) and max memory (B).

        """
        # Check parallel mode
        if self._strategy.auto_mode == "full":
1712
            self._logger.info(
1713 1714 1715 1716 1717
                "The cost will be calcudated in the search process when the auto mode is full."
            )
            return

        # Check mode
1718 1719 1720
        mode = mode if mode is not None else self._mode
        assert mode is not None, "Please set mode."
        if mode not in self._has_prepared:
1721 1722
            raise ValueError(
                "The mode {} is not in accepted modes {}".format(
1723
                    mode, list(self._has_prepared.keys())
1724 1725
                )
            )
1726 1727
        self.to_mode(mode)

1728 1729 1730
        if inputs_spec is not None and not self._has_prepared[mode]:
            self._inputs_spec = self._validate_spec(inputs_spec)
            self._labels_spec = self._validate_spec(labels_spec)
1731 1732 1733
            self._build(mode)
            self._plan(mode)
        else:
1734
            if in_dygraph_mode() or self._dygraph_mode:
1735
                raise ValueError(
1736 1737 1738 1739 1740
                    "Please call `prepare()` or `fit()` or  `evaluate()` or  `predict()` before calling `cost()`."
                )
            else:
                self._logger.info(
                    "The program whose cost to be estimated must be static default program. Otherwise, please call `prepare()`before calling `cost()`."
1741
                )
1742 1743 1744 1745 1746 1747 1748 1749
                program = paddle.static.default_main_program()
                if (
                    not program.global_block().ops
                    or not program.global_block().ops
                ) and not self._has_prepared[mode]:
                    raise ValueError(
                        "Please call `prepare()` or `fit()` or  `evaluate()` or  `predict()` before calling `cost()`."
                    )
1750 1751 1752 1753 1754 1755

        # Estimate the exec cost and max memory
        global_cost, max_memory = get_cost_from_engine(self, mode)

        return global_cost.time, max_memory

1756 1757
    @property
    def main_program(self):
1758
        return self._dist_main_progs[self._mode][self._cur_rank]
1759 1760 1761

    @property
    def startup_program(self):
1762
        return self._dist_startup_progs[self._mode][self._cur_rank]
1763 1764 1765

    @property
    def dist_context(self):
1766
        return self._dist_contexts[self._mode]
1767 1768 1769

    @property
    def serial_main_program(self):
1770
        return self._serial_main_progs[self._mode]
1771 1772 1773

    @property
    def serial_startup_program(self):
1774
        return self._serial_startup_progs[self._mode]
1775 1776 1777

    @property
    def fetch_vars(self):
1778
        return self._fetch_vars[self._mode]
1779 1780 1781

    @property
    def inputs(self):
1782
        return self._inputs
1783 1784 1785

    @property
    def labels(self):
1786
        return self._labels