engine.py 69.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import copy
16
import logging
17
import numbers
18 19
import os
import random
20

21 22
import numpy as np

23
import paddle
24
import paddle.distributed.auto_parallel.utils as auto_utils
25
from paddle import static, utils
26
from paddle.distributed import fleet
27 28 29 30
from paddle.fluid.executor import _to_name_str
from paddle.framework import IrGraph
from paddle.framework import _current_expected_place as _get_device
from paddle.framework import core, in_dygraph_mode
31
from paddle.metric import Metric
32
from paddle.static import InputSpec, Operator, Variable, global_scope
33

34
from ..utils.log_utils import get_logger
Z
zhaoyingli 已提交
35
from .callbacks import config_callbacks
36
from .cluster import Cluster, get_default_cluster
37 38 39
from .converter import Converter
from .cost.estimate_cost import get_cost_from_engine
from .dist_context import DistributedContext, get_default_distributed_context
40 41
from .dist_loader import (
    DistributedDataLoader,
42
    DistributedDataLoaderFromGenerator,
43
)
44 45 46
from .dist_op import DistributedOperator
from .dist_saver import DistributedSaver
from .helper import ProgramHelper
47
from .interface import CollectionNames, fetch, get_collection
48 49 50 51
from .parallelizer_v2 import Parallelizer
from .planner_v2 import Planner
from .process_group import get_all_process_groups, new_process_group
from .strategy import Strategy
52

53 54

class Engine:
55
    """
56 57
    An Engine object can provide the full power of auto parallel to users.
    With the help of it, users can easily obtain the abilities of the
58 59 60 61 62 63 64
    distributed training and inference. It also support the dynamic graph and
    static graph at the same time.

    Args:
        model (paddle.nn.Layer, optional): The model is an instance of
            paddle.nn.Layer.
        loss (Loss|Callable|None, optional): The loss can be a `paddle.nn.Layer`
65 66
            instance or any callable function taken the predicted values and
            ground truth values as input. It can be None when there is no loss.
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
            Default: None.
        optimizer (Optimizer|None, optional): The optimizer need to be set in training
            and should be None in eval and predict mode. Default: None.
        metrics (Metric|list[Metric]|None, optional): If metrics is set, all
            metrics will be calculated and output in train/eval mode. Default: None.
        cluster (Cluster|None, optional): The cluster represents the topology information
            about the used physical devices. Default: None. (Unused for now)
        strategy (Strategy|None, optional): The strategy is used to configure the
        parallelization and optimization behaviors. Default: None.

    Examples:

        .. code-block:: python

            import paddle
            import paddle.vision.transforms as T
83
            from paddle.distributed.fleet import auto
84 85 86 87 88 89 90 91 92 93
            from paddle.vision.datasets import MNIST

            transform = T.Compose([
                T.Transpose(),
                T.Normalize([127.5], [127.5])
            ])
            train_dataset = MNIST(mode='train', transform=transform)
            valid_dataset = MNIST(mode='test', transform=transform)

            model = paddle.vision.models.LeNet()
94
            loss = paddle.nn.CrossEntropyLoss()
95 96 97 98
            optimizer = paddle.optimizer.Adam(
                learning_rate=0.001, parameters=model.parameters())
            metrics = paddle.metric.Accuracy(topk=(1, 2))

99 100
            engine = auto.Engine(model, loss, optimizer, metrics)
            # fit
101 102 103
            engine.fit(train_dataset,
                       epochs=2,
                       batch_size=64)
104
            # evaluate
105 106 107 108 109 110 111
            engine.evaluate(valid_dataset,
                            batch_size=64)
            # predict
            engine.predict(valid_dataset,
                           batch_size=64)
            # save
            engine.save("./my_model")
112
            # load
113 114 115
            engine.load("./my_model")

    """
116

117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
    def __init__(
        self,
        model=None,
        loss=None,
        optimizer=None,
        metrics=None,
        cluster=None,
        strategy=None,
    ):

        if (
            model
            and not isinstance(model, paddle.nn.Layer)
            and not callable(model)
        ):
132 133 134 135
            raise TypeError(
                "'model must be sub classes of `paddle.nn.Layer` or any callable function."
            )
        self._model = model
136 137 138 139 140 141 142 143 144

        if (
            loss
            and not isinstance(loss, (paddle.nn.Layer, Variable))
            and not callable(loss)
        ):
            raise TypeError(
                "'loss' must be sub classes of `paddle.nn.Layer` or any callable function or a Variable."
            )
145 146 147
        self._loss = loss

        if optimizer and not isinstance(
148
            optimizer,
149
            (paddle.optimizer.Optimizer, paddle.static.Optimizer),
150
        ):
151 152
            raise TypeError(
                "'optimizer' must be object of class `paddle.optimizer.Optimizer`"
153
                " or `paddle.static.Optimizer`."
154
            )
155
        self._optimizer = auto_utils.validate_opt(optimizer)
156 157

        metrics = metrics or []
158
        for metric in auto_utils.to_list(metrics):
159 160 161 162 163 164
            if metric and not isinstance(metric, Metric):
                raise TypeError(
                    "{} is not sub class of Metric".format(
                        metric.__class__.__name__
                    )
                )
165
        self._metrics = auto_utils.to_list(metrics)
166 167 168 169 170 171 172 173 174 175 176 177 178

        if cluster and not isinstance(cluster, Cluster):
            raise TypeError(
                "'cluster' must be the object or class `paddle.distributed.auto_parallel.Cluster`"
            )
        self._cluster = cluster or get_default_cluster()

        if strategy and not isinstance(strategy, Strategy):
            raise TypeError(
                "'strategy' must be object of class `paddle.distributed.auto_parallel.Strategy`"
            )
        self._strategy = strategy or Strategy()

179
        self._logger = get_logger(logging.INFO)
180
        if os.getenv("POD_NAME"):
181 182
            self._logger.info(
                "Distribute training by paddle.distributed.launch"
183
            )
184
            fleet.init(is_collective=True)
185

186 187 188 189 190 191
        # for compute cost
        # TODO: remove _fwd_main_progs and _orig_optimizer
        self._fwd_dist_contexts = {}
        self._fwd_main_progs = {}
        self._orig_optimizer = copy.deepcopy(self._optimizer)

192
        self._executor = None
193 194 195
        self._cur_rank = paddle.distributed.get_rank()
        self._nranks = paddle.distributed.get_world_size()
        self._saver = DistributedSaver()
196

197 198
        self._orig_main_prog = static.default_main_program()
        self._orig_startup_prog = static.default_startup_program()
199
        self._orig_dist_context = get_default_distributed_context()
200
        self._dist_contexts = {}
201
        self._planners = {}
202 203
        self._has_prepared = {"train": False, "eval": False, "predict": False}
        self._has_prepared_reader = {
204 205
            "train": False,
            "eval": False,
206
            "predict": False,
207
        }
208 209 210 211
        self._inputs_spec = []
        self._labels_spec = []
        self._inputs = []
        self._labels = []
212
        self._losses = []
213

214
        self._mode = None
215 216
        self._skip_build = False
        self._outside_dataloader = False
217
        self._planned_mode = None
218 219
        self._dygraph_mode = False
        self._tuning = self._strategy.tuning
220

Z
zhaoyingli 已提交
221 222
        self.history = None

223 224
        paddle.framework.set_flags({'FLAGS_new_executor_sequential_run': 1})

225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242
    def _prepare_data_spec(self, data, split, batch_size):
        inputs_spec = []
        labels_spec = []
        if isinstance(data, paddle.io.IterableDataset):
            if split is None:
                inputs, labels = next(iter(data))
            else:
                sample = next(iter(data))
                inputs = sample[:split]
                labels = sample[split:]
        elif isinstance(data, paddle.io.Dataset):
            if split is None:
                inputs, labels = data[0]
            else:
                sample = data[0]
                inputs = sample[:split]
                labels = sample[split:]
        else:
243
            raise TypeError(
C
chenxujun 已提交
244
                "Data should be a Dataset or IterableDataset, but received {}.".format(
245 246 247
                    type(data).__name__
                )
            )
248 249
        inputs = auto_utils.to_list(inputs)
        labels = auto_utils.to_list(labels)
250 251

        num_shards = self._strategy.dataset.num_shards
252

