fleet_base.py 53.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
16
import copy
17
import warnings
18
import paddle
19
import os
20
import numpy as np
21
from paddle.fluid.framework import dygraph_only, _global_flags
22
from paddle.fluid import compiler
23
from .role_maker import UserDefinedRoleMaker, PaddleCloudRoleMaker, RoleMakerBase
24
from .strategy_compiler import StrategyCompiler
25
from .distributed_strategy import DistributedStrategy
26 27
from .meta_optimizer_factory import MetaOptimizerFactory
from .runtime_factory import RuntimeFactory
28
from paddle.fluid.wrapped_decorator import wrap_decorator
29
from paddle.fluid.dygraph import parallel_helper
30
from paddle.fluid.ir import apply_build_strategy
31
from . import topology as tp
32
from .topology import ParallelMode
33
from ..meta_parallel import TensorParallel, model_parallel_random_seed
J
JZ-LIANG 已提交
34
from ..meta_parallel import PipelineParallel, ShardingParallel
35
from ..meta_optimizers import HybridParallelOptimizer
36
from ..meta_optimizers import HybridParallelGradScaler
37

38 39
__all__ = []

40

41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
def apply_ir_passes(main_program, startup_program, config):
    build_strategy = config._user_defined_strategy.build_strategy._copy()
    if not _global_flags()['FLAGS_apply_pass_to_program']:
        return build_strategy

    pipeline_opt = getattr(main_program, "_pipeline_opt", {})
    if pipeline_opt:
        main_program = pipeline_opt["section_program"]
        startup_program = startup_program._pipeline_opt["startup_program"]

    pass_attrs = {"use_cuda": config._is_collective}
    fuse_all_reduce = config._user_defined_strategy.fuse_all_reduce_ops
    if fuse_all_reduce and build_strategy.fuse_all_optimizer_ops:
        # FIXME(zjl): currently, fuse_all_optimizer_ops
        # have conflict with fuse_all_reduce_ops because 
        # RawProgramOptimizer also inserts coalesce_tensor 
        # into program. These two procedures may conflict  
        # in which vars are to be fused. 
        warnings.warn(
            'Currently, the fuse_all_optimizer_ops pass has conflict with fuse_all_reduce_ops pass. Disable the fuse_all_optimizer_ops pass temporarily.'
        )
        build_strategy.fuse_all_optimizer_ops = False

    return apply_build_strategy(main_program, startup_program, build_strategy,
                                pass_attrs)


68 69 70 71 72 73 74 75 76 77 78 79
def _inited_runtime_handler_(func):
    def __impl__(*args, **kwargs):
        cls = args[0]

        if cls._runtime_handle is None:
            raise ValueError("Fleet can not find suitable runtime handler")

        return func(*args, **kwargs)

    return __impl__


80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
def _is_non_distributed_check_(func):
    def __impl__(*args, **kwargs):
        cls = args[0]

        if cls._role_maker is not None and cls._role_maker._is_non_distributed(
        ) is True:
            warnings.warn(
                "%s() function doesn't work when use non_distributed fleet." %
                (func.__name__))
            return

        return func(*args, **kwargs)

    return __impl__


96
inited_runtime_handler = wrap_decorator(_inited_runtime_handler_)
97
is_non_distributed_check = wrap_decorator(_is_non_distributed_check_)
98 99


100 101 102
class Fleet(object):
    """
    Unified API for distributed training of PaddlePaddle
103
    Please reference the https://github.com/PaddlePaddle/FleetX for details
104 105 106 107 108


    Returns:
        Fleet: A Fleet instance

109
    Example for collective training:
1
123malin 已提交
110

111 112
        .. code-block:: python

1
123malin 已提交
113 114
            import paddle
            paddle.enable_static()
115
            import paddle.distributed.fleet as fleet
116 117 118

            fleet.init(is_collective=True)

119 120 121
            strategy = fleet.DistributedStrategy()
            optimizer = paddle.optimizer.SGD(learning_rate=0.001)
            optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
122 123 124 125 126 127 128 129

            # do distributed training


    Example for parameter server training:

        .. code-block:: python

1
123malin 已提交
130 131
            import paddle
            paddle.enable_static()
132 133
            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
S
ShenLiang 已提交
134
            fleet.init(strategy=strategy)
135

136
            optimizer = paddle.optimizer.SGD(learning_rate=0.001)
137
            optimizer = fleet.distributed_optimizer(optimizer)
138

139 140
            if fleet.is_first_worker():
                print("this is first worker")
141

142 143
            print("current node index: {}".format(fleet.worker_index()))
            print("total number of worker num: {}".format(fleet.worker_num()))
144

145 146 147
            if fleet.is_worker():
                print("this is worker")
            print("worker endpoints: {}".format(fleet.worker_endpoints(to_string=True)))
148

149 150
            print("server num: {}".format(fleet.server_num()))
            print("server endpoints: {}".format(fleet.server_endpoints(to_string=True)))
151

152 153 154
            if fleet.is_server():
                print("this is server")
            fleet.stop_worker()
155 156


157 158 159
    """

    def __init__(self):
160
        self._role_maker = None
161
        self.strategy_compiler = None
162
        self._is_collective = False
163
        self._runtime_handle = None
D
Dong Daxiang 已提交
164 165
        self._util = None
        self._context = {}
166

167
    def init(self, role_maker=None, is_collective=False, strategy=None):
168 169 170
        """
        Initialize role_maker in Fleet.

171 172 173 174 175 176 177 178 179 180 181
        This function is responsible for the distributed architecture
        what you want to run your code behind.

        Args:
            role_maker (RoleMakerBase, optional): A ``RoleMakerBase`` containing the configuration
                of environment variables related to distributed training.If you did not initialize 
                the rolemaker by yourself, it will be automatically initialized to PaddleRoleMaker.
                The default value is None.
            is_collective (Boolean, optional): A ``Boolean`` variable determines whether the program 
                runs on the CPU or GPU. False means set distributed training using CPU, and True means
                GPU.The default value is False.The default value is False.
182 183 184 185
            strategy (DistributedStrategy): Extra properties for distributed training. 
                For details, please refer to paddle.distributed.fleet.DistributedStrategy. Default: None.


186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
        Returns:
            None

        Examples1:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

        Examples2:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init(is_collective=True)

        Examples3:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
1
123malin 已提交
208
                role = fleet.PaddleCloudRoleMaker()
209
                fleet.init(role)
210

211 212 213 214 215 216
        Examples4:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                strategy = fleet.DistributedStrategy()
S
ShenLiang 已提交
217
                fleet.init(strategy=strategy)
218

219
        """
