fleet_base.py 45.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
16
import copy
17
import warnings
18
import paddle
19
import os
20
from paddle.fluid.framework import dygraph_only
21
from paddle.fluid import compiler
22
from .role_maker import UserDefinedRoleMaker, PaddleCloudRoleMaker, RoleMakerBase
23
from .strategy_compiler import StrategyCompiler
24
from .distributed_strategy import DistributedStrategy
25 26
from .meta_optimizer_factory import MetaOptimizerFactory
from .runtime_factory import RuntimeFactory
27
from paddle.fluid.wrapped_decorator import wrap_decorator
28
from paddle.fluid.dygraph import parallel_helper
29
from . import topology as tp
30 31
from .topology import ParallelMode
from ..meta_parallel import ModelParallel
32
from ..meta_parallel import PipelineParallel
33
from ..meta_optimizers import HybridParallelOptimizer
34
from ..meta_optimizers import HybridParallelGradScaler
35

36 37
__all__ = []

38

39 40 41 42 43 44 45 46 47 48 49 50
def _inited_runtime_handler_(func):
    def __impl__(*args, **kwargs):
        cls = args[0]

        if cls._runtime_handle is None:
            raise ValueError("Fleet can not find suitable runtime handler")

        return func(*args, **kwargs)

    return __impl__


51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
def _is_non_distributed_check_(func):
    def __impl__(*args, **kwargs):
        cls = args[0]

        if cls._role_maker is not None and cls._role_maker._is_non_distributed(
        ) is True:
            warnings.warn(
                "%s() function doesn't work when use non_distributed fleet." %
                (func.__name__))
            return

        return func(*args, **kwargs)

    return __impl__


67
inited_runtime_handler = wrap_decorator(_inited_runtime_handler_)
68
is_non_distributed_check = wrap_decorator(_is_non_distributed_check_)
69 70


71 72 73
class Fleet(object):
    """
    Unified API for distributed training of PaddlePaddle
74
    Please reference the https://github.com/PaddlePaddle/FleetX for details
75 76 77 78 79


    Returns:
        Fleet: A Fleet instance

80
    Example for collective training:
1
123malin 已提交
81

82 83
        .. code-block:: python

1
123malin 已提交
84 85
            import paddle
            paddle.enable_static()
86
            import paddle.distributed.fleet as fleet
87 88 89

            fleet.init(is_collective=True)

90 91 92
            strategy = fleet.DistributedStrategy()
            optimizer = paddle.optimizer.SGD(learning_rate=0.001)
            optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
93 94 95 96 97 98 99 100

            # do distributed training


    Example for parameter server training:

        .. code-block:: python

1
123malin 已提交
101 102
            import paddle
            paddle.enable_static()
103 104
            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
S
ShenLiang 已提交
105
            fleet.init(strategy=strategy)
106

107
            optimizer = paddle.optimizer.SGD(learning_rate=0.001)
108
            optimizer = fleet.distributed_optimizer(optimizer)
109

110 111
            if fleet.is_first_worker():
                print("this is first worker")
112

113 114
            print("current node index: {}".format(fleet.worker_index()))
            print("total number of worker num: {}".format(fleet.worker_num()))
115

116 117 118
            if fleet.is_worker():
                print("this is worker")
            print("worker endpoints: {}".format(fleet.worker_endpoints(to_string=True)))
119

120 121
            print("server num: {}".format(fleet.server_num()))
            print("server endpoints: {}".format(fleet.server_endpoints(to_string=True)))
122

123 124 125
            if fleet.is_server():
                print("this is server")
            fleet.stop_worker()
126 127


128 129 130
    """

    def __init__(self):
131
        self._role_maker = None
132
        self.strategy_compiler = None
133
        self._is_collective = False
134
        self._runtime_handle = None
D
Dong Daxiang 已提交
135 136
        self._util = None
        self._context = {}
137

138
    def init(self, role_maker=None, is_collective=False, strategy=None):
139 140 141
        """
        Initialize role_maker in Fleet.

142 143 144 145 146 147 148 149 150 151 152
        This function is responsible for the distributed architecture
        what you want to run your code behind.

        Args:
            role_maker (RoleMakerBase, optional): A ``RoleMakerBase`` containing the configuration
                of environment variables related to distributed training.If you did not initialize 
                the rolemaker by yourself, it will be automatically initialized to PaddleRoleMaker.
                The default value is None.
            is_collective (Boolean, optional): A ``Boolean`` variable determines whether the program 
                runs on the CPU or GPU. False means set distributed training using CPU, and True means
                GPU.The default value is False.The default value is False.
153 154 155 156
            strategy (DistributedStrategy): Extra properties for distributed training. 
                For details, please refer to paddle.distributed.fleet.DistributedStrategy. Default: None.


157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
        Returns:
            None

        Examples1:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

        Examples2:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init(is_collective=True)

        Examples3:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
1
123malin 已提交
179
                role = fleet.PaddleCloudRoleMaker()
180
                fleet.init(role)
181

182 183 184 185 186 187
        Examples4:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                strategy = fleet.DistributedStrategy()
S
ShenLiang 已提交
188
                fleet.init(strategy=strategy)
189

190
        """