253 254 255 256 257 258 259 260 261 262 263 264
        def _adjust_item_spec(num_shards, spec):
            if num_shards > 1 and len(spec.shape) > 1:
                spec.shape[0] = spec.shape[0] * num_shards

        def _infer_item_spec(item, name, batch_size, specs):
            if isinstance(item, np.ndarray):
                spec = InputSpec.from_numpy(item, name)
                if batch_size is None:
                    _adjust_item_spec(num_shards, spec)
                    specs.append(spec)
                else:
                    specs.append(spec.batch(batch_size))
W
wanghuancoder 已提交
265
            elif isinstance(item, (Variable, core.eager.Tensor)):
266
                spec = InputSpec.from_tensor(item, name)
267
                _adjust_item_spec(num_shards, spec)
268 269 270 271
                if batch_size is None:
                    specs.append(spec)
                else:
                    specs.append(spec.batch(batch_size))
272
            elif isinstance(item, numbers.Number):
273
                specs.append(InputSpec([batch_size], type(item), name))
274 275 276 277 278 279
            else:
                raise TypeError(
                    "The sample's dtype returned of dataset should be number, np.ndarray or Tensor, but got {}".format(
                        type(item).__name__
                    )
                )
280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295

        if inputs is not None:
            for i, item in enumerate(inputs):
                assert item is not None, "Receive None input."
                name = "input" + str(i)
                _infer_item_spec(item, name, batch_size, inputs_spec)
        if labels is not None:
            for i, item in enumerate(labels):
                assert item is not None, "Receive None input."
                name = "label" + str(i)
                _infer_item_spec(item, name, batch_size, labels_spec)

        inputs_spec = self._validate_spec(inputs_spec)
        labels_spec = self._validate_spec(labels_spec)
        return inputs_spec, labels_spec

296
    def _prepare_data_tensor(self, inputs_spec, labels_spec, inputs, labels):
297
        if in_dygraph_mode() or self._dygraph_mode:
298 299
            raise ValueError("Only support static graph mode.")

300
        if inputs_spec:
301 302 303 304 305
            assert isinstance(
                inputs_spec, list
            ), "inputs should be list, but received {}".format(
                type(inputs_spec)
            )
306 307
            assert isinstance(
                inputs, list
308
            ), f"inputs should be list, but received {type(inputs)}"
309 310 311 312 313 314
            assert len(inputs_spec) == len(
                inputs
            ), "the number of `inputs_spec` should be equal to `inputs`'s."
            for input_spec, input in zip(inputs_spec, inputs):
                if input_spec.shape != input.shape:
                    input.desc.set_shape(input_spec.shape)
315
        if labels_spec:
316 317 318 319 320
            assert isinstance(
                labels_spec, list
            ), "labels should be list, but received {}".format(
                type(labels_spec)
            )
321 322
            assert isinstance(
                labels, list
323
            ), f"labels should be list, but received {type(labels)}"
324 325 326 327 328 329 330
            assert len(labels_spec) == len(
                labels
            ), "the number of `labels_spec` should be equal to `labels`'s."
            for label_spec, label in zip(labels_spec, labels):
                if label_spec.shape != label.shape:
                    label.desc.set_shape(label_spec.shape)

331 332
        return inputs, labels

333
    def _prepare_reader(self, feed_list=[]):
334
        dist_context = self._dist_contexts[self._mode]
335
        dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
336 337 338 339
        dist_main_block = dist_main_prog.global_block()

        # NOTE: this list may be changed if Paddle changes the existing rules.
        related_reader_ops = [
340 341 342
            "create_py_reader",
            "create_double_buffer_reader",
            "read",
343 344 345 346 347 348 349 350 351 352 353 354 355
        ]
        # remove the first three ops if multiple run fit/evaluate/predict
        if dist_main_block.ops[0].type == 'create_py_reader':
            for i in range(len(related_reader_ops)):
                if dist_main_block.ops[0].type in related_reader_ops:
                    dist_main_block._remove_op(0, sync=False)
        dist_main_block._sync_with_cpp()
        # Step 1: find the reader ops
        reader_op_indices = []
        for idx, op in enumerate(dist_main_block.ops):
            if op.type in related_reader_ops:
                reader_op_indices.append(idx)
        # Step 2: insert the new reader ops to cpp
356 357
        # record the read ops' desc to insert to program of forward task_node
        read_ops_desc = []
358 359 360 361
        new_reader_ops = []
        for idx in reversed(reader_op_indices):
            new_op_desc = dist_main_block.desc._prepend_op()
            new_op_desc.copy_from(dist_main_block.ops[idx].desc)
362
            read_ops_desc.append(new_op_desc)
363 364 365
            new_op = Operator(
                dist_main_block, new_op_desc, type=new_op_desc.type()
            )
366 367 368 369 370 371 372 373 374 375 376 377 378 379 380
            new_reader_ops.append(new_op)
            dist_op = DistributedOperator(new_op)
            dist_context.add_dist_op_for_program(dist_op)
        # Step 3: insert the new reader ops to python
        for new_op in new_reader_ops:
            dist_main_block.ops.insert(0, new_op)
        for i in range(len(reader_op_indices)):
            reader_op_indices[i] += len(reader_op_indices)
        # Step 4: remove the old reader ops from python and cpp
        for idx in reversed(reader_op_indices):
            op = dist_main_block.ops.pop(idx)
            dist_main_block.desc._remove_op(idx, idx + 1)
        dist_main_block._sync_with_cpp()
        self._has_prepared_reader[self._mode] = True

381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403
        # Insert read op to forward TaskNode if 1F1B pass is setted
        if self.main_program._pipeline_opt:
            assert "tasks" in self.main_program._pipeline_opt["fleet_opt"]
            fleet_opt = self.main_program._pipeline_opt["fleet_opt"]
            fwd_task = fleet_opt["tasks"][0]
            fwd_prog = fwd_task.get_program()
            fwd_block = fwd_prog.global_block()

            for var in feed_list:
                if var.name not in fwd_block.vars:
                    fwd_block._clone_variable(var)

            for op_desc in read_ops_desc:
                new_op_desc = fwd_block.desc._prepend_op()
                new_op_desc.copy_from(op_desc)
                new_op = Operator(
                    fwd_block, new_op_desc, type=new_op_desc.type()
                )
                fwd_block.ops.insert(0, new_op)

            fwd_block._sync_with_cpp()
            fwd_task.set_program(fwd_prog)

404 405 406 407 408
    def _prepare_feed(self, data, user_feeds, mode):
        feeds = {}
        if data is not None:
            if isinstance(data, (list, tuple)):
                if len(data) == 1 and isinstance(data[0], dict):
409 410
                    for name, value in data[0].items():
                        feeds[name] = value
411
                else:
412
                    raise ValueError(f"Unsupported data {data}")
413
            elif isinstance(data, dict):
414 415
                for name, value in data.items():
                    feeds[name] = value
416
            else:
417
                raise ValueError(f"Unsupported data {data}")
418
        if user_feeds is not None:
419 420 421 422 423
            assert isinstance(
                user_feeds, dict
            ), "user_feeds must be a dict, but receive {}".format(
                type(user_feeds).__name__
            )
424 425
            for name, data in user_feeds.items():
                feeds[name] = data
426 427
        return feeds

428
    def _prepare_fetch(self, user_fetches, mode):
429
        if user_fetches is not None:
430 431 432 433 434
            assert isinstance(
                user_fetches, list
            ), "user_fetches must be a list, but receive {}".format(
                type(user_fetches).__name__
            )
435 436
        else:
            user_fetches = []
437
        fetch_names = []
438
        fetch_indices = []
439

440 441
        def _process_fetch_group(group_name, var_list):
            group_indices = []
442
            for var in var_list:
443 444 445 446 447 448
                # Remove duplicate var_names
                if self._is_local_var(var):
                    var_name = _to_name_str(var)
                    if var_name not in fetch_names:
                        fetch_names.append(var_name)
                    group_indices.append(fetch_names.index(var_name))
449 450
            if not group_indices:
                fetch_names.append([])
451 452
            fetch_indices.append(group_indices)

453 454
        dist_context = self._dist_contexts[mode]
        fetch_vars = dist_context.serial_fetch_vars
455
        if mode != "predict":
456
            _process_fetch_group("loss", fetch_vars["loss"])
457
        if mode != "predict":
458
            metrics = fetch_vars["metrics"]
459 460 461
            for i, var_list in enumerate(metrics):
                _process_fetch_group("metrics_" + str(i), var_list)
        if mode == "predict":
462
            _process_fetch_group("outputs", fetch_vars["outputs"])
463 464 465
        for usr_fetch in user_fetches:
            var_name = _to_name_str(usr_fetch)
            fetch(var_name)