S
ShenLiang 已提交
220 221 222
        if strategy is None:
            strategy = DistributedStrategy()
        self._user_defined_strategy = copy.deepcopy(strategy)
223 224

        if role_maker is None:
225 226 227 228 229 230
            if isinstance(is_collective, bool):
                self._is_collective = is_collective
                self._role_maker = PaddleCloudRoleMaker(
                    is_collective=self._is_collective)
            else:
                raise ValueError(
231 232
                    "`is_collective` should be instance of `bool`, but got {}".
                    format(type(is_collective)))
233
        else:
234 235
            if isinstance(role_maker, RoleMakerBase):
                self._role_maker = role_maker
236
                self._is_collective = role_maker._is_collective
237 238 239 240
            else:
                raise ValueError(
                    "`role_maker` should be subclass of `RoleMakerBase`, but got {}".
                    format(type(role_maker)))
241
        self._role_maker._generate_role()
242

243 244 245
        import paddle.distributed.fleet as fleet
        fleet.util._set_role_maker(self._role_maker)

246
        self.strategy_compiler = StrategyCompiler()
247 248 249 250 251 252 253 254 255

        if self._role_maker._is_non_distributed() and self._is_collective:
            if paddle.fluid.core.is_compiled_with_cuda():
                gpus_num = paddle.fluid.core.get_cuda_device_count()
                if gpus_num != 1:
                    raise ValueError(
                        "CUDA_VISIBLE_DEVICES shoule be set only 1 card if you use `python` to launch fleet program."
                    )

256
        if paddle.fluid.framework.in_dygraph_mode():
257
            if self.worker_num() == 1:
258 259 260
                # if worker_num is 1, should construct default topology & hcg
                self._topology = tp.CommunicateTopology()
                self._hcg = tp.HybridCommunicateGroup(self._topology)
261
                return
262 263 264 265
            if parallel_helper._is_parallel_ctx_initialized():
                warnings.warn(
                    "The dygraph parallel environment has been initialized.")
            else:
266 267 268 269 270 271 272 273 274
                # FLAGS_nccl_nrings is used for dynamic graph multi-stream communication
                if "FLAGS_nccl_nrings" in os.environ:
                    warnings.warn(
                        "You have set the environment variable FLAGS_nccl_nrings "
                        "outside the program, so the nccl_comm_num in "
                        "DistributedStrategy will not take effect here.")
                else:
                    os.environ["FLAGS_nccl_nrings"] = str(
                        self._user_defined_strategy.nccl_comm_num)
275
                paddle.distributed.init_parallel_env()
276

277 278 279 280 281 282 283
            # init hybrid parallel environment in dygraph
            if tp._HYBRID_PARALLEL_GROUP is None:
                self._init_hybrid_parallel_env()
            else:
                warnings.warn(
                    "The dygraph hybrid parallel environment has been initialized."
                )
W
WangXi 已提交
284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299
        elif self._is_collective:
            use_sharding = self._user_defined_strategy.sharding

            # global group
            global_rank = self.worker_index()
            global_world_size = self.worker_num()
            # NOTE(wangxi): see sharding_optimizer
            global_ring_id = 3 if use_sharding else 0
            global_ranks = list(range(global_world_size))

            if tp._HYBRID_PARALLEL_GROUP is None: tp._CommunicateGroup()
            cg = tp._HYBRID_PARALLEL_GROUP
            self._hcg = cg
            cg.set_comm_group('global', global_rank, global_world_size,
                              global_ring_id, global_ranks)

Y
Yuang Liu 已提交
300 301 302
            use_tensor_parallel = self._user_defined_strategy.tensor_parallel
            use_mp = use_sharding or use_tensor_parallel

W
WangXi 已提交
303
            # hybrid group
Y
Yuang Liu 已提交
304 305 306 307 308 309 310 311 312 313 314 315 316 317 318
            if use_mp is False: return

            mp_degree_sharding = 1
            mp_degree_tensor_parallel = 1
            if use_sharding:
                sharding_configs = self._user_defined_strategy.sharding_configs
                mp_degree_sharding = int(sharding_configs['mp_degree'])

            if use_tensor_parallel:
                tensor_parallel_configs = self._user_defined_strategy.tensor_parallel_configs
                mp_degree_tensor_parallel = int(tensor_parallel_configs[
                    'tensor_parallel_degree'])

            if use_sharding and use_tensor_parallel:
                assert mp_degree_sharding == mp_degree_tensor_parallel
W
WangXi 已提交
319

Y
Yuang Liu 已提交
320
            mp_degree = mp_degree_sharding if use_sharding else mp_degree_tensor_parallel
W
WangXi 已提交
321 322 323 324 325 326 327 328 329 330 331 332 333

            if mp_degree > 1:
                assert global_world_size % mp_degree == 0
                # NOTE(wangxi): mp_ring_id sync with sharding_optimizer.py _build_groups
                mp_ring_id = 0
                mp_rank = global_rank % mp_degree
                mp_group_id = global_rank // mp_degree
                mp_group_ranks = [
                    idx for idx in global_ranks
                    if idx // mp_degree == mp_group_id
                ]
                cg.set_comm_group('model', mp_rank, mp_degree, mp_ring_id,
                                  mp_group_ranks)
334 335 336 337 338 339 340 341

    def _init_hybrid_parallel_env(self):
        """initialize the hybrid environment
        """
        self.hybrid_configs = self._user_defined_strategy.hybrid_configs
        self.dp_degree = self.hybrid_configs["dp_degree"]
        self.mp_degree = self.hybrid_configs["mp_degree"]
        self.pp_degree = self.hybrid_configs["pp_degree"]
J
JZ-LIANG 已提交
342
        self.sharding_degree = self.hybrid_configs["sharding_degree"]
343 344 345

        assert self.mp_degree >= 0, "mp_degree should be greater or equal to 0"
        assert self.pp_degree >= 0, "pp_degree should be greater or equal to 0"
