fleet_base.py 56.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
16
import copy
17
import warnings
18
import paddle
19
import os
20
from types import MethodType
21
import numpy as np
22
from paddle.fluid.framework import dygraph_only, _global_flags
23
from paddle.fluid import compiler
24
from .role_maker import UserDefinedRoleMaker, PaddleCloudRoleMaker, RoleMakerBase
25
from .strategy_compiler import StrategyCompiler
26
from .distributed_strategy import DistributedStrategy
27 28
from .meta_optimizer_factory import MetaOptimizerFactory
from .runtime_factory import RuntimeFactory
29
from paddle.fluid.wrapped_decorator import wrap_decorator
30
from paddle.fluid.dygraph import parallel_helper
31
from paddle.fluid.ir import apply_build_strategy
32
from . import topology as tp
33
from .topology import ParallelMode
34
from ..meta_parallel import TensorParallel, model_parallel_random_seed
J
JZ-LIANG 已提交
35
from ..meta_parallel import PipelineParallel, ShardingParallel
36
from ..meta_optimizers import HybridParallelOptimizer
37
from paddle import _C_ops
38 39
from paddle.fluid import core
from paddle.fluid.dygraph import to_variable
40

41 42
__all__ = []

43

44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
def apply_ir_passes(main_program, startup_program, config):
    build_strategy = config._user_defined_strategy.build_strategy._copy()
    if not _global_flags()['FLAGS_apply_pass_to_program']:
        return build_strategy

    pipeline_opt = getattr(main_program, "_pipeline_opt", {})
    if pipeline_opt:
        main_program = pipeline_opt["section_program"]
        startup_program = startup_program._pipeline_opt["startup_program"]

    pass_attrs = {"use_cuda": config._is_collective}
    fuse_all_reduce = config._user_defined_strategy.fuse_all_reduce_ops
    if fuse_all_reduce and build_strategy.fuse_all_optimizer_ops:
        # FIXME(zjl): currently, fuse_all_optimizer_ops
        # have conflict with fuse_all_reduce_ops because 
        # RawProgramOptimizer also inserts coalesce_tensor 
        # into program. These two procedures may conflict  
        # in which vars are to be fused. 
        warnings.warn(
            'Currently, the fuse_all_optimizer_ops pass has conflict with fuse_all_reduce_ops pass. Disable the fuse_all_optimizer_ops pass temporarily.'
        )
        build_strategy.fuse_all_optimizer_ops = False

    return apply_build_strategy(main_program, startup_program, build_strategy,
                                pass_attrs)


71 72 73 74 75 76 77 78 79 80 81 82
def _inited_runtime_handler_(func):
    def __impl__(*args, **kwargs):
        cls = args[0]

        if cls._runtime_handle is None:
            raise ValueError("Fleet can not find suitable runtime handler")

        return func(*args, **kwargs)

    return __impl__


83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
def _is_non_distributed_check_(func):
    def __impl__(*args, **kwargs):
        cls = args[0]

        if cls._role_maker is not None and cls._role_maker._is_non_distributed(
        ) is True:
            warnings.warn(
                "%s() function doesn't work when use non_distributed fleet." %
                (func.__name__))
            return

        return func(*args, **kwargs)

    return __impl__


99
inited_runtime_handler = wrap_decorator(_inited_runtime_handler_)
100
is_non_distributed_check = wrap_decorator(_is_non_distributed_check_)
101 102


103 104 105
class Fleet(object):
    """
    Unified API for distributed training of PaddlePaddle
106
    Please reference the https://github.com/PaddlePaddle/FleetX for details
107 108 109 110 111


    Returns:
        Fleet: A Fleet instance

112
    Example for collective training:
1
123malin 已提交
113

114 115
        .. code-block:: python

1
123malin 已提交
116 117
            import paddle
            paddle.enable_static()
118
            import paddle.distributed.fleet as fleet
119 120 121

            fleet.init(is_collective=True)

122 123 124
            strategy = fleet.DistributedStrategy()
            optimizer = paddle.optimizer.SGD(learning_rate=0.001)
            optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
125 126 127 128 129 130 131 132

            # do distributed training


    Example for parameter server training:

        .. code-block:: python

1
123malin 已提交
133 134
            import paddle
            paddle.enable_static()
135 136
            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
S
ShenLiang 已提交
137
            fleet.init(strategy=strategy)
138

139
            optimizer = paddle.optimizer.SGD(learning_rate=0.001)
140
            optimizer = fleet.distributed_optimizer(optimizer)
141

142 143
            if fleet.is_first_worker():
                print("this is first worker")
144

145 146
            print("current node index: {}".format(fleet.worker_index()))
            print("total number of worker num: {}".format(fleet.worker_num()))
147

148 149 150
            if fleet.is_worker():
                print("this is worker")
            print("worker endpoints: {}".format(fleet.worker_endpoints(to_string=True)))
151

152 153
            print("server num: {}".format(fleet.server_num()))
            print("server endpoints: {}".format(fleet.server_endpoints(to_string=True)))
154

155 156 157
            if fleet.is_server():
                print("this is server")
            fleet.stop_worker()
158 159


160 161 162
    """

    def __init__(self):
163
        self._role_maker = None
164
        self.strategy_compiler = None
165
        self._is_collective = False
166
        self._runtime_handle = None
D
Dong Daxiang 已提交
167 168
        self._util = None
        self._context = {}
169

170
    def init(self, role_maker=None, is_collective=False, strategy=None):
171 172 173
        """
        Initialize role_maker in Fleet.

174 175 176 177 178 179 180 181 182 183 184
        This function is responsible for the distributed architecture
        what you want to run your code behind.

        Args:
            role_maker (RoleMakerBase, optional): A ``RoleMakerBase`` containing the configuration
                of environment variables related to distributed training.If you did not initialize 
                the rolemaker by yourself, it will be automatically initialized to PaddleRoleMaker.
                The default value is None.
            is_collective (Boolean, optional): A ``Boolean`` variable determines whether the program 
                runs on the CPU or GPU. False means set distributed training using CPU, and True means
                GPU.The default value is False.The default value is False.
185 186 187 188
            strategy (DistributedStrategy): Extra properties for distributed training. 
                For details, please refer to paddle.distributed.fleet.DistributedStrategy. Default: None.


189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210
        Returns:
            None

        Examples1:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

        Examples2:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init(is_collective=True)

        Examples3:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
1
123malin 已提交
211
                role = fleet.PaddleCloudRoleMaker()
212
                fleet.init(role)
213

214 215 216 217 218 219
        Examples4:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                strategy = fleet.DistributedStrategy()
S
ShenLiang 已提交
220
                fleet.init(strategy=strategy)
221

222
        """
