parameter_server_runtime.py 27.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import warnings

18
import paddle
W
wangzhen38 已提交
19
from paddle.framework import core
20 21 22 23 24 25 26 27 28
from paddle.static import (
    CompiledProgram,
    Executor,
    Program,
    Variable,
    default_main_program,
    default_startup_program,
    save_inference_model,
)
29

C
Chengmo 已提交
30
from ..base.private_helper_function import wait_server_ready
31
from .runtime_base import RuntimeBase
32

33 34
__all__ = []

35 36 37

class ParameterServerRuntime(RuntimeBase):
    def __init__(self):
38
        super().__init__()
39 40 41 42 43 44 45 46 47 48 49 50 51
        self._communicator = None

    def _set_basic_info(self, context):
        self.context = context
        self.role_maker = context["role_maker"]
        self.origin_main_program = context["origin_main_program"]
        self.origin_startup_program = context["origin_startup_program"]
        self.async_strategy = self._get_distributed_strategy()
        self.compiled_strategy = self.build_compiled_startegy()

    def _get_distributed_strategy(self):
        strategy = None

52
        from paddle.incubate.distributed.fleet.parameter_server.distribute_transpiler.distributed_strategy import (
53 54
            StrategyFactory,
        )
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73

        dist_strategy = self.context["valid_strategy"]
        k_steps = dist_strategy.a_sync_configs["k_steps"]

        if not dist_strategy.a_sync and k_steps == 0:
            strategy = StrategyFactory.create_sync_strategy()

        if dist_strategy.a_sync and k_steps == 0:
            strategy = StrategyFactory.create_async_strategy()

        if dist_strategy.a_sync and k_steps > 0:
            strategy = StrategyFactory.create_geo_strategy(k_steps)

        if not strategy:
            raise ValueError("k_steps must be invalid value, please check")

        return strategy

    def build_compiled_startegy(self):
74
        from paddle.incubate.distributed.fleet.parameter_server.ir.public import (
75 76 77 78 79 80 81 82 83
            CompileTimeStrategy,
        )

        compiled_config = CompileTimeStrategy(
            self.origin_main_program,
            self.origin_main_program,
            self.async_strategy,
            self.role_maker,
        )
84 85
        return compiled_config

86 87 88
    def _load_sparse_params(
        self, executor, dirname, varnames, main_program=None
    ):
89
        assert vars is not None
90 91 92 93 94 95 96 97
        check_vars = []
        load_prog = Program()
        load_block = load_prog.global_block()

        def _in_varnames(var):
            return var.name in varnames

        load_vars = list(
98
            filter(_in_varnames, default_main_program().list_vars())
99
        )
100 101 102
        if main_program is None:
            main_program = self.origin_main_program

103
        from paddle.incubate.distributed.fleet.parameter_server.ir.public import (
104 105 106
            _get_varname_parts,
        )

107 108 109 110
        for each_var in load_vars:
            assert isinstance(each_var, Variable)

            origin_varname, _, _ = _get_varname_parts(each_var.name)
W
wangzhen38 已提交
111
            new_var = paddle.static.io._clone_var_in_block(load_block, each_var)
112 113
            var_path = os.path.join(dirname, origin_varname)
            if not os.path.exists(var_path):
114 115
                raise ValueError(
                    "SelectedRows var {} can not find at {}".format(
116 117 118
                        new_var.name, var_path
                    )
                )
119 120

            if os.path.isfile(var_path):
121 122 123 124 125 126 127 128 129 130 131
                load_block.append_op(
                    type='sparse_tensor_load',
                    inputs={},
                    outputs={'Out': [new_var]},
                    attrs={
                        'file_path': os.path.join(dirname, origin_varname),
                        'node_index': self.role_maker._server_index(),
                        'node_num': self.role_maker._server_num(),
                        'shape': each_var.shape,
                    },
                )
132 133 134 135 136
            check_vars.append(each_var)

        executor.run(load_prog)

    def _load_distributed_params(self, dirname, varnames):
137
        from paddle.distributed.communicator import LargeScaleKV
138
        from paddle.incubate.distributed.fleet.parameter_server.ir.public import (
139 140
            _get_varname_parts,
        )
