fleet_base.py 43.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
16
import copy
17
import warnings
18
import paddle
19
import os
20
from paddle.fluid.framework import dygraph_only
21
from paddle.fluid import compiler
22
from .role_maker import UserDefinedRoleMaker, PaddleCloudRoleMaker, RoleMakerBase
23
from .strategy_compiler import StrategyCompiler
24
from .distributed_strategy import DistributedStrategy
25 26
from .meta_optimizer_factory import MetaOptimizerFactory
from .runtime_factory import RuntimeFactory
27
from paddle.fluid.wrapped_decorator import wrap_decorator
28
from paddle.fluid.dygraph import parallel_helper
29
from . import topology as tp
30

31

32 33 34 35 36 37 38 39 40 41 42 43
def _inited_runtime_handler_(func):
    def __impl__(*args, **kwargs):
        cls = args[0]

        if cls._runtime_handle is None:
            raise ValueError("Fleet can not find suitable runtime handler")

        return func(*args, **kwargs)

    return __impl__


44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
def _is_non_distributed_check_(func):
    def __impl__(*args, **kwargs):
        cls = args[0]

        if cls._role_maker is not None and cls._role_maker._is_non_distributed(
        ) is True:
            warnings.warn(
                "%s() function doesn't work when use non_distributed fleet." %
                (func.__name__))
            return

        return func(*args, **kwargs)

    return __impl__


60
inited_runtime_handler = wrap_decorator(_inited_runtime_handler_)
61
is_non_distributed_check = wrap_decorator(_is_non_distributed_check_)
62 63


64 65 66
class Fleet(object):
    """
    Unified API for distributed training of PaddlePaddle
67
    Please reference the https://github.com/PaddlePaddle/FleetX for details
68 69 70 71 72


    Returns:
        Fleet: A Fleet instance

73
    Example for collective training:
1
123malin 已提交
74

75 76
        .. code-block:: python

1
123malin 已提交
77 78
            import paddle
            paddle.enable_static()
79
            import paddle.distributed.fleet as fleet
80 81 82

            fleet.init(is_collective=True)

83 84 85
            strategy = fleet.DistributedStrategy()
            optimizer = paddle.optimizer.SGD(learning_rate=0.001)
            optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
86 87 88 89 90 91 92 93

            # do distributed training


    Example for parameter server training:

        .. code-block:: python

1
123malin 已提交
94 95
            import paddle
            paddle.enable_static()
96 97
            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
S
ShenLiang 已提交
98
            fleet.init(strategy=strategy)
99

100
            optimizer = paddle.optimizer.SGD(learning_rate=0.001)
101
            optimizer = fleet.distributed_optimizer(optimizer)
102

103 104
            if fleet.is_first_worker():
                print("this is first worker")
105

106 107
            print("current node index: {}".format(fleet.worker_index()))
            print("total number of worker num: {}".format(fleet.worker_num()))
108

109 110 111
            if fleet.is_worker():
                print("this is worker")
            print("worker endpoints: {}".format(fleet.worker_endpoints(to_string=True)))
112

113 114
            print("server num: {}".format(fleet.server_num()))
            print("server endpoints: {}".format(fleet.server_endpoints(to_string=True)))
115

116 117 118
            if fleet.is_server():
                print("this is server")
            fleet.stop_worker()
119 120


121 122 123
    """

    def __init__(self):
124
        self._role_maker = None
125
        self.strategy_compiler = None
126
        self._is_collective = False
127
        self._runtime_handle = None
D
Dong Daxiang 已提交
128 129
        self._util = None
        self._context = {}
130

131
    def init(self, role_maker=None, is_collective=False, strategy=None):
132 133 134
        """
        Initialize role_maker in Fleet.

135 136 137 138 139 140 141 142 143 144 145
        This function is responsible for the distributed architecture
        what you want to run your code behind.

        Args:
            role_maker (RoleMakerBase, optional): A ``RoleMakerBase`` containing the configuration
                of environment variables related to distributed training.If you did not initialize 
                the rolemaker by yourself, it will be automatically initialized to PaddleRoleMaker.
                The default value is None.
            is_collective (Boolean, optional): A ``Boolean`` variable determines whether the program 
                runs on the CPU or GPU. False means set distributed training using CPU, and True means
                GPU.The default value is False.The default value is False.
146 147 148 149
            strategy (DistributedStrategy): Extra properties for distributed training. 
                For details, please refer to paddle.distributed.fleet.DistributedStrategy. Default: None.


150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
        Returns:
            None

        Examples1:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

        Examples2:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init(is_collective=True)

        Examples3:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
1
123malin 已提交
172
                role = fleet.PaddleCloudRoleMaker()
173
                fleet.init(role)
174

175 176 177 178 179 180
        Examples4:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                strategy = fleet.DistributedStrategy()
S
ShenLiang 已提交
181
                fleet.init(strategy=strategy)
182

183
        """
S
ShenLiang 已提交
184 185 186
        if strategy is None:
            strategy = DistributedStrategy()
        self._user_defined_strategy = copy.deepcopy(strategy)
