engine.py 68.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import os
Z
zhaoyingli 已提交
16
import copy
17
import logging
18
import random
Z
zhaoyingli 已提交
19
import numbers
20
import numpy as np
21 22 23
from collections import defaultdict

import paddle
24
import paddle.utils as utils
25

26
from paddle import fluid, static
27
from paddle.metric import Metric
28
from paddle.static import InputSpec
29
from paddle.fluid import core
30
from paddle.fluid import Variable
31
from paddle.fluid.layers.utils import flatten
32
from paddle.fluid.executor import global_scope, _to_name_str
33
from paddle.fluid.framework import Operator, _non_static_mode
34 35
from paddle.fluid.framework import _current_expected_place as _get_device
from paddle.fluid.dygraph.parallel import ParallelEnv
36
from paddle.distributed import fleet
37
from paddle.distributed.parallel import _is_global_parallel_initialize
38

39
from .callbacks import config_callbacks
40
from .converter import Converter
41
from .helper import ProgramHelper
42
from .cluster import Cluster, get_default_cluster
43 44
from .planner_v2 import Planner
from .parallelizer_v2 import Parallelizer
45 46
from .dist_op import DistributedOperator
from .dist_saver import DistributedSaver
Z
zhaoyingli 已提交
47 48 49 50
from .dist_loader import (
    DistributedDataLoaderFromGenerator,
    DistributedDataLoader,
)
51
from .process_group import new_process_group, get_all_process_groups
52
from .dist_context import DistributedContext, get_default_distributed_context
53
from .strategy import Strategy
J
JZ-LIANG 已提交
54
from .interface import CollectionNames, get_collection, fetch
Z
zhaoyingli 已提交
55 56
from .utils import to_list, get_dist_attr, get_lr, validate_opt
from .utils import initialize_pg_in_full_mode, get_input_split_info
57
from .cost.estimate_cost import get_cost_from_engine
58

Z
zhaoyingli 已提交
59 60
from ..utils.log_utils import get_logger

61 62

class Engine:
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
    """
    An Engine object can provide the full power of auto parallel to users.
    With the help of it, users can easily obtain the abilities of the
    distributed training and inference. It also support the dynamic graph and
    static graph at the same time.

    Args:
        model (paddle.nn.Layer, optional): The model is an instance of
            paddle.nn.Layer.
        loss (Loss|Callable|None, optional): The loss can be a `paddle.nn.Layer`
            instance or any callable function taken the predicted values and
            ground truth values as input. It can be None when there is no loss.
            Default: None.
        optimizer (Optimizer|None, optional): The optimizer need to be set in training
            and should be None in eval and predict mode. Default: None.
        metrics (Metric|list[Metric]|None, optional): If metrics is set, all
            metrics will be calculated and output in train/eval mode. Default: None.
        cluster (Cluster|None, optional): The cluster represents the topology information
            about the used physical devices. Default: None. (Unused for now)
        strategy (Strategy|None, optional): The strategy is used to configure the
        parallelization and optimization behaviors. Default: None.

    Examples:

        .. code-block:: python

            import paddle
            import paddle.vision.transforms as T
91
            from paddle.distributed.fleet import auto
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
            from paddle.vision.datasets import MNIST

            transform = T.Compose([
                T.Transpose(),
                T.Normalize([127.5], [127.5])
            ])
            train_dataset = MNIST(mode='train', transform=transform)
            valid_dataset = MNIST(mode='test', transform=transform)

            model = paddle.vision.models.LeNet()
            loss = paddle.nn.CrossEntropyLoss()
            optimizer = paddle.optimizer.Adam(
                learning_rate=0.001, parameters=model.parameters())
            metrics = paddle.metric.Accuracy(topk=(1, 2))

            engine = auto.Engine(model, loss, optimizer, metrics)
            # fit
            engine.fit(train_dataset,
                       epochs=2,
                       batch_size=64)
            # evaluate
            engine.evaluate(valid_dataset,
                            batch_size=64)
            # predict
            engine.predict(valid_dataset,
                           batch_size=64)
            # save
            engine.save("./my_model")
            # load
            engine.load("./my_model")

    """
124

Z
zhaoyingli 已提交
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
    def __init__(
        self,
        model=None,
        loss=None,
        optimizer=None,
        metrics=None,
        cluster=None,
        strategy=None,
    ):

        if (
            model
            and not isinstance(model, paddle.nn.Layer)
            and not callable(model)
        ):
140 141 142 143
            raise TypeError(
                "'model must be sub classes of `paddle.nn.Layer` or any callable function."
            )
        self._model = model
Z
zhaoyingli 已提交
144 145 146 147 148 149 150 151 152

        if (
            loss
            and not isinstance(loss, (paddle.nn.Layer, Variable))
            and not callable(loss)
        ):
            raise TypeError(
                "'loss' must be sub classes of `paddle.nn.Layer` or any callable function or a Variable."
            )
153 154 155
        self._loss = loss

        if optimizer and not isinstance(
Z
zhaoyingli 已提交
156 157 158
            optimizer,
            (paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer),
        ):
159 160
            raise TypeError(
                "'optimizer' must be object of class `paddle.optimizer.Optimizer`"
Z
zhaoyingli 已提交
161 162 163
                " or `paddle.fluid.optimizer.Optimizer`."
            )
        self._optimizer = validate_opt(optimizer)
164 165 166

        metrics = metrics or []
        for metric in to_list(metrics):
Z
zhaoyingli 已提交
167 168 169 170 171 172
            if metric and not isinstance(metric, Metric):
                raise TypeError(
                    "{} is not sub class of Metric".format(
                        metric.__class__.__name__
                    )
                )
173 174 175 176 177 178 179 180 181 182 183 184 185 186
        self._metrics = to_list(metrics)

        if cluster and not isinstance(cluster, Cluster):
            raise TypeError(
                "'cluster' must be the object or class `paddle.distributed.auto_parallel.Cluster`"
            )
        self._cluster = cluster or get_default_cluster()

        if strategy and not isinstance(strategy, Strategy):
            raise TypeError(
                "'strategy' must be object of class `paddle.distributed.auto_parallel.Strategy`"
            )
        self._strategy = strategy or Strategy()

Z
zhaoyingli 已提交
187
        self._logger = get_logger(logging.INFO)
188
        if os.getenv("POD_NAME") and not _is_global_parallel_initialize():
Z
zhaoyingli 已提交
189 190 191
            self._logger.info(
                "Distribute training by paddle.distributed.launch"
            )
192
            fleet.init(is_collective=True)
193

194 195 196 197 198 199
        # for compute cost
        # TODO: remove _fwd_main_progs and _orig_optimizer
        self._fwd_dist_contexts = {}
        self._fwd_main_progs = {}
        self._orig_optimizer = copy.deepcopy(self._optimizer)

200
        self._executor = None
201 202 203
        self._cur_rank = paddle.distributed.get_rank()
        self._nranks = paddle.distributed.get_world_size()
        self._saver = DistributedSaver()
204

205 206
        self._orig_main_prog = static.default_main_program()
        self._orig_startup_prog = static.default_startup_program()
207
        self._orig_dist_context = get_default_distributed_context()
208
        self._dist_contexts = {}
209
        self._planners = {}
210 211
        self._has_prepared = {"train": False, "eval": False, "predict": False}
        self._has_prepared_reader = {
212 213
            "train": False,
            "eval": False,
Z
zhaoyingli 已提交
214
            "predict": False,
215
        }
216 217 218 219
        self._inputs_spec = []
        self._labels_spec = []
        self._inputs = []
        self._labels = []
Z
zhaoyingli 已提交
220
        self._losses = []
221

Z
zhaoyingli 已提交
222
        self._mode = None
223 224
        self._skip_build = False
        self._outside_dataloader = False
225
        self._planned_mode = None
226 227
        self._dygraph_mode = False
        self._tuning = self._strategy.tuning
228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248

        self.history = None

    def _prepare_data_spec(self, data, split, batch_size):
        inputs_spec = []
        labels_spec = []
        if isinstance(data, paddle.io.IterableDataset):
            if split is None:
                inputs, labels = next(iter(data))
            else:
                sample = next(iter(data))
                inputs = sample[:split]
                labels = sample[split:]
        elif isinstance(data, paddle.io.Dataset):
            if split is None:
                inputs, labels = data[0]
            else:
                sample = data[0]
                inputs = sample[:split]
                labels = sample[split:]
        else:
Z
zhaoyingli 已提交
249 250 251 252 253
            raise TypeError(
                "Data should be a Dataset or IterableDatset, but received {}.".format(
                    type(data).__name__
                )
            )
