fleet.py 51.6 KB
Newer Older
W
wuhuachaocoding 已提交
1
#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import copy
16
import paddle
17
import os
W
wuhuachaocoding 已提交
18
from paddle.fluid.framework import _global_flags
19
from paddle.fluid import compiler
20
from .base.role_maker import PaddleCloudRoleMaker, RoleMakerBase
W
wuhuachaocoding 已提交
21 22 23 24
from .base.strategy_compiler import StrategyCompiler
from .base.distributed_strategy import DistributedStrategy
from .base.meta_optimizer_factory import MetaOptimizerFactory
from .base.runtime_factory import RuntimeFactory
25
from paddle.fluid.wrapped_decorator import wrap_decorator
26
from paddle.fluid.dygraph import parallel_helper
27
from paddle.fluid.ir import apply_build_strategy
W
wuhuachaocoding 已提交
28 29
from .base import topology as tp
from .meta_parallel import model_parallel_random_seed
R
Roc 已提交
30
from .utils.log_util import logger, set_log_level
31

32 33
__all__ = []

34

35 36 37 38 39 40 41 42 43 44 45 46 47 48
def apply_ir_passes(main_program, startup_program, config):
    build_strategy = config._user_defined_strategy.build_strategy._copy()
    if not _global_flags()['FLAGS_apply_pass_to_program']:
        return build_strategy

    pipeline_opt = getattr(main_program, "_pipeline_opt", {})
    if pipeline_opt:
        main_program = pipeline_opt["section_program"]
        startup_program = startup_program._pipeline_opt["startup_program"]

    pass_attrs = {"use_cuda": config._is_collective}
    fuse_all_reduce = config._user_defined_strategy.fuse_all_reduce_ops
    if fuse_all_reduce and build_strategy.fuse_all_optimizer_ops:
        # FIXME(zjl): currently, fuse_all_optimizer_ops
49 50 51 52
        # have conflict with fuse_all_reduce_ops because
        # RawProgramOptimizer also inserts coalesce_tensor
        # into program. These two procedures may conflict
        # in which vars are to be fused.
R
Roc 已提交
53
        logger.warning(
54 55 56 57
            'Currently, the fuse_all_optimizer_ops pass has conflict with fuse_all_reduce_ops pass. Disable the fuse_all_optimizer_ops pass temporarily.'
        )
        build_strategy.fuse_all_optimizer_ops = False

58 59 60
    return apply_build_strategy(
        main_program, startup_program, build_strategy, pass_attrs
    )
61 62


63 64 65 66 67 68 69 70 71 72 73 74
def _inited_runtime_handler_(func):
    def __impl__(*args, **kwargs):
        cls = args[0]

        if cls._runtime_handle is None:
            raise ValueError("Fleet can not find suitable runtime handler")

        return func(*args, **kwargs)

    return __impl__


75 76 77 78
def _is_non_distributed_check_(func):
    def __impl__(*args, **kwargs):
        cls = args[0]

79 80 81 82
        if (
            cls._role_maker is not None
            and cls._role_maker._is_non_distributed() is True
        ):
R
Roc 已提交
83
            logger.warning(
84 85 86
                "%s() function doesn't work when use non_distributed fleet."
                % (func.__name__)
            )
87 88 89 90 91 92 93
            return

        return func(*args, **kwargs)

    return __impl__


94
inited_runtime_handler = wrap_decorator(_inited_runtime_handler_)
95
is_non_distributed_check = wrap_decorator(_is_non_distributed_check_)
96 97


98 99 100
class Fleet(object):
    """
    Unified API for distributed training of PaddlePaddle
101
    Please reference the https://github.com/PaddlePaddle/PaddleFleetX for details
102 103 104 105 106


    Returns:
        Fleet: A Fleet instance

107
    Example for collective training:
1
123malin 已提交
108

109 110
        .. code-block:: python

1
123malin 已提交
111 112
            import paddle
            paddle.enable_static()
113
            import paddle.distributed.fleet as fleet
114 115 116

            fleet.init(is_collective=True)

117 118 119
            strategy = fleet.DistributedStrategy()
            optimizer = paddle.optimizer.SGD(learning_rate=0.001)
            optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
120 121 122 123 124 125 126 127

            # do distributed training


    Example for parameter server training:

        .. code-block:: python

1
123malin 已提交
128 129
            import paddle
            paddle.enable_static()
130 131
            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
S
ShenLiang 已提交
132
            fleet.init(strategy=strategy)
133

134
            optimizer = paddle.optimizer.SGD(learning_rate=0.001)
135
            optimizer = fleet.distributed_optimizer(optimizer)
136

137 138
            if fleet.is_first_worker():
                print("this is first worker")
139

140 141
            print("current node index: {}".format(fleet.worker_index()))
            print("total number of worker num: {}".format(fleet.worker_num()))
142

143 144 145
            if fleet.is_worker():
                print("this is worker")
            print("worker endpoints: {}".format(fleet.worker_endpoints(to_string=True)))
146

147 148
            print("server num: {}".format(fleet.server_num()))
            print("server endpoints: {}".format(fleet.server_endpoints(to_string=True)))
149

150 151 152
            if fleet.is_server():
                print("this is server")
            fleet.stop_worker()
153 154


155 156 157
    """

    def __init__(self):
158
        self._role_maker = None
159
        self.strategy_compiler = None
160
        self._is_collective = False
161
        self._runtime_handle = None
D
Dong Daxiang 已提交
162 163
        self._util = None
        self._context = {}
W
wuhuachaocoding 已提交
164
        self.user_defined_optimizer = paddle.optimizer.Optimizer(0.0)
165

166 167 168 169 170 171 172
    def init(
        self,
        role_maker=None,
        is_collective=False,
        strategy=None,
        log_level="INFO",
    ):
