fleet.py 54.4 KB
Newer Older
W
wuhuachaocoding 已提交
1
#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import copy
16
import os
17 18

import paddle
19
from paddle.fluid import compiler
20
from paddle.fluid.framework import in_dygraph_mode
21 22
from paddle.fluid.wrapped_decorator import wrap_decorator
from paddle.framework import _global_flags
23
from paddle.framework.ir import apply_build_strategy
24

W
wuhuachaocoding 已提交
25
from .base import topology as tp
26 27 28 29 30
from .base.distributed_strategy import DistributedStrategy
from .base.meta_optimizer_factory import MetaOptimizerFactory
from .base.role_maker import PaddleCloudRoleMaker, RoleMakerBase
from .base.runtime_factory import RuntimeFactory
from .base.strategy_compiler import StrategyCompiler
W
wuhuachaocoding 已提交
31
from .meta_parallel import model_parallel_random_seed
R
Roc 已提交
32
from .utils.log_util import logger, set_log_level
33

34 35
__all__ = []

36

37 38 39 40 41 42 43 44 45 46 47 48 49 50
def apply_ir_passes(main_program, startup_program, config):
    build_strategy = config._user_defined_strategy.build_strategy._copy()
    if not _global_flags()['FLAGS_apply_pass_to_program']:
        return build_strategy

    pipeline_opt = getattr(main_program, "_pipeline_opt", {})
    if pipeline_opt:
        main_program = pipeline_opt["section_program"]
        startup_program = startup_program._pipeline_opt["startup_program"]

    pass_attrs = {"use_cuda": config._is_collective}
    fuse_all_reduce = config._user_defined_strategy.fuse_all_reduce_ops
    if fuse_all_reduce and build_strategy.fuse_all_optimizer_ops:
        # FIXME(zjl): currently, fuse_all_optimizer_ops
51 52 53 54
        # have conflict with fuse_all_reduce_ops because
        # RawProgramOptimizer also inserts coalesce_tensor
        # into program. These two procedures may conflict
        # in which vars are to be fused.
R
Roc 已提交
55
        logger.warning(
56 57 58 59
            'Currently, the fuse_all_optimizer_ops pass has conflict with fuse_all_reduce_ops pass. Disable the fuse_all_optimizer_ops pass temporarily.'
        )
        build_strategy.fuse_all_optimizer_ops = False

60 61 62
    return apply_build_strategy(
        main_program, startup_program, build_strategy, pass_attrs
    )
63 64


65 66 67 68 69 70 71 72 73 74 75 76
def _inited_runtime_handler_(func):
    def __impl__(*args, **kwargs):
        cls = args[0]

        if cls._runtime_handle is None:
            raise ValueError("Fleet can not find suitable runtime handler")

        return func(*args, **kwargs)

    return __impl__


77 78 79 80
def _is_non_distributed_check_(func):
    def __impl__(*args, **kwargs):
        cls = args[0]

81 82 83 84
        if (
            cls._role_maker is not None
            and cls._role_maker._is_non_distributed() is True
        ):
R
Roc 已提交
85
            logger.warning(
86 87 88
                "%s() function doesn't work when use non_distributed fleet."
                % (func.__name__)
            )
89 90 91 92 93 94 95
            return

        return func(*args, **kwargs)

    return __impl__


96
inited_runtime_handler = wrap_decorator(_inited_runtime_handler_)
97
is_non_distributed_check = wrap_decorator(_is_non_distributed_check_)
98 99


100
class Fleet:
101 102
    """
    Unified API for distributed training of PaddlePaddle
103
    Please reference the https://github.com/PaddlePaddle/PaddleFleetX for details
104 105 106 107 108


    Returns:
        Fleet: A Fleet instance

109
    Example for collective training:
1
123malin 已提交
110

111 112
        .. code-block:: python

1
123malin 已提交
113 114
            import paddle
            paddle.enable_static()
115
            import paddle.distributed.fleet as fleet
116 117 118

            fleet.init(is_collective=True)

119 120 121
            strategy = fleet.DistributedStrategy()
            optimizer = paddle.optimizer.SGD(learning_rate=0.001)
            optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
122 123 124 125 126 127 128 129

            # do distributed training


    Example for parameter server training:

        .. code-block:: python

1
123malin 已提交
130 131
            import paddle
            paddle.enable_static()
132 133
            import paddle.distributed.fleet as fleet
            strategy = fleet.DistributedStrategy()
S
ShenLiang 已提交
134
            fleet.init(strategy=strategy)
135

136
            optimizer = paddle.optimizer.SGD(learning_rate=0.001)
137
            optimizer = fleet.distributed_optimizer(optimizer)
138

139 140
            if fleet.is_first_worker():
                print("this is first worker")
141

142 143
            print("current node index: {}".format(fleet.worker_index()))
            print("total number of worker num: {}".format(fleet.worker_num()))
144

145 146 147
            if fleet.is_worker():
                print("this is worker")
            print("worker endpoints: {}".format(fleet.worker_endpoints(to_string=True)))
148

149 150
            print("server num: {}".format(fleet.server_num()))
            print("server endpoints: {}".format(fleet.server_endpoints(to_string=True)))
151

152 153 154
            if fleet.is_server():
                print("this is server")
            fleet.stop_worker()
155 156


157 158 159
    """

    def __init__(self):
160
        self._role_maker = None
161
        self.strategy_compiler = None
162
        self._is_collective = False
163
        self._runtime_handle = None
D
Dong Daxiang 已提交
164 165
        self._util = None
        self._context = {}
W
wuhuachaocoding 已提交
166
        self.user_defined_optimizer = paddle.optimizer.Optimizer(0.0)
167

168 169 170 171 172 173 174
    def init(
        self,
        role_maker=None,
        is_collective=False,
        strategy=None,
        log_level="INFO",
    ):