S
ShenLiang 已提交
191 192 193
        if strategy is None:
            strategy = DistributedStrategy()
        self._user_defined_strategy = copy.deepcopy(strategy)
194 195

        if role_maker is None:
196 197 198 199 200 201
            if isinstance(is_collective, bool):
                self._is_collective = is_collective
                self._role_maker = PaddleCloudRoleMaker(
                    is_collective=self._is_collective)
            else:
                raise ValueError(
202 203
                    "`is_collective` should be instance of `bool`, but got {}".
                    format(type(is_collective)))
204
        else:
205 206
            if isinstance(role_maker, RoleMakerBase):
                self._role_maker = role_maker
207
                self._is_collective = role_maker._is_collective
208 209 210 211
            else:
                raise ValueError(
                    "`role_maker` should be subclass of `RoleMakerBase`, but got {}".
                    format(type(role_maker)))
212
        self._role_maker._generate_role()
213

214 215 216
        import paddle.distributed.fleet as fleet
        fleet.util._set_role_maker(self._role_maker)

217
        self.strategy_compiler = StrategyCompiler()
218 219 220 221 222 223 224 225 226

        if self._role_maker._is_non_distributed() and self._is_collective:
            if paddle.fluid.core.is_compiled_with_cuda():
                gpus_num = paddle.fluid.core.get_cuda_device_count()
                if gpus_num != 1:
                    raise ValueError(
                        "CUDA_VISIBLE_DEVICES shoule be set only 1 card if you use `python` to launch fleet program."
                    )

227
        if paddle.fluid.framework.in_dygraph_mode():
228
            if self.worker_num() == 1:
229 230 231
                # if worker_num is 1, should construct default topology & hcg
                self._topology = tp.CommunicateTopology()
                self._hcg = tp.HybridCommunicateGroup(self._topology)
232
                return
233 234 235 236
            if parallel_helper._is_parallel_ctx_initialized():
                warnings.warn(
                    "The dygraph parallel environment has been initialized.")
            else:
237 238 239 240 241 242 243 244 245
                # FLAGS_nccl_nrings is used for dynamic graph multi-stream communication
                if "FLAGS_nccl_nrings" in os.environ:
                    warnings.warn(
                        "You have set the environment variable FLAGS_nccl_nrings "
                        "outside the program, so the nccl_comm_num in "
                        "DistributedStrategy will not take effect here.")
                else:
                    os.environ["FLAGS_nccl_nrings"] = str(
                        self._user_defined_strategy.nccl_comm_num)
246
                paddle.distributed.init_parallel_env()
247

248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289
            # init hybrid parallel environment in dygraph
            if tp._HYBRID_PARALLEL_GROUP is None:
                self._init_hybrid_parallel_env()
            else:
                warnings.warn(
                    "The dygraph hybrid parallel environment has been initialized."
                )

    def _init_hybrid_parallel_env(self):
        """initialize the hybrid environment
        """
        self.hybrid_configs = self._user_defined_strategy.hybrid_configs
        self.dp_degree = self.hybrid_configs["dp_degree"]
        self.mp_degree = self.hybrid_configs["mp_degree"]
        self.pp_degree = self.hybrid_configs["pp_degree"]

        assert self.mp_degree >= 0, "mp_degree should be greater or equal to 0"
        assert self.pp_degree >= 0, "pp_degree should be greater or equal to 0"

        self.mp_degree = max(self.mp_degree, 1)
        self.pp_degree = max(self.pp_degree, 1)

        if self.dp_degree < 0:
            nranks = paddle.distributed.get_world_size()
            self.dp_degree = nranks // (self.mp_degree * self.pp_degree)

        self.dp_degree = max(self.dp_degree, 1)

        self._topology = tp.CommunicateTopology(
            hybrid_group_names=["data", "pipe", "model"],
            dims=[self.dp_degree, self.pp_degree, self.mp_degree])

        self._hcg = tp.HybridCommunicateGroup(self._topology)

    def get_hybrid_communicate_group(self):
        assert self._hcg is not None
        return self._hcg

    def get_hybrid_parallel_topology(self):
        assert self._topology is not None
        return self._topology

290 291 292 293 294 295 296
    def is_first_worker(self):
        """
        Check whether the node is the first instance of worker.

        Returns:
            bool: True if this is the first node of worker,
                  False if not.
297

298 299 300 301 302 303 304 305
        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.is_first_worker()

306
        """
307
        return self._role_maker._is_first_worker()
308 309 310 311 312 313 314

    def worker_index(self):
        """
        Get current worker index.

        Returns:
            int: node id
315 316 317 318

        Examples:

            .. code-block:: python
1
123malin 已提交
319

320 321 322 323
                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.worker_index()

324
        """
325
        return self._role_maker._worker_index()
326 327 328 329 330 331 332

    def worker_num(self):
        """
        Get current total worker number.

        Returns:
            int: worker numbers
1
123malin 已提交
333

334
        Examples:
1
123malin 已提交
335

336 337 338 339 340 341
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.worker_num()

342
        """
343
        return self._role_maker._worker_num()
344

345 346 347 348 349 350 351 352 353 354 355 356
    def node_num(self):
        return self._role_maker._get_node_num()

    def local_rank(self):
        return self._role_maker._get_local_rank()

    def local_device_ids(self):
        return self._role_maker._get_local_device_ids()

    def world_device_ids(self):
        return self._role_maker._get_world_device_ids()