173 174 175
        """
        Initialize role_maker in Fleet.

176 177 178 179 180
        This function is responsible for the distributed architecture
        what you want to run your code behind.

        Args:
            role_maker (RoleMakerBase, optional): A ``RoleMakerBase`` containing the configuration
181
                of environment variables related to distributed training.If you did not initialize
182 183
                the rolemaker by yourself, it will be automatically initialized to PaddleRoleMaker.
                The default value is None.
184
            is_collective (Boolean, optional): A ``Boolean`` variable determines whether the program
185
                runs on Collective mode or ParameterServer mode. True means the program runs on
186
                Collective mode, and False means running on ParameterServer mode. The default value
187
                is False.
188
            strategy (DistributedStrategy): Extra properties for distributed training.
189
                For details, please refer to paddle.distributed.fleet.DistributedStrategy. Default: None.
R
Roc 已提交
190 191
            log_level (Integer, String, optional): A ``Integer`` or ``String`` Variable determining how hight
                the logging level is. Default is "INFO".
192 193


194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
        Returns:
            None

        Examples1:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

        Examples2:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init(is_collective=True)

        Examples3:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
1
123malin 已提交
216
                role = fleet.PaddleCloudRoleMaker()
217
                fleet.init(role)
218

219 220 221 222 223 224
        Examples4:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                strategy = fleet.DistributedStrategy()
S
ShenLiang 已提交
225
                fleet.init(strategy=strategy)
226

R
Roc 已提交
227 228 229 230 231 232 233 234
        Examples5:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                strategy = fleet.DistributedStrategy()
                fleet.init(log_level = "DEBUG")

235
        """
R
Roc 已提交
236 237 238

        set_log_level(log_level)

S
ShenLiang 已提交
239 240 241
        if strategy is None:
            strategy = DistributedStrategy()
        self._user_defined_strategy = copy.deepcopy(strategy)
242 243

        if role_maker is None:
244 245 246
            if isinstance(is_collective, bool):
                self._is_collective = is_collective
                self._role_maker = PaddleCloudRoleMaker(
247 248
                    is_collective=self._is_collective
                )
249 250
            else:
                raise ValueError(
251 252 253 254
                    "`is_collective` should be instance of `bool`, but got {}".format(
                        type(is_collective)
                    )
                )
255
        else:
256 257
            if isinstance(role_maker, RoleMakerBase):
                self._role_maker = role_maker
258
                self._is_collective = role_maker._is_collective
259 260
            else:
                raise ValueError(
261 262 263 264
                    "`role_maker` should be subclass of `RoleMakerBase`, but got {}".format(
                        type(role_maker)
                    )
                )
265
        self._role_maker._generate_role()
266

267
        import paddle.distributed.fleet as fleet
268

269 270
        fleet.util._set_role_maker(self._role_maker)

271
        self.strategy_compiler = StrategyCompiler()
272 273 274 275 276 277 278 279 280

        if self._role_maker._is_non_distributed() and self._is_collective:
            if paddle.fluid.core.is_compiled_with_cuda():
                gpus_num = paddle.fluid.core.get_cuda_device_count()
                if gpus_num != 1:
                    raise ValueError(
                        "CUDA_VISIBLE_DEVICES shoule be set only 1 card if you use `python` to launch fleet program."
                    )

J
Jiabin Yang 已提交
281
        if paddle.fluid.framework._non_static_mode():
282
            if self.worker_num() == 1:
283 284 285
                # if worker_num is 1, should construct default topology & hcg
                self._topology = tp.CommunicateTopology()
                self._hcg = tp.HybridCommunicateGroup(self._topology)
286
                return
287
            if parallel_helper._is_parallel_ctx_initialized():
R
Roc 已提交
288
                logger.warning(
289 290
                    "The dygraph parallel environment has been initialized."
                )
291
            else:
292 293
                # FLAGS_nccl_nrings is used for dynamic graph multi-stream communication
                if "FLAGS_nccl_nrings" in os.environ:
R
Roc 已提交
294
                    logger.warning(
295 296
                        "You have set the environment variable FLAGS_nccl_nrings "
                        "outside the program, so the nccl_comm_num in "
297 298
                        "DistributedStrategy will not take effect here."
                    )
299 300
                else:
                    os.environ["FLAGS_nccl_nrings"] = str(
301 302
                        self._user_defined_strategy.nccl_comm_num
                    )
303
                paddle.distributed.init_parallel_env()
304

K
kuizhiqing 已提交
305
            # hybrid parallel not support for npu/xpu
306
            if not self._user_defined_strategy.heter_ccl_mode:
K
kuizhiqing 已提交
307 308 309 310
                # init hybrid parallel environment in dygraph
                if tp._HYBRID_PARALLEL_GROUP is None:
                    self._init_hybrid_parallel_env()
                else:
R
Roc 已提交
311
                    logger.warning(
K
kuizhiqing 已提交
312 313
                        "The dygraph hybrid parallel environment has been initialized."
                    )
W
WangXi 已提交
314 315 316 317 318 319 320 321 322 323
        elif self._is_collective:
            use_sharding = self._user_defined_strategy.sharding

            # global group
            global_rank = self.worker_index()
            global_world_size = self.worker_num()
            # NOTE(wangxi): see sharding_optimizer
            global_ring_id = 3 if use_sharding else 0
            global_ranks = list(range(global_world_size))

324 325
            if tp._HYBRID_PARALLEL_GROUP is None:
                tp._CommunicateGroup()
W
WangXi 已提交
326 327
            cg = tp._HYBRID_PARALLEL_GROUP
            self._hcg = cg
328 329 330 331 332 333 334
            cg.set_comm_group(
                'global',
                global_rank,
                global_world_size,
                global_ring_id,
                global_ranks,
            )
W
WangXi 已提交
335

Y
Yuang Liu 已提交
336 337 338
            use_tensor_parallel = self._user_defined_strategy.tensor_parallel
            use_mp = use_sharding or use_tensor_parallel

W
WangXi 已提交
339
            # hybrid group
340 341
            if use_mp is False:
                return
Y
Yuang Liu 已提交
342 343 344 345 346 347 348 349

            mp_degree_sharding = 1
            mp_degree_tensor_parallel = 1
            if use_sharding:
                sharding_configs = self._user_defined_strategy.sharding_configs
                mp_degree_sharding = int(sharding_configs['mp_degree'])

            if use_tensor_parallel:
350 351 352
                tensor_parallel_configs = (
                    self._user_defined_strategy.tensor_parallel_configs
                )
353
                mp_degree_tensor_parallel = int(
354 355
                    tensor_parallel_configs['tensor_parallel_degree']
                )
Y
Yuang Liu 已提交
356 357 358

            if use_sharding and use_tensor_parallel:
                assert mp_degree_sharding == mp_degree_tensor_parallel
W
WangXi 已提交
359