S
ShenLiang 已提交
223 224 225
        if strategy is None:
            strategy = DistributedStrategy()
        self._user_defined_strategy = copy.deepcopy(strategy)
226 227

        if role_maker is None:
228 229 230 231 232 233
            if isinstance(is_collective, bool):
                self._is_collective = is_collective
                self._role_maker = PaddleCloudRoleMaker(
                    is_collective=self._is_collective)
            else:
                raise ValueError(
234 235
                    "`is_collective` should be instance of `bool`, but got {}".
                    format(type(is_collective)))
236
        else:
237 238
            if isinstance(role_maker, RoleMakerBase):
                self._role_maker = role_maker
239
                self._is_collective = role_maker._is_collective
240 241 242 243
            else:
                raise ValueError(
                    "`role_maker` should be subclass of `RoleMakerBase`, but got {}".
                    format(type(role_maker)))
244
        self._role_maker._generate_role()
245

246 247 248
        import paddle.distributed.fleet as fleet
        fleet.util._set_role_maker(self._role_maker)

249
        self.strategy_compiler = StrategyCompiler()
250 251 252 253 254 255 256 257 258

        if self._role_maker._is_non_distributed() and self._is_collective:
            if paddle.fluid.core.is_compiled_with_cuda():
                gpus_num = paddle.fluid.core.get_cuda_device_count()
                if gpus_num != 1:
                    raise ValueError(
                        "CUDA_VISIBLE_DEVICES shoule be set only 1 card if you use `python` to launch fleet program."
                    )

259
        if paddle.fluid.framework.in_dygraph_mode():
260
            if self.worker_num() == 1:
261 262 263
                # if worker_num is 1, should construct default topology & hcg
                self._topology = tp.CommunicateTopology()
                self._hcg = tp.HybridCommunicateGroup(self._topology)
264
                return
265 266 267 268
            if parallel_helper._is_parallel_ctx_initialized():
                warnings.warn(
                    "The dygraph parallel environment has been initialized.")
            else:
269 270 271 272 273 274 275 276 277
                # FLAGS_nccl_nrings is used for dynamic graph multi-stream communication
                if "FLAGS_nccl_nrings" in os.environ:
                    warnings.warn(
                        "You have set the environment variable FLAGS_nccl_nrings "
                        "outside the program, so the nccl_comm_num in "
                        "DistributedStrategy will not take effect here.")
                else:
                    os.environ["FLAGS_nccl_nrings"] = str(
                        self._user_defined_strategy.nccl_comm_num)
278
                paddle.distributed.init_parallel_env()
279

280 281 282 283 284 285 286
            # init hybrid parallel environment in dygraph
            if tp._HYBRID_PARALLEL_GROUP is None:
                self._init_hybrid_parallel_env()
            else:
                warnings.warn(
                    "The dygraph hybrid parallel environment has been initialized."
                )
W
WangXi 已提交
287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302
        elif self._is_collective:
            use_sharding = self._user_defined_strategy.sharding

            # global group
            global_rank = self.worker_index()
            global_world_size = self.worker_num()
            # NOTE(wangxi): see sharding_optimizer
            global_ring_id = 3 if use_sharding else 0
            global_ranks = list(range(global_world_size))

            if tp._HYBRID_PARALLEL_GROUP is None: tp._CommunicateGroup()
            cg = tp._HYBRID_PARALLEL_GROUP
            self._hcg = cg
            cg.set_comm_group('global', global_rank, global_world_size,
                              global_ring_id, global_ranks)

Y
Yuang Liu 已提交
303 304 305
            use_tensor_parallel = self._user_defined_strategy.tensor_parallel
            use_mp = use_sharding or use_tensor_parallel

W
WangXi 已提交
306
            # hybrid group
Y
Yuang Liu 已提交
307 308 309 310 311 312 313 314 315 316 317 318 319 320 321
            if use_mp is False: return

            mp_degree_sharding = 1
            mp_degree_tensor_parallel = 1
            if use_sharding:
                sharding_configs = self._user_defined_strategy.sharding_configs
                mp_degree_sharding = int(sharding_configs['mp_degree'])

            if use_tensor_parallel:
                tensor_parallel_configs = self._user_defined_strategy.tensor_parallel_configs
                mp_degree_tensor_parallel = int(tensor_parallel_configs[
                    'tensor_parallel_degree'])

            if use_sharding and use_tensor_parallel:
                assert mp_degree_sharding == mp_degree_tensor_parallel
W
WangXi 已提交
322

Y
Yuang Liu 已提交
323
            mp_degree = mp_degree_sharding if use_sharding else mp_degree_tensor_parallel
W
WangXi 已提交
324 325 326 327 328 329 330 331 332 333 334 335 336

            if mp_degree > 1:
                assert global_world_size % mp_degree == 0
                # NOTE(wangxi): mp_ring_id sync with sharding_optimizer.py _build_groups
                mp_ring_id = 0
                mp_rank = global_rank % mp_degree
                mp_group_id = global_rank // mp_degree
                mp_group_ranks = [
                    idx for idx in global_ranks
                    if idx // mp_degree == mp_group_id
                ]
                cg.set_comm_group('model', mp_rank, mp_degree, mp_ring_id,
                                  mp_group_ranks)
337 338 339 340 341 342 343 344

    def _init_hybrid_parallel_env(self):
        """initialize the hybrid environment
        """
        self.hybrid_configs = self._user_defined_strategy.hybrid_configs
        self.dp_degree = self.hybrid_configs["dp_degree"]
        self.mp_degree = self.hybrid_configs["mp_degree"]
        self.pp_degree = self.hybrid_configs["pp_degree"]
J
JZ-LIANG 已提交
345
        self.sharding_degree = self.hybrid_configs["sharding_degree"]
346 347 348

        assert self.mp_degree >= 0, "mp_degree should be greater or equal to 0"
        assert self.pp_degree >= 0, "pp_degree should be greater or equal to 0"
J
JZ-LIANG 已提交
349
        assert self.sharding_degree >= 0, "sharding_degree should be greater or equal to 0"
350 351 352 353 354 355 356 357 358 359 360

        self.mp_degree = max(self.mp_degree, 1)
        self.pp_degree = max(self.pp_degree, 1)

        if self.dp_degree < 0:
            nranks = paddle.distributed.get_world_size()
            self.dp_degree = nranks // (self.mp_degree * self.pp_degree)

        self.dp_degree = max(self.dp_degree, 1)

        self._topology = tp.CommunicateTopology(
J
JZ-LIANG 已提交
361 362 363 364 365
            hybrid_group_names=["data", "pipe", "sharding", "model"],
            dims=[
                self.dp_degree, self.pp_degree, self.sharding_degree,
                self.mp_degree
            ])
