fleet_base.py 66.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
16
import copy
17
import warnings
18
import paddle
19
import os
20
from types import MethodType
21
import numpy as np
22
from paddle.fluid.framework import dygraph_only, _global_flags
23
from paddle.fluid import compiler
24
from .role_maker import UserDefinedRoleMaker, PaddleCloudRoleMaker, RoleMakerBase
25
from .strategy_compiler import StrategyCompiler
26
from .distributed_strategy import DistributedStrategy
27 28
from .meta_optimizer_factory import MetaOptimizerFactory
from .runtime_factory import RuntimeFactory
29
from paddle.fluid.wrapped_decorator import wrap_decorator
30
from paddle.fluid.dygraph import parallel_helper
31
from paddle.fluid.ir import apply_build_strategy
32
from . import topology as tp
33
from .topology import ParallelMode
34
from ..meta_parallel import TensorParallel, model_parallel_random_seed
J
JZ-LIANG 已提交
35
from ..meta_parallel import PipelineParallel, ShardingParallel
K
kuizhiqing 已提交
36
from ..meta_optimizers import HybridParallelOptimizer, HeterParallelOptimizer
37
from paddle import _C_ops
38 39
from paddle.fluid import core
from paddle.fluid.dygraph import to_variable
40
from paddle.distributed.fleet.utils.recompute import LegacyRecomputeFunction
41
from paddle.fluid.dygraph.varbase_patch_methods import _grad_scalar
42

43 44
__all__ = []

45 46 47 48
_grad_scalar = None


class _RecomputeModelWrapper(paddle.nn.Layer):
49

50 51 52 53 54 55 56 57 58 59 60 61
    def __init__(self, model, segments=2, preserve_rng_state=True):
        super(_RecomputeModelWrapper, self).__init__()
        assert isinstance(model, paddle.nn.Sequential), (
            "The model passed to RecomputeModelWrapper must be of type "
            "paddle.nn.Sequential.")
        self._model = model
        self._segments = segments
        self._preserve_rng_state = preserve_rng_state
        self._layers = list(model.children())
        self._segment_size = len(self._layers) // segments

    def _run_func(self, begin, end):
62

63 64 65 66 67 68 69 70
        def do_run(input):
            for i in range(begin, end):
                input = self._layers[i](input)
            return input

        return do_run

    def _checkpoint(self, func, *args, **kwargs):
71 72
        return LegacyRecomputeFunction.apply(func, self._preserve_rng_state,
                                             *args)
73 74 75 76 77 78 79 80 81

    def forward(self, input):
        end = 0
        for begin in range(0, self._segment_size * (self._segments - 1),
                           self._segment_size):
            end = begin + self._segment_size
            input = self._checkpoint(self._run_func(begin, end), input)
        return self._run_func(end, len(self._layers))(input)

82

83 84 85 86 87 88 89 90 91 92 93 94 95 96
def apply_ir_passes(main_program, startup_program, config):
    build_strategy = config._user_defined_strategy.build_strategy._copy()
    if not _global_flags()['FLAGS_apply_pass_to_program']:
        return build_strategy

    pipeline_opt = getattr(main_program, "_pipeline_opt", {})
    if pipeline_opt:
        main_program = pipeline_opt["section_program"]
        startup_program = startup_program._pipeline_opt["startup_program"]

    pass_attrs = {"use_cuda": config._is_collective}
    fuse_all_reduce = config._user_defined_strategy.fuse_all_reduce_ops
    if fuse_all_reduce and build_strategy.fuse_all_optimizer_ops:
        # FIXME(zjl): currently, fuse_all_optimizer_ops
97 98 99 100
        # have conflict with fuse_all_reduce_ops because
        # RawProgramOptimizer also inserts coalesce_tensor
        # into program. These two procedures may conflict
        # in which vars are to be fused.
101 102 103 104 105 106 107 108 109
        warnings.warn(
            'Currently, the fuse_all_optimizer_ops pass has conflict with fuse_all_reduce_ops pass. Disable the fuse_all_optimizer_ops pass temporarily.'
        )
        build_strategy.fuse_all_optimizer_ops = False

    return apply_build_strategy(main_program, startup_program, build_strategy,
                                pass_attrs)


110
def _inited_runtime_handler_(func):
111

112 113 114 115 116 117 118 119 120 121 122
    def __impl__(*args, **kwargs):
        cls = args[0]

        if cls._runtime_handle is None:
            raise ValueError("Fleet can not find suitable runtime handler")

        return func(*args, **kwargs)

    return __impl__


123
def _is_non_distributed_check_(func):
124

125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
    def __impl__(*args, **kwargs):
        cls = args[0]

        if cls._role_maker is not None and cls._role_maker._is_non_distributed(
        ) is True:
            warnings.warn(
                "%s() function doesn't work when use non_distributed fleet." %
                (func.__name__))
            return

        return func(*args, **kwargs)

    return __impl__


140
inited_runtime_handler = wrap_decorator(_inited_runtime_handler_)
141
is_non_distributed_check = wrap_decorator(_is_non_distributed_check_)
142 143


144 145 146
class Fleet(object):
    """
    Unified API for distributed training of PaddlePaddle
147
    Please reference the https://github.com/PaddlePaddle/FleetX for details
148 149 150 151 152


    Returns:
        Fleet: A Fleet instance

153
    Example for collective training:
1
123malin 已提交
154

155 156
        .. code-block:: python

1
123malin 已提交
157 158
            import paddle
            paddle.enable_static()
159
            import paddle.distributed.fleet as fleet
160 161 162

            fleet.init(is_collective=True)

163 164 165
            strategy = fleet.DistributedStrategy()
            optimizer = paddle.optimizer.SGD(learning_rate=0.001)
            optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
166 167 168 169 170 171 172 173

            # do distributed training


    Example for parameter server training:

        .. code-block:: python

1
123malin 已提交
174 175
            import paddle
            paddle.enable_static()
176 177
            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
S
ShenLiang 已提交
178
            fleet.init(strategy=strategy)
179

180
            optimizer = paddle.optimizer.SGD(learning_rate=0.001)
181
            optimizer = fleet.distributed_optimizer(optimizer)
182

183 184
            if fleet.is_first_worker():
                print("this is first worker")
185

186 187
            print("current node index: {}".format(fleet.worker_index()))
            print("total number of worker num: {}".format(fleet.worker_num()))
188

189 190 191
            if fleet.is_worker():
                print("this is worker")
            print("worker endpoints: {}".format(fleet.worker_endpoints(to_string=True)))
192

193 194
            print("server num: {}".format(fleet.server_num()))
            print("server endpoints: {}".format(fleet.server_endpoints(to_string=True)))
195

196 197 198
            if fleet.is_server():
                print("this is server")
            fleet.stop_worker()
199 200


201 202 203
    """

    def __init__(self):
204
        self._role_maker = None
205
        self.strategy_compiler = None
206
        self._is_collective = False
207
        self._runtime_handle = None
D
Dong Daxiang 已提交
208 209
        self._util = None
        self._context = {}
210

211
    def init(self, role_maker=None, is_collective=False, strategy=None):
212 213 214
        """
        Initialize role_maker in Fleet.

215 216 217 218 219 220 221 222 223 224 225
        This function is responsible for the distributed architecture
        what you want to run your code behind.

        Args:
            role_maker (RoleMakerBase, optional): A ``RoleMakerBase`` containing the configuration
                of environment variables related to distributed training.If you did not initialize 
                the rolemaker by yourself, it will be automatically initialized to PaddleRoleMaker.
                The default value is None.
            is_collective (Boolean, optional): A ``Boolean`` variable determines whether the program 
                runs on the CPU or GPU. False means set distributed training using CPU, and True means
                GPU.The default value is False.The default value is False.
226 227 228 229
            strategy (DistributedStrategy): Extra properties for distributed training. 
                For details, please refer to paddle.distributed.fleet.DistributedStrategy. Default: None.


230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251
        Returns:
            None

        Examples1:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

        Examples2:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init(is_collective=True)

        Examples3:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
1
123malin 已提交
252
                role = fleet.PaddleCloudRoleMaker()
253
                fleet.init(role)
254

255 256 257 258 259 260
        Examples4:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                strategy = fleet.DistributedStrategy()
S
ShenLiang 已提交
261
                fleet.init(strategy=strategy)
262

263
        """
