fleet_base.py 65.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
16
import copy
17
import warnings
18
import paddle
19
import os
20
from types import MethodType
21
import numpy as np
22
from paddle.fluid.framework import dygraph_only, _global_flags
23
from paddle.fluid import compiler
24
from .role_maker import UserDefinedRoleMaker, PaddleCloudRoleMaker, RoleMakerBase
25
from .strategy_compiler import StrategyCompiler
26
from .distributed_strategy import DistributedStrategy
27 28
from .meta_optimizer_factory import MetaOptimizerFactory
from .runtime_factory import RuntimeFactory
29
from paddle.fluid.wrapped_decorator import wrap_decorator
30
from paddle.fluid.dygraph import parallel_helper
31
from paddle.fluid.ir import apply_build_strategy
32
from . import topology as tp
33
from .topology import ParallelMode
34
from ..meta_parallel import TensorParallel, model_parallel_random_seed
J
JZ-LIANG 已提交
35
from ..meta_parallel import PipelineParallel, ShardingParallel
K
kuizhiqing 已提交
36
from ..meta_optimizers import HybridParallelOptimizer, HeterParallelOptimizer
37
from paddle import _C_ops
38 39
from paddle.fluid import core
from paddle.fluid.dygraph import to_variable
40 41
from paddle.distributed.fleet.utils.recompute import RecomputeFunction
from paddle.fluid.dygraph.varbase_patch_methods import _grad_scalar
42

43 44
__all__ = []

45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
_grad_scalar = None


class _RecomputeModelWrapper(paddle.nn.Layer):
    def __init__(self, model, segments=2, preserve_rng_state=True):
        super(_RecomputeModelWrapper, self).__init__()
        assert isinstance(model, paddle.nn.Sequential), (
            "The model passed to RecomputeModelWrapper must be of type "
            "paddle.nn.Sequential.")
        self._model = model
        self._segments = segments
        self._preserve_rng_state = preserve_rng_state
        self._layers = list(model.children())
        self._segment_size = len(self._layers) // segments

    def _run_func(self, begin, end):
        def do_run(input):
            for i in range(begin, end):
                input = self._layers[i](input)
            return input

        return do_run

    def _checkpoint(self, func, *args, **kwargs):
        return RecomputeFunction.apply(func, self._preserve_rng_state, *args)

    def forward(self, input):
        end = 0
        for begin in range(0, self._segment_size * (self._segments - 1),
                           self._segment_size):
            end = begin + self._segment_size
            input = self._checkpoint(self._run_func(begin, end), input)
        return self._run_func(end, len(self._layers))(input)

79

80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
def apply_ir_passes(main_program, startup_program, config):
    build_strategy = config._user_defined_strategy.build_strategy._copy()
    if not _global_flags()['FLAGS_apply_pass_to_program']:
        return build_strategy

    pipeline_opt = getattr(main_program, "_pipeline_opt", {})
    if pipeline_opt:
        main_program = pipeline_opt["section_program"]
        startup_program = startup_program._pipeline_opt["startup_program"]

    pass_attrs = {"use_cuda": config._is_collective}
    fuse_all_reduce = config._user_defined_strategy.fuse_all_reduce_ops
    if fuse_all_reduce and build_strategy.fuse_all_optimizer_ops:
        # FIXME(zjl): currently, fuse_all_optimizer_ops
        # have conflict with fuse_all_reduce_ops because 
        # RawProgramOptimizer also inserts coalesce_tensor 
        # into program. These two procedures may conflict  
        # in which vars are to be fused. 
        warnings.warn(
            'Currently, the fuse_all_optimizer_ops pass has conflict with fuse_all_reduce_ops pass. Disable the fuse_all_optimizer_ops pass temporarily.'
        )
        build_strategy.fuse_all_optimizer_ops = False

    return apply_build_strategy(main_program, startup_program, build_strategy,
                                pass_attrs)


107 108 109 110 111 112 113 114 115 116 117 118
def _inited_runtime_handler_(func):
    def __impl__(*args, **kwargs):
        cls = args[0]

        if cls._runtime_handle is None:
            raise ValueError("Fleet can not find suitable runtime handler")

        return func(*args, **kwargs)

    return __impl__


119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
def _is_non_distributed_check_(func):
    def __impl__(*args, **kwargs):
        cls = args[0]

        if cls._role_maker is not None and cls._role_maker._is_non_distributed(
        ) is True:
            warnings.warn(
                "%s() function doesn't work when use non_distributed fleet." %
                (func.__name__))
            return

        return func(*args, **kwargs)

    return __impl__


135
inited_runtime_handler = wrap_decorator(_inited_runtime_handler_)
136
is_non_distributed_check = wrap_decorator(_is_non_distributed_check_)
137 138


139 140 141
class Fleet(object):
    """
    Unified API for distributed training of PaddlePaddle
142
    Please reference the https://github.com/PaddlePaddle/FleetX for details
143 144 145 146 147


    Returns:
        Fleet: A Fleet instance

148
    Example for collective training:
1
123malin 已提交
149

150 151
        .. code-block:: python

1
123malin 已提交
152 153
            import paddle
            paddle.enable_static()
154
            import paddle.distributed.fleet as fleet
155 156 157

            fleet.init(is_collective=True)

158 159 160
            strategy = fleet.DistributedStrategy()
            optimizer = paddle.optimizer.SGD(learning_rate=0.001)
            optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
161 162 163 164 165 166 167 168

            # do distributed training


    Example for parameter server training:

        .. code-block:: python

1
123malin 已提交
169 170
            import paddle
            paddle.enable_static()
171 172
            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
S
ShenLiang 已提交
173
            fleet.init(strategy=strategy)
174

175
            optimizer = paddle.optimizer.SGD(learning_rate=0.001)
176
            optimizer = fleet.distributed_optimizer(optimizer)
177

178 179
            if fleet.is_first_worker():
                print("this is first worker")
180

181 182
            print("current node index: {}".format(fleet.worker_index()))
            print("total number of worker num: {}".format(fleet.worker_num()))
183

184 185 186
            if fleet.is_worker():
                print("this is worker")
            print("worker endpoints: {}".format(fleet.worker_endpoints(to_string=True)))
187

188 189
            print("server num: {}".format(fleet.server_num()))
            print("server endpoints: {}".format(fleet.server_endpoints(to_string=True)))
190

191 192 193
            if fleet.is_server():
                print("this is server")
            fleet.stop_worker()
194 195


196 197 198
    """

    def __init__(self):
199
        self._role_maker = None
200
        self.strategy_compiler = None
201
        self._is_collective = False
202
        self._runtime_handle = None
D
Dong Daxiang 已提交
203 204
        self._util = None
        self._context = {}
205

206
    def init(self, role_maker=None, is_collective=False, strategy=None):
207 208 209
        """
        Initialize role_maker in Fleet.

210 211 212 213 214 215 216 217 218 219 220
        This function is responsible for the distributed architecture
        what you want to run your code behind.

        Args:
            role_maker (RoleMakerBase, optional): A ``RoleMakerBase`` containing the configuration
                of environment variables related to distributed training.If you did not initialize 
                the rolemaker by yourself, it will be automatically initialized to PaddleRoleMaker.
                The default value is None.
            is_collective (Boolean, optional): A ``Boolean`` variable determines whether the program 
                runs on the CPU or GPU. False means set distributed training using CPU, and True means
                GPU.The default value is False.The default value is False.
221 222 223 224
            strategy (DistributedStrategy): Extra properties for distributed training. 
                For details, please refer to paddle.distributed.fleet.DistributedStrategy. Default: None.


225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246
        Returns:
            None

        Examples1:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

        Examples2:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init(is_collective=True)

        Examples3:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
1
123malin 已提交
247
                role = fleet.PaddleCloudRoleMaker()
248
                fleet.init(role)
249

250 251 252 253 254 255
        Examples4:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                strategy = fleet.DistributedStrategy()
S
ShenLiang 已提交
256
                fleet.init(strategy=strategy)
257

258
        """
