engine.py 67.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import os
16
import logging
17 18
import random
import numpy as np
19 20 21
from collections import defaultdict

import paddle
22
import paddle.utils as utils
23

Z
zhaoyingli 已提交
24
from paddle import fluid, static
25
from paddle.metric import Metric
26
from paddle.static import InputSpec
27
from paddle.fluid import core
28
from paddle.fluid import Variable
29
from paddle.fluid.layers.utils import flatten
30
from paddle.fluid.executor import global_scope, _to_name_str
31
from paddle.fluid.framework import Operator, _non_static_mode
32 33
from paddle.fluid.framework import _current_expected_place as _get_device
from paddle.fluid.dygraph.parallel import ParallelEnv
34
from paddle.distributed import fleet
35

Z
zhaoyingli 已提交
36
from .callbacks import config_callbacks
37
from .converter import Converter
38
from .helper import ProgramHelper
39
from .cluster import Cluster, get_default_cluster
40 41
from .planner_v2 import Planner
from .parallelizer_v2 import Parallelizer
42 43
from .dist_op import DistributedOperator
from .dist_saver import DistributedSaver
44
from .dist_loader import DistributedDataLoaderFromGenerator, DistributedDataLoader
Z
zhaoyingli 已提交
45
from .utils import to_list, get_dist_attr, get_lr
46
from .process_group import new_process_group, get_all_process_groups
47
from .dist_context import DistributedContext, get_default_distributed_context
48
from .strategy import Strategy
49
from .interface import CollectionNames, get_collection
Z
zhaoyingli 已提交
50
from ..utils.log_utils import get_logger
51 52
from .utils import initialize_pg_in_full_mode
from .cost.estimate_cost import get_cost_from_engine
53 54 55


class Engine:
56
    """
57 58
    An Engine object can provide the full power of auto parallel to users.
    With the help of it, users can easily obtain the abilities of the
59 60 61 62 63 64 65
    distributed training and inference. It also support the dynamic graph and
    static graph at the same time.

    Args:
        model (paddle.nn.Layer, optional): The model is an instance of
            paddle.nn.Layer.
        loss (Loss|Callable|None, optional): The loss can be a `paddle.nn.Layer`
66 67
            instance or any callable function taken the predicted values and
            ground truth values as input. It can be None when there is no loss.
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
            Default: None.
        optimizer (Optimizer|None, optional): The optimizer need to be set in training
            and should be None in eval and predict mode. Default: None.
        metrics (Metric|list[Metric]|None, optional): If metrics is set, all
            metrics will be calculated and output in train/eval mode. Default: None.
        cluster (Cluster|None, optional): The cluster represents the topology information
            about the used physical devices. Default: None. (Unused for now)
        strategy (Strategy|None, optional): The strategy is used to configure the
        parallelization and optimization behaviors. Default: None.

    Examples:

        .. code-block:: python

            import paddle
            import paddle.vision.transforms as T
84
            from paddle.distributed.fleet import auto
85 86 87 88 89 90 91 92 93 94
            from paddle.vision.datasets import MNIST

            transform = T.Compose([
                T.Transpose(),
                T.Normalize([127.5], [127.5])
            ])
            train_dataset = MNIST(mode='train', transform=transform)
            valid_dataset = MNIST(mode='test', transform=transform)

            model = paddle.vision.models.LeNet()
95
            loss = paddle.nn.CrossEntropyLoss()
96 97 98 99
            optimizer = paddle.optimizer.Adam(
                learning_rate=0.001, parameters=model.parameters())
            metrics = paddle.metric.Accuracy(topk=(1, 2))

100 101
            engine = auto.Engine(model, loss, optimizer, metrics)
            # fit
102 103 104
            engine.fit(train_dataset,
                       epochs=2,
                       batch_size=64)
105
            # evaluate
106 107 108 109 110 111 112
            engine.evaluate(valid_dataset,
                            batch_size=64)
            # predict
            engine.predict(valid_dataset,
                           batch_size=64)
            # save
            engine.save("./my_model")
113
            # load
114 115 116
            engine.load("./my_model")

    """
117

118 119
    def __init__(self,
                 model=None,
120 121 122
                 loss=None,
                 optimizer=None,
                 metrics=None,
123
                 cluster=None,
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
                 strategy=None):

        if model and not isinstance(model,
                                    paddle.nn.Layer) and not callable(model):
            raise TypeError(
                "'model must be sub classes of `paddle.nn.Layer` or any callable function."
            )
        self._model = model
        self._loss = loss

        if optimizer and not isinstance(
                optimizer,
            (paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer)):
            raise TypeError(
                "'optimizer' must be object of class `paddle.optimizer.Optimizer`"
                " or `paddle.fluid.optimizer.Optimizer`.")
        self._optimizer = self._validate_opt(optimizer)

        metrics = metrics or []
        for metric in to_list(metrics):
            assert isinstance(metric, Metric), \
                "{} is not sub class of Metric".format(
                    metric.__class__.__name__)
        self._metrics = to_list(metrics)

        if cluster and not isinstance(cluster, Cluster):
            raise TypeError(
                "'cluster' must be the object or class `paddle.distributed.auto_parallel.Cluster`"
            )
        self._cluster = cluster or get_default_cluster()

        if strategy and not isinstance(strategy, Strategy):
            raise TypeError(
                "'strategy' must be object of class `paddle.distributed.auto_parallel.Strategy`"
            )
        self._strategy = strategy or Strategy()

        if os.getenv("POD_NAME"):
            print("Distribute training by paddle.distributed.launch",
                  flush=True)
            fleet.init(is_collective=True)
165

166
        self._executor = None
167 168 169
        self._cur_rank = paddle.distributed.get_rank()
        self._nranks = paddle.distributed.get_world_size()
        self._saver = DistributedSaver()
170

171
        self._logger = get_logger(logging.INFO)
172

173 174
        self._orig_main_prog = static.default_main_program()
        self._orig_startup_prog = static.default_startup_program()
175
        self._orig_dist_context = get_default_distributed_context()
176
        self._dist_contexts = {}
177 178
        self._serial_main_progs = {}
        self._serial_startup_progs = {}
179 180 181 182
        self._dist_main_progs = defaultdict(dict)  # dist main programs
        self._dist_startup_progs = defaultdict(dict)  # dist startup programs
        self._feed_vars = {}
        self._fetch_vars = {}
183
        self._planners = {}
184 185
        self._has_prepared = {"train": False, "eval": False, "predict": False}
        self._has_prepared_reader = {
186 187 188 189
            "train": False,
            "eval": False,
            "predict": False
        }
190 191 192 193
        self._inputs_spec = []
        self._labels_spec = []
        self._inputs = []
        self._labels = []
194

195 196
        self._skip_build = False
        self._outside_dataloader = False
197
        self._planned_mode = None
198 199
        self._dygraph_mode = False
        self._tuning = self._strategy.tuning
200
        self._losses = None
201

Z
zhaoyingli 已提交
202 203
        self.history = None

204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
    def _prepare_data_spec(self, data, split, batch_size):
        inputs_spec = []
        labels_spec = []
        if isinstance(data, paddle.io.IterableDataset):
            if split is None:
                inputs, labels = next(iter(data))
            else:
                sample = next(iter(data))
                inputs = sample[:split]
                labels = sample[split:]
        elif isinstance(data, paddle.io.Dataset):
            if split is None:
                inputs, labels = data[0]
            else:
                sample = data[0]
                inputs = sample[:split]
                labels = sample[split:]
        else:
            raise ValueError(
                "Data should be a Dataset or IterableDatset, but received {}.".
                format(type(data).__name__))
        inputs = to_list(inputs)
        labels = to_list(labels)

        num_shards = self._strategy.dataset.num_shards
229