S
ShenLiang 已提交
264 265 266
        if strategy is None:
            strategy = DistributedStrategy()
        self._user_defined_strategy = copy.deepcopy(strategy)
267 268

        if role_maker is None:
269 270 271 272 273 274
            if isinstance(is_collective, bool):
                self._is_collective = is_collective
                self._role_maker = PaddleCloudRoleMaker(
                    is_collective=self._is_collective)
            else:
                raise ValueError(
275 276
                    "`is_collective` should be instance of `bool`, but got {}".
                    format(type(is_collective)))
277
        else:
278 279
            if isinstance(role_maker, RoleMakerBase):
                self._role_maker = role_maker
280
                self._is_collective = role_maker._is_collective
281 282
            else:
                raise ValueError(
283 284
                    "`role_maker` should be subclass of `RoleMakerBase`, but got {}"
                    .format(type(role_maker)))
285
        self._role_maker._generate_role()
286

287 288 289
        import paddle.distributed.fleet as fleet
        fleet.util._set_role_maker(self._role_maker)

290
        self.strategy_compiler = StrategyCompiler()
291 292 293 294 295 296 297 298 299

        if self._role_maker._is_non_distributed() and self._is_collective:
            if paddle.fluid.core.is_compiled_with_cuda():
                gpus_num = paddle.fluid.core.get_cuda_device_count()
                if gpus_num != 1:
                    raise ValueError(
                        "CUDA_VISIBLE_DEVICES shoule be set only 1 card if you use `python` to launch fleet program."
                    )

J
Jiabin Yang 已提交
300
        if paddle.fluid.framework._non_static_mode():
301
            if self.worker_num() == 1:
302 303 304
                # if worker_num is 1, should construct default topology & hcg
                self._topology = tp.CommunicateTopology()
                self._hcg = tp.HybridCommunicateGroup(self._topology)
305
                return
306 307 308 309
            if parallel_helper._is_parallel_ctx_initialized():
                warnings.warn(
                    "The dygraph parallel environment has been initialized.")
            else:
310 311 312 313 314 315 316 317 318
                # FLAGS_nccl_nrings is used for dynamic graph multi-stream communication
                if "FLAGS_nccl_nrings" in os.environ:
                    warnings.warn(
                        "You have set the environment variable FLAGS_nccl_nrings "
                        "outside the program, so the nccl_comm_num in "
                        "DistributedStrategy will not take effect here.")
                else:
                    os.environ["FLAGS_nccl_nrings"] = str(
                        self._user_defined_strategy.nccl_comm_num)
319
                paddle.distributed.init_parallel_env()
320

K
kuizhiqing 已提交
321 322 323 324 325 326 327 328 329
            # hybrid parallel not support for npu/xpu
            if self._user_defined_strategy.heter_ccl_mode == False:
                # init hybrid parallel environment in dygraph
                if tp._HYBRID_PARALLEL_GROUP is None:
                    self._init_hybrid_parallel_env()
                else:
                    warnings.warn(
                        "The dygraph hybrid parallel environment has been initialized."
                    )
W
WangXi 已提交
330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345
        elif self._is_collective:
            use_sharding = self._user_defined_strategy.sharding

            # global group
            global_rank = self.worker_index()
            global_world_size = self.worker_num()
            # NOTE(wangxi): see sharding_optimizer
            global_ring_id = 3 if use_sharding else 0
            global_ranks = list(range(global_world_size))

            if tp._HYBRID_PARALLEL_GROUP is None: tp._CommunicateGroup()
            cg = tp._HYBRID_PARALLEL_GROUP
            self._hcg = cg
            cg.set_comm_group('global', global_rank, global_world_size,
                              global_ring_id, global_ranks)

Y
Yuang Liu 已提交
346 347 348
            use_tensor_parallel = self._user_defined_strategy.tensor_parallel
            use_mp = use_sharding or use_tensor_parallel

W
WangXi 已提交
349
            # hybrid group
Y
Yuang Liu 已提交
350 351 352 353 354 355 356 357 358 359
            if use_mp is False: return

            mp_degree_sharding = 1
            mp_degree_tensor_parallel = 1
            if use_sharding:
                sharding_configs = self._user_defined_strategy.sharding_configs
                mp_degree_sharding = int(sharding_configs['mp_degree'])

            if use_tensor_parallel:
                tensor_parallel_configs = self._user_defined_strategy.tensor_parallel_configs
360 361
                mp_degree_tensor_parallel = int(
                    tensor_parallel_configs['tensor_parallel_degree'])
Y
Yuang Liu 已提交
362 363 364

            if use_sharding and use_tensor_parallel:
                assert mp_degree_sharding == mp_degree_tensor_parallel
W
WangXi 已提交
365

Y
Yuang Liu 已提交
366
            mp_degree = mp_degree_sharding if use_sharding else mp_degree_tensor_parallel
W
WangXi 已提交
367 368 369 370 371 372 373 374 375 376 377 378 379

            if mp_degree > 1:
                assert global_world_size % mp_degree == 0
                # NOTE(wangxi): mp_ring_id sync with sharding_optimizer.py _build_groups
                mp_ring_id = 0
                mp_rank = global_rank % mp_degree
                mp_group_id = global_rank // mp_degree
                mp_group_ranks = [
                    idx for idx in global_ranks
                    if idx // mp_degree == mp_group_id
                ]
                cg.set_comm_group('model', mp_rank, mp_degree, mp_ring_id,
                                  mp_group_ranks)
380 381 382 383 384 385 386 387

    def _init_hybrid_parallel_env(self):
        """initialize the hybrid environment
        """
        self.hybrid_configs = self._user_defined_strategy.hybrid_configs
        self.dp_degree = self.hybrid_configs["dp_degree"]
        self.mp_degree = self.hybrid_configs["mp_degree"]
        self.pp_degree = self.hybrid_configs["pp_degree"]
J
JZ-LIANG 已提交
388
        self.sharding_degree = self.hybrid_configs["sharding_degree"]
389 390 391

        assert self.mp_degree >= 0, "mp_degree should be greater or equal to 0"
        assert self.pp_degree >= 0, "pp_degree should be greater or equal to 0"
J
JZ-LIANG 已提交
392
        assert self.sharding_degree >= 0, "sharding_degree should be greater or equal to 0"
393 394 395 396 397 398 399 400 401 402 403

        self.mp_degree = max(self.mp_degree, 1)
        self.pp_degree = max(self.pp_degree, 1)

        if self.dp_degree < 0:
            nranks = paddle.distributed.get_world_size()
            self.dp_degree = nranks // (self.mp_degree * self.pp_degree)

        self.dp_degree = max(self.dp_degree, 1)

        self._topology = tp.CommunicateTopology(
J
JZ-LIANG 已提交
404 405 406 407 408
            hybrid_group_names=["data", "pipe", "sharding", "model"],
            dims=[
                self.dp_degree, self.pp_degree, self.sharding_degree,
                self.mp_degree
            ])
409 410 411

        self._hcg = tp.HybridCommunicateGroup(self._topology)

412 413 414 415 416 417 418 419
        if self.mp_degree > 1:
            tensor_parallel_configs = self._user_defined_strategy.tensor_parallel_configs
            tensor_init_seed = tensor_parallel_configs["tensor_init_seed"]
            if tensor_init_seed == -1:
                model_parallel_random_seed()
            else:
                model_parallel_random_seed(tensor_init_seed)