S
ShenLiang 已提交
259 260 261
        if strategy is None:
            strategy = DistributedStrategy()
        self._user_defined_strategy = copy.deepcopy(strategy)
262 263

        if role_maker is None:
264 265 266 267 268 269
            if isinstance(is_collective, bool):
                self._is_collective = is_collective
                self._role_maker = PaddleCloudRoleMaker(
                    is_collective=self._is_collective)
            else:
                raise ValueError(
270 271
                    "`is_collective` should be instance of `bool`, but got {}".
                    format(type(is_collective)))
272
        else:
273 274
            if isinstance(role_maker, RoleMakerBase):
                self._role_maker = role_maker
275
                self._is_collective = role_maker._is_collective
276 277 278 279
            else:
                raise ValueError(
                    "`role_maker` should be subclass of `RoleMakerBase`, but got {}".
                    format(type(role_maker)))
280
        self._role_maker._generate_role()
281

282 283 284
        import paddle.distributed.fleet as fleet
        fleet.util._set_role_maker(self._role_maker)

285
        self.strategy_compiler = StrategyCompiler()
286 287 288 289 290 291 292 293 294

        if self._role_maker._is_non_distributed() and self._is_collective:
            if paddle.fluid.core.is_compiled_with_cuda():
                gpus_num = paddle.fluid.core.get_cuda_device_count()
                if gpus_num != 1:
                    raise ValueError(
                        "CUDA_VISIBLE_DEVICES shoule be set only 1 card if you use `python` to launch fleet program."
                    )

J
Jiabin Yang 已提交
295
        if paddle.fluid.framework._non_static_mode():
296
            if self.worker_num() == 1:
297 298 299
                # if worker_num is 1, should construct default topology & hcg
                self._topology = tp.CommunicateTopology()
                self._hcg = tp.HybridCommunicateGroup(self._topology)
300
                return
301 302 303 304
            if parallel_helper._is_parallel_ctx_initialized():
                warnings.warn(
                    "The dygraph parallel environment has been initialized.")
            else:
305 306 307 308 309 310 311 312 313
                # FLAGS_nccl_nrings is used for dynamic graph multi-stream communication
                if "FLAGS_nccl_nrings" in os.environ:
                    warnings.warn(
                        "You have set the environment variable FLAGS_nccl_nrings "
                        "outside the program, so the nccl_comm_num in "
                        "DistributedStrategy will not take effect here.")
                else:
                    os.environ["FLAGS_nccl_nrings"] = str(
                        self._user_defined_strategy.nccl_comm_num)
314
                paddle.distributed.init_parallel_env()
315

K
kuizhiqing 已提交
316 317 318 319 320 321 322 323 324
            # hybrid parallel not support for npu/xpu
            if self._user_defined_strategy.heter_ccl_mode == False:
                # init hybrid parallel environment in dygraph
                if tp._HYBRID_PARALLEL_GROUP is None:
                    self._init_hybrid_parallel_env()
                else:
                    warnings.warn(
                        "The dygraph hybrid parallel environment has been initialized."
                    )
W
WangXi 已提交
325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340
        elif self._is_collective:
            use_sharding = self._user_defined_strategy.sharding

            # global group
            global_rank = self.worker_index()
            global_world_size = self.worker_num()
            # NOTE(wangxi): see sharding_optimizer
            global_ring_id = 3 if use_sharding else 0
            global_ranks = list(range(global_world_size))

            if tp._HYBRID_PARALLEL_GROUP is None: tp._CommunicateGroup()
            cg = tp._HYBRID_PARALLEL_GROUP
            self._hcg = cg
            cg.set_comm_group('global', global_rank, global_world_size,
                              global_ring_id, global_ranks)

Y
Yuang Liu 已提交
341 342 343
            use_tensor_parallel = self._user_defined_strategy.tensor_parallel
            use_mp = use_sharding or use_tensor_parallel

W
WangXi 已提交
344
            # hybrid group
Y
Yuang Liu 已提交
345 346 347 348 349 350 351 352 353 354 355 356 357 358 359
            if use_mp is False: return

            mp_degree_sharding = 1
            mp_degree_tensor_parallel = 1
            if use_sharding:
                sharding_configs = self._user_defined_strategy.sharding_configs
                mp_degree_sharding = int(sharding_configs['mp_degree'])

            if use_tensor_parallel:
                tensor_parallel_configs = self._user_defined_strategy.tensor_parallel_configs
                mp_degree_tensor_parallel = int(tensor_parallel_configs[
                    'tensor_parallel_degree'])

            if use_sharding and use_tensor_parallel:
                assert mp_degree_sharding == mp_degree_tensor_parallel
W
WangXi 已提交
360

Y
Yuang Liu 已提交
361
            mp_degree = mp_degree_sharding if use_sharding else mp_degree_tensor_parallel
W
WangXi 已提交
362 363 364 365 366 367 368 369 370 371 372 373 374

            if mp_degree > 1:
                assert global_world_size % mp_degree == 0
                # NOTE(wangxi): mp_ring_id sync with sharding_optimizer.py _build_groups
                mp_ring_id = 0
                mp_rank = global_rank % mp_degree
                mp_group_id = global_rank // mp_degree
                mp_group_ranks = [
                    idx for idx in global_ranks
                    if idx // mp_degree == mp_group_id
                ]
                cg.set_comm_group('model', mp_rank, mp_degree, mp_ring_id,
                                  mp_group_ranks)
375 376 377 378 379 380 381 382

    def _init_hybrid_parallel_env(self):
        """initialize the hybrid environment
        """
        self.hybrid_configs = self._user_defined_strategy.hybrid_configs
        self.dp_degree = self.hybrid_configs["dp_degree"]
        self.mp_degree = self.hybrid_configs["mp_degree"]
        self.pp_degree = self.hybrid_configs["pp_degree"]
J
JZ-LIANG 已提交
383
        self.sharding_degree = self.hybrid_configs["sharding_degree"]
384 385 386

        assert self.mp_degree >= 0, "mp_degree should be greater or equal to 0"
        assert self.pp_degree >= 0, "pp_degree should be greater or equal to 0"
J
JZ-LIANG 已提交
387
        assert self.sharding_degree >= 0, "sharding_degree should be greater or equal to 0"
388 389 390 391 392 393 394 395 396 397 398

        self.mp_degree = max(self.mp_degree, 1)
        self.pp_degree = max(self.pp_degree, 1)

        if self.dp_degree < 0:
            nranks = paddle.distributed.get_world_size()
            self.dp_degree = nranks // (self.mp_degree * self.pp_degree)

        self.dp_degree = max(self.dp_degree, 1)

        self._topology = tp.CommunicateTopology(
J
JZ-LIANG 已提交
399 400 401 402 403
            hybrid_group_names=["data", "pipe", "sharding", "model"],
            dims=[
                self.dp_degree, self.pp_degree, self.sharding_degree,
                self.mp_degree
            ])
404 405 406

        self._hcg = tp.HybridCommunicateGroup(self._topology)

407 408 409 410 411 412 413 414
        if self.mp_degree > 1:
            tensor_parallel_configs = self._user_defined_strategy.tensor_parallel_configs
            tensor_init_seed = tensor_parallel_configs["tensor_init_seed"]
            if tensor_init_seed == -1:
                model_parallel_random_seed()
            else:
                model_parallel_random_seed(tensor_init_seed)

415 416 417 418 419 420 421 422
    def get_hybrid_communicate_group(self):
        assert self._hcg is not None
        return self._hcg

    def get_hybrid_parallel_topology(self):
        assert self._topology is not None
        return self._topology