187 188

        if role_maker is None:
189 190 191 192 193 194
            if isinstance(is_collective, bool):
                self._is_collective = is_collective
                self._role_maker = PaddleCloudRoleMaker(
                    is_collective=self._is_collective)
            else:
                raise ValueError(
195 196
                    "`is_collective` should be instance of `bool`, but got {}".
                    format(type(is_collective)))
197
        else:
198 199
            if isinstance(role_maker, RoleMakerBase):
                self._role_maker = role_maker
200
                self._is_collective = role_maker._is_collective
201 202 203 204
            else:
                raise ValueError(
                    "`role_maker` should be subclass of `RoleMakerBase`, but got {}".
                    format(type(role_maker)))
205
        self._role_maker._generate_role()
206

207 208 209
        import paddle.distributed.fleet as fleet
        fleet.util._set_role_maker(self._role_maker)

210
        self.strategy_compiler = StrategyCompiler()
211 212 213 214 215 216 217 218 219

        if self._role_maker._is_non_distributed() and self._is_collective:
            if paddle.fluid.core.is_compiled_with_cuda():
                gpus_num = paddle.fluid.core.get_cuda_device_count()
                if gpus_num != 1:
                    raise ValueError(
                        "CUDA_VISIBLE_DEVICES shoule be set only 1 card if you use `python` to launch fleet program."
                    )

220
        if paddle.fluid.framework.in_dygraph_mode():
221 222
            if self.worker_num() == 1:
                return
223 224 225 226
            if parallel_helper._is_parallel_ctx_initialized():
                warnings.warn(
                    "The dygraph parallel environment has been initialized.")
            else:
227 228 229 230 231 232 233 234 235
                # FLAGS_nccl_nrings is used for dynamic graph multi-stream communication
                if "FLAGS_nccl_nrings" in os.environ:
                    warnings.warn(
                        "You have set the environment variable FLAGS_nccl_nrings "
                        "outside the program, so the nccl_comm_num in "
                        "DistributedStrategy will not take effect here.")
                else:
                    os.environ["FLAGS_nccl_nrings"] = str(
                        self._user_defined_strategy.nccl_comm_num)
236
                paddle.distributed.init_parallel_env()
237

238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279
            # init hybrid parallel environment in dygraph
            if tp._HYBRID_PARALLEL_GROUP is None:
                self._init_hybrid_parallel_env()
            else:
                warnings.warn(
                    "The dygraph hybrid parallel environment has been initialized."
                )

    def _init_hybrid_parallel_env(self):
        """initialize the hybrid environment
        """
        self.hybrid_configs = self._user_defined_strategy.hybrid_configs
        self.dp_degree = self.hybrid_configs["dp_degree"]
        self.mp_degree = self.hybrid_configs["mp_degree"]
        self.pp_degree = self.hybrid_configs["pp_degree"]

        assert self.mp_degree >= 0, "mp_degree should be greater or equal to 0"
        assert self.pp_degree >= 0, "pp_degree should be greater or equal to 0"

        self.mp_degree = max(self.mp_degree, 1)
        self.pp_degree = max(self.pp_degree, 1)

        if self.dp_degree < 0:
            nranks = paddle.distributed.get_world_size()
            self.dp_degree = nranks // (self.mp_degree * self.pp_degree)

        self.dp_degree = max(self.dp_degree, 1)

        self._topology = tp.CommunicateTopology(
            hybrid_group_names=["data", "pipe", "model"],
            dims=[self.dp_degree, self.pp_degree, self.mp_degree])

        self._hcg = tp.HybridCommunicateGroup(self._topology)

    def get_hybrid_communicate_group(self):
        assert self._hcg is not None
        return self._hcg

    def get_hybrid_parallel_topology(self):
        assert self._topology is not None
        return self._topology

280 281 282 283 284 285 286
    def is_first_worker(self):
        """
        Check whether the node is the first instance of worker.

        Returns:
            bool: True if this is the first node of worker,
                  False if not.
287

288 289 290 291 292 293 294 295
        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.is_first_worker()

296
        """
297
        return self._role_maker._is_first_worker()
298 299 300 301 302 303 304

    def worker_index(self):
        """
        Get current worker index.

        Returns:
            int: node id
305 306 307 308

        Examples:

            .. code-block:: python
1
123malin 已提交
309

310 311 312 313
                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.worker_index()

314
        """
315
        return self._role_maker._worker_index()
316 317 318 319 320 321 322

    def worker_num(self):
        """
        Get current total worker number.

        Returns:
            int: worker numbers
1
123malin 已提交
323

324
        Examples:
1
123malin 已提交
325

326 327 328 329 330 331
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.worker_num()

332
        """
333
        return self._role_maker._worker_num()
334

335 336 337 338 339 340 341 342 343 344 345 346
    def node_num(self):
        return self._role_maker._get_node_num()

    def local_rank(self):
        return self._role_maker._get_local_rank()

    def local_device_ids(self):
        return self._role_maker._get_local_device_ids()

    def world_device_ids(self):
        return self._role_maker._get_world_device_ids()