466 467 468
        user_fetches_collection = [
            item[1] for item in get_collection(CollectionNames.FETCHES)
        ]
469
        var_list = user_fetches_collection or []
470 471 472
        _process_fetch_group("fetches", var_list)
        return fetch_names, fetch_indices

473 474 475 476 477 478 479 480 481 482
    def _prepare_logger(
        self,
        outs,
        epoch=None,
        step=None,
        lr=None,
        fetch_names=None,
        fetch_indices=None,
        mode=None,
    ):
Z
zhaoyingli 已提交
483
        logs = {}
484
        if epoch is not None:
Z
zhaoyingli 已提交
485
            logs["epoch"] = epoch
486
        if step is not None:
Z
zhaoyingli 已提交
487
            logs["step"] = step + 1
488
        if lr is not None:
Z
zhaoyingli 已提交
489
            logs["lr"] = lr
490 491
        group_idx = 0
        if mode != "predict":
Z
zhaoyingli 已提交
492
            # logging loss
493
            loss_indices = fetch_indices[group_idx]
Z
zhaoyingli 已提交
494
            assert len(loss_indices) <= 1
495
            for idx in loss_indices:
Z
zhaoyingli 已提交
496
                logs["loss"] = outs[idx][0]
497
            group_idx += 1
Z
zhaoyingli 已提交
498
            # logging metrics
499 500
            dist_context = self._dist_contexts[mode]
            metric_vars = dist_context.serial_fetch_vars["metrics"]
501 502 503 504 505 506 507 508 509
            if metric_vars:
                for metric in self._metrics:
                    metrics_indices = fetch_indices[group_idx]
                    metric_out = []
                    for idx in metrics_indices:
                        metric_out.append(outs[idx])
                    if metric_out:
                        metric.update(*metric_out)
                        results = metric.accumulate()
510
                        for i, res in enumerate(auto_utils.to_list(results)):
Z
zhaoyingli 已提交
511
                            logs[metric.name()[i]] = res
512
                    group_idx += 1
Z
zhaoyingli 已提交
513 514 515 516 517 518 519
        # logging outputs
        elif mode == "predict":
            outputs_indices = fetch_indices[group_idx]
            logs_out = {}
            for idx in outputs_indices:
                logs_out["out%d" % (idx)] = outs[idx]
            logs["outputs"] = logs_out
520 521
            group_idx += 1
        # logging user fetches
Z
zhaoyingli 已提交
522 523
        collect_fetches = get_collection(CollectionNames.FETCHES)
        logs_fetch = {}
524 525 526 527
        for name, var_name in collect_fetches:
            if var_name in fetch_names:
                idx = fetch_names.index(var_name)
                logs_fetch[name or var_name] = outs[idx]
Z
zhaoyingli 已提交
528 529
        logs["fetches"] = logs_fetch
        return logs
530

531
    def _prepare_program(self, mode, init_parameters=True):
532 533 534 535 536 537
        # Do the build process
        self._build(mode)
        # Do the planning process
        self._plan(mode)
        # Do the parallel process
        self._parallel(mode)
538 539 540 541 542
        # Init comm
        self._init_comm()
        if init_parameters:
            # startup program
            self._initialize(mode)
543 544
        self._has_prepared[mode] = True

545
    def _build(self, mode):
546
        if in_dygraph_mode() or self._dygraph_mode:
547
            paddle.disable_static()
548 549 550
            self._dygraph_mode = True
            self._logger.info("Building model with 'to_static' method.")

551
            self.program_helper = ProgramHelper(
552 553 554 555 556
                self._model,
                self._loss,
                self._metrics,
                self._inputs_spec,
                self._labels_spec,
557
            )
558
            # build forward main program
559 560
            with utils.unique_name.guard():
                self.program_helper.build_program(mode)
561

562 563 564
            self.concrete_program = self.program_helper.concrete_program
            serial_main_prog = self.program_helper.main_program
            serial_startup_prog = self.program_helper.startup_program
565

566 567
            self._inputs = self.program_helper.input_vars
            self._labels = self.program_helper.label_vars
568
            outputs = self.program_helper.output_vars
569
            self._losses = self.program_helper.loss_vars
570
            metrics = self.program_helper.metric_vars
571

572
            paddle.enable_static()
573
        else:
574 575 576
            # build program in static mode
            dist_context = self._dist_contexts.get(mode, None)
            if dist_context is not None:
577 578
                return

579
            outputs = []
580
            metrics = []
581
            self._losses = []
582 583
            serial_main_prog = self._orig_main_prog.clone()
            serial_startup_prog = self._orig_startup_prog.clone()
584
            if not self._skip_build:
585 586 587
                with static.program_guard(
                    serial_main_prog, serial_startup_prog
                ), utils.unique_name.guard():
588 589 590 591 592 593 594
                    self._inputs = [
                        s._create_feed_layer() for s in self._inputs_spec
                    ]
                    self._labels = [
                        s._create_feed_layer() for s in self._labels_spec
                    ]

595
                    outputs = auto_utils.to_list(self._model(*self._inputs))
596

597
                    if mode != "predict" and self._loss:
598 599 600 601 602
                        assert isinstance(
                            self._loss, paddle.nn.Layer
                        ) or callable(
                            self._loss
                        ), "the type of `loss` of the Engine arguments should be sub classes of `paddle.nn.Layer` or any callable function."
603
                        self._losses = auto_utils.to_list(
604 605
                            self._loss(*(outputs + self._labels))
                        )
606

607
                    if mode != "predict" and (outputs or self._labels):
608 609
                        for metric in self._metrics:
                            metrics.append(
610
                                auto_utils.to_list(
611 612
                                    metric.compute(*(outputs + self._labels))
                                )
613
                            )
Z
zhaoyingli 已提交
614
            elif mode == "train":
615 616 617
                assert isinstance(
                    self._loss, Variable
                ), "the type of `loss` of the Engine arguments should be Variable."
618
                self._losses = auto_utils.to_list(self._loss)
619 620 621 622 623 624 625

        default_ctx = get_default_distributed_context()
        if not default_ctx.has_annotation:
            # We build the world process group because the data parallel
            # needs all ranks by default.
            new_process_group(list(range(self._nranks)))
            default_ctx.data_parallel = True
626 627 628 629 630 631
            self._inputs = [
                auto_utils.set_data_parallel(var) for var in self._inputs
            ]
            self._labels = [
                auto_utils.set_data_parallel(var) for var in self._labels
            ]
632

633
        feed_vars = {"inputs": self._inputs, "labels": self._labels}
634 635

        fetch_vars = {
636
            "outputs": paddle.utils.flatten(outputs),
637
            "loss": self._losses,
638
            "metrics": metrics,
639 640
        }

641 642 643
        if mode != "train":
            serial_main_prog = serial_main_prog.clone(for_test=True)

644 645 646
        auto_utils.set_recompute_segments(
            self._model, self._losses, self._strategy, serial_main_prog
        )
647
        self._dist_contexts[mode] = DistributedContext(
648 649 650
            serial_main_prog,
            serial_startup_prog,
            self._optimizer,
651 652 653 654 655 656 657 658 659 660 661
            self._losses,
            feed_vars,
            fetch_vars,
            self._cluster,
            self._strategy,
        )
        self._fwd_dist_contexts[mode] = DistributedContext(
            serial_main_prog,
            serial_startup_prog,
            self._optimizer,
            self._losses,
662 663 664 665 666
            feed_vars,
            fetch_vars,
            self._cluster,
            self._strategy,
        )
667
        self._dist_contexts[mode].gradient_scale = self._strategy.gradient_scale
668
        self._fwd_main_progs[mode] = serial_main_prog.clone()
669

670 671 672
    def _optimization_tuning(self, mode, dataset, batch_size):
        if not self._tuning.enable:
            raise ValueError("Please set `tuning.enable=True`.")
673

674 675 676 677 678 679 680 681
        assert mode == "train"
        # Do the build process
        self._build(mode)
        # Do the planning process
        self._plan(mode)

        dataset.dp_world_size = self._dp_world_sizes
        dataset.dp_rank = self._dp_ranks
682 683

        from .tuner.optimization_tuner import OptimizationTuner
684 685 686 687 688 689 690 691 692

        self._optimization_tuner = OptimizationTuner(
            self._dist_contexts[mode],
            dataset,
            self._inputs_spec,
            self._labels_spec,
            batch_size=batch_size,
            rank=self._cur_rank,
        )
693 694 695

        self._optimization_tuner.tune()

696
        if self._tuning.run_after_tuning:
697 698
            # update the strategy
            self._dist_contexts[
699 700
                mode
            ]._strategy = self._optimization_tuner.get_best_config()
701

702 703 704 705 706 707
    def _plan(self, mode):
        if self._planned_mode is None:
            self._planned_mode = mode
        else:
            self._init_dist_context(mode)

708 709
        self._planners[mode] = Planner(mode, self._dist_contexts[mode])
        self._planners[mode].plan()
710

711 712 713 714
        # infer data parallel info
        inputs_var = self._dist_contexts[mode].serial_feed_vars["inputs"]
        labels_var = self._dist_contexts[mode].serial_feed_vars["labels"]
        block = self._dist_contexts[mode].serial_main_program.global_block()
715
        # TODO: check this feed_list
716 717 718 719 720
        feed_list = []
        for var in inputs_var + labels_var:
            if var.name in block.vars:
                feed_list.append(block.vars[var.name])

721 722
        self._dp_world_sizes = []
        self._dp_ranks = []
723
        for feed_var in feed_list:
724
            dp_world_size, dp_rank = auto_utils.get_input_split_info(
725
                self._cur_rank, feed_var, self._dist_contexts[mode]
726
            )
727 728
            self._dp_world_sizes.append(dp_world_size)
            self._dp_ranks.append(dp_rank)
729

730
    def _parallel(self, mode, all_ranks=False):
731 732
        # Parallelize program based on the planner's results
        # For now, the completer has to be passed to the planner,
C
chenxujun 已提交
733
        # because we may use it to complete the annotation of the backward and update.
734
        parallelizer = Parallelizer(
Y
yuehuayingxueluo 已提交
735 736 737
            mode,
            self._planners[mode].completer,
            self._dist_contexts[mode],
738
        )
739 740 741 742
        if not all_ranks:
            parallelizer.parallel(self._cur_rank)
        else:
            parallelizer.parallel_all()
743 744

    def _init_dist_context(self, mode):
745
        # Init dist_context['mode'] with the first planned dist_context
746 747 748 749 750 751 752 753 754 755
        # to guarantee that train/eval/predict mode have same parallel strategy
        dist_context = self._dist_contexts[mode]
        origin_main_prog = dist_context._original_serial_main_program
        ref_mode = self._planned_mode
        ref_dist_context = self._dist_contexts[ref_mode]
        ref_origin_main_prog = ref_dist_context._original_serial_main_program
        ref_blocks = ref_origin_main_prog.blocks
        for ib, block in enumerate(origin_main_prog.blocks):
            for iop, op in enumerate(block.ops):
                ref_op = ref_blocks[ib].ops[iop]
756 757 758 759 760 761 762 763
                assert (
                    op.type == ref_op.type
                ), "'{}' mode op '{}' is different with '{}' op '{}'. ".format(
                    mode, op.type, ref_mode, ref_op.type
                )
                ref_op_dist_attr = (
                    ref_dist_context.get_op_dist_attr_for_program(ref_op)
                )
764 765
                dist_context.set_op_dist_attr_for_program(op, ref_op_dist_attr)

766
    def _init_comm(self):
767 768 769 770
        if self._nranks > 1:
            # Traverse different rank programs and traverse each op of them,
            # instantiate communication by process_mapping.
            all_process_groups = get_all_process_groups()
771

772
            if self._strategy.auto_mode == "full":
773
                auto_utils.initialize_pg_in_full_mode(
774
                    all_process_groups, self._cur_rank
775
                )
776 777
            else:
                for process_group in all_process_groups:
778
                    if self._cur_rank not in process_group.ranks:
779 780
                        continue
                    process_group.instantiate()
781

782
    def _initialize(self, mode):
783
        self._place = _get_device()
784
        if isinstance(self._place, paddle.framework.CUDAPlace):
785 786 787
            self._place = paddle.framework.CUDAPlace(
                paddle.distributed.ParallelEnv().dev_id
            )
788

789 790 791 792 793
        if self._strategy.seed:
            paddle.seed(self._strategy.seed + self._dp_ranks[0])
            np.random.seed(self._strategy.seed + self._dp_ranks[0])
            random.seed(self._strategy.seed + self._dp_ranks[0])

794
        dist_context = self._dist_contexts[mode]
795
        if self._dygraph_mode:
796
            dist_main_program = dist_context.dist_main_programs[self._cur_rank]
797 798 799
            self.program_helper.init(
                dist_main_program, self._place, dist_context
            )
800

801
        if self._executor is None:
802
            self._executor = paddle.static.Executor(self._place)
803
            uninitialized = []
804 805 806
            dist_startup_prog = dist_context.dist_startup_programs[
                self._cur_rank
            ]
807 808 809 810 811 812 813 814
            for var in dist_startup_prog.list_vars():
                scope_var = global_scope().find_var(var.name)
                if scope_var and scope_var.get_tensor()._is_initialized():
                    continue
                uninitialized.append(var)
            if uninitialized:
                prune_startup_prog = dist_startup_prog._prune(uninitialized)
                self._executor.run(prune_startup_prog)
815

816
            if hasattr(self, "_state_dict") and hasattr(self, "_dist_attr"):
817 818 819
                self._set_state_dict(
                    mode, self._strict, self._state_dict, self._dist_attr
                )
820 821

        if self._strategy.reinit:
Z
zhaoyingli 已提交
822
            self._logger.info("NOTE: parameters will be re-initialized.")
823 824 825
            dist_startup_prog = dist_context.dist_startup_programs[
                self._cur_rank
            ]
826 827
            self._executor.run(dist_startup_prog)

828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845
    def fit(
        self,
        train_data,
        train_sample_split=None,
        batch_size=1,
        epochs=1,
        steps_per_epoch=None,
        log_freq=10,
        save_dir=None,
        save_freq=1,
        valid_data=None,
        valid_sample_split=None,
        valid_freq=1,
        valid_steps=None,
        collate_fn=None,
        callbacks=None,
        verbose=2,
    ):
846 847 848 849 850 851 852 853
        """
        Trains the model for a fixed number of epochs. If `valid_data` is set,
        evaluation will be done at the end of each epoch.

        Args:
            train_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None.
            train_sample_split (int, optional): Each sample of the train dataset is assumed
                to be a (input, label) pair by default and has two items. If each sample has
854
                more than two items, train_sample_split specifies how to split these items into
855
                input and label. The items before it are input and the left are label. Default: None.
856
            batch_size (int, optional): The batch size of train_data and valid_data if provided.
857 858 859
                The user's data will be used directly without batching if set to None. Default: 1.
            epochs (int, optional): The number of epochs to train the model. Default: 1.
            steps_per_epoch (int, optional): The total number of steps (batches of samples)
860
                is executed in one epoch before stating the next one. If None, it is equal to
861 862
                the number samples in your dataset divided by the batch size. Default: None.
            valid_data (Dataset, optional): An instance of paddle paddle.io.Dataset used for
863
                evaluation at the end of epoch. No evaluation will be done if set to None.
864
                Default: None. (Unsupported for now)
865
            valid_freq (int, optional): Only relevant if valid_data is provided. This specifies
866 867
                how many training epochs before a new evaluation is performed. Default: 1.
            valid_sample_split (int, optional): Only relevant if valid_data is provided.
868 869
                Each sample of the valid dataset is assumed to be a (input, label) pair
                by default and has two items. If each sample has more than two items,
870 871 872
                valid_sample_split specifies how to split these items into input and label.
                The items before it are input and the left are label. Default: None.
            valid_steps (int, optional): Only relevant if valid_data is provided.
873 874
                It is the total number of steps (batches of samples) to draw before
                stopping validation at the end of every epoch. If None, validation will run until the
875 876 877 878
                `valid_data` dataset is exhausted. The validation will start from the
                beginning of the dataset at each epoch. Default: None.
            collate_fn(callable, optional): function to generate mini-batch data by merging
                the sample list, None for only stack each fields of sample in axis
879
                0. Default None.
880 881 882 883 884 885 886 887 888 889 890 891
            callbacks (Callback|None, optional): A list of `Callback` instances to apply
                during training. Default: None. (Unused for now)

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle
                import paddle.vision.transforms as T
892
                from paddle.distributed.fleet import auto
893 894 895 896 897 898 899 900 901
                from paddle.vision.datasets import MNIST

                transform = T.Compose([
                    T.Transpose(),
                    T.Normalize([127.5], [127.5])
                ])
                train_dataset = MNIST(mode='train', transform=transform)

                model = paddle.vision.models.LeNet()
902
                loss = paddle.nn.CrossEntropyLoss()
903 904 905 906
                optimizer = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=model.parameters())
                metrics = paddle.metric.Accuracy(topk=(1, 2))

907
                engine = auto.Engine(model, loss, optimizer, metrics)
908 909 910 911
                engine.fit(train_dataset,
                           epochs=2,
                           batch_size=64)
        """
