fleet_base.py 45.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
16
import copy
17
import warnings
18
import paddle
19
import os
20
from paddle.fluid.framework import dygraph_only
21
from paddle.fluid import compiler
22
from .role_maker import UserDefinedRoleMaker, PaddleCloudRoleMaker, RoleMakerBase
23
from .strategy_compiler import StrategyCompiler
24
from .distributed_strategy import DistributedStrategy
25 26
from .meta_optimizer_factory import MetaOptimizerFactory
from .runtime_factory import RuntimeFactory
27
from paddle.fluid.wrapped_decorator import wrap_decorator
28
from paddle.fluid.dygraph import parallel_helper
29
from . import topology as tp
30 31
from .topology import ParallelMode
from ..meta_parallel import ModelParallel
32
from ..meta_parallel import PipelineParallel
33
from ..meta_optimizers import HybridParallelOptimizer
34
from ..meta_optimizers import HybridParallelGradScaler
35

36

37 38 39 40 41 42 43 44 45 46 47 48
def _inited_runtime_handler_(func):
    def __impl__(*args, **kwargs):
        cls = args[0]

        if cls._runtime_handle is None:
            raise ValueError("Fleet can not find suitable runtime handler")

        return func(*args, **kwargs)

    return __impl__


49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
def _is_non_distributed_check_(func):
    def __impl__(*args, **kwargs):
        cls = args[0]

        if cls._role_maker is not None and cls._role_maker._is_non_distributed(
        ) is True:
            warnings.warn(
                "%s() function doesn't work when use non_distributed fleet." %
                (func.__name__))
            return

        return func(*args, **kwargs)

    return __impl__


65
inited_runtime_handler = wrap_decorator(_inited_runtime_handler_)
66
is_non_distributed_check = wrap_decorator(_is_non_distributed_check_)
67 68


69 70 71
class Fleet(object):
    """
    Unified API for distributed training of PaddlePaddle
72
    Please reference the https://github.com/PaddlePaddle/FleetX for details
73 74 75 76 77


    Returns:
        Fleet: A Fleet instance

78
    Example for collective training:
1
123malin 已提交
79

80 81
        .. code-block:: python

1
123malin 已提交
82 83
            import paddle
            paddle.enable_static()
84
            import paddle.distributed.fleet as fleet
85 86 87

            fleet.init(is_collective=True)

88 89 90
            strategy = fleet.DistributedStrategy()
            optimizer = paddle.optimizer.SGD(learning_rate=0.001)
            optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
91 92 93 94 95 96 97 98

            # do distributed training


    Example for parameter server training:

        .. code-block:: python

1
123malin 已提交
99 100
            import paddle
            paddle.enable_static()
101 102
            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
S
ShenLiang 已提交
103
            fleet.init(strategy=strategy)
104

105
            optimizer = paddle.optimizer.SGD(learning_rate=0.001)
106
            optimizer = fleet.distributed_optimizer(optimizer)
107

108 109
            if fleet.is_first_worker():
                print("this is first worker")
110

111 112
            print("current node index: {}".format(fleet.worker_index()))
            print("total number of worker num: {}".format(fleet.worker_num()))
113

114 115 116
            if fleet.is_worker():
                print("this is worker")
            print("worker endpoints: {}".format(fleet.worker_endpoints(to_string=True)))
117

118 119
            print("server num: {}".format(fleet.server_num()))
            print("server endpoints: {}".format(fleet.server_endpoints(to_string=True)))
120

121 122 123
            if fleet.is_server():
                print("this is server")
            fleet.stop_worker()
124 125


126 127 128
    """

    def __init__(self):
129
        self._role_maker = None
130
        self.strategy_compiler = None
131
        self._is_collective = False
132
        self._runtime_handle = None
D
Dong Daxiang 已提交
133 134
        self._util = None
        self._context = {}
135

136
    def init(self, role_maker=None, is_collective=False, strategy=None):
137 138 139
        """
        Initialize role_maker in Fleet.

140 141 142 143 144 145 146 147 148 149 150
        This function is responsible for the distributed architecture
        what you want to run your code behind.

        Args:
            role_maker (RoleMakerBase, optional): A ``RoleMakerBase`` containing the configuration
                of environment variables related to distributed training.If you did not initialize 
                the rolemaker by yourself, it will be automatically initialized to PaddleRoleMaker.
                The default value is None.
            is_collective (Boolean, optional): A ``Boolean`` variable determines whether the program 
                runs on the CPU or GPU. False means set distributed training using CPU, and True means
                GPU.The default value is False.The default value is False.
151 152 153 154
            strategy (DistributedStrategy): Extra properties for distributed training. 
                For details, please refer to paddle.distributed.fleet.DistributedStrategy. Default: None.


155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
        Returns:
            None

        Examples1:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

        Examples2:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init(is_collective=True)

        Examples3:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
1
123malin 已提交
177
                role = fleet.PaddleCloudRoleMaker()
178
                fleet.init(role)
179

180 181 182 183 184 185
        Examples4:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                strategy = fleet.DistributedStrategy()
S
ShenLiang 已提交
186
                fleet.init(strategy=strategy)
187

188
        """