360 361 362 363 364
            mp_degree = (
                mp_degree_sharding
                if use_sharding
                else mp_degree_tensor_parallel
            )
W
WangXi 已提交
365 366 367 368 369 370 371 372

            if mp_degree > 1:
                assert global_world_size % mp_degree == 0
                # NOTE(wangxi): mp_ring_id sync with sharding_optimizer.py _build_groups
                mp_ring_id = 0
                mp_rank = global_rank % mp_degree
                mp_group_id = global_rank // mp_degree
                mp_group_ranks = [
373 374
                    idx
                    for idx in global_ranks
W
WangXi 已提交
375 376
                    if idx // mp_degree == mp_group_id
                ]
377 378 379
                cg.set_comm_group(
                    'model', mp_rank, mp_degree, mp_ring_id, mp_group_ranks
                )
W
wuhuachaocoding 已提交
380
        return self
381 382

    def _init_hybrid_parallel_env(self):
383
        """initialize the hybrid environment"""
384 385 386 387
        self.hybrid_configs = self._user_defined_strategy.hybrid_configs
        self.dp_degree = self.hybrid_configs["dp_degree"]
        self.mp_degree = self.hybrid_configs["mp_degree"]
        self.pp_degree = self.hybrid_configs["pp_degree"]
J
JZ-LIANG 已提交
388
        self.sharding_degree = self.hybrid_configs["sharding_degree"]
389 390 391

        assert self.mp_degree >= 0, "mp_degree should be greater or equal to 0"
        assert self.pp_degree >= 0, "pp_degree should be greater or equal to 0"
392 393 394
        assert (
            self.sharding_degree >= 0
        ), "sharding_degree should be greater or equal to 0"
395 396 397 398 399 400 401 402 403 404 405

        self.mp_degree = max(self.mp_degree, 1)
        self.pp_degree = max(self.pp_degree, 1)

        if self.dp_degree < 0:
            nranks = paddle.distributed.get_world_size()
            self.dp_degree = nranks // (self.mp_degree * self.pp_degree)

        self.dp_degree = max(self.dp_degree, 1)

        self._topology = tp.CommunicateTopology(
J
JZ-LIANG 已提交
406 407
            hybrid_group_names=["data", "pipe", "sharding", "model"],
            dims=[
408 409 410 411 412 413
                self.dp_degree,
                self.pp_degree,
                self.sharding_degree,
                self.mp_degree,
            ],
        )
414 415 416

        self._hcg = tp.HybridCommunicateGroup(self._topology)

417
        if self.mp_degree > 1:
418 419 420
            tensor_parallel_configs = (
                self._user_defined_strategy.tensor_parallel_configs
            )
421 422 423 424 425 426
            tensor_init_seed = tensor_parallel_configs["tensor_init_seed"]
            if tensor_init_seed == -1:
                model_parallel_random_seed()
            else:
                model_parallel_random_seed(tensor_init_seed)

427 428 429 430 431 432 433 434
    def get_hybrid_communicate_group(self):
        assert self._hcg is not None
        return self._hcg

    def get_hybrid_parallel_topology(self):
        assert self._topology is not None
        return self._topology

435 436 437 438 439 440 441
    def is_first_worker(self):
        """
        Check whether the node is the first instance of worker.

        Returns:
            bool: True if this is the first node of worker,
                  False if not.
442

443 444 445 446 447 448 449 450
        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.is_first_worker()

451
        """
452
        return self._role_maker._is_first_worker()
453 454 455 456 457 458 459

    def worker_index(self):
        """
        Get current worker index.

        Returns:
            int: node id
460 461 462 463

        Examples:

            .. code-block:: python
1
123malin 已提交
464

465 466 467 468
                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.worker_index()

469
        """
470
        return self._role_maker._worker_index()
471 472 473 474 475 476 477

    def worker_num(self):
        """
        Get current total worker number.

        Returns:
            int: worker numbers
1
123malin 已提交
478

479
        Examples:
1
123malin 已提交
480

481 482 483 484 485 486
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.worker_num()

487
        """
488
        return self._role_maker._worker_num()
489

490 491 492 493 494 495 496 497 498 499 500 501
    def node_num(self):
        return self._role_maker._get_node_num()

    def local_rank(self):
        return self._role_maker._get_local_rank()

    def local_device_ids(self):
        return self._role_maker._get_local_device_ids()

    def world_device_ids(self):
        return self._role_maker._get_world_device_ids()

502 503 504 505 506 507 508
    def is_worker(self):
        """
        Check whether the node is an instance of worker.

        Returns:
            bool: True if this is a node of worker,
                  False if not.
509 510

        Examples:
1
123malin 已提交
511

512 513 514 515 516 517
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.is_worker()

518
        """
519
        return self._role_maker._is_worker()
520

521 522 523
    def is_coordinator(self):
        return self._role_maker._is_coordinator()

524 525
    def worker_endpoints(self, to_string=False):
        """
526
        Get current worker endpoints, such as ["127.0.0.1:1001", "127.0.0.1:1002"].
527 528 529

        Returns:
            list/string: server endpoints
530 531

        Examples:
1
123malin 已提交
532

533 534 535 536 537 538
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.worker_endpoints()

539 540
        """
        if to_string:
541
            return ",".join(self._role_maker._get_trainer_endpoints())
542
        else:
543
            return self._role_maker._get_trainer_endpoints()
544 545 546 547 548 549 550

    def server_num(self):
        """
        Get current total worker number.

        Returns:
            int: server number
551 552

        Examples:
1
123malin 已提交
553

554
            .. code-block:: python
1
123malin 已提交
555 556 557 558

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.server_num()
559
        """
560
        return len(self._role_maker._get_pserver_endpoints())
561 562 563 564 565 566 567

    def server_index(self):
        """
        Get current server index.

        Returns:
            int: node id
568 569

        Examples:
1
123malin 已提交
570

571 572 573 574 575 576
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.server_index()

577
        """
578
        return self._role_maker._server_index()
579 580 581 582 583 584 585

    def server_endpoints(self, to_string=False):
        """
        Get current server endpoints, such as ["127.0.0.1:1001", "127.0.0.1:1002"].

        Returns:
            list/string: server endpoints
586 587

        Examples:
1
123malin 已提交
588

589 590 591 592 593 594
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.server_endpoints()

595
        """