175 176 177
        """
        Initialize role_maker in Fleet.

178 179 180 181 182
        This function is responsible for the distributed architecture
        what you want to run your code behind.

        Args:
            role_maker (RoleMakerBase, optional): A ``RoleMakerBase`` containing the configuration
183
                of environment variables related to distributed training.If you did not initialize
184 185
                the rolemaker by yourself, it will be automatically initialized to PaddleRoleMaker.
                The default value is None.
186
            is_collective (Boolean, optional): A ``Boolean`` variable determines whether the program
187
                runs on Collective mode or ParameterServer mode. True means the program runs on
188
                Collective mode, and False means running on ParameterServer mode. The default value
189
                is False.
190
            strategy (DistributedStrategy): Extra properties for distributed training.
191
                For details, please refer to paddle.distributed.fleet.DistributedStrategy. Default: None.
R
Roc 已提交
192 193
            log_level (Integer, String, optional): A ``Integer`` or ``String`` Variable determining how hight
                the logging level is. Default is "INFO".
194 195


196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217
        Returns:
            None

        Examples1:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

        Examples2:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init(is_collective=True)

        Examples3:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
1
123malin 已提交
218
                role = fleet.PaddleCloudRoleMaker()
219
                fleet.init(role)
220

221 222 223 224 225 226
        Examples4:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                strategy = fleet.DistributedStrategy()
S
ShenLiang 已提交
227
                fleet.init(strategy=strategy)
228

R
Roc 已提交
229 230 231 232 233 234 235 236
        Examples5:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                strategy = fleet.DistributedStrategy()
                fleet.init(log_level = "DEBUG")

237
        """
Q
qizhaoaoe 已提交
238
        from paddle.distributed import parallel_helper
R
Roc 已提交
239 240 241

        set_log_level(log_level)

S
ShenLiang 已提交
242 243 244
        if strategy is None:
            strategy = DistributedStrategy()
        self._user_defined_strategy = copy.deepcopy(strategy)
245 246

        if role_maker is None:
247 248 249
            if isinstance(is_collective, bool):
                self._is_collective = is_collective
                self._role_maker = PaddleCloudRoleMaker(
250 251
                    is_collective=self._is_collective
                )
252 253
            else:
                raise ValueError(
254 255 256 257
                    "`is_collective` should be instance of `bool`, but got {}".format(
                        type(is_collective)
                    )
                )
258
        else:
259 260
            if isinstance(role_maker, RoleMakerBase):
                self._role_maker = role_maker
261
                self._is_collective = role_maker._is_collective
262 263
            else:
                raise ValueError(
264 265 266 267
                    "`role_maker` should be subclass of `RoleMakerBase`, but got {}".format(
                        type(role_maker)
                    )
                )
268
        self._role_maker._generate_role()
269

270
        import paddle.distributed.fleet as fleet
271

272 273
        fleet.util._set_role_maker(self._role_maker)

274
        self.strategy_compiler = StrategyCompiler()
275 276

        if self._role_maker._is_non_distributed() and self._is_collective:
277 278
            if paddle.framework.core.is_compiled_with_cuda():
                gpus_num = paddle.framework.core.get_cuda_device_count()
279 280 281 282 283
                if gpus_num != 1:
                    raise ValueError(
                        "CUDA_VISIBLE_DEVICES shoule be set only 1 card if you use `python` to launch fleet program."
                    )

284
        if in_dygraph_mode():
285
            if self.worker_num() == 1:
286 287 288
                # if worker_num is 1, should construct default topology & hcg
                self._topology = tp.CommunicateTopology()
                self._hcg = tp.HybridCommunicateGroup(self._topology)
289
                return
290
            if parallel_helper._is_parallel_ctx_initialized():
R
Roc 已提交
291
                logger.warning(
292 293
                    "The dygraph parallel environment has been initialized."
                )
294
            else:
295 296
                # FLAGS_nccl_nrings is used for dynamic graph multi-stream communication
                if "FLAGS_nccl_nrings" in os.environ:
R
Roc 已提交
297
                    logger.warning(
298 299
                        "You have set the environment variable FLAGS_nccl_nrings "
                        "outside the program, so the nccl_comm_num in "
300 301
                        "DistributedStrategy will not take effect here."
                    )
302 303
                else:
                    os.environ["FLAGS_nccl_nrings"] = str(
304 305
                        self._user_defined_strategy.nccl_comm_num
                    )
306
                paddle.distributed.init_parallel_env()
307

K
kuizhiqing 已提交
308
            # hybrid parallel not support for npu/xpu
309
            if not self._user_defined_strategy.heter_ccl_mode:
K
kuizhiqing 已提交
310 311 312 313
                # init hybrid parallel environment in dygraph
                if tp._HYBRID_PARALLEL_GROUP is None:
                    self._init_hybrid_parallel_env()
                else:
R
Roc 已提交
314
                    logger.warning(
K
kuizhiqing 已提交
315 316
                        "The dygraph hybrid parallel environment has been initialized."
                    )
W
WangXi 已提交
317 318 319 320 321 322 323 324 325 326
        elif self._is_collective:
            use_sharding = self._user_defined_strategy.sharding

            # global group
            global_rank = self.worker_index()
            global_world_size = self.worker_num()
            # NOTE(wangxi): see sharding_optimizer
            global_ring_id = 3 if use_sharding else 0
            global_ranks = list(range(global_world_size))

327 328
            if tp._HYBRID_PARALLEL_GROUP is None:
                tp._CommunicateGroup()
W
WangXi 已提交
329 330
            cg = tp._HYBRID_PARALLEL_GROUP
            self._hcg = cg
331 332 333 334 335 336 337
            cg.set_comm_group(
                'global',
                global_rank,
                global_world_size,
                global_ring_id,
                global_ranks,
            )
W
WangXi 已提交
338

Y
Yuang Liu 已提交
339 340 341
            use_tensor_parallel = self._user_defined_strategy.tensor_parallel
            use_mp = use_sharding or use_tensor_parallel

W
WangXi 已提交
342
            # hybrid group
343 344
            if use_mp is False:
                return
Y
Yuang Liu 已提交
345 346 347 348 349 350 351 352

            mp_degree_sharding = 1
            mp_degree_tensor_parallel = 1
            if use_sharding:
                sharding_configs = self._user_defined_strategy.sharding_configs
                mp_degree_sharding = int(sharding_configs['mp_degree'])

            if use_tensor_parallel:
353 354 355
                tensor_parallel_configs = (
                    self._user_defined_strategy.tensor_parallel_configs
                )
356
                mp_degree_tensor_parallel = int(
357 358
                    tensor_parallel_configs['tensor_parallel_degree']
                )
Y
Yuang Liu 已提交
359 360 361

            if use_sharding and use_tensor_parallel:
                assert mp_degree_sharding == mp_degree_tensor_parallel
W
WangXi 已提交
362

363 364 365 366 367
            mp_degree = (
                mp_degree_sharding
                if use_sharding
                else mp_degree_tensor_parallel
            )