254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272
        inputs = to_list(inputs)
        labels = to_list(labels)

        num_shards = self._strategy.dataset.num_shards

        def _adjust_item_spec(num_shards, spec):
            if num_shards > 1 and len(spec.shape) > 1:
                spec.shape[0] = spec.shape[0] * num_shards

        def _infer_item_spec(item, name, batch_size, specs):
            if isinstance(item, np.ndarray):
                spec = InputSpec.from_numpy(item, name)
                if batch_size is None:
                    _adjust_item_spec(num_shards, spec)
                    specs.append(spec)
                else:
                    specs.append(spec.batch(batch_size))
            elif isinstance(item, (Variable, core.VarBase, core.eager.Tensor)):
                spec = InputSpec.from_tensor(item, name)
Z
zhaoyingli 已提交
273
                _adjust_item_spec(num_shards, spec)
274 275 276 277
                if batch_size is None:
                    specs.append(spec)
                else:
                    specs.append(spec.batch(batch_size))
Z
zhaoyingli 已提交
278
            elif isinstance(item, numbers.Number):
279
                specs.append(InputSpec([batch_size], type(item), name))
Z
zhaoyingli 已提交
280 281 282 283 284 285
            else:
                raise TypeError(
                    "The sample's dtype returned of dataset should be number, np.ndarray or Tensor, but got {}".format(
                        type(item).__name__
                    )
                )
286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301

        if inputs is not None:
            for i, item in enumerate(inputs):
                assert item is not None, "Receive None input."
                name = "input" + str(i)
                _infer_item_spec(item, name, batch_size, inputs_spec)
        if labels is not None:
            for i, item in enumerate(labels):
                assert item is not None, "Receive None input."
                name = "label" + str(i)
                _infer_item_spec(item, name, batch_size, labels_spec)

        inputs_spec = self._validate_spec(inputs_spec)
        labels_spec = self._validate_spec(labels_spec)
        return inputs_spec, labels_spec

Z
zhaoyingli 已提交
302
    def _prepare_data_tensor(self, inputs_spec, labels_spec, inputs, labels):
303
        if _non_static_mode() or self._dygraph_mode:
Z
zhaoyingli 已提交
304 305
            raise ValueError("Only support static graph mode.")

306
        if inputs_spec:
Z
zhaoyingli 已提交
307 308 309 310 311 312 313 314 315 316 317 318 319 320
            assert isinstance(
                inputs_spec, list
            ), "inputs should be list, but received {}".format(
                type(inputs_spec)
            )
            assert isinstance(
                inputs, list
            ), "inputs should be list, but received {}".format(type(inputs))
            assert len(inputs_spec) == len(
                inputs
            ), "the number of `inputs_spec` should be equal to `inputs`'s."
            for input_spec, input in zip(inputs_spec, inputs):
                if input_spec.shape != input.shape:
                    input.desc.set_shape(input_spec.shape)
321
        if labels_spec:
Z
zhaoyingli 已提交
322 323 324 325 326 327 328 329 330 331 332 333 334 335 336
            assert isinstance(
                labels_spec, list
            ), "labels should be list, but received {}".format(
                type(labels_spec)
            )
            assert isinstance(
                labels, list
            ), "labels should be list, but received {}".format(type(labels))
            assert len(labels_spec) == len(
                labels
            ), "the number of `labels_spec` should be equal to `labels`'s."
            for label_spec, label in zip(labels_spec, labels):
                if label_spec.shape != label.shape:
                    label.desc.set_shape(label_spec.shape)

337 338
        return inputs, labels

339
    def _prepare_reader(self, feed_list=[]):
340
        dist_context = self._dist_contexts[self._mode]
341
        dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
342
        dist_main_block = dist_main_prog.global_block()
343

344 345
        # NOTE: this list may be changed if Paddle changes the existing rules.
        related_reader_ops = [
Z
zhaoyingli 已提交
346 347 348
            "create_py_reader",
            "create_double_buffer_reader",
            "read",
349 350 351 352 353 354 355 356 357 358 359 360 361
        ]
        # remove the first three ops if multiple run fit/evaluate/predict
        if dist_main_block.ops[0].type == 'create_py_reader':
            for i in range(len(related_reader_ops)):
                if dist_main_block.ops[0].type in related_reader_ops:
                    dist_main_block._remove_op(0, sync=False)
        dist_main_block._sync_with_cpp()
        # Step 1: find the reader ops
        reader_op_indices = []
        for idx, op in enumerate(dist_main_block.ops):
            if op.type in related_reader_ops:
                reader_op_indices.append(idx)
        # Step 2: insert the new reader ops to cpp
362 363
        # record the read ops' desc to insert to program of forward task_node
        read_ops_desc = []
364 365 366 367
        new_reader_ops = []
        for idx in reversed(reader_op_indices):
            new_op_desc = dist_main_block.desc._prepend_op()
            new_op_desc.copy_from(dist_main_block.ops[idx].desc)
368
            read_ops_desc.append(new_op_desc)
Z
zhaoyingli 已提交
369 370 371
            new_op = Operator(
                dist_main_block, new_op_desc, type=new_op_desc.type()
            )
372 373 374 375 376 377 378 379 380 381 382 383 384 385 386
            new_reader_ops.append(new_op)
            dist_op = DistributedOperator(new_op)
            dist_context.add_dist_op_for_program(dist_op)
        # Step 3: insert the new reader ops to python
        for new_op in new_reader_ops:
            dist_main_block.ops.insert(0, new_op)
        for i in range(len(reader_op_indices)):
            reader_op_indices[i] += len(reader_op_indices)
        # Step 4: remove the old reader ops from python and cpp
        for idx in reversed(reader_op_indices):
            op = dist_main_block.ops.pop(idx)
            dist_main_block.desc._remove_op(idx, idx + 1)
        dist_main_block._sync_with_cpp()
        self._has_prepared_reader[self._mode] = True

387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409
        # Insert read op to forward TaskNode if 1F1B pass is setted
        if self.main_program._pipeline_opt:
            assert "tasks" in self.main_program._pipeline_opt["fleet_opt"]
            fleet_opt = self.main_program._pipeline_opt["fleet_opt"]
            fwd_task = fleet_opt["tasks"][0]
            fwd_prog = fwd_task.get_program()
            fwd_block = fwd_prog.global_block()

            for var in feed_list:
                if var.name not in fwd_block.vars:
                    fwd_block._clone_variable(var)

            for op_desc in read_ops_desc:
                new_op_desc = fwd_block.desc._prepend_op()
                new_op_desc.copy_from(op_desc)
                new_op = Operator(
                    fwd_block, new_op_desc, type=new_op_desc.type()
                )
                fwd_block.ops.insert(0, new_op)

            fwd_block._sync_with_cpp()
            fwd_task.set_program(fwd_prog)

410 411 412 413 414 415 416 417 418 419 420 421 422 423 424
    def _prepare_feed(self, data, user_feeds, mode):
        feeds = {}
        if data is not None:
            if isinstance(data, (list, tuple)):
                if len(data) == 1 and isinstance(data[0], dict):
                    for name, data in data[0].items():
                        feeds[name] = data
                else:
                    raise ValueError("Unsupported data {}".format(data))
            elif isinstance(data, dict):
                for name, data in data.items():
                    feeds[name] = data
            else:
                raise ValueError("Unsupported data {}".format(data))
        if user_feeds is not None:
Z
zhaoyingli 已提交
425 426 427 428 429
            assert isinstance(
                user_feeds, dict
            ), "user_feeds must be a dict, but receive {}".format(
                type(user_feeds).__name__
            )
430 431 432 433 434 435
            for name, data in user_feeds.items():
                feeds[name] = data
        return feeds

    def _prepare_fetch(self, user_fetches, mode):
        if user_fetches is not None:
Z
zhaoyingli 已提交
436 437 438 439 440
            assert isinstance(
                user_fetches, list
            ), "user_fetches must be a list, but receive {}".format(
                type(user_fetches).__name__
            )
J
JZ-LIANG 已提交
441 442
        else:
            user_fetches = []
443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458
        fetch_names = []
        fetch_indices = []

        def _process_fetch_group(group_name, var_list):
            group_indices = []
            for var in var_list:
                # Remove duplicate var_names
                if self._is_local_var(var):
                    var_name = _to_name_str(var)
                    if var_name not in fetch_names:
                        fetch_names.append(var_name)
                    group_indices.append(fetch_names.index(var_name))
            if not group_indices:
                fetch_names.append([])
            fetch_indices.append(group_indices)