J
JZ-LIANG 已提交
346
        assert self.sharding_degree >= 0, "sharding_degree should be greater or equal to 0"
347 348 349 350 351 352 353 354 355 356 357

        self.mp_degree = max(self.mp_degree, 1)
        self.pp_degree = max(self.pp_degree, 1)

        if self.dp_degree < 0:
            nranks = paddle.distributed.get_world_size()
            self.dp_degree = nranks // (self.mp_degree * self.pp_degree)

        self.dp_degree = max(self.dp_degree, 1)

        self._topology = tp.CommunicateTopology(
J
JZ-LIANG 已提交
358 359 360 361 362
            hybrid_group_names=["data", "pipe", "sharding", "model"],
            dims=[
                self.dp_degree, self.pp_degree, self.sharding_degree,
                self.mp_degree
            ])
363 364 365

        self._hcg = tp.HybridCommunicateGroup(self._topology)

366 367 368 369 370 371 372 373
        if self.mp_degree > 1:
            tensor_parallel_configs = self._user_defined_strategy.tensor_parallel_configs
            tensor_init_seed = tensor_parallel_configs["tensor_init_seed"]
            if tensor_init_seed == -1:
                model_parallel_random_seed()
            else:
                model_parallel_random_seed(tensor_init_seed)

374 375 376 377 378 379 380 381
    def get_hybrid_communicate_group(self):
        assert self._hcg is not None
        return self._hcg

    def get_hybrid_parallel_topology(self):
        assert self._topology is not None
        return self._topology

382 383 384 385 386 387 388
    def is_first_worker(self):
        """
        Check whether the node is the first instance of worker.

        Returns:
            bool: True if this is the first node of worker,
                  False if not.
389

390 391 392 393 394 395 396 397
        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.is_first_worker()

398
        """
399
        return self._role_maker._is_first_worker()
400 401 402 403 404 405 406

    def worker_index(self):
        """
        Get current worker index.

        Returns:
            int: node id
407 408 409 410

        Examples:

            .. code-block:: python
1
123malin 已提交
411

412 413 414 415
                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.worker_index()

416
        """
417
        return self._role_maker._worker_index()
418 419 420 421 422 423 424

    def worker_num(self):
        """
        Get current total worker number.

        Returns:
            int: worker numbers
1
123malin 已提交
425

426
        Examples:
1
123malin 已提交
427

428 429 430 431 432 433
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.worker_num()

434
        """
435
        return self._role_maker._worker_num()
436

437 438 439 440 441 442 443 444 445 446 447 448
    def node_num(self):
        return self._role_maker._get_node_num()

    def local_rank(self):
        return self._role_maker._get_local_rank()

    def local_device_ids(self):
        return self._role_maker._get_local_device_ids()

    def world_device_ids(self):
        return self._role_maker._get_world_device_ids()

449 450 451 452 453 454 455
    def is_worker(self):
        """
        Check whether the node is an instance of worker.

        Returns:
            bool: True if this is a node of worker,
                  False if not.
456 457

        Examples:
1
123malin 已提交
458

459 460 461 462 463 464
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.is_worker()

465
        """
466
        return self._role_maker._is_worker()
467 468 469

    def worker_endpoints(self, to_string=False):
        """
470
        Get current worker endpoints, such as ["127.0.0.1:1001", "127.0.0.1:1002"].
471 472 473

        Returns:
            list/string: server endpoints
474 475

        Examples:
1
123malin 已提交
476

477 478 479 480 481 482
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.worker_endpoints()

483 484
        """
        if to_string:
485
            return ",".join(self._role_maker._get_trainer_endpoints())
486
        else:
487
            return self._role_maker._get_trainer_endpoints()
488 489 490 491 492 493 494

    def server_num(self):
        """
        Get current total worker number.

        Returns:
            int: server number
495 496

        Examples:
1
123malin 已提交
497

498
            .. code-block:: python
1
123malin 已提交
499 500 501 502

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.server_num()
503
        """
504
        return len(self._role_maker._get_pserver_endpoints())
505 506 507 508 509 510 511

    def server_index(self):
        """
        Get current server index.

        Returns:
            int: node id
512 513

        Examples:
1
123malin 已提交
514

515 516 517 518 519 520
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.server_index()

521
        """
522
        return self._role_maker._server_index()
523 524 525 526 527 528 529

    def server_endpoints(self, to_string=False):
        """
        Get current server endpoints, such as ["127.0.0.1:1001", "127.0.0.1:1002"].

        Returns:
            list/string: server endpoints
530 531

        Examples:
1
123malin 已提交
532

533 534 535 536 537 538
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.server_endpoints()

539
        """
540

541
        if to_string:
542
            return ",".join(self._role_maker._get_pserver_endpoints())
543
        else:
544
            return self._role_maker._get_pserver_endpoints()
545 546 547 548 549 550 551 552

    def is_server(self):
        """
        Check whether the node is an instance of server.

        Returns:
            bool: True if this is a node of server,
                  False if not.
553 554 555 556

        Examples:

            .. code-block:: python
1
123malin 已提交
557

558 559 560 561
                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.is_server()

562
        """
563
        return self._role_maker._is_server(
564
        ) or self._role_maker._is_heter_worker()
565 566 567

    def barrier_worker(self):
        """
568 569 570 571
        barrier all workers

        Returns:
            None
572
        """
573
        self._role_maker._barrier("worker")
574

575
    @is_non_distributed_check
576
    @inited_runtime_handler
577 578
    def init_worker(self):
        """
579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596
        initialize `Communicator` for parameter server training.


        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_worker()

597 598 599
        """
        self._runtime_handle._init_worker()

600
    @is_non_distributed_check
601
    @inited_runtime_handler
602
    def init_server(self, *args, **kwargs):
603
        """
604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622
        init_server executor to initialize startup program,
        if the `args` is not empty, it will run load_persistables for increment training.


        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_server()

623
        """
624
        self._runtime_handle._init_server(*args, **kwargs)
625

T
Thunderbrook 已提交
626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648
    def load_model(self, path, mode):
        """
        load fleet model from path


        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.load_model("path", "mode")

        """
        self._runtime_handle.load_model(path, mode)

649
    @is_non_distributed_check
650
    @inited_runtime_handler