230 231 232 233 234 235 236 237 238 239 240 241 242 243
        def _adjust_item_spec(num_shards, spec):
            if num_shards > 1 and len(spec.shape) > 1:
                spec.shape[0] = spec.shape[0] * num_shards

        def _infer_item_spec(item, name, batch_size, specs):
            if isinstance(item, np.ndarray):
                spec = InputSpec.from_numpy(item, name)
                if batch_size is None:
                    _adjust_item_spec(num_shards, spec)
                    specs.append(spec)
                else:
                    specs.append(spec.batch(batch_size))
            elif isinstance(item, (Variable, core.VarBase, core.eager.Tensor)):
                spec = InputSpec.from_tensor(item, name)
244
                _adjust_item_spec(num_shards, spec)
245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356
                if batch_size is None:
                    specs.append(spec)
                else:
                    specs.append(spec.batch(batch_size))
            else:
                specs.append(InputSpec([batch_size], type(item), name))

        if inputs is not None:
            for i, item in enumerate(inputs):
                assert item is not None, "Receive None input."
                name = "input" + str(i)
                _infer_item_spec(item, name, batch_size, inputs_spec)
        if labels is not None:
            for i, item in enumerate(labels):
                assert item is not None, "Receive None input."
                name = "label" + str(i)
                _infer_item_spec(item, name, batch_size, labels_spec)

        inputs_spec = self._validate_spec(inputs_spec)
        labels_spec = self._validate_spec(labels_spec)
        return inputs_spec, labels_spec

    def _prepare_data_tensor(self,
                             inputs_spec,
                             labels_spec,
                             inputs=None,
                             labels=None):
        if _non_static_mode() or self._dygraph_mode:
            return None, None
        inputs_spec = inputs_spec if inputs_spec else []
        labels_spec = labels_spec if labels_spec else []
        if inputs_spec:
            assert isinstance(inputs_spec, list), \
                "inputs should be list, but received {}".format(type(inputs_spec))
            if inputs is None:
                inputs = [s._create_feed_layer() for s in inputs_spec]
            else:
                assert isinstance(inputs, list), \
                    "inputs should be list, but received {}".format(type(inputs))
                for input_spec, input in zip(inputs_spec, inputs):
                    if input_spec.shape != input.shape:
                        input.desc.set_shape(input_spec.shape)
        if labels_spec:
            assert isinstance(labels_spec, list), \
                "labels should be list, but received {}".format(type(labels_spec))
            if labels is None:
                labels = [s._create_feed_layer() for s in labels_spec]
            else:
                assert isinstance(labels, list), \
                    "labels should be list, but received {}".format(type(labels))
                for label_spec, label in zip(labels_spec, labels):
                    if label_spec.shape != label.shape:
                        label.desc.set_shape(label_spec.shape)
        return inputs, labels

    def _prepare_reader(self):
        dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank]
        dist_context = self._dist_contexts[self._mode]
        dist_main_block = dist_main_prog.global_block()

        # NOTE: this list may be changed if Paddle changes the existing rules.
        related_reader_ops = [
            "create_py_reader", "create_double_buffer_reader", "read"
        ]
        # remove the first three ops if multiple run fit/evaluate/predict
        if dist_main_block.ops[0].type == 'create_py_reader':
            for i in range(len(related_reader_ops)):
                if dist_main_block.ops[0].type in related_reader_ops:
                    dist_main_block._remove_op(0, sync=False)
        dist_main_block._sync_with_cpp()
        # Step 1: find the reader ops
        reader_op_indices = []
        for idx, op in enumerate(dist_main_block.ops):
            if op.type in related_reader_ops:
                reader_op_indices.append(idx)
        # Step 2: insert the new reader ops to cpp
        new_reader_ops = []
        for idx in reversed(reader_op_indices):
            new_op_desc = dist_main_block.desc._prepend_op()
            new_op_desc.copy_from(dist_main_block.ops[idx].desc)
            new_op = Operator(dist_main_block,
                              new_op_desc,
                              type=new_op_desc.type())
            new_reader_ops.append(new_op)
            dist_op = DistributedOperator(new_op)
            dist_context.add_dist_op_for_program(dist_op)
        # Step 3: insert the new reader ops to python
        for new_op in new_reader_ops:
            dist_main_block.ops.insert(0, new_op)
        for i in range(len(reader_op_indices)):
            reader_op_indices[i] += len(reader_op_indices)
        # Step 4: remove the old reader ops from python and cpp
        for idx in reversed(reader_op_indices):
            op = dist_main_block.ops.pop(idx)
            dist_main_block.desc._remove_op(idx, idx + 1)
        dist_main_block._sync_with_cpp()
        self._has_prepared_reader[self._mode] = True

    def _prepare_feed(self, data, user_feeds, mode):
        feeds = {}
        if data is not None:
            if isinstance(data, (list, tuple)):
                if len(data) == 1 and isinstance(data[0], dict):
                    for name, data in data[0].items():
                        feeds[name] = data
                else:
                    raise ValueError("Unsupported data {}".format(data))
            elif isinstance(data, dict):
                for name, data in data.items():
                    feeds[name] = data
            else:
                raise ValueError("Unsupported data {}".format(data))
357 358 359
        if user_feeds is not None:
            assert isinstance(user_feeds, dict), \
                "user_feeds must be a dict, but receive {}".format(type(user_feeds).__name__)
360 361
            for name, data in user_feeds.items():
                feeds[name] = data
362 363
        return feeds

364
    def _prepare_fetch(self, user_fetches, mode):
365 366 367 368
        if user_fetches is not None:
            assert isinstance(user_fetches, list), \
                "user_fetches must be a list, but receive {}".format(type(user_fetches).__name__)
        fetch_names = []
369
        fetch_indices = []
370

371 372
        def _process_fetch_group(group_name, var_list):
            group_indices = []
373
            for var in var_list:
374 375 376 377 378 379
                # Remove duplicate var_names
                if self._is_local_var(var):
                    var_name = _to_name_str(var)
                    if var_name not in fetch_names:
                        fetch_names.append(var_name)
                    group_indices.append(fetch_names.index(var_name))
380 381
            if not group_indices:
                fetch_names.append([])
382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405
            fetch_indices.append(group_indices)

        if mode != "predict":
            _process_fetch_group("loss", self._fetch_vars[mode]["loss"])
        if mode != "predict":
            metrics = self._fetch_vars[mode]["metrics"]
            for i, var_list in enumerate(metrics):
                _process_fetch_group("metrics_" + str(i), var_list)
        if mode == "predict":
            _process_fetch_group("outputs", self._fetch_vars[mode]["outputs"])
        user_fetches_collection = [
            item[1] for item in get_collection(CollectionNames.FETCHES)
        ]
        var_list = (user_fetches_collection or []) + (user_fetches or [])
        _process_fetch_group("fetches", var_list)
        return fetch_names, fetch_indices

    def _prepare_logger(self,
                        outs,
                        epoch=None,
                        step=None,
                        lr=None,
                        fetch_names=None,
                        fetch_indices=None,
406
                        mode=None):
Z
zhaoyingli 已提交
407
        logs = {}
408
        if epoch is not None:
Z
zhaoyingli 已提交
409
            logs["epoch"] = epoch
410
        if step is not None:
Z
zhaoyingli 已提交
411
            logs["step"] = step + 1
412
        if lr is not None:
Z
zhaoyingli 已提交
413
            logs["lr"] = lr
414 415
        group_idx = 0
        if mode != "predict":
Z
zhaoyingli 已提交
416
            # logging loss
417
            loss_indices = fetch_indices[group_idx]
Z
zhaoyingli 已提交
418
            assert len(loss_indices) <= 1
419
            for idx in loss_indices:
Z
zhaoyingli 已提交
420
                logs["loss"] = outs[idx][0]
421
            group_idx += 1
Z
zhaoyingli 已提交
422
            # logging metrics