S
ShenLiang 已提交
189 190 191
        if strategy is None:
            strategy = DistributedStrategy()
        self._user_defined_strategy = copy.deepcopy(strategy)
192 193

        if role_maker is None:
194 195 196 197 198 199
            if isinstance(is_collective, bool):
                self._is_collective = is_collective
                self._role_maker = PaddleCloudRoleMaker(
                    is_collective=self._is_collective)
            else:
                raise ValueError(
200 201
                    "`is_collective` should be instance of `bool`, but got {}".
                    format(type(is_collective)))
202
        else:
203 204
            if isinstance(role_maker, RoleMakerBase):
                self._role_maker = role_maker
205
                self._is_collective = role_maker._is_collective
206 207 208 209
            else:
                raise ValueError(
                    "`role_maker` should be subclass of `RoleMakerBase`, but got {}".
                    format(type(role_maker)))
210
        self._role_maker._generate_role()
211

212 213 214
        import paddle.distributed.fleet as fleet
        fleet.util._set_role_maker(self._role_maker)

215
        self.strategy_compiler = StrategyCompiler()
216 217 218 219 220 221 222 223 224

        if self._role_maker._is_non_distributed() and self._is_collective:
            if paddle.fluid.core.is_compiled_with_cuda():
                gpus_num = paddle.fluid.core.get_cuda_device_count()
                if gpus_num != 1:
                    raise ValueError(
                        "CUDA_VISIBLE_DEVICES shoule be set only 1 card if you use `python` to launch fleet program."
                    )

225
        if paddle.fluid.framework.in_dygraph_mode():
226
            if self.worker_num() == 1:
227 228 229
                # if worker_num is 1, should construct default topology & hcg
                self._topology = tp.CommunicateTopology()
                self._hcg = tp.HybridCommunicateGroup(self._topology)
230
                return
231 232 233 234
            if parallel_helper._is_parallel_ctx_initialized():
                warnings.warn(
                    "The dygraph parallel environment has been initialized.")
            else:
235 236 237 238 239 240 241 242 243
                # FLAGS_nccl_nrings is used for dynamic graph multi-stream communication
                if "FLAGS_nccl_nrings" in os.environ:
                    warnings.warn(
                        "You have set the environment variable FLAGS_nccl_nrings "
                        "outside the program, so the nccl_comm_num in "
                        "DistributedStrategy will not take effect here.")
                else:
                    os.environ["FLAGS_nccl_nrings"] = str(
                        self._user_defined_strategy.nccl_comm_num)
244
                paddle.distributed.init_parallel_env()
245

246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287
            # init hybrid parallel environment in dygraph
            if tp._HYBRID_PARALLEL_GROUP is None:
                self._init_hybrid_parallel_env()
            else:
                warnings.warn(
                    "The dygraph hybrid parallel environment has been initialized."
                )

    def _init_hybrid_parallel_env(self):
        """initialize the hybrid environment
        """
        self.hybrid_configs = self._user_defined_strategy.hybrid_configs
        self.dp_degree = self.hybrid_configs["dp_degree"]
        self.mp_degree = self.hybrid_configs["mp_degree"]
        self.pp_degree = self.hybrid_configs["pp_degree"]

        assert self.mp_degree >= 0, "mp_degree should be greater or equal to 0"
        assert self.pp_degree >= 0, "pp_degree should be greater or equal to 0"

        self.mp_degree = max(self.mp_degree, 1)
        self.pp_degree = max(self.pp_degree, 1)

        if self.dp_degree < 0:
            nranks = paddle.distributed.get_world_size()
            self.dp_degree = nranks // (self.mp_degree * self.pp_degree)

        self.dp_degree = max(self.dp_degree, 1)

        self._topology = tp.CommunicateTopology(
            hybrid_group_names=["data", "pipe", "model"],
            dims=[self.dp_degree, self.pp_degree, self.mp_degree])

        self._hcg = tp.HybridCommunicateGroup(self._topology)

    def get_hybrid_communicate_group(self):
        assert self._hcg is not None
        return self._hcg

    def get_hybrid_parallel_topology(self):
        assert self._topology is not None
        return self._topology

288 289 290 291 292 293 294
    def is_first_worker(self):
        """
        Check whether the node is the first instance of worker.

        Returns:
            bool: True if this is the first node of worker,
                  False if not.
295

296 297 298 299 300 301 302 303
        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.is_first_worker()

304
        """
305
        return self._role_maker._is_first_worker()
306 307 308 309 310 311 312

    def worker_index(self):
        """
        Get current worker index.

        Returns:
            int: node id
313 314 315 316

        Examples:

            .. code-block:: python
1
123malin 已提交
317

318 319 320 321
                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.worker_index()

322
        """
323
        return self._role_maker._worker_index()
324 325 326 327 328 329 330

    def worker_num(self):
        """
        Get current total worker number.

        Returns:
            int: worker numbers
1
123malin 已提交
331

332
        Examples:
1
123malin 已提交
333

334 335 336 337 338 339
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.worker_num()

340
        """
341
        return self._role_maker._worker_num()
342

343 344 345 346 347 348 349 350 351 352 353 354
    def node_num(self):
        return self._role_maker._get_node_num()

    def local_rank(self):
        return self._role_maker._get_local_rank()

    def local_device_ids(self):
        return self._role_maker._get_local_device_ids()

    def world_device_ids(self):
        return self._role_maker._get_world_device_ids()

