fleet_base.py 37.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
16
import copy
17
import warnings
18
import paddle
19
import os
20
from paddle.fluid.framework import dygraph_only
21
from paddle.fluid import compiler
22
from .role_maker import UserDefinedRoleMaker, PaddleCloudRoleMaker, RoleMakerBase
23
from .strategy_compiler import StrategyCompiler
24
from .distributed_strategy import DistributedStrategy
25 26
from .meta_optimizer_factory import MetaOptimizerFactory
from .runtime_factory import RuntimeFactory
27
from paddle.fluid.wrapped_decorator import wrap_decorator
28
from paddle.fluid.dygraph import parallel_helper
29

30

31 32 33 34 35 36 37 38 39 40 41 42
def _inited_runtime_handler_(func):
    def __impl__(*args, **kwargs):
        cls = args[0]

        if cls._runtime_handle is None:
            raise ValueError("Fleet can not find suitable runtime handler")

        return func(*args, **kwargs)

    return __impl__


43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
def _is_non_distributed_check_(func):
    def __impl__(*args, **kwargs):
        cls = args[0]

        if cls._role_maker is not None and cls._role_maker._is_non_distributed(
        ) is True:
            warnings.warn(
                "%s() function doesn't work when use non_distributed fleet." %
                (func.__name__))
            return

        return func(*args, **kwargs)

    return __impl__


59
inited_runtime_handler = wrap_decorator(_inited_runtime_handler_)
60
is_non_distributed_check = wrap_decorator(_is_non_distributed_check_)
61 62


63 64 65
class Fleet(object):
    """
    Unified API for distributed training of PaddlePaddle
66
    Please reference the https://github.com/PaddlePaddle/FleetX for details
67 68 69 70 71


    Returns:
        Fleet: A Fleet instance

72
    Example for collective training:
1
123malin 已提交
73

74 75
        .. code-block:: python

1
123malin 已提交
76 77
            import paddle
            paddle.enable_static()
78
            import paddle.distributed.fleet as fleet
79 80 81

            fleet.init(is_collective=True)

82 83 84
            strategy = fleet.DistributedStrategy()
            optimizer = paddle.optimizer.SGD(learning_rate=0.001)
            optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
85 86 87 88 89 90 91 92

            # do distributed training


    Example for parameter server training:

        .. code-block:: python

1
123malin 已提交
93 94
            import paddle
            paddle.enable_static()
95 96
            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
97 98
            fleet.init(strategy=strategy)

99
            optimizer = paddle.optimizer.SGD(learning_rate=0.001)
100
            optimizer = fleet.distributed_optimizer(optimizer)
101

102 103
            if fleet.is_first_worker():
                print("this is first worker")
104

105 106
            print("current node index: {}".format(fleet.worker_index()))
            print("total number of worker num: {}".format(fleet.worker_num()))
107

108 109 110
            if fleet.is_worker():
                print("this is worker")
            print("worker endpoints: {}".format(fleet.worker_endpoints(to_string=True)))
111

112 113
            print("server num: {}".format(fleet.server_num()))
            print("server endpoints: {}".format(fleet.server_endpoints(to_string=True)))
114

115 116 117
            if fleet.is_server():
                print("this is server")
            fleet.stop_worker()
118 119


120 121 122
    """

    def __init__(self):
123
        self._role_maker = None
124
        self.strategy_compiler = None
125
        self._is_collective = False
126
        self._runtime_handle = None
D
Dong Daxiang 已提交
127 128
        self._util = None
        self._context = {}
129

130
    def init(self, role_maker=None, is_collective=False, strategy=None):
131 132 133
        """
        Initialize role_maker in Fleet.

134 135 136 137 138 139 140 141 142 143 144
        This function is responsible for the distributed architecture
        what you want to run your code behind.

        Args:
            role_maker (RoleMakerBase, optional): A ``RoleMakerBase`` containing the configuration
                of environment variables related to distributed training.If you did not initialize 
                the rolemaker by yourself, it will be automatically initialized to PaddleRoleMaker.
                The default value is None.
            is_collective (Boolean, optional): A ``Boolean`` variable determines whether the program 
                runs on the CPU or GPU. False means set distributed training using CPU, and True means
                GPU.The default value is False.The default value is False.
145 146 147 148
            strategy (DistributedStrategy): Extra properties for distributed training. 
                For details, please refer to paddle.distributed.fleet.DistributedStrategy. Default: None.


149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
        Returns:
            None

        Examples1:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

        Examples2:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init(is_collective=True)

        Examples3:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
1
123malin 已提交
171
                role = fleet.PaddleCloudRoleMaker()
172
                fleet.init(role)
173

174 175 176 177 178 179 180 181
        Examples4:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                strategy = fleet.DistributedStrategy()
                fleet.init(strategy=strategy)

182
        """
