parameter_server_runtime.py 27.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import warnings

18
import paddle
19 20
import paddle.fluid as fluid
from paddle.fluid import core
21 22 23 24 25 26 27 28 29 30
from paddle.static import (
    CompiledProgram,
    Executor,
    ParallelExecutor,
    Program,
    Variable,
    default_main_program,
    default_startup_program,
    save_inference_model,
)
31

C
Chengmo 已提交
32
from ..base.private_helper_function import wait_server_ready
33
from .runtime_base import RuntimeBase
34

35 36
__all__ = []

37 38 39

class ParameterServerRuntime(RuntimeBase):
    def __init__(self):
40
        super().__init__()
41 42 43 44 45 46 47 48 49 50 51 52 53
        self._communicator = None

    def _set_basic_info(self, context):
        self.context = context
        self.role_maker = context["role_maker"]
        self.origin_main_program = context["origin_main_program"]
        self.origin_startup_program = context["origin_startup_program"]
        self.async_strategy = self._get_distributed_strategy()
        self.compiled_strategy = self.build_compiled_startegy()

    def _get_distributed_strategy(self):
        strategy = None

W
wangzhen38 已提交
54
        from paddle.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import (
55 56
            StrategyFactory,
        )
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75

        dist_strategy = self.context["valid_strategy"]
        k_steps = dist_strategy.a_sync_configs["k_steps"]

        if not dist_strategy.a_sync and k_steps == 0:
            strategy = StrategyFactory.create_sync_strategy()

        if dist_strategy.a_sync and k_steps == 0:
            strategy = StrategyFactory.create_async_strategy()

        if dist_strategy.a_sync and k_steps > 0:
            strategy = StrategyFactory.create_geo_strategy(k_steps)

        if not strategy:
            raise ValueError("k_steps must be invalid value, please check")

        return strategy

    def build_compiled_startegy(self):
76 77 78 79 80 81 82 83 84 85
        from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
            CompileTimeStrategy,
        )

        compiled_config = CompileTimeStrategy(
            self.origin_main_program,
            self.origin_main_program,
            self.async_strategy,
            self.role_maker,
        )
86 87
        return compiled_config

88 89 90
    def _load_sparse_params(
        self, executor, dirname, varnames, main_program=None
    ):
91
        assert vars is not None
92 93 94 95 96 97 98 99
        check_vars = []
        load_prog = Program()
        load_block = load_prog.global_block()

        def _in_varnames(var):
            return var.name in varnames

        load_vars = list(
100
            filter(_in_varnames, default_main_program().list_vars())
101
        )
102 103 104
        if main_program is None:
            main_program = self.origin_main_program

105 106 107 108
        from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
            _get_varname_parts,
        )

109 110 111 112 113 114 115 116
        for each_var in load_vars:
            assert isinstance(each_var, Variable)

            origin_varname, _, _ = _get_varname_parts(each_var.name)

            new_var = fluid.io._clone_var_in_block_(load_block, each_var)
            var_path = os.path.join(dirname, origin_varname)
            if not os.path.exists(var_path):
117 118
                raise ValueError(
                    "SelectedRows var {} can not find at {}".format(
119 120 121
                        new_var.name, var_path
                    )
                )
122 123

            if os.path.isfile(var_path):
124 125 126 127 128 129 130 131 132 133 134
                load_block.append_op(
                    type='sparse_tensor_load',
                    inputs={},
                    outputs={'Out': [new_var]},
                    attrs={
                        'file_path': os.path.join(dirname, origin_varname),
                        'node_index': self.role_maker._server_index(),
                        'node_num': self.role_maker._server_num(),
                        'shape': each_var.shape,
                    },
                )
135 136 137 138 139
            check_vars.append(each_var)

        executor.run(load_prog)

    def _load_distributed_params(self, dirname, varnames):
140
        from paddle.distributed.communicator import LargeScaleKV
141 142 143
        from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
            _get_varname_parts,
        )
144 145 146 147 148 149 150 151 152 153 154 155 156

        scale_kv = LargeScaleKV()
        for varname in varnames:
            origin_varname, _, _ = _get_varname_parts(varname)
            sparse_dir = os.path.join(dirname, origin_varname, varname)
            scale_kv.load(varname, sparse_dir)

    @staticmethod
    def __exclude_vars(exclude_var_names=[]):
        def is_valid(var):
            if var.name in exclude_var_names:
                return False