423 424 425 426 427 428 429 430 431 432 433
            metric_vars = self._fetch_vars[mode]["metrics"]
            if metric_vars:
                for metric in self._metrics:
                    metrics_indices = fetch_indices[group_idx]
                    metric_out = []
                    for idx in metrics_indices:
                        metric_out.append(outs[idx])
                    if metric_out:
                        metric.update(*metric_out)
                        results = metric.accumulate()
                        for i, res in enumerate(to_list(results)):
Z
zhaoyingli 已提交
434
                            logs[metric.name()[i]] = res
435
                    group_idx += 1
Z
zhaoyingli 已提交
436 437 438 439 440 441 442
        # logging outputs
        elif mode == "predict":
            outputs_indices = fetch_indices[group_idx]
            logs_out = {}
            for idx in outputs_indices:
                logs_out["out%d" % (idx)] = outs[idx]
            logs["outputs"] = logs_out
443 444
            group_idx += 1
        # logging user fetches
Z
zhaoyingli 已提交
445 446 447
        collect_fetches = get_collection(CollectionNames.FETCHES)
        logs_fetch = {}
        for name, var in collect_fetches:
448 449
            if var.name in fetch_names:
                idx = fetch_names.index(var.name)
Z
zhaoyingli 已提交
450 451 452
                logs_fetch[name or var.name] = outs[idx]
        logs["fetches"] = logs_fetch
        return logs
453

454 455 456 457 458 459 460 461 462 463 464
    def _prepare_program(self, mode):
        # Do the build process
        self._build(mode)
        # Do the planning process
        self._plan(mode)
        # Do the parallel process
        self._parallel(mode)
        # Init comm and startup program
        self._initialize(mode)
        self._has_prepared[mode] = True

465
    def _build(self, mode):
466
        if _non_static_mode() or self._dygraph_mode:
467
            paddle.disable_static()
468 469 470
            self._dygraph_mode = True
            self._logger.info("Building model with 'to_static' method.")

471 472
            inputs_spec = self._inputs_spec
            labels_spec = self._labels_spec if self._labels_spec else []
473
            self.program_helper = ProgramHelper(self._model, self._loss,
474 475
                                                self._metrics, inputs_spec,
                                                labels_spec)
476
            # build forward main program
477
            self.program_helper.build_program(mode)
478

479 480 481
            self.concrete_program = self.program_helper.concrete_program
            serial_main_prog = self.program_helper.main_program
            serial_startup_prog = self.program_helper.startup_program
482

483 484 485 486
            inputs = self.program_helper.input_vars
            outputs = self.program_helper.output_vars
            labels = self.program_helper.label_vars
            losses = self.program_helper.loss_vars
487
            self._losses = losses
488
            metrics = self.program_helper.metric_vars
489

490 491 492
            self._inputs = inputs
            self._labels = labels

493
            paddle.enable_static()
494 495 496 497 498 499
        else:
            # build program in static mode
            serial_main_prog = self._serial_main_progs.get(mode, None)
            if serial_main_prog is not None:
                return

500
            outputs = []
501 502
            losses = []
            metrics = []
503 504
            inputs = self._inputs if self._inputs else []
            labels = self._labels if self._labels else []
505 506
            serial_main_prog = self._orig_main_prog.clone()
            serial_startup_prog = self._orig_startup_prog.clone()
507 508 509 510 511 512
            if not self._skip_build:
                with static.program_guard(serial_main_prog, serial_startup_prog), \
                    utils.unique_name.guard():
                    outputs = to_list(self._model(*inputs))
                    if mode != "predict" and self._loss:
                        losses = to_list(self._loss(*(outputs + labels)))
513
                        self._losses = losses
514 515 516 517 518 519 520

                    if mode != "predict" and (outputs or labels):
                        for metric in self._metrics:
                            metrics.append(
                                to_list(metric.compute(*(outputs + labels))))
            else:
                losses = to_list(self._loss)
521
                self.losses = losses
522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537

        default_ctx = get_default_distributed_context()
        if not default_ctx.has_annotation:
            # We build the world process group because the data parallel
            # needs all ranks by default.
            new_process_group(list(range(self._nranks)))
            default_ctx.data_parallel = True

        feed_vars = {"inputs": inputs, "labels": labels}

        fetch_vars = {
            "outputs": flatten(outputs),
            "loss": losses,
            "metrics": metrics
        }

538 539 540
        if mode != "train":
            serial_main_prog = serial_main_prog.clone(for_test=True)

541
        self._set_recompute_ckpts()
542 543
        self._dist_contexts[mode] = DistributedContext(
            serial_main_prog, serial_startup_prog, self._optimizer, losses,
544 545
            feed_vars, fetch_vars, self._cluster, self._strategy)
        self._dist_contexts[mode].gradient_scale = self._strategy.gradient_scale
546

547 548 549
    def _optimization_tuning(self, mode, dataset, batch_size):
        if not self._tuning.enable:
            raise ValueError("Please set `tuning.enable=True`.")
550

551 552 553 554 555 556 557 558
        assert mode == "train"
        # Do the build process
        self._build(mode)
        # Do the planning process
        self._plan(mode)

        dataset.dp_world_size = self._dp_world_sizes
        dataset.dp_rank = self._dp_ranks
559 560

        from .tuner.optimization_tuner import OptimizationTuner
561
        self._optimization_tuner = OptimizationTuner(self._tuning.to_dict(),
562 563
                                                     self._dist_contexts[mode],
                                                     dataset,
564 565
                                                     self._inputs_spec,
                                                     self._labels_spec,
566 567 568 569 570
                                                     batch_size=batch_size,
                                                     rank=self._cur_rank)

        self._optimization_tuner.tune()

571
        if self._tuning.run_after_tuning:
572 573 574 575
            # update the strategy
            self._dist_contexts[
                mode]._strategy = self._optimization_tuner.get_best_config()

576 577 578 579 580 581
    def _plan(self, mode):
        if self._planned_mode is None:
            self._planned_mode = mode
        else:
            self._init_dist_context(mode)

582 583
        self._planners[mode] = Planner(mode, self._dist_contexts[mode])
        self._planners[mode].plan()
584

585 586 587 588
        # infer data parallel info
        inputs_var = self._dist_contexts[mode].serial_feed_vars["inputs"]
        labels_var = self._dist_contexts[mode].serial_feed_vars["labels"]
        block = self._dist_contexts[mode].serial_main_program.global_block()
589
        # TODO: check this feed_list
590 591 592 593 594
        feed_list = []
        for var in inputs_var + labels_var:
            if var.name in block.vars:
                feed_list.append(block.vars[var.name])

595 596
        self._dp_world_sizes = []
        self._dp_ranks = []
597 598 599
        for feed_var in feed_list:
            dp_world_size, dp_rank = self._get_input_split_info(
                feed_var, self._dist_contexts[mode])
600 601
            self._dp_world_sizes.append(dp_world_size)
            self._dp_ranks.append(dp_rank)
602

603
    def _parallel(self, mode, all_ranks=False):
604 605 606
        # Parallelize program based on the planner's results
        # For now, the completer has to be passed to the planner,
        # because we may use it to complete the annotation of the backwarkward and update.
607
        parallelizer = Parallelizer(mode, self._planners[mode].completer,
608 609 610 611 612
                                    self._dist_contexts[mode])
        if not all_ranks:
            parallelizer.parallel(self._cur_rank)
        else:
            parallelizer.parallel_all()
613 614

    def _init_dist_context(self, mode):
615
        # Init dist_context['mode'] with the first planned dist_context