596

597
        if to_string:
598
            return ",".join(self._role_maker._get_pserver_endpoints())
599
        else:
600
            return self._role_maker._get_pserver_endpoints()
601 602 603 604 605 606 607 608

    def is_server(self):
        """
        Check whether the node is an instance of server.

        Returns:
            bool: True if this is a node of server,
                  False if not.
609 610 611 612

        Examples:

            .. code-block:: python
1
123malin 已提交
613

614 615 616 617
                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.is_server()

618
        """
619 620
        return self._role_maker._is_server()

621 622
    def barrier_worker(self):
        """
623 624 625 626
        barrier all workers

        Returns:
            None
627
        """
628
        self._role_maker._barrier("worker")
629

630
    @is_non_distributed_check
631
    @inited_runtime_handler
632
    def init_worker(self, scopes=None):
633
        """
634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651
        initialize `Communicator` for parameter server training.


        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_worker()

652
        """
653
        self._runtime_handle._init_worker(scopes)
654

655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673
    @is_non_distributed_check
    @inited_runtime_handler
    def init_coordinator(self, scopes=None):
        """
        initialize coordinator node
        """
        self._runtime_handle._init_coordinator(scopes)

    def make_fl_strategy(self):
        self._runtime_handle._make_fl_strategy()

    @is_non_distributed_check
    @inited_runtime_handler
    def get_fl_client(self):
        """
        get worker(training node) ptr
        """
        return self._runtime_handle._worker

674
    @is_non_distributed_check
675
    @inited_runtime_handler
676
    def init_server(self, *args, **kwargs):
677
        """
678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696
        init_server executor to initialize startup program,
        if the `args` is not empty, it will run load_persistables for increment training.


        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_server()

697
        """
698
        self._runtime_handle._init_server(*args, **kwargs)
699

Z
zmxdream 已提交
700 701
    @is_non_distributed_check
    @inited_runtime_handler
T
Thunderbrook 已提交
702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719
    def load_model(self, path, mode):
        """
        load fleet model from path


        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

720
                fleet.load_model("path", mode=0)
T
Thunderbrook 已提交
721 722

        """
723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773
        self._runtime_handle._load_persistables(path, mode)

    @is_non_distributed_check
    @inited_runtime_handler
    def load_one_table(self, table_id, path, mode):
        """
        load fleet one table from path


        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.load_one_table(0, "path", mode=0)

        """
        self._runtime_handle._load_one_table(table_id, path, mode)

    @is_non_distributed_check
    @inited_runtime_handler
    def load_inference_model(self, path, mode):
        """
        load fleet inference model from path


        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.load_inference_model("path", mode=1)

        """
        self._runtime_handle._load_inference_model(path, mode)
T
Thunderbrook 已提交
774

775
    @is_non_distributed_check
776
    @inited_runtime_handler
777 778
    def run_server(self):
        """
779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796
        run server will run pserver main program with executor.

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                if fleet.is_server():
                    fleet.init_server()

797 798 799
        """
        self._runtime_handle._run_server()

800
    @is_non_distributed_check
801
    @inited_runtime_handler
802 803
    def stop_worker(self):
        """
804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820
        stop `Communicator` and give training complete notice to parameter server.

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_server()

821 822 823
        """
        self._runtime_handle._stop_worker()

Z
zmxdream 已提交
824 825
    @is_non_distributed_check
    @inited_runtime_handler
T
tangwei12 已提交
826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859
    def save(self, dirname, feed=[], fetch=[], **configs):
        inference = True

        if not feed and not fetch:
            inference = False

        place = paddle.CPUPlace()
        executor = paddle.static.Executor(place)

        if inference:
            feeded_var_names = []
            fetch_var_names = []

            for var in feed:
                if isinstance(var, str):
                    feeded_var_names.append(var)
                elif isinstance(var, paddle.static.Variable):
                    feeded_var_names.append(var.name)
                else:
                    raise ValueError("feed must be [str|Variable]")

            for var in fetch:
                if isinstance(var, str):
                    fetch_var_names.append(var)
                elif isinstance(var, paddle.static.Variable):
                    fetch_var_names.append(var.name)
                else:
                    raise ValueError("feed must be [str|Variable]")

            fetch_vars = [
                paddle.static.default_main_program().global_block().var(name)
                for name in fetch_var_names
            ]

860 861 862
            self._runtime_handle._save_inference_model(
                executor, dirname, feeded_var_names, fetch_vars, None, True, 0
            )
T
tangwei12 已提交
863 864 865 866
        else:
            increment_mode = 0
            if "mode" in configs:
                increment_mode = int(configs["mode"])
867 868 869
            self._runtime_handle._save_persistables(
                executor, dirname, main_program=None, mode=increment_mode
            )
T
tangwei12 已提交
870

Z
zmxdream 已提交
871 872
    @is_non_distributed_check
    @inited_runtime_handler
873 874 875 876 877 878 879 880 881 882
    def save_inference_model(
        self,
        executor,
        dirname,
        feeded_var_names,
        target_vars,
        main_program=None,
        export_for_deployment=True,
        mode=0,
    ):
883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902
        """
        save inference model for inference.

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_server()

        """

903 904 905 906 907 908 909 910 911
        self._runtime_handle._save_inference_model(
            executor,
            dirname,
            feeded_var_names,
            target_vars,
            main_program,
            export_for_deployment,
            mode,
        )
912

Z
zmxdream 已提交
913 914
    @is_non_distributed_check
    @inited_runtime_handler
915
    def save_persistables(self, executor, dirname, main_program=None, mode=0):
916 917
        """

1
123malin 已提交
918
        saves all persistable tensors from :code:`main_program` to
919 920
        the folder :code:`dirname`. You can refer to

1
123malin 已提交
921 922
        The :code:`dirname` is used to specify the folder where persistable tensors
        are going to be saved. If you would like to save tensors in separate
923 924 925
        files, set :code:`filename` None.

        Args:
1
123malin 已提交
926
            executor(Executor): The executor to run for saving persistable tensors.
927 928 929 930 931
                                You can refer to :ref:`api_guide_executor_en` for
                                more details.

            dirname(str, optional): The saving directory path.
                                When you need to save the parameter to the memory, set it to None.
1
123malin 已提交
932
            main_program(Program, optional): The program whose persistbale tensors will
933 934 935 936 937 938 939 940 941 942
                                             be saved. Default: None.


        Returns:
            None

        Examples:

            .. code-block:: text

1
123malin 已提交
943 944
                import paddle
                paddle.enable_static()
945 946 947 948 949 950 951
                import paddle.distributed.fleet as fleet

                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

1
123malin 已提交
952 953
                exe = paddle.static.Executor(paddle.CPUPlace())
                fleet.save_persistables(exe, "dirname", paddle.static.default_main_program())
954 955

        """