347 348 349 350 351 352 353
    def is_worker(self):
        """
        Check whether the node is an instance of worker.

        Returns:
            bool: True if this is a node of worker,
                  False if not.
354 355

        Examples:
1
123malin 已提交
356

357 358 359 360 361 362
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.is_worker()

363
        """
364
        return self._role_maker._is_worker()
365 366 367

    def worker_endpoints(self, to_string=False):
        """
368
        Get current worker endpoints, such as ["127.0.0.1:1001", "127.0.0.1:1002"].
369 370 371

        Returns:
            list/string: server endpoints
372 373

        Examples:
1
123malin 已提交
374

375 376 377 378 379 380
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.worker_endpoints()

381 382
        """
        if to_string:
383
            return ",".join(self._role_maker._get_trainer_endpoints())
384
        else:
385
            return self._role_maker._get_trainer_endpoints()
386 387 388 389 390 391 392

    def server_num(self):
        """
        Get current total worker number.

        Returns:
            int: server number
393 394

        Examples:
1
123malin 已提交
395

396
            .. code-block:: python
1
123malin 已提交
397 398 399 400

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.server_num()
401
        """
402
        return len(self._role_maker._get_pserver_endpoints())
403 404 405 406 407 408 409

    def server_index(self):
        """
        Get current server index.

        Returns:
            int: node id
410 411

        Examples:
1
123malin 已提交
412

413 414 415 416 417 418
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.server_index()

419
        """
420
        return self._role_maker._server_index()
421 422 423 424 425 426 427

    def server_endpoints(self, to_string=False):
        """
        Get current server endpoints, such as ["127.0.0.1:1001", "127.0.0.1:1002"].

        Returns:
            list/string: server endpoints
428 429

        Examples:
1
123malin 已提交
430

431 432 433 434 435 436
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.server_endpoints()

437
        """
438

439
        if to_string:
440
            return ",".join(self._role_maker._get_pserver_endpoints())
441
        else:
442
            return self._role_maker._get_pserver_endpoints()
443 444 445 446 447 448 449 450

    def is_server(self):
        """
        Check whether the node is an instance of server.

        Returns:
            bool: True if this is a node of server,
                  False if not.
451 452 453 454

        Examples:

            .. code-block:: python
1
123malin 已提交
455

456 457 458 459
                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.is_server()

460
        """
461
        return self._role_maker._is_server(
462
        ) or self._role_maker._is_heter_worker()
463 464 465

    def barrier_worker(self):
        """
466 467 468 469
        barrier all workers

        Returns:
            None
470
        """
471
        self._role_maker._barrier("worker")
472

473
    @is_non_distributed_check
474
    @inited_runtime_handler
475 476
    def init_worker(self):
        """
477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494
        initialize `Communicator` for parameter server training.


        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_worker()

495 496 497
        """
        self._runtime_handle._init_worker()

498
    @is_non_distributed_check
499
    @inited_runtime_handler
500
    def init_server(self, *args, **kwargs):
501
        """
502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520
        init_server executor to initialize startup program,
        if the `args` is not empty, it will run load_persistables for increment training.


        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_server()

521
        """
522
        self._runtime_handle._init_server(*args, **kwargs)
523

524
    @is_non_distributed_check
525
    @inited_runtime_handler
526 527
    def run_server(self):
        """
528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545
        run server will run pserver main program with executor.

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                if fleet.is_server():
                    fleet.init_server()

546 547 548
        """
        self._runtime_handle._run_server()

549
    @is_non_distributed_check
550
    @inited_runtime_handler
551 552
    def stop_worker(self):
        """
553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569
        stop `Communicator` and give training complete notice to parameter server.

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_server()

570 571 572
        """
        self._runtime_handle._stop_worker()

573 574 575 576 577 578
    def save_inference_model(self,
                             executor,
                             dirname,
                             feeded_var_names,
                             target_vars,
                             main_program=None,
579 580
                             export_for_deployment=True,
                             mode=0):
581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600
        """
        save inference model for inference.

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_server()

        """

601 602
        self._runtime_handle._save_inference_model(
            executor, dirname, feeded_var_names, target_vars, main_program,
603
            export_for_deployment, mode)
604

605
    def save_persistables(self, executor, dirname, main_program=None, mode=0):
606 607
        """

1
123malin 已提交
608
        saves all persistable tensors from :code:`main_program` to
609 610
        the folder :code:`dirname`. You can refer to

1
123malin 已提交
611 612
        The :code:`dirname` is used to specify the folder where persistable tensors
        are going to be saved. If you would like to save tensors in separate
613 614 615
        files, set :code:`filename` None.

        Args:
1
123malin 已提交
616
            executor(Executor): The executor to run for saving persistable tensors.
617 618 619 620 621
                                You can refer to :ref:`api_guide_executor_en` for
                                more details.

            dirname(str, optional): The saving directory path.
                                When you need to save the parameter to the memory, set it to None.
1
123malin 已提交
622
            main_program(Program, optional): The program whose persistbale tensors will
623 624 625 626 627 628 629 630 631 632
                                             be saved. Default: None.


        Returns:
            None

        Examples:

            .. code-block:: text

1
123malin 已提交
633 634
                import paddle
                paddle.enable_static()
635 636 637 638 639 640 641
                import paddle.distributed.fleet as fleet

                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

1
123malin 已提交
642 643
                exe = paddle.static.Executor(paddle.CPUPlace())
                fleet.save_persistables(exe, "dirname", paddle.static.default_main_program())
644 645 646

        """

