fleet_base.py 34.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
16
import copy
17
import warnings
18
import paddle
19
from paddle.fluid.framework import dygraph_only
20
from paddle.fluid import compiler
21
from .role_maker import UserDefinedRoleMaker, PaddleCloudRoleMaker, RoleMakerBase
22
from .strategy_compiler import StrategyCompiler
23
from .distributed_strategy import DistributedStrategy
24 25
from .meta_optimizer_factory import MetaOptimizerFactory
from .runtime_factory import RuntimeFactory
26
from paddle.fluid.wrapped_decorator import wrap_decorator
27
from paddle.fluid.dygraph import parallel_helper
28

29

30 31 32 33 34 35 36 37 38 39 40 41
def _inited_runtime_handler_(func):
    def __impl__(*args, **kwargs):
        cls = args[0]

        if cls._runtime_handle is None:
            raise ValueError("Fleet can not find suitable runtime handler")

        return func(*args, **kwargs)

    return __impl__


42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
def _is_non_distributed_check_(func):
    def __impl__(*args, **kwargs):
        cls = args[0]

        if cls._role_maker is not None and cls._role_maker._is_non_distributed(
        ) is True:
            warnings.warn(
                "%s() function doesn't work when use non_distributed fleet." %
                (func.__name__))
            return

        return func(*args, **kwargs)

    return __impl__


58
inited_runtime_handler = wrap_decorator(_inited_runtime_handler_)
59
is_non_distributed_check = wrap_decorator(_is_non_distributed_check_)
60 61


62 63 64
class Fleet(object):
    """
    Unified API for distributed training of PaddlePaddle
65
    Please reference the https://github.com/PaddlePaddle/FleetX for details
66 67 68 69 70


    Returns:
        Fleet: A Fleet instance

71
    Example for collective training:
72 73
        .. code-block:: python

74
            import paddle.distributed.fleet as fleet
75 76 77

            fleet.init(is_collective=True)

78 79 80
            strategy = fleet.DistributedStrategy()
            optimizer = paddle.optimizer.SGD(learning_rate=0.001)
            optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96

            # do distributed training


    Example for parameter server training:

        .. code-block:: python

            import paddle.distributed.fleet as fleet

            fleet.init()

            strategy = fleet.DistributedStrategy()
            optimizer = paddle.optimizer.SGD(learning_rate=0.001)
            optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)

97 98
            if fleet.is_first_worker():
                print("this is first worker")
99

100 101
            print("current node index: {}".format(fleet.worker_index()))
            print("total number of worker num: {}".format(fleet.worker_num()))
102

103 104 105
            if fleet.is_worker():
                print("this is worker")
            print("worker endpoints: {}".format(fleet.worker_endpoints(to_string=True)))
106

107 108
            print("server num: {}".format(fleet.server_num()))
            print("server endpoints: {}".format(fleet.server_endpoints(to_string=True)))
109

110 111 112
            if fleet.is_server():
                print("this is server")
            fleet.stop_worker()
113 114


115 116 117
    """

    def __init__(self):
118
        self._role_maker = None
119
        self.strategy_compiler = None
120
        self._is_collective = False
121
        self._runtime_handle = None
D
Dong Daxiang 已提交
122 123
        self._util = None
        self._context = {}
124

125 126 127 128
    def init(self, role_maker=None, is_collective=False):
        """
        Initialize role_maker in Fleet.

129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
        This function is responsible for the distributed architecture
        what you want to run your code behind.

        Args:
            role_maker (RoleMakerBase, optional): A ``RoleMakerBase`` containing the configuration
                of environment variables related to distributed training.If you did not initialize 
                the rolemaker by yourself, it will be automatically initialized to PaddleRoleMaker.
                The default value is None.
            is_collective (Boolean, optional): A ``Boolean`` variable determines whether the program 
                runs on the CPU or GPU. False means set distributed training using CPU, and True means
                GPU.The default value is False.The default value is False.
        Returns:
            None

        Examples1:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

        Examples2:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init(is_collective=True)

        Examples3:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                role = fleet.PaddleCloudRoleMaker
                fleet.init(role)
164

165
        """
166 167

        if role_maker is None:
168 169 170 171 172 173
            if isinstance(is_collective, bool):
                self._is_collective = is_collective
                self._role_maker = PaddleCloudRoleMaker(
                    is_collective=self._is_collective)
            else:
                raise ValueError(
174 175
                    "`is_collective` should be instance of `bool`, but got {}".
                    format(type(is_collective)))
176
        else:
177 178 179 180 181 182
            if isinstance(role_maker, RoleMakerBase):
                self._role_maker = role_maker
            else:
                raise ValueError(
                    "`role_maker` should be subclass of `RoleMakerBase`, but got {}".
                    format(type(role_maker)))
183
        self._role_maker._generate_role()
184

185 186 187
        import paddle.distributed.fleet as fleet
        fleet.util._set_role_maker(self._role_maker)

188
        self.strategy_compiler = StrategyCompiler()
189
        if paddle.fluid.framework.in_dygraph_mode():
190 191
            if self.worker_num() == 1:
                return
192 193 194 195 196
            if parallel_helper._is_parallel_ctx_initialized():
                warnings.warn(
                    "The dygraph parallel environment has been initialized.")
            else:
                paddle.distributed.init_parallel_env()
197 198 199 200 201 202 203 204

    def is_first_worker(self):
        """
        Check whether the node is the first instance of worker.

        Returns:
            bool: True if this is the first node of worker,
                  False if not.
205

206 207 208 209 210 211 212 213
        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.is_first_worker()

214
        """
215
        return self._role_maker._is_first_worker()
216 217 218 219 220 221 222

    def worker_index(self):
        """
        Get current worker index.

        Returns:
            int: node id
223 224 225 226 227 228 229 230

        Examples:

            .. code-block:: python
                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.worker_index()

231
        """
232
        return self._role_maker._worker_index()
233 234 235 236 237 238 239

    def worker_num(self):
        """
        Get current total worker number.

        Returns:
            int: worker numbers
D
Dong Daxiang 已提交
240
        
241 242 243 244 245 246 247
        Examples:
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.worker_num()

248
        """
249
        return self._role_maker._worker_num()
250 251 252 253 254 255 256 257

    def is_worker(self):
        """
        Check whether the node is an instance of worker.

        Returns:
            bool: True if this is a node of worker,
                  False if not.
258 259 260 261 262 263 264 265

        Examples:
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.is_worker()

266
        """
267
        return self._role_maker._is_worker()
268 269 270

    def worker_endpoints(self, to_string=False):
        """
271
        Get current worker endpoints, such as ["127.0.0.1:1001", "127.0.0.1:1002"].
272 273 274

        Returns:
            list/string: server endpoints
275 276 277 278 279 280 281 282

        Examples:
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.worker_endpoints()

283 284
        """
        if to_string:
285
            return ",".join(self._role_maker._get_trainer_endpoints())
286
        else:
287
            return self._role_maker._get_trainer_endpoints()
288 289 290 291 292 293 294

    def server_num(self):
        """
        Get current total worker number.

        Returns:
            int: server number
295 296 297 298 299 300

        Examples:
            .. code-block:: python
            import paddle.distributed.fleet as fleet
            fleet.init()
            fleet.server_num()
301
        """
302
        return len(self._role_maker._get_pserver_endpoints())
303 304 305 306 307 308 309

    def server_index(self):
        """
        Get current server index.

        Returns:
            int: node id
310 311 312 313 314 315 316 317

        Examples:
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.server_index()

318
        """
319
        return self._role_maker._server_index()
320 321 322 323 324 325 326

    def server_endpoints(self, to_string=False):
        """
        Get current server endpoints, such as ["127.0.0.1:1001", "127.0.0.1:1002"].

        Returns:
            list/string: server endpoints
327 328 329 330 331 332 333 334

        Examples:
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.server_endpoints()

335
        """
336

337
        if to_string:
338
            return ",".join(self._role_maker._get_pserver_endpoints())
339
        else:
340
            return self._role_maker._get_pserver_endpoints()
341 342 343 344 345 346 347 348

    def is_server(self):
        """
        Check whether the node is an instance of server.

        Returns:
            bool: True if this is a node of server,
                  False if not.
349 350 351 352 353 354 355 356

        Examples:

            .. code-block:: python
                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.is_server()

357
        """
358
        return self._role_maker._is_server(
359
        ) or self._role_maker._is_heter_worker()