420 421 422 423 424 425 426 427
    def get_hybrid_communicate_group(self):
        assert self._hcg is not None
        return self._hcg

    def get_hybrid_parallel_topology(self):
        assert self._topology is not None
        return self._topology

428 429 430 431 432 433 434
    def is_first_worker(self):
        """
        Check whether the node is the first instance of worker.

        Returns:
            bool: True if this is the first node of worker,
                  False if not.
435

436 437 438 439 440 441 442 443
        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.is_first_worker()

444
        """
445
        return self._role_maker._is_first_worker()
446 447 448 449 450 451 452

    def worker_index(self):
        """
        Get current worker index.

        Returns:
            int: node id
453 454 455 456

        Examples:

            .. code-block:: python
1
123malin 已提交
457

458 459 460 461
                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.worker_index()

462
        """
463
        return self._role_maker._worker_index()
464 465 466 467 468 469 470

    def worker_num(self):
        """
        Get current total worker number.

        Returns:
            int: worker numbers
1
123malin 已提交
471

472
        Examples:
1
123malin 已提交
473

474 475 476 477 478 479
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.worker_num()

480
        """
481
        return self._role_maker._worker_num()
482

483 484 485 486 487 488 489 490 491 492 493 494
    def node_num(self):
        return self._role_maker._get_node_num()

    def local_rank(self):
        return self._role_maker._get_local_rank()

    def local_device_ids(self):
        return self._role_maker._get_local_device_ids()

    def world_device_ids(self):
        return self._role_maker._get_world_device_ids()

495 496 497 498 499 500 501
    def is_worker(self):
        """
        Check whether the node is an instance of worker.

        Returns:
            bool: True if this is a node of worker,
                  False if not.
502 503

        Examples:
1
123malin 已提交
504

505 506 507 508 509 510
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.is_worker()

511
        """
512
        return self._role_maker._is_worker()
513 514 515

    def worker_endpoints(self, to_string=False):
        """
516
        Get current worker endpoints, such as ["127.0.0.1:1001", "127.0.0.1:1002"].
517 518 519

        Returns:
            list/string: server endpoints
520 521

        Examples:
1
123malin 已提交
522

523 524 525 526 527 528
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.worker_endpoints()

529 530
        """
        if to_string:
531
            return ",".join(self._role_maker._get_trainer_endpoints())
532
        else:
533
            return self._role_maker._get_trainer_endpoints()
534 535 536 537 538 539 540

    def server_num(self):
        """
        Get current total worker number.

        Returns:
            int: server number
541 542

        Examples:
1
123malin 已提交
543

544
            .. code-block:: python
1
123malin 已提交
545 546 547 548

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.server_num()
549
        """
550
        return len(self._role_maker._get_pserver_endpoints())
551 552 553 554 555 556 557

    def server_index(self):
        """
        Get current server index.

        Returns:
            int: node id
558 559

        Examples:
1
123malin 已提交
560

561 562 563 564 565 566
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.server_index()

567
        """
568
        return self._role_maker._server_index()
569 570 571 572 573 574 575

    def server_endpoints(self, to_string=False):
        """
        Get current server endpoints, such as ["127.0.0.1:1001", "127.0.0.1:1002"].

        Returns:
            list/string: server endpoints
576 577

        Examples:
1
123malin 已提交
578

579 580 581 582 583 584
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.server_endpoints()

585
        """
586

587
        if to_string:
588
            return ",".join(self._role_maker._get_pserver_endpoints())
589
        else:
590
            return self._role_maker._get_pserver_endpoints()
591 592 593 594 595 596 597 598

    def is_server(self):
        """
        Check whether the node is an instance of server.

        Returns:
            bool: True if this is a node of server,
                  False if not.
599 600 601 602

        Examples:

            .. code-block:: python
1
123malin 已提交
603

604 605 606 607
                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.is_server()

608
        """
609 610
        return self._role_maker._is_server()

611 612
    def barrier_worker(self):
        """
613 614 615 616
        barrier all workers

        Returns:
            None
617
        """
618
        self._role_maker._barrier("worker")
619

620
    @is_non_distributed_check
621
    @inited_runtime_handler
622
    def init_worker(self, scopes=None):
623
        """
624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641
        initialize `Communicator` for parameter server training.


        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_worker()

642
        """
643
        self._runtime_handle._init_worker(scopes)
644

645
    @is_non_distributed_check
646
    @inited_runtime_handler
647
    def init_server(self, *args, **kwargs):
648
        """
649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667
        init_server executor to initialize startup program,
        if the `args` is not empty, it will run load_persistables for increment training.


        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_server()

668
        """
669
        self._runtime_handle._init_server(*args, **kwargs)
670

Z
zmxdream 已提交
671 672
    @is_non_distributed_check
    @inited_runtime_handler
T
Thunderbrook 已提交
673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695
    def load_model(self, path, mode):
        """
        load fleet model from path


        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.load_model("path", "mode")

        """
        self._runtime_handle.load_model(path, mode)

696
    @is_non_distributed_check
697
    @inited_runtime_handler
698 699
    def run_server(self):
        """
700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717
        run server will run pserver main program with executor.

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                if fleet.is_server():
                    fleet.init_server()

718 719 720
        """
        self._runtime_handle._run_server()

721
    @is_non_distributed_check
722
    @inited_runtime_handler
723 724
    def stop_worker(self):
        """
725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741
        stop `Communicator` and give training complete notice to parameter server.

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_server()

742 743 744
        """
        self._runtime_handle._stop_worker()

Z
zmxdream 已提交
745 746
    @is_non_distributed_check
    @inited_runtime_handler
T
tangwei12 已提交
747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780
    def save(self, dirname, feed=[], fetch=[], **configs):
        inference = True

        if not feed and not fetch:
            inference = False

        place = paddle.CPUPlace()
        executor = paddle.static.Executor(place)

        if inference:
            feeded_var_names = []
            fetch_var_names = []

            for var in feed:
                if isinstance(var, str):
                    feeded_var_names.append(var)
                elif isinstance(var, paddle.static.Variable):
                    feeded_var_names.append(var.name)
                else:
                    raise ValueError("feed must be [str|Variable]")

            for var in fetch:
                if isinstance(var, str):
                    fetch_var_names.append(var)
                elif isinstance(var, paddle.static.Variable):
                    fetch_var_names.append(var.name)
                else:
                    raise ValueError("feed must be [str|Variable]")

            fetch_vars = [
                paddle.static.default_main_program().global_block().var(name)
                for name in fetch_var_names
            ]

781 782 783 784
            self._runtime_handle._save_inference_model(executor, dirname,
                                                       feeded_var_names,
                                                       fetch_vars, None, True,
                                                       0)
T
tangwei12 已提交
785 786 787 788
        else:
            increment_mode = 0
            if "mode" in configs:
                increment_mode = int(configs["mode"])
789 790 791 792
            self._runtime_handle._save_persistables(executor,
                                                    dirname,
                                                    main_program=None,
                                                    mode=increment_mode)
T
tangwei12 已提交
793

Z
zmxdream 已提交
794 795
    @is_non_distributed_check
    @inited_runtime_handler
796 797 798 799 800 801
    def save_inference_model(self,
                             executor,
                             dirname,
                             feeded_var_names,
                             target_vars,
                             main_program=None,
802 803
                             export_for_deployment=True,
                             mode=0):
804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822
        """
        save inference model for inference.

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_server()

        """
T
tangwei12 已提交
823 824 825
        # warnings.warn(
        #     "'save_inference_model' is a deprecated, will be deleted after v2.2.0, Please use fleet.save instead."
        # )
826

827 828 829 830
        self._runtime_handle._save_inference_model(executor, dirname,
                                                   feeded_var_names,
                                                   target_vars, main_program,
                                                   export_for_deployment, mode)
831

Z
zmxdream 已提交
832 833
    @is_non_distributed_check
    @inited_runtime_handler
834
    def save_persistables(self, executor, dirname, main_program=None, mode=0):