366 367 368

        self._hcg = tp.HybridCommunicateGroup(self._topology)

369 370 371 372 373 374 375 376
        if self.mp_degree > 1:
            tensor_parallel_configs = self._user_defined_strategy.tensor_parallel_configs
            tensor_init_seed = tensor_parallel_configs["tensor_init_seed"]
            if tensor_init_seed == -1:
                model_parallel_random_seed()
            else:
                model_parallel_random_seed(tensor_init_seed)

377 378 379 380 381 382 383 384
    def get_hybrid_communicate_group(self):
        assert self._hcg is not None
        return self._hcg

    def get_hybrid_parallel_topology(self):
        assert self._topology is not None
        return self._topology

385 386 387 388 389 390 391
    def is_first_worker(self):
        """
        Check whether the node is the first instance of worker.

        Returns:
            bool: True if this is the first node of worker,
                  False if not.
392

393 394 395 396 397 398 399 400
        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.is_first_worker()

401
        """
402
        return self._role_maker._is_first_worker()
403 404 405 406 407 408 409

    def worker_index(self):
        """
        Get current worker index.

        Returns:
            int: node id
410 411 412 413

        Examples:

            .. code-block:: python
1
123malin 已提交
414

415 416 417 418
                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.worker_index()

419
        """
420
        return self._role_maker._worker_index()
421 422 423 424 425 426 427

    def worker_num(self):
        """
        Get current total worker number.

        Returns:
            int: worker numbers
1
123malin 已提交
428

429
        Examples:
1
123malin 已提交
430

431 432 433 434 435 436
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.worker_num()

437
        """
438
        return self._role_maker._worker_num()
439

440 441 442 443 444 445 446 447 448 449 450 451
    def node_num(self):
        return self._role_maker._get_node_num()

    def local_rank(self):
        return self._role_maker._get_local_rank()

    def local_device_ids(self):
        return self._role_maker._get_local_device_ids()

    def world_device_ids(self):
        return self._role_maker._get_world_device_ids()

452 453 454 455 456 457 458
    def is_worker(self):
        """
        Check whether the node is an instance of worker.

        Returns:
            bool: True if this is a node of worker,
                  False if not.
459 460

        Examples:
1
123malin 已提交
461

462 463 464 465 466 467
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.is_worker()

468
        """
469
        return self._role_maker._is_worker()
470 471 472

    def worker_endpoints(self, to_string=False):
        """
473
        Get current worker endpoints, such as ["127.0.0.1:1001", "127.0.0.1:1002"].
474 475 476

        Returns:
            list/string: server endpoints
477 478

        Examples:
1
123malin 已提交
479

480 481 482 483 484 485
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.worker_endpoints()

486 487
        """
        if to_string:
488
            return ",".join(self._role_maker._get_trainer_endpoints())
489
        else:
490
            return self._role_maker._get_trainer_endpoints()
491 492 493 494 495 496 497

    def server_num(self):
        """
        Get current total worker number.

        Returns:
            int: server number
498 499

        Examples:
1
123malin 已提交
500

501
            .. code-block:: python
1
123malin 已提交
502 503 504 505

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.server_num()
506
        """
507
        return len(self._role_maker._get_pserver_endpoints())
508 509 510 511 512 513 514

    def server_index(self):
        """
        Get current server index.

        Returns:
            int: node id
515 516

        Examples:
1
123malin 已提交
517

518 519 520 521 522 523
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.server_index()

524
        """
525
        return self._role_maker._server_index()
526 527 528 529 530 531 532

    def server_endpoints(self, to_string=False):
        """
        Get current server endpoints, such as ["127.0.0.1:1001", "127.0.0.1:1002"].

        Returns:
            list/string: server endpoints
533 534

        Examples:
1
123malin 已提交
535

536 537 538 539 540 541
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.server_endpoints()

542
        """
543

544
        if to_string:
545
            return ",".join(self._role_maker._get_pserver_endpoints())
546
        else:
547
            return self._role_maker._get_pserver_endpoints()
548 549 550 551 552 553 554 555

    def is_server(self):
        """
        Check whether the node is an instance of server.

        Returns:
            bool: True if this is a node of server,
                  False if not.
556 557 558 559

        Examples:

            .. code-block:: python
1
123malin 已提交
560

561 562 563 564
                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.is_server()

565
        """
566 567
        return self._role_maker._is_server()

568 569
    def barrier_worker(self):
        """
570 571 572 573
        barrier all workers

        Returns:
            None
574
        """
575
        self._role_maker._barrier("worker")
576

577
    @is_non_distributed_check
578
    @inited_runtime_handler
579 580
    def init_worker(self):
        """
581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598
        initialize `Communicator` for parameter server training.


        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_worker()

599 600 601
        """
        self._runtime_handle._init_worker()

602
    @is_non_distributed_check
603
    @inited_runtime_handler
604
    def init_server(self, *args, **kwargs):
605
        """
606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624
        init_server executor to initialize startup program,
        if the `args` is not empty, it will run load_persistables for increment training.


        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_server()

625
        """
626
        self._runtime_handle._init_server(*args, **kwargs)
627

T
Thunderbrook 已提交
628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650
    def load_model(self, path, mode):
        """
        load fleet model from path


        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.load_model("path", "mode")

        """
        self._runtime_handle.load_model(path, mode)

651
    @is_non_distributed_check
652
    @inited_runtime_handler
653 654
    def run_server(self):
        """
655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672
        run server will run pserver main program with executor.

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                if fleet.is_server():
                    fleet.init_server()

673 674 675
        """
        self._runtime_handle._run_server()

676
    @is_non_distributed_check
677
    @inited_runtime_handler
678 679
    def stop_worker(self):
        """
680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696
        stop `Communicator` and give training complete notice to parameter server.

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_server()

697 698 699
        """
        self._runtime_handle._stop_worker()

T
tangwei12 已提交
700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742
    def save(self, dirname, feed=[], fetch=[], **configs):
        inference = True

        if not feed and not fetch:
            inference = False

        place = paddle.CPUPlace()
        executor = paddle.static.Executor(place)

        if inference:
            feeded_var_names = []
            fetch_var_names = []

            for var in feed:
                if isinstance(var, str):
                    feeded_var_names.append(var)
                elif isinstance(var, paddle.static.Variable):
                    feeded_var_names.append(var.name)
                else:
                    raise ValueError("feed must be [str|Variable]")

            for var in fetch:
                if isinstance(var, str):
                    fetch_var_names.append(var)
                elif isinstance(var, paddle.static.Variable):
                    fetch_var_names.append(var.name)
                else:
                    raise ValueError("feed must be [str|Variable]")

            fetch_vars = [
                paddle.static.default_main_program().global_block().var(name)
                for name in fetch_var_names
            ]

            self._runtime_handle._save_inference_model(
                executor, dirname, feeded_var_names, fetch_vars, None, True, 0)
        else:
            increment_mode = 0
            if "mode" in configs:
                increment_mode = int(configs["mode"])
            self._runtime_handle._save_persistables(
                executor, dirname, main_program=None, mode=increment_mode)