360 361 362

    def barrier_worker(self):
        """
363 364 365 366
        barrier all workers

        Returns:
            None
367
        """
368
        self._role_maker._barrier("worker")
369

370
    @is_non_distributed_check
371
    @inited_runtime_handler
372 373
    def init_worker(self):
        """
374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391
        initialize `Communicator` for parameter server training.


        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_worker()

392 393 394
        """
        self._runtime_handle._init_worker()

395
    @is_non_distributed_check
396
    @inited_runtime_handler
397
    def init_server(self, *args, **kwargs):
398
        """
399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417
        init_server executor to initialize startup program,
        if the `args` is not empty, it will run load_persistables for increment training.


        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_server()

418
        """
419
        self._runtime_handle._init_server(*args, **kwargs)
420

421
    @is_non_distributed_check
422
    @inited_runtime_handler
423 424
    def run_server(self):
        """
425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442
        run server will run pserver main program with executor.

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                if fleet.is_server():
                    fleet.init_server()

443 444 445
        """
        self._runtime_handle._run_server()

446
    @is_non_distributed_check
447
    @inited_runtime_handler
448 449
    def stop_worker(self):
        """
450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466
        stop `Communicator` and give training complete notice to parameter server.

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_server()

467 468 469
        """
        self._runtime_handle._stop_worker()

470 471 472 473 474 475 476
    def save_inference_model(self,
                             executor,
                             dirname,
                             feeded_var_names,
                             target_vars,
                             main_program=None,
                             export_for_deployment=True):
477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496
        """
        save inference model for inference.

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_server()

        """

497 498 499 500 501
        self._runtime_handle._save_inference_model(
            executor, dirname, feeded_var_names, target_vars, main_program,
            export_for_deployment)

    def save_persistables(self, executor, dirname, main_program=None):
502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541
        """

        saves all persistable variables from :code:`main_program` to
        the folder :code:`dirname`. You can refer to

        The :code:`dirname` is used to specify the folder where persistable variables
        are going to be saved. If you would like to save variables in separate
        files, set :code:`filename` None.

        Args:
            executor(Executor): The executor to run for saving persistable variables.
                                You can refer to :ref:`api_guide_executor_en` for
                                more details.

            dirname(str, optional): The saving directory path.
                                When you need to save the parameter to the memory, set it to None.
            main_program(Program, optional): The program whose persistbale variables will
                                             be saved. Default: None.


        Returns:
            None

        Examples:

            .. code-block:: text

                import paddle.distributed.fleet as fleet
                import paddle.fluid as fluid

                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                exe = fluid.Executor(fluid.CPUPlace())
                fleet.save_persistables(exe, "dirname", fluid.default_main_program())

        """

542 543
        self._runtime_handle._save_persistables(executor, dirname, main_program)

544
    def distributed_optimizer(self, optimizer, strategy=None):
545
        """
546 547 548 549 550 551 552 553 554
        Optimizer for distributed training.

        For the distributed training, this method would rebuild a new instance of DistributedOptimizer.
        Which has basic Optimizer function and special features for distributed training.

        Args:
            optimizer(Optimizer): The executor to run for init server.
            strategy(DistributedStrategy): Extra properties for distributed optimizer.

555
        Returns:
556
            Fleet: instance of fleet.
557 558

        Examples:
559

560
            .. code-block:: python
561 562 563 564 565 566 567 568

                import paddle.distributed.fleet as fleet
                role = fleet.role_maker.PaddleCloudRoleMaker(is_collective=True)
                fleet.init(role)
                strategy = fleet.DistributedStrategy()
                optimizer = paddle.optimizer.SGD(learning_rate=0.001)
                optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)

569 570
        """
        self.user_defined_optimizer = optimizer
571 572 573
        if paddle.fluid.framework.in_dygraph_mode():
            return self

574 575
        if strategy == None:
            strategy = DistributedStrategy()
D
Dong Daxiang 已提交
576 577 578

        self._user_defined_strategy = copy.deepcopy(strategy)
        self._context = {}
579 580
        return self