355 356 357 358 359 360 361
    def is_worker(self):
        """
        Check whether the node is an instance of worker.

        Returns:
            bool: True if this is a node of worker,
                  False if not.
362 363

        Examples:
1
123malin 已提交
364

365 366 367 368 369 370
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.is_worker()

371
        """
372
        return self._role_maker._is_worker()
373 374 375

    def worker_endpoints(self, to_string=False):
        """
376
        Get current worker endpoints, such as ["127.0.0.1:1001", "127.0.0.1:1002"].
377 378 379

        Returns:
            list/string: server endpoints
380 381

        Examples:
1
123malin 已提交
382

383 384 385 386 387 388
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.worker_endpoints()

389 390
        """
        if to_string:
391
            return ",".join(self._role_maker._get_trainer_endpoints())
392
        else:
393
            return self._role_maker._get_trainer_endpoints()
394 395 396 397 398 399 400

    def server_num(self):
        """
        Get current total worker number.

        Returns:
            int: server number
401 402

        Examples:
1
123malin 已提交
403

404
            .. code-block:: python
1
123malin 已提交
405 406 407 408

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.server_num()
409
        """
410
        return len(self._role_maker._get_pserver_endpoints())
411 412 413 414 415 416 417

    def server_index(self):
        """
        Get current server index.

        Returns:
            int: node id
418 419

        Examples:
1
123malin 已提交
420

421 422 423 424 425 426
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.server_index()

427
        """
428
        return self._role_maker._server_index()
429 430 431 432 433 434 435

    def server_endpoints(self, to_string=False):
        """
        Get current server endpoints, such as ["127.0.0.1:1001", "127.0.0.1:1002"].

        Returns:
            list/string: server endpoints
436 437

        Examples:
1
123malin 已提交
438

439 440 441 442 443 444
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.server_endpoints()

445
        """
446

447
        if to_string:
448
            return ",".join(self._role_maker._get_pserver_endpoints())
449
        else:
450
            return self._role_maker._get_pserver_endpoints()
451 452 453 454 455 456 457 458

    def is_server(self):
        """
        Check whether the node is an instance of server.

        Returns:
            bool: True if this is a node of server,
                  False if not.
459 460 461 462

        Examples:

            .. code-block:: python
1
123malin 已提交
463

464 465 466 467
                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.is_server()

468
        """
469
        return self._role_maker._is_server(
470
        ) or self._role_maker._is_heter_worker()
471 472 473

    def barrier_worker(self):
        """
474 475 476 477
        barrier all workers

        Returns:
            None
478
        """
479
        self._role_maker._barrier("worker")
480

481
    @is_non_distributed_check
482
    @inited_runtime_handler
483 484
    def init_worker(self):
        """
485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502
        initialize `Communicator` for parameter server training.


        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_worker()

503 504 505
        """
        self._runtime_handle._init_worker()

506
    @is_non_distributed_check
507
    @inited_runtime_handler
508
    def init_server(self, *args, **kwargs):
509
        """
510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528
        init_server executor to initialize startup program,
        if the `args` is not empty, it will run load_persistables for increment training.


        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_server()

529
        """
530
        self._runtime_handle._init_server(*args, **kwargs)
531

532
    @is_non_distributed_check
533
    @inited_runtime_handler
534 535
    def run_server(self):
        """
536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553
        run server will run pserver main program with executor.

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                if fleet.is_server():
                    fleet.init_server()

554 555 556
        """
        self._runtime_handle._run_server()

557
    @is_non_distributed_check
558
    @inited_runtime_handler
559 560
    def stop_worker(self):
        """
561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577
        stop `Communicator` and give training complete notice to parameter server.

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_server()

578 579 580
        """
        self._runtime_handle._stop_worker()

581 582 583 584 585 586
    def save_inference_model(self,
                             executor,
                             dirname,
                             feeded_var_names,
                             target_vars,
                             main_program=None,
587 588
                             export_for_deployment=True,
                             mode=0):
589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608
        """
        save inference model for inference.

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_server()

        """

609 610
        self._runtime_handle._save_inference_model(
            executor, dirname, feeded_var_names, target_vars, main_program,
611
            export_for_deployment, mode)
612

613
    def save_persistables(self, executor, dirname, main_program=None, mode=0):
614 615
        """

1
123malin 已提交
616
        saves all persistable tensors from :code:`main_program` to
617 618
        the folder :code:`dirname`. You can refer to

1
123malin 已提交
619 620
        The :code:`dirname` is used to specify the folder where persistable tensors
        are going to be saved. If you would like to save tensors in separate
621 622 623
        files, set :code:`filename` None.

        Args:
1
123malin 已提交
624
            executor(Executor): The executor to run for saving persistable tensors.
625 626 627 628 629
                                You can refer to :ref:`api_guide_executor_en` for
                                more details.

            dirname(str, optional): The saving directory path.
                                When you need to save the parameter to the memory, set it to None.
1
123malin 已提交
630
            main_program(Program, optional): The program whose persistbale tensors will
631 632 633 634 635 636 637 638 639 640
                                             be saved. Default: None.


        Returns:
            None

        Examples:

            .. code-block:: text

1
123malin 已提交
641 642
                import paddle
                paddle.enable_static()
643 644 645 646 647 648 649
                import paddle.distributed.fleet as fleet

                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

1
123malin 已提交
650 651
                exe = paddle.static.Executor(paddle.CPUPlace())
                fleet.save_persistables(exe, "dirname", paddle.static.default_main_program())
652 653 654

        """