459 460
        dist_context = self._dist_contexts[mode]
        fetch_vars = dist_context.serial_fetch_vars
461
        if mode != "predict":
462
            _process_fetch_group("loss", fetch_vars["loss"])
463
        if mode != "predict":
464
            metrics = fetch_vars["metrics"]
465 466 467
            for i, var_list in enumerate(metrics):
                _process_fetch_group("metrics_" + str(i), var_list)
        if mode == "predict":
468
            _process_fetch_group("outputs", fetch_vars["outputs"])
J
JZ-LIANG 已提交
469 470 471
        for usr_fetch in user_fetches:
            var_name = _to_name_str(usr_fetch)
            fetch(var_name)
472 473 474
        user_fetches_collection = [
            item[1] for item in get_collection(CollectionNames.FETCHES)
        ]
J
JZ-LIANG 已提交
475
        var_list = user_fetches_collection or []
476 477 478
        _process_fetch_group("fetches", var_list)
        return fetch_names, fetch_indices

Z
zhaoyingli 已提交
479 480 481 482 483 484 485 486 487 488
    def _prepare_logger(
        self,
        outs,
        epoch=None,
        step=None,
        lr=None,
        fetch_names=None,
        fetch_indices=None,
        mode=None,
    ):
489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504
        logs = {}
        if epoch is not None:
            logs["epoch"] = epoch
        if step is not None:
            logs["step"] = step + 1
        if lr is not None:
            logs["lr"] = lr
        group_idx = 0
        if mode != "predict":
            # logging loss
            loss_indices = fetch_indices[group_idx]
            assert len(loss_indices) <= 1
            for idx in loss_indices:
                logs["loss"] = outs[idx][0]
            group_idx += 1
            # logging metrics
505 506
            dist_context = self._dist_contexts[mode]
            metric_vars = dist_context.serial_fetch_vars["metrics"]
507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529
            if metric_vars:
                for metric in self._metrics:
                    metrics_indices = fetch_indices[group_idx]
                    metric_out = []
                    for idx in metrics_indices:
                        metric_out.append(outs[idx])
                    if metric_out:
                        metric.update(*metric_out)
                        results = metric.accumulate()
                        for i, res in enumerate(to_list(results)):
                            logs[metric.name()[i]] = res
                    group_idx += 1
        # logging outputs
        elif mode == "predict":
            outputs_indices = fetch_indices[group_idx]
            logs_out = {}
            for idx in outputs_indices:
                logs_out["out%d" % (idx)] = outs[idx]
            logs["outputs"] = logs_out
            group_idx += 1
        # logging user fetches
        collect_fetches = get_collection(CollectionNames.FETCHES)
        logs_fetch = {}
J
JZ-LIANG 已提交
530 531 532 533
        for name, var_name in collect_fetches:
            if var_name in fetch_names:
                idx = fetch_names.index(var_name)
                logs_fetch[name or var_name] = outs[idx]
534 535 536
        logs["fetches"] = logs_fetch
        return logs

537
    def _prepare_program(self, mode, init_parameters=True):
538
        # Do the build process
539 540 541 542
        self._build(mode)
        # Do the planning process
        self._plan(mode)
        # Do the parallel process
543
        self._parallel(mode)
544 545 546 547 548
        # Init comm
        self._init_comm()
        if init_parameters:
            # startup program
            self._initialize(mode)
549
        self._has_prepared[mode] = True
550

551
    def _build(self, mode):
552
        if _non_static_mode() or self._dygraph_mode:
553
            paddle.disable_static()
554 555 556
            self._dygraph_mode = True
            self._logger.info("Building model with 'to_static' method.")

Z
zhaoyingli 已提交
557 558 559 560 561 562 563
            self.program_helper = ProgramHelper(
                self._model,
                self._loss,
                self._metrics,
                self._inputs_spec,
                self._labels_spec,
            )
564
            # build forward main program
565
            self.program_helper.build_program(mode)
566

567 568 569
            self.concrete_program = self.program_helper.concrete_program
            serial_main_prog = self.program_helper.main_program
            serial_startup_prog = self.program_helper.startup_program
570

Z
zhaoyingli 已提交
571 572
            self._inputs = self.program_helper.input_vars
            self._labels = self.program_helper.label_vars
573
            outputs = self.program_helper.output_vars
Z
zhaoyingli 已提交
574
            self._losses = self.program_helper.loss_vars
575
            metrics = self.program_helper.metric_vars
576

577
            paddle.enable_static()
578 579
        else:
            # build program in static mode
580 581
            dist_context = self._dist_contexts.get(mode, None)
            if dist_context is not None:
582 583
                return

584
            outputs = []
585
            metrics = []
Z
zhaoyingli 已提交
586
            self._losses = []
587 588
            serial_main_prog = self._orig_main_prog.clone()
            serial_startup_prog = self._orig_startup_prog.clone()
589
            if not self._skip_build:
Z
zhaoyingli 已提交
590 591 592 593 594 595 596 597 598 599 600
                with static.program_guard(
                    serial_main_prog, serial_startup_prog
                ), utils.unique_name.guard():
                    self._inputs = [
                        s._create_feed_layer() for s in self._inputs_spec
                    ]
                    self._labels = [
                        s._create_feed_layer() for s in self._labels_spec
                    ]

                    outputs = to_list(self._model(*self._inputs))
601

Z
zhaoyingli 已提交
602 603 604 605 606 607 608 609 610 611 612
                    if mode != "predict" and self._loss:
                        assert isinstance(
                            self._loss, paddle.nn.Layer
                        ) or callable(
                            self._loss
                        ), "the type of `loss` of the Engine arguments should be sub classes of `paddle.nn.Layer` or any callable function."
                        self._losses = to_list(
                            self._loss(*(outputs + self._labels))
                        )

                    if mode != "predict" and (outputs or self._labels):
613 614
                        for metric in self._metrics:
                            metrics.append(
Z
zhaoyingli 已提交
615 616 617 618
                                to_list(
                                    metric.compute(*(outputs + self._labels))
                                )
                            )
619
            elif mode == "train":
Z
zhaoyingli 已提交
620 621 622 623
                assert isinstance(
                    self._loss, Variable
                ), "the type of `loss` of the Engine arguments should be Variable."
                self._losses = to_list(self._loss)
624 625 626 627 628 629 630 631

        default_ctx = get_default_distributed_context()
        if not default_ctx.has_annotation:
            # We build the world process group because the data parallel
            # needs all ranks by default.
            new_process_group(list(range(self._nranks)))
            default_ctx.data_parallel = True

Z
zhaoyingli 已提交
632
        feed_vars = {"inputs": self._inputs, "labels": self._labels}
633 634 635

        fetch_vars = {
            "outputs": flatten(outputs),
Z
zhaoyingli 已提交
636 637
            "loss": self._losses,
            "metrics": metrics,
638 639
        }

640 641 642
        if mode != "train":
            serial_main_prog = serial_main_prog.clone(for_test=True)

643
        self._set_recompute_ckpts()
644
        self._dist_contexts[mode] = DistributedContext(
Z
zhaoyingli 已提交
645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663
            serial_main_prog,
            serial_startup_prog,
            self._optimizer,
            self._losses,
            feed_vars,
            fetch_vars,
            self._cluster,
            self._strategy,
        )
        self._fwd_dist_contexts[mode] = DistributedContext(
            serial_main_prog,
            serial_startup_prog,
            self._optimizer,
            self._losses,
            feed_vars,
            fetch_vars,
            self._cluster,
            self._strategy,
        )
664
        self._dist_contexts[mode].gradient_scale = self._strategy.gradient_scale
Z
zhaoyingli 已提交
665
        self._fwd_main_progs[mode] = serial_main_prog.clone()
666

667 668 669
    def _optimization_tuning(self, mode, dataset, batch_size):
        if not self._tuning.enable:
            raise ValueError("Please set `tuning.enable=True`.")
670

671 672 673 674 675 676 677 678
        assert mode == "train"
        # Do the build process
        self._build(mode)
        # Do the planning process
        self._plan(mode)

        dataset.dp_world_size = self._dp_world_sizes
        dataset.dp_rank = self._dp_ranks
679 680

        from .tuner.optimization_tuner import OptimizationTuner
Z
zhaoyingli 已提交
681 682 683 684 685 686 687 688 689 690

        self._optimization_tuner = OptimizationTuner(
            self._tuning.to_dict(),
            self._dist_contexts[mode],
            dataset,
            self._inputs_spec,
            self._labels_spec,
            batch_size=batch_size,
            rank=self._cur_rank,
        )
