engine.py 71.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import copy
16
import json
17
import logging
18
import numbers
19 20
import os
import random
21

22 23
import numpy as np

24
import paddle
25
import paddle.distributed.auto_parallel.static.utils as auto_utils
26
from paddle import static, utils
27
from paddle.distributed import fleet
28 29 30
from paddle.fluid.executor import _to_name_str
from paddle.framework import IrGraph
from paddle.framework import _current_expected_place as _get_device
31
from paddle.framework import core, in_dynamic_mode
32
from paddle.metric import Metric
33
from paddle.static import InputSpec, Operator, Variable, global_scope
34

35 36 37
from ...utils.log_utils import get_logger
from ..interface import CollectionNames, fetch, get_collection
from ..strategy import Strategy
Z
zhaoyingli 已提交
38
from .callbacks import config_callbacks
39
from .cluster import Cluster, get_default_cluster
40 41 42
from .converter import Converter
from .cost.estimate_cost import get_cost_from_engine
from .dist_context import DistributedContext, get_default_distributed_context
43 44
from .dist_loader import (
    DistributedDataLoader,
45
    DistributedDataLoaderFromGenerator,
46
)
47 48 49 50 51 52
from .dist_op import DistributedOperator
from .dist_saver import DistributedSaver
from .helper import ProgramHelper
from .parallelizer_v2 import Parallelizer
from .planner_v2 import Planner
from .process_group import get_all_process_groups, new_process_group
53

54 55

class Engine:
56
    """
57 58
    An Engine object can provide the full power of auto parallel to users.
    With the help of it, users can easily obtain the abilities of the
59 60 61 62 63 64 65
    distributed training and inference. It also support the dynamic graph and
    static graph at the same time.

    Args:
        model (paddle.nn.Layer, optional): The model is an instance of
            paddle.nn.Layer.
        loss (Loss|Callable|None, optional): The loss can be a `paddle.nn.Layer`
66 67
            instance or any callable function taken the predicted values and
            ground truth values as input. It can be None when there is no loss.
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
            Default: None.
        optimizer (Optimizer|None, optional): The optimizer need to be set in training
            and should be None in eval and predict mode. Default: None.
        metrics (Metric|list[Metric]|None, optional): If metrics is set, all
            metrics will be calculated and output in train/eval mode. Default: None.
        cluster (Cluster|None, optional): The cluster represents the topology information
            about the used physical devices. Default: None. (Unused for now)
        strategy (Strategy|None, optional): The strategy is used to configure the
        parallelization and optimization behaviors. Default: None.

    Examples:

        .. code-block:: python

            import paddle
            import paddle.vision.transforms as T
84
            from paddle.distributed.fleet import auto
85 86 87 88 89 90 91 92 93 94
            from paddle.vision.datasets import MNIST

            transform = T.Compose([
                T.Transpose(),
                T.Normalize([127.5], [127.5])
            ])
            train_dataset = MNIST(mode='train', transform=transform)
            valid_dataset = MNIST(mode='test', transform=transform)

            model = paddle.vision.models.LeNet()
95
            loss = paddle.nn.CrossEntropyLoss()
96 97 98 99
            optimizer = paddle.optimizer.Adam(
                learning_rate=0.001, parameters=model.parameters())
            metrics = paddle.metric.Accuracy(topk=(1, 2))

100 101
            engine = auto.Engine(model, loss, optimizer, metrics)
            # fit
102 103 104
            engine.fit(train_dataset,
                       epochs=2,
                       batch_size=64)
105
            # evaluate
106 107 108 109 110 111 112
            engine.evaluate(valid_dataset,
                            batch_size=64)
            # predict
            engine.predict(valid_dataset,
                           batch_size=64)
            # save
            engine.save("./my_model")
113
            # load
114 115 116
            engine.load("./my_model")

    """
117

118 119 120 121 122 123 124 125 126 127 128 129 130 131
    def __init__(
        self,
        model=None,
        loss=None,
        optimizer=None,
        metrics=None,
        cluster=None,
        strategy=None,
    ):
        if (
            model
            and not isinstance(model, paddle.nn.Layer)
            and not callable(model)
        ):
132 133 134 135
            raise TypeError(
                "'model must be sub classes of `paddle.nn.Layer` or any callable function."
            )
        self._model = model
136 137 138 139 140 141 142 143 144

        if (
            loss
            and not isinstance(loss, (paddle.nn.Layer, Variable))
            and not callable(loss)
        ):
            raise TypeError(
                "'loss' must be sub classes of `paddle.nn.Layer` or any callable function or a Variable."
            )
145 146 147
        self._loss = loss

        if optimizer and not isinstance(
148
            optimizer,
149
            (paddle.optimizer.Optimizer, paddle.static.Optimizer),
150
        ):
151 152
            raise TypeError(
                "'optimizer' must be object of class `paddle.optimizer.Optimizer`"
153
                " or `paddle.static.Optimizer`."
154
            )
155
        self._optimizer = auto_utils.validate_opt(optimizer)
156 157

        metrics = metrics or []
158
        for metric in auto_utils.to_list(metrics):
159 160 161 162 163 164
            if metric and not isinstance(metric, Metric):
                raise TypeError(
                    "{} is not sub class of Metric".format(
                        metric.__class__.__name__
                    )
                )
165
        self._metrics = auto_utils.to_list(metrics)
166 167 168 169 170 171 172 173 174 175 176 177 178

        if cluster and not isinstance(cluster, Cluster):
            raise TypeError(
                "'cluster' must be the object or class `paddle.distributed.auto_parallel.Cluster`"
            )
        self._cluster = cluster or get_default_cluster()

        if strategy and not isinstance(strategy, Strategy):
            raise TypeError(
                "'strategy' must be object of class `paddle.distributed.auto_parallel.Strategy`"
            )
        self._strategy = strategy or Strategy()

179
        self._logger = get_logger(logging.INFO)
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196

        self._json_config = None
        if cluster:
            self._cluster = cluster
        else:
            if os.getenv("PADDLE_AUTO_PARALLEL_CONFIG"):
                try:
                    path = os.getenv("PADDLE_AUTO_PARALLEL_CONFIG")
                    with open(path, "r") as f:
                        self._json_config = json.load(f)
                except Exception as e:
                    self._logger.info(
                        "Load json failed, please check json file, engine will run default config."
                    )
                    self._json_config = None
            self._cluster = get_default_cluster(self._json_config)

197
        if os.getenv("POD_NAME"):
198 199
            self._logger.info(
                "Distribute training by paddle.distributed.launch"
200
            )
201
            fleet.init(is_collective=True)
202

203 204 205 206 207 208
        # for compute cost
        # TODO: remove _fwd_main_progs and _orig_optimizer
        self._fwd_dist_contexts = {}
        self._fwd_main_progs = {}
        self._orig_optimizer = copy.deepcopy(self._optimizer)

209
        self._executor = None
210 211 212
        self._cur_rank = paddle.distributed.get_rank()
        self._nranks = paddle.distributed.get_world_size()
        self._saver = DistributedSaver()
213

214 215
        self._orig_main_prog = static.default_main_program()
        self._orig_startup_prog = static.default_startup_program()
216
        self._orig_dist_context = get_default_distributed_context()
217
        self._dist_contexts = {}
218
        self._planners = {}
219 220
        self._has_prepared = {"train": False, "eval": False, "predict": False}
        self._has_prepared_reader = {
221 222
            "train": False,
            "eval": False,
223
            "predict": False,
224
        }
225 226 227 228
        self._inputs_spec = []
        self._labels_spec = []
        self._inputs = []
        self._labels = []
229
        self._losses = []
230

231
        self._mode = None
232 233
        self._skip_build = False
        self._outside_dataloader = False
234
        self._planned_mode = None
235 236
        self._dygraph_mode = False
        self._tuning = self._strategy.tuning
237 238 239 240 241
        self._acc_steps = 1
        if self._strategy.gradient_merge.enable:
            self._acc_steps = self._strategy.gradient_merge.k_steps
        elif self._strategy.pipeline.enable:
            self._acc_steps = self._strategy.pipeline.accumulate_steps
242

Z
zhaoyingli 已提交
243 244
        self.history = None

245
        paddle.framework.set_flags({'FLAGS_new_executor_sequential_run': 1})
246
        paddle.framework.set_flags({'FLAGS_new_executor_static_build': 1})
247

248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265
    def _prepare_data_spec(self, data, split, batch_size):
        inputs_spec = []
        labels_spec = []
        if isinstance(data, paddle.io.IterableDataset):
            if split is None:
                inputs, labels = next(iter(data))
            else:
                sample = next(iter(data))
                inputs = sample[:split]
                labels = sample[split:]
        elif isinstance(data, paddle.io.Dataset):
            if split is None:
                inputs, labels = data[0]
            else:
                sample = data[0]
                inputs = sample[:split]
                labels = sample[split:]
        else:
266
            raise TypeError(
C
chenxujun 已提交
267
                "Data should be a Dataset or IterableDataset, but received {}.".format(
268 269 270
                    type(data).__name__
                )
            )
271 272
        inputs = auto_utils.to_list(inputs)
        labels = auto_utils.to_list(labels)
273 274

        num_shards = self._strategy.dataset.num_shards
275

276 277 278 279 280 281 282 283 284 285 286 287
        def _adjust_item_spec(num_shards, spec):
            if num_shards > 1 and len(spec.shape) > 1:
                spec.shape[0] = spec.shape[0] * num_shards

        def _infer_item_spec(item, name, batch_size, specs):
            if isinstance(item, np.ndarray):
                spec = InputSpec.from_numpy(item, name)
                if batch_size is None:
                    _adjust_item_spec(num_shards, spec)
                    specs.append(spec)
                else:
                    specs.append(spec.batch(batch_size))
W
wanghuancoder 已提交
288
            elif isinstance(item, (Variable, core.eager.Tensor)):
289
                spec = InputSpec.from_tensor(item, name)
290
                _adjust_item_spec(num_shards, spec)
291 292 293 294
                if batch_size is None:
                    specs.append(spec)
                else:
                    specs.append(spec.batch(batch_size))
295
            elif isinstance(item, numbers.Number):
296
                specs.append(InputSpec([batch_size], type(item), name))
297 298 299 300 301 302
            else:
                raise TypeError(
                    "The sample's dtype returned of dataset should be number, np.ndarray or Tensor, but got {}".format(
                        type(item).__name__
                    )
                )
303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318

        if inputs is not None:
            for i, item in enumerate(inputs):
                assert item is not None, "Receive None input."
                name = "input" + str(i)
                _infer_item_spec(item, name, batch_size, inputs_spec)
        if labels is not None:
            for i, item in enumerate(labels):
                assert item is not None, "Receive None input."
                name = "label" + str(i)
                _infer_item_spec(item, name, batch_size, labels_spec)

        inputs_spec = self._validate_spec(inputs_spec)
        labels_spec = self._validate_spec(labels_spec)
        return inputs_spec, labels_spec

319
    def _prepare_data_tensor(self, inputs_spec, labels_spec, inputs, labels):
320
        if in_dynamic_mode() or self._dygraph_mode:
321 322
            raise ValueError("Only support static graph mode.")

323
        if inputs_spec:
324 325 326 327 328
            assert isinstance(
                inputs_spec, list
            ), "inputs should be list, but received {}".format(
                type(inputs_spec)
            )
329 330
            assert isinstance(
                inputs, list
331
            ), f"inputs should be list, but received {type(inputs)}"
332 333 334 335 336 337
            assert len(inputs_spec) == len(
                inputs
            ), "the number of `inputs_spec` should be equal to `inputs`'s."
            for input_spec, input in zip(inputs_spec, inputs):
                if input_spec.shape != input.shape:
                    input.desc.set_shape(input_spec.shape)
338
        if labels_spec:
339 340 341 342 343
            assert isinstance(
                labels_spec, list
            ), "labels should be list, but received {}".format(
                type(labels_spec)
            )
344 345
            assert isinstance(
                labels, list
346
            ), f"labels should be list, but received {type(labels)}"
347 348 349 350 351 352 353
            assert len(labels_spec) == len(
                labels
            ), "the number of `labels_spec` should be equal to `labels`'s."
            for label_spec, label in zip(labels_spec, labels):
                if label_spec.shape != label.shape:
                    label.desc.set_shape(label_spec.shape)

354 355
        return inputs, labels

356
    def _prepare_reader(self, feed_list=[]):
357
        dist_context = self._dist_contexts[self._mode]
358
        dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
359 360 361 362
        dist_main_block = dist_main_prog.global_block()

        # NOTE: this list may be changed if Paddle changes the existing rules.
        related_reader_ops = [
363 364 365
            "create_py_reader",
            "create_double_buffer_reader",
            "read",
366 367 368 369 370 371 372 373 374 375 376 377 378
        ]
        # remove the first three ops if multiple run fit/evaluate/predict
        if dist_main_block.ops[0].type == 'create_py_reader':
            for i in range(len(related_reader_ops)):
                if dist_main_block.ops[0].type in related_reader_ops:
                    dist_main_block._remove_op(0, sync=False)
        dist_main_block._sync_with_cpp()
        # Step 1: find the reader ops
        reader_op_indices = []
        for idx, op in enumerate(dist_main_block.ops):
            if op.type in related_reader_ops:
                reader_op_indices.append(idx)
        # Step 2: insert the new reader ops to cpp
379 380
        # record the read ops' desc to insert to program of forward task_node
        read_ops_desc = []
381 382 383 384
        new_reader_ops = []
        for idx in reversed(reader_op_indices):
            new_op_desc = dist_main_block.desc._prepend_op()
            new_op_desc.copy_from(dist_main_block.ops[idx].desc)
385
            read_ops_desc.append(new_op_desc)
386 387 388
            new_op = Operator(
                dist_main_block, new_op_desc, type=new_op_desc.type()
            )
389 390 391 392 393 394 395 396 397 398 399 400 401 402 403
            new_reader_ops.append(new_op)
            dist_op = DistributedOperator(new_op)
            dist_context.add_dist_op_for_program(dist_op)
        # Step 3: insert the new reader ops to python
        for new_op in new_reader_ops:
            dist_main_block.ops.insert(0, new_op)
        for i in range(len(reader_op_indices)):
            reader_op_indices[i] += len(reader_op_indices)
        # Step 4: remove the old reader ops from python and cpp
        for idx in reversed(reader_op_indices):
            op = dist_main_block.ops.pop(idx)
            dist_main_block.desc._remove_op(idx, idx + 1)
        dist_main_block._sync_with_cpp()
        self._has_prepared_reader[self._mode] = True

404 405 406 407 408
        # Insert read op to forward TaskNode for fleet executor if 1F1B pass is setted
        if (
            self.main_program._pipeline_opt
            and not auto_utils.use_new_executor()
        ):
409 410
            assert "tasks" in self.main_program._pipeline_opt["fleet_opt"]
            fleet_opt = self.main_program._pipeline_opt["fleet_opt"]
411 412 413 414 415 416
            fwd_task = None
            if self._strategy.pipeline.schedule_mode == "1F1B":
                fwd_task = fleet_opt["tasks"][1]
            elif self._strategy.pipeline.schedule_mode == "stream":
                fwd_task = fleet_opt["tasks"][0]
            assert fwd_task is not None
417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434
            fwd_prog = fwd_task.get_program()
            fwd_block = fwd_prog.global_block()

            for var in feed_list:
                if var.name not in fwd_block.vars:
                    fwd_block._clone_variable(var)

            for op_desc in read_ops_desc:
                new_op_desc = fwd_block.desc._prepend_op()
                new_op_desc.copy_from(op_desc)
                new_op = Operator(
                    fwd_block, new_op_desc, type=new_op_desc.type()
                )
                fwd_block.ops.insert(0, new_op)

            fwd_block._sync_with_cpp()
            fwd_task.set_program(fwd_prog)

435 436 437 438 439
    def _prepare_feed(self, data, user_feeds, mode):
        feeds = {}
        if data is not None:
            if isinstance(data, (list, tuple)):
                if len(data) == 1 and isinstance(data[0], dict):
440 441
                    for name, value in data[0].items():
                        feeds[name] = value
442
                else:
443
                    raise ValueError(f"Unsupported data {data}")
444
            elif isinstance(data, dict):
445 446
                for name, value in data.items():
                    feeds[name] = value
447
            else:
448
                raise ValueError(f"Unsupported data {data}")
449
        if user_feeds is not None:
450 451 452 453 454
            assert isinstance(
                user_feeds, dict
            ), "user_feeds must be a dict, but receive {}".format(
                type(user_feeds).__name__
            )
455 456
            for name, data in user_feeds.items():
                feeds[name] = data
457 458
        return feeds

459
    def _prepare_fetch(self, user_fetches, mode):
460
        if user_fetches is not None:
461 462 463 464 465
            assert isinstance(
                user_fetches, list
            ), "user_fetches must be a list, but receive {}".format(
                type(user_fetches).__name__
            )
466
        fetch_names = []
467
        fetch_indices = []
468

469 470
        def _process_fetch_group(group_name, var_list):
            group_indices = []
471
            for var in var_list:
472 473 474 475 476 477 478 479
                # Remove duplicate var_names
                if self._is_local_var(var):
                    var_name = _to_name_str(var)
                    if var_name not in fetch_names:
                        fetch_names.append(var_name)
                    group_indices.append(fetch_names.index(var_name))
            fetch_indices.append(group_indices)