655 656
        self._runtime_handle._save_persistables(executor, dirname, main_program,
                                                mode)
657

658 659 660
    def shrink(self, threshold):
        self._runtime_handle._shrink(threshold)

661
    def distributed_optimizer(self, optimizer, strategy=None):
662
        """
663 664 665 666 667 668 669
        Optimizer for distributed training.

        For the distributed training, this method would rebuild a new instance of DistributedOptimizer.
        Which has basic Optimizer function and special features for distributed training.

        Args:
            optimizer(Optimizer): The executor to run for init server.
670 671 672 673 674
            strategy(DistributedStrategy): Extra properties for distributed optimizer. 
                It is recommended to use DistributedStrategy in fleet.init(). The strategy
                here is for compatibility. If the strategy in fleet.distributed_optimizer() 
                is not None, then it will overwrite the DistributedStrategy in fleet.init(), 
                which will take effect in distributed training.
675

676
        Returns:
677
            Fleet: instance of fleet.
678 679

        Examples:
680

681
            .. code-block:: python
682

1
123malin 已提交
683
                import paddle
684
                import paddle.distributed.fleet as fleet
1
123malin 已提交
685
                fleet.init(is_collective=True)
686 687 688 689
                strategy = fleet.DistributedStrategy()
                optimizer = paddle.optimizer.SGD(learning_rate=0.001)
                optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)

690 691
        """
        self.user_defined_optimizer = optimizer
692

693
        if strategy is not None:
T
tangwei12 已提交
694 695 696 697 698 699 700
            if self._is_collective:
                warnings.warn(
                    "It is recommended to use DistributedStrategy "
                    "in fleet.init(). The strategy here is only for compatibility. "
                    "If the strategy in fleet.distributed_optimizer() is "
                    "not None, then it will overwrite the DistributedStrategy in fleet.init(), "
                    "which will take effect in distributed training.")
701
            self._user_defined_strategy = copy.deepcopy(strategy)
D
Dong Daxiang 已提交
702 703

        self._context = {}
S
ShenLiang 已提交
704 705

        if paddle.fluid.framework.in_dygraph_mode():
706 707 708 709 710
            if self.worker_num() > 1:
                return HybridParallelOptimizer(optimizer, self._hcg,
                                               self._user_defined_strategy)
            else:
                return optimizer
711 712
        return self

713
    @dygraph_only
714
    def distributed_model(self, model):
715
        """
716 717 718 719 720 721 722
        Return distributed data parallel model (Only work in dygraph mode)

        Args:
            model (Layer): the user-defind model which inherits Layer.

        Returns:
            distributed data parallel model which inherits Layer.
723 724

        Examples:
725

726 727
            .. code-block:: python

728 729 730 731 732 733 734 735 736
                import paddle
                import paddle.nn as nn
                from paddle.distributed import fleet

                class LinearNet(nn.Layer):
                    def __init__(self):
                        super(LinearNet, self).__init__()
                        self._linear1 = nn.Linear(10, 10)
                        self._linear2 = nn.Linear(10, 1)
737

738 739
                    def forward(self, x):
                        return self._linear2(self._linear1(x))
740

1
123malin 已提交
741
                # 1. initialize fleet environment
742 743
                fleet.init(is_collective=True)

1
123malin 已提交
744
                # 2. create layer & optimizer
745 746 747 748 749
                layer = LinearNet()
                loss_fn = nn.MSELoss()
                adam = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=layer.parameters())

1
123malin 已提交
750
                # 3. get data_parallel model using fleet
751 752 753
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)

1
123malin 已提交
754
                # 4. run layer
755 756 757 758 759 760 761 762 763 764 765 766
                inputs = paddle.randn([10, 10], 'float32')
                outputs = dp_layer(inputs)
                labels = paddle.randn([10, 1], 'float32')
                loss = loss_fn(outputs, labels)

                print("loss:", loss.numpy())

                loss.backward()

                adam.step()
                adam.clear_grad()

767

768
        """
769 770 771 772 773 774 775 776 777 778 779 780 781 782 783
        assert model is not None, "model should not be None"
        if self.worker_num() <= 1:
            return model
        if self._hcg.get_parallel_mode() == ParallelMode.DATA_PARALLEL:
            distributed_model = paddle.DataParallel(
                model,
                comm_buffer_size=self._user_defined_strategy.
                fuse_grad_size_in_MB,
                last_comm_buffer_size=self._user_defined_strategy.
                last_comm_group_size_MB,
                find_unused_parameters=self._user_defined_strategy.
                find_unused_parameters)
        elif self._hcg.get_parallel_mode() == ParallelMode.MODEL_PARALLEL:
            distributed_model = ModelParallel(
                model, self._hcg, strategy=self._user_defined_strategy)