W
WangXi 已提交
368 369 370 371 372 373 374 375

            if mp_degree > 1:
                assert global_world_size % mp_degree == 0
                # NOTE(wangxi): mp_ring_id sync with sharding_optimizer.py _build_groups
                mp_ring_id = 0
                mp_rank = global_rank % mp_degree
                mp_group_id = global_rank // mp_degree
                mp_group_ranks = [
376 377
                    idx
                    for idx in global_ranks
W
WangXi 已提交
378 379
                    if idx // mp_degree == mp_group_id
                ]
380 381 382
                cg.set_comm_group(
                    'model', mp_rank, mp_degree, mp_ring_id, mp_group_ranks
                )
W
wuhuachaocoding 已提交
383
        return self
384 385

    def _init_hybrid_parallel_env(self):
386
        """initialize the hybrid environment"""
387 388 389 390
        self.hybrid_configs = self._user_defined_strategy.hybrid_configs
        self.dp_degree = self.hybrid_configs["dp_degree"]
        self.mp_degree = self.hybrid_configs["mp_degree"]
        self.pp_degree = self.hybrid_configs["pp_degree"]
J
JZ-LIANG 已提交
391
        self.sharding_degree = self.hybrid_configs["sharding_degree"]
392 393 394

        assert self.mp_degree >= 0, "mp_degree should be greater or equal to 0"
        assert self.pp_degree >= 0, "pp_degree should be greater or equal to 0"
395 396 397
        assert (
            self.sharding_degree >= 0
        ), "sharding_degree should be greater or equal to 0"
398 399 400 401 402 403 404 405 406 407 408

        self.mp_degree = max(self.mp_degree, 1)
        self.pp_degree = max(self.pp_degree, 1)

        if self.dp_degree < 0:
            nranks = paddle.distributed.get_world_size()
            self.dp_degree = nranks // (self.mp_degree * self.pp_degree)

        self.dp_degree = max(self.dp_degree, 1)

        self._topology = tp.CommunicateTopology(
J
JZ-LIANG 已提交
409 410
            hybrid_group_names=["data", "pipe", "sharding", "model"],
            dims=[
411 412 413 414 415 416
                self.dp_degree,
                self.pp_degree,
                self.sharding_degree,
                self.mp_degree,
            ],
        )
417 418 419

        self._hcg = tp.HybridCommunicateGroup(self._topology)

420
        if self.mp_degree > 1:
421 422 423
            tensor_parallel_configs = (
                self._user_defined_strategy.tensor_parallel_configs
            )
424 425 426 427 428 429
            tensor_init_seed = tensor_parallel_configs["tensor_init_seed"]
            if tensor_init_seed == -1:
                model_parallel_random_seed()
            else:
                model_parallel_random_seed(tensor_init_seed)

430 431 432 433 434 435 436 437
    def get_hybrid_communicate_group(self):
        assert self._hcg is not None
        return self._hcg

    def get_hybrid_parallel_topology(self):
        assert self._topology is not None
        return self._topology

438 439 440 441 442 443 444
    def is_first_worker(self):
        """
        Check whether the node is the first instance of worker.

        Returns:
            bool: True if this is the first node of worker,
                  False if not.
445

446 447 448 449 450 451 452 453
        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.is_first_worker()

454
        """
455
        return self._role_maker._is_first_worker()
456 457 458 459 460 461 462

    def worker_index(self):
        """
        Get current worker index.

        Returns:
            int: node id
463 464 465 466

        Examples:

            .. code-block:: python
1
123malin 已提交
467

468 469 470 471
                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.worker_index()

472
        """
473
        return self._role_maker._worker_index()
474 475 476 477 478 479 480

    def worker_num(self):
        """
        Get current total worker number.

        Returns:
            int: worker numbers
1
123malin 已提交
481

482
        Examples:
1
123malin 已提交
483

484 485 486 487 488 489
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.worker_num()

490
        """
491
        return self._role_maker._worker_num()
492

493 494 495 496 497 498 499 500 501 502 503 504
    def node_num(self):
        return self._role_maker._get_node_num()

    def local_rank(self):
        return self._role_maker._get_local_rank()

    def local_device_ids(self):
        return self._role_maker._get_local_device_ids()

    def world_device_ids(self):
        return self._role_maker._get_world_device_ids()

505 506 507 508 509 510 511
    def is_worker(self):
        """
        Check whether the node is an instance of worker.

        Returns:
            bool: True if this is a node of worker,
                  False if not.
512 513

        Examples:
1
123malin 已提交
514

515 516 517 518 519 520
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.is_worker()

521
        """
522
        return self._role_maker._is_worker()
523

524 525 526
    def is_coordinator(self):
        return self._role_maker._is_coordinator()

527 528
    def worker_endpoints(self, to_string=False):
        """
529
        Get current worker endpoints, such as ["127.0.0.1:1001", "127.0.0.1:1002"].
530 531 532

        Returns:
            list/string: server endpoints
533 534

        Examples:
1
123malin 已提交
535

536 537 538 539 540 541
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.worker_endpoints()

542 543
        """
        if to_string:
544
            return ",".join(self._role_maker._get_trainer_endpoints())
545
        else:
546
            return self._role_maker._get_trainer_endpoints()
547 548 549 550 551 552 553

    def server_num(self):
        """
        Get current total worker number.

        Returns:
            int: server number
554 555

        Examples:
1
123malin 已提交
556

557
            .. code-block:: python
1
123malin 已提交
558 559 560 561

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.server_num()
562
        """
563
        return len(self._role_maker._get_pserver_endpoints())
564 565 566 567 568 569 570

    def server_index(self):
        """
        Get current server index.

        Returns:
            int: node id
571 572

        Examples:
1
123malin 已提交
573

574 575 576 577 578 579
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.server_index()

580
        """
581
        return self._role_maker._server_index()
582 583 584 585 586 587 588

    def server_endpoints(self, to_string=False):
        """
        Get current server endpoints, such as ["127.0.0.1:1001", "127.0.0.1:1002"].

        Returns:
            list/string: server endpoints
589 590

        Examples:
1
123malin 已提交
591

592 593 594 595 596 597
            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.server_endpoints()

598
        """
599

600
        if to_string:
601
            return ",".join(self._role_maker._get_pserver_endpoints())
602
        else:
603
            return self._role_maker._get_pserver_endpoints()