581 582 583
    @dygraph_only
    def distributed_model(self, model):
        """
584 585 586 587 588 589 590
        Return distributed data parallel model (Only work in dygraph mode)

        Args:
            model (Layer): the user-defind model which inherits Layer.

        Returns:
            distributed data parallel model which inherits Layer.
591 592

        Examples:
593

594 595
            .. code-block:: python

596 597 598 599 600 601 602 603 604
                import paddle
                import paddle.nn as nn
                from paddle.distributed import fleet

                class LinearNet(nn.Layer):
                    def __init__(self):
                        super(LinearNet, self).__init__()
                        self._linear1 = nn.Linear(10, 10)
                        self._linear2 = nn.Linear(10, 1)
605

606 607
                    def forward(self, x):
                        return self._linear2(self._linear1(x))
608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637

                # 1. enable dynamic mode
                paddle.disable_static()

                # 2. initialize fleet environment
                fleet.init(is_collective=True)

                # 3. create layer & optimizer
                layer = LinearNet()
                loss_fn = nn.MSELoss()
                adam = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=layer.parameters())

                # 4. get data_parallel model using fleet
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)

                # 5. run layer
                inputs = paddle.randn([10, 10], 'float32')
                outputs = dp_layer(inputs)
                labels = paddle.randn([10, 1], 'float32')
                loss = loss_fn(outputs, labels)

                print("loss:", loss.numpy())

                loss.backward()

                adam.step()
                adam.clear_grad()

638

639 640 641 642 643 644 645 646 647
        """
        assert model is not None
        self.model = paddle.DataParallel(model)
        return self.model

    @dygraph_only
    def state_dict(self):
        """
        Get state dict information from optimizer.
648
        (Only work in dygraph mode)
649 650 651 652 653 654 655

        Returns: 
            state_dict(dict) : dict contains all the Tensor used by optimizer

        Examples:
            .. code-block:: python

656 657 658 659 660 661
                import numpy as np
                import paddle
                from paddle.distributed import fleet

                paddle.disable_static()
                fleet.init(is_collective=True)
662

663 664
                value = np.arange(26).reshape(2, 13).astype("float32")
                a = paddle.fluid.dygraph.to_variable(value)
665

666 667
                layer = paddle.nn.Linear(13, 5)
                adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
668

669 670 671
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)
                state_dict = adam.state_dict()
672 673 674 675 676 677 678 679
        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.state_dict()

    @dygraph_only
    def set_state_dict(self, state_dict):
        """
        Load optimizer state dict.
680
        (Only work in dygraph mode)
681 682 683 684

        Args: 
            state_dict(dict) : Dict contains all the Tensor needed by optimizer

685 686
        Returns:
            None
687 688 689 690

        Examples:
            .. code-block:: python

691 692 693
                import numpy as np
                import paddle
                from paddle.distributed import fleet
694

695 696 697 698 699
                paddle.disable_static()
                fleet.init(is_collective=True)

                value = np.arange(26).reshape(2, 13).astype("float32")
                a = paddle.fluid.dygraph.to_variable(value)
700

701 702
                layer = paddle.nn.Linear(13, 5)
                adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
703

704 705 706 707 708 709
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)
                state_dict = adam.state_dict()
                paddle.framework.save(state_dict, "paddle_dy")
                para_state_dict, opti_state_dict = paddle.framework.load( "paddle_dy")
                adam.set_state_dict(opti_state_dict)
710 711 712 713 714 715 716 717
        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.set_state_dict(state_dict)

    @dygraph_only
    def set_lr(self, value):
        """
        Set the value of the learning rate manually in the optimizer. 
718
        (Only work in dygraph mode)
719

720 721 722
        Args:
            value (float|Tensor): the value of learning rate

723 724
        Returns: 
            None 
725 726 727 728

        Examples:
            .. code-block:: python

729 730 731
                import numpy as np
                import paddle
                from paddle.distributed import fleet
732

733 734
                paddle.disable_static()
                fleet.init(is_collective=True)
735

736 737
                value = np.arange(26).reshape(2, 13).astype("float32")
                a = paddle.fluid.dygraph.to_variable(value)
738

739 740
                layer = paddle.nn.Linear(13, 5)
                adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
741

742 743 744 745 746 747 748 749 750 751 752 753 754 755
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)

                lr_list = [0.2, 0.3, 0.4, 0.5, 0.6]
                for i in range(5):
                    adam.set_lr(lr_list[i])
                    lr = adam.get_lr()
                    print("current lr is {}".format(lr))
                # Print:
                #    current lr is 0.2
                #    current lr is 0.3
                #    current lr is 0.4
                #    current lr is 0.5
                #    current lr is 0.6
756 757 758 759 760 761 762 763
        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.set_lr(value)

    @dygraph_only
    def get_lr(self):
        """
        Get current step learning rate.