651 652
    def run_server(self):
        """
653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670
        run server will run pserver main program with executor.

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                if fleet.is_server():
                    fleet.init_server()

671 672 673
        """
        self._runtime_handle._run_server()

674
    @is_non_distributed_check
675
    @inited_runtime_handler
676 677
    def stop_worker(self):
        """
678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694
        stop `Communicator` and give training complete notice to parameter server.

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_server()

695 696 697
        """
        self._runtime_handle._stop_worker()

T
tangwei12 已提交
698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740
    def save(self, dirname, feed=[], fetch=[], **configs):
        inference = True

        if not feed and not fetch:
            inference = False

        place = paddle.CPUPlace()
        executor = paddle.static.Executor(place)

        if inference:
            feeded_var_names = []
            fetch_var_names = []

            for var in feed:
                if isinstance(var, str):
                    feeded_var_names.append(var)
                elif isinstance(var, paddle.static.Variable):
                    feeded_var_names.append(var.name)
                else:
                    raise ValueError("feed must be [str|Variable]")

            for var in fetch:
                if isinstance(var, str):
                    fetch_var_names.append(var)
                elif isinstance(var, paddle.static.Variable):
                    fetch_var_names.append(var.name)
                else:
                    raise ValueError("feed must be [str|Variable]")

            fetch_vars = [
                paddle.static.default_main_program().global_block().var(name)
                for name in fetch_var_names
            ]

            self._runtime_handle._save_inference_model(
                executor, dirname, feeded_var_names, fetch_vars, None, True, 0)
        else:
            increment_mode = 0
            if "mode" in configs:
                increment_mode = int(configs["mode"])
            self._runtime_handle._save_persistables(
                executor, dirname, main_program=None, mode=increment_mode)

741 742 743 744 745 746
    def save_inference_model(self,
                             executor,
                             dirname,
                             feeded_var_names,
                             target_vars,
                             main_program=None,
747 748
                             export_for_deployment=True,
                             mode=0):
749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767
        """
        save inference model for inference.

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_server()

        """
T
tangwei12 已提交
768 769 770
        # warnings.warn(
        #     "'save_inference_model' is a deprecated, will be deleted after v2.2.0, Please use fleet.save instead."
        # )
771

772 773
        self._runtime_handle._save_inference_model(
            executor, dirname, feeded_var_names, target_vars, main_program,
774
            export_for_deployment, mode)
775

776
    def save_persistables(self, executor, dirname, main_program=None, mode=0):
777 778
        """

1
123malin 已提交
779
        saves all persistable tensors from :code:`main_program` to
780 781
        the folder :code:`dirname`. You can refer to

1
123malin 已提交
782 783
        The :code:`dirname` is used to specify the folder where persistable tensors
        are going to be saved. If you would like to save tensors in separate
784 785 786
        files, set :code:`filename` None.

        Args:
1
123malin 已提交
787
            executor(Executor): The executor to run for saving persistable tensors.
788 789 790 791 792
                                You can refer to :ref:`api_guide_executor_en` for
                                more details.

            dirname(str, optional): The saving directory path.
                                When you need to save the parameter to the memory, set it to None.
1
123malin 已提交
793
            main_program(Program, optional): The program whose persistbale tensors will
794 795 796 797 798 799 800 801 802 803
                                             be saved. Default: None.


        Returns:
            None

        Examples:

            .. code-block:: text

1
123malin 已提交
804 805
                import paddle
                paddle.enable_static()
806 807 808 809 810 811 812
                import paddle.distributed.fleet as fleet

                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

1
123malin 已提交
813 814
                exe = paddle.static.Executor(paddle.CPUPlace())
                fleet.save_persistables(exe, "dirname", paddle.static.default_main_program())
815 816

        """
T
tangwei12 已提交
817 818 819
        # warnings.warn(
        #     "'save_persistables' is a deprecated, will be deleted after v2.2.0, Please use fleet.save instead."
        # )
820

821 822
        self._runtime_handle._save_persistables(executor, dirname, main_program,
                                                mode)
823

824 825 826
    def shrink(self, threshold):
        self._runtime_handle._shrink(threshold)

827
    def distributed_optimizer(self, optimizer, strategy=None):
828
        """
829 830 831 832 833 834 835
        Optimizer for distributed training.

        For the distributed training, this method would rebuild a new instance of DistributedOptimizer.
        Which has basic Optimizer function and special features for distributed training.

        Args:
            optimizer(Optimizer): The executor to run for init server.
836 837 838 839 840
            strategy(DistributedStrategy): Extra properties for distributed optimizer. 
                It is recommended to use DistributedStrategy in fleet.init(). The strategy
                here is for compatibility. If the strategy in fleet.distributed_optimizer() 
                is not None, then it will overwrite the DistributedStrategy in fleet.init(), 
                which will take effect in distributed training.
841

842
        Returns:
843
            Fleet: instance of fleet.
844 845

        Examples:
846

847
            .. code-block:: python
848

1
123malin 已提交
849
                import paddle
850
                import paddle.distributed.fleet as fleet
1
123malin 已提交
851
                fleet.init(is_collective=True)
852 853 854 855
                strategy = fleet.DistributedStrategy()
                optimizer = paddle.optimizer.SGD(learning_rate=0.001)
                optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)

856 857
        """
        self.user_defined_optimizer = optimizer
858

859
        if strategy is not None:
T
tangwei12 已提交
860 861 862 863 864 865 866
            if self._is_collective:
                warnings.warn(
                    "It is recommended to use DistributedStrategy "
                    "in fleet.init(). The strategy here is only for compatibility. "
                    "If the strategy in fleet.distributed_optimizer() is "
                    "not None, then it will overwrite the DistributedStrategy in fleet.init(), "
                    "which will take effect in distributed training.")
867
            self._user_defined_strategy = copy.deepcopy(strategy)
D
Dong Daxiang 已提交
868 869

        self._context = {}
S
ShenLiang 已提交
870 871

        if paddle.fluid.framework.in_dygraph_mode():
872 873 874 875 876
            if self.worker_num() > 1:
                return HybridParallelOptimizer(optimizer, self._hcg,
                                               self._user_defined_strategy)
            else:
                return optimizer
877 878
        return self

879
    @dygraph_only
880
    def distributed_model(self, model):