616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632
        # to guarantee that train/eval/predict mode have same parallel strategy
        dist_context = self._dist_contexts[mode]
        origin_main_prog = dist_context._original_serial_main_program
        ref_mode = self._planned_mode
        ref_dist_context = self._dist_contexts[ref_mode]
        ref_origin_main_prog = ref_dist_context._original_serial_main_program
        ref_blocks = ref_origin_main_prog.blocks
        for ib, block in enumerate(origin_main_prog.blocks):
            for iop, op in enumerate(block.ops):
                ref_op = ref_blocks[ib].ops[iop]
                assert op.type == ref_op.type, \
                    "'{}' mode op '{}' is different with '{}' op '{}'. ".format(mode, op.type, ref_mode, ref_op.type)
                ref_op_dist_attr = ref_dist_context.get_op_dist_attr_for_program(
                    ref_op)
                dist_context.set_op_dist_attr_for_program(op, ref_op_dist_attr)

    def _initialize(self, mode):
633
        # Get the current content from the distributed context
634 635 636 637
        self._serial_main_progs[mode] = self._dist_contexts[
            mode].serial_main_program
        self._serial_startup_progs[mode] = self._dist_contexts[
            mode].serial_startup_program
638 639 640 641
        self._dist_main_progs[mode] = self._dist_contexts[
            mode].dist_main_programs
        self._dist_startup_progs[mode] = self._dist_contexts[
            mode].dist_startup_programs
642 643
        self._feed_vars[mode] = self._dist_contexts[mode].serial_feed_vars
        self._fetch_vars[mode] = self._dist_contexts[mode].serial_fetch_vars
Z
zhaoyingli 已提交
644
        self._optimizer = self._dist_contexts[mode]._serial_optimizer
645

646 647 648 649
        if self._nranks > 1:
            # Traverse different rank programs and traverse each op of them,
            # instantiate communication by process_mapping.
            all_process_groups = get_all_process_groups()
C
caozhou 已提交
650 651
            cur_rank = self._cur_rank
            # NOTE: After the implementation of the unified dynamic and static communication group initialization mode in the future, the initialization logic of full mode will be removed because port occupation error may occur.
652 653 654 655
            if self._strategy.auto_mode == "full":
                initialize_pg_in_full_mode(all_process_groups, cur_rank)
            else:
                for process_group in all_process_groups:
C
caozhou 已提交
656
                    if cur_rank not in process_group.ranks:
657 658
                        continue
                    process_group.instantiate()
659

660 661 662
        place = _get_device()
        if isinstance(place, fluid.CUDAPlace):
            place = fluid.CUDAPlace(ParallelEnv().dev_id)
663

664 665 666 667 668
        if self._strategy.seed:
            paddle.seed(self._strategy.seed + self._dp_ranks[0])
            np.random.seed(self._strategy.seed + self._dp_ranks[0])
            random.seed(self._strategy.seed + self._dp_ranks[0])

669
        if self._dygraph_mode:
670 671 672
            dist_context = self._dist_contexts[mode]
            dist_main_program = self._dist_main_progs[mode][self._cur_rank]
            self.program_helper.init(dist_main_program, place, dist_context)
673

674
        if self._executor is None:
675
            self._executor = paddle.static.Executor(place)
676 677 678 679 680 681 682 683 684 685
            uninitialized = []
            dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank]
            for var in dist_startup_prog.list_vars():
                scope_var = global_scope().find_var(var.name)
                if scope_var and scope_var.get_tensor()._is_initialized():
                    continue
                uninitialized.append(var)
            if uninitialized:
                prune_startup_prog = dist_startup_prog._prune(uninitialized)
                self._executor.run(prune_startup_prog)
686

687 688 689 690 691
            if hasattr(self, "_state_dict") and hasattr(self, "_dist_attr"):
                self._set_state_dict(mode, self._strict, self._state_dict,
                                     self._dist_attr)

        if self._strategy.reinit:
Z
zhaoyingli 已提交
692
            self._logger.info("NOTE: parameters will be re-initialized.")
693 694 695
            dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank]
            self._executor.run(dist_startup_prog)

696 697
    def fit(self,
            train_data,
698
            train_sample_split=None,
699 700 701
            batch_size=1,
            epochs=1,
            steps_per_epoch=None,
Z
zhaoyingli 已提交
702 703 704
            log_freq=10,
            save_dir=None,
            save_freq=1,
705 706 707 708
            valid_data=None,
            valid_sample_split=None,
            valid_freq=1,
            valid_steps=None,
709
            collate_fn=None,
Z
zhaoyingli 已提交
710 711
            callbacks=None,
            verbose=2):
712 713 714 715 716 717 718 719
        """
        Trains the model for a fixed number of epochs. If `valid_data` is set,
        evaluation will be done at the end of each epoch.

        Args:
            train_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None.
            train_sample_split (int, optional): Each sample of the train dataset is assumed
                to be a (input, label) pair by default and has two items. If each sample has
720
                more than two items, train_sample_split specifies how to split these items into
721
                input and label. The items before it are input and the left are label. Default: None.
722
            batch_size (int, optional): The batch size of train_data and valid_data if provided.
723 724 725
                The user's data will be used directly without batching if set to None. Default: 1.
            epochs (int, optional): The number of epochs to train the model. Default: 1.
            steps_per_epoch (int, optional): The total number of steps (batches of samples)
726
                is executed in one epoch before stating the next one. If None, it is equal to
727 728
                the number samples in your dataset divided by the batch size. Default: None.
            valid_data (Dataset, optional): An instance of paddle paddle.io.Dataset used for
729
                evaluation at the end of epoch. No evaluation will be done if set to None.
730
                Default: None. (Unsupported for now)
731
            valid_freq (int, optional): Only relevant if valid_data is provided. This specifies
732 733
                how many training epochs before a new evaluation is performed. Default: 1.
            valid_sample_split (int, optional): Only relevant if valid_data is provided.
734 735
                Each sample of the valid dataset is assumed to be a (input, label) pair
                by default and has two items. If each sample has more than two items,
736 737 738
                valid_sample_split specifies how to split these items into input and label.
                The items before it are input and the left are label. Default: None.
            valid_steps (int, optional): Only relevant if valid_data is provided.
739 740
                It is the total number of steps (batches of samples) to draw before
                stopping validation at the end of every epoch. If None, validation will run until the
741 742 743 744
                `valid_data` dataset is exhausted. The validation will start from the
                beginning of the dataset at each epoch. Default: None.
            collate_fn(callable, optional): function to generate mini-batch data by merging
                the sample list, None for only stack each fields of sample in axis
745
                0. Default None.
746 747 748 749 750 751 752 753 754 755 756 757
            callbacks (Callback|None, optional): A list of `Callback` instances to apply
                during training. Default: None. (Unused for now)

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle
                import paddle.vision.transforms as T
758
                from paddle.distributed.fleet import auto
759 760 761 762 763 764 765 766 767
                from paddle.vision.datasets import MNIST

                transform = T.Compose([
                    T.Transpose(),
                    T.Normalize([127.5], [127.5])
                ])
                train_dataset = MNIST(mode='train', transform=transform)

                model = paddle.vision.models.LeNet()
768
                loss = paddle.nn.CrossEntropyLoss()
769 770 771 772
                optimizer = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=model.parameters())
                metrics = paddle.metric.Accuracy(topk=(1, 2))

773
                engine = auto.Engine(model, loss, optimizer, metrics)
774 775 776 777
                engine.fit(train_dataset,
                           epochs=2,
                           batch_size=64)
        """
778 779 780 781 782 783 784
        self._mode = 'train'
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
            train_data, train_sample_split, batch_size)
        self._inputs, self._labels = self._prepare_data_tensor(
            self._inputs_spec, self._labels_spec)
        if not self._has_prepared[self._mode]:
            self._prepare_program(self._mode)
Z
zhaoyingli 已提交
785
        else:
786
            self._switch_mode(self._mode)
Z
zhaoyingli 已提交
787 788 789 790

        assert self._mode in self._dist_main_progs, \
            "train model is not ready, please call `engine._prepare_program('train')` first."

791 792 793 794 795 796 797 798
        train_dataloader = self._prepare_dataloader_from_generator(
            dataset=train_data,
            capacity=70,
            iterable=False,
            batch_size=batch_size,
            epochs=epochs,
            steps_per_epoch=steps_per_epoch,
            collate_fn=collate_fn)