764
        (Only work in dygraph mode)
765 766 767 768 769 770 771

        Returns:
            float: The learning rate of the current step.

        Examples:
            .. code-block:: python

772 773 774 775 776 777
                import numpy as np
                import paddle
                from paddle.distributed import fleet

                paddle.disable_static()
                fleet.init(is_collective=True)
778

779 780
                value = np.arange(26).reshape(2, 13).astype("float32")
                a = paddle.fluid.dygraph.to_variable(value)
781

782 783
                layer = paddle.nn.Linear(13, 5)
                adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
784

785 786
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)
787

788 789
                lr = adam.get_lr()
                print(lr) # 0.01
790 791 792 793 794 795 796 797
        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.get_lr()

    @dygraph_only
    def step(self):
        """
        Execute the optimizer once.
798
        (Only work in dygraph mode)
799

800 801
        Returns:
            None
802 803 804 805

        Examples:
            .. code-block:: python

806 807 808
                import paddle
                import paddle.nn as nn
                from paddle.distributed import fleet
809

810 811 812 813 814
                class LinearNet(nn.Layer):
                    def __init__(self):
                        super(LinearNet, self).__init__()
                        self._linear1 = nn.Linear(10, 10)
                        self._linear2 = nn.Linear(10, 1)
815

816 817
                    def forward(self, x):
                        return self._linear2(self._linear1(x))
818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855

                # 1. enable dynamic mode
                paddle.disable_static()

                # 2. initialize fleet environment
                fleet.init(is_collective=True)

                # 3. create layer & optimizer
                layer = LinearNet()
                loss_fn = nn.MSELoss()
                adam = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=layer.parameters())

                # 4. get data_parallel model using fleet
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)

                # 5. run layer
                inputs = paddle.randn([10, 10], 'float32')
                outputs = dp_layer(inputs)
                labels = paddle.randn([10, 1], 'float32')
                loss = loss_fn(outputs, labels)

                print("loss:", loss.numpy())

                loss.backward()

                adam.step()
                adam.clear_grad()


        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.step()

    @dygraph_only
    def clear_grad(self):
        """
856 857
        Clear the gradients of all optimized parameters for model.
        (Only work in dygraph mode)
858

859 860
        Returns: 
            None
861 862 863 864

        Examples:
            .. code-block:: python

865 866 867
                import paddle
                import paddle.nn as nn
                from paddle.distributed import fleet
868

869 870 871 872 873
                class LinearNet(nn.Layer):
                    def __init__(self):
                        super(LinearNet, self).__init__()
                        self._linear1 = nn.Linear(10, 10)
                        self._linear2 = nn.Linear(10, 1)
874

875 876
                    def forward(self, x):
                        return self._linear2(self._linear1(x))
877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910

                # 1. enable dynamic mode
                paddle.disable_static()

                # 2. initialize fleet environment
                fleet.init(is_collective=True)

                # 3. create layer & optimizer
                layer = LinearNet()
                loss_fn = nn.MSELoss()
                adam = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=layer.parameters())

                # 4. get data_parallel model using fleet
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)

                # 5. run layer
                inputs = paddle.randn([10, 10], 'float32')
                outputs = dp_layer(inputs)
                labels = paddle.randn([10, 1], 'float32')
                loss = loss_fn(outputs, labels)

                print("loss:", loss.numpy())

                loss.backward()

                adam.step()
                adam.clear_grad()

        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.clear_grad()

D
Dong Daxiang 已提交
911 912 913 914 915 916 917 918 919
    def _final_strategy(self):
        if "valid_strategy" not in self._context:
            print(
                "WARNING: You may need to call minimize function before this function is called"
            )
            return {}
        else:
            return self._context["valid_strategy"]