183 184 185
        if strategy is None:
            strategy = DistributedStrategy()
        self._user_defined_strategy = copy.deepcopy(strategy)
186 187

        if role_maker is None:
188 189 190 191 192 193
            if isinstance(is_collective, bool):
                self._is_collective = is_collective
                self._role_maker = PaddleCloudRoleMaker(
                    is_collective=self._is_collective)
            else:
                raise ValueError(
194 195
                    "`is_collective` should be instance of `bool`, but got {}".
                    format(type(is_collective)))
196
        else:
197 198 199 200 201 202
            if isinstance(role_maker, RoleMakerBase):
                self._role_maker = role_maker
            else:
                raise ValueError(
                    "`role_maker` should be subclass of `RoleMakerBase`, but got {}".
                    format(type(role_maker)))
203
        self._role_maker._generate_role()
204

205 206 207
        import paddle.distributed.fleet as fleet
        fleet.util._set_role_maker(self._role_maker)

208
        self.strategy_compiler = StrategyCompiler()
209 210 211 212 213 214 215 216 217

        if self._role_maker._is_non_distributed() and self._is_collective:
            if paddle.fluid.core.is_compiled_with_cuda():
                gpus_num = paddle.fluid.core.get_cuda_device_count()
                if gpus_num != 1:
                    raise ValueError(
                        "CUDA_VISIBLE_DEVICES shoule be set only 1 card if you use `python` to launch fleet program."
                    )

218
        if paddle.fluid.framework.in_dygraph_mode():
219 220
            if self.worker_num() == 1:
                return
221 222 223 224
            if parallel_helper._is_parallel_ctx_initialized():
                warnings.warn(
                    "The dygraph parallel environment has been initialized.")
            else:
225 226 227 228 229 230 231 232 233
                # FLAGS_nccl_nrings is used for dynamic graph multi-stream communication
                if "FLAGS_nccl_nrings" in os.environ:
                    warnings.warn(
                        "You have set the environment variable FLAGS_nccl_nrings "
                        "outside the program, so the nccl_comm_num in "
                        "DistributedStrategy will not take effect here.")
                else:
                    os.environ["FLAGS_nccl_nrings"] = str(
                        self._user_defined_strategy.nccl_comm_num)
234
                paddle.distributed.init_parallel_env()
235 236 237 238 239 240 241 242

    def is_first_worker(self):
        """
        Check whether the node is the first instance of worker.

        Returns:
            bool: True if this is the first node of worker,
                  False if not.
243

244 245 246 247 248 249 250 251
        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.is_first_worker()

252
        """
253
        return self._role_maker._is_first_worker()
254 255 256 257 258 259 260

    def worker_index(self):
        """
        Get current worker index.

        Returns:
            int: node id
261 262 263 264

        Examples:

            .. code-block:: python
1
123malin 已提交
265

266 267 268 269
                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.worker_index()

270
        """
271
        return self._role_maker._worker_index()
272 273 274 275 276 277 278

    def worker_num(self):
        """
        Get current total worker number.

        Returns:
            int: worker numbers
1
123malin 已提交
279

280
        Examples:
1
123malin 已提交
281

282 283 284 285 286 287
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.worker_num()

288
        """
289
        return self._role_maker._worker_num()
290 291 292 293 294 295 296 297

    def is_worker(self):
        """
        Check whether the node is an instance of worker.

        Returns:
            bool: True if this is a node of worker,
                  False if not.
298 299

        Examples:
1
123malin 已提交
300

301 302 303 304 305 306
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.is_worker()

307
        """
308
        return self._role_maker._is_worker()
309 310 311

    def worker_endpoints(self, to_string=False):
        """
312
        Get current worker endpoints, such as ["127.0.0.1:1001", "127.0.0.1:1002"].
313 314 315

        Returns:
            list/string: server endpoints
316 317

        Examples:
1
123malin 已提交
318

319 320 321 322 323 324
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.worker_endpoints()

325 326
        """
        if to_string:
327
            return ",".join(self._role_maker._get_trainer_endpoints())
328
        else:
329
            return self._role_maker._get_trainer_endpoints()
330 331 332 333 334 335 336

    def server_num(self):
        """
        Get current total worker number.

        Returns:
            int: server number
337 338

        Examples:
1
123malin 已提交
339

340
            .. code-block:: python
1
123malin 已提交
341 342 343 344

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.server_num()
345
        """
346
        return len(self._role_maker._get_pserver_endpoints())
347 348 349 350 351 352 353

    def server_index(self):
        """
        Get current server index.

        Returns:
            int: node id
354 355

        Examples:
1
123malin 已提交
356

357 358 359 360 361 362
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.server_index()

363
        """
364
        return self._role_maker._server_index()
365 366 367 368 369 370 371

    def server_endpoints(self, to_string=False):
        """
        Get current server endpoints, such as ["127.0.0.1:1001", "127.0.0.1:1002"].

        Returns:
            list/string: server endpoints
372 373

        Examples:
1
123malin 已提交
374

375 376 377 378 379 380
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.server_endpoints()

381
        """
382

383
        if to_string:
384
            return ",".join(self._role_maker._get_pserver_endpoints())
385
        else:
386
            return self._role_maker._get_pserver_endpoints()
387 388 389 390 391 392 393 394

    def is_server(self):
        """
        Check whether the node is an instance of server.

        Returns:
            bool: True if this is a node of server,
                  False if not.
395 396 397 398

        Examples:

            .. code-block:: python
1
123malin 已提交
399

400 401 402 403
                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.is_server()

404
        """
405
        return self._role_maker._is_server(
406
        ) or self._role_maker._is_heter_worker()
407 408 409

    def barrier_worker(self):
        """
410 411 412 413
        barrier all workers

        Returns:
            None
414
        """
415
        self._role_maker._barrier("worker")
416

417
    @is_non_distributed_check
418
    @inited_runtime_handler
419 420
    def init_worker(self):
        """
421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438
        initialize `Communicator` for parameter server training.


        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_worker()

439 440 441
        """
        self._runtime_handle._init_worker()

442
    @is_non_distributed_check
443
    @inited_runtime_handler
444
    def init_server(self, *args, **kwargs):
445
        """
446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464
        init_server executor to initialize startup program,
        if the `args` is not empty, it will run load_persistables for increment training.


        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_server()

465
        """
466
        self._runtime_handle._init_server(*args, **kwargs)
467

468
    @is_non_distributed_check
469
    @inited_runtime_handler
470 471
    def run_server(self):
        """
472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489
        run server will run pserver main program with executor.

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                if fleet.is_server():
                    fleet.init_server()

490 491 492
        """
        self._runtime_handle._run_server()

493
    @is_non_distributed_check
494
    @inited_runtime_handler
495 496
    def stop_worker(self):
        """
497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513
        stop `Communicator` and give training complete notice to parameter server.

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_server()

514 515 516
        """
        self._runtime_handle._stop_worker()

517 518 519 520 521 522 523
    def save_inference_model(self,
                             executor,
                             dirname,
                             feeded_var_names,
                             target_vars,
                             main_program=None,
                             export_for_deployment=True):
524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543
        """
        save inference model for inference.

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_server()

        """

544 545 546 547
        self._runtime_handle._save_inference_model(
            executor, dirname, feeded_var_names, target_vars, main_program,
            export_for_deployment)

548
    def save_persistables(self, executor, dirname, main_program=None, mode=1):
549 550
        """

1
123malin 已提交
551
        saves all persistable tensors from :code:`main_program` to
552 553
        the folder :code:`dirname`. You can refer to

1
123malin 已提交
554 555
        The :code:`dirname` is used to specify the folder where persistable tensors
        are going to be saved. If you would like to save tensors in separate
556 557 558
        files, set :code:`filename` None.

        Args:
1
123malin 已提交
559
            executor(Executor): The executor to run for saving persistable tensors.
560 561 562 563 564
                                You can refer to :ref:`api_guide_executor_en` for
                                more details.

            dirname(str, optional): The saving directory path.
                                When you need to save the parameter to the memory, set it to None.
1
123malin 已提交
565
            main_program(Program, optional): The program whose persistbale tensors will
566 567 568 569 570 571 572 573 574 575
                                             be saved. Default: None.


        Returns:
            None

        Examples:

            .. code-block:: text

1
123malin 已提交
576 577
                import paddle
                paddle.enable_static()
578 579 580 581 582 583 584
                import paddle.distributed.fleet as fleet

                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

1
123malin 已提交
585 586
                exe = paddle.static.Executor(paddle.CPUPlace())
                fleet.save_persistables(exe, "dirname", paddle.static.default_main_program())
587 588 589

        """