743 744 745 746 747 748
    def save_inference_model(self,
                             executor,
                             dirname,
                             feeded_var_names,
                             target_vars,
                             main_program=None,
749 750
                             export_for_deployment=True,
                             mode=0):
751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769
        """
        save inference model for inference.

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_server()

        """
T
tangwei12 已提交
770 771 772
        # warnings.warn(
        #     "'save_inference_model' is a deprecated, will be deleted after v2.2.0, Please use fleet.save instead."
        # )
773

774 775
        self._runtime_handle._save_inference_model(
            executor, dirname, feeded_var_names, target_vars, main_program,
776
            export_for_deployment, mode)
777

778
    def save_persistables(self, executor, dirname, main_program=None, mode=0):
779 780
        """

1
123malin 已提交
781
        saves all persistable tensors from :code:`main_program` to
782 783
        the folder :code:`dirname`. You can refer to

1
123malin 已提交
784 785
        The :code:`dirname` is used to specify the folder where persistable tensors
        are going to be saved. If you would like to save tensors in separate
786 787 788
        files, set :code:`filename` None.

        Args:
1
123malin 已提交
789
            executor(Executor): The executor to run for saving persistable tensors.
790 791 792 793 794
                                You can refer to :ref:`api_guide_executor_en` for
                                more details.

            dirname(str, optional): The saving directory path.
                                When you need to save the parameter to the memory, set it to None.
1
123malin 已提交
795
            main_program(Program, optional): The program whose persistbale tensors will
796 797 798 799 800 801 802 803 804 805
                                             be saved. Default: None.


        Returns:
            None

        Examples:

            .. code-block:: text

1
123malin 已提交
806 807
                import paddle
                paddle.enable_static()
808 809 810 811 812 813 814
                import paddle.distributed.fleet as fleet

                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

1
123malin 已提交
815 816
                exe = paddle.static.Executor(paddle.CPUPlace())
                fleet.save_persistables(exe, "dirname", paddle.static.default_main_program())
817 818

        """
T
tangwei12 已提交
819 820 821
        # warnings.warn(
        #     "'save_persistables' is a deprecated, will be deleted after v2.2.0, Please use fleet.save instead."
        # )
822

823 824
        self._runtime_handle._save_persistables(executor, dirname, main_program,
                                                mode)
825

826
    def shrink(self, threshold=None):
827 828
        self._runtime_handle._shrink(threshold)

829
    def distributed_optimizer(self, optimizer, strategy=None):
830
        """
831 832 833 834 835 836 837
        Optimizer for distributed training.

        For the distributed training, this method would rebuild a new instance of DistributedOptimizer.
        Which has basic Optimizer function and special features for distributed training.

        Args:
            optimizer(Optimizer): The executor to run for init server.
838 839 840 841 842
            strategy(DistributedStrategy): Extra properties for distributed optimizer. 
                It is recommended to use DistributedStrategy in fleet.init(). The strategy
                here is for compatibility. If the strategy in fleet.distributed_optimizer() 
                is not None, then it will overwrite the DistributedStrategy in fleet.init(), 
                which will take effect in distributed training.
843

844
        Returns:
845
            Fleet: instance of fleet.
846 847

        Examples:
848

849
            .. code-block:: python
850

1
123malin 已提交
851
                import paddle
852
                import paddle.distributed.fleet as fleet
1
123malin 已提交
853
                fleet.init(is_collective=True)
854 855 856 857
                strategy = fleet.DistributedStrategy()
                optimizer = paddle.optimizer.SGD(learning_rate=0.001)
                optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)

858 859
        """
        self.user_defined_optimizer = optimizer
860

861
        if strategy is not None:
T
tangwei12 已提交
862 863 864 865 866 867 868
            if self._is_collective:
                warnings.warn(
                    "It is recommended to use DistributedStrategy "
                    "in fleet.init(). The strategy here is only for compatibility. "
                    "If the strategy in fleet.distributed_optimizer() is "
                    "not None, then it will overwrite the DistributedStrategy in fleet.init(), "
                    "which will take effect in distributed training.")
869
            self._user_defined_strategy = copy.deepcopy(strategy)
D
Dong Daxiang 已提交
870 871

        self._context = {}
S
ShenLiang 已提交
872 873

        if paddle.fluid.framework.in_dygraph_mode():
874 875 876 877 878
            if self.worker_num() > 1:
                return HybridParallelOptimizer(optimizer, self._hcg,
                                               self._user_defined_strategy)
            else:
                return optimizer
879 880
        return self

881
    @dygraph_only
882
    def distributed_model(self, model):
883
        """
884 885 886 887 888 889 890
        Return distributed data parallel model (Only work in dygraph mode)

        Args:
            model (Layer): the user-defind model which inherits Layer.

        Returns:
            distributed data parallel model which inherits Layer.
891 892

        Examples:
893

894 895
            .. code-block:: python

896 897 898 899 900 901 902 903 904
                import paddle
                import paddle.nn as nn
                from paddle.distributed import fleet

                class LinearNet(nn.Layer):
                    def __init__(self):
                        super(LinearNet, self).__init__()
                        self._linear1 = nn.Linear(10, 10)
                        self._linear2 = nn.Linear(10, 1)
905

906 907
                    def forward(self, x):
                        return self._linear2(self._linear1(x))
908

1
123malin 已提交
909
                # 1. initialize fleet environment
910 911
                fleet.init(is_collective=True)

1
123malin 已提交
912
                # 2. create layer & optimizer
913 914 915 916 917
                layer = LinearNet()
                loss_fn = nn.MSELoss()
                adam = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=layer.parameters())

1
123malin 已提交
918
                # 3. get data_parallel model using fleet
919 920 921
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)

1
123malin 已提交
922
                # 4. run layer
923 924 925 926 927 928 929 930 931 932 933 934
                inputs = paddle.randn([10, 10], 'float32')
                outputs = dp_layer(inputs)
                labels = paddle.randn([10, 1], 'float32')
                loss = loss_fn(outputs, labels)

                print("loss:", loss.numpy())

                loss.backward()

                adam.step()
                adam.clear_grad()

935

936
        """
937 938 939
        assert model is not None, "model should not be None"
        if self.worker_num() <= 1:
            return model