480 481
        dist_context = self._dist_contexts[mode]
        fetch_vars = dist_context.serial_fetch_vars
482
        if mode != "predict":
483
            _process_fetch_group("loss", fetch_vars["loss"])
484
        if mode != "predict":
485
            metrics = fetch_vars["metrics"]
486 487 488
            for i, var_list in enumerate(metrics):
                _process_fetch_group("metrics_" + str(i), var_list)
        if mode == "predict":
489
            _process_fetch_group("outputs", fetch_vars["outputs"])
490
        for usr_fetch in user_fetches or []:
491 492
            var_name = _to_name_str(usr_fetch)
            fetch(var_name)
493 494 495
        user_fetches_collection = [
            item[1] for item in get_collection(CollectionNames.FETCHES)
        ]
496
        var_list = user_fetches_collection or []
497 498 499
        _process_fetch_group("fetches", var_list)
        return fetch_names, fetch_indices

500 501 502 503 504 505 506 507 508 509
    def _prepare_logger(
        self,
        outs,
        epoch=None,
        step=None,
        lr=None,
        fetch_names=None,
        fetch_indices=None,
        mode=None,
    ):
Z
zhaoyingli 已提交
510
        logs = {}
511
        if epoch is not None:
Z
zhaoyingli 已提交
512
            logs["epoch"] = epoch
513
        if step is not None:
Z
zhaoyingli 已提交
514
            logs["step"] = step + 1
515
        if lr is not None:
Z
zhaoyingli 已提交
516
            logs["lr"] = lr
517 518
        group_idx = 0
        if mode != "predict":
Z
zhaoyingli 已提交
519
            # logging loss
520
            loss_indices = fetch_indices[group_idx]
Z
zhaoyingli 已提交
521
            assert len(loss_indices) <= 1
522
            for idx in loss_indices:
523
                logs["loss"] = outs[idx]
524
            group_idx += 1
Z
zhaoyingli 已提交
525
            # logging metrics
526 527
            dist_context = self._dist_contexts[mode]
            metric_vars = dist_context.serial_fetch_vars["metrics"]
528 529 530 531 532 533 534 535 536
            if metric_vars:
                for metric in self._metrics:
                    metrics_indices = fetch_indices[group_idx]
                    metric_out = []
                    for idx in metrics_indices:
                        metric_out.append(outs[idx])
                    if metric_out:
                        metric.update(*metric_out)
                        results = metric.accumulate()
537
                        for i, res in enumerate(auto_utils.to_list(results)):
Z
zhaoyingli 已提交
538
                            logs[metric.name()[i]] = res
539
                    group_idx += 1
Z
zhaoyingli 已提交
540 541 542 543 544 545 546
        # logging outputs
        elif mode == "predict":
            outputs_indices = fetch_indices[group_idx]
            logs_out = {}
            for idx in outputs_indices:
                logs_out["out%d" % (idx)] = outs[idx]
            logs["outputs"] = logs_out
547 548
            group_idx += 1
        # logging user fetches
Z
zhaoyingli 已提交
549 550
        collect_fetches = get_collection(CollectionNames.FETCHES)
        logs_fetch = {}
551 552 553 554
        for name, var_name in collect_fetches:
            if var_name in fetch_names:
                idx = fetch_names.index(var_name)
                logs_fetch[name or var_name] = outs[idx]
Z
zhaoyingli 已提交
555 556
        logs["fetches"] = logs_fetch
        return logs
557

558
    def _prepare_program(self, mode, init_parameters=True):
559 560 561 562 563 564
        # Do the build process
        self._build(mode)
        # Do the planning process
        self._plan(mode)
        # Do the parallel process
        self._parallel(mode)
565 566 567 568 569
        # Init comm
        self._init_comm()
        if init_parameters:
            # startup program
            self._initialize(mode)
570 571
        self._has_prepared[mode] = True

572
    def _build(self, mode):
573
        if in_dynamic_mode() or self._dygraph_mode:
574
            paddle.disable_static()
575 576 577
            self._dygraph_mode = True
            self._logger.info("Building model with 'to_static' method.")

578
            self.program_helper = ProgramHelper(
579 580 581 582 583
                self._model,
                self._loss,
                self._metrics,
                self._inputs_spec,
                self._labels_spec,
584
            )
585
            # build forward main program
586 587
            with utils.unique_name.guard():
                self.program_helper.build_program(mode)
588

589 590 591
            self.concrete_program = self.program_helper.concrete_program
            serial_main_prog = self.program_helper.main_program
            serial_startup_prog = self.program_helper.startup_program
592

593 594
            self._inputs = self.program_helper.input_vars
            self._labels = self.program_helper.label_vars
595
            outputs = self.program_helper.output_vars
596
            self._losses = self.program_helper.loss_vars
597
            metrics = self.program_helper.metric_vars
598

599
            paddle.enable_static()
600
        else:
601 602 603
            # build program in static mode
            dist_context = self._dist_contexts.get(mode, None)
            if dist_context is not None:
604 605
                return

606
            outputs = []
607
            metrics = []
608
            self._losses = []
609 610
            serial_main_prog = self._orig_main_prog.clone()
            serial_startup_prog = self._orig_startup_prog.clone()
611
            if not self._skip_build:
612 613 614
                with static.program_guard(
                    serial_main_prog, serial_startup_prog
                ), utils.unique_name.guard():
615 616 617 618 619 620 621
                    self._inputs = [
                        s._create_feed_layer() for s in self._inputs_spec
                    ]
                    self._labels = [
                        s._create_feed_layer() for s in self._labels_spec
                    ]

622
                    outputs = auto_utils.to_list(self._model(*self._inputs))
623

624
                    if mode != "predict" and self._loss:
625 626 627 628 629
                        assert isinstance(
                            self._loss, paddle.nn.Layer
                        ) or callable(
                            self._loss
                        ), "the type of `loss` of the Engine arguments should be sub classes of `paddle.nn.Layer` or any callable function."
630
                        self._losses = auto_utils.to_list(
631 632
                            self._loss(*(outputs + self._labels))
                        )
633

634
                    if mode != "predict" and (outputs or self._labels):
635 636
                        for metric in self._metrics:
                            metrics.append(
637
                                auto_utils.to_list(
638 639
                                    metric.compute(*(outputs + self._labels))
                                )
640
                            )
Z
zhaoyingli 已提交
641
            elif mode == "train":
642 643 644
                assert isinstance(
                    self._loss, Variable
                ), "the type of `loss` of the Engine arguments should be Variable."
645
                self._losses = auto_utils.to_list(self._loss)
646 647 648 649 650 651 652

        default_ctx = get_default_distributed_context()
        if not default_ctx.has_annotation:
            # We build the world process group because the data parallel
            # needs all ranks by default.
            new_process_group(list(range(self._nranks)))
            default_ctx.data_parallel = True
653 654 655 656 657 658
            self._inputs = [
                auto_utils.set_data_parallel(var) for var in self._inputs
            ]
            self._labels = [
                auto_utils.set_data_parallel(var) for var in self._labels
            ]
659

660
        feed_vars = {"inputs": self._inputs, "labels": self._labels}
661 662

        fetch_vars = {
663
            "outputs": paddle.utils.flatten(outputs),
664
            "loss": self._losses,
665
            "metrics": metrics,
666 667
        }

668 669 670
        if mode != "train":
            serial_main_prog = serial_main_prog.clone(for_test=True)

671 672 673
        auto_utils.set_recompute_segments(
            self._model, self._losses, self._strategy, serial_main_prog
        )
674
        self._dist_contexts[mode] = DistributedContext(
675 676 677
            serial_main_prog,
            serial_startup_prog,
            self._optimizer,
678 679 680 681 682
            self._losses,
            feed_vars,
            fetch_vars,
            self._cluster,
            self._strategy,
683
            self._json_config,
684 685 686 687 688 689
        )
        self._fwd_dist_contexts[mode] = DistributedContext(
            serial_main_prog,
            serial_startup_prog,
            self._optimizer,
            self._losses,
690 691 692 693
            feed_vars,
            fetch_vars,
            self._cluster,
            self._strategy,
694
            self._json_config,
695
        )
696
        self._dist_contexts[mode].gradient_scale = self._strategy.gradient_scale
697
        self._fwd_main_progs[mode] = serial_main_prog.clone()
698

699 700 701
    def _optimization_tuning(self, mode, dataset, batch_size):
        if not self._tuning.enable:
            raise ValueError("Please set `tuning.enable=True`.")
702

703 704 705 706 707 708 709 710
        assert mode == "train"
        # Do the build process
        self._build(mode)
        # Do the planning process
        self._plan(mode)

        dataset.dp_world_size = self._dp_world_sizes
        dataset.dp_rank = self._dp_ranks
711 712

        from .tuner.optimization_tuner import OptimizationTuner