784 785 786
        elif self._hcg.get_parallel_mode() == ParallelMode.PIPELINE_PARALLEL:
            distributed_model = PipelineParallel(
                model, self._hcg, strategy=self._user_defined_strategy)
787
        return distributed_model
788 789 790 791 792

    @dygraph_only
    def state_dict(self):
        """
        Get state dict information from optimizer.
793
        (Only work in dygraph mode)
794 795 796 797 798 799 800

        Returns: 
            state_dict(dict) : dict contains all the Tensor used by optimizer

        Examples:
            .. code-block:: python

801 802 803 804 805
                import numpy as np
                import paddle
                from paddle.distributed import fleet

                fleet.init(is_collective=True)
806

807
                value = np.arange(26).reshape(2, 13).astype("float32")
1
123malin 已提交
808
                a = paddle.to_tensor(value)
809

810 811
                layer = paddle.nn.Linear(13, 5)
                adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
812

813 814 815
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)
                state_dict = adam.state_dict()
816 817 818 819 820 821 822 823
        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.state_dict()

    @dygraph_only
    def set_state_dict(self, state_dict):
        """
        Load optimizer state dict.
824
        (Only work in dygraph mode)
825 826 827 828

        Args: 
            state_dict(dict) : Dict contains all the Tensor needed by optimizer

829 830
        Returns:
            None
831 832 833 834

        Examples:
            .. code-block:: python

835 836 837
                import numpy as np
                import paddle
                from paddle.distributed import fleet
838

839 840 841
                fleet.init(is_collective=True)

                value = np.arange(26).reshape(2, 13).astype("float32")
1
123malin 已提交
842
                a = paddle.to_tensor(value)
843

844 845
                layer = paddle.nn.Linear(13, 5)
                adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
846

847 848 849
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)
                state_dict = adam.state_dict()
1
123malin 已提交
850 851 852
                paddle.save(state_dict, "paddle_dy")
                para_state_dict = paddle.load("paddle_dy")
                adam.set_state_dict(para_state_dict)
853 854 855 856 857 858 859 860
        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.set_state_dict(state_dict)

    @dygraph_only
    def set_lr(self, value):
        """
        Set the value of the learning rate manually in the optimizer. 
861
        (Only work in dygraph mode)
862

863 864 865
        Args:
            value (float|Tensor): the value of learning rate

866 867
        Returns: 
            None 
868 869 870 871

        Examples:
            .. code-block:: python

872 873 874
                import numpy as np
                import paddle
                from paddle.distributed import fleet
875

876
                fleet.init(is_collective=True)
877

878
                value = np.arange(26).reshape(2, 13).astype("float32")
1
123malin 已提交
879
                a = paddle.to_tensor(value)
880

881 882
                layer = paddle.nn.Linear(13, 5)
                adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
883

884 885 886 887 888 889 890 891 892 893 894 895 896 897
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)

                lr_list = [0.2, 0.3, 0.4, 0.5, 0.6]
                for i in range(5):
                    adam.set_lr(lr_list[i])
                    lr = adam.get_lr()
                    print("current lr is {}".format(lr))
                # Print:
                #    current lr is 0.2
                #    current lr is 0.3
                #    current lr is 0.4
                #    current lr is 0.5
                #    current lr is 0.6
898 899 900 901 902 903 904 905
        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.set_lr(value)

    @dygraph_only
    def get_lr(self):
        """
        Get current step learning rate.
906
        (Only work in dygraph mode)
907 908 909 910 911

        Returns:
            float: The learning rate of the current step.

        Examples:
1
123malin 已提交
912

913 914
            .. code-block:: python

915 916 917 918 919
                import numpy as np
                import paddle
                from paddle.distributed import fleet

                fleet.init(is_collective=True)
920

921
                value = np.arange(26).reshape(2, 13).astype("float32")
1
123malin 已提交
922
                a = paddle.to_tensor(value)
923

924 925
                layer = paddle.nn.Linear(13, 5)
                adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
926

927 928
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)
929

930 931
                lr = adam.get_lr()
                print(lr) # 0.01
932 933 934 935 936 937 938 939
        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.get_lr()

    @dygraph_only
    def step(self):
        """
        Execute the optimizer once.
940
        (Only work in dygraph mode)
941

942 943
        Returns:
            None
944 945

        Examples:
1
123malin 已提交
946

947 948
            .. code-block:: python

949 950 951
                import paddle
                import paddle.nn as nn
                from paddle.distributed import fleet
952

953 954 955 956 957
                class LinearNet(nn.Layer):
                    def __init__(self):
                        super(LinearNet, self).__init__()
                        self._linear1 = nn.Linear(10, 10)
                        self._linear2 = nn.Linear(10, 1)
958

959 960
                    def forward(self, x):
                        return self._linear2(self._linear1(x))
961

1
123malin 已提交
962
                # 1. initialize fleet environment
963 964
                fleet.init(is_collective=True)

1
123malin 已提交
965
                # 2. create layer & optimizer
966 967 968 969 970
                layer = LinearNet()
                loss_fn = nn.MSELoss()
                adam = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=layer.parameters())

1
123malin 已提交
971
                # 3. get data_parallel model using fleet
972 973 974
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)

1
123malin 已提交
975
                # 4. run layer
976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995
                inputs = paddle.randn([10, 10], 'float32')
                outputs = dp_layer(inputs)
                labels = paddle.randn([10, 1], 'float32')
                loss = loss_fn(outputs, labels)

                print("loss:", loss.numpy())

                loss.backward()

                adam.step()
                adam.clear_grad()


        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.step()

    @dygraph_only
    def clear_grad(self):
        """