J
JZ-LIANG 已提交
940 941 942 943 944

        if self._hcg.get_parallel_mode() == ParallelMode.SHARDING_PARALLEL:
            distributed_model = ShardingParallel(
                model, self._hcg, strategy=self._user_defined_strategy)
        elif self._hcg.get_parallel_mode() == ParallelMode.DATA_PARALLEL:
945 946 947 948 949 950 951 952

            # NOTE (JZ-LIANG) init parameters broadcast within sharding group
            # normally it should be done inside DataParallel
            if self.sharding_degree > 1:
                from paddle.distributed.fleet.utils.hybrid_parallel_util import broadcast_mp_parameters, broadcast_sharding_parameters
                assert self.sharding_degree == self._hcg.get_sharding_parallel_world_size(
                )
                broadcast_sharding_parameters(model, self._hcg)
953 954 955 956 957 958 959 960
            distributed_model = paddle.DataParallel(
                model,
                comm_buffer_size=self._user_defined_strategy.
                fuse_grad_size_in_MB,
                last_comm_buffer_size=self._user_defined_strategy.
                last_comm_group_size_MB,
                find_unused_parameters=self._user_defined_strategy.
                find_unused_parameters)
961 962
        elif self._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL:
            distributed_model = TensorParallel(
963
                model, self._hcg, strategy=self._user_defined_strategy)
964 965 966
        elif self._hcg.get_parallel_mode() == ParallelMode.PIPELINE_PARALLEL:
            distributed_model = PipelineParallel(
                model, self._hcg, strategy=self._user_defined_strategy)
J
JZ-LIANG 已提交
967

968
        return distributed_model
969 970 971 972 973

    @dygraph_only
    def state_dict(self):
        """
        Get state dict information from optimizer.
974
        (Only work in dygraph mode)
975 976 977 978 979 980 981

        Returns: 
            state_dict(dict) : dict contains all the Tensor used by optimizer

        Examples:
            .. code-block:: python

982 983 984 985 986
                import numpy as np
                import paddle
                from paddle.distributed import fleet

                fleet.init(is_collective=True)
987

988
                value = np.arange(26).reshape(2, 13).astype("float32")
1
123malin 已提交
989
                a = paddle.to_tensor(value)
990

991 992
                layer = paddle.nn.Linear(13, 5)
                adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
993

994 995 996
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)
                state_dict = adam.state_dict()
997 998 999 1000 1001 1002 1003 1004
        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.state_dict()

    @dygraph_only
    def set_state_dict(self, state_dict):
        """
        Load optimizer state dict.
1005
        (Only work in dygraph mode)
1006 1007 1008 1009

        Args: 
            state_dict(dict) : Dict contains all the Tensor needed by optimizer

1010 1011
        Returns:
            None
1012 1013 1014 1015

        Examples:
            .. code-block:: python

1016 1017 1018
                import numpy as np
                import paddle
                from paddle.distributed import fleet
1019

1020 1021 1022
                fleet.init(is_collective=True)

                value = np.arange(26).reshape(2, 13).astype("float32")
1
123malin 已提交
1023
                a = paddle.to_tensor(value)
1024

1025 1026
                layer = paddle.nn.Linear(13, 5)
                adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
1027

1028 1029 1030
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)
                state_dict = adam.state_dict()
1
123malin 已提交
1031 1032 1033
                paddle.save(state_dict, "paddle_dy")
                para_state_dict = paddle.load("paddle_dy")
                adam.set_state_dict(para_state_dict)
1034 1035 1036 1037 1038 1039 1040 1041
        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.set_state_dict(state_dict)

    @dygraph_only
    def set_lr(self, value):
        """
        Set the value of the learning rate manually in the optimizer. 
1042
        (Only work in dygraph mode)
1043

1044 1045 1046
        Args:
            value (float|Tensor): the value of learning rate

1047 1048
        Returns: 
            None 
1049 1050 1051 1052

        Examples:
            .. code-block:: python

1053 1054 1055
                import numpy as np
                import paddle
                from paddle.distributed import fleet
1056

1057
                fleet.init(is_collective=True)
1058

1059
                value = np.arange(26).reshape(2, 13).astype("float32")
1
123malin 已提交
1060
                a = paddle.to_tensor(value)
1061

1062 1063
                layer = paddle.nn.Linear(13, 5)
                adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
1064

1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)

                lr_list = [0.2, 0.3, 0.4, 0.5, 0.6]
                for i in range(5):
                    adam.set_lr(lr_list[i])
                    lr = adam.get_lr()
                    print("current lr is {}".format(lr))
                # Print:
                #    current lr is 0.2
                #    current lr is 0.3
                #    current lr is 0.4
                #    current lr is 0.5
                #    current lr is 0.6
1079 1080 1081 1082 1083 1084 1085 1086
        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.set_lr(value)

    @dygraph_only
    def get_lr(self):
        """
        Get current step learning rate.
1087
        (Only work in dygraph mode)
1088 1089 1090 1091 1092

        Returns:
            float: The learning rate of the current step.

        Examples:
1
123malin 已提交
1093

1094 1095
            .. code-block:: python

1096 1097 1098 1099 1100
                import numpy as np
                import paddle
                from paddle.distributed import fleet

                fleet.init(is_collective=True)
1101

1102
                value = np.arange(26).reshape(2, 13).astype("float32")
1
123malin 已提交
1103
                a = paddle.to_tensor(value)
1104

1105 1106
                layer = paddle.nn.Linear(13, 5)
                adam = paddle.optimizer.Adam(learning_rate=0.01, parameters=layer.parameters())
1107

1108 1109
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)
1110

1111 1112
                lr = adam.get_lr()
                print(lr) # 0.01
1113 1114 1115 1116 1117 1118 1119 1120
        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.get_lr()

    @dygraph_only
    def step(self):
        """
        Execute the optimizer once.
1121
        (Only work in dygraph mode)
1122

1123 1124
        Returns:
            None
1125 1126

        Examples:
1
123malin 已提交
1127

1128 1129
            .. code-block:: python

1130 1131 1132
                import paddle
                import paddle.nn as nn
                from paddle.distributed import fleet
1133

1134 1135 1136 1137 1138
                class LinearNet(nn.Layer):
                    def __init__(self):
                        super(LinearNet, self).__init__()
                        self._linear1 = nn.Linear(10, 10)
                        self._linear2 = nn.Linear(10, 1)
1139

1140 1141
                    def forward(self, x):
                        return self._linear2(self._linear1(x))
1142

1
123malin 已提交
1143
                # 1. initialize fleet environment
1144 1145
                fleet.init(is_collective=True)

1
123malin 已提交
1146
                # 2. create layer & optimizer
1147 1148 1149 1150 1151
                layer = LinearNet()
                loss_fn = nn.MSELoss()
                adam = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=layer.parameters())

1
123malin 已提交
1152
                # 3. get data_parallel model using fleet
1153 1154 1155
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)

1
123malin 已提交
1156
                # 4. run layer
1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176
                inputs = paddle.randn([10, 10], 'float32')
                outputs = dp_layer(inputs)
                labels = paddle.randn([10, 1], 'float32')
                loss = loss_fn(outputs, labels)

                print("loss:", loss.numpy())

                loss.backward()

                adam.step()
                adam.clear_grad()


        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.step()

    @dygraph_only
    def clear_grad(self):
        """