141 142 143 144 145 146 147 148 149 150 151 152 153

        scale_kv = LargeScaleKV()
        for varname in varnames:
            origin_varname, _, _ = _get_varname_parts(varname)
            sparse_dir = os.path.join(dirname, origin_varname, varname)
            scale_kv.load(varname, sparse_dir)

    @staticmethod
    def __exclude_vars(exclude_var_names=[]):
        def is_valid(var):
            if var.name in exclude_var_names:
                return False

154
            from paddle.incubate.distributed.fleet.parameter_server.ir.public import (
155 156
                _get_varname_parts,
            )
157 158 159 160 161 162 163 164

            origin_varname, _, _ = _get_varname_parts(var.name)
            if origin_varname.endswith("@GRAD"):
                return False

            if origin_varname == "learning_rate_0":
                return False

165 166 167 168 169
            if (
                var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH
                or var.desc.type() == core.VarDesc.VarType.FETCH_LIST
                or var.desc.type() == core.VarDesc.VarType.READER
            ):
170 171 172 173 174 175 176 177
                return False
            return var.persistable

        return is_valid

    def _init_worker(self):
        def sync_strategy_envs():
            kwargs = {}
178
            kwargs[
179 180
                "pserver_endpoints"
            ] = self.role_maker._get_pserver_endpoints()
181
            kwargs["trainer_id"] = self.role_maker._worker_index()
182 183 184
            return kwargs

        def geo_strategy_envs():
185
            from paddle.incubate.distributed.fleet.parameter_server.ir.public import (
186 187
                get_sparse_tablenames,
            )
188 189 190 191 192 193

            def get_sparse_attrs():
                opt_init_map = {}
                opt_init_map["gaussian_random"] = ["seed", "mean", "std"]
                opt_init_map["fill_constant"] = ["value"]
                opt_init_map["uniform_random"] = ["seed", "min", "max"]
194
                opt_init_map["truncated_gaussian_random"] = [
195 196 197
                    "seed",
                    "mean",
                    "std",
198
                ]
199

200 201 202
                dist_varnames = get_sparse_tablenames(
                    self.origin_main_program, True
                )
203
                sparse_varnames = get_sparse_tablenames(
204 205
                    self.origin_main_program, False
                )
206 207 208

                if len(dist_varnames) != 0:
                    raise ValueError(
209
                        "GeoStrategy can not support large scale embeding now, please use paddle.static.nn.embedding"
210 211 212 213
                    )

                init_attrs = []
                for value_name in sparse_varnames:
214 215 216
                    value_var = self.origin_main_program.global_block().vars[
                        value_name
                    ]
217 218
                    value_attr = [
                        value_name,
219
                        ",".join([str(dim) for dim in value_var.shape]),
220 221
                    ]
                    for op in self.origin_startup_program.global_block().ops:
222 223 224 225
                        if (
                            op.type in opt_init_map.keys()
                            and value_name == op.output("Out")[0]
                        ):
226 227 228 229 230 231 232 233 234
                            init_attr = [op.type]
                            for attr in opt_init_map[op.type]:
                                init_attr.append(str(op.attr(attr)))
                            value_attr.append("&".join(init_attr))
                            init_attrs.append(":".join(value_attr))
                            break
                return "#".join(init_attrs)

            kwargs = {}
235
            kwargs["trainers"] = self.role_maker._worker_num()
236 237 238
            kwargs["sparse_attrs"] = get_sparse_attrs()
            return kwargs

239
        from paddle.incubate.distributed.fleet.parameter_server.distribute_transpiler.distributed_strategy import (
W
wangzhen38 已提交
240 241 242
            GeoStrategy,
            SyncStrategy,
        )
243
        from paddle.incubate.distributed.fleet.parameter_server.ir.public import (
W
wangzhen38 已提交
244 245 246
            _get_lr_ops,
            _has_global_step,
        )
247 248

        trainer_config = self.async_strategy.get_trainer_runtime_config()
249
        print(trainer_config)
250

C
Chengmo 已提交
251 252 253 254 255 256 257
        dist_strategy = self.context["valid_strategy"]
        launch_barrier = dist_strategy.a_sync_configs["launch_barrier"]
        if launch_barrier:
            # for trainer wait server ready
            wait_server_ready(self.role_maker._get_pserver_endpoints())

            # for ps-heter mode, wait heter worker ready
258 259 260
            if (
                self.role_maker._is_heter_parameter_server_mode
                and self.role_maker._is_worker()
C
Chengmo 已提交
261 262 263
            ):
                wait_server_ready(self.role_maker._get_heter_worker_endpoints())