423 424 425 426 427 428 429
    def is_first_worker(self):
        """
        Check whether the node is the first instance of worker.

        Returns:
            bool: True if this is the first node of worker,
                  False if not.
430

431 432 433 434 435 436 437 438
        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.is_first_worker()

439
        """
440
        return self._role_maker._is_first_worker()
441 442 443 444 445 446 447

    def worker_index(self):
        """
        Get current worker index.

        Returns:
            int: node id
448 449 450 451

        Examples:

            .. code-block:: python
1
123malin 已提交
452

453 454 455 456
                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.worker_index()

457
        """
458
        return self._role_maker._worker_index()
459 460 461 462 463 464 465

    def worker_num(self):
        """
        Get current total worker number.

        Returns:
            int: worker numbers
1
123malin 已提交
466

467
        Examples:
1
123malin 已提交
468

469 470 471 472 473 474
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.worker_num()

475
        """
476
        return self._role_maker._worker_num()
477

478 479 480 481 482 483 484 485 486 487 488 489
    def node_num(self):
        return self._role_maker._get_node_num()

    def local_rank(self):
        return self._role_maker._get_local_rank()

    def local_device_ids(self):
        return self._role_maker._get_local_device_ids()

    def world_device_ids(self):
        return self._role_maker._get_world_device_ids()

490 491 492 493 494 495 496
    def is_worker(self):
        """
        Check whether the node is an instance of worker.

        Returns:
            bool: True if this is a node of worker,
                  False if not.
497 498

        Examples:
1
123malin 已提交
499

500 501 502 503 504 505
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.is_worker()

506
        """
507
        return self._role_maker._is_worker()
508 509 510

    def worker_endpoints(self, to_string=False):
        """
511
        Get current worker endpoints, such as ["127.0.0.1:1001", "127.0.0.1:1002"].
512 513 514

        Returns:
            list/string: server endpoints
515 516

        Examples:
1
123malin 已提交
517

518 519 520 521 522 523
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.worker_endpoints()

524 525
        """
        if to_string:
526
            return ",".join(self._role_maker._get_trainer_endpoints())
527
        else:
528
            return self._role_maker._get_trainer_endpoints()
529 530 531 532 533 534 535

    def server_num(self):
        """
        Get current total worker number.

        Returns:
            int: server number
536 537

        Examples:
1
123malin 已提交
538

539
            .. code-block:: python
1
123malin 已提交
540 541 542 543

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.server_num()
544
        """
545
        return len(self._role_maker._get_pserver_endpoints())
546 547 548 549 550 551 552

    def server_index(self):
        """
        Get current server index.

        Returns:
            int: node id
553 554

        Examples:
1
123malin 已提交
555

556 557 558 559 560 561
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.server_index()

562
        """
563
        return self._role_maker._server_index()
564 565 566 567 568 569 570

    def server_endpoints(self, to_string=False):
        """
        Get current server endpoints, such as ["127.0.0.1:1001", "127.0.0.1:1002"].

        Returns:
            list/string: server endpoints
571 572

        Examples:
1
123malin 已提交
573

574 575 576 577 578 579
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.server_endpoints()

580
        """
581

582
        if to_string:
583
            return ",".join(self._role_maker._get_pserver_endpoints())
584
        else:
585
            return self._role_maker._get_pserver_endpoints()
586 587 588 589 590 591 592 593

    def is_server(self):
        """
        Check whether the node is an instance of server.

        Returns:
            bool: True if this is a node of server,
                  False if not.
594 595 596 597

        Examples:

            .. code-block:: python
1
123malin 已提交
598

599 600 601 602
                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.is_server()

603
        """
604 605
        return self._role_maker._is_server()

606 607
    def barrier_worker(self):
        """
608 609 610 611
        barrier all workers

        Returns:
            None
612
        """
613
        self._role_maker._barrier("worker")
614

615
    @is_non_distributed_check
616
    @inited_runtime_handler
617
    def init_worker(self, scopes=None):
618
        """
619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636
        initialize `Communicator` for parameter server training.


        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_worker()

637
        """
638
        self._runtime_handle._init_worker(scopes)
639

640
    @is_non_distributed_check
641
    @inited_runtime_handler
642
    def init_server(self, *args, **kwargs):
643
        """
644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662
        init_server executor to initialize startup program,
        if the `args` is not empty, it will run load_persistables for increment training.


        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_server()

663
        """
664
        self._runtime_handle._init_server(*args, **kwargs)
665

Z
zmxdream 已提交
666 667
    @is_non_distributed_check
    @inited_runtime_handler
T
Thunderbrook 已提交
668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690
    def load_model(self, path, mode):
        """
        load fleet model from path


        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.load_model("path", "mode")

        """
        self._runtime_handle.load_model(path, mode)

691
    @is_non_distributed_check
692
    @inited_runtime_handler
693 694
    def run_server(self):
        """
695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712
        run server will run pserver main program with executor.

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                if fleet.is_server():
                    fleet.init_server()

713 714 715
        """
        self._runtime_handle._run_server()

716
    @is_non_distributed_check
717
    @inited_runtime_handler
718 719
    def stop_worker(self):
        """
720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736
        stop `Communicator` and give training complete notice to parameter server.

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_server()

737 738 739
        """
        self._runtime_handle._stop_worker()

Z
zmxdream 已提交
740 741
    @is_non_distributed_check
    @inited_runtime_handler
T
tangwei12 已提交
742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784
    def save(self, dirname, feed=[], fetch=[], **configs):
        inference = True

        if not feed and not fetch:
            inference = False

        place = paddle.CPUPlace()
        executor = paddle.static.Executor(place)

        if inference:
            feeded_var_names = []
            fetch_var_names = []

            for var in feed:
                if isinstance(var, str):
                    feeded_var_names.append(var)
                elif isinstance(var, paddle.static.Variable):
                    feeded_var_names.append(var.name)
                else:
                    raise ValueError("feed must be [str|Variable]")

            for var in fetch:
                if isinstance(var, str):
                    fetch_var_names.append(var)
                elif isinstance(var, paddle.static.Variable):
                    fetch_var_names.append(var.name)
                else:
                    raise ValueError("feed must be [str|Variable]")

            fetch_vars = [
                paddle.static.default_main_program().global_block().var(name)
                for name in fetch_var_names
            ]

            self._runtime_handle._save_inference_model(
                executor, dirname, feeded_var_names, fetch_vars, None, True, 0)
        else:
            increment_mode = 0
            if "mode" in configs:
                increment_mode = int(configs["mode"])
            self._runtime_handle._save_persistables(
                executor, dirname, main_program=None, mode=increment_mode)

Z
zmxdream 已提交
785 786
    @is_non_distributed_check
    @inited_runtime_handler
787 788 789 790 791 792
    def save_inference_model(self,
                             executor,
                             dirname,
                             feeded_var_names,
                             target_vars,
                             main_program=None,
793 794
                             export_for_deployment=True,
                             mode=0):
795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813
        """
        save inference model for inference.

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_server()

        """
T
tangwei12 已提交
814 815 816
        # warnings.warn(
        #     "'save_inference_model' is a deprecated, will be deleted after v2.2.0, Please use fleet.save instead."
        # )
817

818 819
        self._runtime_handle._save_inference_model(
            executor, dirname, feeded_var_names, target_vars, main_program,
820
            export_for_deployment, mode)
821

Z
zmxdream 已提交
822 823
    @is_non_distributed_check
    @inited_runtime_handler
824
    def save_persistables(self, executor, dirname, main_program=None, mode=0):