604 605 606 607 608 609 610 611

    def is_server(self):
        """
        Check whether the node is an instance of server.

        Returns:
            bool: True if this is a node of server,
                  False if not.
612 613 614 615

        Examples:

            .. code-block:: python
1
123malin 已提交
616

617 618 619 620
                import paddle.distributed.fleet as fleet
                fleet.init()
                fleet.is_server()

621
        """
622 623
        return self._role_maker._is_server()

624 625
    def barrier_worker(self):
        """
626 627 628 629
        barrier all workers

        Returns:
            None
630
        """
631
        self._role_maker._barrier("worker")
632

633
    @is_non_distributed_check
634
    @inited_runtime_handler
635
    def init_worker(self, scopes=None):
636
        """
637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654
        initialize `Communicator` for parameter server training.


        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_worker()

655
        """
656
        self._runtime_handle._init_worker(scopes)
657

658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676
    @is_non_distributed_check
    @inited_runtime_handler
    def init_coordinator(self, scopes=None):
        """
        initialize coordinator node
        """
        self._runtime_handle._init_coordinator(scopes)

    def make_fl_strategy(self):
        self._runtime_handle._make_fl_strategy()

    @is_non_distributed_check
    @inited_runtime_handler
    def get_fl_client(self):
        """
        get worker(training node) ptr
        """
        return self._runtime_handle._worker

677
    @is_non_distributed_check
678
    @inited_runtime_handler
679
    def init_server(self, *args, **kwargs):
680
        """
681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699
        init_server executor to initialize startup program,
        if the `args` is not empty, it will run load_persistables for increment training.


        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_server()

700
        """
701
        self._runtime_handle._init_server(*args, **kwargs)
702

Z
zmxdream 已提交
703 704
    @is_non_distributed_check
    @inited_runtime_handler
T
Thunderbrook 已提交
705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722
    def load_model(self, path, mode):
        """
        load fleet model from path


        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

723
                fleet.load_model("path", mode=0)
T
Thunderbrook 已提交
724 725

        """
726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776
        self._runtime_handle._load_persistables(path, mode)

    @is_non_distributed_check
    @inited_runtime_handler
    def load_one_table(self, table_id, path, mode):
        """
        load fleet one table from path


        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.load_one_table(0, "path", mode=0)

        """
        self._runtime_handle._load_one_table(table_id, path, mode)

    @is_non_distributed_check
    @inited_runtime_handler
    def load_inference_model(self, path, mode):
        """
        load fleet inference model from path


        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.load_inference_model("path", mode=1)

        """
        self._runtime_handle._load_inference_model(path, mode)
T
Thunderbrook 已提交
777

778
    @is_non_distributed_check
779
    @inited_runtime_handler
780 781
    def run_server(self):
        """
782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799
        run server will run pserver main program with executor.

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                if fleet.is_server():
                    fleet.init_server()

800 801 802
        """
        self._runtime_handle._run_server()

803
    @is_non_distributed_check
804
    @inited_runtime_handler
805 806
    def stop_worker(self):
        """
807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823
        stop `Communicator` and give training complete notice to parameter server.

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_server()

824 825 826
        """
        self._runtime_handle._stop_worker()

Z
zmxdream 已提交
827 828
    @is_non_distributed_check
    @inited_runtime_handler
T
tangwei12 已提交
829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862
    def save(self, dirname, feed=[], fetch=[], **configs):
        inference = True

        if not feed and not fetch:
            inference = False

        place = paddle.CPUPlace()
        executor = paddle.static.Executor(place)

        if inference:
            feeded_var_names = []
            fetch_var_names = []

            for var in feed:
                if isinstance(var, str):
                    feeded_var_names.append(var)
                elif isinstance(var, paddle.static.Variable):
                    feeded_var_names.append(var.name)
                else:
                    raise ValueError("feed must be [str|Variable]")

            for var in fetch:
                if isinstance(var, str):
                    fetch_var_names.append(var)
                elif isinstance(var, paddle.static.Variable):
                    fetch_var_names.append(var.name)
                else:
                    raise ValueError("feed must be [str|Variable]")

            fetch_vars = [
                paddle.static.default_main_program().global_block().var(name)
                for name in fetch_var_names
            ]

863 864 865
            self._runtime_handle._save_inference_model(
                executor, dirname, feeded_var_names, fetch_vars, None, True, 0
            )
T
tangwei12 已提交
866 867 868 869
        else:
            increment_mode = 0
            if "mode" in configs:
                increment_mode = int(configs["mode"])
870 871 872
            self._runtime_handle._save_persistables(
                executor, dirname, main_program=None, mode=increment_mode
            )
T
tangwei12 已提交
873

Z
zmxdream 已提交
874 875
    @is_non_distributed_check
    @inited_runtime_handler
876 877 878 879 880 881 882 883 884 885
    def save_inference_model(
        self,
        executor,
        dirname,
        feeded_var_names,
        target_vars,
        main_program=None,
        export_for_deployment=True,
        mode=0,
    ):
886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905
        """
        save inference model for inference.

        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.init_server()

        """

906 907 908 909 910 911 912 913 914
        self._runtime_handle._save_inference_model(
            executor,
            dirname,
            feeded_var_names,
            target_vars,
            main_program,
            export_for_deployment,
            mode,
        )
915

Z
zmxdream 已提交
916 917
    @is_non_distributed_check
    @inited_runtime_handler
918
    def save_persistables(self, executor, dirname, main_program=None, mode=0):
919 920
        """

1
123malin 已提交
921
        saves all persistable tensors from :code:`main_program` to
922 923
        the folder :code:`dirname`. You can refer to

1
123malin 已提交
924 925
        The :code:`dirname` is used to specify the folder where persistable tensors
        are going to be saved. If you would like to save tensors in separate
926 927 928
        files, set :code:`filename` None.

        Args:
1
123malin 已提交
929
            executor(Executor): The executor to run for saving persistable tensors.
930 931 932 933 934
                                You can refer to :ref:`api_guide_executor_en` for
                                more details.

            dirname(str, optional): The saving directory path.
                                When you need to save the parameter to the memory, set it to None.
1
123malin 已提交
935
            main_program(Program, optional): The program whose persistbale tensors will
936 937 938 939 940 941 942 943 944 945
                                             be saved. Default: None.


        Returns:
            None

        Examples:

            .. code-block:: text

1
123malin 已提交
946 947
                import paddle
                paddle.enable_static()
948 949 950 951 952 953 954
                import paddle.distributed.fleet as fleet

                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

1
123malin 已提交
955 956
                exe = paddle.static.Executor(paddle.CPUPlace())
                fleet.save_persistables(exe, "dirname", paddle.static.default_main_program())
957 958

        """