956 957 958
        self._runtime_handle._save_persistables(
            executor, dirname, main_program, mode
        )
959

Z
zhaocaibei123 已提交
960 961 962 963 964
    @is_non_distributed_check
    @inited_runtime_handler
    def save_cache_model(self, dirname, **configs):
        return self._runtime_handle._save_cache_model(dirname, **configs)

965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996
    @is_non_distributed_check
    @inited_runtime_handler
    def check_save_pre_patch_done(self):
        return self._runtime_handle._check_save_pre_patch_done()

    @is_non_distributed_check
    @inited_runtime_handler
    def save_one_table(self, table_id, path, mode):
        """
        save fleet one table from path


        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.save_one_table(0, "path", mode=0)

        """
        self._runtime_handle._save_one_table(table_id, path, mode)

    @is_non_distributed_check
    @inited_runtime_handler
997 998 999
    def save_dense_params(
        self, executor, dirname, scope, program, var_names=None
    ):
1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022
        """
        save fleet one table from path


        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                import paddle
                place = paddle.fluid.CPUPlace()
                exe = paddle.fluid.Executor(place)

                # build net
                # fleet.distributed_optimizer(...)

                fleet.save_dense_params(exe, "path", scope=paddle.static.global_scope(), program=paddle.static.default_main_program())

        """
1023 1024 1025
        self._runtime_handle._save_dense_params(
            executor, dirname, scope, program, var_names
        )
1026

1027
    def shrink(self, threshold=None):
1028 1029
        self._runtime_handle._shrink(threshold)

1030
    def distributed_optimizer(self, optimizer, strategy=None):
1031
        """
1032 1033 1034 1035 1036 1037 1038
        Optimizer for distributed training.

        For the distributed training, this method would rebuild a new instance of DistributedOptimizer.
        Which has basic Optimizer function and special features for distributed training.

        Args:
            optimizer(Optimizer): The executor to run for init server.
1039
            strategy(DistributedStrategy): Extra properties for distributed optimizer.
1040
                It is recommended to use DistributedStrategy in fleet.init(). The strategy
1041 1042
                here is for compatibility. If the strategy in fleet.distributed_optimizer()
                is not None, then it will overwrite the DistributedStrategy in fleet.init(),
1043
                which will take effect in distributed training.
1044

1045
        Returns:
1046
            Fleet: instance of fleet.
1047 1048

        Examples:
1049

1050
            .. code-block:: python
1051

1
123malin 已提交
1052
                import paddle
1053
                import paddle.distributed.fleet as fleet
1
123malin 已提交
1054
                fleet.init(is_collective=True)
1055 1056 1057 1058
                strategy = fleet.DistributedStrategy()
                optimizer = paddle.optimizer.SGD(learning_rate=0.001)
                optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)

1059 1060
        """
        self.user_defined_optimizer = optimizer
1061

1062
        if strategy is not None:
T
tangwei12 已提交
1063
            if self._is_collective:
R
Roc 已提交
1064
                logger.warning(
T
tangwei12 已提交
1065 1066 1067 1068
                    "It is recommended to use DistributedStrategy "
                    "in fleet.init(). The strategy here is only for compatibility. "
                    "If the strategy in fleet.distributed_optimizer() is "
                    "not None, then it will overwrite the DistributedStrategy in fleet.init(), "
1069 1070
                    "which will take effect in distributed training."
                )
1071
            self._user_defined_strategy = copy.deepcopy(strategy)
D
Dong Daxiang 已提交
1072 1073

        self._context = {}
S
ShenLiang 已提交
1074

1075 1076
        return self

1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088
    def _get_amp_optimizer(self):
        # imitate target optimizer retrieval
        amp_optimizer = None
        for optimizer in self.strategy_compiler._get_applied_meta_optimizer():
            if hasattr(optimizer, 'amp_init'):
                amp_optimizer = optimizer
                break

        if amp_optimizer is None:
            if hasattr(self.user_defined_optimizer, 'amp_init'):
                amp_optimizer = self.user_defined_optimizer

1089 1090 1091
        assert (
            amp_optimizer is not None
        ), "amp_init can only be used when the amp(auto mixed precision) strategy is turned on."
1092 1093 1094
        return amp_optimizer

    def get_loss_scaling(self):
1095
        """Return the real-time loss scaling factor."""
1096 1097 1098
        amp_optimizer = self._get_amp_optimizer()
        return amp_optimizer.get_loss_scaling()

1099 1100 1101
    def amp_init(
        self, place, scope=None, test_program=None, use_fp16_test=False
    ):
H
huangxu96 已提交
1102 1103
        """
        Init the amp training, such as cast fp32 parameters to fp16 type.
1104

H
huangxu96 已提交
1105
        Args:
1106
            place(CUDAPlace): place is used to initialize
H
huangxu96 已提交
1107 1108 1109 1110
                fp16 parameters with fp32 values.
            scope(Scope): The scope is used to find fp32 parameters.
            test_program(Program): The program is used for testing.
            use_fp16_test(bool): Whether to use fp16 testing.
1111

H
huangxu96 已提交
1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131
        Examples:
            .. code-block:: python

                import paddle
                import paddle.nn.functional as F
                paddle.enable_static()

                def run_example_code():
                    place = paddle.CUDAPlace(0)
                    exe = paddle.static.Executor(place)
                    data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32')
                    conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3)
                    # 1) Use fp16_guard to control the range of fp16 kernels used.
                    with paddle.static.amp.fp16_guard():
                        bn = paddle.static.nn.batch_norm(input=conv2d, act="relu")
                        pool = F.max_pool2d(bn, kernel_size=2, stride=2)
                        hidden = paddle.static.nn.fc(pool, size=10)
                        loss = paddle.mean(hidden)
                    # 2) Create the optimizer and set `multi_precision` to True.
                    # Setting `multi_precision` to True can avoid the poor accuracy
1132
                    # or the slow convergence in a way.
H
huangxu96 已提交
1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151
                    optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True)
                    # 3) These ops in `custom_black_list` will keep in the float32 computation type.
                    amp_list = paddle.static.amp.CustomOpLists(
                        custom_black_list=['pool2d'])
                    # 4) The entry of Paddle AMP.
                    # Enable pure fp16 training by setting `use_pure_fp16` to True.
                    optimizer = paddle.static.amp.decorate(
                        optimizer,
                        amp_list,
                        init_loss_scaling=128.0,
                        use_dynamic_loss_scaling=True,
                        use_pure_fp16=True)
                    # If you don't use the default_startup_program(), you sholud pass
                    # your defined `startup_program` into `minimize`.
                    optimizer.minimize(loss)
                    exe.run(paddle.static.default_startup_program())
                    # 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`).
                    # If you want to perform the testing process, you should pass `test_program` into `amp_init`.
                    optimizer.amp_init(place, scope=paddle.static.global_scope())
1152

H
huangxu96 已提交
1153
                if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
1154
                    run_example_code()
H
huangxu96 已提交
1155
        """