713 714 715 716 717 718 719 720 721

        self._optimization_tuner = OptimizationTuner(
            self._dist_contexts[mode],
            dataset,
            self._inputs_spec,
            self._labels_spec,
            batch_size=batch_size,
            rank=self._cur_rank,
        )
722 723 724

        self._optimization_tuner.tune()

725
        if self._tuning.run_after_tuning:
726 727
            # update the strategy
            self._dist_contexts[
728 729
                mode
            ]._strategy = self._optimization_tuner.get_best_config()
730

731 732 733 734 735 736
    def _plan(self, mode):
        if self._planned_mode is None:
            self._planned_mode = mode
        else:
            self._init_dist_context(mode)

737 738
        self._planners[mode] = Planner(mode, self._dist_contexts[mode])
        self._planners[mode].plan()
739

740 741 742 743
        # infer data parallel info
        inputs_var = self._dist_contexts[mode].serial_feed_vars["inputs"]
        labels_var = self._dist_contexts[mode].serial_feed_vars["labels"]
        block = self._dist_contexts[mode].serial_main_program.global_block()
744
        # TODO: check this feed_list
745 746 747 748 749
        feed_list = []
        for var in inputs_var + labels_var:
            if var.name in block.vars:
                feed_list.append(block.vars[var.name])

750 751
        self._dp_world_sizes = []
        self._dp_ranks = []
752
        for feed_var in feed_list:
753
            dp_world_size, dp_rank = auto_utils.get_input_split_info(
754
                self._cur_rank, feed_var, self._dist_contexts[mode]
755
            )
756 757
            self._dp_world_sizes.append(dp_world_size)
            self._dp_ranks.append(dp_rank)
758

759
    def _parallel(self, mode, all_ranks=False):
760
        # Parallelize program based on the planner's results
L
Leo Chen 已提交
761
        # For now, the completer has to be passed to the Parallelizer,
C
chenxujun 已提交
762
        # because we may use it to complete the annotation of the backward and update.
763
        parallelizer = Parallelizer(
Y
yuehuayingxueluo 已提交
764 765 766
            mode,
            self._planners[mode].completer,
            self._dist_contexts[mode],
767
        )
768 769 770 771
        if not all_ranks:
            parallelizer.parallel(self._cur_rank)
        else:
            parallelizer.parallel_all()
772 773

    def _init_dist_context(self, mode):
774
        # Init dist_context['mode'] with the first planned dist_context
775 776 777 778 779 780 781 782 783 784
        # to guarantee that train/eval/predict mode have same parallel strategy
        dist_context = self._dist_contexts[mode]
        origin_main_prog = dist_context._original_serial_main_program
        ref_mode = self._planned_mode
        ref_dist_context = self._dist_contexts[ref_mode]
        ref_origin_main_prog = ref_dist_context._original_serial_main_program
        ref_blocks = ref_origin_main_prog.blocks
        for ib, block in enumerate(origin_main_prog.blocks):
            for iop, op in enumerate(block.ops):
                ref_op = ref_blocks[ib].ops[iop]
785 786 787 788 789 790 791 792
                assert (
                    op.type == ref_op.type
                ), "'{}' mode op '{}' is different with '{}' op '{}'. ".format(
                    mode, op.type, ref_mode, ref_op.type
                )
                ref_op_dist_attr = (
                    ref_dist_context.get_op_dist_attr_for_program(ref_op)
                )
793 794
                dist_context.set_op_dist_attr_for_program(op, ref_op_dist_attr)

795
    def _init_comm(self):
796 797 798 799
        if self._nranks > 1:
            # Traverse different rank programs and traverse each op of them,
            # instantiate communication by process_mapping.
            all_process_groups = get_all_process_groups()
800

801
            if self._strategy.auto_mode == "full_random":
802
                auto_utils.initialize_pg_in_full_mode(
803
                    all_process_groups, self._cur_rank
804
                )
805 806 807
            else:
                for process_group in all_process_groups:
                    process_group.instantiate()
808

809
    def _initialize(self, mode):
810
        self._place = _get_device()
811
        if isinstance(self._place, paddle.framework.CUDAPlace):
812 813 814
            self._place = paddle.framework.CUDAPlace(
                paddle.distributed.ParallelEnv().dev_id
            )
815

816 817 818 819 820
        if self._strategy.seed:
            paddle.seed(self._strategy.seed + self._dp_ranks[0])
            np.random.seed(self._strategy.seed + self._dp_ranks[0])
            random.seed(self._strategy.seed + self._dp_ranks[0])

821
        dist_context = self._dist_contexts[mode]
822
        if self._dygraph_mode:
823
            dist_main_program = dist_context.dist_main_programs[self._cur_rank]
824 825 826
            self.program_helper.init(
                dist_main_program, self._place, dist_context
            )
827

828
        if self._executor is None:
829
            self._executor = paddle.static.Executor(self._place)
830
            uninitialized = []
831 832 833
            dist_startup_prog = dist_context.dist_startup_programs[
                self._cur_rank
            ]
834 835 836 837 838 839 840 841
            for var in dist_startup_prog.list_vars():
                scope_var = global_scope().find_var(var.name)
                if scope_var and scope_var.get_tensor()._is_initialized():
                    continue
                uninitialized.append(var)
            if uninitialized:
                prune_startup_prog = dist_startup_prog._prune(uninitialized)
                self._executor.run(prune_startup_prog)
842

843
            if hasattr(self, "_state_dict") and hasattr(self, "_dist_attr"):
844 845 846
                self._set_state_dict(
                    mode, self._strict, self._state_dict, self._dist_attr
                )
847 848

        if self._strategy.reinit:
Z
zhaoyingli 已提交
849
            self._logger.info("NOTE: parameters will be re-initialized.")
850 851 852
            dist_startup_prog = dist_context.dist_startup_programs[
                self._cur_rank
            ]
853 854
            self._executor.run(dist_startup_prog)

855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871
    def fit(
        self,
        train_data,
        train_sample_split=None,
        batch_size=1,
        epochs=1,
        steps_per_epoch=None,
        log_freq=10,
        save_dir=None,
        save_freq=1,
        valid_data=None,
        valid_sample_split=None,
        valid_freq=1,
        valid_steps=None,
        collate_fn=None,
        callbacks=None,
        verbose=2,
872
        nvprof_range=[-1, -1],
873
    ):
874 875 876 877 878 879 880 881
        """
        Trains the model for a fixed number of epochs. If `valid_data` is set,
        evaluation will be done at the end of each epoch.

        Args:
            train_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None.
            train_sample_split (int, optional): Each sample of the train dataset is assumed
                to be a (input, label) pair by default and has two items. If each sample has
882
                more than two items, train_sample_split specifies how to split these items into
883
                input and label. The items before it are input and the left are label. Default: None.
884
            batch_size (int, optional): The batch size of train_data and valid_data if provided.
885 886 887
                The user's data will be used directly without batching if set to None. Default: 1.
            epochs (int, optional): The number of epochs to train the model. Default: 1.
            steps_per_epoch (int, optional): The total number of steps (batches of samples)
888
                is executed in one epoch before stating the next one. If None, it is equal to
889 890
                the number samples in your dataset divided by the batch size. Default: None.
            valid_data (Dataset, optional): An instance of paddle paddle.io.Dataset used for
891
                evaluation at the end of epoch. No evaluation will be done if set to None.
892
                Default: None. (Unsupported for now)
893
            valid_freq (int, optional): Only relevant if valid_data is provided. This specifies
894 895
                how many training epochs before a new evaluation is performed. Default: 1.
            valid_sample_split (int, optional): Only relevant if valid_data is provided.
896 897
                Each sample of the valid dataset is assumed to be a (input, label) pair
                by default and has two items. If each sample has more than two items,
898 899 900
                valid_sample_split specifies how to split these items into input and label.
                The items before it are input and the left are label. Default: None.
            valid_steps (int, optional): Only relevant if valid_data is provided.
901 902
                It is the total number of steps (batches of samples) to draw before
                stopping validation at the end of every epoch. If None, validation will run until the
903 904 905 906
                `valid_data` dataset is exhausted. The validation will start from the
                beginning of the dataset at each epoch. Default: None.
            collate_fn(callable, optional): function to generate mini-batch data by merging
                the sample list, None for only stack each fields of sample in axis
907
                0. Default None.
908 909
            callbacks (Callback|None, optional): A list of `Callback` instances to apply
                during training. Default: None. (Unused for now)
910
            nvprof_range(list, optional): A list of integers indicating nvprof ranges in form of [start_step, end_step]. Note that if start_step >= end_step, the nvprof will not apply.
911 912 913 914 915 916 917 918 919 920

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle
                import paddle.vision.transforms as T
921
                from paddle.distributed.fleet import auto
922 923 924 925 926 927 928 929 930
                from paddle.vision.datasets import MNIST

                transform = T.Compose([
                    T.Transpose(),
                    T.Normalize([127.5], [127.5])
                ])
                train_dataset = MNIST(mode='train', transform=transform)

                model = paddle.vision.models.LeNet()
931
                loss = paddle.nn.CrossEntropyLoss()
932 933 934 935
                optimizer = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=model.parameters())
                metrics = paddle.metric.Accuracy(topk=(1, 2))

936
                engine = auto.Engine(model, loss, optimizer, metrics)
937 938 939 940
                engine.fit(train_dataset,
                           epochs=2,
                           batch_size=64)
        """