835 836
        """

1
123malin 已提交
837
        saves all persistable tensors from :code:`main_program` to
838 839
        the folder :code:`dirname`. You can refer to

1
123malin 已提交
840 841
        The :code:`dirname` is used to specify the folder where persistable tensors
        are going to be saved. If you would like to save tensors in separate
842 843 844
        files, set :code:`filename` None.

        Args:
1
123malin 已提交
845
            executor(Executor): The executor to run for saving persistable tensors.
846 847 848 849 850
                                You can refer to :ref:`api_guide_executor_en` for
                                more details.

            dirname(str, optional): The saving directory path.
                                When you need to save the parameter to the memory, set it to None.
1
123malin 已提交
851
            main_program(Program, optional): The program whose persistbale tensors will
852 853 854 855 856 857 858 859 860 861
                                             be saved. Default: None.


        Returns:
            None

        Examples:

            .. code-block:: text

1
123malin 已提交
862 863
                import paddle
                paddle.enable_static()
864 865 866 867 868 869 870
                import paddle.distributed.fleet as fleet

                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

1
123malin 已提交
871 872
                exe = paddle.static.Executor(paddle.CPUPlace())
                fleet.save_persistables(exe, "dirname", paddle.static.default_main_program())
873 874

        """
T
tangwei12 已提交
875 876 877
        # warnings.warn(
        #     "'save_persistables' is a deprecated, will be deleted after v2.2.0, Please use fleet.save instead."
        # )
878

879 880
        self._runtime_handle._save_persistables(executor, dirname, main_program,
                                                mode)
881

Z
zhaocaibei123 已提交
882 883 884 885 886
    @is_non_distributed_check
    @inited_runtime_handler
    def save_cache_model(self, dirname, **configs):
        return self._runtime_handle._save_cache_model(dirname, **configs)

887
    def shrink(self, threshold=None):
888 889
        self._runtime_handle._shrink(threshold)

890
    def distributed_optimizer(self, optimizer, strategy=None):
891
        """
892 893 894 895 896 897 898
        Optimizer for distributed training.

        For the distributed training, this method would rebuild a new instance of DistributedOptimizer.
        Which has basic Optimizer function and special features for distributed training.

        Args:
            optimizer(Optimizer): The executor to run for init server.
899 900 901 902 903
            strategy(DistributedStrategy): Extra properties for distributed optimizer. 
                It is recommended to use DistributedStrategy in fleet.init(). The strategy
                here is for compatibility. If the strategy in fleet.distributed_optimizer() 
                is not None, then it will overwrite the DistributedStrategy in fleet.init(), 
                which will take effect in distributed training.
904

905
        Returns:
906
            Fleet: instance of fleet.
907 908

        Examples:
909

910
            .. code-block:: python
911

1
123malin 已提交
912
                import paddle
913
                import paddle.distributed.fleet as fleet
1
123malin 已提交
914
                fleet.init(is_collective=True)
915 916 917 918
                strategy = fleet.DistributedStrategy()
                optimizer = paddle.optimizer.SGD(learning_rate=0.001)
                optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)

919 920
        """
        self.user_defined_optimizer = optimizer
921

922
        if strategy is not None:
T
tangwei12 已提交
923 924 925 926 927 928 929
            if self._is_collective:
                warnings.warn(
                    "It is recommended to use DistributedStrategy "
                    "in fleet.init(). The strategy here is only for compatibility. "
                    "If the strategy in fleet.distributed_optimizer() is "
                    "not None, then it will overwrite the DistributedStrategy in fleet.init(), "
                    "which will take effect in distributed training.")
930
            self._user_defined_strategy = copy.deepcopy(strategy)
D
Dong Daxiang 已提交
931 932

        self._context = {}
S
ShenLiang 已提交
933

J
Jiabin Yang 已提交
934
        if paddle.fluid.framework._non_static_mode():
935
            if self.worker_num() > 1:
K
kuizhiqing 已提交
936 937 938 939 940 941
                if self._user_defined_strategy.heter_ccl_mode == False:
                    return HybridParallelOptimizer(optimizer, self._hcg,
                                                   self._user_defined_strategy)
                else:
                    return HeterParallelOptimizer(optimizer,
                                                  self._user_defined_strategy)
942 943
            else:
                return optimizer
944 945
        return self

946
    @dygraph_only
947
    def distributed_model(self, model):
948
        """
949 950 951 952 953 954 955
        Return distributed data parallel model (Only work in dygraph mode)

        Args:
            model (Layer): the user-defind model which inherits Layer.

        Returns:
            distributed data parallel model which inherits Layer.
956 957

        Examples:
958

959 960
            .. code-block:: python

961 962 963 964 965 966 967 968 969
                import paddle
                import paddle.nn as nn
                from paddle.distributed import fleet

                class LinearNet(nn.Layer):
                    def __init__(self):
                        super(LinearNet, self).__init__()
                        self._linear1 = nn.Linear(10, 10)
                        self._linear2 = nn.Linear(10, 1)
970

971 972
                    def forward(self, x):
                        return self._linear2(self._linear1(x))
973

1
123malin 已提交
974
                # 1. initialize fleet environment
975 976
                fleet.init(is_collective=True)

1
123malin 已提交
977
                # 2. create layer & optimizer
978 979 980 981 982
                layer = LinearNet()
                loss_fn = nn.MSELoss()
                adam = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=layer.parameters())

1
123malin 已提交
983
                # 3. get data_parallel model using fleet
984 985 986
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)

1
123malin 已提交
987
                # 4. run layer
988 989 990 991 992 993 994 995 996 997 998 999
                inputs = paddle.randn([10, 10], 'float32')
                outputs = dp_layer(inputs)
                labels = paddle.randn([10, 1], 'float32')
                loss = loss_fn(outputs, labels)

                print("loss:", loss.numpy())

                loss.backward()

                adam.step()
                adam.clear_grad()

1000

1001
        """
1002 1003 1004
        assert model is not None, "model should not be None"
        if self.worker_num() <= 1:
            return model
J
JZ-LIANG 已提交
1005

1006 1007 1008 1009 1010 1011 1012
        amp_enable = False
        recompute_enable = False
        strategy = self._user_defined_strategy
        if strategy.amp == True:
            amp_enable = True
            amp_level = "O2" if strategy.amp_configs['use_pure_fp16'] else "O1"
            if amp_level.upper() == "O2":
1013 1014 1015 1016 1017
                model = paddle.amp.decorate(models=model,
                                            optimizers=None,
                                            level="O2",
                                            master_weight=None,
                                            save_dtype=None)
1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039
            init_loss_scaling = strategy.amp_configs['init_loss_scaling']
            incr_ratio = strategy.amp_configs['incr_ratio']
            decr_ratio = strategy.amp_configs['decr_ratio']
            incr_every_n_steps = strategy.amp_configs['incr_every_n_steps']
            decr_every_n_nan_or_inf = strategy.amp_configs[
                'decr_every_n_nan_or_inf']
            use_dynamic_loss_scaling = strategy.amp_configs[
                'use_dynamic_loss_scaling']

            global _grad_scalar
            _grad_scalar = paddle.amp.GradScaler(
                init_loss_scaling=init_loss_scaling,
                incr_ratio=incr_ratio,
                decr_ratio=decr_ratio,
                incr_every_n_steps=incr_every_n_steps,
                decr_every_n_nan_or_inf=decr_every_n_nan_or_inf,
                use_dynamic_loss_scaling=use_dynamic_loss_scaling)

        if strategy.recompute == True:
            recompute_enable = True
            model = _RecomputeModelWrapper(model)

K
kuizhiqing 已提交
1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050
        if self._user_defined_strategy.heter_ccl_mode == True:
            distributed_model = paddle.DataParallel(
                model,
                comm_buffer_size=self._user_defined_strategy.
                fuse_grad_size_in_MB,
                last_comm_buffer_size=self._user_defined_strategy.
                last_comm_group_size_MB,
                find_unused_parameters=self._user_defined_strategy.
                find_unused_parameters)
            return distributed_model