1177 1178
        Clear the gradients of all optimized parameters for model.
        (Only work in dygraph mode)
1179

1180 1181
        Returns: 
            None
1182 1183

        Examples:
1
123malin 已提交
1184

1185 1186
            .. code-block:: python

1187 1188 1189
                import paddle
                import paddle.nn as nn
                from paddle.distributed import fleet
1190

1191 1192 1193 1194 1195
                class LinearNet(nn.Layer):
                    def __init__(self):
                        super(LinearNet, self).__init__()
                        self._linear1 = nn.Linear(10, 10)
                        self._linear2 = nn.Linear(10, 1)
1196

1197 1198
                    def forward(self, x):
                        return self._linear2(self._linear1(x))
1199

1
123malin 已提交
1200
                # 1. initialize fleet environment
1201 1202
                fleet.init(is_collective=True)

1
123malin 已提交
1203
                # 2. create layer & optimizer
1204 1205 1206 1207 1208
                layer = LinearNet()
                loss_fn = nn.MSELoss()
                adam = paddle.optimizer.Adam(
                    learning_rate=0.001, parameters=layer.parameters())

1
123malin 已提交
1209
                # 3. get data_parallel model using fleet
1210 1211 1212
                adam = fleet.distributed_optimizer(adam)
                dp_layer = fleet.distributed_model(layer)

1
123malin 已提交
1213
                # 4. run layer
1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229
                inputs = paddle.randn([10, 10], 'float32')
                outputs = dp_layer(inputs)
                labels = paddle.randn([10, 1], 'float32')
                loss = loss_fn(outputs, labels)

                print("loss:", loss.numpy())

                loss.backward()

                adam.step()
                adam.clear_grad()

        """
        # imitate target optimizer retrieval
        return self.user_defined_optimizer.clear_grad()

1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246
    def _get_amp_optimizer(self):
        # imitate target optimizer retrieval
        amp_optimizer = None
        for optimizer in self.strategy_compiler._get_applied_meta_optimizer():
            if hasattr(optimizer, 'amp_init'):
                amp_optimizer = optimizer
                break

        if amp_optimizer is None:
            if hasattr(self.user_defined_optimizer, 'amp_init'):
                amp_optimizer = self.user_defined_optimizer

        assert amp_optimizer is not None, \
            "amp_init can only be used when the amp(auto mixed precision) strategy is turned on."
        return amp_optimizer

    def get_loss_scaling(self):
1247 1248
        """Return the real-time loss scaling factor.
        """
1249 1250 1251
        amp_optimizer = self._get_amp_optimizer()
        return amp_optimizer.get_loss_scaling()

H
huangxu96 已提交
1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311
    def amp_init(self,
                 place,
                 scope=None,
                 test_program=None,
                 use_fp16_test=False):
        """
        Init the amp training, such as cast fp32 parameters to fp16 type.
  
        Args:
            place(CUDAPlace): place is used to initialize 
                fp16 parameters with fp32 values.
            scope(Scope): The scope is used to find fp32 parameters.
            test_program(Program): The program is used for testing.
            use_fp16_test(bool): Whether to use fp16 testing.
            
        Examples:
            .. code-block:: python

                import numpy as np
                import paddle
                import paddle.nn.functional as F
                paddle.enable_static()

                def run_example_code():
                    place = paddle.CUDAPlace(0)
                    exe = paddle.static.Executor(place)
                    data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32')
                    conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3)
                    # 1) Use fp16_guard to control the range of fp16 kernels used.
                    with paddle.static.amp.fp16_guard():
                        bn = paddle.static.nn.batch_norm(input=conv2d, act="relu")
                        pool = F.max_pool2d(bn, kernel_size=2, stride=2)
                        hidden = paddle.static.nn.fc(pool, size=10)
                        loss = paddle.mean(hidden)
                    # 2) Create the optimizer and set `multi_precision` to True.
                    # Setting `multi_precision` to True can avoid the poor accuracy
                    # or the slow convergence in a way. 
                    optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True)
                    # 3) These ops in `custom_black_list` will keep in the float32 computation type.
                    amp_list = paddle.static.amp.CustomOpLists(
                        custom_black_list=['pool2d'])
                    # 4) The entry of Paddle AMP.
                    # Enable pure fp16 training by setting `use_pure_fp16` to True.
                    optimizer = paddle.static.amp.decorate(
                        optimizer,
                        amp_list,
                        init_loss_scaling=128.0,
                        use_dynamic_loss_scaling=True,
                        use_pure_fp16=True)
                    # If you don't use the default_startup_program(), you sholud pass
                    # your defined `startup_program` into `minimize`.
                    optimizer.minimize(loss)
                    exe.run(paddle.static.default_startup_program())
                    # 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`).
                    # If you want to perform the testing process, you should pass `test_program` into `amp_init`.
                    optimizer.amp_init(place, scope=paddle.static.global_scope())
                    
                if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
                    run_example_code()       
        """
1312
        amp_optimizer = self._get_amp_optimizer()
1313
        return amp_optimizer.amp_init(place, scope, test_program, use_fp16_test)
H
huangxu96 已提交
1314

D
Dong Daxiang 已提交
1315 1316 1317 1318 1319 1320 1321 1322 1323
    def _final_strategy(self):
        if "valid_strategy" not in self._context:
            print(
                "WARNING: You may need to call minimize function before this function is called"
            )
            return {}
        else:
            return self._context["valid_strategy"]

1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341
    def _get_applied_meta_list(self):
        if "applied_meta_list" not in self._context:
            print(
                "WARNING: You may need to call minimize function before _get_applied_meta_list called"
            )
            return []
        else:
            return self._context["applied_meta_list"]

    def _get_applied_graph_list(self):
        if "applied_graph_list" not in self._context:
            print(
                "WARNING: You may need to call minimize function before _get_applied_graph_list called"
            )
            return []
        else:
            return self._context["applied_graph_list"]

1342 1343 1344 1345 1346 1347 1348 1349 1350
    def minimize(self,
                 loss,
                 startup_program=None,
                 parameter_list=None,
                 no_grad_set=None):
        """
        Add distributed operations to minimize ``loss`` by updating ``parameter_list``.

        Args:
1
123malin 已提交
1351
            loss (Tensor): A ``Tensor`` containing the value to minimize.
1352 1353 1354
            startup_program (Program, optional): :ref:`api_fluid_Program` for
                initializing parameters in ``parameter_list``. The default value
                is None, at this time :ref:`api_fluid_default_startup_program` will be used.
1
123malin 已提交
1355
            parameter_list (Iterable, optional): Iterable of ``Tensor`` or ``Tensor.name`` to update
1356 1357
                to minimize ``loss``. The default value is None, at this time all parameters
                will be updated.
1
123malin 已提交
1358
            no_grad_set (set, optional): Set of ``Tensor``  or ``Tensor.name`` that don't need
1359 1360 1361 1362
                to be updated. The default value is None.

        Returns:
            tuple: tuple (optimize_ops, params_grads), A list of operators appended
1
123malin 已提交
1363
            by minimize and a list of (param, grad) tensor pairs, param is
1364
            ``Parameter``, grad is the gradient value corresponding to the parameter.
1365 1366
            The returned tuple can be passed to ``fetch_list`` in ``Executor.run()`` to
            indicate program pruning. If so, the program will be pruned by ``feed`` and
1367 1368 1369
            ``fetch_list`` before run, see details in ``Executor``.

        Examples:
1
123malin 已提交
1370

1371
            .. code-block:: python
1372

1373
                import paddle
1
123malin 已提交
1374
                paddle.enable_static()
1375
                import paddle.distributed.fleet as fleet
1
123malin 已提交
1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386
                import paddle.nn.functional as F

                hid_dim = 10
                label_dim = 2
                input_x = paddle.static.data(name='x', shape=[None, 13], dtype='float32')
                input_y = paddle.static.data(name='y', shape=[None, 1], dtype='int64')
                fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim, activation='tanh')
                fc_2 = paddle.static.nn.fc(x=fc_1, size=hid_dim, activation='tanh')
                prediction = paddle.static.nn.fc(x=[fc_2], size=label_dim, activation='softmax')
                cost = F.cross_entropy(input=prediction, label=input_y)
                avg_cost = paddle.mean(x=cost)
1387

1
123malin 已提交
1388
                fleet.init(is_collective=True)
1389 1390 1391 1392
                strategy = fleet.DistributedStrategy()
                optimizer = paddle.optimizer.SGD(learning_rate=0.001)
                optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
                optimizer.minimize(avg_cost)
1393

1394
                # for more examples, please reference https://github.com/PaddlePaddle/FleetX