691 692 693

        self._optimization_tuner.tune()

694
        if self._tuning.run_after_tuning:
695 696
            # update the strategy
            self._dist_contexts[
Z
zhaoyingli 已提交
697 698
                mode
            ]._strategy = self._optimization_tuner.get_best_config()
699

700 701 702 703 704 705
    def _plan(self, mode):
        if self._planned_mode is None:
            self._planned_mode = mode
        else:
            self._init_dist_context(mode)

706 707
        self._planners[mode] = Planner(mode, self._dist_contexts[mode])
        self._planners[mode].plan()
708

709 710 711 712
        # infer data parallel info
        inputs_var = self._dist_contexts[mode].serial_feed_vars["inputs"]
        labels_var = self._dist_contexts[mode].serial_feed_vars["labels"]
        block = self._dist_contexts[mode].serial_main_program.global_block()
713
        # TODO: check this feed_list
714 715 716 717 718
        feed_list = []
        for var in inputs_var + labels_var:
            if var.name in block.vars:
                feed_list.append(block.vars[var.name])

719 720
        self._dp_world_sizes = []
        self._dp_ranks = []
721
        for feed_var in feed_list:
Z
zhaoyingli 已提交
722 723 724
            dp_world_size, dp_rank = get_input_split_info(
                self._cur_rank, feed_var, self._dist_contexts[mode]
            )
725 726
            self._dp_world_sizes.append(dp_world_size)
            self._dp_ranks.append(dp_rank)
727

728
    def _parallel(self, mode, all_ranks=False):
729 730 731
        # Parallelize program based on the planner's results
        # For now, the completer has to be passed to the planner,
        # because we may use it to complete the annotation of the backwarkward and update.
Z
zhaoyingli 已提交
732 733 734
        parallelizer = Parallelizer(
            mode, self._planners[mode].completer, self._dist_contexts[mode]
        )
735 736 737 738
        if not all_ranks:
            parallelizer.parallel(self._cur_rank)
        else:
            parallelizer.parallel_all()
739 740

    def _init_dist_context(self, mode):
741
        # Init dist_context['mode'] with the first planned dist_context
742 743 744 745 746 747 748 749 750 751
        # to guarantee that train/eval/predict mode have same parallel strategy
        dist_context = self._dist_contexts[mode]
        origin_main_prog = dist_context._original_serial_main_program
        ref_mode = self._planned_mode
        ref_dist_context = self._dist_contexts[ref_mode]
        ref_origin_main_prog = ref_dist_context._original_serial_main_program
        ref_blocks = ref_origin_main_prog.blocks
        for ib, block in enumerate(origin_main_prog.blocks):
            for iop, op in enumerate(block.ops):
                ref_op = ref_blocks[ib].ops[iop]
Z
zhaoyingli 已提交
752 753 754 755 756 757 758 759
                assert (
                    op.type == ref_op.type
                ), "'{}' mode op '{}' is different with '{}' op '{}'. ".format(
                    mode, op.type, ref_mode, ref_op.type
                )
                ref_op_dist_attr = (
                    ref_dist_context.get_op_dist_attr_for_program(ref_op)
                )
760 761
                dist_context.set_op_dist_attr_for_program(op, ref_op_dist_attr)

762
    def _init_comm(self):
763 764 765 766
        if self._nranks > 1:
            # Traverse different rank programs and traverse each op of them,
            # instantiate communication by process_mapping.
            all_process_groups = get_all_process_groups()
767

768
            if self._strategy.auto_mode == "full":
769
                initialize_pg_in_full_mode(all_process_groups, self._cur_rank)
770 771 772 773 774
            else:
                for process_group in all_process_groups:
                    if self._cur_rank not in process_group.ranks:
                        continue
                    process_group.instantiate()
775

776
    def _initialize(self, mode):
777 778 779
        place = _get_device()
        if isinstance(place, fluid.CUDAPlace):
            place = fluid.CUDAPlace(ParallelEnv().dev_id)
780

781 782 783 784 785
        if self._strategy.seed:
            paddle.seed(self._strategy.seed + self._dp_ranks[0])
            np.random.seed(self._strategy.seed + self._dp_ranks[0])
            random.seed(self._strategy.seed + self._dp_ranks[0])

786
        dist_context = self._dist_contexts[mode]
787
        if self._dygraph_mode:
788
            dist_main_program = dist_context.dist_main_programs[self._cur_rank]
789
            self.program_helper.init(dist_main_program, place, dist_context)
790

791
        if self._executor is None:
792
            self._executor = paddle.static.Executor(place)
793
            uninitialized = []
794 795 796
            dist_startup_prog = dist_context.dist_startup_programs[
                self._cur_rank
            ]
797 798 799 800 801 802 803 804
            for var in dist_startup_prog.list_vars():
                scope_var = global_scope().find_var(var.name)
                if scope_var and scope_var.get_tensor()._is_initialized():
                    continue
                uninitialized.append(var)
            if uninitialized:
                prune_startup_prog = dist_startup_prog._prune(uninitialized)
                self._executor.run(prune_startup_prog)
805

806
            if hasattr(self, "_state_dict") and hasattr(self, "_dist_attr"):
Z
zhaoyingli 已提交
807 808 809
                self._set_state_dict(
                    mode, self._strict, self._state_dict, self._dist_attr
                )
810 811

        if self._strategy.reinit:
812
            self._logger.info("NOTE: parameters will be re-initialized.")
813 814 815
            dist_startup_prog = dist_context.dist_startup_programs[
                self._cur_rank
            ]
816 817
            self._executor.run(dist_startup_prog)

Z
zhaoyingli 已提交
818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835
    def fit(
        self,
        train_data,
        train_sample_split=None,
        batch_size=1,
        epochs=1,
        steps_per_epoch=None,
        log_freq=10,
        save_dir=None,
        save_freq=1,
        valid_data=None,
        valid_sample_split=None,
        valid_freq=1,
        valid_steps=None,
        collate_fn=None,
        callbacks=None,
        verbose=2,
    ):
836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881
        """
        Trains the model for a fixed number of epochs. If `valid_data` is set,
        evaluation will be done at the end of each epoch.

        Args:
            train_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None.
            train_sample_split (int, optional): Each sample of the train dataset is assumed
                to be a (input, label) pair by default and has two items. If each sample has
                more than two items, train_sample_split specifies how to split these items into
                input and label. The items before it are input and the left are label. Default: None.
            batch_size (int, optional): The batch size of train_data and valid_data if provided.
                The user's data will be used directly without batching if set to None. Default: 1.
            epochs (int, optional): The number of epochs to train the model. Default: 1.
            steps_per_epoch (int, optional): The total number of steps (batches of samples)
                is executed in one epoch before stating the next one. If None, it is equal to
                the number samples in your dataset divided by the batch size. Default: None.
            valid_data (Dataset, optional): An instance of paddle paddle.io.Dataset used for
                evaluation at the end of epoch. No evaluation will be done if set to None.
                Default: None. (Unsupported for now)
            valid_freq (int, optional): Only relevant if valid_data is provided. This specifies
                how many training epochs before a new evaluation is performed. Default: 1.
            valid_sample_split (int, optional): Only relevant if valid_data is provided.
                Each sample of the valid dataset is assumed to be a (input, label) pair
                by default and has two items. If each sample has more than two items,
                valid_sample_split specifies how to split these items into input and label.
                The items before it are input and the left are label. Default: None.
            valid_steps (int, optional): Only relevant if valid_data is provided.
                It is the total number of steps (batches of samples) to draw before
                stopping validation at the end of every epoch. If None, validation will run until the
                `valid_data` dataset is exhausted. The validation will start from the
                beginning of the dataset at each epoch. Default: None.
            collate_fn(callable, optional): function to generate mini-batch data by merging
                the sample list, None for only stack each fields of sample in axis
                0. Default None.
            callbacks (Callback|None, optional): A list of `Callback` instances to apply
                during training. Default: None. (Unused for now)

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle
                import paddle.vision.transforms as T
882
                from paddle.distributed.fleet import auto
883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901
                from paddle.vision.datasets import MNIST

                transform = T.Compose([
                    T.Transpose(),
                    T.Normalize([127.5], [127.5])
                ])
                train_dataset = MNIST(mode='train', transform=transform)

                model = paddle.vision.models.LeNet()
                loss = paddle.nn.CrossEntropyLoss()
                optimizer = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=model.parameters())
                metrics = paddle.metric.Accuracy(topk=(1, 2))

                engine = auto.Engine(model, loss, optimizer, metrics)
                engine.fit(train_dataset,
                           epochs=2,
                           batch_size=64)
        """