J
JZ-LIANG 已提交
1051
        if self._hcg.get_parallel_mode() == ParallelMode.SHARDING_PARALLEL:
1052 1053 1054
            model = ShardingParallel(model,
                                     self._hcg,
                                     strategy=self._user_defined_strategy)
J
JZ-LIANG 已提交
1055
        elif self._hcg.get_parallel_mode() == ParallelMode.DATA_PARALLEL:
1056 1057 1058 1059 1060 1061 1062 1063

            # NOTE (JZ-LIANG) init parameters broadcast within sharding group
            # normally it should be done inside DataParallel
            if self.sharding_degree > 1:
                from paddle.distributed.fleet.utils.hybrid_parallel_util import broadcast_mp_parameters, broadcast_sharding_parameters
                assert self.sharding_degree == self._hcg.get_sharding_parallel_world_size(
                )
                broadcast_sharding_parameters(model, self._hcg)
1064
            model = paddle.DataParallel(
1065 1066 1067 1068 1069 1070
                model,
                comm_buffer_size=self._user_defined_strategy.
                fuse_grad_size_in_MB,
                last_comm_buffer_size=self._user_defined_strategy.
                last_comm_group_size_MB,
                find_unused_parameters=self._user_defined_strategy.
1071
                find_unused_parameters)
1072
        elif self._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL:
1073 1074 1075
            model = TensorParallel(model,
                                   self._hcg,
                                   strategy=self._user_defined_strategy)
1076
        elif self._hcg.get_parallel_mode() == ParallelMode.PIPELINE_PARALLEL:
1077 1078 1079
            model = PipelineParallel(model,
                                     self._hcg,
                                     strategy=self._user_defined_strategy)
J
JZ-LIANG 已提交
1080

1081
        return model
1082 1083 1084 1085 1086

    @dygraph_only
    def state_dict(self):
        """
        Get state dict information from optimizer.
1087
        (Only work in dygraph mode)
1088 1089 1090 1091 1092 1093 1094

        Returns: 
            state_dict(dict) : dict contains all the Tensor used by optimizer

        Examples:
            .. code-block:: python

1095 1096 1097 1098 1099
                import numpy as np
                import paddle
                from paddle.distributed import fleet

                fleet.init(is_collective=True)
1100

1101
                value = np.arange(26).reshape(2, 13).astype("float32")
1
123malin 已提交
1102
                a = paddle.to_tensor(value)
1103

1104 1105
                layer = paddle.nn.Linear(13, 5)
                adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
1106

1107 1108 1109
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)
                state_dict = adam.state_dict()
1110 1111 1112 1113 1114 1115 1116 1117
        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.state_dict()

    @dygraph_only
    def set_state_dict(self, state_dict):
        """
        Load optimizer state dict.
1118
        (Only work in dygraph mode)
1119 1120 1121 1122

        Args: 
            state_dict(dict) : Dict contains all the Tensor needed by optimizer

1123 1124
        Returns:
            None
1125 1126 1127 1128

        Examples:
            .. code-block:: python

1129 1130 1131
                import numpy as np
                import paddle
                from paddle.distributed import fleet
1132

1133 1134 1135
                fleet.init(is_collective=True)

                value = np.arange(26).reshape(2, 13).astype("float32")
1
123malin 已提交
1136
                a = paddle.to_tensor(value)
1137

1138 1139
                layer = paddle.nn.Linear(13, 5)
                adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
1140

1141 1142 1143
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)
                state_dict = adam.state_dict()
1
123malin 已提交
1144 1145 1146
                paddle.save(state_dict, "paddle_dy")
                para_state_dict = paddle.load("paddle_dy")
                adam.set_state_dict(para_state_dict)
1147 1148 1149 1150 1151 1152 1153 1154
        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.set_state_dict(state_dict)

    @dygraph_only
    def set_lr(self, value):
        """
        Set the value of the learning rate manually in the optimizer. 
1155
        (Only work in dygraph mode)
1156

1157 1158 1159
        Args:
            value (float|Tensor): the value of learning rate

1160 1161
        Returns: 
            None 
1162 1163 1164 1165

        Examples:
            .. code-block:: python

1166 1167 1168
                import numpy as np
                import paddle
                from paddle.distributed import fleet
1169

1170
                fleet.init(is_collective=True)
1171

1172
                value = np.arange(26).reshape(2, 13).astype("float32")
1
123malin 已提交
1173
                a = paddle.to_tensor(value)
1174

1175 1176
                layer = paddle.nn.Linear(13, 5)
                adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
1177

1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)

                lr_list = [0.2, 0.3, 0.4, 0.5, 0.6]
                for i in range(5):
                    adam.set_lr(lr_list[i])
                    lr = adam.get_lr()
                    print("current lr is {}".format(lr))
                # Print:
                #    current lr is 0.2
                #    current lr is 0.3
                #    current lr is 0.4
                #    current lr is 0.5
                #    current lr is 0.6
1192 1193 1194 1195 1196 1197 1198 1199
        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.set_lr(value)

    @dygraph_only
    def get_lr(self):
        """
        Get current step learning rate.
1200
        (Only work in dygraph mode)
1201 1202 1203 1204 1205

        Returns:
            float: The learning rate of the current step.

        Examples:
1
123malin 已提交
1206

1207 1208
            .. code-block:: python

1209 1210 1211 1212 1213
                import numpy as np
                import paddle
                from paddle.distributed import fleet

                fleet.init(is_collective=True)
1214

1215
                value = np.arange(26).reshape(2, 13).astype("float32")
1
123malin 已提交
1216
                a = paddle.to_tensor(value)
1217

1218 1219
                layer = paddle.nn.Linear(13, 5)
                adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
1220

1221 1222
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)
1223

1224 1225
                lr = adam.get_lr()
                print(lr) # 0.01
1226 1227 1228 1229 1230 1231 1232 1233
        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.get_lr()

    @dygraph_only
    def step(self):
        """
        Execute the optimizer once.
1234
        (Only work in dygraph mode)
1235

1236 1237
        Returns:
            None
1238 1239

        Examples:
1
123malin 已提交
1240

1241 1242
            .. code-block:: python

1243 1244 1245
                import paddle
                import paddle.nn as nn
                from paddle.distributed import fleet
1246

1247 1248 1249 1250 1251
                class LinearNet(nn.Layer):
                    def __init__(self):
                        super(LinearNet, self).__init__()
                        self._linear1 = nn.Linear(10, 10)
                        self._linear2 = nn.Linear(10, 1)
1252

1253 1254
                    def forward(self, x):
                        return self._linear2(self._linear1(x))
1255

1
123malin 已提交
1256
                # 1. initialize fleet environment
1257 1258
                fleet.init(is_collective=True)

1
123malin 已提交
1259
                # 2. create layer & optimizer
1260 1261 1262 1263 1264
                layer = LinearNet()
                loss_fn = nn.MSELoss()
                adam = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=layer.parameters())

1
123malin 已提交
1265
                # 3. get data_parallel model using fleet
1266 1267 1268
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)

1
123malin 已提交
1269
                # 4. run layer
1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289
                inputs = paddle.randn([10, 10], 'float32')
                outputs = dp_layer(inputs)
                labels = paddle.randn([10, 1], 'float32')
                loss = loss_fn(outputs, labels)

                print("loss:", loss.numpy())

                loss.backward()

                adam.step()
                adam.clear_grad()


        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.step()

    @dygraph_only
    def clear_grad(self):
        """
1290 1291
        Clear the gradients of all optimized parameters for model.
        (Only work in dygraph mode)
1292

1293 1294
        Returns: 
            None
1295 1296

        Examples:
1
123malin 已提交
1297

1298 1299
            .. code-block:: python

1300 1301 1302
                import paddle
                import paddle.nn as nn
                from paddle.distributed import fleet