825 826
        """

1
123malin 已提交
827
        saves all persistable tensors from :code:`main_program` to
828 829
        the folder :code:`dirname`. You can refer to

1
123malin 已提交
830 831
        The :code:`dirname` is used to specify the folder where persistable tensors
        are going to be saved. If you would like to save tensors in separate
832 833 834
        files, set :code:`filename` None.

        Args:
1
123malin 已提交
835
            executor(Executor): The executor to run for saving persistable tensors.
836 837 838 839 840
                                You can refer to :ref:`api_guide_executor_en` for
                                more details.

            dirname(str, optional): The saving directory path.
                                When you need to save the parameter to the memory, set it to None.
1
123malin 已提交
841
            main_program(Program, optional): The program whose persistbale tensors will
842 843 844 845 846 847 848 849 850 851
                                             be saved. Default: None.


        Returns:
            None

        Examples:

            .. code-block:: text

1
123malin 已提交
852 853
                import paddle
                paddle.enable_static()
854 855 856 857 858 859 860
                import paddle.distributed.fleet as fleet

                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

1
123malin 已提交
861 862
                exe = paddle.static.Executor(paddle.CPUPlace())
                fleet.save_persistables(exe, "dirname", paddle.static.default_main_program())
863 864

        """
T
tangwei12 已提交
865 866 867
        # warnings.warn(
        #     "'save_persistables' is a deprecated, will be deleted after v2.2.0, Please use fleet.save instead."
        # )
868

869 870
        self._runtime_handle._save_persistables(executor, dirname, main_program,
                                                mode)
871

Z
zhaocaibei123 已提交
872 873 874 875 876
    @is_non_distributed_check
    @inited_runtime_handler
    def save_cache_model(self, dirname, **configs):
        return self._runtime_handle._save_cache_model(dirname, **configs)

877
    def shrink(self, threshold=None):
878 879
        self._runtime_handle._shrink(threshold)

880
    def distributed_optimizer(self, optimizer, strategy=None):
881
        """
882 883 884 885 886 887 888
        Optimizer for distributed training.

        For the distributed training, this method would rebuild a new instance of DistributedOptimizer.
        Which has basic Optimizer function and special features for distributed training.

        Args:
            optimizer(Optimizer): The executor to run for init server.
889 890 891 892 893
            strategy(DistributedStrategy): Extra properties for distributed optimizer. 
                It is recommended to use DistributedStrategy in fleet.init(). The strategy
                here is for compatibility. If the strategy in fleet.distributed_optimizer() 
                is not None, then it will overwrite the DistributedStrategy in fleet.init(), 
                which will take effect in distributed training.
894

895
        Returns:
896
            Fleet: instance of fleet.
897 898

        Examples:
899

900
            .. code-block:: python
901

1
123malin 已提交
902
                import paddle
903
                import paddle.distributed.fleet as fleet
1
123malin 已提交
904
                fleet.init(is_collective=True)
905 906 907 908
                strategy = fleet.DistributedStrategy()
                optimizer = paddle.optimizer.SGD(learning_rate=0.001)
                optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)

909 910
        """
        self.user_defined_optimizer = optimizer
911

912
        if strategy is not None:
T
tangwei12 已提交
913 914 915 916 917 918 919
            if self._is_collective:
                warnings.warn(
                    "It is recommended to use DistributedStrategy "
                    "in fleet.init(). The strategy here is only for compatibility. "
                    "If the strategy in fleet.distributed_optimizer() is "
                    "not None, then it will overwrite the DistributedStrategy in fleet.init(), "
                    "which will take effect in distributed training.")
920
            self._user_defined_strategy = copy.deepcopy(strategy)
D
Dong Daxiang 已提交
921 922

        self._context = {}
S
ShenLiang 已提交
923

J
Jiabin Yang 已提交
924
        if paddle.fluid.framework._non_static_mode():
925
            if self.worker_num() > 1:
K
kuizhiqing 已提交
926 927 928 929 930 931
                if self._user_defined_strategy.heter_ccl_mode == False:
                    return HybridParallelOptimizer(optimizer, self._hcg,
                                                   self._user_defined_strategy)
                else:
                    return HeterParallelOptimizer(optimizer,
                                                  self._user_defined_strategy)
932 933
            else:
                return optimizer
934 935
        return self

936
    @dygraph_only
937
    def distributed_model(self, model):
938
        """
939 940 941 942 943 944 945
        Return distributed data parallel model (Only work in dygraph mode)

        Args:
            model (Layer): the user-defind model which inherits Layer.

        Returns:
            distributed data parallel model which inherits Layer.
946 947

        Examples:
948

949 950
            .. code-block:: python

951 952 953 954 955 956 957 958 959
                import paddle
                import paddle.nn as nn
                from paddle.distributed import fleet

                class LinearNet(nn.Layer):
                    def __init__(self):
                        super(LinearNet, self).__init__()
                        self._linear1 = nn.Linear(10, 10)
                        self._linear2 = nn.Linear(10, 1)
960

961 962
                    def forward(self, x):
                        return self._linear2(self._linear1(x))
963

1
123malin 已提交
964
                # 1. initialize fleet environment
965 966
                fleet.init(is_collective=True)

1
123malin 已提交
967
                # 2. create layer & optimizer
968 969 970 971 972
                layer = LinearNet()
                loss_fn = nn.MSELoss()
                adam = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=layer.parameters())

1
123malin 已提交
973
                # 3. get data_parallel model using fleet
974 975 976
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)

1
123malin 已提交
977
                # 4. run layer
978 979 980 981 982 983 984 985 986 987 988 989
                inputs = paddle.randn([10, 10], 'float32')
                outputs = dp_layer(inputs)
                labels = paddle.randn([10, 1], 'float32')
                loss = loss_fn(outputs, labels)

                print("loss:", loss.numpy())

                loss.backward()

                adam.step()
                adam.clear_grad()

990

991
        """
992 993 994
        assert model is not None, "model should not be None"
        if self.worker_num() <= 1:
            return model
J
JZ-LIANG 已提交
995

996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030
        amp_enable = False
        recompute_enable = False
        strategy = self._user_defined_strategy
        if strategy.amp == True:
            amp_enable = True
            amp_level = "O2" if strategy.amp_configs['use_pure_fp16'] else "O1"
            if amp_level.upper() == "O2":
                model = paddle.amp.decorate(
                    models=model,
                    optimizers=None,
                    level="O2",
                    master_weight=None,
                    save_dtype=None)
            init_loss_scaling = strategy.amp_configs['init_loss_scaling']
            incr_ratio = strategy.amp_configs['incr_ratio']
            decr_ratio = strategy.amp_configs['decr_ratio']
            incr_every_n_steps = strategy.amp_configs['incr_every_n_steps']
            decr_every_n_nan_or_inf = strategy.amp_configs[
                'decr_every_n_nan_or_inf']
            use_dynamic_loss_scaling = strategy.amp_configs[
                'use_dynamic_loss_scaling']

            global _grad_scalar
            _grad_scalar = paddle.amp.GradScaler(
                init_loss_scaling=init_loss_scaling,
                incr_ratio=incr_ratio,
                decr_ratio=decr_ratio,
                incr_every_n_steps=incr_every_n_steps,
                decr_every_n_nan_or_inf=decr_every_n_nan_or_inf,
                use_dynamic_loss_scaling=use_dynamic_loss_scaling)

        if strategy.recompute == True:
            recompute_enable = True
            model = _RecomputeModelWrapper(model)

K
kuizhiqing 已提交
1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041
        if self._user_defined_strategy.heter_ccl_mode == True:
            distributed_model = paddle.DataParallel(
                model,
                comm_buffer_size=self._user_defined_strategy.
                fuse_grad_size_in_MB,
                last_comm_buffer_size=self._user_defined_strategy.
                last_comm_group_size_MB,
                find_unused_parameters=self._user_defined_strategy.
                find_unused_parameters)
            return distributed_model

J
JZ-LIANG 已提交
1042
        if self._hcg.get_parallel_mode() == ParallelMode.SHARDING_PARALLEL:
1043
            model = ShardingParallel(
J
JZ-LIANG 已提交
1044 1045
                model, self._hcg, strategy=self._user_defined_strategy)
        elif self._hcg.get_parallel_mode() == ParallelMode.DATA_PARALLEL:
1046 1047 1048 1049 1050 1051 1052 1053

            # NOTE (JZ-LIANG) init parameters broadcast within sharding group
            # normally it should be done inside DataParallel
            if self.sharding_degree > 1:
                from paddle.distributed.fleet.utils.hybrid_parallel_util import broadcast_mp_parameters, broadcast_sharding_parameters
                assert self.sharding_degree == self._hcg.get_sharding_parallel_world_size(
                )
                broadcast_sharding_parameters(model, self._hcg)
1054
            model = paddle.DataParallel(
1055 1056 1057 1058 1059 1060
                model,
                comm_buffer_size=self._user_defined_strategy.
                fuse_grad_size_in_MB,
                last_comm_buffer_size=self._user_defined_strategy.
                last_comm_group_size_MB,
                find_unused_parameters=self._user_defined_strategy.
1061
                find_unused_parameters)
1062
        elif self._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL:
1063
            model = TensorParallel(
1064
                model, self._hcg, strategy=self._user_defined_strategy)
1065
        elif self._hcg.get_parallel_mode() == ParallelMode.PIPELINE_PARALLEL:
1066
            model = PipelineParallel(
1067
                model, self._hcg, strategy=self._user_defined_strategy)
J
JZ-LIANG 已提交
1068

1069
        return model
1070 1071 1072 1073 1074

    @dygraph_only
    def state_dict(self):
        """
        Get state dict information from optimizer.
1075
        (Only work in dygraph mode)
1076 1077 1078 1079 1080 1081 1082

        Returns: 
            state_dict(dict) : dict contains all the Tensor used by optimizer

        Examples:
            .. code-block:: python

1083 1084 1085 1086 1087
                import numpy as np
                import paddle
                from paddle.distributed import fleet

                fleet.init(is_collective=True)
1088

1089
                value = np.arange(26).reshape(2, 13).astype("float32")
1
123malin 已提交
1090
                a = paddle.to_tensor(value)
1091

1092 1093
                layer = paddle.nn.Linear(13, 5)
                adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
1094

1095 1096 1097
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)
                state_dict = adam.state_dict()
1098 1099 1100 1101 1102 1103 1104 1105
        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.state_dict()

    @dygraph_only
    def set_state_dict(self, state_dict):
        """
        Load optimizer state dict.
1106
        (Only work in dygraph mode)
1107 1108 1109 1110

        Args: 
            state_dict(dict) : Dict contains all the Tensor needed by optimizer

1111 1112
        Returns:
            None
1113 1114 1115 1116

        Examples:
            .. code-block:: python

1117 1118 1119
                import numpy as np
                import paddle
                from paddle.distributed import fleet
1120

1121 1122 1123
                fleet.init(is_collective=True)

                value = np.arange(26).reshape(2, 13).astype("float32")
1
123malin 已提交
1124
                a = paddle.to_tensor(value)
1125

1126 1127
                layer = paddle.nn.Linear(13, 5)
                adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
1128

1129 1130 1131
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)
                state_dict = adam.state_dict()
1
123malin 已提交
1132 1133 1134
                paddle.save(state_dict, "paddle_dy")
                para_state_dict = paddle.load("paddle_dy")
                adam.set_state_dict(para_state_dict)
1135 1136 1137 1138 1139 1140 1141 1142
        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.set_state_dict(state_dict)

    @dygraph_only
    def set_lr(self, value):
        """
        Set the value of the learning rate manually in the optimizer. 
1143
        (Only work in dygraph mode)
1144

1145 1146 1147
        Args:
            value (float|Tensor): the value of learning rate

1148 1149
        Returns: 
            None 
1150 1151 1152 1153

        Examples:
            .. code-block:: python

1154 1155 1156
                import numpy as np
                import paddle
                from paddle.distributed import fleet
1157

1158
                fleet.init(is_collective=True)
1159

1160
                value = np.arange(26).reshape(2, 13).astype("float32")
1
123malin 已提交
1161
                a = paddle.to_tensor(value)
1162

1163 1164
                layer = paddle.nn.Linear(13, 5)
                adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
1165

1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)

                lr_list = [0.2, 0.3, 0.4, 0.5, 0.6]
                for i in range(5):
                    adam.set_lr(lr_list[i])
                    lr = adam.get_lr()
                    print("current lr is {}".format(lr))
                # Print:
                #    current lr is 0.2
                #    current lr is 0.3
                #    current lr is 0.4
                #    current lr is 0.5
                #    current lr is 0.6
1180 1181 1182 1183 1184 1185 1186 1187
        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.set_lr(value)

    @dygraph_only
    def get_lr(self):
        """
        Get current step learning rate.
1188
        (Only work in dygraph mode)
1189 1190 1191 1192 1193

        Returns:
            float: The learning rate of the current step.

        Examples:
1
123malin 已提交
1194

1195 1196
            .. code-block:: python

1197 1198 1199 1200 1201
                import numpy as np
                import paddle
                from paddle.distributed import fleet

                fleet.init(is_collective=True)
1202

1203
                value = np.arange(26).reshape(2, 13).astype("float32")
1
123malin 已提交
1204
                a = paddle.to_tensor(value)
1205

1206 1207
                layer = paddle.nn.Linear(13, 5)
                adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
1208

1209 1210
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)
1211

1212 1213
                lr = adam.get_lr()
                print(lr) # 0.01
1214 1215 1216 1217 1218 1219 1220 1221
        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.get_lr()

    @dygraph_only
    def step(self):
        """
        Execute the optimizer once.
1222
        (Only work in dygraph mode)
1223

1224 1225
        Returns:
            None
1226 1227

        Examples:
1
123malin 已提交
1228

1229 1230
            .. code-block:: python

1231 1232 1233
                import paddle
                import paddle.nn as nn
                from paddle.distributed import fleet
1234

1235 1236 1237 1238 1239
                class LinearNet(nn.Layer):
                    def __init__(self):
                        super(LinearNet, self).__init__()
                        self._linear1 = nn.Linear(10, 10)
                        self._linear2 = nn.Linear(10, 1)
1240

1241 1242
                    def forward(self, x):
                        return self._linear2(self._linear1(x))
1243

1
123malin 已提交
1244
                # 1. initialize fleet environment
1245 1246
                fleet.init(is_collective=True)

1
123malin 已提交
1247
                # 2. create layer & optimizer
1248 1249 1250 1251 1252
                layer = LinearNet()
                loss_fn = nn.MSELoss()
                adam = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=layer.parameters())

1
123malin 已提交
1253
                # 3. get data_parallel model using fleet
1254 1255 1256
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)

1
123malin 已提交
1257
                # 4. run layer
1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277
                inputs = paddle.randn([10, 10], 'float32')
                outputs = dp_layer(inputs)
                labels = paddle.randn([10, 1], 'float32')
                loss = loss_fn(outputs, labels)

                print("loss:", loss.numpy())

                loss.backward()

                adam.step()
                adam.clear_grad()


        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.step()

    @dygraph_only
    def clear_grad(self):
        """
1278 1279
        Clear the gradients of all optimized parameters for model.
        (Only work in dygraph mode)
1280

1281 1282
        Returns: 
            None
1283 1284

        Examples:
1
123malin 已提交
1285

1286 1287
            .. code-block:: python

1288 1289 1290
                import paddle
                import paddle.nn as nn
                from paddle.distributed import fleet
1291

1292 1293 1294 1295 1296
                class LinearNet(nn.Layer):
                    def __init__(self):
                        super(LinearNet, self).__init__()
                        self._linear1 = nn.Linear(10, 10)
                        self._linear2 = nn.Linear(10, 1)