647 648
        self._runtime_handle._save_persistables(executor, dirname, main_program,
                                                mode)
649

650 651 652
    def shrink(self, threshold):
        self._runtime_handle._shrink(threshold)

653
    def distributed_optimizer(self, optimizer, strategy=None):
654
        """
655 656 657 658 659 660 661
        Optimizer for distributed training.

        For the distributed training, this method would rebuild a new instance of DistributedOptimizer.
        Which has basic Optimizer function and special features for distributed training.

        Args:
            optimizer(Optimizer): The executor to run for init server.
662 663 664 665 666
            strategy(DistributedStrategy): Extra properties for distributed optimizer. 
                It is recommended to use DistributedStrategy in fleet.init(). The strategy
                here is for compatibility. If the strategy in fleet.distributed_optimizer() 
                is not None, then it will overwrite the DistributedStrategy in fleet.init(), 
                which will take effect in distributed training.
667

668
        Returns:
669
            Fleet: instance of fleet.
670 671

        Examples:
672

673
            .. code-block:: python
674

1
123malin 已提交
675
                import paddle
676
                import paddle.distributed.fleet as fleet
1
123malin 已提交
677
                fleet.init(is_collective=True)
678 679 680 681
                strategy = fleet.DistributedStrategy()
                optimizer = paddle.optimizer.SGD(learning_rate=0.001)
                optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)

682 683
        """
        self.user_defined_optimizer = optimizer
684

685
        if strategy is not None:
T
tangwei12 已提交
686 687 688 689 690 691 692
            if self._is_collective:
                warnings.warn(
                    "It is recommended to use DistributedStrategy "
                    "in fleet.init(). The strategy here is only for compatibility. "
                    "If the strategy in fleet.distributed_optimizer() is "
                    "not None, then it will overwrite the DistributedStrategy in fleet.init(), "
                    "which will take effect in distributed training.")
693
            self._user_defined_strategy = copy.deepcopy(strategy)
D
Dong Daxiang 已提交
694 695

        self._context = {}
S
ShenLiang 已提交
696 697 698 699 700

        # TODO(shenliang03): This is a temporary solution to support amp. In the case of a dynamic graph, 
        # the optimizer is returned directly. This problem will be fixed in the future.
        if paddle.fluid.framework.in_dygraph_mode():
            return optimizer
701 702
        return self

703
    @dygraph_only
704
    def distributed_model(self, model):
705
        """
706 707 708 709 710 711 712
        Return distributed data parallel model (Only work in dygraph mode)

        Args:
            model (Layer): the user-defind model which inherits Layer.

        Returns:
            distributed data parallel model which inherits Layer.
713 714

        Examples:
715

716 717
            .. code-block:: python

718 719 720 721 722 723 724 725 726
                import paddle
                import paddle.nn as nn
                from paddle.distributed import fleet

                class LinearNet(nn.Layer):
                    def __init__(self):
                        super(LinearNet, self).__init__()
                        self._linear1 = nn.Linear(10, 10)
                        self._linear2 = nn.Linear(10, 1)
727

728 729
                    def forward(self, x):
                        return self._linear2(self._linear1(x))
730

1
123malin 已提交
731
                # 1. initialize fleet environment
732 733
                fleet.init(is_collective=True)

1
123malin 已提交
734
                # 2. create layer & optimizer
735 736 737 738 739
                layer = LinearNet()
                loss_fn = nn.MSELoss()
                adam = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=layer.parameters())

1
123malin 已提交
740
                # 3. get data_parallel model using fleet
741 742 743
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)

1
123malin 已提交
744
                # 4. run layer
745 746 747 748 749 750 751 752 753 754 755 756
                inputs = paddle.randn([10, 10], 'float32')
                outputs = dp_layer(inputs)
                labels = paddle.randn([10, 1], 'float32')
                loss = loss_fn(outputs, labels)

                print("loss:", loss.numpy())

                loss.backward()

                adam.step()
                adam.clear_grad()

757

758 759
        """
        assert model is not None
760 761
        self.model = paddle.DataParallel(
            model,
762 763
            comm_buffer_size=self._user_defined_strategy.fuse_grad_size_in_MB,
            last_comm_buffer_size=self._user_defined_strategy.
764 765 766
            last_comm_group_size_MB,
            find_unused_parameters=self._user_defined_strategy.
            find_unused_parameters)