959 960 961
        self._runtime_handle._save_persistables(
            executor, dirname, main_program, mode
        )
962

Z
zhaocaibei123 已提交
963 964 965 966 967
    @is_non_distributed_check
    @inited_runtime_handler
    def save_cache_model(self, dirname, **configs):
        return self._runtime_handle._save_cache_model(dirname, **configs)

968 969 970 971 972
    @is_non_distributed_check
    @inited_runtime_handler
    def check_save_pre_patch_done(self):
        return self._runtime_handle._check_save_pre_patch_done()

L
lxsbupt 已提交
973 974 975 976 977 978 979 980 981
    @is_non_distributed_check
    @inited_runtime_handler
    def save_cache_table(
        self, table_id, pass_id, mem_cache_key_threshold=4000000000
    ):
        return self._runtime_handle._save_cache_table(
            table_id, pass_id, mem_cache_key_threshold
        )

982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008
    @is_non_distributed_check
    @inited_runtime_handler
    def save_one_table(self, table_id, path, mode):
        """
        save fleet one table from path


        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()

                # build net
                # fleet.distributed_optimizer(...)

                fleet.save_one_table(0, "path", mode=0)

        """
        self._runtime_handle._save_one_table(table_id, path, mode)

    @is_non_distributed_check
    @inited_runtime_handler
1009 1010 1011
    def save_dense_params(
        self, executor, dirname, scope, program, var_names=None
    ):
1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025
        """
        save fleet one table from path


        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle.distributed.fleet as fleet
                fleet.init()
                import paddle
1026 1027
                place = paddle.CPUPlace()
                exe =  paddle.static.Executor(place)
1028 1029 1030 1031 1032 1033 1034

                # build net
                # fleet.distributed_optimizer(...)

                fleet.save_dense_params(exe, "path", scope=paddle.static.global_scope(), program=paddle.static.default_main_program())

        """
1035 1036 1037
        self._runtime_handle._save_dense_params(
            executor, dirname, scope, program, var_names
        )
1038

L
lxsbupt 已提交
1039 1040
    @is_non_distributed_check
    @inited_runtime_handler
1041
    def shrink(self, threshold=None):
1042 1043
        self._runtime_handle._shrink(threshold)

1044
    def distributed_optimizer(self, optimizer, strategy=None):
1045
        """
1046 1047 1048 1049 1050 1051 1052
        Optimizer for distributed training.

        For the distributed training, this method would rebuild a new instance of DistributedOptimizer.
        Which has basic Optimizer function and special features for distributed training.

        Args:
            optimizer(Optimizer): The executor to run for init server.
1053
            strategy(DistributedStrategy): Extra properties for distributed optimizer.
1054
                It is recommended to use DistributedStrategy in fleet.init(). The strategy
1055 1056
                here is for compatibility. If the strategy in fleet.distributed_optimizer()
                is not None, then it will overwrite the DistributedStrategy in fleet.init(),
1057
                which will take effect in distributed training.
1058

1059
        Returns:
1060
            Fleet: instance of fleet.
1061 1062

        Examples:
1063

1064
            .. code-block:: python
1065

1
123malin 已提交
1066
                import paddle
1067
                import paddle.distributed.fleet as fleet
1
123malin 已提交
1068
                fleet.init(is_collective=True)
1069 1070 1071 1072
                strategy = fleet.DistributedStrategy()
                optimizer = paddle.optimizer.SGD(learning_rate=0.001)
                optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)

1073 1074
        """
        self.user_defined_optimizer = optimizer
1075

1076
        if strategy is not None:
T
tangwei12 已提交
1077
            if self._is_collective:
R
Roc 已提交
1078
                logger.warning(
T
tangwei12 已提交
1079 1080 1081 1082
                    "It is recommended to use DistributedStrategy "
                    "in fleet.init(). The strategy here is only for compatibility. "
                    "If the strategy in fleet.distributed_optimizer() is "
                    "not None, then it will overwrite the DistributedStrategy in fleet.init(), "
1083 1084
                    "which will take effect in distributed training."
                )
1085
            self._user_defined_strategy = copy.deepcopy(strategy)
D
Dong Daxiang 已提交
1086 1087

        self._context = {}
S
ShenLiang 已提交
1088

1089 1090
        return self

1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102
    def _get_amp_optimizer(self):
        # imitate target optimizer retrieval
        amp_optimizer = None
        for optimizer in self.strategy_compiler._get_applied_meta_optimizer():
            if hasattr(optimizer, 'amp_init'):
                amp_optimizer = optimizer
                break

        if amp_optimizer is None:
            if hasattr(self.user_defined_optimizer, 'amp_init'):
                amp_optimizer = self.user_defined_optimizer

1103 1104 1105
        assert (
            amp_optimizer is not None
        ), "amp_init can only be used when the amp(auto mixed precision) strategy is turned on."
1106 1107 1108
        return amp_optimizer

    def get_loss_scaling(self):
1109
        """Return the real-time loss scaling factor."""
1110 1111 1112
        amp_optimizer = self._get_amp_optimizer()
        return amp_optimizer.get_loss_scaling()

1113 1114 1115
    def amp_init(
        self, place, scope=None, test_program=None, use_fp16_test=False
    ):