1156
        amp_optimizer = self._get_amp_optimizer()
1157
        return amp_optimizer.amp_init(place, scope, test_program, use_fp16_test)
H
huangxu96 已提交
1158

D
Dong Daxiang 已提交
1159 1160 1161 1162 1163 1164 1165 1166 1167
    def _final_strategy(self):
        if "valid_strategy" not in self._context:
            print(
                "WARNING: You may need to call minimize function before this function is called"
            )
            return {}
        else:
            return self._context["valid_strategy"]

1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185
    def _get_applied_meta_list(self):
        if "applied_meta_list" not in self._context:
            print(
                "WARNING: You may need to call minimize function before _get_applied_meta_list called"
            )
            return []
        else:
            return self._context["applied_meta_list"]

    def _get_applied_graph_list(self):
        if "applied_graph_list" not in self._context:
            print(
                "WARNING: You may need to call minimize function before _get_applied_graph_list called"
            )
            return []
        else:
            return self._context["applied_graph_list"]

1186 1187 1188
    def minimize(
        self, loss, startup_program=None, parameter_list=None, no_grad_set=None
    ):
1189 1190 1191 1192
        """
        Add distributed operations to minimize ``loss`` by updating ``parameter_list``.

        Args:
1
123malin 已提交
1193
            loss (Tensor): A ``Tensor`` containing the value to minimize.
1194 1195 1196
            startup_program (Program, optional): :ref:`api_fluid_Program` for
                initializing parameters in ``parameter_list``. The default value
                is None, at this time :ref:`api_fluid_default_startup_program` will be used.
1
123malin 已提交
1197
            parameter_list (Iterable, optional): Iterable of ``Tensor`` or ``Tensor.name`` to update
1198 1199
                to minimize ``loss``. The default value is None, at this time all parameters
                will be updated.
1
123malin 已提交
1200
            no_grad_set (set, optional): Set of ``Tensor``  or ``Tensor.name`` that don't need
1201 1202 1203 1204
                to be updated. The default value is None.

        Returns:
            tuple: tuple (optimize_ops, params_grads), A list of operators appended
1
123malin 已提交
1205
            by minimize and a list of (param, grad) tensor pairs, param is
1206
            ``Parameter``, grad is the gradient value corresponding to the parameter.
1207 1208
            The returned tuple can be passed to ``fetch_list`` in ``Executor.run()`` to
            indicate program pruning. If so, the program will be pruned by ``feed`` and
1209 1210 1211
            ``fetch_list`` before run, see details in ``Executor``.

        Examples:
1
123malin 已提交
1212

1213
            .. code-block:: python
1214

1215
                import paddle
1
123malin 已提交
1216
                paddle.enable_static()
1217
                import paddle.distributed.fleet as fleet
1
123malin 已提交
1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228
                import paddle.nn.functional as F

                hid_dim = 10
                label_dim = 2
                input_x = paddle.static.data(name='x', shape=[None, 13], dtype='float32')
                input_y = paddle.static.data(name='y', shape=[None, 1], dtype='int64')
                fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim, activation='tanh')
                fc_2 = paddle.static.nn.fc(x=fc_1, size=hid_dim, activation='tanh')
                prediction = paddle.static.nn.fc(x=[fc_2], size=label_dim, activation='softmax')
                cost = F.cross_entropy(input=prediction, label=input_y)
                avg_cost = paddle.mean(x=cost)
1229

1
123malin 已提交
1230
                fleet.init(is_collective=True)
1231 1232 1233 1234
                strategy = fleet.DistributedStrategy()
                optimizer = paddle.optimizer.SGD(learning_rate=0.001)
                optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
                optimizer.minimize(avg_cost)
1235

1236
                # for more examples, please reference https://github.com/PaddlePaddle/PaddleFleetX
1237 1238

        """
1239
        if not isinstance(loss, list):
1240 1241 1242
            return self._minimize_impl(
                loss, startup_program, parameter_list, no_grad_set
            )
1243
        else:
1244 1245 1246 1247 1248
            if (
                paddle.fluid.framework._non_static_mode()
                or self._role_maker._is_non_distributed()
                or self._is_collective
            ):
1249
                raise ValueError("loss can be list only in PS mode")
1250 1251 1252 1253 1254 1255 1256
            return self._minimize_losses_impl(
                loss, startup_program, parameter_list, no_grad_set
            )

    def _minimize_impl(
        self, loss, startup_program=None, parameter_list=None, no_grad_set=None
    ):
D
Dong Daxiang 已提交
1257 1258
        context = {}
        context["user_defined_strategy"] = copy.deepcopy(
1259 1260
            self._user_defined_strategy
        )
J
Jiabin Yang 已提交
1261
        if paddle.fluid.framework._non_static_mode():
1262 1263
            # imitate target optimizer retrieval
            target_opt = self.user_defined_optimizer
D
Dong Daxiang 已提交
1264
            self._context = context
1265 1266
            return target_opt.minimize(loss)

1267 1268
        # cache original feed forward program
        self.origin_main_program = loss.block.program