1303

1304 1305 1306 1307 1308
                class LinearNet(nn.Layer):
                    def __init__(self):
                        super(LinearNet, self).__init__()
                        self._linear1 = nn.Linear(10, 10)
                        self._linear2 = nn.Linear(10, 1)
1309

1310 1311
                    def forward(self, x):
                        return self._linear2(self._linear1(x))
1312

1
123malin 已提交
1313
                # 1. initialize fleet environment
1314 1315
                fleet.init(is_collective=True)

1
123malin 已提交
1316
                # 2. create layer & optimizer
1317 1318 1319 1320 1321
                layer = LinearNet()
                loss_fn = nn.MSELoss()
                adam = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=layer.parameters())

1
123malin 已提交
1322
                # 3. get data_parallel model using fleet
1323 1324 1325
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)

1
123malin 已提交
1326
                # 4. run layer
1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342
                inputs = paddle.randn([10, 10], 'float32')
                outputs = dp_layer(inputs)
                labels = paddle.randn([10, 1], 'float32')
                loss = loss_fn(outputs, labels)

                print("loss:", loss.numpy())

                loss.backward()

                adam.step()
                adam.clear_grad()

        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.clear_grad()

1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359
    def _get_amp_optimizer(self):
        # imitate target optimizer retrieval
        amp_optimizer = None
        for optimizer in self.strategy_compiler._get_applied_meta_optimizer():
            if hasattr(optimizer, 'amp_init'):
                amp_optimizer = optimizer
                break

        if amp_optimizer is None:
            if hasattr(self.user_defined_optimizer, 'amp_init'):
                amp_optimizer = self.user_defined_optimizer

        assert amp_optimizer is not None, \
            "amp_init can only be used when the amp(auto mixed precision) strategy is turned on."
        return amp_optimizer

    def get_loss_scaling(self):
1360 1361
        """Return the real-time loss scaling factor.
        """
1362 1363 1364
        amp_optimizer = self._get_amp_optimizer()
        return amp_optimizer.get_loss_scaling()

H
huangxu96 已提交
1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424
    def amp_init(self,
                 place,
                 scope=None,
                 test_program=None,
                 use_fp16_test=False):
        """
        Init the amp training, such as cast fp32 parameters to fp16 type.
  
        Args:
            place(CUDAPlace): place is used to initialize 
                fp16 parameters with fp32 values.
            scope(Scope): The scope is used to find fp32 parameters.
            test_program(Program): The program is used for testing.
            use_fp16_test(bool): Whether to use fp16 testing.
            
        Examples:
            .. code-block:: python

                import numpy as np
                import paddle
                import paddle.nn.functional as F
                paddle.enable_static()

                def run_example_code():
                    place = paddle.CUDAPlace(0)
                    exe = paddle.static.Executor(place)
                    data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32')
                    conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3)
                    # 1) Use fp16_guard to control the range of fp16 kernels used.
                    with paddle.static.amp.fp16_guard():
                        bn = paddle.static.nn.batch_norm(input=conv2d, act="relu")
                        pool = F.max_pool2d(bn, kernel_size=2, stride=2)
                        hidden = paddle.static.nn.fc(pool, size=10)
                        loss = paddle.mean(hidden)
                    # 2) Create the optimizer and set `multi_precision` to True.
                    # Setting `multi_precision` to True can avoid the poor accuracy
                    # or the slow convergence in a way. 
                    optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True)
                    # 3) These ops in `custom_black_list` will keep in the float32 computation type.
                    amp_list = paddle.static.amp.CustomOpLists(
                        custom_black_list=['pool2d'])
                    # 4) The entry of Paddle AMP.
                    # Enable pure fp16 training by setting `use_pure_fp16` to True.
                    optimizer = paddle.static.amp.decorate(
                        optimizer,
                        amp_list,
                        init_loss_scaling=128.0,
                        use_dynamic_loss_scaling=True,
                        use_pure_fp16=True)
                    # If you don't use the default_startup_program(), you sholud pass
                    # your defined `startup_program` into `minimize`.
                    optimizer.minimize(loss)
                    exe.run(paddle.static.default_startup_program())
                    # 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`).
                    # If you want to perform the testing process, you should pass `test_program` into `amp_init`.
                    optimizer.amp_init(place, scope=paddle.static.global_scope())
                    
                if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
                    run_example_code()       
        """
1425
        amp_optimizer = self._get_amp_optimizer()
1426
        return amp_optimizer.amp_init(place, scope, test_program, use_fp16_test)
H
huangxu96 已提交
1427

D
Dong Daxiang 已提交
1428 1429 1430 1431 1432 1433 1434 1435 1436
    def _final_strategy(self):
        if "valid_strategy" not in self._context:
            print(
                "WARNING: You may need to call minimize function before this function is called"
            )
            return {}
        else:
            return self._context["valid_strategy"]

1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454
    def _get_applied_meta_list(self):
        if "applied_meta_list" not in self._context:
            print(
                "WARNING: You may need to call minimize function before _get_applied_meta_list called"
            )
            return []
        else:
            return self._context["applied_meta_list"]

    def _get_applied_graph_list(self):
        if "applied_graph_list" not in self._context:
            print(
                "WARNING: You may need to call minimize function before _get_applied_graph_list called"
            )
            return []
        else:
            return self._context["applied_graph_list"]

1455 1456 1457 1458 1459 1460 1461 1462 1463
    def minimize(self,
                 loss,
                 startup_program=None,
                 parameter_list=None,
                 no_grad_set=None):
        """
        Add distributed operations to minimize ``loss`` by updating ``parameter_list``.

        Args:
1
123malin 已提交
1464
            loss (Tensor): A ``Tensor`` containing the value to minimize.
1465 1466 1467
            startup_program (Program, optional): :ref:`api_fluid_Program` for
                initializing parameters in ``parameter_list``. The default value
                is None, at this time :ref:`api_fluid_default_startup_program` will be used.
1
123malin 已提交
1468
            parameter_list (Iterable, optional): Iterable of ``Tensor`` or ``Tensor.name`` to update
1469 1470
                to minimize ``loss``. The default value is None, at this time all parameters
                will be updated.
1
123malin 已提交
1471
            no_grad_set (set, optional): Set of ``Tensor``  or ``Tensor.name`` that don't need
1472 1473 1474 1475
                to be updated. The default value is None.

        Returns:
            tuple: tuple (optimize_ops, params_grads), A list of operators appended
1
123malin 已提交
1476
            by minimize and a list of (param, grad) tensor pairs, param is
1477
            ``Parameter``, grad is the gradient value corresponding to the parameter.
1478 1479
            The returned tuple can be passed to ``fetch_list`` in ``Executor.run()`` to
            indicate program pruning. If so, the program will be pruned by ``feed`` and
1480 1481 1482
            ``fetch_list`` before run, see details in ``Executor``.

        Examples:
1
123malin 已提交
1483

1484
            .. code-block:: python
1485

1486
                import paddle
1
123malin 已提交
1487
                paddle.enable_static()
1488
                import paddle.distributed.fleet as fleet
1
123malin 已提交
1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499
                import paddle.nn.functional as F

                hid_dim = 10
                label_dim = 2
                input_x = paddle.static.data(name='x', shape=[None, 13], dtype='float32')
                input_y = paddle.static.data(name='y', shape=[None, 1], dtype='int64')
                fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim, activation='tanh')
                fc_2 = paddle.static.nn.fc(x=fc_1, size=hid_dim, activation='tanh')
                prediction = paddle.static.nn.fc(x=[fc_2], size=label_dim, activation='softmax')
                cost = F.cross_entropy(input=prediction, label=input_y)
                avg_cost = paddle.mean(x=cost)
1500

1
123malin 已提交
1501
                fleet.init(is_collective=True)
1502 1503 1504 1505
                strategy = fleet.DistributedStrategy()
                optimizer = paddle.optimizer.SGD(learning_rate=0.001)
                optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
                optimizer.minimize(avg_cost)
1506

1507
                # for more examples, please reference https://github.com/PaddlePaddle/FleetX
1508 1509

        """