996 997
        Clear the gradients of all optimized parameters for model.
        (Only work in dygraph mode)
998

999 1000
        Returns: 
            None
1001 1002

        Examples:
1
123malin 已提交
1003

1004 1005
            .. code-block:: python

1006 1007 1008
                import paddle
                import paddle.nn as nn
                from paddle.distributed import fleet
1009

1010 1011 1012 1013 1014
                class LinearNet(nn.Layer):
                    def __init__(self):
                        super(LinearNet, self).__init__()
                        self._linear1 = nn.Linear(10, 10)
                        self._linear2 = nn.Linear(10, 1)
1015

1016 1017
                    def forward(self, x):
                        return self._linear2(self._linear1(x))
1018

1
123malin 已提交
1019
                # 1. initialize fleet environment
1020 1021
                fleet.init(is_collective=True)

1
123malin 已提交
1022
                # 2. create layer & optimizer
1023 1024 1025 1026 1027
                layer = LinearNet()
                loss_fn = nn.MSELoss()
                adam = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=layer.parameters())

1
123malin 已提交
1028
                # 3. get data_parallel model using fleet
1029 1030 1031
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)

1
123malin 已提交
1032
                # 4. run layer
1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048
                inputs = paddle.randn([10, 10], 'float32')
                outputs = dp_layer(inputs)
                labels = paddle.randn([10, 1], 'float32')
                loss = loss_fn(outputs, labels)

                print("loss:", loss.numpy())

                loss.backward()

                adam.step()
                adam.clear_grad()

        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.clear_grad()

1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065
    def _get_amp_optimizer(self):
        # imitate target optimizer retrieval
        amp_optimizer = None
        for optimizer in self.strategy_compiler._get_applied_meta_optimizer():
            if hasattr(optimizer, 'amp_init'):
                amp_optimizer = optimizer
                break

        if amp_optimizer is None:
            if hasattr(self.user_defined_optimizer, 'amp_init'):
                amp_optimizer = self.user_defined_optimizer

        assert amp_optimizer is not None, \
            "amp_init can only be used when the amp(auto mixed precision) strategy is turned on."
        return amp_optimizer

    def get_loss_scaling(self):
1066 1067
        """Return the real-time loss scaling factor.
        """
1068 1069 1070
        amp_optimizer = self._get_amp_optimizer()
        return amp_optimizer.get_loss_scaling()

H
huangxu96 已提交
1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130
    def amp_init(self,
                 place,
                 scope=None,
                 test_program=None,
                 use_fp16_test=False):
        """
        Init the amp training, such as cast fp32 parameters to fp16 type.
  
        Args:
            place(CUDAPlace): place is used to initialize 
                fp16 parameters with fp32 values.
            scope(Scope): The scope is used to find fp32 parameters.
            test_program(Program): The program is used for testing.
            use_fp16_test(bool): Whether to use fp16 testing.
            
        Examples:
            .. code-block:: python

                import numpy as np
                import paddle
                import paddle.nn.functional as F
                paddle.enable_static()

                def run_example_code():
                    place = paddle.CUDAPlace(0)
                    exe = paddle.static.Executor(place)
                    data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32')
                    conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3)
                    # 1) Use fp16_guard to control the range of fp16 kernels used.
                    with paddle.static.amp.fp16_guard():
                        bn = paddle.static.nn.batch_norm(input=conv2d, act="relu")
                        pool = F.max_pool2d(bn, kernel_size=2, stride=2)
                        hidden = paddle.static.nn.fc(pool, size=10)
                        loss = paddle.mean(hidden)
                    # 2) Create the optimizer and set `multi_precision` to True.
                    # Setting `multi_precision` to True can avoid the poor accuracy
                    # or the slow convergence in a way. 
                    optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True)
                    # 3) These ops in `custom_black_list` will keep in the float32 computation type.
                    amp_list = paddle.static.amp.CustomOpLists(
                        custom_black_list=['pool2d'])
                    # 4) The entry of Paddle AMP.
                    # Enable pure fp16 training by setting `use_pure_fp16` to True.
                    optimizer = paddle.static.amp.decorate(
                        optimizer,
                        amp_list,
                        init_loss_scaling=128.0,
                        use_dynamic_loss_scaling=True,
                        use_pure_fp16=True)
                    # If you don't use the default_startup_program(), you sholud pass
                    # your defined `startup_program` into `minimize`.
                    optimizer.minimize(loss)
                    exe.run(paddle.static.default_startup_program())
                    # 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`).
                    # If you want to perform the testing process, you should pass `test_program` into `amp_init`.
                    optimizer.amp_init(place, scope=paddle.static.global_scope())
                    
                if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
                    run_example_code()       
        """
1131
        amp_optimizer = self._get_amp_optimizer()
1132
        return amp_optimizer.amp_init(place, scope, test_program, use_fp16_test)
H
huangxu96 已提交
1133