941 942
        self._mode = 'train'
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
943 944
            train_data, train_sample_split, batch_size
        )
945
        micro_batch_size = self._validate_batch_size(batch_size)
946 947
        if not self._has_prepared[self._mode]:
            self._prepare_program(self._mode)
Z
zhaoyingli 已提交
948
        else:
949
            self._switch_mode(self._mode)
Z
zhaoyingli 已提交
950

951 952 953 954
        train_dataloader = self._prepare_dataloader_from_generator(
            dataset=train_data,
            capacity=70,
            iterable=False,
955
            batch_size=micro_batch_size,
956 957
            epochs=epochs,
            steps_per_epoch=steps_per_epoch,
958 959
            collate_fn=collate_fn,
        )
Z
zhaoyingli 已提交
960

961
        fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode)
Z
zhaoyingli 已提交
962 963 964 965

        cbks = config_callbacks(
            callbacks,
            engine=self,
966
            batch_size=micro_batch_size,
Z
zhaoyingli 已提交
967 968 969 970 971 972 973
            epochs=epochs,
            steps=train_dataloader._steps,
            log_freq=log_freq,
            save_freq=save_freq,
            save_dir=save_dir,
            verbose=verbose,
            metrics=self._metrics_name(),
974 975 976
            acc_step=1
            if self._strategy.pipeline.enable
            else self._acc_steps,  # lr update once every local batch
Z
zhaoyingli 已提交
977 978 979 980 981 982
        )

        cbks.on_begin('train')
        for epoch in range(epochs):
            logs = {}
            cbks.on_epoch_begin(epoch)
983

Z
zhaoyingli 已提交
984
            for step, _ in enumerate(train_dataloader):
985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006
                with paddle.profiler.utils._nvprof_range(
                    iter_id=step, start=nvprof_range[0], end=nvprof_range[1]
                ):
                    cbks.on_batch_begin('train', step, logs)
                    try:
                        outs = self._executor.run(
                            self.main_program,
                            fetch_list=fetch_names,
                            use_program_cache=self._strategy.use_cache,
                            return_numpy=self._strategy.return_numpy,
                        )
                    except core.EOFException:
                        break
                    lr = auto_utils.get_lr(self.optimizer)
                    logs = self._prepare_logger(
                        outs,
                        epoch,
                        step,
                        lr,
                        fetch_names,
                        fetch_indices,
                        self._mode,
1007
                    )
1008
                    cbks.on_batch_end('train', step, logs)
Z
zhaoyingli 已提交
1009 1010

            if valid_data and (epoch + 1) % valid_freq == 0:
1011 1012 1013 1014 1015 1016 1017 1018 1019 1020
                val_logs = self.evaluate(
                    valid_data,
                    valid_sample_split,
                    batch_size,
                    valid_steps,
                    log_freq,
                    collate_fn,
                    callbacks,
                    verbose,
                )
Z
zhaoyingli 已提交
1021
                val_logs = {
1022
                    "val_" + name: val for name, val in val_logs.items()
Z
zhaoyingli 已提交
1023 1024 1025 1026 1027 1028 1029 1030 1031 1032
                }
                logs.update(val_logs)
                self._switch_mode("train")
            else:
                self._reset_metrics()

            cbks.on_epoch_end(epoch, logs)

        cbks.on_end('train', logs)
        return self.history
1033

1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044
    def evaluate(
        self,
        valid_data,
        valid_sample_split=None,
        batch_size=1,
        steps=None,
        log_freq=10,
        collate_fn=None,
        callbacks=None,
        verbose=2,
    ):
1045 1046 1047 1048
        """
        Evaluate the loss and metrics of the model on evaluation data.

        Args:
1049 1050
            valid_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None.
            valid_sample_split (int, optional): Each sample of the eval dataset is assumed
1051
                to be a (input, label) pair by default and has two items. If each sample has
1052
                more than two items, valid_sample_split specifies how to split these items into
1053
                input and label. The items before it are input and the left are label. Default: None.
1054
            batch_size (int, optional): The batch size of valid_data. The user's data will
1055
                be used directly without batching if set to None. Default: 1.
1056 1057
            steps (int, optional): It is the total number of steps (batches of samples) to draw before
                stopping evaluation. If None, evaluation will run until the `valid_data` dataset is exhausted.
1058 1059 1060 1061 1062
                The evaluation will start from the beginning of the dataset in each run. Default: None.
            collate_fn(callable, optional): function to generate mini-batch data by merging
                the sample list, None for only stack each fields of sample in axis
                0. Default None.
            callbacks (Callback|None, optional): A list of `Callback` instances to apply
1063
                during evaluating. Default: None. (Unused for now)
1064 1065 1066 1067 1068 1069 1070 1071 1072 1073

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle
                import paddle.vision.transforms as T
1074
                from paddle.distributed.fleet import auto
1075 1076 1077 1078 1079 1080 1081 1082 1083
                from paddle.vision.datasets import MNIST

                transform = T.Compose([
                    T.Transpose(),
                    T.Normalize([127.5], [127.5])
                ])
                valid_dataset = MNIST(mode='test', transform=transform)

                model = paddle.vision.models.LeNet()
1084
                loss = paddle.nn.CrossEntropyLoss()
1085 1086
                metrics = paddle.metric.Accuracy(topk=(1, 2))

1087
                engine = auto.Engine(model, loss, metrics=metrics)
1088 1089 1090
                engine.evaluate(valid_dataset, batch_size=64)

        """
1091 1092
        self._mode = 'eval'
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
1093 1094
            valid_data, valid_sample_split, batch_size
        )
1095
        micro_batch_size = self._validate_batch_size(batch_size)
1096 1097
        if not self._has_prepared[self._mode]:
            self._prepare_program(self._mode)
Z
zhaoyingli 已提交
1098
        else:
1099
            self._switch_mode(self._mode)
Z
zhaoyingli 已提交
1100

1101 1102 1103 1104
        valid_dataloader = self._prepare_dataloader_from_generator(
            dataset=valid_data,
            capacity=70,
            iterable=False,
1105
            batch_size=micro_batch_size,
1106
            steps_per_epoch=steps,
1107 1108
            collate_fn=collate_fn,
        )
Z
zhaoyingli 已提交
1109

1110
        fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode)
1111

Z
zhaoyingli 已提交
1112 1113 1114
        cbks = config_callbacks(
            callbacks,
            engine=self,
1115
            batch_size=micro_batch_size,
Z
zhaoyingli 已提交
1116 1117 1118 1119 1120 1121
            log_freq=log_freq,
            verbose=verbose,
            metrics=self._metrics_name(),
        )

        eval_steps = valid_dataloader._steps
1122 1123 1124
        cbks.on_begin(
            'eval', {'steps': eval_steps, 'metrics': self._metrics_name()}
        )
Z
zhaoyingli 已提交
1125
        logs = {}
1126
        for step, _ in enumerate(valid_dataloader):
Z
zhaoyingli 已提交
1127
            cbks.on_batch_begin('eval', step, logs)
1128
            try:
1129 1130
                outs = self._executor.run(
                    self.main_program,
1131
                    fetch_list=fetch_names,
1132
                    use_program_cache=self._strategy.use_cache,
1133 1134
                    return_numpy=self._strategy.return_numpy,
                )
1135
            except core.EOFException:
1136
                break
1137 1138 1139
            logs = self._prepare_logger(
                outs, None, step, None, fetch_names, fetch_indices, self._mode
            )
Z
zhaoyingli 已提交
1140 1141
            cbks.on_batch_end('eval', step, logs)
        cbks.on_end('eval', logs)
1142
        self._reset_metrics()
Z
zhaoyingli 已提交
1143
        return logs
1144

1145 1146 1147 1148 1149 1150 1151 1152 1153 1154
    def predict(
        self,
        test_data,
        test_sample_split=None,
        batch_size=1,
        steps=None,
        collate_fn=None,
        callbacks=None,
        verbose=2,
    ):