B
Baibaifan 已提交
1269 1270 1271 1272
        # add distributed attr
        if not hasattr(self.origin_main_program, "distributed_info_"):
            setattr(self.origin_main_program, "distributed_info_", dict())
            self.origin_main_program.distributed_info_[
1273 1274
                "dp_degree"
            ] = self._user_defined_strategy.sharding_configs["dp_degree"]
B
Baibaifan 已提交
1275
            self.origin_main_program.distributed_info_[
1276 1277
                "mp_degree"
            ] = self._user_defined_strategy.sharding_configs["mp_degree"]
B
Baibaifan 已提交
1278
            self.origin_main_program.distributed_info_[
1279 1280
                "pp_degree"
            ] = self._user_defined_strategy.sharding_configs["pp_degree"]
B
Baibaifan 已提交
1281
            self.origin_main_program.distributed_info_[
1282 1283
                "sharding_degree"
            ] = self._user_defined_strategy.sharding_configs["sharding_degree"]
B
Baibaifan 已提交
1284

1285
        context["origin_main_program"] = self.origin_main_program
1286
        context["origin_main_programs"] = [self.origin_main_program]
1287
        context["loss"] = loss
1288
        if startup_program is None:
1289
            self.origin_startup_program = (
1290
                paddle.static.default_startup_program().clone(for_test=False)
1291
            )
1292
            startup_program = paddle.static.default_startup_program()
1293
        else:
1294
            self.origin_startup_program = startup_program.clone(for_test=False)
1295

1296
        context["origin_startup_program"] = startup_program
1297
        context["origin_startup_programs"] = [startup_program]
1298
        context["role_maker"] = self._role_maker
1299

1300
        # Use the auto-parallel's routines instead
1301 1302 1303 1304
        if (
            self._user_defined_strategy.semi_auto
            or self._user_defined_strategy.auto_search
        ):
W
wuhuachaocoding 已提交
1305
            from ..auto_parallel.parallelizer import AutoParallelizer
1306

1307
            auto_parallelizer = AutoParallelizer(self)
1308 1309 1310 1311 1312 1313 1314 1315
            (
                optimize_ops,
                params_grads,
                dist_startup_prog,
                dist_main_prog,
            ) = auto_parallelizer.parallelize(
                loss, startup_program, parameter_list, no_grad_set
            )
1316

1317 1318
            return optimize_ops, params_grads, dist_startup_prog, dist_main_prog

1319
        # compile time
1320
        distributed_optimizer_list = (
1321
            MetaOptimizerFactory()._get_valid_meta_optimizers(
1322 1323 1324
                self.user_defined_optimizer
            )
        )
D
Dong Daxiang 已提交
1325

D
Dong Daxiang 已提交
1326
        context["user_defined_strategy"] = copy.deepcopy(
1327 1328
            self._user_defined_strategy
        )
D
Dong Daxiang 已提交
1329
        copy_user_defined_strategy = copy.deepcopy(self._user_defined_strategy)
1330 1331 1332 1333 1334 1335

        # trigger the auto-parallel in very strict condition
        # strategy = DistributedStrategy()
        # strategy.auto = True
        # optimizer = paddle.optimizer.SGD(learning_rate=0.1)
        # optimizer = fleet.distributed_optimizer(optimizer, strategy)
D
Dong Daxiang 已提交
1336
        if copy_user_defined_strategy._is_strict_auto():
1337 1338
            # turn on all the strategy for each optimizer
            for opt in distributed_optimizer_list:
D
Dong Daxiang 已提交
1339
                opt._enable_strategy(copy_user_defined_strategy, context)
1340

1341 1342
        valid_optimizer_list = []
        valid_graph_optimizer_list = []
D
Dong Daxiang 已提交
1343
        can_not_apply_optimizer_list = []
1344 1345
        # recall meta optimizers for ranking
        for opt in distributed_optimizer_list:
1346 1347 1348 1349 1350 1351
            opt._set_basic_info(
                loss,
                self._role_maker,
                self.user_defined_optimizer,
                copy_user_defined_strategy,
            )
1352 1353
            if opt._can_apply() and not opt._is_graph_out():
                valid_optimizer_list.append(opt)
D
Dong Daxiang 已提交
1354
            elif opt._can_apply() and opt._is_graph_out():
1355
                valid_graph_optimizer_list.append(opt)
D
Dong Daxiang 已提交
1356 1357
            else:
                can_not_apply_optimizer_list.append(opt)
1358
        # combine recalled meta optimizers to be a valid meta optimizer
1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369
        (
            meta_optimizer,
            graph_optimizer,
        ) = self.strategy_compiler.generate_optimizer(
            loss,
            self._role_maker,
            self.user_defined_optimizer,
            copy_user_defined_strategy,
            valid_optimizer_list,
            valid_graph_optimizer_list,
        )
D
Dong Daxiang 已提交
1370

D
Dong Daxiang 已提交
1371
        valid_strategy = self.strategy_compiler._get_valid_strategy(
1372 1373
            copy_user_defined_strategy, can_not_apply_optimizer_list
        )
D
Dong Daxiang 已提交
1374 1375

        context["valid_strategy"] = copy.deepcopy(valid_strategy)
R
Roc 已提交
1376
        logger.debug("valid_strategy: " + str(context["valid_strategy"]))
1377 1378 1379
        logger.debug(
            "user_defined_strategy: " + str(context["user_defined_strategy"])
        )
1380

1381 1382 1383 1384 1385 1386
        applied_meta_list = self.strategy_compiler._get_applied_meta_list()
        applied_graph_list = self.strategy_compiler._get_applied_graph_list()

        context['applied_meta_list'] = applied_meta_list
        context['applied_graph_list'] = applied_graph_list

D
Dong Daxiang 已提交
1387
        self._context = context
1388

D
Dong Daxiang 已提交
1389
        self.valid_strategy = valid_strategy
1390
        self.valid_strategy._enable_env()
D
Dong Daxiang 已提交
1391

1392 1393
        optimize_ops = []
        params_grads = []
1394

1395 1396 1397 1398 1399
        if self._role_maker._is_non_distributed() and not self._is_collective:
            if self._runtime_handle is None:
                self._runtime_handle = RuntimeFactory()._create_runtime(context)

            compiled_program = compiler.CompiledProgram(
1400 1401
                self.origin_main_program
            ).with_data_parallel(loss_name=loss.name, share_vars_from=None)