1297

1298 1299
                    def forward(self, x):
                        return self._linear2(self._linear1(x))
1300

1
123malin 已提交
1301
                # 1. initialize fleet environment
1302 1303
                fleet.init(is_collective=True)

1
123malin 已提交
1304
                # 2. create layer & optimizer
1305 1306 1307 1308 1309
                layer = LinearNet()
                loss_fn = nn.MSELoss()
                adam = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=layer.parameters())

1
123malin 已提交
1310
                # 3. get data_parallel model using fleet
1311 1312 1313
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)

1
123malin 已提交
1314
                # 4. run layer
1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330
                inputs = paddle.randn([10, 10], 'float32')
                outputs = dp_layer(inputs)
                labels = paddle.randn([10, 1], 'float32')
                loss = loss_fn(outputs, labels)

                print("loss:", loss.numpy())

                loss.backward()

                adam.step()
                adam.clear_grad()

        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.clear_grad()

1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347
    def _get_amp_optimizer(self):
        # imitate target optimizer retrieval
        amp_optimizer = None
        for optimizer in self.strategy_compiler._get_applied_meta_optimizer():
            if hasattr(optimizer, 'amp_init'):
                amp_optimizer = optimizer
                break

        if amp_optimizer is None:
            if hasattr(self.user_defined_optimizer, 'amp_init'):
                amp_optimizer = self.user_defined_optimizer

        assert amp_optimizer is not None, \
            "amp_init can only be used when the amp(auto mixed precision) strategy is turned on."
        return amp_optimizer

    def get_loss_scaling(self):
1348 1349
        """Return the real-time loss scaling factor.
        """
1350 1351 1352
        amp_optimizer = self._get_amp_optimizer()
        return amp_optimizer.get_loss_scaling()

H
huangxu96 已提交
1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412
    def amp_init(self,
                 place,
                 scope=None,
                 test_program=None,
                 use_fp16_test=False):
        """
        Init the amp training, such as cast fp32 parameters to fp16 type.
  
        Args:
            place(CUDAPlace): place is used to initialize 
                fp16 parameters with fp32 values.
            scope(Scope): The scope is used to find fp32 parameters.
            test_program(Program): The program is used for testing.
            use_fp16_test(bool): Whether to use fp16 testing.
            
        Examples:
            .. code-block:: python

                import numpy as np
                import paddle
                import paddle.nn.functional as F
                paddle.enable_static()

                def run_example_code():
                    place = paddle.CUDAPlace(0)
                    exe = paddle.static.Executor(place)
                    data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32')
                    conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3)
                    # 1) Use fp16_guard to control the range of fp16 kernels used.
                    with paddle.static.amp.fp16_guard():
                        bn = paddle.static.nn.batch_norm(input=conv2d, act="relu")
                        pool = F.max_pool2d(bn, kernel_size=2, stride=2)
                        hidden = paddle.static.nn.fc(pool, size=10)
                        loss = paddle.mean(hidden)
                    # 2) Create the optimizer and set `multi_precision` to True.
                    # Setting `multi_precision` to True can avoid the poor accuracy
                    # or the slow convergence in a way. 
                    optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True)
                    # 3) These ops in `custom_black_list` will keep in the float32 computation type.
                    amp_list = paddle.static.amp.CustomOpLists(
                        custom_black_list=['pool2d'])
                    # 4) The entry of Paddle AMP.
                    # Enable pure fp16 training by setting `use_pure_fp16` to True.
                    optimizer = paddle.static.amp.decorate(
                        optimizer,
                        amp_list,
                        init_loss_scaling=128.0,
                        use_dynamic_loss_scaling=True,
                        use_pure_fp16=True)
                    # If you don't use the default_startup_program(), you sholud pass
                    # your defined `startup_program` into `minimize`.
                    optimizer.minimize(loss)
                    exe.run(paddle.static.default_startup_program())
                    # 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`).
                    # If you want to perform the testing process, you should pass `test_program` into `amp_init`.
                    optimizer.amp_init(place, scope=paddle.static.global_scope())
                    
                if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
                    run_example_code()       
        """
1413
        amp_optimizer = self._get_amp_optimizer()
1414
        return amp_optimizer.amp_init(place, scope, test_program, use_fp16_test)
H
huangxu96 已提交
1415

D
Dong Daxiang 已提交
1416 1417 1418 1419 1420 1421 1422 1423 1424
    def _final_strategy(self):
        if "valid_strategy" not in self._context:
            print(
                "WARNING: You may need to call minimize function before this function is called"
            )
            return {}
        else:
            return self._context["valid_strategy"]

1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442
    def _get_applied_meta_list(self):
        if "applied_meta_list" not in self._context:
            print(
                "WARNING: You may need to call minimize function before _get_applied_meta_list called"
            )
            return []
        else:
            return self._context["applied_meta_list"]

    def _get_applied_graph_list(self):
        if "applied_graph_list" not in self._context:
            print(
                "WARNING: You may need to call minimize function before _get_applied_graph_list called"
            )
            return []
        else:
            return self._context["applied_graph_list"]

1443 1444 1445 1446 1447 1448 1449 1450 1451
    def minimize(self,
                 loss,
                 startup_program=None,
                 parameter_list=None,
                 no_grad_set=None):
        """
        Add distributed operations to minimize ``loss`` by updating ``parameter_list``.

        Args:
1
123malin 已提交
1452
            loss (Tensor): A ``Tensor`` containing the value to minimize.
1453 1454 1455
            startup_program (Program, optional): :ref:`api_fluid_Program` for
                initializing parameters in ``parameter_list``. The default value
                is None, at this time :ref:`api_fluid_default_startup_program` will be used.
1
123malin 已提交
1456
            parameter_list (Iterable, optional): Iterable of ``Tensor`` or ``Tensor.name`` to update
1457 1458
                to minimize ``loss``. The default value is None, at this time all parameters
                will be updated.
1
123malin 已提交
1459
            no_grad_set (set, optional): Set of ``Tensor``  or ``Tensor.name`` that don't need
1460 1461 1462 1463
                to be updated. The default value is None.

        Returns:
            tuple: tuple (optimize_ops, params_grads), A list of operators appended
1
123malin 已提交
1464
            by minimize and a list of (param, grad) tensor pairs, param is
1465
            ``Parameter``, grad is the gradient value corresponding to the parameter.
1466 1467
            The returned tuple can be passed to ``fetch_list`` in ``Executor.run()`` to
            indicate program pruning. If so, the program will be pruned by ``feed`` and
1468 1469 1470
            ``fetch_list`` before run, see details in ``Executor``.

        Examples:
1
123malin 已提交
1471

1472
            .. code-block:: python
1473

1474
                import paddle
1
123malin 已提交
1475
                paddle.enable_static()
1476
                import paddle.distributed.fleet as fleet
1
123malin 已提交
1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487
                import paddle.nn.functional as F

                hid_dim = 10
                label_dim = 2
                input_x = paddle.static.data(name='x', shape=[None, 13], dtype='float32')
                input_y = paddle.static.data(name='y', shape=[None, 1], dtype='int64')
                fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim, activation='tanh')
                fc_2 = paddle.static.nn.fc(x=fc_1, size=hid_dim, activation='tanh')
                prediction = paddle.static.nn.fc(x=[fc_2], size=label_dim, activation='softmax')
                cost = F.cross_entropy(input=prediction, label=input_y)
                avg_cost = paddle.mean(x=cost)
1488

1
123malin 已提交
1489
                fleet.init(is_collective=True)
1490 1491 1492 1493
                strategy = fleet.DistributedStrategy()
                optimizer = paddle.optimizer.SGD(learning_rate=0.001)
                optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
                optimizer.minimize(avg_cost)
1494

1495
                # for more examples, please reference https://github.com/PaddlePaddle/FleetX
1496 1497

        """