157 158 159
            from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
                _get_varname_parts,
            )
160 161 162 163 164 165 166 167

            origin_varname, _, _ = _get_varname_parts(var.name)
            if origin_varname.endswith("@GRAD"):
                return False

            if origin_varname == "learning_rate_0":
                return False

168 169 170 171 172
            if (
                var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH
                or var.desc.type() == core.VarDesc.VarType.FETCH_LIST
                or var.desc.type() == core.VarDesc.VarType.READER
            ):
173 174 175 176 177 178 179 180
                return False
            return var.persistable

        return is_valid

    def _init_worker(self):
        def sync_strategy_envs():
            kwargs = {}
181
            kwargs[
182 183
                "pserver_endpoints"
            ] = self.role_maker._get_pserver_endpoints()
184
            kwargs["trainer_id"] = self.role_maker._worker_index()
185 186 187
            return kwargs

        def geo_strategy_envs():
188 189 190
            from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
                get_sparse_tablenames,
            )
191 192 193 194 195 196

            def get_sparse_attrs():
                opt_init_map = {}
                opt_init_map["gaussian_random"] = ["seed", "mean", "std"]
                opt_init_map["fill_constant"] = ["value"]
                opt_init_map["uniform_random"] = ["seed", "min", "max"]
197
                opt_init_map["truncated_gaussian_random"] = [
198 199 200
                    "seed",
                    "mean",
                    "std",
201
                ]
202

203 204 205
                dist_varnames = get_sparse_tablenames(
                    self.origin_main_program, True
                )
206
                sparse_varnames = get_sparse_tablenames(
207 208
                    self.origin_main_program, False
                )
209 210 211

                if len(dist_varnames) != 0:
                    raise ValueError(
212
                        "GeoStrategy can not support large scale embeding now, please use paddle.static.nn.embedding"
213 214 215 216
                    )

                init_attrs = []
                for value_name in sparse_varnames:
217 218 219
                    value_var = self.origin_main_program.global_block().vars[
                        value_name
                    ]
220 221
                    value_attr = [
                        value_name,
222
                        ",".join([str(dim) for dim in value_var.shape]),
223 224
                    ]
                    for op in self.origin_startup_program.global_block().ops:
225 226 227 228
                        if (
                            op.type in opt_init_map.keys()
                            and value_name == op.output("Out")[0]
                        ):
229 230 231 232 233 234 235 236 237
                            init_attr = [op.type]
                            for attr in opt_init_map[op.type]:
                                init_attr.append(str(op.attr(attr)))
                            value_attr.append("&".join(init_attr))
                            init_attrs.append(":".join(value_attr))
                            break
                return "#".join(init_attrs)

            kwargs = {}
238
            kwargs["trainers"] = self.role_maker._worker_num()
239 240 241
            kwargs["sparse_attrs"] = get_sparse_attrs()
            return kwargs

242 243 244 245
        from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
            _get_lr_ops,
            _has_global_step,
        )
W
wangzhen38 已提交
246 247 248 249
        from paddle.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import (
            GeoStrategy,
            SyncStrategy,
        )
250 251

        trainer_config = self.async_strategy.get_trainer_runtime_config()
252
        print(trainer_config)
253

C
Chengmo 已提交
254 255 256 257 258 259 260
        dist_strategy = self.context["valid_strategy"]
        launch_barrier = dist_strategy.a_sync_configs["launch_barrier"]
        if launch_barrier:
            # for trainer wait server ready
            wait_server_ready(self.role_maker._get_pserver_endpoints())

            # for ps-heter mode, wait heter worker ready
261 262 263
            if (
                self.role_maker._is_heter_parameter_server_mode
                and self.role_maker._is_worker()
C
Chengmo 已提交
264 265 266
            ):
                wait_server_ready(self.role_maker._get_heter_worker_endpoints())

267 268 269
        lrs = _has_global_step(_get_lr_ops(self.origin_main_program))

        if lrs:
270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286
            kwargs = {"need_global_step": "1"}
        else:
            kwargs = {"need_global_step": "0"}

        if isinstance(self.async_strategy, GeoStrategy):
            geo_kwargs = geo_strategy_envs()
            kwargs.update(geo_kwargs)
        if isinstance(self.async_strategy, SyncStrategy):
            sync_kwargs = sync_strategy_envs()
            kwargs.update(sync_kwargs)

        kwargs = kwargs if kwargs else None

        send_ctx = self.compiled_strategy.get_communicator_send_context()

        if self.compiled_strategy.is_geo_mode():
            recv_ctx = self.compiled_strategy.get_communicator_recv_context(
287 288
                recv_type=4
            )
289 290
        else:
            recv_ctx = self.compiled_strategy.get_communicator_recv_context(
291 292
                recv_type=1
            )
293

294
        from paddle.distributed.communicator import Communicator
295

296
        self._communicator = Communicator(
297 298
            trainer_config.mode, kwargs, trainer_config.get_communicator_flags()
        )
299 300 301 302 303 304 305
        self._communicator.init_with_ctx(send_ctx, recv_ctx)

        if not self._communicator.is_running():
            self._communicator.start()
        else:
            warnings.warn("communicator has been initialized, skip")

306
    def _get_executor(self):
307
        executor = Executor(paddle.CPUPlace())
308
        if self.role_maker._is_heter_parameter_server_mode:
309 310 311 312 313
            heter_worker_device_guard = (
                self.context["valid_strategy"]
                .a_sync_configs["heter_worker_device_guard"]
                .upper()
            )
314
            if heter_worker_device_guard not in ["GPU", "XPU", "CPU"]:
315 316 317 318 319
                raise ValueError(
                    "Heter Worker Not Support Device {}".format(
                        heter_worker_device_guard
                    )
                )
320 321 322
            if self.role_maker._is_heter_worker():
                if heter_worker_device_guard == "GPU":
                    executor = Executor(
323
                        paddle.CUDAPlace(
324 325 326
                            int(os.getenv("FLAGS_selected_gpus", "0"))
                        )
                    )
327 328
                elif heter_worker_device_guard == "XPU":
                    executor = Executor(
329
                        paddle.XPUPlace(
330 331 332
                            int(os.getenv("FLAGS_selected_xpus", "0"))
                        )
                    )
333 334
        return executor

335 336 337 338 339 340 341 342
    def _init_server(self, *args, **kwargs):
        if len(args) > 1:
            raise ValueError("init server can only accept 1 args: `dirname`")
        elif len(args) == 1:
            model_dirname = args[0]
        else:
            model_dirname = None

343
        executor = self._get_executor()
344 345 346 347
        if (
            self.role_maker._is_heter_worker()
            and self.context["valid_strategy"].a_sync_configs["launch_barrier"]
        ):
348 349
            # for heter trainer wait server ready
            wait_server_ready(self.role_maker._get_pserver_endpoints())
350
        executor.run(default_startup_program())
351

T
tangwei12 已提交
352 353
        if self.role_maker._is_heter_worker():
            self._init_worker()
354 355
            return

356 357 358
        sparse_varnames = self.compiled_strategy.get_sparse_varname_on_ps(False)
        sparse_related_optimize_varnames = []
        for var_name in sparse_varnames:
359 360 361
            sparse_related_optimize_varnames += (
                self.compiled_strategy.get_optimize_varname_on_ps(var_name)
            )
362
        sparse_related_optimize_varnames = list(
363 364
            set(sparse_related_optimize_varnames)
        )
365
        distribtued_varnames = self.compiled_strategy.get_sparse_varname_on_ps(
366 367
            True
        )
368 369
        distributed_related_optimize_varnames = []
        for var_name in distribtued_varnames:
370 371 372
            distributed_related_optimize_varnames += (
                self.compiled_strategy.get_optimize_varname_on_ps(var_name)
            )
373
        distributed_related_optimize_varnames = list(
374 375
            set(distributed_related_optimize_varnames)
        )
376 377 378

        remaining_vars = list(
            filter(
379
                ParameterServerRuntime.__exclude_vars(
380 381 382 383 384
                    sparse_varnames
                    + distribtued_varnames
                    + sparse_related_optimize_varnames
                    + distributed_related_optimize_varnames
                ),
385
                default_main_program().list_vars(),
386 387
            )
        )
388

389 390 391 392 393 394 395
        if not model_dirname:
            return

        if not os.path.isdir(model_dirname):
            raise ValueError("There is no directory named '%s'", model_dirname)

        # load dense