912 913
        self._mode = 'train'
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
914 915
            train_data, train_sample_split, batch_size
        )
916 917
        if not self._has_prepared[self._mode]:
            self._prepare_program(self._mode)
Z
zhaoyingli 已提交
918
        else:
919
            self._switch_mode(self._mode)
Z
zhaoyingli 已提交
920

921 922 923 924 925 926 927
        train_dataloader = self._prepare_dataloader_from_generator(
            dataset=train_data,
            capacity=70,
            iterable=False,
            batch_size=batch_size,
            epochs=epochs,
            steps_per_epoch=steps_per_epoch,
928 929
            collate_fn=collate_fn,
        )
Z
zhaoyingli 已提交
930

931
        fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode)
Z
zhaoyingli 已提交
932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957

        cbks = config_callbacks(
            callbacks,
            engine=self,
            batch_size=batch_size,
            epochs=epochs,
            steps=train_dataloader._steps,
            log_freq=log_freq,
            save_freq=save_freq,
            save_dir=save_dir,
            verbose=verbose,
            metrics=self._metrics_name(),
            acc_step=self._k_steps,
        )

        cbks.on_begin('train')
        for epoch in range(epochs):
            logs = {}
            cbks.on_epoch_begin(epoch)
            for step, _ in enumerate(train_dataloader):
                cbks.on_batch_begin('train', step, logs)
                try:
                    outs = self._executor.run(
                        self.main_program,
                        fetch_list=fetch_names,
                        use_program_cache=self._strategy.use_cache,
958 959
                        return_numpy=self._strategy.return_numpy,
                    )
Z
zhaoyingli 已提交
960 961
                except core.EOFException:
                    break
962
                lr = auto_utils.get_lr(self._optimizer)
963 964 965 966 967 968 969 970 971
                logs = self._prepare_logger(
                    outs,
                    epoch,
                    step,
                    lr,
                    fetch_names,
                    fetch_indices,
                    self._mode,
                )
Z
zhaoyingli 已提交
972 973 974
                cbks.on_batch_end('train', step, logs)

            if valid_data and (epoch + 1) % valid_freq == 0:
975 976 977 978 979 980 981 982 983 984
                val_logs = self.evaluate(
                    valid_data,
                    valid_sample_split,
                    batch_size,
                    valid_steps,
                    log_freq,
                    collate_fn,
                    callbacks,
                    verbose,
                )
Z
zhaoyingli 已提交
985
                val_logs = {
986
                    "val_" + name: val for name, val in val_logs.items()
Z
zhaoyingli 已提交
987 988 989 990 991 992 993 994 995 996
                }
                logs.update(val_logs)
                self._switch_mode("train")
            else:
                self._reset_metrics()

            cbks.on_epoch_end(epoch, logs)

        cbks.on_end('train', logs)
        return self.history
997

998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008
    def evaluate(
        self,
        valid_data,
        valid_sample_split=None,
        batch_size=1,
        steps=None,
        log_freq=10,
        collate_fn=None,
        callbacks=None,
        verbose=2,
    ):
1009 1010 1011 1012
        """
        Evaluate the loss and metrics of the model on evaluation data.

        Args:
1013 1014
            valid_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None.
            valid_sample_split (int, optional): Each sample of the eval dataset is assumed
1015
                to be a (input, label) pair by default and has two items. If each sample has
1016
                more than two items, valid_sample_split specifies how to split these items into
1017
                input and label. The items before it are input and the left are label. Default: None.
1018
            batch_size (int, optional): The batch size of valid_data. The user's data will
1019
                be used directly without batching if set to None. Default: 1.
1020 1021
            steps (int, optional): It is the total number of steps (batches of samples) to draw before
                stopping evaluation. If None, evaluation will run until the `valid_data` dataset is exhausted.
1022 1023 1024 1025 1026
                The evaluation will start from the beginning of the dataset in each run. Default: None.
            collate_fn(callable, optional): function to generate mini-batch data by merging
                the sample list, None for only stack each fields of sample in axis
                0. Default None.
            callbacks (Callback|None, optional): A list of `Callback` instances to apply
1027
                during evaluating. Default: None. (Unused for now)
1028 1029 1030 1031 1032 1033 1034 1035 1036 1037

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle
                import paddle.vision.transforms as T
1038
                from paddle.distributed.fleet import auto
1039 1040 1041 1042 1043 1044 1045 1046 1047
                from paddle.vision.datasets import MNIST

                transform = T.Compose([
                    T.Transpose(),
                    T.Normalize([127.5], [127.5])
                ])
                valid_dataset = MNIST(mode='test', transform=transform)

                model = paddle.vision.models.LeNet()
1048
                loss = paddle.nn.CrossEntropyLoss()
1049 1050
                metrics = paddle.metric.Accuracy(topk=(1, 2))

1051
                engine = auto.Engine(model, loss, metrics=metrics)
1052 1053 1054
                engine.evaluate(valid_dataset, batch_size=64)

        """
1055 1056
        self._mode = 'eval'
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
1057 1058
            valid_data, valid_sample_split, batch_size
        )
1059 1060
        if not self._has_prepared[self._mode]:
            self._prepare_program(self._mode)
Z
zhaoyingli 已提交
1061
        else:
1062
            self._switch_mode(self._mode)
Z
zhaoyingli 已提交
1063

1064 1065 1066 1067 1068 1069
        valid_dataloader = self._prepare_dataloader_from_generator(
            dataset=valid_data,
            capacity=70,
            iterable=False,
            batch_size=batch_size,
            steps_per_epoch=steps,
1070 1071
            collate_fn=collate_fn,
        )
Z
zhaoyingli 已提交
1072

1073
        fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode)
1074

Z
zhaoyingli 已提交
1075 1076 1077 1078 1079 1080 1081 1082 1083 1084
        cbks = config_callbacks(
            callbacks,
            engine=self,
            batch_size=batch_size,
            log_freq=log_freq,
            verbose=verbose,
            metrics=self._metrics_name(),
        )

        eval_steps = valid_dataloader._steps
1085 1086 1087
        cbks.on_begin(
            'eval', {'steps': eval_steps, 'metrics': self._metrics_name()}
        )
Z
zhaoyingli 已提交
1088
        logs = {}
1089
        for step, _ in enumerate(valid_dataloader):
Z
zhaoyingli 已提交
1090
            cbks.on_batch_begin('eval', step, logs)
1091
            try:
1092 1093
                outs = self._executor.run(
                    self.main_program,
1094
                    fetch_list=fetch_names,
1095
                    use_program_cache=self._strategy.use_cache,
1096 1097
                    return_numpy=self._strategy.return_numpy,
                )
1098
            except core.EOFException:
1099
                break
1100 1101 1102
            logs = self._prepare_logger(
                outs, None, step, None, fetch_names, fetch_indices, self._mode
            )
Z
zhaoyingli 已提交
1103 1104
            cbks.on_batch_end('eval', step, logs)
        cbks.on_end('eval', logs)
1105
        self._reset_metrics()
Z
zhaoyingli 已提交
1106
        return logs
1107

1108 1109 1110 1111 1112 1113 1114 1115 1116 1117
    def predict(
        self,
        test_data,
        test_sample_split=None,
        batch_size=1,
        steps=None,
        collate_fn=None,
        callbacks=None,
        verbose=2,
    ):
1118 1119 1120 1121 1122 1123 1124
        """
        Compute the output predictions on testing data.

        Args:
            test_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None.
            test_sample_split (int, optional): Each sample of the test dataset is assumed
                to be a (input, label) pair by default and has two items. If each sample has
1125
                more than two items, test_sample_split specifies how to split these items into
1126 1127 1128
                input and label. The items before it are input and the left are label. Default: None.
            batch_size (int, optional): The batch size of test_data. The user's data will
                be used directly without batching if set to None. Default: 1.
1129 1130
            steps (int, optional): It is the total number of steps (batches of samples) to draw before
                stopping predict. If None, predict will run until the `test_data` dataset is exhausted.
1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146
                The predict will start from the beginning of the dataset in each run. Default: None.
            collate_fn(callable, optional): function to generate mini-batch data by merging
                the sample list, None for only stack each fields of sample in axis
                0. Default None.
            callbacks (Callback|None, optional): A list of `Callback` instances to apply
                during testing. Default: None. (Unused for now)

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle
                import paddle.vision.transforms as T
1147
                from paddle.distributed.fleet import auto
1148 1149 1150 1151 1152 1153 1154 1155 1156 1157
                from paddle.vision.datasets import MNIST

                transform = T.Compose([
                    T.Transpose(),
                    T.Normalize([127.5], [127.5])
                ])
                valid_dataset = MNIST(mode='test', transform=transform)

                model = paddle.vision.models.LeNet()

1158
                engine = auto.Engine(model)
1159 1160
                engine.predict(valid_dataset, batch_size=64)
        """