1498 1499 1500 1501
        if not isinstance(loss, list):
            return self._minimize_impl(loss, startup_program, parameter_list,
                                       no_grad_set)
        else:
J
Jiabin Yang 已提交
1502
            if paddle.fluid.framework._non_static_mode(
1503 1504 1505 1506 1507 1508 1509 1510 1511 1512
            ) or self._role_maker._is_non_distributed() or self._is_collective:
                raise ValueError("loss can be list only in PS mode")
            return self._minimize_losses_impl(loss, startup_program,
                                              parameter_list, no_grad_set)

    def _minimize_impl(self,
                       loss,
                       startup_program=None,
                       parameter_list=None,
                       no_grad_set=None):
D
Dong Daxiang 已提交
1513 1514 1515
        context = {}
        context["user_defined_strategy"] = copy.deepcopy(
            self._user_defined_strategy)
J
Jiabin Yang 已提交
1516
        if paddle.fluid.framework._non_static_mode():
1517 1518
            # imitate target optimizer retrieval
            target_opt = self.user_defined_optimizer
D
Dong Daxiang 已提交
1519
            self._context = context
1520 1521
            return target_opt.minimize(loss)

1522 1523
        # cache original feed forward program
        self.origin_main_program = loss.block.program
B
Baibaifan 已提交
1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539
        # add distributed attr
        if not hasattr(self.origin_main_program, "distributed_info_"):
            setattr(self.origin_main_program, "distributed_info_", dict())
            self.origin_main_program.distributed_info_[
                "dp_degree"] = self._user_defined_strategy.sharding_configs[
                    "dp_degree"]
            self.origin_main_program.distributed_info_[
                "mp_degree"] = self._user_defined_strategy.sharding_configs[
                    "mp_degree"]
            self.origin_main_program.distributed_info_[
                "pp_degree"] = self._user_defined_strategy.sharding_configs[
                    "pp_degree"]
            self.origin_main_program.distributed_info_[
                "sharding_degree"] = self._user_defined_strategy.sharding_configs[
                    "sharding_degree"]

1540
        context["origin_main_program"] = self.origin_main_program
1541
        context["origin_main_programs"] = [self.origin_main_program]
1542
        context["loss"] = loss
1543 1544
        if startup_program == None:
            self.origin_startup_program = \
1545 1546
                paddle.static.default_startup_program().clone(for_test=False)
            startup_program = paddle.static.default_startup_program()
1547 1548 1549
        else:
            self.origin_startup_program = \
                startup_program.clone(for_test=False)
1550

1551
        context["origin_startup_program"] = startup_program
1552
        context["origin_startup_programs"] = [startup_program]
1553
        context["role_maker"] = self._role_maker
1554

1555
        # Use the auto-parallel's routines instead
1556
        if self._user_defined_strategy.semi_auto or self._user_defined_strategy.auto_search:
1557 1558 1559 1560
            from ...auto_parallel.parallelizer import AutoParallelizer
            auto_parallelizer = AutoParallelizer(self)
            optimize_ops, params_grads, dist_startup_prog, dist_main_prog = auto_parallelizer.parallelize(
                loss, startup_program, parameter_list, no_grad_set)
1561

1562 1563
            return optimize_ops, params_grads, dist_startup_prog, dist_main_prog

1564 1565 1566 1567
        # compile time
        distributed_optimizer_list = \
            MetaOptimizerFactory()._get_valid_meta_optimizers(
                self.user_defined_optimizer)
D
Dong Daxiang 已提交
1568

D
Dong Daxiang 已提交
1569 1570 1571
        context["user_defined_strategy"] = copy.deepcopy(
            self._user_defined_strategy)
        copy_user_defined_strategy = copy.deepcopy(self._user_defined_strategy)
1572 1573 1574 1575 1576 1577

        # trigger the auto-parallel in very strict condition
        # strategy = DistributedStrategy()
        # strategy.auto = True
        # optimizer = paddle.optimizer.SGD(learning_rate=0.1)
        # optimizer = fleet.distributed_optimizer(optimizer, strategy)
D
Dong Daxiang 已提交
1578
        if copy_user_defined_strategy._is_strict_auto():
1579 1580
            # turn on all the strategy for each optimizer
            for opt in distributed_optimizer_list:
D
Dong Daxiang 已提交
1581
                opt._enable_strategy(copy_user_defined_strategy, context)
1582

1583 1584
        valid_optimizer_list = []
        valid_graph_optimizer_list = []
D
Dong Daxiang 已提交
1585
        can_not_apply_optimizer_list = []
1586 1587 1588 1589
        # recall meta optimizers for ranking
        for opt in distributed_optimizer_list:
            opt._set_basic_info(loss, self._role_maker,
                                self.user_defined_optimizer,
D
Dong Daxiang 已提交
1590
                                copy_user_defined_strategy)
1591 1592
            if opt._can_apply() and not opt._is_graph_out():
                valid_optimizer_list.append(opt)
D
Dong Daxiang 已提交
1593
            elif opt._can_apply() and opt._is_graph_out():
1594
                valid_graph_optimizer_list.append(opt)
D
Dong Daxiang 已提交
1595 1596
            else:
                can_not_apply_optimizer_list.append(opt)
1597
        # combine recalled meta optimizers to be a valid meta optimizer
D
Dong Daxiang 已提交
1598
        meta_optimizer, graph_optimizer = \
1599 1600
            self.strategy_compiler.generate_optimizer(
                loss, self._role_maker, self.user_defined_optimizer,
D
Dong Daxiang 已提交
1601
                copy_user_defined_strategy, valid_optimizer_list,
1602
                valid_graph_optimizer_list)
D
Dong Daxiang 已提交
1603

D
Dong Daxiang 已提交
1604
        valid_strategy = self.strategy_compiler._get_valid_strategy(
D
Dong Daxiang 已提交
1605 1606 1607
            copy_user_defined_strategy, can_not_apply_optimizer_list)

        context["valid_strategy"] = copy.deepcopy(valid_strategy)
1608 1609
        # print("valid_strategy:", context["valid_strategy"])
        # print("user_defined_strategy:", context["user_defined_strategy"])
1610

1611 1612 1613 1614 1615 1616
        applied_meta_list = self.strategy_compiler._get_applied_meta_list()
        applied_graph_list = self.strategy_compiler._get_applied_graph_list()

        context['applied_meta_list'] = applied_meta_list
        context['applied_graph_list'] = applied_graph_list

D
Dong Daxiang 已提交
1617
        self._context = context
1618

D
Dong Daxiang 已提交
1619
        self.valid_strategy = valid_strategy
1620
        self.valid_strategy._enable_env()
D
Dong Daxiang 已提交
1621

1622 1623
        optimize_ops = []
        params_grads = []
1624

1625 1626 1627 1628 1629 1630 1631 1632 1633
        if self._role_maker._is_non_distributed() and not self._is_collective:
            if self._runtime_handle is None:
                self._runtime_handle = RuntimeFactory()._create_runtime(context)

            compiled_program = compiler.CompiledProgram(
                self.origin_main_program).with_data_parallel(
                    loss_name=loss.name, share_vars_from=None)
            loss.block.program._graph = compiled_program
            return self.user_defined_optimizer.minimize(
M
MRXLT 已提交
1634
                loss, startup_program, parameter_list, no_grad_set=no_grad_set)
1635

1636
        if meta_optimizer:
1637
            # print("before minimize program id:", id(loss.block.program))
1638
            optimize_ops, params_grads = meta_optimizer.minimize(
M
MRXLT 已提交
1639
                loss, startup_program, parameter_list, no_grad_set=no_grad_set)
1640
            # print("after minimize program id:", id(loss.block.program))
1641

1642
            default_program = paddle.static.default_main_program()