1402
            loss.block.program._graph = compiled_program
1403 1404 1405
            return self.user_defined_optimizer.minimize(
                loss, startup_program, parameter_list, no_grad_set=no_grad_set
            )
1406

1407
        if meta_optimizer:
1408 1409 1410
            logger.debug(
                "before minimize program id: " + str(id(loss.block.program))
            )
1411
            optimize_ops, params_grads = meta_optimizer.minimize(
1412 1413 1414 1415 1416
                loss, startup_program, parameter_list, no_grad_set=no_grad_set
            )
            logger.debug(
                "after minimize program id: " + str(id(loss.block.program))
            )
1417
            default_program = paddle.static.default_main_program()
R
Roc 已提交
1418
            logger.debug("default program id: " + str(id(default_program)))
1419 1420 1421

            if id(default_program) != id(loss.block.program):
                paddle.fluid.framework.switch_main_program(loss.block.program)
1422 1423 1424
            logger.debug(
                "default program id after switch: " + str(id(default_program))
            )
1425

1426 1427
        else:
            optimize_ops, params_grads = self.user_defined_optimizer.minimize(
1428 1429
                loss, startup_program, parameter_list, no_grad_set=no_grad_set
            )
1430

1431 1432
        context["program_optimize_ops"] = optimize_ops
        context["program_params_grads"] = params_grads
1433

1434
        if graph_optimizer:
1435 1436 1437 1438
            logger.debug(
                "before graph minimize program id: "
                + str(id(loss.block.program))
            )
D
Dong Daxiang 已提交
1439
            optimize_ops, params_grads = graph_optimizer.minimize(
1440 1441
                loss, startup_program, parameter_list, no_grad_set=no_grad_set
            )
1442 1443 1444 1445
            # since we do not encourage users to use graph operations
            # if a graph optimizer takes effect, mostly
            # optimizers_ops and params_grads are None
            # i.e. users can not modify current computation graph anymore
1446 1447
            context["graph_optimize_ops"] = optimize_ops
            context["graph_optimize_grads"] = params_grads
1448 1449
        else:
            apply_ir_passes(loss.block.program, startup_program, self)
1450

1451 1452
        if not self._role_maker._is_heter_parameter_server_mode:
            program = paddle.static.default_main_program()
1453 1454 1455
            opt_info = {} if program._fleet_opt is None else program._fleet_opt
            opt_info["mpi_size"] = self.worker_num()
            opt_info["mpi_rank"] = self.worker_index()
1456 1457 1458 1459
            for (
                k,
                v,
            ) in self._user_defined_strategy.trainer_desc_configs.items():
1460
                if v or k not in opt_info:
1461
                    opt_info[k] = v
1462 1463 1464 1465 1466 1467
            program._fleet_opt = opt_info

        if self._runtime_handle is None:
            self._runtime_handle = RuntimeFactory()._create_runtime(context)

        import paddle.distributed.fleet as fleet
1468

1469 1470 1471 1472
        fleet.util._set_strategy(context["valid_strategy"])

        return optimize_ops, params_grads

1473 1474 1475 1476 1477 1478 1479
    def _minimize_losses_impl(
        self,
        losses,
        startup_programs=None,
        parameter_list=None,
        no_grad_set=None,
    ):
1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494
        context = {}

        # cache original feed forward program
        self.origin_main_program = losses[0].block.program
        context["origin_main_program"] = self.origin_main_program
        context["origin_main_programs"] = []
        for loss in losses:
            context["origin_main_programs"].append(loss.block.program)
        context["loss"] = losses

        if startup_programs is None:
            if len(losses) == 1:
                startup_programs = [paddle.static.default_startup_program()]
            else:
                raise ValueError(
1495 1496
                    "startup_program can't be None when loss is list."
                )
1497 1498 1499 1500 1501 1502 1503 1504 1505
        self.origin_startup_program = startup_programs[0].clone(for_test=False)
        context["origin_startup_program"] = startup_programs[0]
        context["origin_startup_programs"] = []
        for program in startup_programs:
            context["origin_startup_programs"].append(program)

        context["role_maker"] = self._role_maker

        context["user_defined_strategy"] = copy.deepcopy(
1506 1507
            self._user_defined_strategy
        )
1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518

        context["valid_strategy"] = copy.deepcopy(self._user_defined_strategy)

        self._context = context

        self.valid_strategy = context["valid_strategy"]
        self.valid_strategy._enable_env()

        optimize_ops = []
        params_grads = []

W
wuhuachaocoding 已提交
1519
        from .meta_optimizers import ParameterServerOptimizer
1520

1521
        ps_optimizer = ParameterServerOptimizer(self.user_defined_optimizer)
1522 1523 1524 1525 1526 1527
        ps_optimizer._set_basic_info(
            losses,
            self._role_maker,
            self.user_defined_optimizer,
            self._user_defined_strategy,
        )
1528
        optimize_ops, params_grads = ps_optimizer.minimize_losses_impl(
1529 1530
            losses, startup_programs, parameter_list, no_grad_set=no_grad_set
        )
1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542

        # default_program = paddle.static.default_main_program()

        # if id(default_program) != id(losses[0].block.program):
        #     paddle.fluid.framework.switch_main_program(losses[0].block.program)

        context["program_optimize_ops"] = optimize_ops
        context["program_params_grads"] = params_grads

        for loss in losses:
            program = loss.block.program
            opt_info = {} if program._fleet_opt is None else program._fleet_opt
1543 1544
            opt_info["mpi_size"] = self.worker_num()
            opt_info["mpi_rank"] = self.worker_index()
1545 1546 1547 1548
            for (
                k,
                v,
            ) in self._user_defined_strategy.trainer_desc_configs.items():
1549
                if v or k not in opt_info:
1550
                    opt_info[k] = v
1551
            program._fleet_opt = opt_info
1552 1553 1554 1555 1556
            logger.debug(
                "fleet base opt info: "
                + str(id(program))
                + str(program._fleet_opt)
            )
1557

1558
        if self._runtime_handle is None:
1559
            self._runtime_handle = RuntimeFactory()._create_runtime(context)
1560

1561
        import paddle.distributed.fleet as fleet
1562

1563
        fleet.util._set_strategy(context["valid_strategy"])
1564 1565

        return optimize_ops, params_grads