767 768 769 770 771 772
        return self.model

    @dygraph_only
    def state_dict(self):
        """
        Get state dict information from optimizer.
773
        (Only work in dygraph mode)
774 775 776 777 778 779 780

        Returns: 
            state_dict(dict) : dict contains all the Tensor used by optimizer

        Examples:
            .. code-block:: python

781 782 783 784 785
                import numpy as np
                import paddle
                from paddle.distributed import fleet

                fleet.init(is_collective=True)
786

787
                value = np.arange(26).reshape(2, 13).astype("float32")
1
123malin 已提交
788
                a = paddle.to_tensor(value)
789

790 791
                layer = paddle.nn.Linear(13, 5)
                adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
792

793 794 795
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)
                state_dict = adam.state_dict()
796 797 798 799 800 801 802 803
        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.state_dict()

    @dygraph_only
    def set_state_dict(self, state_dict):
        """
        Load optimizer state dict.
804
        (Only work in dygraph mode)
805 806 807 808

        Args: 
            state_dict(dict) : Dict contains all the Tensor needed by optimizer

809 810
        Returns:
            None
811 812 813 814

        Examples:
            .. code-block:: python

815 816 817
                import numpy as np
                import paddle
                from paddle.distributed import fleet
818

819 820 821
                fleet.init(is_collective=True)

                value = np.arange(26).reshape(2, 13).astype("float32")
1
123malin 已提交
822
                a = paddle.to_tensor(value)
823

824 825
                layer = paddle.nn.Linear(13, 5)
                adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
826

827 828 829
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)
                state_dict = adam.state_dict()
1
123malin 已提交
830 831 832
                paddle.save(state_dict, "paddle_dy")
                para_state_dict = paddle.load("paddle_dy")
                adam.set_state_dict(para_state_dict)
833 834 835 836 837 838 839 840
        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.set_state_dict(state_dict)

    @dygraph_only
    def set_lr(self, value):
        """
        Set the value of the learning rate manually in the optimizer. 
841
        (Only work in dygraph mode)
842

843 844 845
        Args:
            value (float|Tensor): the value of learning rate

846 847
        Returns: 
            None 
848 849 850 851

        Examples:
            .. code-block:: python

852 853 854
                import numpy as np
                import paddle
                from paddle.distributed import fleet
855

856
                fleet.init(is_collective=True)
857

858
                value = np.arange(26).reshape(2, 13).astype("float32")
1
123malin 已提交
859
                a = paddle.to_tensor(value)
860

861 862
                layer = paddle.nn.Linear(13, 5)
                adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
863

864 865 866 867 868 869 870 871 872 873 874 875 876 877
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)

                lr_list = [0.2, 0.3, 0.4, 0.5, 0.6]
                for i in range(5):
                    adam.set_lr(lr_list[i])
                    lr = adam.get_lr()
                    print("current lr is {}".format(lr))
                # Print:
                #    current lr is 0.2
                #    current lr is 0.3
                #    current lr is 0.4
                #    current lr is 0.5
                #    current lr is 0.6
878 879 880 881 882 883 884 885
        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.set_lr(value)

    @dygraph_only
    def get_lr(self):
        """
        Get current step learning rate.
886
        (Only work in dygraph mode)
887 888 889 890 891

        Returns:
            float: The learning rate of the current step.

        Examples:
1
123malin 已提交
892

893 894
            .. code-block:: python

895 896 897 898 899
                import numpy as np
                import paddle
                from paddle.distributed import fleet

                fleet.init(is_collective=True)
900

901
                value = np.arange(26).reshape(2, 13).astype("float32")
1
123malin 已提交
902
                a = paddle.to_tensor(value)
903

904 905
                layer = paddle.nn.Linear(13, 5)
                adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
906

907 908
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)
909

910 911
                lr = adam.get_lr()
                print(lr) # 0.01
912 913 914 915 916 917 918 919
        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.get_lr()

    @dygraph_only
    def step(self):
        """
        Execute the optimizer once.
920
        (Only work in dygraph mode)
921

922 923
        Returns:
            None
924 925

        Examples:
1
123malin 已提交
926

927 928
            .. code-block:: python

929 930 931
                import paddle
                import paddle.nn as nn
                from paddle.distributed import fleet
932

933 934 935 936 937
                class LinearNet(nn.Layer):
                    def __init__(self):
                        super(LinearNet, self).__init__()
                        self._linear1 = nn.Linear(10, 10)
                        self._linear2 = nn.Linear(10, 1)
938

939 940
                    def forward(self, x):
                        return self._linear2(self._linear1(x))
941

1
123malin 已提交
942
                # 1. initialize fleet environment
943 944
                fleet.init(is_collective=True)

1
123malin 已提交
945
                # 2. create layer & optimizer
946 947 948 949 950
                layer = LinearNet()
                loss_fn = nn.MSELoss()
                adam = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=layer.parameters())

1
123malin 已提交
951
                # 3. get data_parallel model using fleet
952 953 954
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)

1
123malin 已提交
955
                # 4. run layer
956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975
                inputs = paddle.randn([10, 10], 'float32')
                outputs = dp_layer(inputs)
                labels = paddle.randn([10, 1], 'float32')
                loss = loss_fn(outputs, labels)

                print("loss:", loss.numpy())

                loss.backward()

                adam.step()
                adam.clear_grad()


        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.step()

    @dygraph_only
    def clear_grad(self):
        """