Z
zhaoyingli 已提交
799

800
        fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode)
Z
zhaoyingli 已提交
801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851

        cbks = config_callbacks(
            callbacks,
            engine=self,
            batch_size=batch_size,
            epochs=epochs,
            steps=train_dataloader._steps,
            log_freq=log_freq,
            save_freq=save_freq,
            save_dir=save_dir,
            verbose=verbose,
            metrics=self._metrics_name(),
            acc_step=self._k_steps,
        )

        cbks.on_begin('train')
        for epoch in range(epochs):
            logs = {}
            cbks.on_epoch_begin(epoch)
            for step, _ in enumerate(train_dataloader):
                cbks.on_batch_begin('train', step, logs)
                try:
                    outs = self._executor.run(
                        self.main_program,
                        fetch_list=fetch_names,
                        use_program_cache=self._strategy.use_cache,
                        return_numpy=self._strategy.return_numpy)
                except core.EOFException:
                    break
                lr = get_lr(self._optimizer)
                logs = self._prepare_logger(outs, epoch, step, lr, fetch_names,
                                            fetch_indices, self._mode)
                cbks.on_batch_end('train', step, logs)

            if valid_data and (epoch + 1) % valid_freq == 0:
                val_logs = self.evaluate(valid_data, valid_sample_split,
                                         batch_size, valid_steps, log_freq,
                                         collate_fn, callbacks, verbose)
                val_logs = {
                    "val_" + name: val
                    for name, val in val_logs.items()
                }
                logs.update(val_logs)
                self._switch_mode("train")
            else:
                self._reset_metrics()

            cbks.on_epoch_end(epoch, logs)

        cbks.on_end('train', logs)
        return self.history
852

853
    def evaluate(self,
854 855
                 valid_data,
                 valid_sample_split=None,
856
                 batch_size=1,
857
                 steps=None,
Z
zhaoyingli 已提交
858
                 log_freq=10,
859
                 collate_fn=None,
Z
zhaoyingli 已提交
860 861
                 callbacks=None,
                 verbose=2):
862 863 864 865
        """
        Evaluate the loss and metrics of the model on evaluation data.

        Args:
866 867
            valid_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None.
            valid_sample_split (int, optional): Each sample of the eval dataset is assumed
868
                to be a (input, label) pair by default and has two items. If each sample has
869
                more than two items, valid_sample_split specifies how to split these items into
870
                input and label. The items before it are input and the left are label. Default: None.
871
            batch_size (int, optional): The batch size of valid_data. The user's data will
872
                be used directly without batching if set to None. Default: 1.
873 874
            steps (int, optional): It is the total number of steps (batches of samples) to draw before
                stopping evaluation. If None, evaluation will run until the `valid_data` dataset is exhausted.
875 876 877 878 879
                The evaluation will start from the beginning of the dataset in each run. Default: None.
            collate_fn(callable, optional): function to generate mini-batch data by merging
                the sample list, None for only stack each fields of sample in axis
                0. Default None.
            callbacks (Callback|None, optional): A list of `Callback` instances to apply
880
                during evaluating. Default: None. (Unused for now)
881 882 883 884 885 886 887 888 889 890

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle
                import paddle.vision.transforms as T
891
                from paddle.distributed.fleet import auto
892 893 894 895 896 897 898 899 900
                from paddle.vision.datasets import MNIST

                transform = T.Compose([
                    T.Transpose(),
                    T.Normalize([127.5], [127.5])
                ])
                valid_dataset = MNIST(mode='test', transform=transform)

                model = paddle.vision.models.LeNet()
901
                loss = paddle.nn.CrossEntropyLoss()
902 903
                metrics = paddle.metric.Accuracy(topk=(1, 2))

904
                engine = auto.Engine(model, loss, metrics=metrics)
905 906 907
                engine.evaluate(valid_dataset, batch_size=64)

        """
908 909 910 911 912 913 914
        self._mode = 'eval'
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
            valid_data, valid_sample_split, batch_size)
        self._inputs, self._labels = self._prepare_data_tensor(
            self._inputs_spec, self._labels_spec)
        if not self._has_prepared[self._mode]:
            self._prepare_program(self._mode)
Z
zhaoyingli 已提交
915
        else:
916
            self._switch_mode(self._mode)
Z
zhaoyingli 已提交
917

918
        assert self._mode in self._dist_main_progs, \
919
            "eval model is not ready, please call `engine._prepare_program('eval')` first."
920 921 922 923 924 925 926
        valid_dataloader = self._prepare_dataloader_from_generator(
            dataset=valid_data,
            capacity=70,
            iterable=False,
            batch_size=batch_size,
            steps_per_epoch=steps,
            collate_fn=collate_fn)
Z
zhaoyingli 已提交
927

928
        fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode)
929

Z
zhaoyingli 已提交
930 931 932 933 934 935 936 937 938 939 940 941 942 943 944
        cbks = config_callbacks(
            callbacks,
            engine=self,
            batch_size=batch_size,
            log_freq=log_freq,
            verbose=verbose,
            metrics=self._metrics_name(),
        )

        eval_steps = valid_dataloader._steps
        cbks.on_begin('eval', {
            'steps': eval_steps,
            'metrics': self._metrics_name()
        })
        logs = {}
945
        for step, _ in enumerate(valid_dataloader):
Z
zhaoyingli 已提交
946
            cbks.on_batch_begin('eval', step, logs)
947
            try:
948 949
                outs = self._executor.run(
                    self.main_program,
950
                    fetch_list=fetch_names,
951 952 953
                    use_program_cache=self._strategy.use_cache,
                    return_numpy=self._strategy.return_numpy)
            except core.EOFException:
954
                break
Z
zhaoyingli 已提交
955 956 957 958
            logs = self._prepare_logger(outs, None, step, None, fetch_names,
                                        fetch_indices, self._mode)
            cbks.on_batch_end('eval', step, logs)
        cbks.on_end('eval', logs)
959
        self._reset_metrics()
Z
zhaoyingli 已提交
960
        return logs
961

962 963
    def predict(self,
                test_data,
964
                test_sample_split=None,
965
                batch_size=1,
966
                steps=None,
967
                collate_fn=None,
Z
zhaoyingli 已提交
968 969
                callbacks=None,
                verbose=2):
970 971 972 973 974 975 976
        """
        Compute the output predictions on testing data.

        Args:
            test_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None.
            test_sample_split (int, optional): Each sample of the test dataset is assumed
                to be a (input, label) pair by default and has two items. If each sample has
977
                more than two items, test_sample_split specifies how to split these items into
978 979 980
                input and label. The items before it are input and the left are label. Default: None.
            batch_size (int, optional): The batch size of test_data. The user's data will
                be used directly without batching if set to None. Default: 1.
981 982
            steps (int, optional): It is the total number of steps (batches of samples) to draw before
                stopping predict. If None, predict will run until the `test_data` dataset is exhausted.
983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998
                The predict will start from the beginning of the dataset in each run. Default: None.
            collate_fn(callable, optional): function to generate mini-batch data by merging
                the sample list, None for only stack each fields of sample in axis
                0. Default None.
            callbacks (Callback|None, optional): A list of `Callback` instances to apply
                during testing. Default: None. (Unused for now)

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle
                import paddle.vision.transforms as T
999
                from paddle.distributed.fleet import auto
1000 1001 1002 1003 1004 1005 1006 1007 1008 1009
                from paddle.vision.datasets import MNIST

                transform = T.Compose([
                    T.Transpose(),
                    T.Normalize([127.5], [127.5])
                ])
                valid_dataset = MNIST(mode='test', transform=transform)

                model = paddle.vision.models.LeNet()

1010
                engine = auto.Engine(model)
1011 1012
                engine.predict(valid_dataset, batch_size=64)
        """