264 265 266
        lrs = _has_global_step(_get_lr_ops(self.origin_main_program))

        if lrs:
267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283
            kwargs = {"need_global_step": "1"}
        else:
            kwargs = {"need_global_step": "0"}

        if isinstance(self.async_strategy, GeoStrategy):
            geo_kwargs = geo_strategy_envs()
            kwargs.update(geo_kwargs)
        if isinstance(self.async_strategy, SyncStrategy):
            sync_kwargs = sync_strategy_envs()
            kwargs.update(sync_kwargs)

        kwargs = kwargs if kwargs else None

        send_ctx = self.compiled_strategy.get_communicator_send_context()

        if self.compiled_strategy.is_geo_mode():
            recv_ctx = self.compiled_strategy.get_communicator_recv_context(
284 285
                recv_type=4
            )
286 287
        else:
            recv_ctx = self.compiled_strategy.get_communicator_recv_context(
288 289
                recv_type=1
            )
290

291
        from paddle.distributed.communicator import Communicator
292

293
        self._communicator = Communicator(
294 295
            trainer_config.mode, kwargs, trainer_config.get_communicator_flags()
        )
296 297 298 299 300 301 302
        self._communicator.init_with_ctx(send_ctx, recv_ctx)

        if not self._communicator.is_running():
            self._communicator.start()
        else:
            warnings.warn("communicator has been initialized, skip")

303
    def _get_executor(self):
304
        executor = Executor(paddle.CPUPlace())
305
        if self.role_maker._is_heter_parameter_server_mode:
306 307 308 309 310
            heter_worker_device_guard = (
                self.context["valid_strategy"]
                .a_sync_configs["heter_worker_device_guard"]
                .upper()
            )
311
            if heter_worker_device_guard not in ["GPU", "XPU", "CPU"]:
312 313 314 315 316
                raise ValueError(
                    "Heter Worker Not Support Device {}".format(
                        heter_worker_device_guard
                    )
                )
317 318 319
            if self.role_maker._is_heter_worker():
                if heter_worker_device_guard == "GPU":
                    executor = Executor(
320
                        paddle.CUDAPlace(
321 322 323
                            int(os.getenv("FLAGS_selected_gpus", "0"))
                        )
                    )
324 325
                elif heter_worker_device_guard == "XPU":
                    executor = Executor(
326
                        paddle.XPUPlace(
327 328 329
                            int(os.getenv("FLAGS_selected_xpus", "0"))
                        )
                    )
330 331
        return executor

332 333 334 335 336 337 338 339
    def _init_server(self, *args, **kwargs):
        if len(args) > 1:
            raise ValueError("init server can only accept 1 args: `dirname`")
        elif len(args) == 1:
            model_dirname = args[0]
        else:
            model_dirname = None

340
        executor = self._get_executor()
341 342 343 344
        if (
            self.role_maker._is_heter_worker()
            and self.context["valid_strategy"].a_sync_configs["launch_barrier"]
        ):
345 346
            # for heter trainer wait server ready
            wait_server_ready(self.role_maker._get_pserver_endpoints())
347
        executor.run(default_startup_program())
348

T
tangwei12 已提交
349 350
        if self.role_maker._is_heter_worker():
            self._init_worker()
351 352
            return

353 354 355
        sparse_varnames = self.compiled_strategy.get_sparse_varname_on_ps(False)
        sparse_related_optimize_varnames = []
        for var_name in sparse_varnames:
356 357 358
            sparse_related_optimize_varnames += (
                self.compiled_strategy.get_optimize_varname_on_ps(var_name)
            )
359
        sparse_related_optimize_varnames = list(
360 361
            set(sparse_related_optimize_varnames)
        )
362
        distribtued_varnames = self.compiled_strategy.get_sparse_varname_on_ps(
363 364
            True
        )
365 366
        distributed_related_optimize_varnames = []
        for var_name in distribtued_varnames:
367 368 369
            distributed_related_optimize_varnames += (
                self.compiled_strategy.get_optimize_varname_on_ps(var_name)
            )
370
        distributed_related_optimize_varnames = list(
371 372
            set(distributed_related_optimize_varnames)
        )
373 374 375

        remaining_vars = list(
            filter(
376
                ParameterServerRuntime.__exclude_vars(
377 378 379 380 381
                    sparse_varnames
                    + distribtued_varnames
                    + sparse_related_optimize_varnames
                    + distributed_related_optimize_varnames
                ),
382
                default_main_program().list_vars(),
383 384
            )
        )