396
        paddle.static.load_vars(
397
            executor,
398
            main_program=default_main_program(),
399 400 401
            dirname=model_dirname,
            vars=remaining_vars,
        )
402

403
        # load sparse
404 405 406 407 408
        self._load_sparse_params(
            executor=executor,
            dirname=model_dirname,
            varnames=sparse_varnames + sparse_related_optimize_varnames,
        )
409

410
        # load large scale
411 412 413 414 415
        self._load_distributed_params(
            dirname=model_dirname,
            varnames=distribtued_varnames
            + distributed_related_optimize_varnames,
        )
416 417

    def _run_server(self):
418
        executor = self._get_executor()
419
        executor.run(default_main_program())
420 421 422

    def _stop_worker(self):
        self._communicator.stop()
423
        executor = self._get_executor()
424
        executor.close()
425 426 427

    def _get_optimizer_status(self, op, param_name):
        supported_opts = [
428 429 430 431 432 433 434 435 436
            "sgd",
            "adam",
            "adagrad",
            "adamax",
            "momentum",
            "lars_momentum",
            "rmsprop",
            "decayed_adagrad",
            "ftrl",
437 438 439 440 441 442 443 444 445
        ]

        reshaped_val_map = {}
        reshaped_val_map["sgd"] = []
        reshaped_val_map["adam"] = ["moment1_0", "moment2_0"]
        reshaped_val_map["adagrad"] = ["moment_0"]
        reshaped_val_map["adamax"] = ["moment_0", "inf_norm_0"]
        reshaped_val_map["momentum"] = ["velocity_0"]
        reshaped_val_map["lars_momentum"] = ["velocity_0"]
446
        reshaped_val_map["rmsprop"] = [
447 448 449
            "momentum_0",
            "mean_square_0",
            "mean_grad_0",
450
        ]
451 452 453 454 455 456 457 458 459
        reshaped_val_map["decayed_adagrad"] = ["moment_0"]
        reshaped_val_map["ftrl"] = ["squared_0", "linear_0"]

        orishaped_val_map = {}
        orishaped_val_map["adam"] = ["beta1_pow_acc_0", "beta2_pow_acc_0"]
        orishaped_val_map["adamax"] = ["beta1_pow_acc_0"]

        if op not in supported_opts:
            raise ValueError(
460 461 462 463
                "fleet can not support optimizer: {}, only this can be supported: {}".format(
                    op, supported_opts
                )
            )
464 465 466 467 468 469 470 471 472 473 474 475 476 477

        reshaped_names = [
            param_name + "_" + val for val in reshaped_val_map[op]
        ]

        if op not in orishaped_val_map:
            origin_names = []
        else:
            origin_names = [
                param_name + "_" + val for val in orishaped_val_map[op]
            ]
        return reshaped_names, origin_names

    def _get_optimizer_op(self, param_name):
478 479 480
        from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
            _get_optimize_ops,
        )
481 482 483

        opts = _get_optimize_ops(self.origin_main_program)
        for op in opts:
484 485 486 487 488
            if (
                "Param" in op.input_names
                and "LearningRate" in op.input_names
                and op.input("Param")[0] == param_name
            ):
489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506
                return op

    def _save_dense_params(self, executor, dirname, context, main_program):
        self._communicator.recv()

        prog = Program()
        block = prog.global_block()
        local_vars = []

        for name, var_ctx in context.items():
            if len(var_ctx.origin_varnames()) != 1:
                raise ValueError("Dense can not support split now.")

            varname = var_ctx.origin_varnames()[0]
            local_vars.append(varname)

            optimizer = self._get_optimizer_op(varname)
            reshaped_varnames, origin_varnames = self._get_optimizer_status(
507 508
                optimizer.type, varname
            )
509 510 511

            for var_name in [varname] + reshaped_varnames + origin_varnames:
                var = self.origin_main_program.global_block().vars[var_name]
512 513 514 515 516 517 518 519 520 521 522 523 524
                block.append_op(
                    type='recv_save',
                    attrs={
                        "trainer_id": self.role_maker._worker_index(),
                        "shape": var.shape,
                        "slice_shapes": [",".join([str(i) for i in var.shape])],
                        "slice_varnames": [var.name],
                        "remote_varnames": [var.name],
                        "is_sparse": False,
                        "endpoints": var_ctx.split_endpoints(),
                        "file_path": os.path.join(dirname, var.name),
                    },
                )