1013 1014 1015 1016 1017 1018 1019
        self._mode = 'predict'
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
            test_data, test_sample_split, batch_size)
        self._inputs, self._labels = self._prepare_data_tensor(
            self._inputs_spec, self._labels_spec)
        if not self._has_prepared[self._mode]:
            self._prepare_program(self._mode)
Z
zhaoyingli 已提交
1020
        else:
1021
            self._switch_mode(self._mode)
Z
zhaoyingli 已提交
1022

1023
        assert self._mode in self._dist_main_progs, \
1024
            "predict model is not ready, please call `engine._prepare_program('predict')` first."
Z
zhaoyingli 已提交
1025

1026 1027 1028 1029 1030 1031 1032
        test_dataloader = self._prepare_dataloader_from_generator(
            dataset=test_data,
            capacity=70,
            iterable=False,
            batch_size=batch_size,
            steps_per_epoch=steps,
            collate_fn=collate_fn)
Z
zhaoyingli 已提交
1033

1034
        fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode)
1035

Z
zhaoyingli 已提交
1036 1037 1038 1039 1040
        outputs = []
        cbks = config_callbacks(callbacks, engine=self, verbose=verbose)
        test_steps = test_dataloader._steps
        cbks.on_begin('predict', {'steps': test_steps})
        logs = {}
1041
        for step, _ in enumerate(test_dataloader):
Z
zhaoyingli 已提交
1042
            cbks.on_batch_begin('predict', step, logs)
1043
            try:
1044 1045
                outs = self._executor.run(
                    self.main_program,
1046
                    fetch_list=fetch_names,
1047 1048 1049
                    use_program_cache=self._strategy.use_cache,
                    return_numpy=self._strategy.return_numpy)
            except core.EOFException:
1050
                break
Z
zhaoyingli 已提交
1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072
            logs = self._prepare_logger(outs, None, step, None, fetch_names,
                                        fetch_indices, self._mode)
            cbks.on_batch_end('predict', step, logs)
            outputs.append(list(logs["outputs"].values()))
        cbks.on_end('predict', logs)
        return outputs

    def dataloader(self,
                   dataset,
                   batch_size=1,
                   shuffle=False,
                   drop_last=False,
                   collate_fn=None,
                   num_workers=0,
                   use_buffer_reader=True,
                   use_shared_memory=True,
                   timeout=0,
                   worker_init_fn=None,
                   epochs=1,
                   steps_per_epoch=None,
                   sample_split=1,
                   mode=None):
1073 1074 1075 1076 1077 1078 1079 1080
        if mode is not None:
            self.to_mode(mode)
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
            dataset, sample_split, batch_size)
        self._inputs, self._labels = self._prepare_data_tensor(
            self._inputs_spec, self._labels_spec)
        if not self._has_prepared[self._mode]:
            self._prepare_program(self._mode)
1081
        else:
1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096
            self._switch_mode(self._mode)
        dataloader = self._prepare_dataloader(
            dataset,
            return_list=False,
            batch_size=batch_size,
            shuffle=shuffle,
            drop_last=drop_last,
            collate_fn=collate_fn,
            num_workers=num_workers,
            use_buffer_reader=use_buffer_reader,
            use_shared_memory=use_shared_memory,
            timeout=timeout,
            worker_init_fn=worker_init_fn,
            epochs=epochs,
            steps_per_epoch=steps_per_epoch)
1097 1098
        return dataloader

1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111
    def dataloader_from_generator(self,
                                  dataset,
                                  capacity=70,
                                  use_double_buffer=True,
                                  iterable=True,
                                  use_multiprocess=False,
                                  drop_last=True,
                                  batch_size=1,
                                  epochs=1,
                                  steps_per_epoch=None,
                                  collate_fn=None,
                                  sample_split=1,
                                  mode=None):
1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147
        if mode is not None:
            self.to_mode(mode)
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
            dataset, sample_split, batch_size)
        self._inputs, self._labels = self._prepare_data_tensor(
            self._inputs_spec, self._labels_spec)
        if not self._has_prepared[self._mode]:
            self._prepare_program(self._mode)
        else:
            self._switch_mode(self._mode)
        dataloader = self._prepare_dataloader_from_generator(
            dataset=dataset,
            capacity=capacity,
            use_double_buffer=use_double_buffer,
            iterable=iterable,
            return_list=False,
            use_multiprocess=use_multiprocess,
            drop_last=drop_last,
            batch_size=batch_size,
            epochs=epochs,
            steps_per_epoch=steps_per_epoch,
            collate_fn=collate_fn)
        return dataloader

    def prepare(self,
                inputs_spec=None,
                labels_spec=None,
                inputs=None,
                labels=None,
                main_program=None,
                startup_program=None,
                mode=None):
        if mode is not None:
            self.to_mode(mode)
        if inputs or labels:
            self._skip_build = True
1148 1149
            self._inputs_spec = inputs_spec
            self._labels_spec = labels_spec
1150
            self._inputs, self._labels = self._prepare_data_tensor(
1151
                self._inputs_spec, self._labels_spec, inputs, labels)
1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162
            self._orig_main_prog = main_program
            if self._orig_main_prog is None:
                self._orig_main_prog = static.default_main_program()
            self._orig_startup_prog = startup_program
            if self._orig_startup_prog is None:
                self._orig_startup_prog = static.default_startup_program()
            if not self._has_prepared[self._mode]:
                self._prepare_program(self._mode)
            else:
                self._switch_mode(self._mode)
        elif inputs_spec or labels_spec:
1163 1164
            self._inputs_spec = inputs_spec
            self._labels_spec = labels_spec
1165 1166
            self._outside_dataloader = True
            self._inputs, self._labels = self._prepare_data_tensor(
1167
                self._inputs_spec, self._labels_spec)
1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181
            self._orig_main_prog = main_program
            if self._orig_main_prog is None:
                self._orig_main_prog = static.default_main_program()
            self._orig_startup_prog = startup_program
            if self._orig_startup_prog is None:
                self._orig_startup_prog = static.default_startup_program()
            if not self._has_prepared[self._mode]:
                self._prepare_program(self._mode)
            else:
                self._switch_mode(self._mode)
        else:
            assert self._inputs_spec and self._labels_spec, \
                "Please call the dataloader(...) before calling prepare(...)"

1182
    def run(self, data=None, feed=None, fetch_list=None, mode=None):
1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194
        if mode is not None:
            self.to_mode(mode)
        feed_dict = self._prepare_feed(data, feed, self._mode)
        fetch_names, fetch_indices = self._prepare_fetch(fetch_list, self._mode)
        if self._outside_dataloader and not self._has_prepared_reader[
                self._mode]:
            self._prepare_reader()
        outs = self._executor.run(self.main_program,
                                  feed=feed_dict,
                                  fetch_list=fetch_names,
                                  use_program_cache=self._strategy.use_cache,
                                  return_numpy=self._strategy.return_numpy)
Z
zhaoyingli 已提交
1195 1196 1197
        logs = self._prepare_logger(outs, None, None, None, fetch_names,
                                    fetch_indices, self._mode)
        return logs
1198

1199 1200
    def _prepare_dataloader(self,
                            dataset,
1201 1202 1203 1204 1205 1206 1207 1208 1209 1210
                            return_list=True,
                            batch_size=1,
                            shuffle=False,
                            drop_last=False,
                            collate_fn=None,
                            num_workers=0,
                            use_buffer_reader=True,
                            use_shared_memory=True,
                            timeout=0,
                            worker_init_fn=None,
1211
                            epochs=1,
1212
                            steps_per_epoch=None):
1213

1214 1215 1216 1217
        if self._strategy.gradient_merge and batch_size is not None:
            assert batch_size % self._k_steps == 0, \
                "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(batch_size, self._k_steps)
            batch_size //= self._k_steps
1218