1155 1156 1157 1158 1159 1160 1161
        """
        Compute the output predictions on testing data.

        Args:
            test_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None.
            test_sample_split (int, optional): Each sample of the test dataset is assumed
                to be a (input, label) pair by default and has two items. If each sample has
1162
                more than two items, test_sample_split specifies how to split these items into
1163 1164 1165
                input and label. The items before it are input and the left are label. Default: None.
            batch_size (int, optional): The batch size of test_data. The user's data will
                be used directly without batching if set to None. Default: 1.
1166 1167
            steps (int, optional): It is the total number of steps (batches of samples) to draw before
                stopping predict. If None, predict will run until the `test_data` dataset is exhausted.
1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183
                The predict will start from the beginning of the dataset in each run. Default: None.
            collate_fn(callable, optional): function to generate mini-batch data by merging
                the sample list, None for only stack each fields of sample in axis
                0. Default None.
            callbacks (Callback|None, optional): A list of `Callback` instances to apply
                during testing. Default: None. (Unused for now)

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle
                import paddle.vision.transforms as T
1184
                from paddle.distributed.fleet import auto
1185 1186 1187 1188 1189 1190 1191 1192 1193 1194
                from paddle.vision.datasets import MNIST

                transform = T.Compose([
                    T.Transpose(),
                    T.Normalize([127.5], [127.5])
                ])
                valid_dataset = MNIST(mode='test', transform=transform)

                model = paddle.vision.models.LeNet()

1195
                engine = auto.Engine(model)
1196 1197
                engine.predict(valid_dataset, batch_size=64)
        """
1198 1199
        self._mode = 'predict'
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
1200 1201
            test_data, test_sample_split, batch_size
        )
1202
        micro_batch_size = self._validate_batch_size(batch_size)
1203 1204
        if not self._has_prepared[self._mode]:
            self._prepare_program(self._mode)
Z
zhaoyingli 已提交
1205
        else:
1206
            self._switch_mode(self._mode)
Z
zhaoyingli 已提交
1207

1208 1209 1210 1211
        test_dataloader = self._prepare_dataloader_from_generator(
            dataset=test_data,
            capacity=70,
            iterable=False,
1212
            batch_size=micro_batch_size,
1213
            steps_per_epoch=steps,
1214 1215
            collate_fn=collate_fn,
        )
Z
zhaoyingli 已提交
1216

1217
        fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode)
1218

Z
zhaoyingli 已提交
1219 1220 1221 1222 1223
        outputs = []
        cbks = config_callbacks(callbacks, engine=self, verbose=verbose)
        test_steps = test_dataloader._steps
        cbks.on_begin('predict', {'steps': test_steps})
        logs = {}
1224
        for step, _ in enumerate(test_dataloader):
Z
zhaoyingli 已提交
1225
            cbks.on_batch_begin('predict', step, logs)
1226
            try:
1227 1228
                outs = self._executor.run(
                    self.main_program,
1229
                    fetch_list=fetch_names,
1230
                    use_program_cache=self._strategy.use_cache,
1231 1232
                    return_numpy=self._strategy.return_numpy,
                )
1233
            except core.EOFException:
1234
                break
1235 1236 1237
            logs = self._prepare_logger(
                outs, None, step, None, fetch_names, fetch_indices, self._mode
            )
Z
zhaoyingli 已提交
1238 1239 1240 1241 1242
            cbks.on_batch_end('predict', step, logs)
            outputs.append(list(logs["outputs"].values()))
        cbks.on_end('predict', logs)
        return outputs

1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258
    def dataloader(
        self,
        dataset,
        batch_size=1,
        shuffle=False,
        drop_last=False,
        collate_fn=None,
        num_workers=0,
        use_buffer_reader=True,
        use_shared_memory=True,
        timeout=0,
        worker_init_fn=None,
        epochs=1,
        steps_per_epoch=None,
        sample_split=1,
        mode=None,
1259
        places=None,
1260
    ):
1261 1262 1263
        if mode is not None:
            self.to_mode(mode)
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
1264 1265
            dataset, sample_split, batch_size
        )
1266
        micro_batch_size = self._validate_batch_size(batch_size)
1267 1268
        if not self._has_prepared[self._mode]:
            self._prepare_program(self._mode)
1269
        else:
1270
            self._switch_mode(self._mode)
1271

1272 1273 1274
        dataloader = self._prepare_dataloader(
            dataset,
            return_list=False,
1275
            batch_size=micro_batch_size,
1276 1277 1278 1279 1280 1281 1282 1283 1284
            shuffle=shuffle,
            drop_last=drop_last,
            collate_fn=collate_fn,
            num_workers=num_workers,
            use_buffer_reader=use_buffer_reader,
            use_shared_memory=use_shared_memory,
            timeout=timeout,
            worker_init_fn=worker_init_fn,
            epochs=epochs,
1285
            steps_per_epoch=steps_per_epoch,
1286
            places=places,
1287
        )
1288 1289
        return dataloader

1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304
    def dataloader_from_generator(
        self,
        dataset,
        capacity=70,
        use_double_buffer=True,
        iterable=True,
        use_multiprocess=False,
        drop_last=True,
        batch_size=1,
        epochs=1,
        steps_per_epoch=None,
        collate_fn=None,
        sample_split=1,
        mode=None,
    ):
1305 1306 1307
        if mode is not None:
            self.to_mode(mode)
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
1308 1309
            dataset, sample_split, batch_size
        )
1310
        micro_batch_size = self._validate_batch_size(batch_size)
1311 1312 1313 1314
        if not self._has_prepared[self._mode]:
            self._prepare_program(self._mode)
        else:
            self._switch_mode(self._mode)
1315

1316 1317 1318 1319 1320 1321 1322 1323
        dataloader = self._prepare_dataloader_from_generator(
            dataset=dataset,
            capacity=capacity,
            use_double_buffer=use_double_buffer,
            iterable=iterable,
            return_list=False,
            use_multiprocess=use_multiprocess,
            drop_last=drop_last,
1324
            batch_size=micro_batch_size,
1325 1326
            epochs=epochs,
            steps_per_epoch=steps_per_epoch,
1327 1328
            collate_fn=collate_fn,
        )
1329 1330
        return dataloader

1331 1332 1333 1334 1335 1336 1337 1338 1339
    def prepare(
        self,
        inputs_spec=None,
        labels_spec=None,
        inputs=None,
        labels=None,
        main_program=None,
        startup_program=None,
        mode=None,
1340
        init_parameters=True,
1341
    ):
1342 1343
        if mode is not None:
            self.to_mode(mode)
1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359

        if not self._mode:
            raise ValueError(
                "Please set mode to be prepared with `prepare(mode=...)`"
            )

        if self._has_prepared[self._mode]:
            return

        inputs_spec = self._validate_spec(inputs_spec)
        labels_spec = self._validate_spec(labels_spec)
        inputs = self._validate_vars(inputs)
        labels = self._validate_vars(labels)

        self._orig_main_prog = main_program
        self._orig_startup_prog = startup_program
1360 1361
        if inputs or labels:
            self._skip_build = True
1362 1363
            inputs, labels = self._prepare_data_tensor(
                inputs_spec, labels_spec, inputs, labels
1364
            )
1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375
            if self._orig_main_prog is None:
                self._orig_main_prog = static.default_main_program()
            if self._orig_startup_prog is None:
                self._orig_startup_prog = static.default_startup_program()
        elif inputs_spec or labels_spec:
            self._outside_dataloader = True
            if self._orig_main_prog is None:
                self._orig_main_prog = static.default_main_program()
            if self._orig_startup_prog is None:
                self._orig_startup_prog = static.default_startup_program()
        else:
1376 1377 1378
            assert (
                self._inputs_spec and self._labels_spec
            ), "Please call the dataloader(...) before calling prepare(...)"
1379

1380 1381 1382
        self._inputs_spec, self._labels_spec = inputs_spec, labels_spec
        self._inputs, self._labels = inputs, labels
        if not self._has_prepared[self._mode]:
1383
            self._prepare_program(self._mode, init_parameters)
1384 1385 1386
        else:
            self._switch_mode(self._mode)

1387
    def run(self, data=None, feed=None, fetch_list=None, mode=None):
1388 1389 1390 1391
        if mode is not None:
            self.to_mode(mode)
        feed_dict = self._prepare_feed(data, feed, self._mode)
        fetch_names, fetch_indices = self._prepare_fetch(fetch_list, self._mode)
1392 1393 1394 1395
        if (
            self._outside_dataloader
            and not self._has_prepared_reader[self._mode]
        ):
1396
            self._prepare_reader()
1397 1398 1399 1400 1401 1402 1403 1404 1405 1406
        outs = self._executor.run(
            self.main_program,
            feed=feed_dict,
            fetch_list=fetch_names,
            use_program_cache=self._strategy.use_cache,
            return_numpy=self._strategy.return_numpy,
        )
        logs = self._prepare_logger(
            outs, None, None, None, fetch_names, fetch_indices, self._mode
        )
Z
zhaoyingli 已提交
1407
        return logs