590 591
        self._runtime_handle._save_persistables(executor, dirname, main_program,
                                                mode)
592

593
    def distributed_optimizer(self, optimizer, strategy=None):
594
        """
595 596 597 598 599 600 601
        Optimizer for distributed training.

        For the distributed training, this method would rebuild a new instance of DistributedOptimizer.
        Which has basic Optimizer function and special features for distributed training.

        Args:
            optimizer(Optimizer): The executor to run for init server.
602 603 604 605 606
            strategy(DistributedStrategy): Extra properties for distributed optimizer. 
                It is recommended to use DistributedStrategy in fleet.init(). The strategy
                here is for compatibility. If the strategy in fleet.distributed_optimizer() 
                is not None, then it will overwrite the DistributedStrategy in fleet.init(), 
                which will take effect in distributed training.
607

608
        Returns:
609
            Fleet: instance of fleet.
610 611

        Examples:
612

613
            .. code-block:: python
614

1
123malin 已提交
615
                import paddle
616
                import paddle.distributed.fleet as fleet
1
123malin 已提交
617
                fleet.init(is_collective=True)
618 619 620 621
                strategy = fleet.DistributedStrategy()
                optimizer = paddle.optimizer.SGD(learning_rate=0.001)
                optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)

622 623
        """
        self.user_defined_optimizer = optimizer
624

625 626
        if strategy is not None:
            warnings.warn(
627 628 629 630
                "It is recommended to use DistributedStrategy "
                "in fleet.init(). The strategy here is only for compatibility. "
                "If the strategy in fleet.distributed_optimizer() is "
                "not None, then it will overwrite the DistributedStrategy in fleet.init(), "
631 632
                "which will take effect in distributed training.")
            self._user_defined_strategy = copy.deepcopy(strategy)
D
Dong Daxiang 已提交
633 634

        self._context = {}
635 636
        return self

637
    @dygraph_only
638
    def distributed_model(self, model):
639
        """
640 641 642 643 644 645 646
        Return distributed data parallel model (Only work in dygraph mode)

        Args:
            model (Layer): the user-defind model which inherits Layer.

        Returns:
            distributed data parallel model which inherits Layer.
647 648

        Examples:
649

650 651
            .. code-block:: python

652 653 654 655 656 657 658 659 660
                import paddle
                import paddle.nn as nn
                from paddle.distributed import fleet

                class LinearNet(nn.Layer):
                    def __init__(self):
                        super(LinearNet, self).__init__()
                        self._linear1 = nn.Linear(10, 10)
                        self._linear2 = nn.Linear(10, 1)
661

662 663
                    def forward(self, x):
                        return self._linear2(self._linear1(x))
664

1
123malin 已提交
665
                # 1. initialize fleet environment
666 667
                fleet.init(is_collective=True)

1
123malin 已提交
668
                # 2. create layer & optimizer
669 670 671 672 673
                layer = LinearNet()
                loss_fn = nn.MSELoss()
                adam = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=layer.parameters())

1
123malin 已提交
674
                # 3. get data_parallel model using fleet
675 676 677
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)

1
123malin 已提交
678
                # 4. run layer
679 680 681 682 683 684 685 686 687 688 689 690
                inputs = paddle.randn([10, 10], 'float32')
                outputs = dp_layer(inputs)
                labels = paddle.randn([10, 1], 'float32')
                loss = loss_fn(outputs, labels)

                print("loss:", loss.numpy())

                loss.backward()

                adam.step()
                adam.clear_grad()

691

692 693
        """
        assert model is not None
694 695
        self.model = paddle.DataParallel(
            model,
696 697 698
            comm_buffer_size=self._user_defined_strategy.fuse_grad_size_in_MB,
            last_comm_buffer_size=self._user_defined_strategy.
            last_comm_group_size_MB)