357 358 359 360 361 362 363
    def is_worker(self):
        """
        Check whether the node is an instance of worker.

        Returns:
            bool: True if this is a node of worker,
                  False if not.
364 365

        Examples:
1
123malin 已提交
366

367 368 369 370 371 372
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.is_worker()

373
        """
374
        return self._role_maker._is_worker()
375 376 377

    def worker_endpoints(self, to_string=False):
        """
378
        Get current worker endpoints, such as ["127.0.0.1:1001", "127.0.0.1:1002"].
379 380 381

        Returns:
            list/string: server endpoints
382 383

        Examples:
1
123malin 已提交
384

385 386 387 388 389 390
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.worker_endpoints()

391 392
        """
        if to_string:
393
            return ",".join(self._role_maker._get_trainer_endpoints())
394
        else:
395
            return self._role_maker._get_trainer_endpoints()
396 397 398 399 400 401 402

    def server_num(self):
        """
        Get current total worker number.

        Returns:
            int: server number
403 404

        Examples:
1
123malin 已提交
405

406
            .. code-block:: python
1
123malin 已提交
407 408 409 410

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.server_num()
411
        """
412
        return len(self._role_maker._get_pserver_endpoints())
413 414 415 416 417 418 419

    def server_index(self):
        """
        Get current server index.

        Returns:
            int: node id
420 421

        Examples:
1
123malin 已提交
422

423 424 425 426 427 428
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.server_index()

429
        """
430
        return self._role_maker._server_index()
431 432 433 434 435 436 437

    def server_endpoints(self, to_string=False):
        """
        Get current server endpoints, such as ["127.0.0.1:1001", "127.0.0.1:1002"].

        Returns:
            list/string: server endpoints
438 439

        Examples:
1
123malin 已提交
440

441 442 443 444 445 446
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.server_endpoints()

447
        """
448

449
        if to_string:
450
            return ",".join(self._role_maker._get_pserver_endpoints())
451
        else:
452
            return self._role_maker._get_pserver_endpoints()
453 454 455 456 457 458 459 460

    def is_server(self):
        """
        Check whether the node is an instance of server.

        Returns:
            bool: True if this is a node of server,
                  False if not.
461 462 463 464

        Examples:

            .. code-block:: python
1
123malin 已提交
465

466 467 468 469
                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.is_server()

470
        """
471
        return self._role_maker._is_server(
472
        ) or self._role_maker._is_heter_worker()
473 474 475

    def barrier_worker(self):
        """
476 477 478 479
        barrier all workers

        Returns:
            None
480
        """
481
        self._role_maker._barrier("worker")
482

483
    @is_non_distributed_check
484
    @inited_runtime_handler
485 486
    def init_worker(self):
        """
487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504
        initialize `Communicator` for parameter server training.


        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_worker()

505 506 507
        """
        self._runtime_handle._init_worker()

508
    @is_non_distributed_check
509
    @inited_runtime_handler
510
    def init_server(self, *args, **kwargs):
511
        """
512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530
        init_server executor to initialize startup program,
        if the `args` is not empty, it will run load_persistables for increment training.


        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_server()

531
        """
532
        self._runtime_handle._init_server(*args, **kwargs)
533

534
    @is_non_distributed_check
535
    @inited_runtime_handler
536 537
    def run_server(self):
        """
538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555
        run server will run pserver main program with executor.

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                if fleet.is_server():
                    fleet.init_server()

556 557 558
        """
        self._runtime_handle._run_server()

559
    @is_non_distributed_check
560
    @inited_runtime_handler
561 562
    def stop_worker(self):
        """
563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579
        stop `Communicator` and give training complete notice to parameter server.

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_server()

580 581 582
        """
        self._runtime_handle._stop_worker()

583 584 585 586 587 588
    def save_inference_model(self,
                             executor,
                             dirname,
                             feeded_var_names,
                             target_vars,
                             main_program=None,
589 590
                             export_for_deployment=True,
                             mode=0):
591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610
        """
        save inference model for inference.

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_server()

        """

611 612
        self._runtime_handle._save_inference_model(
            executor, dirname, feeded_var_names, target_vars, main_program,
613
            export_for_deployment, mode)
614

615
    def save_persistables(self, executor, dirname, main_program=None, mode=0):
616 617
        """

1
123malin 已提交
618
        saves all persistable tensors from :code:`main_program` to
619 620
        the folder :code:`dirname`. You can refer to

1
123malin 已提交
621 622
        The :code:`dirname` is used to specify the folder where persistable tensors
        are going to be saved. If you would like to save tensors in separate
623 624 625
        files, set :code:`filename` None.

        Args:
1
123malin 已提交
626
            executor(Executor): The executor to run for saving persistable tensors.
627 628 629 630 631
                                You can refer to :ref:`api_guide_executor_en` for
                                more details.

            dirname(str, optional): The saving directory path.
                                When you need to save the parameter to the memory, set it to None.
1
123malin 已提交
632
            main_program(Program, optional): The program whose persistbale tensors will
633 634 635 636 637 638 639 640 641 642
                                             be saved. Default: None.


        Returns:
            None

        Examples:

            .. code-block:: text

1
123malin 已提交
643 644
                import paddle
                paddle.enable_static()
645 646 647 648 649 650 651
                import paddle.distributed.fleet as fleet

                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

1
123malin 已提交
652 653
                exe = paddle.static.Executor(paddle.CPUPlace())
                fleet.save_persistables(exe, "dirname", paddle.static.default_main_program())
654 655 656

        """