1408

1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423
    def _prepare_dataloader(
        self,
        dataset,
        return_list=True,
        batch_size=1,
        shuffle=False,
        drop_last=False,
        collate_fn=None,
        num_workers=0,
        use_buffer_reader=True,
        use_shared_memory=True,
        timeout=0,
        worker_init_fn=None,
        epochs=1,
        steps_per_epoch=None,
1424
        places=None,
1425
    ):
1426 1427 1428
        dist_context = self._dist_contexts[self._mode]
        dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
        dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank]
1429
        dist_main_block = dist_main_prog.global_block()
1430

1431 1432 1433 1434
        # NOTE: Get feed_list, then insert dataloader op with sharded var shape.
        # Cause predict_program does not contain labels var,
        # then we will add labels var from serial_program to dist_program,
        # that maintains the length of feed_list equal to the length of dataset's values.
1435 1436
        inputs_var = dist_context.serial_feed_vars["inputs"]
        labels_var = dist_context.serial_feed_vars["labels"]
1437 1438 1439 1440
        feed_list = []
        for var in inputs_var + labels_var:
            if var.name in dist_main_block.vars:
                feed_list.append(dist_main_block.vars[var.name])
1441 1442 1443 1444
            else:
                copy_var = dist_main_block._clone_variable(var, var.persistable)
                copy_var.desc.set_original_id(var.desc.original_id())
                feed_list.append(copy_var)
1445 1446

        # insert read op at the end of program
1447
        with static.program_guard(dist_main_prog, dist_startup_prog):
1448
            dataloader = DistributedDataLoader(
1449
                dataset,
1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464
                feed_list=feed_list,
                places=places,
                return_list=return_list,
                batch_size=batch_size,
                shuffle=shuffle,
                drop_last=drop_last,
                collate_fn=collate_fn,
                num_workers=num_workers,
                use_buffer_reader=use_buffer_reader,
                use_shared_memory=use_shared_memory,
                timeout=timeout,
                worker_init_fn=worker_init_fn,
                epochs=epochs,
                steps_per_epoch=steps_per_epoch,
                split_data=self._strategy.split_data,
1465
                data_parallel_world_size=self._dp_world_sizes,
1466 1467
                data_parallel_rank=self._dp_ranks,
            )
1468

1469 1470
        return dataloader

1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484
    def _prepare_dataloader_from_generator(
        self,
        dataset,
        capacity=None,
        use_double_buffer=True,
        iterable=True,
        return_list=False,
        use_multiprocess=False,
        drop_last=True,
        batch_size=1,
        epochs=1,
        steps_per_epoch=None,
        collate_fn=None,
    ):
1485 1486 1487
        dist_context = self._dist_contexts[self._mode]
        dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
        dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank]
1488 1489 1490 1491 1492 1493
        dist_main_block = dist_main_prog.global_block()

        # NOTE: Get feed_list, then insert dataloader op with sharded var shape.
        # Cause predict_program does not contain labels var,
        # then we will add labels var from serial_program to dist_program,
        # that maintains the length of feed_list equal to the length of dataset's values.
1494 1495
        inputs_var = dist_context.serial_feed_vars["inputs"]
        labels_var = dist_context.serial_feed_vars["labels"]
1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522
        feed_list = []
        for var in inputs_var + labels_var:
            if var.name in dist_main_block.vars:
                feed_list.append(dist_main_block.vars[var.name])
            else:
                copy_var = dist_main_block._clone_variable(var, var.persistable)
                copy_var.desc.set_original_id(var.desc.original_id())
                feed_list.append(copy_var)

        places = paddle.static.cuda_places()
        with static.program_guard(dist_main_prog, dist_startup_prog):
            dataloader = DistributedDataLoaderFromGenerator(
                dataset=dataset,
                feed_list=feed_list,
                capacity=capacity,
                use_double_buffer=use_double_buffer,
                iterable=iterable,
                return_list=return_list,
                use_multiprocess=use_multiprocess,
                drop_last=drop_last,
                places=places,
                batch_size=batch_size,
                epochs=epochs,
                steps_per_epoch=steps_per_epoch,
                collate_fn=collate_fn,
                split_data=self._strategy.split_data,
                data_parallel_world_size=self._dp_world_sizes,
1523
                data_parallel_rank=self._dp_ranks,
1524 1525 1526
                acc_steps=1
                if not self._strategy.pipeline.enable
                else self._acc_steps,
1527
            )
1528
        self._prepare_reader(feed_list)
1529 1530 1531 1532 1533
        return dataloader

    def _tune(self, tune_data, tune_sample_split=None, batch_size=1):
        self._mode = 'train'
        self._inputs_spec, self._labels_spec = self._prepare_data_spec(
1534 1535
            tune_data, tune_sample_split, batch_size
        )
1536 1537
        self._optimization_tuning(self._mode, tune_data, batch_size)

1538 1539 1540 1541 1542 1543 1544 1545 1546 1547
    def _validate_batch_size(self, batch_size):
        if batch_size is None:
            return None
        assert (
            batch_size % self._acc_steps == 0
        ), "Requires batch_size:[{}] to be divisible by acc_steps:[{}].".format(
            batch_size, self._acc_steps
        )
        return batch_size // self._acc_steps

1548
    def _validate_spec(self, specs):
1549
        specs = auto_utils.to_list(specs)
1550 1551
        if specs is not None:
            for i, spec in enumerate(specs):
1552 1553 1554 1555
                if not isinstance(spec, InputSpec):
                    raise TypeError(
                        "'spec' must be object of class `paddle.static.InputSpec`."
                    )
1556 1557
                if spec.name is None:
                    raise ValueError(
1558 1559 1560 1561
                        "Requires Input[{}].name != None, but receive `None` with {}.".format(
                            i, spec
                        )
                    )
1562
                if self._acc_steps > 1:
1563
                    shape = list(spec.shape)
1564
                    assert (
1565
                        shape[0] % self._acc_steps == 0
1566
                    ), "Requires batch_size[{}] to be divisible by k_steps[{}].".format(
1567
                        spec.shape[0], self._acc_steps
1568
                    )
1569
                    shape[0] //= self._acc_steps
1570
                    spec.shape = shape
1571 1572 1573
        return specs or []

    def _validate_vars(self, vars):
1574
        vars = auto_utils.to_list(vars)
1575 1576 1577 1578 1579
        if vars is not None:
            for i, var in enumerate(vars):
                if not isinstance(var, Variable):
                    raise TypeError("'var' must be a `Variable`.")
        return vars or []
1580

1581 1582 1583 1584
    def _is_local_var(self, var):
        var_name = _to_name_str(var)
        return var_name in self.main_program.global_block().vars

1585 1586 1587 1588
    def _reset_metrics(self):
        for metric in self._metrics:
            metric.reset()

Z
zhaoyingli 已提交
1589 1590 1591
    def _metrics_name(self):
        metrics_name = ['loss'] if self._loss else []
        for m in self._metrics:
1592
            metrics_name.extend(auto_utils.to_list(m.name()))
Z
zhaoyingli 已提交
1593 1594
        return metrics_name

1595
    def _switch_mode(self, mode):
1596
        assert (
1597
            mode in self._dist_contexts
1598
        ), f"{mode} model is not ready, please call `prepare()` first."
1599
        self.to_mode(mode)
1600

1601
    def to_mode(self, mode):
1602 1603 1604 1605
        assert mode in [
            "train",
            "eval",
            "predict",
1606
        ], f"mode {mode} should be one of ['train', 'eval', 'predict']"
1607 1608
        self._mode = mode

1609 1610
    def _set_state_dict(self, mode, strict, state_dict, dist_attr):
        dist_context = self._dist_contexts[mode]
1611
        program = dist_context.dist_main_programs[self._cur_rank]
1612
        cur_dist_attr = auto_utils.get_dist_attr(program, dist_context)
1613 1614
        converter = Converter(state_dict, dist_attr, cur_dist_attr)
        state_dict = converter.convert(strict=strict)
1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627
        for name, param in program.state_dict().items():
            param_array = np.array(param)
            if name not in state_dict:
                continue
            if param_array.dtype != state_dict[name].dtype:
                self._logger.info(
                    "cast {}'s dtype from '{}' to '{}'".format(
                        name,
                        str(state_dict[name].dtype),
                        str(param_array.dtype),
                    )
                )
                state_dict[name] = state_dict[name].astype(param_array.dtype)
1628 1629 1630
        program.set_state_dict(state_dict)

    def save(self, path, training=True):