1510 1511 1512 1513
        if not isinstance(loss, list):
            return self._minimize_impl(loss, startup_program, parameter_list,
                                       no_grad_set)
        else:
J
Jiabin Yang 已提交
1514
            if paddle.fluid.framework._non_static_mode(
1515 1516 1517 1518 1519 1520 1521 1522 1523 1524
            ) or self._role_maker._is_non_distributed() or self._is_collective:
                raise ValueError("loss can be list only in PS mode")
            return self._minimize_losses_impl(loss, startup_program,
                                              parameter_list, no_grad_set)

    def _minimize_impl(self,
                       loss,
                       startup_program=None,
                       parameter_list=None,
                       no_grad_set=None):
D
Dong Daxiang 已提交
1525 1526 1527
        context = {}
        context["user_defined_strategy"] = copy.deepcopy(
            self._user_defined_strategy)
J
Jiabin Yang 已提交
1528
        if paddle.fluid.framework._non_static_mode():
1529 1530
            # imitate target optimizer retrieval
            target_opt = self.user_defined_optimizer
D
Dong Daxiang 已提交
1531
            self._context = context
1532 1533
            return target_opt.minimize(loss)

1534 1535
        # cache original feed forward program
        self.origin_main_program = loss.block.program
B
Baibaifan 已提交
1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551
        # add distributed attr
        if not hasattr(self.origin_main_program, "distributed_info_"):
            setattr(self.origin_main_program, "distributed_info_", dict())
            self.origin_main_program.distributed_info_[
                "dp_degree"] = self._user_defined_strategy.sharding_configs[
                    "dp_degree"]
            self.origin_main_program.distributed_info_[
                "mp_degree"] = self._user_defined_strategy.sharding_configs[
                    "mp_degree"]
            self.origin_main_program.distributed_info_[
                "pp_degree"] = self._user_defined_strategy.sharding_configs[
                    "pp_degree"]
            self.origin_main_program.distributed_info_[
                "sharding_degree"] = self._user_defined_strategy.sharding_configs[
                    "sharding_degree"]

1552
        context["origin_main_program"] = self.origin_main_program
1553
        context["origin_main_programs"] = [self.origin_main_program]
1554
        context["loss"] = loss
1555 1556
        if startup_program == None:
            self.origin_startup_program = \
1557 1558
                paddle.static.default_startup_program().clone(for_test=False)
            startup_program = paddle.static.default_startup_program()
1559 1560 1561
        else:
            self.origin_startup_program = \
                startup_program.clone(for_test=False)
1562

1563
        context["origin_startup_program"] = startup_program
1564
        context["origin_startup_programs"] = [startup_program]
1565
        context["role_maker"] = self._role_maker
1566

1567
        # Use the auto-parallel's routines instead
1568
        if self._user_defined_strategy.semi_auto or self._user_defined_strategy.auto_search:
1569 1570 1571 1572
            from ...auto_parallel.parallelizer import AutoParallelizer
            auto_parallelizer = AutoParallelizer(self)
            optimize_ops, params_grads, dist_startup_prog, dist_main_prog = auto_parallelizer.parallelize(
                loss, startup_program, parameter_list, no_grad_set)
1573

1574 1575
            return optimize_ops, params_grads, dist_startup_prog, dist_main_prog

1576 1577 1578 1579
        # compile time
        distributed_optimizer_list = \
            MetaOptimizerFactory()._get_valid_meta_optimizers(
                self.user_defined_optimizer)
D
Dong Daxiang 已提交
1580

D
Dong Daxiang 已提交
1581 1582 1583
        context["user_defined_strategy"] = copy.deepcopy(
            self._user_defined_strategy)
        copy_user_defined_strategy = copy.deepcopy(self._user_defined_strategy)
1584 1585 1586 1587 1588 1589

        # trigger the auto-parallel in very strict condition
        # strategy = DistributedStrategy()
        # strategy.auto = True
        # optimizer = paddle.optimizer.SGD(learning_rate=0.1)
        # optimizer = fleet.distributed_optimizer(optimizer, strategy)
D
Dong Daxiang 已提交
1590
        if copy_user_defined_strategy._is_strict_auto():
1591 1592
            # turn on all the strategy for each optimizer
            for opt in distributed_optimizer_list:
D
Dong Daxiang 已提交
1593
                opt._enable_strategy(copy_user_defined_strategy, context)
1594

1595 1596
        valid_optimizer_list = []
        valid_graph_optimizer_list = []
D
Dong Daxiang 已提交
1597
        can_not_apply_optimizer_list = []
1598 1599 1600 1601
        # recall meta optimizers for ranking
        for opt in distributed_optimizer_list:
            opt._set_basic_info(loss, self._role_maker,
                                self.user_defined_optimizer,
D
Dong Daxiang 已提交
1602
                                copy_user_defined_strategy)
1603 1604
            if opt._can_apply() and not opt._is_graph_out():
                valid_optimizer_list.append(opt)
D
Dong Daxiang 已提交
1605
            elif opt._can_apply() and opt._is_graph_out():
1606
                valid_graph_optimizer_list.append(opt)
D
Dong Daxiang 已提交
1607 1608
            else:
                can_not_apply_optimizer_list.append(opt)
1609
        # combine recalled meta optimizers to be a valid meta optimizer
D
Dong Daxiang 已提交
1610
        meta_optimizer, graph_optimizer = \
1611 1612
            self.strategy_compiler.generate_optimizer(
                loss, self._role_maker, self.user_defined_optimizer,
D
Dong Daxiang 已提交
1613
                copy_user_defined_strategy, valid_optimizer_list,
1614
                valid_graph_optimizer_list)
D
Dong Daxiang 已提交
1615

D
Dong Daxiang 已提交
1616
        valid_strategy = self.strategy_compiler._get_valid_strategy(
D
Dong Daxiang 已提交
1617 1618 1619
            copy_user_defined_strategy, can_not_apply_optimizer_list)

        context["valid_strategy"] = copy.deepcopy(valid_strategy)
1620 1621
        # print("valid_strategy:", context["valid_strategy"])
        # print("user_defined_strategy:", context["user_defined_strategy"])
1622

1623 1624 1625 1626 1627 1628
        applied_meta_list = self.strategy_compiler._get_applied_meta_list()
        applied_graph_list = self.strategy_compiler._get_applied_graph_list()

        context['applied_meta_list'] = applied_meta_list
        context['applied_graph_list'] = applied_graph_list

D
Dong Daxiang 已提交
1629
        self._context = context
1630

D
Dong Daxiang 已提交
1631
        self.valid_strategy = valid_strategy
1632
        self.valid_strategy._enable_env()
D
Dong Daxiang 已提交
1633

1634 1635
        optimize_ops = []
        params_grads = []
1636

1637 1638 1639 1640 1641 1642 1643 1644
        if self._role_maker._is_non_distributed() and not self._is_collective:
            if self._runtime_handle is None:
                self._runtime_handle = RuntimeFactory()._create_runtime(context)

            compiled_program = compiler.CompiledProgram(
                self.origin_main_program).with_data_parallel(
                    loss_name=loss.name, share_vars_from=None)
            loss.block.program._graph = compiled_program
1645 1646 1647 1648
            return self.user_defined_optimizer.minimize(loss,
                                                        startup_program,
                                                        parameter_list,
                                                        no_grad_set=no_grad_set)
1649

1650
        if meta_optimizer:
1651
            # print("before minimize program id:", id(loss.block.program))
1652
            optimize_ops, params_grads = meta_optimizer.minimize(
M
MRXLT 已提交
1653
                loss, startup_program, parameter_list, no_grad_set=no_grad_set)
1654
            # print("after minimize program id:", id(loss.block.program))
1655

1656
            default_program = paddle.static.default_main_program()