H
huangxu96 已提交
1116 1117
        """
        Init the amp training, such as cast fp32 parameters to fp16 type.
1118

H
huangxu96 已提交
1119
        Args:
1120
            place(CUDAPlace): place is used to initialize
H
huangxu96 已提交
1121 1122 1123 1124
                fp16 parameters with fp32 values.
            scope(Scope): The scope is used to find fp32 parameters.
            test_program(Program): The program is used for testing.
            use_fp16_test(bool): Whether to use fp16 testing.
1125

H
huangxu96 已提交
1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145
        Examples:
            .. code-block:: python

                import paddle
                import paddle.nn.functional as F
                paddle.enable_static()

                def run_example_code():
                    place = paddle.CUDAPlace(0)
                    exe = paddle.static.Executor(place)
                    data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32')
                    conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3)
                    # 1) Use fp16_guard to control the range of fp16 kernels used.
                    with paddle.static.amp.fp16_guard():
                        bn = paddle.static.nn.batch_norm(input=conv2d, act="relu")
                        pool = F.max_pool2d(bn, kernel_size=2, stride=2)
                        hidden = paddle.static.nn.fc(pool, size=10)
                        loss = paddle.mean(hidden)
                    # 2) Create the optimizer and set `multi_precision` to True.
                    # Setting `multi_precision` to True can avoid the poor accuracy
1146
                    # or the slow convergence in a way.
H
huangxu96 已提交
1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165
                    optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True)
                    # 3) These ops in `custom_black_list` will keep in the float32 computation type.
                    amp_list = paddle.static.amp.CustomOpLists(
                        custom_black_list=['pool2d'])
                    # 4) The entry of Paddle AMP.
                    # Enable pure fp16 training by setting `use_pure_fp16` to True.
                    optimizer = paddle.static.amp.decorate(
                        optimizer,
                        amp_list,
                        init_loss_scaling=128.0,
                        use_dynamic_loss_scaling=True,
                        use_pure_fp16=True)
                    # If you don't use the default_startup_program(), you sholud pass
                    # your defined `startup_program` into `minimize`.
                    optimizer.minimize(loss)
                    exe.run(paddle.static.default_startup_program())
                    # 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`).
                    # If you want to perform the testing process, you should pass `test_program` into `amp_init`.
                    optimizer.amp_init(place, scope=paddle.static.global_scope())
1166

H
huangxu96 已提交
1167
                if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
1168
                    run_example_code()
H
huangxu96 已提交
1169
        """
1170
        amp_optimizer = self._get_amp_optimizer()
1171
        return amp_optimizer.amp_init(place, scope, test_program, use_fp16_test)
H
huangxu96 已提交
1172

D
Dong Daxiang 已提交
1173 1174 1175 1176 1177 1178 1179 1180 1181
    def _final_strategy(self):
        if "valid_strategy" not in self._context:
            print(
                "WARNING: You may need to call minimize function before this function is called"
            )
            return {}
        else:
            return self._context["valid_strategy"]

1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199
    def _get_applied_meta_list(self):
        if "applied_meta_list" not in self._context:
            print(
                "WARNING: You may need to call minimize function before _get_applied_meta_list called"
            )
            return []
        else:
            return self._context["applied_meta_list"]

    def _get_applied_graph_list(self):
        if "applied_graph_list" not in self._context:
            print(
                "WARNING: You may need to call minimize function before _get_applied_graph_list called"
            )
            return []
        else:
            return self._context["applied_graph_list"]

1200 1201 1202
    def minimize(
        self, loss, startup_program=None, parameter_list=None, no_grad_set=None
    ):
1203 1204 1205 1206
        """
        Add distributed operations to minimize ``loss`` by updating ``parameter_list``.

        Args:
1
123malin 已提交
1207
            loss (Tensor): A ``Tensor`` containing the value to minimize.
1208 1209 1210
            startup_program (Program, optional): :ref:`api_fluid_Program` for
                initializing parameters in ``parameter_list``. The default value
                is None, at this time :ref:`api_fluid_default_startup_program` will be used.
1
123malin 已提交
1211
            parameter_list (Iterable, optional): Iterable of ``Tensor`` or ``Tensor.name`` to update
1212 1213
                to minimize ``loss``. The default value is None, at this time all parameters
                will be updated.
1
123malin 已提交
1214
            no_grad_set (set, optional): Set of ``Tensor``  or ``Tensor.name`` that don't need
1215 1216 1217 1218
                to be updated. The default value is None.

        Returns:
            tuple: tuple (optimize_ops, params_grads), A list of operators appended
1
123malin 已提交
1219
            by minimize and a list of (param, grad) tensor pairs, param is
1220
            ``Parameter``, grad is the gradient value corresponding to the parameter.
1221 1222
            The returned tuple can be passed to ``fetch_list`` in ``Executor.run()`` to
            indicate program pruning. If so, the program will be pruned by ``feed`` and
1223 1224 1225
            ``fetch_list`` before run, see details in ``Executor``.

        Examples:
1
123malin 已提交
1226

1227
            .. code-block:: python
1228

1229
                import paddle
1
123malin 已提交
1230
                paddle.enable_static()
1231
                import paddle.distributed.fleet as fleet
1
123malin 已提交
1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242
                import paddle.nn.functional as F

                hid_dim = 10
                label_dim = 2
                input_x = paddle.static.data(name='x', shape=[None, 13], dtype='float32')
                input_y = paddle.static.data(name='y', shape=[None, 1], dtype='int64')
                fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim, activation='tanh')
                fc_2 = paddle.static.nn.fc(x=fc_1, size=hid_dim, activation='tanh')
                prediction = paddle.static.nn.fc(x=[fc_2], size=label_dim, activation='softmax')
                cost = F.cross_entropy(input=prediction, label=input_y)
                avg_cost = paddle.mean(x=cost)
1243

1
123malin 已提交
1244
                fleet.init(is_collective=True)
1245 1246 1247 1248
                strategy = fleet.DistributedStrategy()
                optimizer = paddle.optimizer.SGD(learning_rate=0.001)
                optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
                optimizer.minimize(avg_cost)
1249

1250
                # for more examples, please reference https://github.com/PaddlePaddle/PaddleFleetX
1251 1252

        """
1253
        if not isinstance(loss, list):
1254 1255 1256
            return self._minimize_impl(
                loss, startup_program, parameter_list, no_grad_set
            )
1257
        else:
1258
            if (
1259
                in_dygraph_mode()
1260 1261 1262
                or self._role_maker._is_non_distributed()
                or self._is_collective
            ):
1263
                raise ValueError("loss can be list only in PS mode")
1264 1265 1266 1267 1268 1269 1270
            return self._minimize_losses_impl(
                loss, startup_program, parameter_list, no_grad_set
            )

    def _minimize_impl(
        self, loss, startup_program=None, parameter_list=None, no_grad_set=None
    ):
D
Dong Daxiang 已提交
1271 1272
        context = {}
        context["user_defined_strategy"] = copy.deepcopy(
1273 1274
            self._user_defined_strategy
        )
1275
        if in_dygraph_mode():
1276 1277
            # imitate target optimizer retrieval
            target_opt = self.user_defined_optimizer
D
Dong Daxiang 已提交
1278
            self._context = context
1279
            return target_opt.minimize(loss)
1280
        else:
1281 1282 1283 1284
            # cache original feed forward program
            self.origin_main_program = loss.block.program
            # add distributed attr
            if not hasattr(self.origin_main_program, "distributed_info_"):
1285
                self.origin_main_program.distributed_info_ = dict()
1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299
                self.origin_main_program.distributed_info_[
                    "dp_degree"
                ] = self._user_defined_strategy.sharding_configs["dp_degree"]
                self.origin_main_program.distributed_info_[
                    "mp_degree"
                ] = self._user_defined_strategy.sharding_configs["mp_degree"]
                self.origin_main_program.distributed_info_[
                    "pp_degree"
                ] = self._user_defined_strategy.sharding_configs["pp_degree"]
                self.origin_main_program.distributed_info_[
                    "sharding_degree"
                ] = self._user_defined_strategy.sharding_configs[
                    "sharding_degree"
                ]
1300

1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314
            context["origin_main_program"] = self.origin_main_program
            context["origin_main_programs"] = [self.origin_main_program]
            context["loss"] = loss
            if startup_program is None:
                self.origin_startup_program = (
                    paddle.static.default_startup_program().clone(
                        for_test=False
                    )
                )
                startup_program = paddle.static.default_startup_program()
            else:
                self.origin_startup_program = startup_program.clone(
                    for_test=False
                )
1315

1316 1317 1318
            context["origin_startup_program"] = startup_program
            context["origin_startup_programs"] = [startup_program]
            context["role_maker"] = self._role_maker
1319

1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335
            # Use the auto-parallel's routines instead
            if (
                self._user_defined_strategy.semi_auto
                or self._user_defined_strategy.auto_search
            ):
                from ..auto_parallel.parallelizer import AutoParallelizer

                auto_parallelizer = AutoParallelizer(self)
                (
                    optimize_ops,
                    params_grads,
                    dist_startup_prog,
                    dist_main_prog,
                ) = auto_parallelizer.parallelize(
                    loss, startup_program, parameter_list, no_grad_set
                )
1336

1337 1338 1339 1340 1341 1342
                return (
                    optimize_ops,
                    params_grads,
                    dist_startup_prog,
                    dist_main_prog,
                )
L
lxsbupt 已提交
1343

1344 1345
            context["user_defined_strategy"] = copy.deepcopy(
                self._user_defined_strategy
L
lxsbupt 已提交
1346
            )
1347 1348
            copy_user_defined_strategy = copy.deepcopy(
                self._user_defined_strategy
1349
            )
L
lxsbupt 已提交
1350

1351 1352 1353 1354 1355 1356 1357 1358 1359 1360
            can_not_apply_optimizer_list = []
            # fix set collective and fleet ps gpu error
            if (
                self._is_collective
                and len(self._user_defined_strategy.sparse_table_configs) > 0
            ):
                context["use_fleet_ps"] = True
                from .meta_optimizers import ParameterServerOptimizer

                meta_optimizer = ParameterServerOptimizer(
L
lxsbupt 已提交
1361 1362
                    self.user_defined_optimizer
                )
1363 1364 1365 1366 1367 1368 1369
                meta_optimizer._set_basic_info(
                    loss,
                    self._role_maker,
                    self.user_defined_optimizer,
                    copy_user_defined_strategy,
                )
                can_not_apply_optimizer_list.append(meta_optimizer)
1370 1371 1372 1373

                # meaningless, just for compatibility with other code
                graph_optimizer = None

1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423
            else:
                # compile time
                distributed_optimizer_list = (
                    MetaOptimizerFactory()._get_valid_meta_optimizers(
                        self.user_defined_optimizer
                    )
                )
                # trigger the auto-parallel in very strict condition
                # strategy = DistributedStrategy()
                # strategy.auto = True
                # optimizer = paddle.optimizer.SGD(learning_rate=0.1)
                # optimizer = fleet.distributed_optimizer(optimizer, strategy)
                if copy_user_defined_strategy._is_strict_auto():
                    # turn on all the strategy for each optimizer
                    for opt in distributed_optimizer_list:
                        opt._enable_strategy(
                            copy_user_defined_strategy, context
                        )

                valid_optimizer_list = []
                valid_graph_optimizer_list = []
                # recall meta optimizers for ranking
                for opt in distributed_optimizer_list:
                    opt._set_basic_info(
                        loss,
                        self._role_maker,
                        self.user_defined_optimizer,
                        copy_user_defined_strategy,
                    )
                    if opt._can_apply() and not opt._is_graph_out():
                        valid_optimizer_list.append(opt)
                    elif opt._can_apply() and opt._is_graph_out():
                        valid_graph_optimizer_list.append(opt)
                    else:
                        can_not_apply_optimizer_list.append(opt)
                # combine recalled meta optimizers to be a valid meta optimizer
                (
                    meta_optimizer,
                    graph_optimizer,
                ) = self.strategy_compiler.generate_optimizer(
                    loss,
                    self._role_maker,
                    self.user_defined_optimizer,
                    copy_user_defined_strategy,
                    valid_optimizer_list,
                    valid_graph_optimizer_list,
                )

            valid_strategy = self.strategy_compiler._get_valid_strategy(
                copy_user_defined_strategy, can_not_apply_optimizer_list
L
lxsbupt 已提交
1424
            )
D
Dong Daxiang 已提交
1425

1426 1427 1428 1429 1430 1431
            context["valid_strategy"] = copy.deepcopy(valid_strategy)
            logger.debug("valid_strategy: " + str(context["valid_strategy"]))
            logger.debug(
                "user_defined_strategy: "
                + str(context["user_defined_strategy"])
            )
D
Dong Daxiang 已提交
1432

1433 1434 1435 1436
            applied_meta_list = self.strategy_compiler._get_applied_meta_list()
            applied_graph_list = (
                self.strategy_compiler._get_applied_graph_list()
            )
1437

1438 1439
            context['applied_meta_list'] = applied_meta_list
            context['applied_graph_list'] = applied_graph_list
1440

1441
            self._context = context
1442

1443 1444
            self.valid_strategy = valid_strategy
            self.valid_strategy._enable_env()
1445

1446 1447
            optimize_ops = []
            params_grads = []
D
Dong Daxiang 已提交
1448

1449 1450 1451 1452 1453 1454 1455 1456
            if (
                self._role_maker._is_non_distributed()
                and not self._is_collective
            ):
                if self._runtime_handle is None:
                    self._runtime_handle = RuntimeFactory()._create_runtime(
                        context
                    )