881
        """
882 883 884 885 886 887 888
        Return distributed data parallel model (Only work in dygraph mode)

        Args:
            model (Layer): the user-defind model which inherits Layer.

        Returns:
            distributed data parallel model which inherits Layer.
889 890

        Examples:
891

892 893
            .. code-block:: python

894 895 896 897 898 899 900 901 902
                import paddle
                import paddle.nn as nn
                from paddle.distributed import fleet

                class LinearNet(nn.Layer):
                    def __init__(self):
                        super(LinearNet, self).__init__()
                        self._linear1 = nn.Linear(10, 10)
                        self._linear2 = nn.Linear(10, 1)
903

904 905
                    def forward(self, x):
                        return self._linear2(self._linear1(x))
906

1
123malin 已提交
907
                # 1. initialize fleet environment
908 909
                fleet.init(is_collective=True)

1
123malin 已提交
910
                # 2. create layer & optimizer
911 912 913 914 915
                layer = LinearNet()
                loss_fn = nn.MSELoss()
                adam = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=layer.parameters())

1
123malin 已提交
916
                # 3. get data_parallel model using fleet
917 918 919
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)

1
123malin 已提交
920
                # 4. run layer
921 922 923 924 925 926 927 928 929 930 931 932
                inputs = paddle.randn([10, 10], 'float32')
                outputs = dp_layer(inputs)
                labels = paddle.randn([10, 1], 'float32')
                loss = loss_fn(outputs, labels)

                print("loss:", loss.numpy())

                loss.backward()

                adam.step()
                adam.clear_grad()

933

934
        """
935 936 937
        assert model is not None, "model should not be None"
        if self.worker_num() <= 1:
            return model
J
JZ-LIANG 已提交
938 939 940 941 942

        if self._hcg.get_parallel_mode() == ParallelMode.SHARDING_PARALLEL:
            distributed_model = ShardingParallel(
                model, self._hcg, strategy=self._user_defined_strategy)
        elif self._hcg.get_parallel_mode() == ParallelMode.DATA_PARALLEL:
943 944 945 946 947 948 949 950

            # NOTE (JZ-LIANG) init parameters broadcast within sharding group
            # normally it should be done inside DataParallel
            if self.sharding_degree > 1:
                from paddle.distributed.fleet.utils.hybrid_parallel_util import broadcast_mp_parameters, broadcast_sharding_parameters
                assert self.sharding_degree == self._hcg.get_sharding_parallel_world_size(
                )
                broadcast_sharding_parameters(model, self._hcg)
951 952 953 954 955 956 957 958
            distributed_model = paddle.DataParallel(
                model,
                comm_buffer_size=self._user_defined_strategy.
                fuse_grad_size_in_MB,
                last_comm_buffer_size=self._user_defined_strategy.
                last_comm_group_size_MB,
                find_unused_parameters=self._user_defined_strategy.
                find_unused_parameters)
959 960
        elif self._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL:
            distributed_model = TensorParallel(
961
                model, self._hcg, strategy=self._user_defined_strategy)
962 963 964
        elif self._hcg.get_parallel_mode() == ParallelMode.PIPELINE_PARALLEL:
            distributed_model = PipelineParallel(
                model, self._hcg, strategy=self._user_defined_strategy)
J
JZ-LIANG 已提交
965

966
        return distributed_model
967 968 969 970 971

    @dygraph_only
    def state_dict(self):
        """
        Get state dict information from optimizer.
972
        (Only work in dygraph mode)
973 974 975 976 977 978 979

        Returns: 
            state_dict(dict) : dict contains all the Tensor used by optimizer

        Examples:
            .. code-block:: python

980 981 982 983 984
                import numpy as np
                import paddle
                from paddle.distributed import fleet

                fleet.init(is_collective=True)
985

986
                value = np.arange(26).reshape(2, 13).astype("float32")
1
123malin 已提交
987
                a = paddle.to_tensor(value)
988

989 990
                layer = paddle.nn.Linear(13, 5)
                adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
991

992 993 994
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)
                state_dict = adam.state_dict()
995 996 997 998 999 1000 1001 1002
        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.state_dict()

    @dygraph_only
    def set_state_dict(self, state_dict):
        """
        Load optimizer state dict.
1003
        (Only work in dygraph mode)
1004 1005 1006 1007

        Args: 
            state_dict(dict) : Dict contains all the Tensor needed by optimizer

1008 1009
        Returns:
            None
1010 1011 1012 1013

        Examples:
            .. code-block:: python

1014 1015 1016
                import numpy as np
                import paddle
                from paddle.distributed import fleet
1017

1018 1019 1020
                fleet.init(is_collective=True)

                value = np.arange(26).reshape(2, 13).astype("float32")
1
123malin 已提交
1021
                a = paddle.to_tensor(value)
1022

1023 1024
                layer = paddle.nn.Linear(13, 5)
                adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
1025

1026 1027 1028
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)
                state_dict = adam.state_dict()
1
123malin 已提交
1029 1030 1031
                paddle.save(state_dict, "paddle_dy")
                para_state_dict = paddle.load("paddle_dy")
                adam.set_state_dict(para_state_dict)
1032 1033 1034 1035 1036 1037 1038 1039
        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.set_state_dict(state_dict)

    @dygraph_only
    def set_lr(self, value):
        """
        Set the value of the learning rate manually in the optimizer. 
1040
        (Only work in dygraph mode)
1041

1042 1043 1044
        Args:
            value (float|Tensor): the value of learning rate

1045 1046
        Returns: 
            None 
1047 1048 1049 1050

        Examples:
            .. code-block:: python

1051 1052 1053
                import numpy as np
                import paddle
                from paddle.distributed import fleet
1054

1055
                fleet.init(is_collective=True)
1056

1057
                value = np.arange(26).reshape(2, 13).astype("float32")
1
123malin 已提交
1058
                a = paddle.to_tensor(value)
1059

1060 1061
                layer = paddle.nn.Linear(13, 5)
                adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
1062

1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)

                lr_list = [0.2, 0.3, 0.4, 0.5, 0.6]
                for i in range(5):
                    adam.set_lr(lr_list[i])
                    lr = adam.get_lr()
                    print("current lr is {}".format(lr))
                # Print:
                #    current lr is 0.2
                #    current lr is 0.3
                #    current lr is 0.4
                #    current lr is 0.5
                #    current lr is 0.6
1077 1078 1079 1080 1081 1082 1083 1084
        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.set_lr(value)

    @dygraph_only
    def get_lr(self):
        """
        Get current step learning rate.