1219 1220 1221
        dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank]
        dist_startup_prog = self._dist_startup_progs[self._mode][self._cur_rank]
        dist_context = self._dist_contexts[self._mode]
1222
        dist_main_block = dist_main_prog.global_block()
1223

1224 1225 1226 1227
        # NOTE: Get feed_list, then insert dataloader op with sharded var shape.
        # Cause predict_program does not contain labels var,
        # then we will add labels var from serial_program to dist_program,
        # that maintains the length of feed_list equal to the length of dataset's values.
1228 1229
        inputs_var = self._feed_vars[self._mode]["inputs"]
        labels_var = self._feed_vars[self._mode]["labels"]
1230 1231 1232 1233
        feed_list = []
        for var in inputs_var + labels_var:
            if var.name in dist_main_block.vars:
                feed_list.append(dist_main_block.vars[var.name])
1234 1235 1236 1237
            else:
                copy_var = dist_main_block._clone_variable(var, var.persistable)
                copy_var.desc.set_original_id(var.desc.original_id())
                feed_list.append(copy_var)
1238 1239

        # insert read op at the end of program
1240
        places = paddle.static.cuda_places()
1241
        with static.program_guard(dist_main_prog, dist_startup_prog):
1242
            dataloader = DistributedDataLoader(
1243
                dataset,
1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258
                feed_list=feed_list,
                places=places,
                return_list=return_list,
                batch_size=batch_size,
                shuffle=shuffle,
                drop_last=drop_last,
                collate_fn=collate_fn,
                num_workers=num_workers,
                use_buffer_reader=use_buffer_reader,
                use_shared_memory=use_shared_memory,
                timeout=timeout,
                worker_init_fn=worker_init_fn,
                epochs=epochs,
                steps_per_epoch=steps_per_epoch,
                split_data=self._strategy.split_data,
1259
                data_parallel_world_size=self._dp_world_sizes,
1260
                data_parallel_rank=self._dp_ranks)
1261

1262 1263
        return dataloader

1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331
    def _prepare_dataloader_from_generator(self,
                                           dataset,
                                           capacity=None,
                                           use_double_buffer=True,
                                           iterable=True,
                                           return_list=False,
                                           use_multiprocess=False,
                                           drop_last=True,
                                           batch_size=1,
                                           epochs=1,
                                           steps_per_epoch=None,
                                           collate_fn=None):

        if self._strategy.gradient_merge and batch_size is not None:
            assert batch_size % self._k_steps == 0, \
                "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(batch_size, self._k_steps)
            batch_size //= self._k_steps

        dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank]
        dist_startup_prog = self._dist_startup_progs[self._mode][self._cur_rank]
        dist_context = self._dist_contexts[self._mode]
        dist_main_block = dist_main_prog.global_block()

        # NOTE: Get feed_list, then insert dataloader op with sharded var shape.
        # Cause predict_program does not contain labels var,
        # then we will add labels var from serial_program to dist_program,
        # that maintains the length of feed_list equal to the length of dataset's values.
        inputs_var = self._feed_vars[self._mode]["inputs"]
        labels_var = self._feed_vars[self._mode]["labels"]
        feed_list = []
        for var in inputs_var + labels_var:
            if var.name in dist_main_block.vars:
                feed_list.append(dist_main_block.vars[var.name])
            else:
                copy_var = dist_main_block._clone_variable(var, var.persistable)
                copy_var.desc.set_original_id(var.desc.original_id())
                feed_list.append(copy_var)

        places = paddle.static.cuda_places()
        with static.program_guard(dist_main_prog, dist_startup_prog):
            dataloader = DistributedDataLoaderFromGenerator(
                dataset=dataset,
                feed_list=feed_list,
                capacity=capacity,
                use_double_buffer=use_double_buffer,
                iterable=iterable,
                return_list=return_list,
                use_multiprocess=use_multiprocess,
                drop_last=drop_last,
                places=places,
                batch_size=batch_size,
                epochs=epochs,
                steps_per_epoch=steps_per_epoch,
                collate_fn=collate_fn,
                split_data=self._strategy.split_data,
                data_parallel_world_size=self._dp_world_sizes,
                data_parallel_rank=self._dp_ranks)
        self._prepare_reader()
        return dataloader

    def _tune(self, tune_data, tune_sample_split=None, batch_size=1):
        self._mode = 'train'
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
            tune_data, tune_sample_split, batch_size)
        self._inputs, self._labels = self._prepare_data_tensor(
            self._inputs_spec, self._labels_spec)
        self._optimization_tuning(self._mode, tune_data, batch_size)

1332 1333
    def _validate_spec(self, specs):
        specs = to_list(specs)
1334
        self._k_steps = self._strategy.gradient_merge.k_steps
1335 1336 1337 1338 1339 1340 1341
        if specs is not None:
            for i, spec in enumerate(specs):
                assert isinstance(spec, InputSpec)
                if spec.name is None:
                    raise ValueError(
                        "Requires Input[{}].name != None, but receive `None` with {}."
                        .format(i, spec))
1342
                if self._k_steps > 1:
1343
                    shape = list(spec.shape)
1344 1345 1346
                    assert shape[0] % self._k_steps == 0, \
                        "Requires batch_size[{}] to be divisible by k_steps[{}].".format(spec.shape[0], self._k_steps)
                    shape[0] //= self._k_steps
1347
                    spec.shape = shape
1348 1349
        return specs

1350 1351 1352 1353
    def _is_local_var(self, var):
        var_name = _to_name_str(var)
        return var_name in self.main_program.global_block().vars

1354 1355
    def _get_input_split_info(self, var, dist_context):
        # deduce how the input data is split among the cluster
1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374
        from .utils import _get_comm_group, _get_corresponding_rank

        tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
        process_mesh = tensor_dist_attr.process_mesh
        dims_mapping = tensor_dist_attr.dims_mapping

        if self._cur_rank not in process_mesh.processes:
            rank_id = _get_corresponding_rank(dist_context, process_mesh,
                                              self._cur_rank)
        else:
            rank_id = self._cur_rank

        batch_size_axis = dims_mapping[0]
        if batch_size_axis > -1 and process_mesh.topology[batch_size_axis] > 1:
            group_ranks = _get_comm_group(process_mesh.processes,
                                          process_mesh.topology,
                                          batch_size_axis, rank_id)
            return len(group_ranks), group_ranks.index(rank_id)

1375
        return 1, 0
1376

1377 1378 1379 1380
    def _set_recompute_ckpts(self):
        # NOTE hack to enable recompute in engine api for GPT-3
        # TODO support more PaddleNLP/CV models here

1381
        recompute = self._strategy.recompute
1382 1383

        # extract ckpts by specific model
1384
        if isinstance(self._model, paddle.nn.Layer):
Z
zhaoyingli 已提交
1385 1386 1387 1388
            if hasattr(self._model,
                       "gpt") and self._model.__class__.__name__ in [
                           'GPTForPretraining', 'GPTForPretrainingAuto'
                       ]:
1389
                exact_ckpts = self._model.gpt.checkpoints
1390
            else:
1391
                exact_ckpts = recompute.checkpoints
1392
        else:
1393
            exact_ckpts = recompute.checkpoints
1394 1395

        # modify strategy
1396 1397
        if recompute.enable:
            recompute.checkpoints = exact_ckpts[:]
1398
            logs = {
1399
                'Model Class': self._model.__class__.__name__,
1400 1401 1402 1403
                'Applied Recompute ckpts': exact_ckpts
            }
            self._logger.info(logs)

1404
    def _validate_opt(self, optimizer):
1405 1406 1407
        if optimizer is not None:
            optimizer._parameter_list = None
            optimizer._param_groups = None
1408 1409
        return optimizer

1410 1411 1412 1413
    def _reset_metrics(self):
        for metric in self._metrics:
            metric.reset()