1657
            # print("default program id:", id(default_program))
1658 1659 1660

            if id(default_program) != id(loss.block.program):
                paddle.fluid.framework.switch_main_program(loss.block.program)
1661
            # print("default program id after switch:", id(default_program))
1662

1663 1664
        else:
            optimize_ops, params_grads = self.user_defined_optimizer.minimize(
M
MRXLT 已提交
1665
                loss, startup_program, parameter_list, no_grad_set=no_grad_set)
1666

1667 1668
        context["program_optimize_ops"] = optimize_ops
        context["program_params_grads"] = params_grads
1669

1670
        if graph_optimizer:
1671
            # print("before graph minimize program id:", id(loss.block.program))
D
Dong Daxiang 已提交
1672
            optimize_ops, params_grads = graph_optimizer.minimize(
M
MRXLT 已提交
1673
                loss, startup_program, parameter_list, no_grad_set=no_grad_set)
1674 1675 1676 1677
            # since we do not encourage users to use graph operations
            # if a graph optimizer takes effect, mostly
            # optimizers_ops and params_grads are None
            # i.e. users can not modify current computation graph anymore
1678 1679
            context["graph_optimize_ops"] = optimize_ops
            context["graph_optimize_grads"] = params_grads
1680 1681
        else:
            apply_ir_passes(loss.block.program, startup_program, self)
1682

1683 1684
        if not self._role_maker._is_heter_parameter_server_mode:
            program = paddle.static.default_main_program()
1685 1686 1687 1688 1689
            opt_info = {} if program._fleet_opt is None else program._fleet_opt
            opt_info["mpi_size"] = self.worker_num()
            opt_info["mpi_rank"] = self.worker_index()
            for k, v in self._user_defined_strategy.trainer_desc_configs.items(
            ):
1690
                if v or k not in opt_info:
1691
                    opt_info[k] = v
1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762
            program._fleet_opt = opt_info

        if self._runtime_handle is None:
            self._runtime_handle = RuntimeFactory()._create_runtime(context)

        import paddle.distributed.fleet as fleet
        fleet.util._set_strategy(context["valid_strategy"])

        return optimize_ops, params_grads

    def _minimize_losses_impl(self,
                              losses,
                              startup_programs=None,
                              parameter_list=None,
                              no_grad_set=None):
        context = {}

        # cache original feed forward program
        self.origin_main_program = losses[0].block.program
        context["origin_main_program"] = self.origin_main_program
        context["origin_main_programs"] = []
        for loss in losses:
            context["origin_main_programs"].append(loss.block.program)
        context["loss"] = losses

        if startup_programs is None:
            if len(losses) == 1:
                startup_programs = [paddle.static.default_startup_program()]
            else:
                raise ValueError(
                    "startup_program can't be None when loss is list.")
        self.origin_startup_program = startup_programs[0].clone(for_test=False)
        context["origin_startup_program"] = startup_programs[0]
        context["origin_startup_programs"] = []
        for program in startup_programs:
            context["origin_startup_programs"].append(program)

        context["role_maker"] = self._role_maker

        context["user_defined_strategy"] = copy.deepcopy(
            self._user_defined_strategy)

        context["valid_strategy"] = copy.deepcopy(self._user_defined_strategy)

        self._context = context

        self.valid_strategy = context["valid_strategy"]
        self.valid_strategy._enable_env()

        optimize_ops = []
        params_grads = []

        from ..meta_optimizers import ParameterServerOptimizer
        ps_optimizer = ParameterServerOptimizer(self.user_defined_optimizer)
        ps_optimizer._set_basic_info(losses, self._role_maker,
                                     self.user_defined_optimizer,
                                     self._user_defined_strategy)
        optimize_ops, params_grads = ps_optimizer.minimize_losses_impl(
            losses, startup_programs, parameter_list, no_grad_set=no_grad_set)

        # default_program = paddle.static.default_main_program()

        # if id(default_program) != id(losses[0].block.program):
        #     paddle.fluid.framework.switch_main_program(losses[0].block.program)

        context["program_optimize_ops"] = optimize_ops
        context["program_params_grads"] = params_grads

        for loss in losses:
            program = loss.block.program
            opt_info = {} if program._fleet_opt is None else program._fleet_opt
1763 1764 1765 1766
            opt_info["mpi_size"] = self.worker_num()
            opt_info["mpi_rank"] = self.worker_index()
            for k, v in self._user_defined_strategy.trainer_desc_configs.items(
            ):
1767
                if v or k not in opt_info:
1768
                    opt_info[k] = v
1769
            program._fleet_opt = opt_info
1770
            # print("fleet base opt info:", id(program), program._fleet_opt)
1771

1772
        if self._runtime_handle is None:
1773
            self._runtime_handle = RuntimeFactory()._create_runtime(context)
1774

1775 1776
        import paddle.distributed.fleet as fleet
        fleet.util._set_strategy(context["valid_strategy"])
1777 1778

        return optimize_ops, params_grads
1779 1780 1781

    @dygraph_only
    def distributed_scaler(self, scaler):
1782

1783 1784 1785 1786 1787 1788
        def unscale_method(self, optimizer):
            if not self._enable:
                return
            if getattr(optimizer, '_param_groups', None) and isinstance(
                    optimizer._param_groups[0], dict):
                param_grads = []
1789 1790
                param_grads_fp16 = []
                param_grads_fp32 = []
1791 1792 1793 1794
                for group in optimizer._param_groups:
                    for param in group['params']:
                        if param._grad_ivar() is not None:
                            param_grads.append(param._grad_ivar())
1795 1796 1797 1798 1799
                            if param._grad_ivar(
                            ).dtype == core.VarDesc.VarType.FP16:
                                param_grads_fp16.append(param._grad_ivar())
                            else:
                                param_grads_fp32.append(param._grad_ivar())
1800 1801 1802 1803 1804
            else:
                param_grads = [
                    param._grad_ivar() for param in optimizer._parameter_list
                    if param._grad_ivar() is not None
                ]
1805 1806
                param_grads_fp16 = [
                    param._grad_ivar() for param in optimizer._parameter_list
1807 1808
                    if (param._grad_ivar() is not None) and (
                        param._grad_ivar().dtype == core.VarDesc.VarType.FP16)
1809 1810 1811
                ]
                param_grads_fp32 = [
                    param._grad_ivar() for param in optimizer._parameter_list
1812 1813
                    if (param._grad_ivar() is not None) and (
                        param._grad_ivar().dtype == core.VarDesc.VarType.FP32)
1814
                ]
1815 1816
            temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool_))
            temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool_))
1817 1818 1819 1820 1821 1822 1823 1824
            if len(param_grads_fp16):
                _C_ops.check_finite_and_unscale(param_grads_fp16, self._scale,
                                                param_grads_fp16,
                                                temp_found_inf_fp16)
            if len(param_grads_fp32):
                _C_ops.check_finite_and_unscale(param_grads_fp32, self._scale,
                                                param_grads_fp32,
                                                temp_found_inf_fp32)
1825

1826
            self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0
1827
            is_found_inf = paddle.to_tensor([self._found_inf], dtype="int32")
1828

1829 1830
            # TODO(shenliang03) Since dp allreduce in the optimizer is
            # after the gradscaler, check_finite needs to synchronize global
1831
            # information. In the future, we should use check_group to speed.
1832 1833 1834
            paddle.distributed.all_reduce(is_found_inf,
                                          op=paddle.distributed.ReduceOp.MAX,
                                          group=None)
1835
            self._found_inf = is_found_inf.numpy()[0]
1836 1837 1838 1839 1840 1841 1842

        # Only tensor_parallel and pipeline_parallel need to modify scaler
        if self._hcg.get_parallel_mode() in (ParallelMode.TENSOR_PARALLEL,
                                             ParallelMode.PIPELINE_PARALLEL):
            scaler._unscale = MethodType(unscale_method, scaler)

        return scaler