385

386 387 388 389 390 391 392
        if not model_dirname:
            return

        if not os.path.isdir(model_dirname):
            raise ValueError("There is no directory named '%s'", model_dirname)

        # load dense
393
        paddle.static.load_vars(
394
            executor,
395
            main_program=default_main_program(),
396 397 398
            dirname=model_dirname,
            vars=remaining_vars,
        )
399

400
        # load sparse
401 402 403 404 405
        self._load_sparse_params(
            executor=executor,
            dirname=model_dirname,
            varnames=sparse_varnames + sparse_related_optimize_varnames,
        )
406

407
        # load large scale
408 409 410 411 412
        self._load_distributed_params(
            dirname=model_dirname,
            varnames=distribtued_varnames
            + distributed_related_optimize_varnames,
        )
413 414

    def _run_server(self):
415
        executor = self._get_executor()
416
        executor.run(default_main_program())
417 418 419

    def _stop_worker(self):
        self._communicator.stop()
420
        executor = self._get_executor()
421
        executor.close()
422 423 424

    def _get_optimizer_status(self, op, param_name):
        supported_opts = [
425 426 427 428 429 430 431 432 433
            "sgd",
            "adam",
            "adagrad",
            "adamax",
            "momentum",
            "lars_momentum",
            "rmsprop",
            "decayed_adagrad",
            "ftrl",
434 435 436 437 438 439 440 441 442
        ]

        reshaped_val_map = {}
        reshaped_val_map["sgd"] = []
        reshaped_val_map["adam"] = ["moment1_0", "moment2_0"]
        reshaped_val_map["adagrad"] = ["moment_0"]
        reshaped_val_map["adamax"] = ["moment_0", "inf_norm_0"]
        reshaped_val_map["momentum"] = ["velocity_0"]
        reshaped_val_map["lars_momentum"] = ["velocity_0"]
443
        reshaped_val_map["rmsprop"] = [
444 445 446
            "momentum_0",
            "mean_square_0",
            "mean_grad_0",
447
        ]
448 449 450 451 452 453 454 455 456
        reshaped_val_map["decayed_adagrad"] = ["moment_0"]
        reshaped_val_map["ftrl"] = ["squared_0", "linear_0"]

        orishaped_val_map = {}
        orishaped_val_map["adam"] = ["beta1_pow_acc_0", "beta2_pow_acc_0"]
        orishaped_val_map["adamax"] = ["beta1_pow_acc_0"]

        if op not in supported_opts:
            raise ValueError(
457 458 459 460
                "fleet can not support optimizer: {}, only this can be supported: {}".format(
                    op, supported_opts
                )
            )
461 462 463 464 465 466 467 468 469 470 471 472 473 474

        reshaped_names = [
            param_name + "_" + val for val in reshaped_val_map[op]
        ]

        if op not in orishaped_val_map:
            origin_names = []
        else:
            origin_names = [
                param_name + "_" + val for val in orishaped_val_map[op]
            ]
        return reshaped_names, origin_names

    def _get_optimizer_op(self, param_name):
475
        from paddle.incubate.distributed.fleet.parameter_server.ir.public import (
476 477
            _get_optimize_ops,
        )
478 479 480

        opts = _get_optimize_ops(self.origin_main_program)
        for op in opts:
481 482 483 484 485
            if (
                "Param" in op.input_names
                and "LearningRate" in op.input_names
                and op.input("Param")[0] == param_name
            ):
486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503
                return op

    def _save_dense_params(self, executor, dirname, context, main_program):
        self._communicator.recv()

        prog = Program()
        block = prog.global_block()
        local_vars = []

        for name, var_ctx in context.items():
            if len(var_ctx.origin_varnames()) != 1:
                raise ValueError("Dense can not support split now.")

            varname = var_ctx.origin_varnames()[0]
            local_vars.append(varname)

            optimizer = self._get_optimizer_op(varname)
            reshaped_varnames, origin_varnames = self._get_optimizer_status(
504 505
                optimizer.type, varname
            )
506 507 508

            for var_name in [varname] + reshaped_varnames + origin_varnames:
                var = self.origin_main_program.global_block().vars[var_name]