525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542

        executor.run(prog)
        return local_vars

    def _save_sparse_params(self, executor, dirname, context, main_program):
        prog = Program()
        block = prog.global_block()
        local_vars = []

        for name, var_ctx in context.items():
            if len(var_ctx.origin_varnames()) != 1:
                raise ValueError("Dense can not support split now.")

            varname = var_ctx.origin_varnames()[0]
            local_vars.append(varname)

            optimizer = self._get_optimizer_op(varname)
            reshaped_varnames, origin_varnames = self._get_optimizer_status(
543 544
                optimizer.type, varname
            )
545 546 547 548 549 550 551 552

            var = self.origin_main_program.global_block().vars[varname]
            slice_shapes = []
            dims1 = ",".join([str(i) for i in var.shape[1:]])

            for section in var_ctx.sections():
                slice_shapes.append(str(section) + dims1)

553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568
            block.append_op(
                type='recv_save',
                attrs={
                    "trainer_id": self.role_maker._worker_index(),
                    "shape": var.shape,
                    "slice_shapes": slice_shapes,
                    "slice_varnames": var_ctx.split_varnames(),
                    "remote_varnames": var_ctx.split_varnames(),
                    "is_sparse": True,
                    "endpoints": var_ctx.split_endpoints(),
                    "pserver_num": len(
                        self.role_maker._get_pserver_endpoints()
                    ),
                    "file_path": os.path.join(dirname, var.name),
                },
            )
569 570

            for reshaped_varname in reshaped_varnames:
571 572 573
                var = self.origin_main_program.global_block().vars[
                    reshaped_varname
                ]
574 575 576 577

                slice_varnames = []
                remote_varnames = []
                for i in range(len(var_ctx.split_varnames())):
578 579 580
                    slice_varnames.append(
                        "{}.block{}".format(reshaped_varname, i)
                    )
581 582 583 584 585
                    remote_varnames.append(reshaped_varname)

                block.append_op(
                    type='recv_save',
                    attrs={
586
                        "trainer_id": self.role_maker._worker_index(),
587 588 589 590 591 592
                        "shape": var.shape,
                        "slice_shapes": slice_shapes,
                        "slice_varnames": slice_varnames,
                        "remote_varnames": remote_varnames,
                        "is_sparse": True,
                        "endpoints": var_ctx.split_endpoints(),
593 594 595 596 597 598
                        "pserver_num": len(
                            self.role_maker._get_pserver_endpoints()
                        ),
                        "file_path": os.path.join(dirname, var.name),
                    },
                )
599 600

            for origin_varname in origin_varnames:
601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617
                var = self.origin_main_program.global_block().vars[
                    origin_varname
                ]

                block.append_op(
                    type='recv_save',
                    attrs={
                        "trainer_id": self.role_maker._worker_index(),
                        "shape": var.shape,
                        "slice_shapes": [",".join([str(i) for i in var.shape])],
                        "slice_varnames": [origin_varname],
                        "remote_varnames": [origin_varname],
                        "is_sparse": False,
                        "endpoints": var_ctx.split_endpoints()[:1],
                        "file_path": os.path.join(dirname, var.name),
                    },
                )
618 619 620
        executor.run(prog)
        return context.keys()

621
    def _save_distributed_params(self, executor, dirname, context, mode):
622 623 624 625
        prog = Program()
        block = prog.global_block()

        for name, var_ctx in context.items():
626 627 628 629 630 631 632 633 634 635 636
            block.append_op(
                type='checkpoint_notify',
                attrs={
                    "varname": name,
                    "mode": mode,
                    "slice_varnames": var_ctx.split_varnames(),
                    "remote_varnames": var_ctx.split_varnames(),
                    "endpoints": var_ctx.split_endpoints(),
                    "dirname": dirname,
                },
            )
637 638 639 640

        executor.run(prog)
        return context.keys()

641 642 643
    def _save_distributed_persistables(
        self, executor, dirname, main_program, mode
    ):
644
        dense_ctx = self.compiled_strategy.get_communicator_recv_context(
645 646
            recv_type=1, use_origin_program=True
        )
647 648

        sparse_ctx = self.compiled_strategy.get_communicator_recv_context(
649 650
            recv_type=2, use_origin_program=True
        )