976 977
        Clear the gradients of all optimized parameters for model.
        (Only work in dygraph mode)
978

979 980
        Returns: 
            None
981 982

        Examples:
1
123malin 已提交
983

984 985
            .. code-block:: python

986 987 988
                import paddle
                import paddle.nn as nn
                from paddle.distributed import fleet
989

990 991 992 993 994
                class LinearNet(nn.Layer):
                    def __init__(self):
                        super(LinearNet, self).__init__()
                        self._linear1 = nn.Linear(10, 10)
                        self._linear2 = nn.Linear(10, 1)
995

996 997
                    def forward(self, x):
                        return self._linear2(self._linear1(x))
998

1
123malin 已提交
999
                # 1. initialize fleet environment
1000 1001
                fleet.init(is_collective=True)

1
123malin 已提交
1002
                # 2. create layer & optimizer
1003 1004 1005 1006 1007
                layer = LinearNet()
                loss_fn = nn.MSELoss()
                adam = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=layer.parameters())

1
123malin 已提交
1008
                # 3. get data_parallel model using fleet
1009 1010 1011
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)

1
123malin 已提交
1012
                # 4. run layer
1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028
                inputs = paddle.randn([10, 10], 'float32')
                outputs = dp_layer(inputs)
                labels = paddle.randn([10, 1], 'float32')
                loss = loss_fn(outputs, labels)

                print("loss:", loss.numpy())

                loss.backward()

                adam.step()
                adam.clear_grad()

        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.clear_grad()

H
huangxu96 已提交
1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088
    def amp_init(self,
                 place,
                 scope=None,
                 test_program=None,
                 use_fp16_test=False):
        """
        Init the amp training, such as cast fp32 parameters to fp16 type.
  
        Args:
            place(CUDAPlace): place is used to initialize 
                fp16 parameters with fp32 values.
            scope(Scope): The scope is used to find fp32 parameters.
            test_program(Program): The program is used for testing.
            use_fp16_test(bool): Whether to use fp16 testing.
            
        Examples:
            .. code-block:: python

                import numpy as np
                import paddle
                import paddle.nn.functional as F
                paddle.enable_static()

                def run_example_code():
                    place = paddle.CUDAPlace(0)
                    exe = paddle.static.Executor(place)
                    data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32')
                    conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3)
                    # 1) Use fp16_guard to control the range of fp16 kernels used.
                    with paddle.static.amp.fp16_guard():
                        bn = paddle.static.nn.batch_norm(input=conv2d, act="relu")
                        pool = F.max_pool2d(bn, kernel_size=2, stride=2)
                        hidden = paddle.static.nn.fc(pool, size=10)
                        loss = paddle.mean(hidden)
                    # 2) Create the optimizer and set `multi_precision` to True.
                    # Setting `multi_precision` to True can avoid the poor accuracy
                    # or the slow convergence in a way. 
                    optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True)
                    # 3) These ops in `custom_black_list` will keep in the float32 computation type.
                    amp_list = paddle.static.amp.CustomOpLists(
                        custom_black_list=['pool2d'])
                    # 4) The entry of Paddle AMP.
                    # Enable pure fp16 training by setting `use_pure_fp16` to True.
                    optimizer = paddle.static.amp.decorate(
                        optimizer,
                        amp_list,
                        init_loss_scaling=128.0,
                        use_dynamic_loss_scaling=True,
                        use_pure_fp16=True)
                    # If you don't use the default_startup_program(), you sholud pass
                    # your defined `startup_program` into `minimize`.
                    optimizer.minimize(loss)
                    exe.run(paddle.static.default_startup_program())
                    # 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`).
                    # If you want to perform the testing process, you should pass `test_program` into `amp_init`.
                    optimizer.amp_init(place, scope=paddle.static.global_scope())
                    
                if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
                    run_example_code()       
        """
1089

H
huangxu96 已提交
1090
        # imitate target optimizer retrieval
1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104
        amp_optimizer = None
        for optimizer in self.strategy_compiler._get_applied_meta_optimizer():
            if hasattr(optimizer, 'amp_init'):
                amp_optimizer = optimizer
                break

        if amp_optimizer is None:
            if hasattr(self.user_defined_optimizer, 'amp_init'):
                amp_optimizer = self.user_defined_optimizer

        assert amp_optimizer is not None, \
            "amp_init can only be used when the amp(auto mixed precision) strategy is turned on."

        return amp_optimizer.amp_init(place, scope, test_program, use_fp16_test)
H
huangxu96 已提交
1105

D
Dong Daxiang 已提交
1106 1107 1108 1109 1110 1111 1112 1113 1114
    def _final_strategy(self):
        if "valid_strategy" not in self._context:
            print(
                "WARNING: You may need to call minimize function before this function is called"
            )
            return {}
        else:
            return self._context["valid_strategy"]