902 903
        self._mode = 'train'
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
Z
zhaoyingli 已提交
904 905
            train_data, train_sample_split, batch_size
        )
906 907
        if not self._has_prepared[self._mode]:
            self._prepare_program(self._mode)
908
        else:
909 910 911 912 913 914 915 916 917
            self._switch_mode(self._mode)

        train_dataloader = self._prepare_dataloader_from_generator(
            dataset=train_data,
            capacity=70,
            iterable=False,
            batch_size=batch_size,
            epochs=epochs,
            steps_per_epoch=steps_per_epoch,
Z
zhaoyingli 已提交
918 919
            collate_fn=collate_fn,
        )
920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937

        fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode)

        cbks = config_callbacks(
            callbacks,
            engine=self,
            batch_size=batch_size,
            epochs=epochs,
            steps=train_dataloader._steps,
            log_freq=log_freq,
            save_freq=save_freq,
            save_dir=save_dir,
            verbose=verbose,
            metrics=self._metrics_name(),
            acc_step=self._k_steps,
        )

        cbks.on_begin('train')
938
        for epoch in range(epochs):
939 940
            logs = {}
            cbks.on_epoch_begin(epoch)
941
            for step, _ in enumerate(train_dataloader):
942
                cbks.on_batch_begin('train', step, logs)
943
                try:
944 945
                    outs = self._executor.run(
                        self.main_program,
946
                        fetch_list=fetch_names,
947
                        use_program_cache=self._strategy.use_cache,
Z
zhaoyingli 已提交
948 949
                        return_numpy=self._strategy.return_numpy,
                    )
950
                except core.EOFException:
951
                    break
952
                lr = get_lr(self.optimizer)
Z
zhaoyingli 已提交
953 954 955 956 957 958 959 960 961
                logs = self._prepare_logger(
                    outs,
                    epoch,
                    step,
                    lr,
                    fetch_names,
                    fetch_indices,
                    self._mode,
                )
962 963 964
                cbks.on_batch_end('train', step, logs)

            if valid_data and (epoch + 1) % valid_freq == 0:
Z
zhaoyingli 已提交
965 966 967 968 969 970 971 972 973 974
                val_logs = self.evaluate(
                    valid_data,
                    valid_sample_split,
                    batch_size,
                    valid_steps,
                    log_freq,
                    collate_fn,
                    callbacks,
                    verbose,
                )
975
                val_logs = {
Z
zhaoyingli 已提交
976
                    "val_" + name: val for name, val in val_logs.items()
977 978
                }
                logs.update(val_logs)
979 980 981
                self._switch_mode("train")
            else:
                self._reset_metrics()
982 983 984 985 986

            cbks.on_epoch_end(epoch, logs)

        cbks.on_end('train', logs)
        return self.history
987

Z
zhaoyingli 已提交
988 989 990 991 992 993 994 995 996 997 998
    def evaluate(
        self,
        valid_data,
        valid_sample_split=None,
        batch_size=1,
        steps=None,
        log_freq=10,
        collate_fn=None,
        callbacks=None,
        verbose=2,
    ):
999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016
        """
        Evaluate the loss and metrics of the model on evaluation data.

        Args:
            valid_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None.
            valid_sample_split (int, optional): Each sample of the eval dataset is assumed
                to be a (input, label) pair by default and has two items. If each sample has
                more than two items, valid_sample_split specifies how to split these items into
                input and label. The items before it are input and the left are label. Default: None.
            batch_size (int, optional): The batch size of valid_data. The user's data will
                be used directly without batching if set to None. Default: 1.
            steps (int, optional): It is the total number of steps (batches of samples) to draw before
                stopping evaluation. If None, evaluation will run until the `valid_data` dataset is exhausted.
                The evaluation will start from the beginning of the dataset in each run. Default: None.
            collate_fn(callable, optional): function to generate mini-batch data by merging
                the sample list, None for only stack each fields of sample in axis
                0. Default None.
            callbacks (Callback|None, optional): A list of `Callback` instances to apply
1017
                during evaluating. Default: None. (Unused for now)
1018 1019 1020 1021 1022 1023 1024 1025 1026 1027

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle
                import paddle.vision.transforms as T
1028
                from paddle.distributed.fleet import auto
1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044
                from paddle.vision.datasets import MNIST

                transform = T.Compose([
                    T.Transpose(),
                    T.Normalize([127.5], [127.5])
                ])
                valid_dataset = MNIST(mode='test', transform=transform)

                model = paddle.vision.models.LeNet()
                loss = paddle.nn.CrossEntropyLoss()
                metrics = paddle.metric.Accuracy(topk=(1, 2))

                engine = auto.Engine(model, loss, metrics=metrics)
                engine.evaluate(valid_dataset, batch_size=64)

        """
1045 1046
        self._mode = 'eval'
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
Z
zhaoyingli 已提交
1047 1048
            valid_data, valid_sample_split, batch_size
        )
1049 1050
        if not self._has_prepared[self._mode]:
            self._prepare_program(self._mode)
1051
        else:
1052 1053 1054 1055 1056 1057 1058 1059
            self._switch_mode(self._mode)

        valid_dataloader = self._prepare_dataloader_from_generator(
            dataset=valid_data,
            capacity=70,
            iterable=False,
            batch_size=batch_size,
            steps_per_epoch=steps,
Z
zhaoyingli 已提交
1060 1061
            collate_fn=collate_fn,
        )
1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074

        fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode)

        cbks = config_callbacks(
            callbacks,
            engine=self,
            batch_size=batch_size,
            log_freq=log_freq,
            verbose=verbose,
            metrics=self._metrics_name(),
        )

        eval_steps = valid_dataloader._steps
Z
zhaoyingli 已提交
1075 1076 1077
        cbks.on_begin(
            'eval', {'steps': eval_steps, 'metrics': self._metrics_name()}
        )
1078
        logs = {}
1079
        for step, _ in enumerate(valid_dataloader):
1080
            cbks.on_batch_begin('eval', step, logs)
1081
            try:
1082 1083
                outs = self._executor.run(
                    self.main_program,
1084
                    fetch_list=fetch_names,
1085
                    use_program_cache=self._strategy.use_cache,
Z
zhaoyingli 已提交
1086 1087
                    return_numpy=self._strategy.return_numpy,
                )
1088
            except core.EOFException:
1089
                break
Z
zhaoyingli 已提交
1090 1091 1092
            logs = self._prepare_logger(
                outs, None, step, None, fetch_names, fetch_indices, self._mode
            )
1093 1094
            cbks.on_batch_end('eval', step, logs)
        cbks.on_end('eval', logs)
1095
        self._reset_metrics()
1096
        return logs
1097

Z
zhaoyingli 已提交
1098 1099 1100 1101 1102 1103 1104 1105 1106 1107
    def predict(
        self,
        test_data,
        test_sample_split=None,
        batch_size=1,
        steps=None,
        collate_fn=None,
        callbacks=None,
        verbose=2,
    ):
1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136
        """
        Compute the output predictions on testing data.

        Args:
            test_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None.
            test_sample_split (int, optional): Each sample of the test dataset is assumed
                to be a (input, label) pair by default and has two items. If each sample has
                more than two items, test_sample_split specifies how to split these items into
                input and label. The items before it are input and the left are label. Default: None.
            batch_size (int, optional): The batch size of test_data. The user's data will
                be used directly without batching if set to None. Default: 1.
            steps (int, optional): It is the total number of steps (batches of samples) to draw before
                stopping predict. If None, predict will run until the `test_data` dataset is exhausted.
                The predict will start from the beginning of the dataset in each run. Default: None.
            collate_fn(callable, optional): function to generate mini-batch data by merging
                the sample list, None for only stack each fields of sample in axis
                0. Default None.
            callbacks (Callback|None, optional): A list of `Callback` instances to apply
                during testing. Default: None. (Unused for now)

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle
                import paddle.vision.transforms as T
1137
                from paddle.distributed.fleet import auto
1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150
                from paddle.vision.datasets import MNIST

                transform = T.Compose([
                    T.Transpose(),
                    T.Normalize([127.5], [127.5])
                ])
                valid_dataset = MNIST(mode='test', transform=transform)

                model = paddle.vision.models.LeNet()

                engine = auto.Engine(model)
                engine.predict(valid_dataset, batch_size=64)
        """