699 700 701 702 703 704
        return self.model

    @dygraph_only
    def state_dict(self):
        """
        Get state dict information from optimizer.
705
        (Only work in dygraph mode)
706 707 708 709 710 711 712

        Returns: 
            state_dict(dict) : dict contains all the Tensor used by optimizer

        Examples:
            .. code-block:: python

713 714 715 716 717
                import numpy as np
                import paddle
                from paddle.distributed import fleet

                fleet.init(is_collective=True)
718

719
                value = np.arange(26).reshape(2, 13).astype("float32")
1
123malin 已提交
720
                a = paddle.to_tensor(value)
721

722 723
                layer = paddle.nn.Linear(13, 5)
                adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
724

725 726 727
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)
                state_dict = adam.state_dict()
728 729 730 731 732 733 734 735
        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.state_dict()

    @dygraph_only
    def set_state_dict(self, state_dict):
        """
        Load optimizer state dict.
736
        (Only work in dygraph mode)
737 738 739 740

        Args: 
            state_dict(dict) : Dict contains all the Tensor needed by optimizer

741 742
        Returns:
            None
743 744 745 746

        Examples:
            .. code-block:: python

747 748 749
                import numpy as np
                import paddle
                from paddle.distributed import fleet
750

751 752 753
                fleet.init(is_collective=True)

                value = np.arange(26).reshape(2, 13).astype("float32")
1
123malin 已提交
754
                a = paddle.to_tensor(value)
755

756 757
                layer = paddle.nn.Linear(13, 5)
                adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
758

759 760 761
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)
                state_dict = adam.state_dict()
1
123malin 已提交
762 763 764
                paddle.save(state_dict, "paddle_dy")
                para_state_dict = paddle.load("paddle_dy")
                adam.set_state_dict(para_state_dict)
765 766 767 768 769 770 771 772
        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.set_state_dict(state_dict)

    @dygraph_only
    def set_lr(self, value):
        """
        Set the value of the learning rate manually in the optimizer. 
773
        (Only work in dygraph mode)
774

775 776 777
        Args:
            value (float|Tensor): the value of learning rate

778 779
        Returns: 
            None 
780 781 782 783

        Examples:
            .. code-block:: python

784 785 786
                import numpy as np
                import paddle
                from paddle.distributed import fleet
787

788
                fleet.init(is_collective=True)
789

790
                value = np.arange(26).reshape(2, 13).astype("float32")
1
123malin 已提交
791
                a = paddle.to_tensor(value)
792

793 794
                layer = paddle.nn.Linear(13, 5)
                adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
795

796 797 798 799 800 801 802 803 804 805 806 807 808 809
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)

                lr_list = [0.2, 0.3, 0.4, 0.5, 0.6]
                for i in range(5):
                    adam.set_lr(lr_list[i])
                    lr = adam.get_lr()
                    print("current lr is {}".format(lr))
                # Print:
                #    current lr is 0.2
                #    current lr is 0.3
                #    current lr is 0.4
                #    current lr is 0.5
                #    current lr is 0.6
810 811 812 813 814 815 816 817
        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.set_lr(value)

    @dygraph_only
    def get_lr(self):
        """
        Get current step learning rate.
818
        (Only work in dygraph mode)
819 820 821 822 823

        Returns:
            float: The learning rate of the current step.

        Examples:
1
123malin 已提交
824

825 826
            .. code-block:: python

827 828 829 830 831
                import numpy as np
                import paddle
                from paddle.distributed import fleet

                fleet.init(is_collective=True)
832

833
                value = np.arange(26).reshape(2, 13).astype("float32")
1
123malin 已提交
834
                a = paddle.to_tensor(value)
835

836 837
                layer = paddle.nn.Linear(13, 5)
                adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
838

839 840
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)
841

842 843
                lr = adam.get_lr()
                print(lr) # 0.01