1085
        (Only work in dygraph mode)
1086 1087 1088 1089 1090

        Returns:
            float: The learning rate of the current step.

        Examples:
1
123malin 已提交
1091

1092 1093
            .. code-block:: python

1094 1095 1096 1097 1098
                import numpy as np
                import paddle
                from paddle.distributed import fleet

                fleet.init(is_collective=True)
1099

1100
                value = np.arange(26).reshape(2, 13).astype("float32")
1
123malin 已提交
1101
                a = paddle.to_tensor(value)
1102

1103 1104
                layer = paddle.nn.Linear(13, 5)
                adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
1105

1106 1107
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)
1108

1109 1110
                lr = adam.get_lr()
                print(lr) # 0.01
1111 1112 1113 1114 1115 1116 1117 1118
        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.get_lr()

    @dygraph_only
    def step(self):
        """
        Execute the optimizer once.
1119
        (Only work in dygraph mode)
1120

1121 1122
        Returns:
            None
1123 1124

        Examples:
1
123malin 已提交
1125

1126 1127
            .. code-block:: python

1128 1129 1130
                import paddle
                import paddle.nn as nn
                from paddle.distributed import fleet
1131

1132 1133 1134 1135 1136
                class LinearNet(nn.Layer):
                    def __init__(self):
                        super(LinearNet, self).__init__()
                        self._linear1 = nn.Linear(10, 10)
                        self._linear2 = nn.Linear(10, 1)
1137

1138 1139
                    def forward(self, x):
                        return self._linear2(self._linear1(x))
1140

1
123malin 已提交
1141
                # 1. initialize fleet environment
1142 1143
                fleet.init(is_collective=True)

1
123malin 已提交
1144
                # 2. create layer & optimizer
1145 1146 1147 1148 1149
                layer = LinearNet()
                loss_fn = nn.MSELoss()
                adam = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=layer.parameters())

1
123malin 已提交
1150
                # 3. get data_parallel model using fleet
1151 1152 1153
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)

1
123malin 已提交
1154
                # 4. run layer
1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174
                inputs = paddle.randn([10, 10], 'float32')
                outputs = dp_layer(inputs)
                labels = paddle.randn([10, 1], 'float32')
                loss = loss_fn(outputs, labels)

                print("loss:", loss.numpy())

                loss.backward()

                adam.step()
                adam.clear_grad()


        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.step()

    @dygraph_only
    def clear_grad(self):
        """
1175 1176
        Clear the gradients of all optimized parameters for model.
        (Only work in dygraph mode)
1177

1178 1179
        Returns: 
            None
1180 1181

        Examples:
1
123malin 已提交
1182

1183 1184
            .. code-block:: python

1185 1186 1187
                import paddle
                import paddle.nn as nn
                from paddle.distributed import fleet
1188

1189 1190 1191 1192 1193
                class LinearNet(nn.Layer):
                    def __init__(self):
                        super(LinearNet, self).__init__()
                        self._linear1 = nn.Linear(10, 10)
                        self._linear2 = nn.Linear(10, 1)
1194

1195 1196
                    def forward(self, x):
                        return self._linear2(self._linear1(x))
1197

1
123malin 已提交
1198
                # 1. initialize fleet environment
1199 1200
                fleet.init(is_collective=True)

1
123malin 已提交
1201
                # 2. create layer & optimizer
1202 1203 1204 1205 1206
                layer = LinearNet()
                loss_fn = nn.MSELoss()
                adam = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=layer.parameters())

1
123malin 已提交
1207
                # 3. get data_parallel model using fleet
1208 1209 1210
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)

1
123malin 已提交
1211
                # 4. run layer
1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227
                inputs = paddle.randn([10, 10], 'float32')
                outputs = dp_layer(inputs)
                labels = paddle.randn([10, 1], 'float32')
                loss = loss_fn(outputs, labels)

                print("loss:", loss.numpy())

                loss.backward()

                adam.step()
                adam.clear_grad()

        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.clear_grad()

1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244
    def _get_amp_optimizer(self):
        # imitate target optimizer retrieval
        amp_optimizer = None
        for optimizer in self.strategy_compiler._get_applied_meta_optimizer():
            if hasattr(optimizer, 'amp_init'):
                amp_optimizer = optimizer
                break

        if amp_optimizer is None:
            if hasattr(self.user_defined_optimizer, 'amp_init'):
                amp_optimizer = self.user_defined_optimizer

        assert amp_optimizer is not None, \
            "amp_init can only be used when the amp(auto mixed precision) strategy is turned on."
        return amp_optimizer

    def get_loss_scaling(self):
1245 1246
        """Return the real-time loss scaling factor.
        """
1247 1248 1249
        amp_optimizer = self._get_amp_optimizer()
        return amp_optimizer.get_loss_scaling()

H
huangxu96 已提交
1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309
    def amp_init(self,
                 place,
                 scope=None,
                 test_program=None,
                 use_fp16_test=False):
        """
        Init the amp training, such as cast fp32 parameters to fp16 type.
  
        Args:
            place(CUDAPlace): place is used to initialize 
                fp16 parameters with fp32 values.
            scope(Scope): The scope is used to find fp32 parameters.
            test_program(Program): The program is used for testing.
            use_fp16_test(bool): Whether to use fp16 testing.
            
        Examples:
            .. code-block:: python

                import numpy as np
                import paddle
                import paddle.nn.functional as F
                paddle.enable_static()

                def run_example_code():
                    place = paddle.CUDAPlace(0)
                    exe = paddle.static.Executor(place)
                    data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32')
                    conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3)
                    # 1) Use fp16_guard to control the range of fp16 kernels used.
                    with paddle.static.amp.fp16_guard():
                        bn = paddle.static.nn.batch_norm(input=conv2d, act="relu")
                        pool = F.max_pool2d(bn, kernel_size=2, stride=2)
                        hidden = paddle.static.nn.fc(pool, size=10)
                        loss = paddle.mean(hidden)
                    # 2) Create the optimizer and set `multi_precision` to True.
                    # Setting `multi_precision` to True can avoid the poor accuracy
                    # or the slow convergence in a way. 
                    optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True)
                    # 3) These ops in `custom_black_list` will keep in the float32 computation type.
                    amp_list = paddle.static.amp.CustomOpLists(
                        custom_black_list=['pool2d'])
                    # 4) The entry of Paddle AMP.
                    # Enable pure fp16 training by setting `use_pure_fp16` to True.
                    optimizer = paddle.static.amp.decorate(
                        optimizer,
                        amp_list,
                        init_loss_scaling=128.0,
                        use_dynamic_loss_scaling=True,
                        use_pure_fp16=True)
                    # If you don't use the default_startup_program(), you sholud pass
                    # your defined `startup_program` into `minimize`.
                    optimizer.minimize(loss)
                    exe.run(paddle.static.default_startup_program())
                    # 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`).
                    # If you want to perform the testing process, you should pass `test_program` into `amp_init`.
                    optimizer.amp_init(place, scope=paddle.static.global_scope())
                    
                if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
                    run_example_code()       
        """