651 652

        distributed_ctx = self.compiled_strategy.get_communicator_recv_context(
653 654
            recv_type=3, use_origin_program=True
        )
655

656 657 658
        recv_dense_varnames = self._save_dense_params(
            executor, dirname, dense_ctx, main_program
        )
659

660 661 662
        recv_sparse_varnames = self._save_sparse_params(
            executor, dirname, sparse_ctx, main_program
        )
663 664

        recv_distributed_varnames = self._save_distributed_params(
665 666
            executor, dirname, distributed_ctx, mode
        )
667

668 669 670 671 672
        saved_varnames = (
            recv_dense_varnames
            + list(recv_sparse_varnames)
            + list(recv_distributed_varnames)
        )
673 674

        remaining_vars = list(
675 676 677 678 679 680
            filter(
                ParameterServerRuntime.__exclude_vars(saved_varnames),
                main_program.list_vars(),
            )
        )

681
        paddle.static.save_vars(
682 683 684 685 686 687 688 689 690
            executor,
            main_program=main_program,
            dirname=dirname,
            vars=remaining_vars,
        )

    def _ps_inference_save_persistables(
        self, executor, dirname, main_program=None, mode=0, **kwargs
    ):
691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712
        """
        This function filters out all variables with `persistable==True` from the
        give `main_program` and then saves these variables to the folder `dirname`
        or file `filename`.

        The `dirname` is used to specify the folder where persistable variables
        are going to be saved. If you would like to save variables in separate
        files, set `filename` None; if you would like to save all variables in a
        single file, use `filename` to specify the file name.
        """

        if isinstance(executor, ParallelExecutor):
            raise TypeError(
                "in fleet.save_persistables() function, executor must be as Executor type, ParallelExecutor is not allowed"
            )

        if not isinstance(executor, Executor):
            raise TypeError(
                "in fleet.save_persistables() function, executor must be as Executor type"
            )

        if main_program is None:
713
            main_program = self.compiled_strategy.get_origin_ps_main_program()
714 715 716 717 718 719

        if isinstance(main_program, CompiledProgram):
            raise TypeError(
                "in fleet.save_persistables() function, main_program must be as Program type, CompiledProgram is not allowed"
            )

720 721 722 723 724 725 726 727 728 729 730 731 732
        self._save_distributed_persistables(
            executor, dirname, main_program, mode
        )

    def _ps_inference_save_inference_model(
        self,
        executor,
        dirname,
        feeded_var_names,
        target_vars,
        main_program=None,
        export_for_deployment=True,
    ):
733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752
        """
        Prune the given `main_program` to build a new program especially for inference,
        and then save it and all related parameters to given `dirname` by the `executor`.
        """

        if isinstance(executor, ParallelExecutor):
            raise TypeError(
                "in fleet.save_inference_model() function, executor must be as Executor type, ParallelExecutor is not allowed"
            )

        if not isinstance(executor, Executor):
            raise TypeError(
                "in fleet.save_inference_model() function, executor must be as Executor type"
            )

        if main_program is not None:
            if isinstance(main_program, CompiledProgram):
                raise TypeError(
                    "in fleet.save_inference_model() function, main_program must be as Program type, CompiledProgram is not allowed"
                )
753
            save_inference_model(
754 755 756 757 758 759 760 761 762
                dirname,
                feeded_var_names,
                target_vars,
                executor,
                main_program,
                None,
                None,
                export_for_deployment,
            )
763
        else:
764
            save_inference_model(
765 766 767 768 769 770 771 772 773 774
                dirname,
                feeded_var_names,
                target_vars,
                executor,
                self.origin_main_program,
                None,
                None,
                export_for_deployment,
                True,
            )
775 776 777 778 779 780 781 782

            model_basename = "__model__"
            model_filename = os.path.join(dirname, model_basename)

            with open(model_filename, "rb") as f:
                program_desc_str = f.read()

            program = Program.parse_from_string(program_desc_str)
783
            program._copy_dist_param_info_from(default_main_program())
784 785 786
            self._ps_inference_save_persistables(
                executor, dirname, program, mode=0
            )
787 788 789 790 791 792

    def _save_inference_model(self, *args, **kwargs):
        self._ps_inference_save_inference_model(*args, **kwargs)

    def _save_persistables(self, *args, **kwargs):
        self._ps_inference_save_persistables(*args, **kwargs)