657 658
        self._runtime_handle._save_persistables(executor, dirname, main_program,
                                                mode)
659

660 661 662
    def shrink(self, threshold):
        self._runtime_handle._shrink(threshold)

663
    def distributed_optimizer(self, optimizer, strategy=None):
664
        """
665 666 667 668 669 670 671
        Optimizer for distributed training.

        For the distributed training, this method would rebuild a new instance of DistributedOptimizer.
        Which has basic Optimizer function and special features for distributed training.

        Args:
            optimizer(Optimizer): The executor to run for init server.
672 673 674 675 676
            strategy(DistributedStrategy): Extra properties for distributed optimizer. 
                It is recommended to use DistributedStrategy in fleet.init(). The strategy
                here is for compatibility. If the strategy in fleet.distributed_optimizer() 
                is not None, then it will overwrite the DistributedStrategy in fleet.init(), 
                which will take effect in distributed training.
677

678
        Returns:
679
            Fleet: instance of fleet.
680 681

        Examples:
682

683
            .. code-block:: python
684

1
123malin 已提交
685
                import paddle
686
                import paddle.distributed.fleet as fleet
1
123malin 已提交
687
                fleet.init(is_collective=True)
688 689 690 691
                strategy = fleet.DistributedStrategy()
                optimizer = paddle.optimizer.SGD(learning_rate=0.001)
                optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)

692 693
        """
        self.user_defined_optimizer = optimizer
694

695
        if strategy is not None:
T
tangwei12 已提交
696 697 698 699 700 701 702
            if self._is_collective:
                warnings.warn(
                    "It is recommended to use DistributedStrategy "
                    "in fleet.init(). The strategy here is only for compatibility. "
                    "If the strategy in fleet.distributed_optimizer() is "
                    "not None, then it will overwrite the DistributedStrategy in fleet.init(), "
                    "which will take effect in distributed training.")
703
            self._user_defined_strategy = copy.deepcopy(strategy)
D
Dong Daxiang 已提交
704 705

        self._context = {}
S
ShenLiang 已提交
706 707

        if paddle.fluid.framework.in_dygraph_mode():
708 709 710 711 712
            if self.worker_num() > 1:
                return HybridParallelOptimizer(optimizer, self._hcg,
                                               self._user_defined_strategy)
            else:
                return optimizer
713 714
        return self

715
    @dygraph_only
716
    def distributed_model(self, model):
717
        """
718 719 720 721 722 723 724
        Return distributed data parallel model (Only work in dygraph mode)

        Args:
            model (Layer): the user-defind model which inherits Layer.

        Returns:
            distributed data parallel model which inherits Layer.
725 726

        Examples:
727

728 729
            .. code-block:: python

730 731 732 733 734 735 736 737 738
                import paddle
                import paddle.nn as nn
                from paddle.distributed import fleet

                class LinearNet(nn.Layer):
                    def __init__(self):
                        super(LinearNet, self).__init__()
                        self._linear1 = nn.Linear(10, 10)
                        self._linear2 = nn.Linear(10, 1)
739

740 741
                    def forward(self, x):
                        return self._linear2(self._linear1(x))
742

1
123malin 已提交
743
                # 1. initialize fleet environment
744 745
                fleet.init(is_collective=True)

1
123malin 已提交
746
                # 2. create layer & optimizer
747 748 749 750 751
                layer = LinearNet()
                loss_fn = nn.MSELoss()
                adam = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=layer.parameters())

1
123malin 已提交
752
                # 3. get data_parallel model using fleet
753 754 755
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)

1
123malin 已提交
756
                # 4. run layer
757 758 759 760 761 762 763 764 765 766 767 768
                inputs = paddle.randn([10, 10], 'float32')
                outputs = dp_layer(inputs)
                labels = paddle.randn([10, 1], 'float32')
                loss = loss_fn(outputs, labels)

                print("loss:", loss.numpy())

                loss.backward()

                adam.step()
                adam.clear_grad()

769

770
        """
771 772 773 774 775 776 777 778 779 780 781 782 783 784 785
        assert model is not None, "model should not be None"
        if self.worker_num() <= 1:
            return model
        if self._hcg.get_parallel_mode() == ParallelMode.DATA_PARALLEL:
            distributed_model = paddle.DataParallel(
                model,
                comm_buffer_size=self._user_defined_strategy.
                fuse_grad_size_in_MB,
                last_comm_buffer_size=self._user_defined_strategy.
                last_comm_group_size_MB,
                find_unused_parameters=self._user_defined_strategy.
                find_unused_parameters)
        elif self._hcg.get_parallel_mode() == ParallelMode.MODEL_PARALLEL:
            distributed_model = ModelParallel(
                model, self._hcg, strategy=self._user_defined_strategy)