920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942
    def minimize(self,
                 loss,
                 startup_program=None,
                 parameter_list=None,
                 no_grad_set=None):
        """
        Add distributed operations to minimize ``loss`` by updating ``parameter_list``.

        Args:
            loss (Variable): A ``Variable`` containing the value to minimize.
            startup_program (Program, optional): :ref:`api_fluid_Program` for
                initializing parameters in ``parameter_list``. The default value
                is None, at this time :ref:`api_fluid_default_startup_program` will be used.
            parameter_list (Iterable, optional): Iterable of ``Variable`` or ``Variable.name`` to update
                to minimize ``loss``. The default value is None, at this time all parameters
                will be updated.
            no_grad_set (set, optional): Set of ``Variable``  or ``Variable.name`` that don't need
                to be updated. The default value is None.

        Returns:
            tuple: tuple (optimize_ops, params_grads), A list of operators appended
            by minimize and a list of (param, grad) variable pairs, param is
            ``Parameter``, grad is the gradient value corresponding to the parameter.
943 944
            The returned tuple can be passed to ``fetch_list`` in ``Executor.run()`` to
            indicate program pruning. If so, the program will be pruned by ``feed`` and
945 946 947
            ``fetch_list`` before run, see details in ``Executor``.

        Examples:
948
            .. code-block:: python
949

950 951
                import paddle
                import paddle.distributed.fleet as fleet
952

953 954 955 956 957 958 959 960 961 962 963 964
                fc_1 = paddle.fluid.layers.fc(input=input_x, size=hid_dim, act='tanh')
                fc_2 = paddle.fluid.layers.fc(input=fc_1, size=hid_dim, act='tanh')
                prediction = paddle.fluid.layers.fc(input=[fc_2], size=label_dim, act='softmax')
                cost = paddle.fluid.layers.cross_entropy(input=prediction, label=input_y)
                avg_cost = paddle.fluid.layers.mean(x=cost)

                role = fleet.role_maker.PaddleCloudRoleMaker(is_collective=True)
                fleet.init(role)
                strategy = fleet.DistributedStrategy()
                optimizer = paddle.optimizer.SGD(learning_rate=0.001)
                optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
                optimizer.minimize(avg_cost)
965

966
                # for more examples, please reference https://github.com/PaddlePaddle/FleetX
967 968

        """
D
Dong Daxiang 已提交
969 970 971
        context = {}
        context["user_defined_strategy"] = copy.deepcopy(
            self._user_defined_strategy)
972 973 974
        if paddle.fluid.framework.in_dygraph_mode():
            # imitate target optimizer retrieval
            target_opt = self.user_defined_optimizer
D
Dong Daxiang 已提交
975
            self._context = context
976 977
            return target_opt.minimize(loss)

978 979
        # cache original feed forward program
        self.origin_main_program = loss.block.program
980 981
        context["origin_main_program"] = self.origin_main_program
        context["loss"] = loss
982 983
        if startup_program == None:
            self.origin_startup_program = \
984 985
                paddle.static.default_startup_program().clone(for_test=False)
            startup_program = paddle.static.default_startup_program()
986 987 988
        else:
            self.origin_startup_program = \
                startup_program.clone(for_test=False)
989

990 991
        context["origin_startup_program"] = startup_program
        context["role_maker"] = self._role_maker
992 993 994 995 996

        # compile time
        distributed_optimizer_list = \
            MetaOptimizerFactory()._get_valid_meta_optimizers(
                self.user_defined_optimizer)
D
Dong Daxiang 已提交
997

D
Dong Daxiang 已提交
998 999 1000
        context["user_defined_strategy"] = copy.deepcopy(
            self._user_defined_strategy)
        copy_user_defined_strategy = copy.deepcopy(self._user_defined_strategy)
1001 1002 1003 1004 1005 1006

        # trigger the auto-parallel in very strict condition
        # strategy = DistributedStrategy()
        # strategy.auto = True
        # optimizer = paddle.optimizer.SGD(learning_rate=0.1)
        # optimizer = fleet.distributed_optimizer(optimizer, strategy)
D
Dong Daxiang 已提交
1007
        if copy_user_defined_strategy._is_strict_auto():
1008 1009
            # turn on all the strategy for each optimizer
            for opt in distributed_optimizer_list:
D
Dong Daxiang 已提交
1010
                opt._enable_strategy(copy_user_defined_strategy, context)
1011

1012 1013
        valid_optimizer_list = []
        valid_graph_optimizer_list = []