509 510 511 512 513 514 515 516 517 518 519 520 521
                block.append_op(
                    type='recv_save',
                    attrs={
                        "trainer_id": self.role_maker._worker_index(),
                        "shape": var.shape,
                        "slice_shapes": [",".join([str(i) for i in var.shape])],
                        "slice_varnames": [var.name],
                        "remote_varnames": [var.name],
                        "is_sparse": False,
                        "endpoints": var_ctx.split_endpoints(),
                        "file_path": os.path.join(dirname, var.name),
                    },
                )
522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539

        executor.run(prog)
        return local_vars

    def _save_sparse_params(self, executor, dirname, context, main_program):
        prog = Program()
        block = prog.global_block()
        local_vars = []

        for name, var_ctx in context.items():
            if len(var_ctx.origin_varnames()) != 1:
                raise ValueError("Dense can not support split now.")

            varname = var_ctx.origin_varnames()[0]
            local_vars.append(varname)

            optimizer = self._get_optimizer_op(varname)
            reshaped_varnames, origin_varnames = self._get_optimizer_status(
540 541
                optimizer.type, varname
            )
542 543 544 545 546 547 548 549

            var = self.origin_main_program.global_block().vars[varname]
            slice_shapes = []
            dims1 = ",".join([str(i) for i in var.shape[1:]])

            for section in var_ctx.sections():
                slice_shapes.append(str(section) + dims1)

550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565
            block.append_op(
                type='recv_save',
                attrs={
                    "trainer_id": self.role_maker._worker_index(),
                    "shape": var.shape,
                    "slice_shapes": slice_shapes,
                    "slice_varnames": var_ctx.split_varnames(),
                    "remote_varnames": var_ctx.split_varnames(),
                    "is_sparse": True,
                    "endpoints": var_ctx.split_endpoints(),
                    "pserver_num": len(
                        self.role_maker._get_pserver_endpoints()
                    ),
                    "file_path": os.path.join(dirname, var.name),
                },
            )
566 567

            for reshaped_varname in reshaped_varnames:
568 569 570
                var = self.origin_main_program.global_block().vars[
                    reshaped_varname
                ]
571 572 573 574

                slice_varnames = []
                remote_varnames = []
                for i in range(len(var_ctx.split_varnames())):
575 576 577
                    slice_varnames.append(
                        "{}.block{}".format(reshaped_varname, i)
                    )
578 579 580 581 582
                    remote_varnames.append(reshaped_varname)

                block.append_op(
                    type='recv_save',
                    attrs={
583
                        "trainer_id": self.role_maker._worker_index(),
584 585 586 587 588 589
                        "shape": var.shape,
                        "slice_shapes": slice_shapes,
                        "slice_varnames": slice_varnames,
                        "remote_varnames": remote_varnames,
                        "is_sparse": True,
                        "endpoints": var_ctx.split_endpoints(),
590 591 592 593 594 595
                        "pserver_num": len(
                            self.role_maker._get_pserver_endpoints()
                        ),
                        "file_path": os.path.join(dirname, var.name),
                    },
                )
596 597

            for origin_varname in origin_varnames:
598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614
                var = self.origin_main_program.global_block().vars[
                    origin_varname
                ]

                block.append_op(
                    type='recv_save',
                    attrs={
                        "trainer_id": self.role_maker._worker_index(),
                        "shape": var.shape,
                        "slice_shapes": [",".join([str(i) for i in var.shape])],
                        "slice_varnames": [origin_varname],
                        "remote_varnames": [origin_varname],
                        "is_sparse": False,
                        "endpoints": var_ctx.split_endpoints()[:1],
                        "file_path": os.path.join(dirname, var.name),
                    },
                )
615 616 617
        executor.run(prog)
        return context.keys()

618
    def _save_distributed_params(self, executor, dirname, context, mode):
619 620 621 622
        prog = Program()
        block = prog.global_block()

        for name, var_ctx in context.items():
623 624 625 626 627 628 629 630 631 632 633
            block.append_op(
                type='checkpoint_notify',
                attrs={
                    "varname": name,
                    "mode": mode,
                    "slice_varnames": var_ctx.split_varnames(),
                    "remote_varnames": var_ctx.split_varnames(),
                    "endpoints": var_ctx.split_endpoints(),
                    "dirname": dirname,
                },
            )
634 635 636 637

        executor.run(prog)
        return context.keys()

638 639 640
    def _save_distributed_persistables(
        self, executor, dirname, main_program, mode
    ):