786 787 788
        elif self._hcg.get_parallel_mode() == ParallelMode.PIPELINE_PARALLEL:
            distributed_model = PipelineParallel(
                model, self._hcg, strategy=self._user_defined_strategy)
789
        return distributed_model
790 791 792 793 794

    @dygraph_only
    def state_dict(self):
        """
        Get state dict information from optimizer.
795
        (Only work in dygraph mode)
796 797 798 799 800 801 802

        Returns: 
            state_dict(dict) : dict contains all the Tensor used by optimizer

        Examples:
            .. code-block:: python

803 804 805 806 807
                import numpy as np
                import paddle
                from paddle.distributed import fleet

                fleet.init(is_collective=True)
808

809
                value = np.arange(26).reshape(2, 13).astype("float32")
1
123malin 已提交
810
                a = paddle.to_tensor(value)
811

812 813
                layer = paddle.nn.Linear(13, 5)
                adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
814

815 816 817
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)
                state_dict = adam.state_dict()
818 819 820 821 822 823 824 825
        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.state_dict()

    @dygraph_only
    def set_state_dict(self, state_dict):
        """
        Load optimizer state dict.
826
        (Only work in dygraph mode)
827 828 829 830

        Args: 
            state_dict(dict) : Dict contains all the Tensor needed by optimizer

831 832
        Returns:
            None
833 834 835 836

        Examples:
            .. code-block:: python

837 838 839
                import numpy as np
                import paddle
                from paddle.distributed import fleet
840

841 842 843
                fleet.init(is_collective=True)

                value = np.arange(26).reshape(2, 13).astype("float32")
1
123malin 已提交
844
                a = paddle.to_tensor(value)
845

846 847
                layer = paddle.nn.Linear(13, 5)
                adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
848

849 850 851
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)
                state_dict = adam.state_dict()
1
123malin 已提交
852 853 854
                paddle.save(state_dict, "paddle_dy")
                para_state_dict = paddle.load("paddle_dy")
                adam.set_state_dict(para_state_dict)
855 856 857 858 859 860 861 862
        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.set_state_dict(state_dict)

    @dygraph_only
    def set_lr(self, value):
        """
        Set the value of the learning rate manually in the optimizer. 
863
        (Only work in dygraph mode)
864

865 866 867
        Args:
            value (float|Tensor): the value of learning rate

868 869
        Returns: 
            None 
870 871 872 873

        Examples:
            .. code-block:: python

874 875 876
                import numpy as np
                import paddle
                from paddle.distributed import fleet
877

878
                fleet.init(is_collective=True)
879

880
                value = np.arange(26).reshape(2, 13).astype("float32")
1
123malin 已提交
881
                a = paddle.to_tensor(value)
882

883 884
                layer = paddle.nn.Linear(13, 5)
                adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
885

886 887 888 889 890 891 892 893 894 895 896 897 898 899
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)

                lr_list = [0.2, 0.3, 0.4, 0.5, 0.6]
                for i in range(5):
                    adam.set_lr(lr_list[i])
                    lr = adam.get_lr()
                    print("current lr is {}".format(lr))
                # Print:
                #    current lr is 0.2
                #    current lr is 0.3
                #    current lr is 0.4
                #    current lr is 0.5
                #    current lr is 0.6
900 901 902 903 904 905 906 907
        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.set_lr(value)

    @dygraph_only
    def get_lr(self):
        """
        Get current step learning rate.
908
        (Only work in dygraph mode)
909 910 911 912 913

        Returns:
            float: The learning rate of the current step.

        Examples:
1
123malin 已提交
914

915 916
            .. code-block:: python

917 918 919 920 921
                import numpy as np
                import paddle
                from paddle.distributed import fleet

                fleet.init(is_collective=True)
922

923
                value = np.arange(26).reshape(2, 13).astype("float32")
1
123malin 已提交
924
                a = paddle.to_tensor(value)
925

926 927
                layer = paddle.nn.Linear(13, 5)
                adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
928

929 930
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)
931

932 933
                lr = adam.get_lr()
                print(lr) # 0.01
934 935 936 937 938 939 940 941
        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.get_lr()

    @dygraph_only
    def step(self):
        """
        Execute the optimizer once.
942
        (Only work in dygraph mode)
943

944 945
        Returns:
            None
946 947

        Examples:
1
123malin 已提交
948

949 950
            .. code-block:: python

951 952 953
                import paddle
                import paddle.nn as nn
                from paddle.distributed import fleet
954

955 956 957 958 959
                class LinearNet(nn.Layer):
                    def __init__(self):
                        super(LinearNet, self).__init__()
                        self._linear1 = nn.Linear(10, 10)
                        self._linear2 = nn.Linear(10, 1)
960

961 962
                    def forward(self, x):
                        return self._linear2(self._linear1(x))
963

1
123malin 已提交
964
                # 1. initialize fleet environment
965 966
                fleet.init(is_collective=True)

1
123malin 已提交
967
                # 2. create layer & optimizer
968 969 970 971 972
                layer = LinearNet()
                loss_fn = nn.MSELoss()
                adam = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=layer.parameters())

1
123malin 已提交
973
                # 3. get data_parallel model using fleet
974 975 976
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)

1
123malin 已提交
977
                # 4. run layer
978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997
                inputs = paddle.randn([10, 10], 'float32')
                outputs = dp_layer(inputs)
                labels = paddle.randn([10, 1], 'float32')
                loss = loss_fn(outputs, labels)

                print("loss:", loss.numpy())

                loss.backward()

                adam.step()
                adam.clear_grad()


        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.step()

    @dygraph_only
    def clear_grad(self):
        """