1457

1458 1459
                compiled_program = compiler.CompiledProgram(
                    self.origin_main_program
1460
                )
1461 1462 1463 1464 1465 1466 1467
                loss.block.program._graph = compiled_program
                return self.user_defined_optimizer.minimize(
                    loss,
                    startup_program,
                    parameter_list,
                    no_grad_set=no_grad_set,
                )
1468

1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490
            if meta_optimizer:
                logger.debug(
                    "before minimize program id: " + str(id(loss.block.program))
                )
                optimize_ops, params_grads = meta_optimizer.minimize(
                    loss,
                    startup_program,
                    parameter_list,
                    no_grad_set=no_grad_set,
                )
                logger.debug(
                    "after minimize program id: " + str(id(loss.block.program))
                )
                default_program = paddle.static.default_main_program()
                logger.debug("default program id: " + str(id(default_program)))

                if id(default_program) != id(loss.block.program):
                    paddle.framework.switch_main_program(loss.block.program)
                logger.debug(
                    "default program id after switch: "
                    + str(id(default_program))
                )
1491

1492 1493 1494 1495 1496 1497 1498 1499 1500 1501
            else:
                (
                    optimize_ops,
                    params_grads,
                ) = self.user_defined_optimizer.minimize(
                    loss,
                    startup_program,
                    parameter_list,
                    no_grad_set=no_grad_set,
                )
1502

1503 1504
            context["program_optimize_ops"] = optimize_ops
            context["program_params_grads"] = params_grads
1505

1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524
            if graph_optimizer:
                logger.debug(
                    "before graph minimize program id: "
                    + str(id(loss.block.program))
                )
                optimize_ops, params_grads = graph_optimizer.minimize(
                    loss,
                    startup_program,
                    parameter_list,
                    no_grad_set=no_grad_set,
                )
                # since we do not encourage users to use graph operations
                # if a graph optimizer takes effect, mostly
                # optimizers_ops and params_grads are None
                # i.e. users can not modify current computation graph anymore
                context["graph_optimize_ops"] = optimize_ops
                context["graph_optimize_grads"] = params_grads
            else:
                apply_ir_passes(loss.block.program, startup_program, self)
1525

1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539
            if not self._role_maker._is_heter_parameter_server_mode:
                program = paddle.static.default_main_program()
                opt_info = (
                    {} if program._fleet_opt is None else program._fleet_opt
                )
                opt_info["mpi_size"] = self.worker_num()
                opt_info["mpi_rank"] = self.worker_index()
                for (
                    k,
                    v,
                ) in self._user_defined_strategy.trainer_desc_configs.items():
                    if v or k not in opt_info:
                        opt_info[k] = v
                program._fleet_opt = opt_info
1540

1541 1542
            if self._runtime_handle is None:
                self._runtime_handle = RuntimeFactory()._create_runtime(context)
1543

1544
            import paddle.distributed.fleet as fleet
1545

1546
            fleet.util._set_strategy(context["valid_strategy"])
1547

1548
            return optimize_ops, params_grads
1549

1550 1551 1552 1553 1554 1555 1556
    def _minimize_losses_impl(
        self,
        losses,
        startup_programs=None,
        parameter_list=None,
        no_grad_set=None,
    ):
1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571
        context = {}

        # cache original feed forward program
        self.origin_main_program = losses[0].block.program
        context["origin_main_program"] = self.origin_main_program
        context["origin_main_programs"] = []
        for loss in losses:
            context["origin_main_programs"].append(loss.block.program)
        context["loss"] = losses

        if startup_programs is None:
            if len(losses) == 1:
                startup_programs = [paddle.static.default_startup_program()]
            else:
                raise ValueError(
1572 1573
                    "startup_program can't be None when loss is list."
                )
1574 1575 1576 1577 1578 1579 1580 1581 1582
        self.origin_startup_program = startup_programs[0].clone(for_test=False)
        context["origin_startup_program"] = startup_programs[0]
        context["origin_startup_programs"] = []
        for program in startup_programs:
            context["origin_startup_programs"].append(program)

        context["role_maker"] = self._role_maker

        context["user_defined_strategy"] = copy.deepcopy(
1583 1584
            self._user_defined_strategy
        )
1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595

        context["valid_strategy"] = copy.deepcopy(self._user_defined_strategy)

        self._context = context

        self.valid_strategy = context["valid_strategy"]
        self.valid_strategy._enable_env()

        optimize_ops = []
        params_grads = []

W
wuhuachaocoding 已提交
1596
        from .meta_optimizers import ParameterServerOptimizer
1597

1598
        ps_optimizer = ParameterServerOptimizer(self.user_defined_optimizer)
1599 1600 1601 1602 1603 1604
        ps_optimizer._set_basic_info(
            losses,
            self._role_maker,
            self.user_defined_optimizer,
            self._user_defined_strategy,
        )
1605
        optimize_ops, params_grads = ps_optimizer.minimize_losses_impl(
1606 1607
            losses, startup_programs, parameter_list, no_grad_set=no_grad_set
        )
1608 1609 1610 1611

        # default_program = paddle.static.default_main_program()

        # if id(default_program) != id(losses[0].block.program):
1612
        #     paddle.framework.switch_main_program(losses[0].block.program)
1613 1614 1615 1616 1617 1618 1619

        context["program_optimize_ops"] = optimize_ops
        context["program_params_grads"] = params_grads

        for loss in losses:
            program = loss.block.program
            opt_info = {} if program._fleet_opt is None else program._fleet_opt
1620 1621
            opt_info["mpi_size"] = self.worker_num()
            opt_info["mpi_rank"] = self.worker_index()
1622 1623 1624 1625
            for (
                k,
                v,
            ) in self._user_defined_strategy.trainer_desc_configs.items():
1626
                if v or k not in opt_info:
1627
                    opt_info[k] = v
1628
            program._fleet_opt = opt_info
1629 1630 1631 1632 1633
            logger.debug(
                "fleet base opt info: "
                + str(id(program))
                + str(program._fleet_opt)
            )
1634

1635
        if self._runtime_handle is None:
1636
            self._runtime_handle = RuntimeFactory()._create_runtime(context)
1637

1638
        import paddle.distributed.fleet as fleet
1639

1640
        fleet.util._set_strategy(context["valid_strategy"])
1641 1642

        return optimize_ops, params_grads