1151 1152
        self._mode = 'predict'
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
Z
zhaoyingli 已提交
1153 1154
            test_data, test_sample_split, batch_size
        )
1155 1156
        if not self._has_prepared[self._mode]:
            self._prepare_program(self._mode)
1157
        else:
1158
            self._switch_mode(self._mode)
1159

1160 1161 1162 1163 1164 1165
        test_dataloader = self._prepare_dataloader_from_generator(
            dataset=test_data,
            capacity=70,
            iterable=False,
            batch_size=batch_size,
            steps_per_epoch=steps,
Z
zhaoyingli 已提交
1166 1167
            collate_fn=collate_fn,
        )
1168 1169

        fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode)
1170 1171

        outputs = []
1172 1173 1174 1175
        cbks = config_callbacks(callbacks, engine=self, verbose=verbose)
        test_steps = test_dataloader._steps
        cbks.on_begin('predict', {'steps': test_steps})
        logs = {}
1176
        for step, _ in enumerate(test_dataloader):
1177
            cbks.on_batch_begin('predict', step, logs)
1178
            try:
1179 1180
                outs = self._executor.run(
                    self.main_program,
1181
                    fetch_list=fetch_names,
1182
                    use_program_cache=self._strategy.use_cache,
Z
zhaoyingli 已提交
1183 1184
                    return_numpy=self._strategy.return_numpy,
                )
1185
            except core.EOFException:
1186
                break
Z
zhaoyingli 已提交
1187 1188 1189
            logs = self._prepare_logger(
                outs, None, step, None, fetch_names, fetch_indices, self._mode
            )
1190 1191 1192
            cbks.on_batch_end('predict', step, logs)
            outputs.append(list(logs["outputs"].values()))
        cbks.on_end('predict', logs)
1193
        return outputs
1194

Z
zhaoyingli 已提交
1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211
    def dataloader(
        self,
        dataset,
        batch_size=1,
        shuffle=False,
        drop_last=False,
        collate_fn=None,
        num_workers=0,
        use_buffer_reader=True,
        use_shared_memory=True,
        timeout=0,
        worker_init_fn=None,
        epochs=1,
        steps_per_epoch=None,
        sample_split=1,
        mode=None,
    ):
1212 1213 1214
        if mode is not None:
            self.to_mode(mode)
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
Z
zhaoyingli 已提交
1215 1216
            dataset, sample_split, batch_size
        )
1217 1218 1219 1220
        if not self._has_prepared[self._mode]:
            self._prepare_program(self._mode)
        else:
            self._switch_mode(self._mode)
Z
zhaoyingli 已提交
1221

1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234
        dataloader = self._prepare_dataloader(
            dataset,
            return_list=False,
            batch_size=batch_size,
            shuffle=shuffle,
            drop_last=drop_last,
            collate_fn=collate_fn,
            num_workers=num_workers,
            use_buffer_reader=use_buffer_reader,
            use_shared_memory=use_shared_memory,
            timeout=timeout,
            worker_init_fn=worker_init_fn,
            epochs=epochs,
Z
zhaoyingli 已提交
1235 1236
            steps_per_epoch=steps_per_epoch,
        )
1237
        return dataloader
1238

Z
zhaoyingli 已提交
1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253
    def dataloader_from_generator(
        self,
        dataset,
        capacity=70,
        use_double_buffer=True,
        iterable=True,
        use_multiprocess=False,
        drop_last=True,
        batch_size=1,
        epochs=1,
        steps_per_epoch=None,
        collate_fn=None,
        sample_split=1,
        mode=None,
    ):
1254 1255 1256
        if mode is not None:
            self.to_mode(mode)
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
Z
zhaoyingli 已提交
1257 1258
            dataset, sample_split, batch_size
        )
1259 1260 1261 1262
        if not self._has_prepared[self._mode]:
            self._prepare_program(self._mode)
        else:
            self._switch_mode(self._mode)
Z
zhaoyingli 已提交
1263

1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274
        dataloader = self._prepare_dataloader_from_generator(
            dataset=dataset,
            capacity=capacity,
            use_double_buffer=use_double_buffer,
            iterable=iterable,
            return_list=False,
            use_multiprocess=use_multiprocess,
            drop_last=drop_last,
            batch_size=batch_size,
            epochs=epochs,
            steps_per_epoch=steps_per_epoch,
Z
zhaoyingli 已提交
1275 1276
            collate_fn=collate_fn,
        )
1277 1278
        return dataloader

Z
zhaoyingli 已提交
1279 1280 1281 1282 1283 1284 1285 1286 1287
    def prepare(
        self,
        inputs_spec=None,
        labels_spec=None,
        inputs=None,
        labels=None,
        main_program=None,
        startup_program=None,
        mode=None,
1288
        init_parameters=True,
Z
zhaoyingli 已提交
1289
    ):
1290 1291
        if mode is not None:
            self.to_mode(mode)
Z
zhaoyingli 已提交
1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307

        if not self._mode:
            raise ValueError(
                "Please set mode to be prepared with `prepare(mode=...)`"
            )

        if self._has_prepared[self._mode]:
            return

        inputs_spec = self._validate_spec(inputs_spec)
        labels_spec = self._validate_spec(labels_spec)
        inputs = self._validate_vars(inputs)
        labels = self._validate_vars(labels)

        self._orig_main_prog = main_program
        self._orig_startup_prog = startup_program
1308 1309
        if inputs or labels:
            self._skip_build = True
Z
zhaoyingli 已提交
1310 1311 1312
            inputs, labels = self._prepare_data_tensor(
                inputs_spec, labels_spec, inputs, labels
            )
1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323
            if self._orig_main_prog is None:
                self._orig_main_prog = static.default_main_program()
            if self._orig_startup_prog is None:
                self._orig_startup_prog = static.default_startup_program()
        elif inputs_spec or labels_spec:
            self._outside_dataloader = True
            if self._orig_main_prog is None:
                self._orig_main_prog = static.default_main_program()
            if self._orig_startup_prog is None:
                self._orig_startup_prog = static.default_startup_program()
        else:
Z
zhaoyingli 已提交
1324 1325 1326 1327 1328 1329 1330
            assert (
                self._inputs_spec and self._labels_spec
            ), "Please call the dataloader(...) before calling prepare(...)"

        self._inputs_spec, self._labels_spec = inputs_spec, labels_spec
        self._inputs, self._labels = inputs, labels
        if not self._has_prepared[self._mode]:
1331
            self._prepare_program(self._mode, init_parameters)
Z
zhaoyingli 已提交
1332 1333
        else:
            self._switch_mode(self._mode)
1334 1335 1336 1337 1338 1339

    def run(self, data=None, feed=None, fetch_list=None, mode=None):
        if mode is not None:
            self.to_mode(mode)
        feed_dict = self._prepare_feed(data, feed, self._mode)
        fetch_names, fetch_indices = self._prepare_fetch(fetch_list, self._mode)
Z
zhaoyingli 已提交
1340 1341 1342 1343
        if (
            self._outside_dataloader
            and not self._has_prepared_reader[self._mode]
        ):
1344
            self._prepare_reader()
Z
zhaoyingli 已提交
1345 1346 1347 1348 1349 1350 1351 1352 1353 1354
        outs = self._executor.run(
            self.main_program,
            feed=feed_dict,
            fetch_list=fetch_names,
            use_program_cache=self._strategy.use_cache,
            return_numpy=self._strategy.return_numpy,
        )
        logs = self._prepare_logger(
            outs, None, None, None, fetch_names, fetch_indices, self._mode
        )
1355 1356
        return logs

Z
zhaoyingli 已提交
1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372
    def _prepare_dataloader(
        self,
        dataset,
        return_list=True,
        batch_size=1,
        shuffle=False,
        drop_last=False,
        collate_fn=None,
        num_workers=0,
        use_buffer_reader=True,
        use_shared_memory=True,
        timeout=0,
        worker_init_fn=None,
        epochs=1,
        steps_per_epoch=None,
    ):
1373 1374

        if self._strategy.gradient_merge and batch_size is not None:
Z
zhaoyingli 已提交
1375 1376 1377 1378 1379
            assert (
                batch_size % self._k_steps == 0
            ), "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(
                batch_size, self._k_steps
            )
1380 1381
            batch_size //= self._k_steps

1382 1383 1384
        dist_context = self._dist_contexts[self._mode]
        dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
        dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank]
1385
        dist_main_block = dist_main_prog.global_block()
1386