1395 1396

        """
D
Dong Daxiang 已提交
1397 1398 1399
        context = {}
        context["user_defined_strategy"] = copy.deepcopy(
            self._user_defined_strategy)
1400 1401 1402
        if paddle.fluid.framework.in_dygraph_mode():
            # imitate target optimizer retrieval
            target_opt = self.user_defined_optimizer
D
Dong Daxiang 已提交
1403
            self._context = context
1404 1405
            return target_opt.minimize(loss)

1406 1407
        # cache original feed forward program
        self.origin_main_program = loss.block.program
1408 1409
        context["origin_main_program"] = self.origin_main_program
        context["loss"] = loss
1410 1411
        if startup_program == None:
            self.origin_startup_program = \
1412 1413
                paddle.static.default_startup_program().clone(for_test=False)
            startup_program = paddle.static.default_startup_program()
1414 1415 1416
        else:
            self.origin_startup_program = \
                startup_program.clone(for_test=False)
1417

1418 1419
        context["origin_startup_program"] = startup_program
        context["role_maker"] = self._role_maker
1420

1421 1422 1423 1424 1425 1426
        # Use the auto-parallel's routines instead
        if self._user_defined_strategy.semi_auto:
            from ...auto_parallel.parallelizer import AutoParallelizer
            auto_parallelizer = AutoParallelizer(self)
            optimize_ops, params_grads, dist_startup_prog, dist_main_prog = auto_parallelizer.parallelize(
                loss, startup_program, parameter_list, no_grad_set)
1427

1428 1429
            return optimize_ops, params_grads, dist_startup_prog, dist_main_prog

1430 1431 1432 1433
        # compile time
        distributed_optimizer_list = \
            MetaOptimizerFactory()._get_valid_meta_optimizers(
                self.user_defined_optimizer)
D
Dong Daxiang 已提交
1434

D
Dong Daxiang 已提交
1435 1436 1437
        context["user_defined_strategy"] = copy.deepcopy(
            self._user_defined_strategy)
        copy_user_defined_strategy = copy.deepcopy(self._user_defined_strategy)
1438 1439 1440 1441 1442 1443

        # trigger the auto-parallel in very strict condition
        # strategy = DistributedStrategy()
        # strategy.auto = True
        # optimizer = paddle.optimizer.SGD(learning_rate=0.1)
        # optimizer = fleet.distributed_optimizer(optimizer, strategy)
D
Dong Daxiang 已提交
1444
        if copy_user_defined_strategy._is_strict_auto():
1445 1446
            # turn on all the strategy for each optimizer
            for opt in distributed_optimizer_list:
D
Dong Daxiang 已提交
1447
                opt._enable_strategy(copy_user_defined_strategy, context)
1448

1449 1450
        valid_optimizer_list = []
        valid_graph_optimizer_list = []
D
Dong Daxiang 已提交
1451
        can_not_apply_optimizer_list = []
1452 1453 1454 1455
        # recall meta optimizers for ranking
        for opt in distributed_optimizer_list:
            opt._set_basic_info(loss, self._role_maker,
                                self.user_defined_optimizer,
D
Dong Daxiang 已提交
1456
                                copy_user_defined_strategy)
1457 1458
            if opt._can_apply() and not opt._is_graph_out():
                valid_optimizer_list.append(opt)
D
Dong Daxiang 已提交
1459
            elif opt._can_apply() and opt._is_graph_out():
1460
                valid_graph_optimizer_list.append(opt)
D
Dong Daxiang 已提交
1461 1462
            else:
                can_not_apply_optimizer_list.append(opt)
1463
        # combine recalled meta optimizers to be a valid meta optimizer
D
Dong Daxiang 已提交
1464
        meta_optimizer, graph_optimizer = \
1465 1466
            self.strategy_compiler.generate_optimizer(
                loss, self._role_maker, self.user_defined_optimizer,
D
Dong Daxiang 已提交
1467
                copy_user_defined_strategy, valid_optimizer_list,
1468
                valid_graph_optimizer_list)
D
Dong Daxiang 已提交
1469

D
Dong Daxiang 已提交
1470
        valid_strategy = self.strategy_compiler._get_valid_strategy(
D
Dong Daxiang 已提交
1471 1472 1473
            copy_user_defined_strategy, can_not_apply_optimizer_list)

        context["valid_strategy"] = copy.deepcopy(valid_strategy)
1474

1475 1476 1477 1478 1479 1480
        applied_meta_list = self.strategy_compiler._get_applied_meta_list()
        applied_graph_list = self.strategy_compiler._get_applied_graph_list()

        context['applied_meta_list'] = applied_meta_list
        context['applied_graph_list'] = applied_graph_list

D
Dong Daxiang 已提交
1481
        self._context = context
1482

D
Dong Daxiang 已提交
1483
        self.valid_strategy = valid_strategy
1484
        self.valid_strategy._enable_env()
D
Dong Daxiang 已提交
1485

1486 1487
        optimize_ops = []
        params_grads = []
1488

1489 1490 1491 1492 1493 1494 1495 1496 1497
        if self._role_maker._is_non_distributed() and not self._is_collective:
            if self._runtime_handle is None:
                self._runtime_handle = RuntimeFactory()._create_runtime(context)

            compiled_program = compiler.CompiledProgram(
                self.origin_main_program).with_data_parallel(
                    loss_name=loss.name, share_vars_from=None)
            loss.block.program._graph = compiled_program
            return self.user_defined_optimizer.minimize(
M
MRXLT 已提交
1498
                loss, startup_program, parameter_list, no_grad_set=no_grad_set)
1499

1500 1501
        if meta_optimizer:
            optimize_ops, params_grads = meta_optimizer.minimize(
M
MRXLT 已提交
1502
                loss, startup_program, parameter_list, no_grad_set=no_grad_set)
1503

1504
            default_program = paddle.static.default_main_program()
1505 1506 1507 1508

            if id(default_program) != id(loss.block.program):
                paddle.fluid.framework.switch_main_program(loss.block.program)

1509 1510
        else:
            optimize_ops, params_grads = self.user_defined_optimizer.minimize(
M
MRXLT 已提交
1511
                loss, startup_program, parameter_list, no_grad_set=no_grad_set)
1512

1513 1514
        context["program_optimize_ops"] = optimize_ops
        context["program_params_grads"] = params_grads
1515

1516
        if graph_optimizer:
D
Dong Daxiang 已提交
1517
            optimize_ops, params_grads = graph_optimizer.minimize(
M
MRXLT 已提交
1518
                loss, startup_program, parameter_list, no_grad_set=no_grad_set)
1519 1520 1521 1522
            # since we do not encourage users to use graph operations
            # if a graph optimizer takes effect, mostly
            # optimizers_ops and params_grads are None
            # i.e. users can not modify current computation graph anymore
1523 1524
            context["graph_optimize_ops"] = optimize_ops
            context["graph_optimize_grads"] = params_grads
1525 1526
        else:
            apply_ir_passes(loss.block.program, startup_program, self)
1527

1528 1529 1530 1531 1532 1533 1534 1535 1536
        if not self._role_maker._is_heter_parameter_server_mode:
            program = paddle.static.default_main_program()
            opt_info = {}
            opt_info["mpi_size"] = self.worker_num()
            opt_info["mpi_rank"] = self.worker_index()
            for k, v in self._user_defined_strategy.trainer_desc_configs.items(
            ):
                opt_info[k] = v
            program._fleet_opt = opt_info
1537

1538
        if self._runtime_handle is None:
1539
            self._runtime_handle = RuntimeFactory()._create_runtime(context)
1540

1541 1542
        import paddle.distributed.fleet as fleet
        fleet.util._set_strategy(context["valid_strategy"])
1543 1544

        return optimize_ops, params_grads
1545 1546 1547

    @dygraph_only
    def distributed_scaler(self, scaler):
1548 1549 1550 1551 1552 1553
        def unscale_method(self, optimizer):
            if not self._enable:
                return
            if getattr(optimizer, '_param_groups', None) and isinstance(
                    optimizer._param_groups[0], dict):
                param_grads = []
1554 1555
                param_grads_fp16 = []
                param_grads_fp32 = []
1556 1557 1558 1559
                for group in optimizer._param_groups:
                    for param in group['params']:
                        if param._grad_ivar() is not None:
                            param_grads.append(param._grad_ivar())
1560 1561 1562 1563 1564
                            if param._grad_ivar(
                            ).dtype == core.VarDesc.VarType.FP16:
                                param_grads_fp16.append(param._grad_ivar())
                            else:
                                param_grads_fp32.append(param._grad_ivar())
1565 1566 1567 1568 1569
            else:
                param_grads = [
                    param._grad_ivar() for param in optimizer._parameter_list
                    if param._grad_ivar() is not None
                ]
1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589
                param_grads_fp16 = [
                    param._grad_ivar() for param in optimizer._parameter_list
                    if (param._grad_ivar() is not None) and (param._grad_ivar(
                    ).dtype == core.VarDesc.VarType.FP16)
                ]
                param_grads_fp32 = [
                    param._grad_ivar() for param in optimizer._parameter_list
                    if (param._grad_ivar() is not None) and (param._grad_ivar(
                    ).dtype == core.VarDesc.VarType.FP32)
                ]
            temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool))
            temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool))
            if len(param_grads_fp16):
                _C_ops.check_finite_and_unscale(param_grads_fp16, self._scale,
                                                param_grads_fp16,
                                                temp_found_inf_fp16)
            if len(param_grads_fp32):
                _C_ops.check_finite_and_unscale(param_grads_fp32, self._scale,
                                                param_grads_fp32,
                                                temp_found_inf_fp32)
1590

1591
            self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0
1592
            is_found_inf = paddle.to_tensor([self._found_inf], dtype="int32")
1593 1594 1595 1596 1597

            # TODO(shenliang03) Since dp allreduce in the optimizer is 
            # after the gradscaler, check_finite needs to synchronize global 
            # information. In the future, we should use check_group to speed.
            paddle.distributed.all_reduce(
1598 1599
                is_found_inf, op=paddle.distributed.ReduceOp.MAX, group=None)
            self._found_inf = is_found_inf.numpy()[0]
1600 1601 1602 1603 1604 1605 1606

        # Only tensor_parallel and pipeline_parallel need to modify scaler
        if self._hcg.get_parallel_mode() in (ParallelMode.TENSOR_PARALLEL,
                                             ParallelMode.PIPELINE_PARALLEL):
            scaler._unscale = MethodType(unscale_method, scaler)

        return scaler