998 999
        Clear the gradients of all optimized parameters for model.
        (Only work in dygraph mode)
1000

1001 1002
        Returns: 
            None
1003 1004

        Examples:
1
123malin 已提交
1005

1006 1007
            .. code-block:: python

1008 1009 1010
                import paddle
                import paddle.nn as nn
                from paddle.distributed import fleet
1011

1012 1013 1014 1015 1016
                class LinearNet(nn.Layer):
                    def __init__(self):
                        super(LinearNet, self).__init__()
                        self._linear1 = nn.Linear(10, 10)
                        self._linear2 = nn.Linear(10, 1)
1017

1018 1019
                    def forward(self, x):
                        return self._linear2(self._linear1(x))
1020

1
123malin 已提交
1021
                # 1. initialize fleet environment
1022 1023
                fleet.init(is_collective=True)

1
123malin 已提交
1024
                # 2. create layer & optimizer
1025 1026 1027 1028 1029
                layer = LinearNet()
                loss_fn = nn.MSELoss()
                adam = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=layer.parameters())

1
123malin 已提交
1030
                # 3. get data_parallel model using fleet
1031 1032 1033
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)

1
123malin 已提交
1034
                # 4. run layer
1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050
                inputs = paddle.randn([10, 10], 'float32')
                outputs = dp_layer(inputs)
                labels = paddle.randn([10, 1], 'float32')
                loss = loss_fn(outputs, labels)

                print("loss:", loss.numpy())

                loss.backward()

                adam.step()
                adam.clear_grad()

        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.clear_grad()

1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067
    def _get_amp_optimizer(self):
        # imitate target optimizer retrieval
        amp_optimizer = None
        for optimizer in self.strategy_compiler._get_applied_meta_optimizer():
            if hasattr(optimizer, 'amp_init'):
                amp_optimizer = optimizer
                break

        if amp_optimizer is None:
            if hasattr(self.user_defined_optimizer, 'amp_init'):
                amp_optimizer = self.user_defined_optimizer

        assert amp_optimizer is not None, \
            "amp_init can only be used when the amp(auto mixed precision) strategy is turned on."
        return amp_optimizer

    def get_loss_scaling(self):
1068 1069
        """Return the real-time loss scaling factor.
        """
1070 1071 1072
        amp_optimizer = self._get_amp_optimizer()
        return amp_optimizer.get_loss_scaling()

H
huangxu96 已提交
1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132
    def amp_init(self,
                 place,
                 scope=None,
                 test_program=None,
                 use_fp16_test=False):
        """
        Init the amp training, such as cast fp32 parameters to fp16 type.
  
        Args:
            place(CUDAPlace): place is used to initialize 
                fp16 parameters with fp32 values.
            scope(Scope): The scope is used to find fp32 parameters.
            test_program(Program): The program is used for testing.
            use_fp16_test(bool): Whether to use fp16 testing.
            
        Examples:
            .. code-block:: python

                import numpy as np
                import paddle
                import paddle.nn.functional as F
                paddle.enable_static()

                def run_example_code():
                    place = paddle.CUDAPlace(0)
                    exe = paddle.static.Executor(place)
                    data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32')
                    conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3)
                    # 1) Use fp16_guard to control the range of fp16 kernels used.
                    with paddle.static.amp.fp16_guard():
                        bn = paddle.static.nn.batch_norm(input=conv2d, act="relu")
                        pool = F.max_pool2d(bn, kernel_size=2, stride=2)
                        hidden = paddle.static.nn.fc(pool, size=10)
                        loss = paddle.mean(hidden)
                    # 2) Create the optimizer and set `multi_precision` to True.
                    # Setting `multi_precision` to True can avoid the poor accuracy
                    # or the slow convergence in a way. 
                    optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True)
                    # 3) These ops in `custom_black_list` will keep in the float32 computation type.
                    amp_list = paddle.static.amp.CustomOpLists(
                        custom_black_list=['pool2d'])
                    # 4) The entry of Paddle AMP.
                    # Enable pure fp16 training by setting `use_pure_fp16` to True.
                    optimizer = paddle.static.amp.decorate(
                        optimizer,
                        amp_list,
                        init_loss_scaling=128.0,
                        use_dynamic_loss_scaling=True,
                        use_pure_fp16=True)
                    # If you don't use the default_startup_program(), you sholud pass
                    # your defined `startup_program` into `minimize`.
                    optimizer.minimize(loss)
                    exe.run(paddle.static.default_startup_program())
                    # 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`).
                    # If you want to perform the testing process, you should pass `test_program` into `amp_init`.
                    optimizer.amp_init(place, scope=paddle.static.global_scope())
                    
                if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
                    run_example_code()       
        """
1133
        amp_optimizer = self._get_amp_optimizer()
1134
        return amp_optimizer.amp_init(place, scope, test_program, use_fp16_test)
H
huangxu96 已提交
1135