1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132
    def _get_applied_meta_list(self):
        if "applied_meta_list" not in self._context:
            print(
                "WARNING: You may need to call minimize function before _get_applied_meta_list called"
            )
            return []
        else:
            return self._context["applied_meta_list"]

    def _get_applied_graph_list(self):
        if "applied_graph_list" not in self._context:
            print(
                "WARNING: You may need to call minimize function before _get_applied_graph_list called"
            )
            return []
        else:
            return self._context["applied_graph_list"]

1133 1134 1135 1136 1137 1138 1139 1140 1141
    def minimize(self,
                 loss,
                 startup_program=None,
                 parameter_list=None,
                 no_grad_set=None):
        """
        Add distributed operations to minimize ``loss`` by updating ``parameter_list``.

        Args:
1
123malin 已提交
1142
            loss (Tensor): A ``Tensor`` containing the value to minimize.
1143 1144 1145
            startup_program (Program, optional): :ref:`api_fluid_Program` for
                initializing parameters in ``parameter_list``. The default value
                is None, at this time :ref:`api_fluid_default_startup_program` will be used.
1
123malin 已提交
1146
            parameter_list (Iterable, optional): Iterable of ``Tensor`` or ``Tensor.name`` to update
1147 1148
                to minimize ``loss``. The default value is None, at this time all parameters
                will be updated.
1
123malin 已提交
1149
            no_grad_set (set, optional): Set of ``Tensor``  or ``Tensor.name`` that don't need
1150 1151 1152 1153
                to be updated. The default value is None.

        Returns:
            tuple: tuple (optimize_ops, params_grads), A list of operators appended
1
123malin 已提交
1154
            by minimize and a list of (param, grad) tensor pairs, param is
1155
            ``Parameter``, grad is the gradient value corresponding to the parameter.
1156 1157
            The returned tuple can be passed to ``fetch_list`` in ``Executor.run()`` to
            indicate program pruning. If so, the program will be pruned by ``feed`` and
1158 1159 1160
            ``fetch_list`` before run, see details in ``Executor``.

        Examples:
1
123malin 已提交
1161

1162
            .. code-block:: python
1163

1164
                import paddle
1
123malin 已提交
1165
                paddle.enable_static()
1166
                import paddle.distributed.fleet as fleet
1
123malin 已提交
1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177
                import paddle.nn.functional as F

                hid_dim = 10
                label_dim = 2
                input_x = paddle.static.data(name='x', shape=[None, 13], dtype='float32')
                input_y = paddle.static.data(name='y', shape=[None, 1], dtype='int64')
                fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim, activation='tanh')
                fc_2 = paddle.static.nn.fc(x=fc_1, size=hid_dim, activation='tanh')
                prediction = paddle.static.nn.fc(x=[fc_2], size=label_dim, activation='softmax')
                cost = F.cross_entropy(input=prediction, label=input_y)
                avg_cost = paddle.mean(x=cost)
1178

1
123malin 已提交
1179
                fleet.init(is_collective=True)
1180 1181 1182 1183
                strategy = fleet.DistributedStrategy()
                optimizer = paddle.optimizer.SGD(learning_rate=0.001)
                optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
                optimizer.minimize(avg_cost)
1184

1185
                # for more examples, please reference https://github.com/PaddlePaddle/FleetX
1186 1187

        """
D
Dong Daxiang 已提交
1188 1189 1190
        context = {}
        context["user_defined_strategy"] = copy.deepcopy(
            self._user_defined_strategy)
1191 1192 1193
        if paddle.fluid.framework.in_dygraph_mode():
            # imitate target optimizer retrieval
            target_opt = self.user_defined_optimizer
D
Dong Daxiang 已提交
1194
            self._context = context
1195 1196
            return target_opt.minimize(loss)

1197 1198
        # cache original feed forward program
        self.origin_main_program = loss.block.program
1199 1200
        context["origin_main_program"] = self.origin_main_program
        context["loss"] = loss
1201 1202
        if startup_program == None:
            self.origin_startup_program = \
1203 1204
                paddle.static.default_startup_program().clone(for_test=False)
            startup_program = paddle.static.default_startup_program()
1205 1206 1207
        else:
            self.origin_startup_program = \
                startup_program.clone(for_test=False)
1208

1209 1210
        context["origin_startup_program"] = startup_program
        context["role_maker"] = self._role_maker
1211 1212 1213 1214 1215

        # compile time
        distributed_optimizer_list = \
            MetaOptimizerFactory()._get_valid_meta_optimizers(
                self.user_defined_optimizer)
D
Dong Daxiang 已提交
1216

D
Dong Daxiang 已提交
1217 1218 1219
        context["user_defined_strategy"] = copy.deepcopy(
            self._user_defined_strategy)
        copy_user_defined_strategy = copy.deepcopy(self._user_defined_strategy)
1220 1221 1222 1223 1224 1225

        # trigger the auto-parallel in very strict condition
        # strategy = DistributedStrategy()
        # strategy.auto = True
        # optimizer = paddle.optimizer.SGD(learning_rate=0.1)
        # optimizer = fleet.distributed_optimizer(optimizer, strategy)