1310
        amp_optimizer = self._get_amp_optimizer()
1311
        return amp_optimizer.amp_init(place, scope, test_program, use_fp16_test)
H
huangxu96 已提交
1312

D
Dong Daxiang 已提交
1313 1314 1315 1316 1317 1318 1319 1320 1321
    def _final_strategy(self):
        if "valid_strategy" not in self._context:
            print(
                "WARNING: You may need to call minimize function before this function is called"
            )
            return {}
        else:
            return self._context["valid_strategy"]

1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339
    def _get_applied_meta_list(self):
        if "applied_meta_list" not in self._context:
            print(
                "WARNING: You may need to call minimize function before _get_applied_meta_list called"
            )
            return []
        else:
            return self._context["applied_meta_list"]

    def _get_applied_graph_list(self):
        if "applied_graph_list" not in self._context:
            print(
                "WARNING: You may need to call minimize function before _get_applied_graph_list called"
            )
            return []
        else:
            return self._context["applied_graph_list"]

1340 1341 1342 1343 1344 1345 1346 1347 1348
    def minimize(self,
                 loss,
                 startup_program=None,
                 parameter_list=None,
                 no_grad_set=None):
        """
        Add distributed operations to minimize ``loss`` by updating ``parameter_list``.

        Args:
1
123malin 已提交
1349
            loss (Tensor): A ``Tensor`` containing the value to minimize.
1350 1351 1352
            startup_program (Program, optional): :ref:`api_fluid_Program` for
                initializing parameters in ``parameter_list``. The default value
                is None, at this time :ref:`api_fluid_default_startup_program` will be used.
1
123malin 已提交
1353
            parameter_list (Iterable, optional): Iterable of ``Tensor`` or ``Tensor.name`` to update
1354 1355
                to minimize ``loss``. The default value is None, at this time all parameters
                will be updated.
1
123malin 已提交
1356
            no_grad_set (set, optional): Set of ``Tensor``  or ``Tensor.name`` that don't need
1357 1358 1359 1360
                to be updated. The default value is None.

        Returns:
            tuple: tuple (optimize_ops, params_grads), A list of operators appended
1
123malin 已提交
1361
            by minimize and a list of (param, grad) tensor pairs, param is
1362
            ``Parameter``, grad is the gradient value corresponding to the parameter.
1363 1364
            The returned tuple can be passed to ``fetch_list`` in ``Executor.run()`` to
            indicate program pruning. If so, the program will be pruned by ``feed`` and
1365 1366 1367
            ``fetch_list`` before run, see details in ``Executor``.

        Examples:
1
123malin 已提交
1368

1369
            .. code-block:: python
1370

1371
                import paddle
1
123malin 已提交
1372
                paddle.enable_static()
1373
                import paddle.distributed.fleet as fleet
1
123malin 已提交
1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384
                import paddle.nn.functional as F

                hid_dim = 10
                label_dim = 2
                input_x = paddle.static.data(name='x', shape=[None, 13], dtype='float32')
                input_y = paddle.static.data(name='y', shape=[None, 1], dtype='int64')
                fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim, activation='tanh')
                fc_2 = paddle.static.nn.fc(x=fc_1, size=hid_dim, activation='tanh')
                prediction = paddle.static.nn.fc(x=[fc_2], size=label_dim, activation='softmax')
                cost = F.cross_entropy(input=prediction, label=input_y)
                avg_cost = paddle.mean(x=cost)
1385

1
123malin 已提交
1386
                fleet.init(is_collective=True)
1387 1388 1389 1390
                strategy = fleet.DistributedStrategy()
                optimizer = paddle.optimizer.SGD(learning_rate=0.001)
                optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
                optimizer.minimize(avg_cost)
1391

1392
                # for more examples, please reference https://github.com/PaddlePaddle/FleetX
1393 1394

        """
D
Dong Daxiang 已提交
1395 1396 1397
        context = {}
        context["user_defined_strategy"] = copy.deepcopy(
            self._user_defined_strategy)
1398 1399 1400
        if paddle.fluid.framework.in_dygraph_mode():
            # imitate target optimizer retrieval
            target_opt = self.user_defined_optimizer
D
Dong Daxiang 已提交
1401
            self._context = context
1402 1403
            return target_opt.minimize(loss)

1404 1405
        # cache original feed forward program
        self.origin_main_program = loss.block.program
1406 1407
        context["origin_main_program"] = self.origin_main_program
        context["loss"] = loss
1408 1409
        if startup_program == None:
            self.origin_startup_program = \
1410 1411
                paddle.static.default_startup_program().clone(for_test=False)
            startup_program = paddle.static.default_startup_program()
1412 1413 1414
        else:
            self.origin_startup_program = \
                startup_program.clone(for_test=False)
1415

1416 1417
        context["origin_startup_program"] = startup_program
        context["role_maker"] = self._role_maker
1418

1419 1420 1421 1422 1423 1424 1425 1426
        # Use the auto-parallel's routines instead
        if self._user_defined_strategy.semi_auto:
            from ...auto_parallel.parallelizer import AutoParallelizer
            auto_parallelizer = AutoParallelizer(self)
            optimize_ops, params_grads, dist_startup_prog, dist_main_prog = auto_parallelizer.parallelize(
                loss, startup_program, parameter_list, no_grad_set)
            return optimize_ops, params_grads, dist_startup_prog, dist_main_prog