844 845 846 847 848 849 850 851
        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.get_lr()

    @dygraph_only
    def step(self):
        """
        Execute the optimizer once.
852
        (Only work in dygraph mode)
853

854 855
        Returns:
            None
856 857

        Examples:
1
123malin 已提交
858

859 860
            .. code-block:: python

861 862 863
                import paddle
                import paddle.nn as nn
                from paddle.distributed import fleet
864

865 866 867 868 869
                class LinearNet(nn.Layer):
                    def __init__(self):
                        super(LinearNet, self).__init__()
                        self._linear1 = nn.Linear(10, 10)
                        self._linear2 = nn.Linear(10, 1)
870

871 872
                    def forward(self, x):
                        return self._linear2(self._linear1(x))
873

1
123malin 已提交
874
                # 1. initialize fleet environment
875 876
                fleet.init(is_collective=True)

1
123malin 已提交
877
                # 2. create layer & optimizer
878 879 880 881 882
                layer = LinearNet()
                loss_fn = nn.MSELoss()
                adam = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=layer.parameters())

1
123malin 已提交
883
                # 3. get data_parallel model using fleet
884 885 886
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)

1
123malin 已提交
887
                # 4. run layer
888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907
                inputs = paddle.randn([10, 10], 'float32')
                outputs = dp_layer(inputs)
                labels = paddle.randn([10, 1], 'float32')
                loss = loss_fn(outputs, labels)

                print("loss:", loss.numpy())

                loss.backward()

                adam.step()
                adam.clear_grad()


        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.step()

    @dygraph_only
    def clear_grad(self):
        """
908 909
        Clear the gradients of all optimized parameters for model.
        (Only work in dygraph mode)
910

911 912
        Returns: 
            None
913 914

        Examples:
1
123malin 已提交
915

916 917
            .. code-block:: python

918 919 920
                import paddle
                import paddle.nn as nn
                from paddle.distributed import fleet
921

922 923 924 925 926
                class LinearNet(nn.Layer):
                    def __init__(self):
                        super(LinearNet, self).__init__()
                        self._linear1 = nn.Linear(10, 10)
                        self._linear2 = nn.Linear(10, 1)
927

928 929
                    def forward(self, x):
                        return self._linear2(self._linear1(x))
930

1
123malin 已提交
931
                # 1. initialize fleet environment
932 933
                fleet.init(is_collective=True)

1
123malin 已提交
934
                # 2. create layer & optimizer
935 936 937 938 939
                layer = LinearNet()
                loss_fn = nn.MSELoss()
                adam = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=layer.parameters())

1
123malin 已提交
940
                # 3. get data_parallel model using fleet
941 942 943
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)

1
123malin 已提交
944
                # 4. run layer
945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960
                inputs = paddle.randn([10, 10], 'float32')
                outputs = dp_layer(inputs)
                labels = paddle.randn([10, 1], 'float32')
                loss = loss_fn(outputs, labels)

                print("loss:", loss.numpy())

                loss.backward()

                adam.step()
                adam.clear_grad()

        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.clear_grad()

D
Dong Daxiang 已提交
961 962 963 964 965 966 967 968 969
    def _final_strategy(self):
        if "valid_strategy" not in self._context:
            print(
                "WARNING: You may need to call minimize function before this function is called"
            )
            return {}
        else:
            return self._context["valid_strategy"]

970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987
    def _get_applied_meta_list(self):
        if "applied_meta_list" not in self._context:
            print(
                "WARNING: You may need to call minimize function before _get_applied_meta_list called"
            )
            return []
        else:
            return self._context["applied_meta_list"]

    def _get_applied_graph_list(self):
        if "applied_graph_list" not in self._context:
            print(
                "WARNING: You may need to call minimize function before _get_applied_graph_list called"
            )
            return []
        else:
            return self._context["applied_graph_list"]