1161 1162
        self._mode = 'predict'
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
1163 1164
            test_data, test_sample_split, batch_size
        )
1165 1166
        if not self._has_prepared[self._mode]:
            self._prepare_program(self._mode)
Z
zhaoyingli 已提交
1167
        else:
1168
            self._switch_mode(self._mode)
Z
zhaoyingli 已提交
1169

1170 1171 1172 1173 1174 1175
        test_dataloader = self._prepare_dataloader_from_generator(
            dataset=test_data,
            capacity=70,
            iterable=False,
            batch_size=batch_size,
            steps_per_epoch=steps,
1176 1177
            collate_fn=collate_fn,
        )
Z
zhaoyingli 已提交
1178

1179
        fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode)
1180

Z
zhaoyingli 已提交
1181 1182 1183 1184 1185
        outputs = []
        cbks = config_callbacks(callbacks, engine=self, verbose=verbose)
        test_steps = test_dataloader._steps
        cbks.on_begin('predict', {'steps': test_steps})
        logs = {}
1186
        for step, _ in enumerate(test_dataloader):
Z
zhaoyingli 已提交
1187
            cbks.on_batch_begin('predict', step, logs)
1188
            try:
1189 1190
                outs = self._executor.run(
                    self.main_program,
1191
                    fetch_list=fetch_names,
1192
                    use_program_cache=self._strategy.use_cache,
1193 1194
                    return_numpy=self._strategy.return_numpy,
                )
1195
            except core.EOFException:
1196
                break
1197 1198 1199
            logs = self._prepare_logger(
                outs, None, step, None, fetch_names, fetch_indices, self._mode
            )
Z
zhaoyingli 已提交
1200 1201 1202 1203 1204
            cbks.on_batch_end('predict', step, logs)
            outputs.append(list(logs["outputs"].values()))
        cbks.on_end('predict', logs)
        return outputs

1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221
    def dataloader(
        self,
        dataset,
        batch_size=1,
        shuffle=False,
        drop_last=False,
        collate_fn=None,
        num_workers=0,
        use_buffer_reader=True,
        use_shared_memory=True,
        timeout=0,
        worker_init_fn=None,
        epochs=1,
        steps_per_epoch=None,
        sample_split=1,
        mode=None,
    ):
1222 1223 1224
        if mode is not None:
            self.to_mode(mode)
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
1225 1226
            dataset, sample_split, batch_size
        )
1227 1228
        if not self._has_prepared[self._mode]:
            self._prepare_program(self._mode)
1229
        else:
1230
            self._switch_mode(self._mode)
1231

1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244
        dataloader = self._prepare_dataloader(
            dataset,
            return_list=False,
            batch_size=batch_size,
            shuffle=shuffle,
            drop_last=drop_last,
            collate_fn=collate_fn,
            num_workers=num_workers,
            use_buffer_reader=use_buffer_reader,
            use_shared_memory=use_shared_memory,
            timeout=timeout,
            worker_init_fn=worker_init_fn,
            epochs=epochs,
1245 1246
            steps_per_epoch=steps_per_epoch,
        )
1247 1248
        return dataloader

1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263
    def dataloader_from_generator(
        self,
        dataset,
        capacity=70,
        use_double_buffer=True,
        iterable=True,
        use_multiprocess=False,
        drop_last=True,
        batch_size=1,
        epochs=1,
        steps_per_epoch=None,
        collate_fn=None,
        sample_split=1,
        mode=None,
    ):
1264 1265 1266
        if mode is not None:
            self.to_mode(mode)
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
1267 1268
            dataset, sample_split, batch_size
        )
1269 1270 1271 1272
        if not self._has_prepared[self._mode]:
            self._prepare_program(self._mode)
        else:
            self._switch_mode(self._mode)
1273

1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284
        dataloader = self._prepare_dataloader_from_generator(
            dataset=dataset,
            capacity=capacity,
            use_double_buffer=use_double_buffer,
            iterable=iterable,
            return_list=False,
            use_multiprocess=use_multiprocess,
            drop_last=drop_last,
            batch_size=batch_size,
            epochs=epochs,
            steps_per_epoch=steps_per_epoch,
1285 1286
            collate_fn=collate_fn,
        )
1287 1288
        return dataloader

1289 1290 1291 1292 1293 1294 1295 1296 1297
    def prepare(
        self,
        inputs_spec=None,
        labels_spec=None,
        inputs=None,
        labels=None,
        main_program=None,
        startup_program=None,
        mode=None,
1298
        init_parameters=True,
1299
    ):
1300 1301
        if mode is not None:
            self.to_mode(mode)
1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317

        if not self._mode:
            raise ValueError(
                "Please set mode to be prepared with `prepare(mode=...)`"
            )

        if self._has_prepared[self._mode]:
            return

        inputs_spec = self._validate_spec(inputs_spec)
        labels_spec = self._validate_spec(labels_spec)
        inputs = self._validate_vars(inputs)
        labels = self._validate_vars(labels)

        self._orig_main_prog = main_program
        self._orig_startup_prog = startup_program
1318 1319
        if inputs or labels:
            self._skip_build = True
1320 1321
            inputs, labels = self._prepare_data_tensor(
                inputs_spec, labels_spec, inputs, labels
1322
            )
1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333
            if self._orig_main_prog is None:
                self._orig_main_prog = static.default_main_program()
            if self._orig_startup_prog is None:
                self._orig_startup_prog = static.default_startup_program()
        elif inputs_spec or labels_spec:
            self._outside_dataloader = True
            if self._orig_main_prog is None:
                self._orig_main_prog = static.default_main_program()
            if self._orig_startup_prog is None:
                self._orig_startup_prog = static.default_startup_program()
        else:
1334 1335 1336
            assert (
                self._inputs_spec and self._labels_spec
            ), "Please call the dataloader(...) before calling prepare(...)"
1337

1338 1339 1340
        self._inputs_spec, self._labels_spec = inputs_spec, labels_spec
        self._inputs, self._labels = inputs, labels
        if not self._has_prepared[self._mode]:
1341
            self._prepare_program(self._mode, init_parameters)
1342 1343 1344
        else:
            self._switch_mode(self._mode)

1345
    def run(self, data=None, feed=None, fetch_list=None, mode=None):
1346 1347 1348 1349
        if mode is not None:
            self.to_mode(mode)
        feed_dict = self._prepare_feed(data, feed, self._mode)
        fetch_names, fetch_indices = self._prepare_fetch(fetch_list, self._mode)
1350 1351 1352 1353
        if (
            self._outside_dataloader
            and not self._has_prepared_reader[self._mode]
        ):
1354
            self._prepare_reader()
1355 1356 1357 1358 1359 1360 1361 1362 1363 1364
        outs = self._executor.run(
            self.main_program,
            feed=feed_dict,
            fetch_list=fetch_names,
            use_program_cache=self._strategy.use_cache,
            return_numpy=self._strategy.return_numpy,
        )
        logs = self._prepare_logger(
            outs, None, None, None, fetch_names, fetch_indices, self._mode
        )
Z
zhaoyingli 已提交
1365
        return logs
1366

1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382
    def _prepare_dataloader(
        self,
        dataset,
        return_list=True,
        batch_size=1,
        shuffle=False,
        drop_last=False,
        collate_fn=None,
        num_workers=0,
        use_buffer_reader=True,
        use_shared_memory=True,
        timeout=0,
        worker_init_fn=None,
        epochs=1,
        steps_per_epoch=None,
    ):
1383

1384
        if self._strategy.gradient_merge and batch_size is not None:
1385 1386 1387 1388 1389
            assert (
                batch_size % self._k_steps == 0
            ), "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(
                batch_size, self._k_steps
            )
1390
            batch_size //= self._k_steps
1391

1392 1393 1394
        dist_context = self._dist_contexts[self._mode]
        dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
        dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank]