1387 1388 1389 1390
        # NOTE: Get feed_list, then insert dataloader op with sharded var shape.
        # Cause predict_program does not contain labels var,
        # then we will add labels var from serial_program to dist_program,
        # that maintains the length of feed_list equal to the length of dataset's values.
1391 1392
        inputs_var = dist_context.serial_feed_vars["inputs"]
        labels_var = dist_context.serial_feed_vars["labels"]
1393 1394 1395 1396
        feed_list = []
        for var in inputs_var + labels_var:
            if var.name in dist_main_block.vars:
                feed_list.append(dist_main_block.vars[var.name])
1397 1398 1399 1400
            else:
                copy_var = dist_main_block._clone_variable(var, var.persistable)
                copy_var.desc.set_original_id(var.desc.original_id())
                feed_list.append(copy_var)
1401 1402

        # insert read op at the end of program
1403
        places = paddle.static.cuda_places()
1404
        with static.program_guard(dist_main_prog, dist_startup_prog):
1405
            dataloader = DistributedDataLoader(
1406
                dataset,
1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421
                feed_list=feed_list,
                places=places,
                return_list=return_list,
                batch_size=batch_size,
                shuffle=shuffle,
                drop_last=drop_last,
                collate_fn=collate_fn,
                num_workers=num_workers,
                use_buffer_reader=use_buffer_reader,
                use_shared_memory=use_shared_memory,
                timeout=timeout,
                worker_init_fn=worker_init_fn,
                epochs=epochs,
                steps_per_epoch=steps_per_epoch,
                split_data=self._strategy.split_data,
1422
                data_parallel_world_size=self._dp_world_sizes,
Z
zhaoyingli 已提交
1423 1424
                data_parallel_rank=self._dp_ranks,
            )
1425

1426 1427
        return dataloader

Z
zhaoyingli 已提交
1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441
    def _prepare_dataloader_from_generator(
        self,
        dataset,
        capacity=None,
        use_double_buffer=True,
        iterable=True,
        return_list=False,
        use_multiprocess=False,
        drop_last=True,
        batch_size=1,
        epochs=1,
        steps_per_epoch=None,
        collate_fn=None,
    ):
1442 1443

        if self._strategy.gradient_merge and batch_size is not None:
Z
zhaoyingli 已提交
1444 1445 1446 1447 1448
            assert (
                batch_size % self._k_steps == 0
            ), "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(
                batch_size, self._k_steps
            )
1449 1450
            batch_size //= self._k_steps

1451 1452 1453
        dist_context = self._dist_contexts[self._mode]
        dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
        dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank]
1454 1455 1456 1457 1458 1459
        dist_main_block = dist_main_prog.global_block()

        # NOTE: Get feed_list, then insert dataloader op with sharded var shape.
        # Cause predict_program does not contain labels var,
        # then we will add labels var from serial_program to dist_program,
        # that maintains the length of feed_list equal to the length of dataset's values.
1460 1461
        inputs_var = dist_context.serial_feed_vars["inputs"]
        labels_var = dist_context.serial_feed_vars["labels"]
1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488
        feed_list = []
        for var in inputs_var + labels_var:
            if var.name in dist_main_block.vars:
                feed_list.append(dist_main_block.vars[var.name])
            else:
                copy_var = dist_main_block._clone_variable(var, var.persistable)
                copy_var.desc.set_original_id(var.desc.original_id())
                feed_list.append(copy_var)

        places = paddle.static.cuda_places()
        with static.program_guard(dist_main_prog, dist_startup_prog):
            dataloader = DistributedDataLoaderFromGenerator(
                dataset=dataset,
                feed_list=feed_list,
                capacity=capacity,
                use_double_buffer=use_double_buffer,
                iterable=iterable,
                return_list=return_list,
                use_multiprocess=use_multiprocess,
                drop_last=drop_last,
                places=places,
                batch_size=batch_size,
                epochs=epochs,
                steps_per_epoch=steps_per_epoch,
                collate_fn=collate_fn,
                split_data=self._strategy.split_data,
                data_parallel_world_size=self._dp_world_sizes,
Z
zhaoyingli 已提交
1489 1490
                data_parallel_rank=self._dp_ranks,
            )
1491
        self._prepare_reader(feed_list)
1492 1493 1494 1495 1496
        return dataloader

    def _tune(self, tune_data, tune_sample_split=None, batch_size=1):
        self._mode = 'train'
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
Z
zhaoyingli 已提交
1497 1498
            tune_data, tune_sample_split, batch_size
        )
1499 1500
        self._optimization_tuning(self._mode, tune_data, batch_size)

1501 1502
    def _validate_spec(self, specs):
        specs = to_list(specs)
1503
        self._k_steps = self._strategy.gradient_merge.k_steps
1504 1505
        if specs is not None:
            for i, spec in enumerate(specs):
Z
zhaoyingli 已提交
1506 1507 1508 1509
                if not isinstance(spec, InputSpec):
                    raise TypeError(
                        "'spec' must be object of class `paddle.static.InputSpec`."
                    )
1510 1511
                if spec.name is None:
                    raise ValueError(
Z
zhaoyingli 已提交
1512 1513 1514 1515
                        "Requires Input[{}].name != None, but receive `None` with {}.".format(
                            i, spec
                        )
                    )
1516 1517
                if self._k_steps > 1:
                    shape = list(spec.shape)
Z
zhaoyingli 已提交
1518 1519 1520 1521 1522
                    assert (
                        shape[0] % self._k_steps == 0
                    ), "Requires batch_size[{}] to be divisible by k_steps[{}].".format(
                        spec.shape[0], self._k_steps
                    )
1523 1524
                    shape[0] //= self._k_steps
                    spec.shape = shape
Z
zhaoyingli 已提交
1525 1526 1527 1528 1529 1530 1531 1532 1533
        return specs or []

    def _validate_vars(self, vars):
        vars = to_list(vars)
        if vars is not None:
            for i, var in enumerate(vars):
                if not isinstance(var, Variable):
                    raise TypeError("'var' must be a `Variable`.")
        return vars or []
1534

1535 1536 1537 1538
    def _is_local_var(self, var):
        var_name = _to_name_str(var)
        return var_name in self.main_program.global_block().vars

1539 1540 1541 1542
    def _set_recompute_ckpts(self):
        # NOTE hack to enable recompute in engine api for GPT-3
        # TODO support more PaddleNLP/CV models here

1543
        recompute = self._strategy.recompute
1544 1545

        # extract ckpts by specific model
1546
        if isinstance(self._model, paddle.nn.Layer):
Z
zhaoyingli 已提交
1547 1548 1549 1550 1551 1552
            if hasattr(
                self._model, "gpt"
            ) and self._model.__class__.__name__ in [
                'GPTForPretraining',
                'GPTForPretrainingAuto',
            ]:
1553
                exact_ckpts = self._model.gpt.checkpoints
1554
            else:
1555
                exact_ckpts = recompute.checkpoints
1556
        else:
1557
            exact_ckpts = recompute.checkpoints
1558 1559

        # modify strategy
1560 1561
        if recompute.enable:
            recompute.checkpoints = exact_ckpts[:]
1562
            logs = {
1563
                'Model Class': self._model.__class__.__name__,
Z
zhaoyingli 已提交
1564
                'Applied Recompute ckpts': exact_ckpts,
1565 1566 1567
            }
            self._logger.info(logs)

1568 1569 1570
    def _reset_metrics(self):
        for metric in self._metrics:
            metric.reset()
1571

1572 1573 1574 1575 1576 1577
    def _metrics_name(self):
        metrics_name = ['loss'] if self._loss else []
        for m in self._metrics:
            metrics_name.extend(to_list(m.name()))
        return metrics_name

1578
    def _switch_mode(self, mode):
Z
zhaoyingli 已提交
1579
        assert (
1580
            mode in self._dist_contexts
Z
zhaoyingli 已提交
1581
        ), "{} model is not ready, please call `prepare()` first.".format(mode)
1582 1583 1584
        self.to_mode(mode)

    def to_mode(self, mode):
Z
zhaoyingli 已提交
1585 1586 1587 1588 1589
        assert mode in [
            "train",
            "eval",
            "predict",
        ], "mode {} should be one of ['train', 'eval', 'predict']".format(mode)
1590
        self._mode = mode
1591 1592 1593

    def _set_state_dict(self, mode, strict, state_dict, dist_attr):
        dist_context = self._dist_contexts[mode]
1594
        program = dist_context.dist_main_programs[self._cur_rank]