988 989 990 991 992 993 994 995 996
    def minimize(self,
                 loss,
                 startup_program=None,
                 parameter_list=None,
                 no_grad_set=None):
        """
        Add distributed operations to minimize ``loss`` by updating ``parameter_list``.

        Args:
1
123malin 已提交
997
            loss (Tensor): A ``Tensor`` containing the value to minimize.
998 999 1000
            startup_program (Program, optional): :ref:`api_fluid_Program` for
                initializing parameters in ``parameter_list``. The default value
                is None, at this time :ref:`api_fluid_default_startup_program` will be used.
1
123malin 已提交
1001
            parameter_list (Iterable, optional): Iterable of ``Tensor`` or ``Tensor.name`` to update
1002 1003
                to minimize ``loss``. The default value is None, at this time all parameters
                will be updated.
1
123malin 已提交
1004
            no_grad_set (set, optional): Set of ``Tensor``  or ``Tensor.name`` that don't need
1005 1006 1007 1008
                to be updated. The default value is None.

        Returns:
            tuple: tuple (optimize_ops, params_grads), A list of operators appended
1
123malin 已提交
1009
            by minimize and a list of (param, grad) tensor pairs, param is
1010
            ``Parameter``, grad is the gradient value corresponding to the parameter.
1011 1012
            The returned tuple can be passed to ``fetch_list`` in ``Executor.run()`` to
            indicate program pruning. If so, the program will be pruned by ``feed`` and
1013 1014 1015
            ``fetch_list`` before run, see details in ``Executor``.

        Examples:
1
123malin 已提交
1016

1017
            .. code-block:: python
1018

1019
                import paddle
1
123malin 已提交
1020
                paddle.enable_static()
1021
                import paddle.distributed.fleet as fleet
1
123malin 已提交
1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032
                import paddle.nn.functional as F

                hid_dim = 10
                label_dim = 2
                input_x = paddle.static.data(name='x', shape=[None, 13], dtype='float32')
                input_y = paddle.static.data(name='y', shape=[None, 1], dtype='int64')
                fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim, activation='tanh')
                fc_2 = paddle.static.nn.fc(x=fc_1, size=hid_dim, activation='tanh')
                prediction = paddle.static.nn.fc(x=[fc_2], size=label_dim, activation='softmax')
                cost = F.cross_entropy(input=prediction, label=input_y)
                avg_cost = paddle.mean(x=cost)
1033

1
123malin 已提交
1034
                fleet.init(is_collective=True)
1035 1036 1037 1038
                strategy = fleet.DistributedStrategy()
                optimizer = paddle.optimizer.SGD(learning_rate=0.001)
                optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
                optimizer.minimize(avg_cost)
1039

1040
                # for more examples, please reference https://github.com/PaddlePaddle/FleetX
1041 1042

        """
D
Dong Daxiang 已提交
1043 1044 1045
        context = {}
        context["user_defined_strategy"] = copy.deepcopy(
            self._user_defined_strategy)
1046 1047 1048
        if paddle.fluid.framework.in_dygraph_mode():
            # imitate target optimizer retrieval
            target_opt = self.user_defined_optimizer
D
Dong Daxiang 已提交
1049
            self._context = context
1050 1051
            return target_opt.minimize(loss)

1052 1053
        # cache original feed forward program
        self.origin_main_program = loss.block.program
1054 1055
        context["origin_main_program"] = self.origin_main_program
        context["loss"] = loss
1056 1057
        if startup_program == None:
            self.origin_startup_program = \
1058 1059
                paddle.static.default_startup_program().clone(for_test=False)
            startup_program = paddle.static.default_startup_program()
1060 1061 1062
        else:
            self.origin_startup_program = \
                startup_program.clone(for_test=False)
1063

1064 1065
        context["origin_startup_program"] = startup_program
        context["role_maker"] = self._role_maker
1066 1067 1068 1069 1070

        # compile time
        distributed_optimizer_list = \
            MetaOptimizerFactory()._get_valid_meta_optimizers(
                self.user_defined_optimizer)
D
Dong Daxiang 已提交
1071

D
Dong Daxiang 已提交
1072 1073 1074
        context["user_defined_strategy"] = copy.deepcopy(
            self._user_defined_strategy)
        copy_user_defined_strategy = copy.deepcopy(self._user_defined_strategy)
1075 1076 1077 1078 1079 1080

        # trigger the auto-parallel in very strict condition
        # strategy = DistributedStrategy()
        # strategy.auto = True
        # optimizer = paddle.optimizer.SGD(learning_rate=0.1)
        # optimizer = fleet.distributed_optimizer(optimizer, strategy)
D
Dong Daxiang 已提交
1081
        if copy_user_defined_strategy._is_strict_auto():
1082 1083
            # turn on all the strategy for each optimizer
            for opt in distributed_optimizer_list:
D
Dong Daxiang 已提交
1084
                opt._enable_strategy(copy_user_defined_strategy, context)
1085

1086 1087
        valid_optimizer_list = []
        valid_graph_optimizer_list = []
D
Dong Daxiang 已提交
1088
        can_not_apply_optimizer_list = []