1395
        dist_main_block = dist_main_prog.global_block()
1396

1397 1398 1399 1400
        # NOTE: Get feed_list, then insert dataloader op with sharded var shape.
        # Cause predict_program does not contain labels var,
        # then we will add labels var from serial_program to dist_program,
        # that maintains the length of feed_list equal to the length of dataset's values.
1401 1402
        inputs_var = dist_context.serial_feed_vars["inputs"]
        labels_var = dist_context.serial_feed_vars["labels"]
1403 1404 1405 1406
        feed_list = []
        for var in inputs_var + labels_var:
            if var.name in dist_main_block.vars:
                feed_list.append(dist_main_block.vars[var.name])
1407 1408 1409 1410
            else:
                copy_var = dist_main_block._clone_variable(var, var.persistable)
                copy_var.desc.set_original_id(var.desc.original_id())
                feed_list.append(copy_var)
1411 1412

        # insert read op at the end of program
1413
        places = paddle.static.cuda_places()
1414
        with static.program_guard(dist_main_prog, dist_startup_prog):
1415
            dataloader = DistributedDataLoader(
1416
                dataset,
1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431
                feed_list=feed_list,
                places=places,
                return_list=return_list,
                batch_size=batch_size,
                shuffle=shuffle,
                drop_last=drop_last,
                collate_fn=collate_fn,
                num_workers=num_workers,
                use_buffer_reader=use_buffer_reader,
                use_shared_memory=use_shared_memory,
                timeout=timeout,
                worker_init_fn=worker_init_fn,
                epochs=epochs,
                steps_per_epoch=steps_per_epoch,
                split_data=self._strategy.split_data,
1432
                data_parallel_world_size=self._dp_world_sizes,
1433 1434
                data_parallel_rank=self._dp_ranks,
            )
1435

1436 1437
        return dataloader

1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451
    def _prepare_dataloader_from_generator(
        self,
        dataset,
        capacity=None,
        use_double_buffer=True,
        iterable=True,
        return_list=False,
        use_multiprocess=False,
        drop_last=True,
        batch_size=1,
        epochs=1,
        steps_per_epoch=None,
        collate_fn=None,
    ):
1452 1453

        if self._strategy.gradient_merge and batch_size is not None:
1454 1455 1456 1457 1458
            assert (
                batch_size % self._k_steps == 0
            ), "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(
                batch_size, self._k_steps
            )
1459 1460
            batch_size //= self._k_steps

1461 1462 1463
        dist_context = self._dist_contexts[self._mode]
        dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
        dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank]
1464 1465 1466 1467 1468 1469
        dist_main_block = dist_main_prog.global_block()

        # NOTE: Get feed_list, then insert dataloader op with sharded var shape.
        # Cause predict_program does not contain labels var,
        # then we will add labels var from serial_program to dist_program,
        # that maintains the length of feed_list equal to the length of dataset's values.
1470 1471
        inputs_var = dist_context.serial_feed_vars["inputs"]
        labels_var = dist_context.serial_feed_vars["labels"]
1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498
        feed_list = []
        for var in inputs_var + labels_var:
            if var.name in dist_main_block.vars:
                feed_list.append(dist_main_block.vars[var.name])
            else:
                copy_var = dist_main_block._clone_variable(var, var.persistable)
                copy_var.desc.set_original_id(var.desc.original_id())
                feed_list.append(copy_var)

        places = paddle.static.cuda_places()
        with static.program_guard(dist_main_prog, dist_startup_prog):
            dataloader = DistributedDataLoaderFromGenerator(
                dataset=dataset,
                feed_list=feed_list,
                capacity=capacity,
                use_double_buffer=use_double_buffer,
                iterable=iterable,
                return_list=return_list,
                use_multiprocess=use_multiprocess,
                drop_last=drop_last,
                places=places,
                batch_size=batch_size,
                epochs=epochs,
                steps_per_epoch=steps_per_epoch,
                collate_fn=collate_fn,
                split_data=self._strategy.split_data,
                data_parallel_world_size=self._dp_world_sizes,
1499 1500
                data_parallel_rank=self._dp_ranks,
            )
1501
        self._prepare_reader(feed_list)
1502 1503 1504 1505 1506
        return dataloader

    def _tune(self, tune_data, tune_sample_split=None, batch_size=1):
        self._mode = 'train'
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
1507 1508
            tune_data, tune_sample_split, batch_size
        )
1509 1510
        self._optimization_tuning(self._mode, tune_data, batch_size)

1511
    def _validate_spec(self, specs):
1512
        specs = auto_utils.to_list(specs)
1513
        self._k_steps = self._strategy.gradient_merge.k_steps
1514 1515
        if specs is not None:
            for i, spec in enumerate(specs):
1516 1517 1518 1519
                if not isinstance(spec, InputSpec):
                    raise TypeError(
                        "'spec' must be object of class `paddle.static.InputSpec`."
                    )
1520 1521
                if spec.name is None:
                    raise ValueError(
1522 1523 1524 1525
                        "Requires Input[{}].name != None, but receive `None` with {}.".format(
                            i, spec
                        )
                    )
1526
                if self._k_steps > 1:
1527
                    shape = list(spec.shape)
1528 1529 1530 1531 1532
                    assert (
                        shape[0] % self._k_steps == 0
                    ), "Requires batch_size[{}] to be divisible by k_steps[{}].".format(
                        spec.shape[0], self._k_steps
                    )
1533
                    shape[0] //= self._k_steps
1534
                    spec.shape = shape
1535 1536 1537
        return specs or []

    def _validate_vars(self, vars):
1538
        vars = auto_utils.to_list(vars)
1539 1540 1541 1542 1543
        if vars is not None:
            for i, var in enumerate(vars):
                if not isinstance(var, Variable):
                    raise TypeError("'var' must be a `Variable`.")
        return vars or []
1544

1545 1546 1547 1548
    def _is_local_var(self, var):
        var_name = _to_name_str(var)
        return var_name in self.main_program.global_block().vars

1549 1550 1551 1552
    def _reset_metrics(self):
        for metric in self._metrics:
            metric.reset()

Z
zhaoyingli 已提交
1553 1554 1555
    def _metrics_name(self):
        metrics_name = ['loss'] if self._loss else []
        for m in self._metrics:
1556
            metrics_name.extend(auto_utils.to_list(m.name()))
Z
zhaoyingli 已提交
1557 1558
        return metrics_name

1559
    def _switch_mode(self, mode):
1560
        assert (
1561
            mode in self._dist_contexts
1562
        ), f"{mode} model is not ready, please call `prepare()` first."
1563
        self.to_mode(mode)
Z
zhaoyingli 已提交
1564
        self._optimizer = self._dist_contexts[mode]._serial_optimizer
1565

1566
    def to_mode(self, mode):
1567 1568 1569 1570
        assert mode in [
            "train",
            "eval",
            "predict",
1571
        ], f"mode {mode} should be one of ['train', 'eval', 'predict']"
1572 1573
        self._mode = mode

1574 1575
    def _set_state_dict(self, mode, strict, state_dict, dist_attr):
        dist_context = self._dist_contexts[mode]
1576
        program = dist_context.dist_main_programs[self._cur_rank]
1577
        cur_dist_attr = auto_utils.get_dist_attr(program, dist_context)
1578 1579
        converter = Converter(state_dict, dist_attr, cur_dist_attr)
        state_dict = converter.convert(strict=strict)
1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592
        for name, param in program.state_dict().items():
            param_array = np.array(param)
            if name not in state_dict:
                continue
            if param_array.dtype != state_dict[name].dtype:
                self._logger.info(
                    "cast {}'s dtype from '{}' to '{}'".format(
                        name,
                        str(state_dict[name].dtype),
                        str(param_array.dtype),
                    )
                )
                state_dict[name] = state_dict[name].astype(param_array.dtype)
1593 1594 1595
        program.set_state_dict(state_dict)

    def save(self, path, training=True):