D
Dong Daxiang 已提交
1136 1137 1138 1139 1140 1141 1142 1143 1144
    def _final_strategy(self):
        if "valid_strategy" not in self._context:
            print(
                "WARNING: You may need to call minimize function before this function is called"
            )
            return {}
        else:
            return self._context["valid_strategy"]

1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162
    def _get_applied_meta_list(self):
        if "applied_meta_list" not in self._context:
            print(
                "WARNING: You may need to call minimize function before _get_applied_meta_list called"
            )
            return []
        else:
            return self._context["applied_meta_list"]

    def _get_applied_graph_list(self):
        if "applied_graph_list" not in self._context:
            print(
                "WARNING: You may need to call minimize function before _get_applied_graph_list called"
            )
            return []
        else:
            return self._context["applied_graph_list"]

1163 1164 1165 1166 1167 1168 1169 1170 1171
    def minimize(self,
                 loss,
                 startup_program=None,
                 parameter_list=None,
                 no_grad_set=None):
        """
        Add distributed operations to minimize ``loss`` by updating ``parameter_list``.

        Args:
1
123malin 已提交
1172
            loss (Tensor): A ``Tensor`` containing the value to minimize.
1173 1174 1175
            startup_program (Program, optional): :ref:`api_fluid_Program` for
                initializing parameters in ``parameter_list``. The default value
                is None, at this time :ref:`api_fluid_default_startup_program` will be used.
1
123malin 已提交
1176
            parameter_list (Iterable, optional): Iterable of ``Tensor`` or ``Tensor.name`` to update
1177 1178
                to minimize ``loss``. The default value is None, at this time all parameters
                will be updated.
1
123malin 已提交
1179
            no_grad_set (set, optional): Set of ``Tensor``  or ``Tensor.name`` that don't need
1180 1181 1182 1183
                to be updated. The default value is None.

        Returns:
            tuple: tuple (optimize_ops, params_grads), A list of operators appended
1
123malin 已提交
1184
            by minimize and a list of (param, grad) tensor pairs, param is
1185
            ``Parameter``, grad is the gradient value corresponding to the parameter.
1186 1187
            The returned tuple can be passed to ``fetch_list`` in ``Executor.run()`` to
            indicate program pruning. If so, the program will be pruned by ``feed`` and
1188 1189 1190
            ``fetch_list`` before run, see details in ``Executor``.

        Examples:
1
123malin 已提交
1191

1192
            .. code-block:: python
1193

1194
                import paddle
1
123malin 已提交
1195
                paddle.enable_static()
1196
                import paddle.distributed.fleet as fleet
1
123malin 已提交
1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207
                import paddle.nn.functional as F

                hid_dim = 10
                label_dim = 2
                input_x = paddle.static.data(name='x', shape=[None, 13], dtype='float32')
                input_y = paddle.static.data(name='y', shape=[None, 1], dtype='int64')
                fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim, activation='tanh')
                fc_2 = paddle.static.nn.fc(x=fc_1, size=hid_dim, activation='tanh')
                prediction = paddle.static.nn.fc(x=[fc_2], size=label_dim, activation='softmax')
                cost = F.cross_entropy(input=prediction, label=input_y)
                avg_cost = paddle.mean(x=cost)
1208

1
123malin 已提交
1209
                fleet.init(is_collective=True)
1210 1211 1212 1213
                strategy = fleet.DistributedStrategy()
                optimizer = paddle.optimizer.SGD(learning_rate=0.001)
                optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
                optimizer.minimize(avg_cost)
1214

1215
                # for more examples, please reference https://github.com/PaddlePaddle/FleetX
1216 1217

        """
D
Dong Daxiang 已提交
1218 1219 1220
        context = {}
        context["user_defined_strategy"] = copy.deepcopy(
            self._user_defined_strategy)
1221 1222 1223
        if paddle.fluid.framework.in_dygraph_mode():
            # imitate target optimizer retrieval
            target_opt = self.user_defined_optimizer
D
Dong Daxiang 已提交
1224
            self._context = context
1225 1226
            return target_opt.minimize(loss)

1227 1228
        # cache original feed forward program
        self.origin_main_program = loss.block.program
1229 1230
        context["origin_main_program"] = self.origin_main_program
        context["loss"] = loss
1231 1232
        if startup_program == None:
            self.origin_startup_program = \
1233 1234
                paddle.static.default_startup_program().clone(for_test=False)
            startup_program = paddle.static.default_startup_program()
1235 1236 1237
        else:
            self.origin_startup_program = \
                startup_program.clone(for_test=False)
1238

1239 1240
        context["origin_startup_program"] = startup_program
        context["role_maker"] = self._role_maker
1241 1242 1243 1244 1245

        # compile time
        distributed_optimizer_list = \
            MetaOptimizerFactory()._get_valid_meta_optimizers(
                self.user_defined_optimizer)
D
Dong Daxiang 已提交
1246

D
Dong Daxiang 已提交
1247 1248 1249
        context["user_defined_strategy"] = copy.deepcopy(
            self._user_defined_strategy)
        copy_user_defined_strategy = copy.deepcopy(self._user_defined_strategy)