D
Dong Daxiang 已提交
1226
        if copy_user_defined_strategy._is_strict_auto():
1227 1228
            # turn on all the strategy for each optimizer
            for opt in distributed_optimizer_list:
D
Dong Daxiang 已提交
1229
                opt._enable_strategy(copy_user_defined_strategy, context)
1230

1231 1232
        valid_optimizer_list = []
        valid_graph_optimizer_list = []
D
Dong Daxiang 已提交
1233
        can_not_apply_optimizer_list = []
1234 1235 1236 1237
        # recall meta optimizers for ranking
        for opt in distributed_optimizer_list:
            opt._set_basic_info(loss, self._role_maker,
                                self.user_defined_optimizer,
D
Dong Daxiang 已提交
1238
                                copy_user_defined_strategy)
1239 1240
            if opt._can_apply() and not opt._is_graph_out():
                valid_optimizer_list.append(opt)
D
Dong Daxiang 已提交
1241
            elif opt._can_apply() and opt._is_graph_out():
1242
                valid_graph_optimizer_list.append(opt)
D
Dong Daxiang 已提交
1243 1244
            else:
                can_not_apply_optimizer_list.append(opt)
1245
        # combine recalled meta optimizers to be a valid meta optimizer
D
Dong Daxiang 已提交
1246
        meta_optimizer, graph_optimizer = \
1247 1248
            self.strategy_compiler.generate_optimizer(
                loss, self._role_maker, self.user_defined_optimizer,
D
Dong Daxiang 已提交
1249
                copy_user_defined_strategy, valid_optimizer_list,
1250
                valid_graph_optimizer_list)
D
Dong Daxiang 已提交
1251

D
Dong Daxiang 已提交
1252
        valid_strategy = self.strategy_compiler._get_valid_strategy(
D
Dong Daxiang 已提交
1253 1254 1255
            copy_user_defined_strategy, can_not_apply_optimizer_list)

        context["valid_strategy"] = copy.deepcopy(valid_strategy)
1256

1257 1258 1259 1260 1261 1262
        applied_meta_list = self.strategy_compiler._get_applied_meta_list()
        applied_graph_list = self.strategy_compiler._get_applied_graph_list()

        context['applied_meta_list'] = applied_meta_list
        context['applied_graph_list'] = applied_graph_list

D
Dong Daxiang 已提交
1263
        self._context = context
1264

D
Dong Daxiang 已提交
1265
        self.valid_strategy = valid_strategy
1266
        self.valid_strategy._enable_env()
D
Dong Daxiang 已提交
1267

1268 1269
        optimize_ops = []
        params_grads = []
1270

1271 1272 1273 1274 1275 1276 1277 1278 1279
        if self._role_maker._is_non_distributed() and not self._is_collective:
            if self._runtime_handle is None:
                self._runtime_handle = RuntimeFactory()._create_runtime(context)

            compiled_program = compiler.CompiledProgram(
                self.origin_main_program).with_data_parallel(
                    loss_name=loss.name, share_vars_from=None)
            loss.block.program._graph = compiled_program
            return self.user_defined_optimizer.minimize(
M
MRXLT 已提交
1280
                loss, startup_program, parameter_list, no_grad_set=no_grad_set)
1281

1282 1283
        if meta_optimizer:
            optimize_ops, params_grads = meta_optimizer.minimize(
M
MRXLT 已提交
1284
                loss, startup_program, parameter_list, no_grad_set=no_grad_set)
1285

1286
            default_program = paddle.static.default_main_program()
1287 1288 1289 1290

            if id(default_program) != id(loss.block.program):
                paddle.fluid.framework.switch_main_program(loss.block.program)

1291 1292
        else:
            optimize_ops, params_grads = self.user_defined_optimizer.minimize(
M
MRXLT 已提交
1293
                loss, startup_program, parameter_list, no_grad_set=no_grad_set)
1294

1295 1296
        context["program_optimize_ops"] = optimize_ops
        context["program_params_grads"] = params_grads
1297

1298
        if graph_optimizer:
D
Dong Daxiang 已提交
1299
            optimize_ops, params_grads = graph_optimizer.minimize(
M
MRXLT 已提交
1300
                loss, startup_program, parameter_list, no_grad_set=no_grad_set)
1301 1302 1303 1304
            # since we do not encourage users to use graph operations
            # if a graph optimizer takes effect, mostly
            # optimizers_ops and params_grads are None
            # i.e. users can not modify current computation graph anymore
1305 1306 1307
            context["graph_optimize_ops"] = optimize_ops
            context["graph_optimize_grads"] = params_grads

1308
        if self._runtime_handle is None:
1309
            self._runtime_handle = RuntimeFactory()._create_runtime(context)
1310

1311 1312
        import paddle.distributed.fleet as fleet
        fleet.util._set_strategy(context["valid_strategy"])
1313 1314

        return optimize_ops, params_grads