1596 1597
        """
        Saves the model, parameters, optimizer state to path.
1598 1599 1600 1601 1602 1603 1604
        If `training` is set to False, only inference model will be saved.

        Args:
            path (str): The file prefix to save model. The format
                is 'dirname/file_prefix' or 'file_prefix'. if empty str.
                A exception will be raised.
            training (bool, optional): Whether to save for training. If not, save
1605
                for inference only. If `training` is set to True, the optimizer state
1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617
                will be saved. Otherwise, only the model and parameters are saved.
                This function will silently overwrite existing file at the target
                location. Default: True.

        Returns:
            None

        Examples:

            .. code-block:: python
                import paddle
                import paddle.vision.transforms as T
1618
                from paddle.distributed.fleet import auto
1619 1620 1621 1622 1623 1624 1625 1626 1627
                from paddle.vision.datasets import MNIST

                transform = T.Compose([
                    T.Transpose(),
                    T.Normalize([127.5], [127.5])
                ])
                train_dataset = MNIST(mode='train', transform=transform)

                model = paddle.vision.models.LeNet()
1628
                loss = paddle.nn.CrossEntropyLoss()
1629 1630 1631 1632
                optimizer = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=model.parameters())
                metrics = paddle.metric.Accuracy(topk=(1, 2))

1633
                engine = auto.Engine(model, loss, optimizer, metrics)
1634 1635 1636 1637
                engine.fit(train_dataset,
                           epochs=1,
                           batch_size=64)
                engine.save("./my_model")
1638

1639
        """
1640
        if training:
1641
            assert self._mode in self._dist_contexts
Z
zhaoyingli 已提交
1642
            dist_context = self._dist_contexts[self._mode]
1643 1644
            serial_program = dist_context.serial_main_program
            dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
1645 1646 1647 1648 1649 1650
            self._saver.save(
                path,
                serial_program=serial_program,
                dist_main_program=dist_main_prog,
                dist_context=dist_context,
            )
1651
        else:
1652 1653 1654 1655 1656
            assert "predict" in self._dist_contexts
            dist_context = self._dist_contexts["predict"]
            feed_vars = dist_context.serial_feed_vars['inputs']
            fetch_vars = dist_context.serial_fetch_vars['outputs']
            dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
1657
            if self._strategy.qat.enable and self._strategy.qat.onnx_format:
1658
                from paddle.static.quantization import QuantWeightPass
1659 1660 1661

                self._logger.info("export quantized model.")
                self._logger.info(
1662
                    f"convert config {self._strategy.qat.to_dict()}"
1663 1664 1665 1666 1667 1668 1669 1670
                )
                test_graph = IrGraph(
                    core.Graph(dist_main_prog.desc), for_test=True
                )
                quant_weight_pass = QuantWeightPass(global_scope(), self._place)
                for sub_graph in test_graph.all_sub_graphs():
                    quant_weight_pass.apply(sub_graph)
                dist_main_prog = test_graph.to_program()
1671 1672 1673 1674 1675 1676 1677
            self._saver.save_inference_model(
                path,
                feed_vars,
                fetch_vars,
                self._executor,
                program=dist_main_prog,
            )
1678

1679 1680 1681 1682 1683 1684
    def load(self, path, strict=True, load_optimizer=True):
        """
        Load the stored model, parameters and optimizer states.

        Args:
            path (str): The prefix of files storing the model states and
1685
                optimizer states.
1686 1687 1688
            strict (bool, optional): Whether to skip the loading of mismatch
                parameter or raise an error when mismatch happens (not found
                the parameter in file storing model states of or receives a
1689
                mismatch shape). Default: True.
1690
            load_optimizer (bool, optional): If True, the stored optimizer
1691
                states is restored. Otherwise, the optimizer states is initialized
1692
                from scratch. Default: True.
1693 1694 1695 1696 1697 1698 1699 1700 1701

        Returns:
            None

        Examples:

            .. code-block:: python
                import paddle
                import paddle.vision.transforms as T
1702
                from paddle.distributed.fleet import auto
1703 1704 1705 1706 1707 1708 1709 1710 1711
                from paddle.vision.datasets import MNIST

                transform = T.Compose([
                    T.Transpose(),
                    T.Normalize([127.5], [127.5])
                ])
                train_dataset = MNIST(mode='train', transform=transform)

                model = paddle.vision.models.LeNet()
1712
                loss = paddle.nn.CrossEntropyLoss()
1713 1714 1715 1716
                optimizer = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=model.parameters())
                metrics = paddle.metric.Accuracy(topk=(1, 2))

1717
                engine = auto.Engine(model, loss, optimizer, metrics)
1718 1719 1720 1721 1722
                engine.fit(train_dataset,
                           epochs=1,
                           batch_size=64)
                engine.save("./my_model")
                engine.load("./my_model")
1723

1724 1725 1726
        """
        self._strict = strict
        self._state_dict, self._dist_attr = self._saver.load(
1727 1728
            path, load_optimizer
        )
1729
        return self._state_dict, self._dist_attr
1730

1731
    def cost(self, inputs_spec=None, labels_spec=None, mode=None):
1732 1733 1734 1735 1736 1737 1738 1739 1740 1741
        """
        Get and Print cost, including memory of every rank,
        max memory among all ranks, and the global cost of one step based on
        communication cost(computation cost is 0 by default).
        In the future, the flops information of every rank and global cost including
        computation cost will be added.

        Args:
            inputs_spec(InputSpec): The specification of inputs. Default: None.
            labels_spec(InputSpec): The specification of labels. Default: None.
1742
            mode (str): The engine mode must be in ["train", "predict", "eval"]. Default: None.
1743 1744 1745 1746 1747 1748 1749

        Returns:
            Return the global execution time (ms) and max memory (B).

        """
        # Check parallel mode
        if self._strategy.auto_mode == "full":
1750
            self._logger.info(
1751 1752 1753 1754 1755
                "The cost will be calcudated in the search process when the auto mode is full."
            )
            return

        # Check mode
1756 1757 1758
        mode = mode if mode is not None else self._mode
        assert mode is not None, "Please set mode."
        if mode not in self._has_prepared:
1759 1760
            raise ValueError(
                "The mode {} is not in accepted modes {}".format(
1761
                    mode, list(self._has_prepared.keys())
1762 1763
                )
            )
1764 1765
        self.to_mode(mode)

1766 1767 1768
        if inputs_spec is not None and not self._has_prepared[mode]:
            self._inputs_spec = self._validate_spec(inputs_spec)
            self._labels_spec = self._validate_spec(labels_spec)
1769 1770 1771
            self._build(mode)
            self._plan(mode)
        else:
1772
            if in_dygraph_mode() or self._dygraph_mode:
1773
                raise ValueError(
1774 1775 1776 1777 1778
                    "Please call `prepare()` or `fit()` or  `evaluate()` or  `predict()` before calling `cost()`."
                )
            else:
                self._logger.info(
                    "The program whose cost to be estimated must be static default program. Otherwise, please call `prepare()`before calling `cost()`."
1779
                )
1780 1781 1782 1783 1784 1785 1786 1787
                program = paddle.static.default_main_program()
                if (
                    not program.global_block().ops
                    or not program.global_block().ops
                ) and not self._has_prepared[mode]:
                    raise ValueError(
                        "Please call `prepare()` or `fit()` or  `evaluate()` or  `predict()` before calling `cost()`."
                    )
1788 1789 1790 1791 1792 1793

        # Estimate the exec cost and max memory
        global_cost, max_memory = get_cost_from_engine(self, mode)

        return global_cost.time, max_memory

1794 1795
    @property
    def main_program(self):
1796 1797
        dist_context = self._dist_contexts[self._mode]
        return dist_context.dist_main_programs[self._cur_rank]
1798 1799 1800

    @property
    def startup_program(self):
1801 1802
        dist_context = self._dist_contexts[self._mode]
        return dist_context.dist_startup_programs[self._cur_rank]
1803 1804 1805

    @property
    def dist_context(self):
1806
        return self._dist_contexts[self._mode]
1807 1808 1809

    @property
    def serial_main_program(self):
1810 1811
        dist_context = self._dist_contexts[self._mode]
        return dist_context.serial_main_program
1812 1813 1814

    @property
    def serial_startup_program(self):
1815 1816 1817 1818 1819 1820 1821
        dist_context = self._dist_contexts[self._mode]
        return dist_context.serial_startup_program

    @property
    def feed_vars(self):
        dist_context = self._dist_contexts[self._mode]
        return dist_context.serial_feed_vars
1822 1823 1824

    @property
    def fetch_vars(self):
1825 1826 1827 1828 1829 1830 1831 1832 1833
        dist_context = self._dist_contexts[self._mode]
        return dist_context.serial_fetch_vars

    @property
    def optimizer(self):
        dist_context = self._dist_contexts[self._mode]
        if dist_context._serial_optimizer:
            return dist_context._serial_optimizer
        return self._optimizer
1834 1835 1836

    @property
    def inputs(self):
1837
        return self._inputs
1838 1839 1840

    @property
    def labels(self):
1841
        return self._labels