1250 1251 1252 1253 1254 1255

        # trigger the auto-parallel in very strict condition
        # strategy = DistributedStrategy()
        # strategy.auto = True
        # optimizer = paddle.optimizer.SGD(learning_rate=0.1)
        # optimizer = fleet.distributed_optimizer(optimizer, strategy)
D
Dong Daxiang 已提交
1256
        if copy_user_defined_strategy._is_strict_auto():
1257 1258
            # turn on all the strategy for each optimizer
            for opt in distributed_optimizer_list:
D
Dong Daxiang 已提交
1259
                opt._enable_strategy(copy_user_defined_strategy, context)
1260

1261 1262
        valid_optimizer_list = []
        valid_graph_optimizer_list = []
D
Dong Daxiang 已提交
1263
        can_not_apply_optimizer_list = []
1264 1265 1266 1267
        # recall meta optimizers for ranking
        for opt in distributed_optimizer_list:
            opt._set_basic_info(loss, self._role_maker,
                                self.user_defined_optimizer,
D
Dong Daxiang 已提交
1268
                                copy_user_defined_strategy)
1269 1270
            if opt._can_apply() and not opt._is_graph_out():
                valid_optimizer_list.append(opt)
D
Dong Daxiang 已提交
1271
            elif opt._can_apply() and opt._is_graph_out():
1272
                valid_graph_optimizer_list.append(opt)
D
Dong Daxiang 已提交
1273 1274
            else:
                can_not_apply_optimizer_list.append(opt)
1275
        # combine recalled meta optimizers to be a valid meta optimizer
D
Dong Daxiang 已提交
1276
        meta_optimizer, graph_optimizer = \
1277 1278
            self.strategy_compiler.generate_optimizer(
                loss, self._role_maker, self.user_defined_optimizer,
D
Dong Daxiang 已提交
1279
                copy_user_defined_strategy, valid_optimizer_list,
1280
                valid_graph_optimizer_list)
D
Dong Daxiang 已提交
1281

D
Dong Daxiang 已提交
1282
        valid_strategy = self.strategy_compiler._get_valid_strategy(
D
Dong Daxiang 已提交
1283 1284 1285
            copy_user_defined_strategy, can_not_apply_optimizer_list)

        context["valid_strategy"] = copy.deepcopy(valid_strategy)
1286

1287 1288 1289 1290 1291 1292
        applied_meta_list = self.strategy_compiler._get_applied_meta_list()
        applied_graph_list = self.strategy_compiler._get_applied_graph_list()

        context['applied_meta_list'] = applied_meta_list
        context['applied_graph_list'] = applied_graph_list

D
Dong Daxiang 已提交
1293
        self._context = context
1294

D
Dong Daxiang 已提交
1295
        self.valid_strategy = valid_strategy
1296
        self.valid_strategy._enable_env()
D
Dong Daxiang 已提交
1297

1298 1299
        optimize_ops = []
        params_grads = []
1300

1301 1302 1303 1304 1305 1306 1307 1308 1309
        if self._role_maker._is_non_distributed() and not self._is_collective:
            if self._runtime_handle is None:
                self._runtime_handle = RuntimeFactory()._create_runtime(context)

            compiled_program = compiler.CompiledProgram(
                self.origin_main_program).with_data_parallel(
                    loss_name=loss.name, share_vars_from=None)
            loss.block.program._graph = compiled_program
            return self.user_defined_optimizer.minimize(
M
MRXLT 已提交
1310
                loss, startup_program, parameter_list, no_grad_set=no_grad_set)
1311

1312 1313
        if meta_optimizer:
            optimize_ops, params_grads = meta_optimizer.minimize(
M
MRXLT 已提交
1314
                loss, startup_program, parameter_list, no_grad_set=no_grad_set)
1315

1316
            default_program = paddle.static.default_main_program()
1317 1318 1319 1320

            if id(default_program) != id(loss.block.program):
                paddle.fluid.framework.switch_main_program(loss.block.program)

1321 1322
        else:
            optimize_ops, params_grads = self.user_defined_optimizer.minimize(
M
MRXLT 已提交
1323
                loss, startup_program, parameter_list, no_grad_set=no_grad_set)
1324

1325 1326
        context["program_optimize_ops"] = optimize_ops
        context["program_params_grads"] = params_grads
1327

1328
        if graph_optimizer:
D
Dong Daxiang 已提交
1329
            optimize_ops, params_grads = graph_optimizer.minimize(
M
MRXLT 已提交
1330
                loss, startup_program, parameter_list, no_grad_set=no_grad_set)
1331 1332 1333 1334
            # since we do not encourage users to use graph operations
            # if a graph optimizer takes effect, mostly
            # optimizers_ops and params_grads are None
            # i.e. users can not modify current computation graph anymore
1335 1336 1337
            context["graph_optimize_ops"] = optimize_ops
            context["graph_optimize_grads"] = params_grads

1338
        if self._runtime_handle is None:
1339
            self._runtime_handle = RuntimeFactory()._create_runtime(context)
1340

1341 1342
        import paddle.distributed.fleet as fleet
        fleet.util._set_strategy(context["valid_strategy"])
1343 1344

        return optimize_ops, params_grads
1345 1346 1347 1348

    @dygraph_only
    def distributed_scaler(self, scaler):
        return HybridParallelGradScaler(scaler, self._hcg)