1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609
        cur_dist_attr = get_dist_attr(program, dist_context)
        converter = Converter(state_dict, dist_attr, cur_dist_attr)
        state_dict = converter.convert(strict=strict)
        program.set_state_dict(state_dict)

    def save(self, path, training=True):
        """
        Saves the model, parameters, optimizer state to path.
        If `training` is set to False, only inference model will be saved.

        Args:
            path (str): The file prefix to save model. The format
                is 'dirname/file_prefix' or 'file_prefix'. if empty str.
                A exception will be raised.
            training (bool, optional): Whether to save for training. If not, save
1610
                for inference only. If `training` is set to True, the optimizer state
1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622
                will be saved. Otherwise, only the model and parameters are saved.
                This function will silently overwrite existing file at the target
                location. Default: True.

        Returns:
            None

        Examples:

            .. code-block:: python
                import paddle
                import paddle.vision.transforms as T
1623
                from paddle.distributed.fleet import auto
1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644
                from paddle.vision.datasets import MNIST

                transform = T.Compose([
                    T.Transpose(),
                    T.Normalize([127.5], [127.5])
                ])
                train_dataset = MNIST(mode='train', transform=transform)

                model = paddle.vision.models.LeNet()
                loss = paddle.nn.CrossEntropyLoss()
                optimizer = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=model.parameters())
                metrics = paddle.metric.Accuracy(topk=(1, 2))

                engine = auto.Engine(model, loss, optimizer, metrics)
                engine.fit(train_dataset,
                           epochs=1,
                           batch_size=64)
                engine.save("./my_model")

        """
1645
        if training:
1646
            assert self._mode in self._dist_contexts
1647
            dist_context = self._dist_contexts[self._mode]
1648 1649
            serial_program = dist_context.serial_main_program
            dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
Z
zhaoyingli 已提交
1650 1651 1652 1653 1654 1655
            self._saver.save(
                path,
                serial_program=serial_program,
                dist_main_program=dist_main_prog,
                dist_context=dist_context,
            )
1656
        else:
1657 1658 1659 1660 1661
            assert "predict" in self._dist_contexts
            dist_context = self._dist_contexts["predict"]
            feed_vars = dist_context.serial_feed_vars['inputs']
            fetch_vars = dist_context.serial_fetch_vars['outputs']
            dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
Z
zhaoyingli 已提交
1662 1663 1664 1665 1666 1667 1668
            self._saver.save_inference_model(
                path,
                feed_vars,
                fetch_vars,
                self._executor,
                program=dist_main_prog,
            )
1669

1670 1671 1672 1673 1674 1675 1676 1677 1678 1679
    def load(self, path, strict=True, load_optimizer=True):
        """
        Load the stored model, parameters and optimizer states.

        Args:
            path (str): The prefix of files storing the model states and
                optimizer states.
            strict (bool, optional): Whether to skip the loading of mismatch
                parameter or raise an error when mismatch happens (not found
                the parameter in file storing model states of or receives a
Z
zhaoyingli 已提交
1680
                mismatch shape). Default: True.
1681
            load_optimizer (bool, optional): If True, the stored optimizer
1682
                states is restored. Otherwise, the optimizer states is initialized
Z
zhaoyingli 已提交
1683
                from scratch. Default: True.
1684 1685 1686 1687 1688 1689 1690 1691 1692

        Returns:
            None

        Examples:

            .. code-block:: python
                import paddle
                import paddle.vision.transforms as T
1693
                from paddle.distributed.fleet import auto
1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713
                from paddle.vision.datasets import MNIST

                transform = T.Compose([
                    T.Transpose(),
                    T.Normalize([127.5], [127.5])
                ])
                train_dataset = MNIST(mode='train', transform=transform)

                model = paddle.vision.models.LeNet()
                loss = paddle.nn.CrossEntropyLoss()
                optimizer = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=model.parameters())
                metrics = paddle.metric.Accuracy(topk=(1, 2))

                engine = auto.Engine(model, loss, optimizer, metrics)
                engine.fit(train_dataset,
                           epochs=1,
                           batch_size=64)
                engine.save("./my_model")
                engine.load("./my_model")
1714

1715 1716 1717
        """
        self._strict = strict
        self._state_dict, self._dist_attr = self._saver.load(
Z
zhaoyingli 已提交
1718 1719
            path, load_optimizer
        )
1720
        return self._state_dict, self._dist_attr
1721

Z
zhaoyingli 已提交
1722
    def cost(self, inputs_spec=None, labels_spec=None, mode=None):
1723 1724 1725 1726 1727 1728 1729 1730 1731 1732
        """
        Get and Print cost, including memory of every rank,
        max memory among all ranks, and the global cost of one step based on
        communication cost(computation cost is 0 by default).
        In the future, the flops information of every rank and global cost including
        computation cost will be added.

        Args:
            inputs_spec(InputSpec): The specification of inputs. Default: None.
            labels_spec(InputSpec): The specification of labels. Default: None.
Z
zhaoyingli 已提交
1733
            mode (str): The engine mode must be in ["train", "predict", "eval"]. Default: None.
1734 1735 1736 1737 1738 1739 1740

        Returns:
            Return the global execution time (ms) and max memory (B).

        """
        # Check parallel mode
        if self._strategy.auto_mode == "full":
Z
zhaoyingli 已提交
1741
            self._logger.info(
1742 1743 1744 1745 1746
                "The cost will be calcudated in the search process when the auto mode is full."
            )
            return

        # Check mode
Z
zhaoyingli 已提交
1747 1748 1749 1750 1751 1752 1753 1754
        mode = mode if mode is not None else self._mode
        assert mode is not None, "Please set mode."
        if mode not in self._has_prepared:
            raise ValueError(
                "The mode {} is not in accepted modes {}".format(
                    mode, list(self._has_prepared.keys())
                )
            )
1755 1756
        self.to_mode(mode)

Z
zhaoyingli 已提交
1757 1758 1759
        if inputs_spec is not None and not self._has_prepared[mode]:
            self._inputs_spec = self._validate_spec(inputs_spec)
            self._labels_spec = self._validate_spec(labels_spec)
1760 1761
            self._build(mode)
            self._plan(mode)
1762
        else:
1763 1764
            if _non_static_mode() or self._dygraph_mode:
                raise ValueError(
Z
zhaoyingli 已提交
1765
                    "Please call `prepare()` or `fit()` or  `evaluate()` or  `predict()` before calling `cost()`."
1766
                )
Z
zhaoyingli 已提交
1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778
            else:
                self._logger.info(
                    "The program whose cost to be estimated must be static default program. Otherwise, please call `prepare()`before calling `cost()`."
                )
                program = paddle.static.default_main_program()
                if (
                    not program.global_block().ops
                    or not program.global_block().ops
                ) and not self._has_prepared[mode]:
                    raise ValueError(
                        "Please call `prepare()` or `fit()` or  `evaluate()` or  `predict()` before calling `cost()`."
                    )
1779

1780 1781
        # Estimate the exec cost and max memory
        global_cost, max_memory = get_cost_from_engine(self, mode)
1782

1783
        return global_cost.time, max_memory
1784 1785 1786

    @property
    def main_program(self):
1787 1788
        dist_context = self._dist_contexts[self._mode]
        return dist_context.dist_main_programs[self._cur_rank]
1789 1790 1791

    @property
    def startup_program(self):
1792 1793
        dist_context = self._dist_contexts[self._mode]
        return dist_context.dist_startup_programs[self._cur_rank]
1794 1795 1796

    @property
    def dist_context(self):
1797
        return self._dist_contexts[self._mode]
1798 1799 1800

    @property
    def serial_main_program(self):
1801 1802
        dist_context = self._dist_contexts[self._mode]
        return dist_context.serial_main_program
1803 1804 1805

    @property
    def serial_startup_program(self):
1806 1807 1808 1809 1810 1811 1812
        dist_context = self._dist_contexts[self._mode]
        return dist_context.serial_startup_program

    @property
    def feed_vars(self):
        dist_context = self._dist_contexts[self._mode]
        return dist_context.serial_feed_vars
1813 1814 1815

    @property
    def fetch_vars(self):
1816 1817 1818 1819 1820 1821 1822 1823 1824
        dist_context = self._dist_contexts[self._mode]
        return dist_context.serial_fetch_vars

    @property
    def optimizer(self):
        dist_context = self._dist_contexts[self._mode]
        if dist_context._serial_optimizer:
            return dist_context._serial_optimizer
        return self._optimizer
1825 1826 1827

    @property
    def inputs(self):
1828
        return self._inputs
1829 1830 1831

    @property
    def labels(self):
1832
        return self._labels