641
        dense_ctx = self.compiled_strategy.get_communicator_recv_context(
642 643
            recv_type=1, use_origin_program=True
        )
644 645

        sparse_ctx = self.compiled_strategy.get_communicator_recv_context(
646 647
            recv_type=2, use_origin_program=True
        )
648 649

        distributed_ctx = self.compiled_strategy.get_communicator_recv_context(
650 651
            recv_type=3, use_origin_program=True
        )
652

653 654 655
        recv_dense_varnames = self._save_dense_params(
            executor, dirname, dense_ctx, main_program
        )
656

657 658 659
        recv_sparse_varnames = self._save_sparse_params(
            executor, dirname, sparse_ctx, main_program
        )
660 661

        recv_distributed_varnames = self._save_distributed_params(
662 663
            executor, dirname, distributed_ctx, mode
        )
664

665 666 667 668 669
        saved_varnames = (
            recv_dense_varnames
            + list(recv_sparse_varnames)
            + list(recv_distributed_varnames)
        )
670 671

        remaining_vars = list(
672 673 674 675 676 677
            filter(
                ParameterServerRuntime.__exclude_vars(saved_varnames),
                main_program.list_vars(),
            )
        )

678
        paddle.static.save_vars(
679 680 681 682 683 684 685 686 687
            executor,
            main_program=main_program,
            dirname=dirname,
            vars=remaining_vars,
        )

    def _ps_inference_save_persistables(
        self, executor, dirname, main_program=None, mode=0, **kwargs
    ):
688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704
        """
        This function filters out all variables with `persistable==True` from the
        give `main_program` and then saves these variables to the folder `dirname`
        or file `filename`.

        The `dirname` is used to specify the folder where persistable variables
        are going to be saved. If you would like to save variables in separate
        files, set `filename` None; if you would like to save all variables in a
        single file, use `filename` to specify the file name.
        """

        if not isinstance(executor, Executor):
            raise TypeError(
                "in fleet.save_persistables() function, executor must be as Executor type"
            )

        if main_program is None:
705
            main_program = self.compiled_strategy.get_origin_ps_main_program()
706 707 708 709 710 711

        if isinstance(main_program, CompiledProgram):
            raise TypeError(
                "in fleet.save_persistables() function, main_program must be as Program type, CompiledProgram is not allowed"
            )

712 713 714 715 716 717 718 719 720 721 722 723 724
        self._save_distributed_persistables(
            executor, dirname, main_program, mode
        )

    def _ps_inference_save_inference_model(
        self,
        executor,
        dirname,
        feeded_var_names,
        target_vars,
        main_program=None,
        export_for_deployment=True,
    ):
725 726 727 728 729 730 731 732 733 734 735 736 737 738 739
        """
        Prune the given `main_program` to build a new program especially for inference,
        and then save it and all related parameters to given `dirname` by the `executor`.
        """

        if not isinstance(executor, Executor):
            raise TypeError(
                "in fleet.save_inference_model() function, executor must be as Executor type"
            )

        if main_program is not None:
            if isinstance(main_program, CompiledProgram):
                raise TypeError(
                    "in fleet.save_inference_model() function, main_program must be as Program type, CompiledProgram is not allowed"
                )
740
            save_inference_model(
741 742 743 744 745 746 747 748 749
                dirname,
                feeded_var_names,
                target_vars,
                executor,
                main_program,
                None,
                None,
                export_for_deployment,
            )
750
        else:
751
            save_inference_model(
752 753 754 755 756 757 758 759 760 761
                dirname,
                feeded_var_names,
                target_vars,
                executor,
                self.origin_main_program,
                None,
                None,
                export_for_deployment,
                True,
            )
762 763 764 765 766 767 768 769

            model_basename = "__model__"
            model_filename = os.path.join(dirname, model_basename)

            with open(model_filename, "rb") as f:
                program_desc_str = f.read()

            program = Program.parse_from_string(program_desc_str)
770
            program._copy_dist_param_info_from(default_main_program())
771 772 773
            self._ps_inference_save_persistables(
                executor, dirname, program, mode=0
            )
774 775 776 777 778 779

    def _save_inference_model(self, *args, **kwargs):
        self._ps_inference_save_inference_model(*args, **kwargs)

    def _save_persistables(self, *args, **kwargs):
        self._ps_inference_save_persistables(*args, **kwargs)