D
Dong Daxiang 已提交
1134 1135 1136 1137 1138 1139 1140 1141 1142
    def _final_strategy(self):
        if "valid_strategy" not in self._context:
            print(
                "WARNING: You may need to call minimize function before this function is called"
            )
            return {}
        else:
            return self._context["valid_strategy"]

1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160
    def _get_applied_meta_list(self):
        if "applied_meta_list" not in self._context:
            print(
                "WARNING: You may need to call minimize function before _get_applied_meta_list called"
            )
            return []
        else:
            return self._context["applied_meta_list"]

    def _get_applied_graph_list(self):
        if "applied_graph_list" not in self._context:
            print(
                "WARNING: You may need to call minimize function before _get_applied_graph_list called"
            )
            return []
        else:
            return self._context["applied_graph_list"]

1161 1162 1163 1164 1165 1166 1167 1168 1169
    def minimize(self,
                 loss,
                 startup_program=None,
                 parameter_list=None,
                 no_grad_set=None):
        """
        Add distributed operations to minimize ``loss`` by updating ``parameter_list``.

        Args:
1
123malin 已提交
1170
            loss (Tensor): A ``Tensor`` containing the value to minimize.
1171 1172 1173
            startup_program (Program, optional): :ref:`api_fluid_Program` for
                initializing parameters in ``parameter_list``. The default value
                is None, at this time :ref:`api_fluid_default_startup_program` will be used.
1
123malin 已提交
1174
            parameter_list (Iterable, optional): Iterable of ``Tensor`` or ``Tensor.name`` to update
1175 1176
                to minimize ``loss``. The default value is None, at this time all parameters
                will be updated.
1
123malin 已提交
1177
            no_grad_set (set, optional): Set of ``Tensor``  or ``Tensor.name`` that don't need
1178 1179 1180 1181
                to be updated. The default value is None.

        Returns:
            tuple: tuple (optimize_ops, params_grads), A list of operators appended
1
123malin 已提交
1182
            by minimize and a list of (param, grad) tensor pairs, param is
1183
            ``Parameter``, grad is the gradient value corresponding to the parameter.
1184 1185
            The returned tuple can be passed to ``fetch_list`` in ``Executor.run()`` to
            indicate program pruning. If so, the program will be pruned by ``feed`` and
1186 1187 1188
            ``fetch_list`` before run, see details in ``Executor``.

        Examples:
1
123malin 已提交
1189

1190
            .. code-block:: python
1191

1192
                import paddle
1
123malin 已提交
1193
                paddle.enable_static()
1194
                import paddle.distributed.fleet as fleet
1
123malin 已提交
1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205
                import paddle.nn.functional as F

                hid_dim = 10
                label_dim = 2
                input_x = paddle.static.data(name='x', shape=[None, 13], dtype='float32')
                input_y = paddle.static.data(name='y', shape=[None, 1], dtype='int64')
                fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim, activation='tanh')
                fc_2 = paddle.static.nn.fc(x=fc_1, size=hid_dim, activation='tanh')
                prediction = paddle.static.nn.fc(x=[fc_2], size=label_dim, activation='softmax')
                cost = F.cross_entropy(input=prediction, label=input_y)
                avg_cost = paddle.mean(x=cost)
1206

1
123malin 已提交
1207
                fleet.init(is_collective=True)
1208 1209 1210 1211
                strategy = fleet.DistributedStrategy()
                optimizer = paddle.optimizer.SGD(learning_rate=0.001)
                optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
                optimizer.minimize(avg_cost)
1212

1213
                # for more examples, please reference https://github.com/PaddlePaddle/FleetX
1214 1215

        """
D
Dong Daxiang 已提交
1216 1217 1218
        context = {}
        context["user_defined_strategy"] = copy.deepcopy(
            self._user_defined_strategy)
1219 1220 1221
        if paddle.fluid.framework.in_dygraph_mode():
            # imitate target optimizer retrieval
            target_opt = self.user_defined_optimizer
D
Dong Daxiang 已提交
1222
            self._context = context
1223 1224
            return target_opt.minimize(loss)

1225 1226
        # cache original feed forward program
        self.origin_main_program = loss.block.program
1227 1228
        context["origin_main_program"] = self.origin_main_program
        context["loss"] = loss
1229 1230
        if startup_program == None:
            self.origin_startup_program = \
1231 1232
                paddle.static.default_startup_program().clone(for_test=False)
            startup_program = paddle.static.default_startup_program()
1233 1234 1235
        else:
            self.origin_startup_program = \
                startup_program.clone(for_test=False)
1236

1237 1238
        context["origin_startup_program"] = startup_program
        context["role_maker"] = self._role_maker
1239 1240 1241 1242 1243

        # compile time
        distributed_optimizer_list = \
            MetaOptimizerFactory()._get_valid_meta_optimizers(
                self.user_defined_optimizer)
D
Dong Daxiang 已提交
1244

D
Dong Daxiang 已提交
1245 1246 1247
        context["user_defined_strategy"] = copy.deepcopy(
            self._user_defined_strategy)
        copy_user_defined_strategy = copy.deepcopy(self._user_defined_strategy)