1643
            # print("default program id:", id(default_program))
1644 1645 1646

            if id(default_program) != id(loss.block.program):
                paddle.fluid.framework.switch_main_program(loss.block.program)
1647
            # print("default program id after switch:", id(default_program))
1648

1649 1650
        else:
            optimize_ops, params_grads = self.user_defined_optimizer.minimize(
M
MRXLT 已提交
1651
                loss, startup_program, parameter_list, no_grad_set=no_grad_set)
1652

1653 1654
        context["program_optimize_ops"] = optimize_ops
        context["program_params_grads"] = params_grads
1655

1656
        if graph_optimizer:
1657
            # print("before graph minimize program id:", id(loss.block.program))
D
Dong Daxiang 已提交
1658
            optimize_ops, params_grads = graph_optimizer.minimize(
M
MRXLT 已提交
1659
                loss, startup_program, parameter_list, no_grad_set=no_grad_set)
1660 1661 1662 1663
            # since we do not encourage users to use graph operations
            # if a graph optimizer takes effect, mostly
            # optimizers_ops and params_grads are None
            # i.e. users can not modify current computation graph anymore
1664 1665
            context["graph_optimize_ops"] = optimize_ops
            context["graph_optimize_grads"] = params_grads
1666 1667
        else:
            apply_ir_passes(loss.block.program, startup_program, self)
1668

1669 1670
        if not self._role_maker._is_heter_parameter_server_mode:
            program = paddle.static.default_main_program()
1671 1672 1673 1674 1675
            opt_info = {} if program._fleet_opt is None else program._fleet_opt
            opt_info["mpi_size"] = self.worker_num()
            opt_info["mpi_rank"] = self.worker_index()
            for k, v in self._user_defined_strategy.trainer_desc_configs.items(
            ):
1676
                if v or k not in opt_info:
1677
                    opt_info[k] = v
1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748
            program._fleet_opt = opt_info

        if self._runtime_handle is None:
            self._runtime_handle = RuntimeFactory()._create_runtime(context)

        import paddle.distributed.fleet as fleet
        fleet.util._set_strategy(context["valid_strategy"])

        return optimize_ops, params_grads

    def _minimize_losses_impl(self,
                              losses,
                              startup_programs=None,
                              parameter_list=None,
                              no_grad_set=None):
        context = {}

        # cache original feed forward program
        self.origin_main_program = losses[0].block.program
        context["origin_main_program"] = self.origin_main_program
        context["origin_main_programs"] = []
        for loss in losses:
            context["origin_main_programs"].append(loss.block.program)
        context["loss"] = losses

        if startup_programs is None:
            if len(losses) == 1:
                startup_programs = [paddle.static.default_startup_program()]
            else:
                raise ValueError(
                    "startup_program can't be None when loss is list.")
        self.origin_startup_program = startup_programs[0].clone(for_test=False)
        context["origin_startup_program"] = startup_programs[0]
        context["origin_startup_programs"] = []
        for program in startup_programs:
            context["origin_startup_programs"].append(program)

        context["role_maker"] = self._role_maker

        context["user_defined_strategy"] = copy.deepcopy(
            self._user_defined_strategy)

        context["valid_strategy"] = copy.deepcopy(self._user_defined_strategy)

        self._context = context

        self.valid_strategy = context["valid_strategy"]
        self.valid_strategy._enable_env()

        optimize_ops = []
        params_grads = []

        from ..meta_optimizers import ParameterServerOptimizer
        ps_optimizer = ParameterServerOptimizer(self.user_defined_optimizer)
        ps_optimizer._set_basic_info(losses, self._role_maker,
                                     self.user_defined_optimizer,
                                     self._user_defined_strategy)
        optimize_ops, params_grads = ps_optimizer.minimize_losses_impl(
            losses, startup_programs, parameter_list, no_grad_set=no_grad_set)

        # default_program = paddle.static.default_main_program()

        # if id(default_program) != id(losses[0].block.program):
        #     paddle.fluid.framework.switch_main_program(losses[0].block.program)

        context["program_optimize_ops"] = optimize_ops
        context["program_params_grads"] = params_grads

        for loss in losses:
            program = loss.block.program
            opt_info = {} if program._fleet_opt is None else program._fleet_opt
1749 1750 1751 1752
            opt_info["mpi_size"] = self.worker_num()
            opt_info["mpi_rank"] = self.worker_index()
            for k, v in self._user_defined_strategy.trainer_desc_configs.items(
            ):
1753
                if v or k not in opt_info:
1754
                    opt_info[k] = v
1755
            program._fleet_opt = opt_info
1756
            # print("fleet base opt info:", id(program), program._fleet_opt)
1757

1758
        if self._runtime_handle is None:
1759
            self._runtime_handle = RuntimeFactory()._create_runtime(context)
1760

1761 1762
        import paddle.distributed.fleet as fleet
        fleet.util._set_strategy(context["valid_strategy"])
1763 1764

        return optimize_ops, params_grads
1765 1766 1767

    @dygraph_only
    def distributed_scaler(self, scaler):
1768 1769 1770 1771 1772 1773
        def unscale_method(self, optimizer):
            if not self._enable:
                return
            if getattr(optimizer, '_param_groups', None) and isinstance(
                    optimizer._param_groups[0], dict):
                param_grads = []
1774 1775
                param_grads_fp16 = []
                param_grads_fp32 = []
1776 1777 1778 1779
                for group in optimizer._param_groups:
                    for param in group['params']:
                        if param._grad_ivar() is not None:
                            param_grads.append(param._grad_ivar())
1780 1781 1782 1783 1784
                            if param._grad_ivar(
                            ).dtype == core.VarDesc.VarType.FP16:
                                param_grads_fp16.append(param._grad_ivar())
                            else:
                                param_grads_fp32.append(param._grad_ivar())
1785 1786 1787 1788 1789
            else:
                param_grads = [
                    param._grad_ivar() for param in optimizer._parameter_list
                    if param._grad_ivar() is not None
                ]
1790 1791
                param_grads_fp16 = [
                    param._grad_ivar() for param in optimizer._parameter_list
1792 1793
                    if (param._grad_ivar() is not None) and (param._grad_ivar(
                    ).dtype == core.VarDesc.VarType.FP16)
1794 1795 1796
                ]
                param_grads_fp32 = [
                    param._grad_ivar() for param in optimizer._parameter_list
1797 1798
                    if (param._grad_ivar() is not None) and (param._grad_ivar(
                    ).dtype == core.VarDesc.VarType.FP32)
1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809
                ]
            temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool))
            temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool))
            if len(param_grads_fp16):
                _C_ops.check_finite_and_unscale(param_grads_fp16, self._scale,
                                                param_grads_fp16,
                                                temp_found_inf_fp16)
            if len(param_grads_fp32):
                _C_ops.check_finite_and_unscale(param_grads_fp32, self._scale,
                                                param_grads_fp32,
                                                temp_found_inf_fp32)
1810

1811
            self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0
1812
            is_found_inf = paddle.to_tensor([self._found_inf], dtype="int32")
1813 1814 1815 1816 1817

            # TODO(shenliang03) Since dp allreduce in the optimizer is 
            # after the gradscaler, check_finite needs to synchronize global 
            # information. In the future, we should use check_group to speed.
            paddle.distributed.all_reduce(
1818 1819
                is_found_inf, op=paddle.distributed.ReduceOp.MAX, group=None)
            self._found_inf = is_found_inf.numpy()[0]
1820 1821 1822 1823 1824 1825 1826

        # Only tensor_parallel and pipeline_parallel need to modify scaler
        if self._hcg.get_parallel_mode() in (ParallelMode.TENSOR_PARALLEL,
                                             ParallelMode.PIPELINE_PARALLEL):
            scaler._unscale = MethodType(unscale_method, scaler)

        return scaler