D
Dong Daxiang 已提交
1014
        can_not_apply_optimizer_list = []
1015 1016 1017 1018
        # recall meta optimizers for ranking
        for opt in distributed_optimizer_list:
            opt._set_basic_info(loss, self._role_maker,
                                self.user_defined_optimizer,
D
Dong Daxiang 已提交
1019
                                copy_user_defined_strategy)
1020 1021
            if opt._can_apply() and not opt._is_graph_out():
                valid_optimizer_list.append(opt)
D
Dong Daxiang 已提交
1022
            elif opt._can_apply() and opt._is_graph_out():
1023
                valid_graph_optimizer_list.append(opt)
D
Dong Daxiang 已提交
1024 1025
            else:
                can_not_apply_optimizer_list.append(opt)
1026
        # combine recalled meta optimizers to be a valid meta optimizer
D
Dong Daxiang 已提交
1027
        meta_optimizer, graph_optimizer = \
1028 1029
            self.strategy_compiler.generate_optimizer(
                loss, self._role_maker, self.user_defined_optimizer,
D
Dong Daxiang 已提交
1030
                copy_user_defined_strategy, valid_optimizer_list,
1031
                valid_graph_optimizer_list)
D
Dong Daxiang 已提交
1032

D
Dong Daxiang 已提交
1033
        valid_strategy = self.strategy_compiler._get_valid_strategy(
D
Dong Daxiang 已提交
1034 1035 1036
            copy_user_defined_strategy, can_not_apply_optimizer_list)

        context["valid_strategy"] = copy.deepcopy(valid_strategy)
1037

D
Dong Daxiang 已提交
1038
        self._context = context
1039

D
Dong Daxiang 已提交
1040
        self.valid_strategy = valid_strategy
1041
        self.valid_strategy._enable_env()
D
Dong Daxiang 已提交
1042

1043 1044
        optimize_ops = []
        params_grads = []
1045

1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059
        if self._role_maker._is_non_distributed() and not self._is_collective:
            if self._runtime_handle is None:
                self._runtime_handle = RuntimeFactory()._create_runtime(context)

            compiled_program = compiler.CompiledProgram(
                self.origin_main_program).with_data_parallel(
                    loss_name=loss.name, share_vars_from=None)
            loss.block.program._graph = compiled_program
            return self.user_defined_optimizer.minimize(
                loss,
                startup_program=startup_program,
                parameter_list=parameter_list,
                no_grad_set=no_grad_set)

1060 1061 1062 1063 1064 1065
        if meta_optimizer:
            optimize_ops, params_grads = meta_optimizer.minimize(
                loss,
                startup_program=startup_program,
                parameter_list=parameter_list,
                no_grad_set=no_grad_set)
1066

1067
            default_program = paddle.static.default_main_program()
1068 1069 1070 1071

            if id(default_program) != id(loss.block.program):
                paddle.fluid.framework.switch_main_program(loss.block.program)

1072 1073 1074 1075 1076 1077
        else:
            optimize_ops, params_grads = self.user_defined_optimizer.minimize(
                loss,
                startup_program=startup_program,
                parameter_list=parameter_list,
                no_grad_set=no_grad_set)
1078

1079 1080
        context["program_optimize_ops"] = optimize_ops
        context["program_params_grads"] = params_grads
1081

1082
        if graph_optimizer:
D
Dong Daxiang 已提交
1083
            optimize_ops, params_grads = graph_optimizer.minimize(
1084 1085 1086 1087 1088 1089 1090 1091
                loss,
                startup_program=startup_program,
                parameter_list=parameter_list,
                no_grad_set=no_grad_set)
            # since we do not encourage users to use graph operations
            # if a graph optimizer takes effect, mostly
            # optimizers_ops and params_grads are None
            # i.e. users can not modify current computation graph anymore
1092 1093 1094
            context["graph_optimize_ops"] = optimize_ops
            context["graph_optimize_grads"] = params_grads

1095
        if self._runtime_handle is None:
1096
            self._runtime_handle = RuntimeFactory()._create_runtime(context)
1097

1098 1099
        import paddle.distributed.fleet as fleet
        fleet.util._set_strategy(context["valid_strategy"])
1100 1101

        return optimize_ops, params_grads