1248 1249 1250 1251 1252 1253

        # trigger the auto-parallel in very strict condition
        # strategy = DistributedStrategy()
        # strategy.auto = True
        # optimizer = paddle.optimizer.SGD(learning_rate=0.1)
        # optimizer = fleet.distributed_optimizer(optimizer, strategy)
D
Dong Daxiang 已提交
1254
        if copy_user_defined_strategy._is_strict_auto():
1255 1256
            # turn on all the strategy for each optimizer
            for opt in distributed_optimizer_list:
D
Dong Daxiang 已提交
1257
                opt._enable_strategy(copy_user_defined_strategy, context)
1258

1259 1260
        valid_optimizer_list = []
        valid_graph_optimizer_list = []
D
Dong Daxiang 已提交
1261
        can_not_apply_optimizer_list = []
1262 1263 1264 1265
        # recall meta optimizers for ranking
        for opt in distributed_optimizer_list:
            opt._set_basic_info(loss, self._role_maker,
                                self.user_defined_optimizer,
D
Dong Daxiang 已提交
1266
                                copy_user_defined_strategy)
1267 1268
            if opt._can_apply() and not opt._is_graph_out():
                valid_optimizer_list.append(opt)
D
Dong Daxiang 已提交
1269
            elif opt._can_apply() and opt._is_graph_out():
1270
                valid_graph_optimizer_list.append(opt)
D
Dong Daxiang 已提交
1271 1272
            else:
                can_not_apply_optimizer_list.append(opt)
1273
        # combine recalled meta optimizers to be a valid meta optimizer
D
Dong Daxiang 已提交
1274
        meta_optimizer, graph_optimizer = \
1275 1276
            self.strategy_compiler.generate_optimizer(
                loss, self._role_maker, self.user_defined_optimizer,
D
Dong Daxiang 已提交
1277
                copy_user_defined_strategy, valid_optimizer_list,
1278
                valid_graph_optimizer_list)
D
Dong Daxiang 已提交
1279

D
Dong Daxiang 已提交
1280
        valid_strategy = self.strategy_compiler._get_valid_strategy(
D
Dong Daxiang 已提交
1281 1282 1283
            copy_user_defined_strategy, can_not_apply_optimizer_list)

        context["valid_strategy"] = copy.deepcopy(valid_strategy)
1284

1285 1286 1287 1288 1289 1290
        applied_meta_list = self.strategy_compiler._get_applied_meta_list()
        applied_graph_list = self.strategy_compiler._get_applied_graph_list()

        context['applied_meta_list'] = applied_meta_list
        context['applied_graph_list'] = applied_graph_list

D
Dong Daxiang 已提交
1291
        self._context = context
1292

D
Dong Daxiang 已提交
1293
        self.valid_strategy = valid_strategy
1294
        self.valid_strategy._enable_env()
D
Dong Daxiang 已提交
1295

1296 1297
        optimize_ops = []
        params_grads = []
1298

1299 1300 1301 1302 1303 1304 1305 1306 1307
        if self._role_maker._is_non_distributed() and not self._is_collective:
            if self._runtime_handle is None:
                self._runtime_handle = RuntimeFactory()._create_runtime(context)

            compiled_program = compiler.CompiledProgram(
                self.origin_main_program).with_data_parallel(
                    loss_name=loss.name, share_vars_from=None)
            loss.block.program._graph = compiled_program
            return self.user_defined_optimizer.minimize(
M
MRXLT 已提交
1308
                loss, startup_program, parameter_list, no_grad_set=no_grad_set)
1309

1310 1311
        if meta_optimizer:
            optimize_ops, params_grads = meta_optimizer.minimize(
M
MRXLT 已提交
1312
                loss, startup_program, parameter_list, no_grad_set=no_grad_set)
1313

1314
            default_program = paddle.static.default_main_program()
1315 1316 1317 1318

            if id(default_program) != id(loss.block.program):
                paddle.fluid.framework.switch_main_program(loss.block.program)

1319 1320
        else:
            optimize_ops, params_grads = self.user_defined_optimizer.minimize(
M
MRXLT 已提交
1321
                loss, startup_program, parameter_list, no_grad_set=no_grad_set)
1322

1323 1324
        context["program_optimize_ops"] = optimize_ops
        context["program_params_grads"] = params_grads
1325

1326
        if graph_optimizer:
D
Dong Daxiang 已提交
1327
            optimize_ops, params_grads = graph_optimizer.minimize(
M
MRXLT 已提交
1328
                loss, startup_program, parameter_list, no_grad_set=no_grad_set)
1329 1330 1331 1332
            # since we do not encourage users to use graph operations
            # if a graph optimizer takes effect, mostly
            # optimizers_ops and params_grads are None
            # i.e. users can not modify current computation graph anymore
1333 1334 1335
            context["graph_optimize_ops"] = optimize_ops
            context["graph_optimize_grads"] = params_grads

1336
        if self._runtime_handle is None:
1337
            self._runtime_handle = RuntimeFactory()._create_runtime(context)
1338

1339 1340
        import paddle.distributed.fleet as fleet
        fleet.util._set_strategy(context["valid_strategy"])
1341 1342

        return optimize_ops, params_grads
1343 1344 1345 1346

    @dygraph_only
    def distributed_scaler(self, scaler):
        return HybridParallelGradScaler(scaler, self._hcg)