Z
zhaoyingli 已提交
1414 1415 1416 1417 1418 1419
    def _metrics_name(self):
        metrics_name = ['loss'] if self._loss else []
        for m in self._metrics:
            metrics_name.extend(to_list(m.name()))
        return metrics_name

1420
    def _switch_mode(self, mode):
1421
        self.to_mode(mode)
Z
zhaoyingli 已提交
1422
        self._optimizer = self._dist_contexts[mode]._serial_optimizer
1423

1424 1425 1426 1427 1428
    def to_mode(self, mode):
        assert mode in ["train", "eval", "predict"], \
            "mode {} should be one of ['train', 'eval', 'predict']".format(mode)
        self._mode = mode

1429 1430 1431 1432 1433 1434 1435 1436 1437
    def _set_state_dict(self, mode, strict, state_dict, dist_attr):
        program = self._dist_main_progs[mode][self._cur_rank]
        dist_context = self._dist_contexts[mode]
        cur_dist_attr = get_dist_attr(program, dist_context)
        converter = Converter(state_dict, dist_attr, cur_dist_attr)
        state_dict = converter.convert(strict=strict)
        program.set_state_dict(state_dict)

    def save(self, path, training=True):
1438 1439
        """
        Saves the model, parameters, optimizer state to path.
1440 1441 1442 1443 1444 1445 1446
        If `training` is set to False, only inference model will be saved.

        Args:
            path (str): The file prefix to save model. The format
                is 'dirname/file_prefix' or 'file_prefix'. if empty str.
                A exception will be raised.
            training (bool, optional): Whether to save for training. If not, save
1447
                for inference only. If `training` is set to True, the optimizer state
1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459
                will be saved. Otherwise, only the model and parameters are saved.
                This function will silently overwrite existing file at the target
                location. Default: True.

        Returns:
            None

        Examples:

            .. code-block:: python
                import paddle
                import paddle.vision.transforms as T
1460
                from paddle.distributed.fleet import auto
1461 1462 1463 1464 1465 1466 1467 1468 1469
                from paddle.vision.datasets import MNIST

                transform = T.Compose([
                    T.Transpose(),
                    T.Normalize([127.5], [127.5])
                ])
                train_dataset = MNIST(mode='train', transform=transform)

                model = paddle.vision.models.LeNet()
1470
                loss = paddle.nn.CrossEntropyLoss()
1471 1472 1473 1474
                optimizer = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=model.parameters())
                metrics = paddle.metric.Accuracy(topk=(1, 2))

1475
                engine = auto.Engine(model, loss, optimizer, metrics)
1476 1477 1478 1479
                engine.fit(train_dataset,
                           epochs=1,
                           batch_size=64)
                engine.save("./my_model")
1480

1481
        """
1482
        if training:
Z
zhaoyingli 已提交
1483 1484 1485 1486
            assert self._mode in self._serial_main_progs
            serial_program = self._serial_main_progs[self._mode]
            dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank]
            dist_context = self._dist_contexts[self._mode]
1487 1488 1489 1490
            self._saver.save(path,
                             serial_program=serial_program,
                             dist_main_program=dist_main_prog,
                             dist_context=dist_context)
1491
        else:
Z
zhaoyingli 已提交
1492 1493 1494 1495
            assert "predict" in self._dist_main_progs
            feed_vars = self._feed_vars["predict"]['inputs']
            fetch_vars = self._fetch_vars["predict"]['outputs']
            dist_main_prog = self._dist_main_progs["predict"][self._cur_rank]
1496 1497 1498 1499 1500
            self._saver.save_inference_model(path,
                                             feed_vars,
                                             fetch_vars,
                                             self._executor,
                                             program=dist_main_prog)
1501

1502 1503 1504 1505 1506 1507
    def load(self, path, strict=True, load_optimizer=True):
        """
        Load the stored model, parameters and optimizer states.

        Args:
            path (str): The prefix of files storing the model states and
1508
                optimizer states.
1509 1510 1511
            strict (bool, optional): Whether to skip the loading of mismatch
                parameter or raise an error when mismatch happens (not found
                the parameter in file storing model states of or receives a
1512
                mismatch shape). Default: True.
1513
            load_optimizer (bool, optional): If True, the stored optimizer
1514
                states is restored. Otherwise, the optimizer states is initialized
1515
                from scratch. Default: True.
1516 1517 1518 1519 1520 1521 1522 1523 1524

        Returns:
            None

        Examples:

            .. code-block:: python
                import paddle
                import paddle.vision.transforms as T
1525
                from paddle.distributed.fleet import auto
1526 1527 1528 1529 1530 1531 1532 1533 1534
                from paddle.vision.datasets import MNIST

                transform = T.Compose([
                    T.Transpose(),
                    T.Normalize([127.5], [127.5])
                ])
                train_dataset = MNIST(mode='train', transform=transform)

                model = paddle.vision.models.LeNet()
1535
                loss = paddle.nn.CrossEntropyLoss()
1536 1537 1538 1539
                optimizer = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=model.parameters())
                metrics = paddle.metric.Accuracy(topk=(1, 2))

1540
                engine = auto.Engine(model, loss, optimizer, metrics)
1541 1542 1543 1544 1545
                engine.fit(train_dataset,
                           epochs=1,
                           batch_size=64)
                engine.save("./my_model")
                engine.load("./my_model")
1546

1547 1548 1549 1550 1551
        """
        self._strict = strict
        self._state_dict, self._dist_attr = self._saver.load(
            path, load_optimizer)
        return self._state_dict, self._dist_attr
1552

1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600
    def cost(self, inputs_spec=None, labels_spec=None, mode="train"):
        """
        Get and Print cost, including memory of every rank,
        max memory among all ranks, and the global cost of one step based on
        communication cost(computation cost is 0 by default).
        In the future, the flops information of every rank and global cost including
        computation cost will be added.

        Args:
            inputs_spec(InputSpec): The specification of inputs. Default: None.
            labels_spec(InputSpec): The specification of labels. Default: None.
            mode (str): The engine mode must be in ["train", "predict", "eval"]. Default: "train".

        Returns:
            Return the global execution time (ms) and max memory (B).

        """
        # Check parallel mode
        if self._strategy.auto_mode == "full":
            print(
                "The cost will be calcudated in the search process when the auto mode is full."
            )
            return

        # Check mode
        accepted_modes = ["train", "predict", "eval"]
        if mode not in accepted_modes:
            raise ValueError("The mode {} is not in accepted modes {}".format(
                mode, accepted_modes))
        self.to_mode(mode)

        if inputs_spec is not None:
            self._inputs_spec, self._labels_spec = inputs_spec, labels_spec
            self._inputs, self._labels = self._prepare_data_tensor(
                self._inputs_spec, self._labels_spec)
            self._build(mode)
            self._plan(mode)
        else:
            if _non_static_mode() or self._dygraph_mode:
                raise ValueError(
                    "Please call `engine._prepare_program('mode')` firstly when in the static graph mode."
                )

        # Estimate the exec cost and max memory
        global_cost, max_memory = get_cost_from_engine(self, mode)

        return global_cost.time, max_memory

1601 1602
    @property
    def main_program(self):
1603
        return self._dist_main_progs[self._mode][self._cur_rank]
1604 1605 1606

    @property
    def startup_program(self):
1607
        return self._dist_startup_progs[self._mode][self._cur_rank]
1608 1609 1610

    @property
    def dist_context(self):
1611
        return self._dist_contexts[self._mode]
1612 1613 1614

    @property
    def serial_main_program(self):
1615
        return self._serial_main_progs[self._mode]
1616 1617 1618

    @property
    def serial_startup_program(self):
1619
        return self._serial_startup_progs[self._mode]
1620 1621 1622

    @property
    def fetch_vars(self):
1623
        return self._fetch_vars[self._mode]
1624 1625 1626

    @property
    def inputs(self):
1627
        return self._inputs
1628 1629 1630

    @property
    def labels(self):
1631
        return self._labels