1427 1428 1429 1430
        # compile time
        distributed_optimizer_list = \
            MetaOptimizerFactory()._get_valid_meta_optimizers(
                self.user_defined_optimizer)
D
Dong Daxiang 已提交
1431

D
Dong Daxiang 已提交
1432 1433 1434
        context["user_defined_strategy"] = copy.deepcopy(
            self._user_defined_strategy)
        copy_user_defined_strategy = copy.deepcopy(self._user_defined_strategy)
1435 1436 1437 1438 1439 1440

        # trigger the auto-parallel in very strict condition
        # strategy = DistributedStrategy()
        # strategy.auto = True
        # optimizer = paddle.optimizer.SGD(learning_rate=0.1)
        # optimizer = fleet.distributed_optimizer(optimizer, strategy)
D
Dong Daxiang 已提交
1441
        if copy_user_defined_strategy._is_strict_auto():
1442 1443
            # turn on all the strategy for each optimizer
            for opt in distributed_optimizer_list:
D
Dong Daxiang 已提交
1444
                opt._enable_strategy(copy_user_defined_strategy, context)
1445

1446 1447
        valid_optimizer_list = []
        valid_graph_optimizer_list = []
D
Dong Daxiang 已提交
1448
        can_not_apply_optimizer_list = []
1449 1450 1451 1452
        # recall meta optimizers for ranking
        for opt in distributed_optimizer_list:
            opt._set_basic_info(loss, self._role_maker,
                                self.user_defined_optimizer,
D
Dong Daxiang 已提交
1453
                                copy_user_defined_strategy)
1454 1455
            if opt._can_apply() and not opt._is_graph_out():
                valid_optimizer_list.append(opt)
D
Dong Daxiang 已提交
1456
            elif opt._can_apply() and opt._is_graph_out():
1457
                valid_graph_optimizer_list.append(opt)
D
Dong Daxiang 已提交
1458 1459
            else:
                can_not_apply_optimizer_list.append(opt)
1460
        # combine recalled meta optimizers to be a valid meta optimizer
D
Dong Daxiang 已提交
1461
        meta_optimizer, graph_optimizer = \
1462 1463
            self.strategy_compiler.generate_optimizer(
                loss, self._role_maker, self.user_defined_optimizer,
D
Dong Daxiang 已提交
1464
                copy_user_defined_strategy, valid_optimizer_list,
1465
                valid_graph_optimizer_list)
D
Dong Daxiang 已提交
1466

D
Dong Daxiang 已提交
1467
        valid_strategy = self.strategy_compiler._get_valid_strategy(
D
Dong Daxiang 已提交
1468 1469 1470
            copy_user_defined_strategy, can_not_apply_optimizer_list)

        context["valid_strategy"] = copy.deepcopy(valid_strategy)
1471

1472 1473 1474 1475 1476 1477
        applied_meta_list = self.strategy_compiler._get_applied_meta_list()
        applied_graph_list = self.strategy_compiler._get_applied_graph_list()

        context['applied_meta_list'] = applied_meta_list
        context['applied_graph_list'] = applied_graph_list

D
Dong Daxiang 已提交
1478
        self._context = context
1479

D
Dong Daxiang 已提交
1480
        self.valid_strategy = valid_strategy
1481
        self.valid_strategy._enable_env()
D
Dong Daxiang 已提交
1482

1483 1484
        optimize_ops = []
        params_grads = []
1485

1486 1487 1488 1489 1490 1491 1492 1493 1494
        if self._role_maker._is_non_distributed() and not self._is_collective:
            if self._runtime_handle is None:
                self._runtime_handle = RuntimeFactory()._create_runtime(context)

            compiled_program = compiler.CompiledProgram(
                self.origin_main_program).with_data_parallel(
                    loss_name=loss.name, share_vars_from=None)
            loss.block.program._graph = compiled_program
            return self.user_defined_optimizer.minimize(
M
MRXLT 已提交
1495
                loss, startup_program, parameter_list, no_grad_set=no_grad_set)
1496

1497 1498
        if meta_optimizer:
            optimize_ops, params_grads = meta_optimizer.minimize(
M
MRXLT 已提交
1499
                loss, startup_program, parameter_list, no_grad_set=no_grad_set)
1500

1501
            default_program = paddle.static.default_main_program()
1502 1503 1504 1505

            if id(default_program) != id(loss.block.program):
                paddle.fluid.framework.switch_main_program(loss.block.program)

1506 1507
        else:
            optimize_ops, params_grads = self.user_defined_optimizer.minimize(
M
MRXLT 已提交
1508
                loss, startup_program, parameter_list, no_grad_set=no_grad_set)
1509

1510 1511
        context["program_optimize_ops"] = optimize_ops
        context["program_params_grads"] = params_grads
1512

1513
        if graph_optimizer:
D
Dong Daxiang 已提交
1514
            optimize_ops, params_grads = graph_optimizer.minimize(
M
MRXLT 已提交
1515
                loss, startup_program, parameter_list, no_grad_set=no_grad_set)
1516 1517 1518 1519
            # since we do not encourage users to use graph operations
            # if a graph optimizer takes effect, mostly
            # optimizers_ops and params_grads are None
            # i.e. users can not modify current computation graph anymore
1520 1521
            context["graph_optimize_ops"] = optimize_ops
            context["graph_optimize_grads"] = params_grads
1522 1523
        else:
            apply_ir_passes(loss.block.program, startup_program, self)
1524

1525 1526 1527 1528 1529 1530 1531 1532
        program = paddle.static.default_main_program()
        opt_info = {}
        opt_info["mpi_size"] = self.worker_num()
        opt_info["mpi_rank"] = self.worker_index()
        for k, v in self._user_defined_strategy.trainer_desc_configs.items():
            opt_info[k] = v
        program._fleet_opt = opt_info

1533
        if self._runtime_handle is None:
1534
            self._runtime_handle = RuntimeFactory()._create_runtime(context)
1535

1536 1537
        import paddle.distributed.fleet as fleet
        fleet.util._set_strategy(context["valid_strategy"])
1538 1539

        return optimize_ops, params_grads
1540 1541 1542 1543

    @dygraph_only
    def distributed_scaler(self, scaler):
        return HybridParallelGradScaler(scaler, self._hcg)