1631 1632
        """
        Saves the model, parameters, optimizer state to path.
1633 1634 1635 1636 1637 1638 1639
        If `training` is set to False, only inference model will be saved.

        Args:
            path (str): The file prefix to save model. The format
                is 'dirname/file_prefix' or 'file_prefix'. if empty str.
                A exception will be raised.
            training (bool, optional): Whether to save for training. If not, save
1640
                for inference only. If `training` is set to True, the optimizer state
1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652
                will be saved. Otherwise, only the model and parameters are saved.
                This function will silently overwrite existing file at the target
                location. Default: True.

        Returns:
            None

        Examples:

            .. code-block:: python
                import paddle
                import paddle.vision.transforms as T
1653
                from paddle.distributed.fleet import auto
1654 1655 1656 1657 1658 1659 1660 1661 1662
                from paddle.vision.datasets import MNIST

                transform = T.Compose([
                    T.Transpose(),
                    T.Normalize([127.5], [127.5])
                ])
                train_dataset = MNIST(mode='train', transform=transform)

                model = paddle.vision.models.LeNet()
1663
                loss = paddle.nn.CrossEntropyLoss()
1664 1665 1666 1667
                optimizer = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=model.parameters())
                metrics = paddle.metric.Accuracy(topk=(1, 2))

1668
                engine = auto.Engine(model, loss, optimizer, metrics)
1669 1670 1671 1672
                engine.fit(train_dataset,
                           epochs=1,
                           batch_size=64)
                engine.save("./my_model")
1673

1674
        """
1675
        if training:
1676
            assert self._mode in self._dist_contexts
Z
zhaoyingli 已提交
1677
            dist_context = self._dist_contexts[self._mode]
1678 1679
            serial_program = dist_context.serial_main_program
            dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
1680 1681 1682 1683 1684 1685
            self._saver.save(
                path,
                serial_program=serial_program,
                dist_main_program=dist_main_prog,
                dist_context=dist_context,
            )
1686
        else:
1687 1688 1689 1690 1691
            assert "predict" in self._dist_contexts
            dist_context = self._dist_contexts["predict"]
            feed_vars = dist_context.serial_feed_vars['inputs']
            fetch_vars = dist_context.serial_fetch_vars['outputs']
            dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
1692
            if self._strategy.qat.enable and self._strategy.qat.onnx_format:
1693
                from paddle.static.quantization import QuantWeightPass
1694 1695 1696

                self._logger.info("export quantized model.")
                self._logger.info(
1697
                    f"convert config {self._strategy.qat.to_dict()}"
1698 1699 1700 1701 1702 1703 1704 1705
                )
                test_graph = IrGraph(
                    core.Graph(dist_main_prog.desc), for_test=True
                )
                quant_weight_pass = QuantWeightPass(global_scope(), self._place)
                for sub_graph in test_graph.all_sub_graphs():
                    quant_weight_pass.apply(sub_graph)
                dist_main_prog = test_graph.to_program()
1706 1707 1708 1709 1710 1711 1712
            self._saver.save_inference_model(
                path,
                feed_vars,
                fetch_vars,
                self._executor,
                program=dist_main_prog,
            )
1713

1714 1715 1716 1717 1718 1719
    def load(self, path, strict=True, load_optimizer=True):
        """
        Load the stored model, parameters and optimizer states.

        Args:
            path (str): The prefix of files storing the model states and
1720
                optimizer states.
1721 1722 1723
            strict (bool, optional): Whether to skip the loading of mismatch
                parameter or raise an error when mismatch happens (not found
                the parameter in file storing model states of or receives a
1724
                mismatch shape). Default: True.
1725
            load_optimizer (bool, optional): If True, the stored optimizer
1726
                states is restored. Otherwise, the optimizer states is initialized
1727
                from scratch. Default: True.
1728 1729 1730 1731 1732 1733 1734 1735 1736

        Returns:
            None

        Examples:

            .. code-block:: python
                import paddle
                import paddle.vision.transforms as T
1737
                from paddle.distributed.fleet import auto
1738 1739 1740 1741 1742 1743 1744 1745 1746
                from paddle.vision.datasets import MNIST

                transform = T.Compose([
                    T.Transpose(),
                    T.Normalize([127.5], [127.5])
                ])
                train_dataset = MNIST(mode='train', transform=transform)

                model = paddle.vision.models.LeNet()
1747
                loss = paddle.nn.CrossEntropyLoss()
1748 1749 1750 1751
                optimizer = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=model.parameters())
                metrics = paddle.metric.Accuracy(topk=(1, 2))

1752
                engine = auto.Engine(model, loss, optimizer, metrics)
1753 1754 1755 1756 1757
                engine.fit(train_dataset,
                           epochs=1,
                           batch_size=64)
                engine.save("./my_model")
                engine.load("./my_model")
1758

1759 1760 1761
        """
        self._strict = strict
        self._state_dict, self._dist_attr = self._saver.load(
1762 1763
            path, load_optimizer
        )
1764
        return self._state_dict, self._dist_attr
1765

1766
    def cost(self, inputs_spec=None, labels_spec=None, mode=None):
1767 1768 1769 1770 1771 1772 1773 1774 1775 1776
        """
        Get and Print cost, including memory of every rank,
        max memory among all ranks, and the global cost of one step based on
        communication cost(computation cost is 0 by default).
        In the future, the flops information of every rank and global cost including
        computation cost will be added.

        Args:
            inputs_spec(InputSpec): The specification of inputs. Default: None.
            labels_spec(InputSpec): The specification of labels. Default: None.
1777
            mode (str): The engine mode must be in ["train", "predict", "eval"]. Default: None.
1778 1779 1780 1781 1782 1783 1784

        Returns:
            Return the global execution time (ms) and max memory (B).

        """
        # Check parallel mode
        if self._strategy.auto_mode == "full":
1785
            self._logger.info(
1786 1787 1788 1789 1790
                "The cost will be calcudated in the search process when the auto mode is full."
            )
            return

        # Check mode
1791 1792 1793
        mode = mode if mode is not None else self._mode
        assert mode is not None, "Please set mode."
        if mode not in self._has_prepared:
1794 1795
            raise ValueError(
                "The mode {} is not in accepted modes {}".format(
1796
                    mode, list(self._has_prepared.keys())
1797 1798
                )
            )
1799 1800
        self.to_mode(mode)

1801 1802 1803
        if inputs_spec is not None and not self._has_prepared[mode]:
            self._inputs_spec = self._validate_spec(inputs_spec)
            self._labels_spec = self._validate_spec(labels_spec)
1804 1805 1806
            self._build(mode)
            self._plan(mode)
        else:
1807
            if in_dynamic_mode() or self._dygraph_mode:
1808
                raise ValueError(
1809 1810 1811 1812 1813
                    "Please call `prepare()` or `fit()` or  `evaluate()` or  `predict()` before calling `cost()`."
                )
            else:
                self._logger.info(
                    "The program whose cost to be estimated must be static default program. Otherwise, please call `prepare()`before calling `cost()`."
1814
                )
1815 1816 1817 1818 1819 1820 1821 1822
                program = paddle.static.default_main_program()
                if (
                    not program.global_block().ops
                    or not program.global_block().ops
                ) and not self._has_prepared[mode]:
                    raise ValueError(
                        "Please call `prepare()` or `fit()` or  `evaluate()` or  `predict()` before calling `cost()`."
                    )
1823 1824 1825 1826 1827 1828

        # Estimate the exec cost and max memory
        global_cost, max_memory = get_cost_from_engine(self, mode)

        return global_cost.time, max_memory

1829 1830
    @property
    def main_program(self):
1831 1832
        dist_context = self._dist_contexts[self._mode]
        return dist_context.dist_main_programs[self._cur_rank]
1833 1834 1835

    @property
    def startup_program(self):
1836 1837
        dist_context = self._dist_contexts[self._mode]
        return dist_context.dist_startup_programs[self._cur_rank]
1838 1839 1840

    @property
    def dist_context(self):
1841
        return self._dist_contexts[self._mode]
1842 1843 1844

    @property
    def serial_main_program(self):
1845 1846
        dist_context = self._dist_contexts[self._mode]
        return dist_context.serial_main_program
1847 1848 1849

    @property
    def serial_startup_program(self):
1850 1851 1852 1853 1854 1855 1856
        dist_context = self._dist_contexts[self._mode]
        return dist_context.serial_startup_program

    @property
    def feed_vars(self):
        dist_context = self._dist_contexts[self._mode]
        return dist_context.serial_feed_vars
1857 1858 1859

    @property
    def fetch_vars(self):
1860 1861 1862 1863 1864 1865 1866 1867 1868
        dist_context = self._dist_contexts[self._mode]
        return dist_context.serial_fetch_vars

    @property
    def optimizer(self):
        dist_context = self._dist_contexts[self._mode]
        if dist_context._serial_optimizer:
            return dist_context._serial_optimizer
        return self._optimizer
1869 1870 1871

    @property
    def inputs(self):
1872
        return self._inputs
1873 1874 1875

    @property
    def labels(self):
1876
        return self._labels