1089 1090 1091 1092
        # recall meta optimizers for ranking
        for opt in distributed_optimizer_list:
            opt._set_basic_info(loss, self._role_maker,
                                self.user_defined_optimizer,
D
Dong Daxiang 已提交
1093
                                copy_user_defined_strategy)
1094 1095
            if opt._can_apply() and not opt._is_graph_out():
                valid_optimizer_list.append(opt)
D
Dong Daxiang 已提交
1096
            elif opt._can_apply() and opt._is_graph_out():
1097
                valid_graph_optimizer_list.append(opt)
D
Dong Daxiang 已提交
1098 1099
            else:
                can_not_apply_optimizer_list.append(opt)
1100
        # combine recalled meta optimizers to be a valid meta optimizer
D
Dong Daxiang 已提交
1101
        meta_optimizer, graph_optimizer = \
1102 1103
            self.strategy_compiler.generate_optimizer(
                loss, self._role_maker, self.user_defined_optimizer,
D
Dong Daxiang 已提交
1104
                copy_user_defined_strategy, valid_optimizer_list,
1105
                valid_graph_optimizer_list)
D
Dong Daxiang 已提交
1106

D
Dong Daxiang 已提交
1107
        valid_strategy = self.strategy_compiler._get_valid_strategy(
D
Dong Daxiang 已提交
1108 1109 1110
            copy_user_defined_strategy, can_not_apply_optimizer_list)

        context["valid_strategy"] = copy.deepcopy(valid_strategy)
1111

1112 1113 1114 1115 1116 1117
        applied_meta_list = self.strategy_compiler._get_applied_meta_list()
        applied_graph_list = self.strategy_compiler._get_applied_graph_list()

        context['applied_meta_list'] = applied_meta_list
        context['applied_graph_list'] = applied_graph_list

D
Dong Daxiang 已提交
1118
        self._context = context
1119

D
Dong Daxiang 已提交
1120
        self.valid_strategy = valid_strategy
1121
        self.valid_strategy._enable_env()
D
Dong Daxiang 已提交
1122

1123 1124
        optimize_ops = []
        params_grads = []
1125

1126 1127 1128 1129 1130 1131 1132 1133 1134
        if self._role_maker._is_non_distributed() and not self._is_collective:
            if self._runtime_handle is None:
                self._runtime_handle = RuntimeFactory()._create_runtime(context)

            compiled_program = compiler.CompiledProgram(
                self.origin_main_program).with_data_parallel(
                    loss_name=loss.name, share_vars_from=None)
            loss.block.program._graph = compiled_program
            return self.user_defined_optimizer.minimize(
M
MRXLT 已提交
1135
                loss, startup_program, parameter_list, no_grad_set=no_grad_set)
1136

1137 1138
        if meta_optimizer:
            optimize_ops, params_grads = meta_optimizer.minimize(
M
MRXLT 已提交
1139
                loss, startup_program, parameter_list, no_grad_set=no_grad_set)
1140

1141
            default_program = paddle.static.default_main_program()
1142 1143 1144 1145

            if id(default_program) != id(loss.block.program):
                paddle.fluid.framework.switch_main_program(loss.block.program)

1146 1147
        else:
            optimize_ops, params_grads = self.user_defined_optimizer.minimize(
M
MRXLT 已提交
1148
                loss, startup_program, parameter_list, no_grad_set=no_grad_set)
1149

1150 1151
        context["program_optimize_ops"] = optimize_ops
        context["program_params_grads"] = params_grads
1152

1153
        if graph_optimizer:
D
Dong Daxiang 已提交
1154
            optimize_ops, params_grads = graph_optimizer.minimize(
M
MRXLT 已提交
1155
                loss, startup_program, parameter_list, no_grad_set=no_grad_set)
1156 1157 1158 1159
            # since we do not encourage users to use graph operations
            # if a graph optimizer takes effect, mostly
            # optimizers_ops and params_grads are None
            # i.e. users can not modify current computation graph anymore
1160 1161 1162
            context["graph_optimize_ops"] = optimize_ops
            context["graph_optimize_grads"] = params_grads

1163
        if self._runtime_handle is None:
1164
            self._runtime_handle = RuntimeFactory()._create_runtime(context)
1165

1166 1167
        import paddle.distributed.fleet as fleet
        fleet.util._set_strategy(context["valid_strategy"])
1168 1169

        return optimize_ops, params_grads