parameter_server_runtime.py 27.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import warnings

18
import paddle
W
wangzhen38 已提交
19
from paddle.framework import core
20 21 22 23 24 25 26 27 28 29
from paddle.static import (
    CompiledProgram,
    Executor,
    ParallelExecutor,
    Program,
    Variable,
    default_main_program,
    default_startup_program,
    save_inference_model,
)
30

C
Chengmo 已提交
31
from ..base.private_helper_function import wait_server_ready
32
from .runtime_base import RuntimeBase
33

34 35
__all__ = []

36 37 38

class ParameterServerRuntime(RuntimeBase):
    def __init__(self):
39
        super().__init__()
40 41 42 43 44 45 46 47 48 49 50 51 52
        self._communicator = None

    def _set_basic_info(self, context):
        self.context = context
        self.role_maker = context["role_maker"]
        self.origin_main_program = context["origin_main_program"]
        self.origin_startup_program = context["origin_startup_program"]
        self.async_strategy = self._get_distributed_strategy()
        self.compiled_strategy = self.build_compiled_startegy()

    def _get_distributed_strategy(self):
        strategy = None

53
        from paddle.incubate.distributed.fleet.parameter_server.distribute_transpiler.distributed_strategy import (
54 55
            StrategyFactory,
        )
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74

        dist_strategy = self.context["valid_strategy"]
        k_steps = dist_strategy.a_sync_configs["k_steps"]

        if not dist_strategy.a_sync and k_steps == 0:
            strategy = StrategyFactory.create_sync_strategy()

        if dist_strategy.a_sync and k_steps == 0:
            strategy = StrategyFactory.create_async_strategy()

        if dist_strategy.a_sync and k_steps > 0:
            strategy = StrategyFactory.create_geo_strategy(k_steps)

        if not strategy:
            raise ValueError("k_steps must be invalid value, please check")

        return strategy

    def build_compiled_startegy(self):
75
        from paddle.incubate.distributed.fleet.parameter_server.ir.public import (
76 77 78 79 80 81 82 83 84
            CompileTimeStrategy,
        )

        compiled_config = CompileTimeStrategy(
            self.origin_main_program,
            self.origin_main_program,
            self.async_strategy,
            self.role_maker,
        )
85 86
        return compiled_config

87 88 89
    def _load_sparse_params(
        self, executor, dirname, varnames, main_program=None
    ):
90
        assert vars is not None
91 92 93 94 95 96 97 98
        check_vars = []
        load_prog = Program()
        load_block = load_prog.global_block()

        def _in_varnames(var):
            return var.name in varnames

        load_vars = list(
99
            filter(_in_varnames, default_main_program().list_vars())
100
        )
101 102 103
        if main_program is None:
            main_program = self.origin_main_program

104
        from paddle.incubate.distributed.fleet.parameter_server.ir.public import (
105 106 107
            _get_varname_parts,
        )

108 109 110 111 112
        for each_var in load_vars:
            assert isinstance(each_var, Variable)

            origin_varname, _, _ = _get_varname_parts(each_var.name)

W
wangzhen38 已提交
113
            new_var = paddle.static.io._clone_var_in_block(load_block, each_var)
114 115
            var_path = os.path.join(dirname, origin_varname)
            if not os.path.exists(var_path):
116 117
                raise ValueError(
                    "SelectedRows var {} can not find at {}".format(
118 119 120
                        new_var.name, var_path
                    )
                )
121 122

            if os.path.isfile(var_path):
123 124 125 126 127 128 129 130 131 132 133
                load_block.append_op(
                    type='sparse_tensor_load',
                    inputs={},
                    outputs={'Out': [new_var]},
                    attrs={
                        'file_path': os.path.join(dirname, origin_varname),
                        'node_index': self.role_maker._server_index(),
                        'node_num': self.role_maker._server_num(),
                        'shape': each_var.shape,
                    },
                )
134 135 136 137 138
            check_vars.append(each_var)

        executor.run(load_prog)

    def _load_distributed_params(self, dirname, varnames):
139
        from paddle.distributed.communicator import LargeScaleKV
140
        from paddle.incubate.distributed.fleet.parameter_server.ir.public import (
141 142
            _get_varname_parts,
        )
143 144 145 146 147 148 149 150 151 152 153 154 155

        scale_kv = LargeScaleKV()
        for varname in varnames:
            origin_varname, _, _ = _get_varname_parts(varname)
            sparse_dir = os.path.join(dirname, origin_varname, varname)
            scale_kv.load(varname, sparse_dir)

    @staticmethod
    def __exclude_vars(exclude_var_names=[]):
        def is_valid(var):
            if var.name in exclude_var_names:
                return False

156
            from paddle.incubate.distributed.fleet.parameter_server.ir.public import (
157 158
                _get_varname_parts,
            )
159 160 161 162 163 164 165 166

            origin_varname, _, _ = _get_varname_parts(var.name)
            if origin_varname.endswith("@GRAD"):
                return False

            if origin_varname == "learning_rate_0":
                return False

167 168 169 170 171
            if (
                var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH
                or var.desc.type() == core.VarDesc.VarType.FETCH_LIST
                or var.desc.type() == core.VarDesc.VarType.READER
            ):
172 173 174 175 176 177 178 179
                return False
            return var.persistable

        return is_valid

    def _init_worker(self):
        def sync_strategy_envs():
            kwargs = {}
180
            kwargs[
181 182
                "pserver_endpoints"
            ] = self.role_maker._get_pserver_endpoints()
183
            kwargs["trainer_id"] = self.role_maker._worker_index()
184 185 186
            return kwargs

        def geo_strategy_envs():
187
            from paddle.incubate.distributed.fleet.parameter_server.ir.public import (
188 189
                get_sparse_tablenames,
            )
190 191 192 193 194 195

            def get_sparse_attrs():
                opt_init_map = {}
                opt_init_map["gaussian_random"] = ["seed", "mean", "std"]
                opt_init_map["fill_constant"] = ["value"]
                opt_init_map["uniform_random"] = ["seed", "min", "max"]
196
                opt_init_map["truncated_gaussian_random"] = [
197 198 199
                    "seed",
                    "mean",
                    "std",
200
                ]
201

202 203 204
                dist_varnames = get_sparse_tablenames(
                    self.origin_main_program, True
                )
205
                sparse_varnames = get_sparse_tablenames(
206 207
                    self.origin_main_program, False
                )
208 209 210

                if len(dist_varnames) != 0:
                    raise ValueError(
211
                        "GeoStrategy can not support large scale embeding now, please use paddle.static.nn.embedding"
212 213 214 215
                    )

                init_attrs = []
                for value_name in sparse_varnames:
216 217 218
                    value_var = self.origin_main_program.global_block().vars[
                        value_name
                    ]
219 220
                    value_attr = [
                        value_name,
221
                        ",".join([str(dim) for dim in value_var.shape]),
222 223
                    ]
                    for op in self.origin_startup_program.global_block().ops:
224 225 226 227
                        if (
                            op.type in opt_init_map.keys()
                            and value_name == op.output("Out")[0]
                        ):
228 229 230 231 232 233 234 235 236
                            init_attr = [op.type]
                            for attr in opt_init_map[op.type]:
                                init_attr.append(str(op.attr(attr)))
                            value_attr.append("&".join(init_attr))
                            init_attrs.append(":".join(value_attr))
                            break
                return "#".join(init_attrs)

            kwargs = {}
237
            kwargs["trainers"] = self.role_maker._worker_num()
238 239 240
            kwargs["sparse_attrs"] = get_sparse_attrs()
            return kwargs

241
        from paddle.incubate.distributed.fleet.parameter_server.distribute_transpiler.distributed_strategy import (
W
wangzhen38 已提交
242 243 244
            GeoStrategy,
            SyncStrategy,
        )
245
        from paddle.incubate.distributed.fleet.parameter_server.ir.public import (
W
wangzhen38 已提交
246 247 248
            _get_lr_ops,
            _has_global_step,
        )
249 250

        trainer_config = self.async_strategy.get_trainer_runtime_config()
251
        print(trainer_config)
252

C
Chengmo 已提交
253 254 255 256 257 258 259
        dist_strategy = self.context["valid_strategy"]
        launch_barrier = dist_strategy.a_sync_configs["launch_barrier"]
        if launch_barrier:
            # for trainer wait server ready
            wait_server_ready(self.role_maker._get_pserver_endpoints())

            # for ps-heter mode, wait heter worker ready
260 261 262
            if (
                self.role_maker._is_heter_parameter_server_mode
                and self.role_maker._is_worker()
C
Chengmo 已提交
263 264 265
            ):
                wait_server_ready(self.role_maker._get_heter_worker_endpoints())

266 267 268
        lrs = _has_global_step(_get_lr_ops(self.origin_main_program))

        if lrs:
269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285
            kwargs = {"need_global_step": "1"}
        else:
            kwargs = {"need_global_step": "0"}

        if isinstance(self.async_strategy, GeoStrategy):
            geo_kwargs = geo_strategy_envs()
            kwargs.update(geo_kwargs)
        if isinstance(self.async_strategy, SyncStrategy):
            sync_kwargs = sync_strategy_envs()
            kwargs.update(sync_kwargs)

        kwargs = kwargs if kwargs else None

        send_ctx = self.compiled_strategy.get_communicator_send_context()

        if self.compiled_strategy.is_geo_mode():
            recv_ctx = self.compiled_strategy.get_communicator_recv_context(
286 287
                recv_type=4
            )
288 289
        else:
            recv_ctx = self.compiled_strategy.get_communicator_recv_context(
290 291
                recv_type=1
            )
292

293
        from paddle.distributed.communicator import Communicator
294

295
        self._communicator = Communicator(
296 297
            trainer_config.mode, kwargs, trainer_config.get_communicator_flags()
        )
298 299 300 301 302 303 304
        self._communicator.init_with_ctx(send_ctx, recv_ctx)

        if not self._communicator.is_running():
            self._communicator.start()
        else:
            warnings.warn("communicator has been initialized, skip")

305
    def _get_executor(self):
306
        executor = Executor(paddle.CPUPlace())
307
        if self.role_maker._is_heter_parameter_server_mode:
308 309 310 311 312
            heter_worker_device_guard = (
                self.context["valid_strategy"]
                .a_sync_configs["heter_worker_device_guard"]
                .upper()
            )
313
            if heter_worker_device_guard not in ["GPU", "XPU", "CPU"]:
314 315 316 317 318
                raise ValueError(
                    "Heter Worker Not Support Device {}".format(
                        heter_worker_device_guard
                    )
                )
319 320 321
            if self.role_maker._is_heter_worker():
                if heter_worker_device_guard == "GPU":
                    executor = Executor(
322
                        paddle.CUDAPlace(
323 324 325
                            int(os.getenv("FLAGS_selected_gpus", "0"))
                        )
                    )
326 327
                elif heter_worker_device_guard == "XPU":
                    executor = Executor(
328
                        paddle.XPUPlace(
329 330 331
                            int(os.getenv("FLAGS_selected_xpus", "0"))
                        )
                    )
332 333
        return executor

334 335 336 337 338 339 340 341
    def _init_server(self, *args, **kwargs):
        if len(args) > 1:
            raise ValueError("init server can only accept 1 args: `dirname`")
        elif len(args) == 1:
            model_dirname = args[0]
        else:
            model_dirname = None

342
        executor = self._get_executor()
343 344 345 346
        if (
            self.role_maker._is_heter_worker()
            and self.context["valid_strategy"].a_sync_configs["launch_barrier"]
        ):
347 348
            # for heter trainer wait server ready
            wait_server_ready(self.role_maker._get_pserver_endpoints())
349
        executor.run(default_startup_program())
350

T
tangwei12 已提交
351 352
        if self.role_maker._is_heter_worker():
            self._init_worker()
353 354
            return

355 356 357
        sparse_varnames = self.compiled_strategy.get_sparse_varname_on_ps(False)
        sparse_related_optimize_varnames = []
        for var_name in sparse_varnames:
358 359 360
            sparse_related_optimize_varnames += (
                self.compiled_strategy.get_optimize_varname_on_ps(var_name)
            )
361
        sparse_related_optimize_varnames = list(
362 363
            set(sparse_related_optimize_varnames)
        )
364
        distribtued_varnames = self.compiled_strategy.get_sparse_varname_on_ps(
365 366
            True
        )
367 368
        distributed_related_optimize_varnames = []
        for var_name in distribtued_varnames:
369 370 371
            distributed_related_optimize_varnames += (
                self.compiled_strategy.get_optimize_varname_on_ps(var_name)
            )
372
        distributed_related_optimize_varnames = list(
373 374
            set(distributed_related_optimize_varnames)
        )
375 376 377

        remaining_vars = list(
            filter(
378
                ParameterServerRuntime.__exclude_vars(
379 380 381 382 383
                    sparse_varnames
                    + distribtued_varnames
                    + sparse_related_optimize_varnames
                    + distributed_related_optimize_varnames
                ),
384
                default_main_program().list_vars(),
385 386
            )
        )
387

388 389 390 391 392 393 394
        if not model_dirname:
            return

        if not os.path.isdir(model_dirname):
            raise ValueError("There is no directory named '%s'", model_dirname)

        # load dense
395
        paddle.static.load_vars(
396
            executor,
397
            main_program=default_main_program(),
398 399 400
            dirname=model_dirname,
            vars=remaining_vars,
        )
401

402
        # load sparse
403 404 405 406 407
        self._load_sparse_params(
            executor=executor,
            dirname=model_dirname,
            varnames=sparse_varnames + sparse_related_optimize_varnames,
        )
408

409
        # load large scale
410 411 412 413 414
        self._load_distributed_params(
            dirname=model_dirname,
            varnames=distribtued_varnames
            + distributed_related_optimize_varnames,
        )
415 416

    def _run_server(self):
417
        executor = self._get_executor()
418
        executor.run(default_main_program())
419 420 421

    def _stop_worker(self):
        self._communicator.stop()
422
        executor = self._get_executor()
423
        executor.close()
424 425 426

    def _get_optimizer_status(self, op, param_name):
        supported_opts = [
427 428 429 430 431 432 433 434 435
            "sgd",
            "adam",
            "adagrad",
            "adamax",
            "momentum",
            "lars_momentum",
            "rmsprop",
            "decayed_adagrad",
            "ftrl",
436 437 438 439 440 441 442 443 444
        ]

        reshaped_val_map = {}
        reshaped_val_map["sgd"] = []
        reshaped_val_map["adam"] = ["moment1_0", "moment2_0"]
        reshaped_val_map["adagrad"] = ["moment_0"]
        reshaped_val_map["adamax"] = ["moment_0", "inf_norm_0"]
        reshaped_val_map["momentum"] = ["velocity_0"]
        reshaped_val_map["lars_momentum"] = ["velocity_0"]
445
        reshaped_val_map["rmsprop"] = [
446 447 448
            "momentum_0",
            "mean_square_0",
            "mean_grad_0",
449
        ]
450 451 452 453 454 455 456 457 458
        reshaped_val_map["decayed_adagrad"] = ["moment_0"]
        reshaped_val_map["ftrl"] = ["squared_0", "linear_0"]

        orishaped_val_map = {}
        orishaped_val_map["adam"] = ["beta1_pow_acc_0", "beta2_pow_acc_0"]
        orishaped_val_map["adamax"] = ["beta1_pow_acc_0"]

        if op not in supported_opts:
            raise ValueError(
459 460 461 462
                "fleet can not support optimizer: {}, only this can be supported: {}".format(
                    op, supported_opts
                )
            )
463 464 465 466 467 468 469 470 471 472 473 474 475 476

        reshaped_names = [
            param_name + "_" + val for val in reshaped_val_map[op]
        ]

        if op not in orishaped_val_map:
            origin_names = []
        else:
            origin_names = [
                param_name + "_" + val for val in orishaped_val_map[op]
            ]
        return reshaped_names, origin_names

    def _get_optimizer_op(self, param_name):
477
        from paddle.incubate.distributed.fleet.parameter_server.ir.public import (
478 479
            _get_optimize_ops,
        )
480 481 482

        opts = _get_optimize_ops(self.origin_main_program)
        for op in opts:
483 484 485 486 487
            if (
                "Param" in op.input_names
                and "LearningRate" in op.input_names
                and op.input("Param")[0] == param_name
            ):
488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505
                return op

    def _save_dense_params(self, executor, dirname, context, main_program):
        self._communicator.recv()

        prog = Program()
        block = prog.global_block()
        local_vars = []

        for name, var_ctx in context.items():
            if len(var_ctx.origin_varnames()) != 1:
                raise ValueError("Dense can not support split now.")

            varname = var_ctx.origin_varnames()[0]
            local_vars.append(varname)

            optimizer = self._get_optimizer_op(varname)
            reshaped_varnames, origin_varnames = self._get_optimizer_status(
506 507
                optimizer.type, varname
            )
508 509 510

            for var_name in [varname] + reshaped_varnames + origin_varnames:
                var = self.origin_main_program.global_block().vars[var_name]
511 512 513 514 515 516 517 518 519 520 521 522 523
                block.append_op(
                    type='recv_save',
                    attrs={
                        "trainer_id": self.role_maker._worker_index(),
                        "shape": var.shape,
                        "slice_shapes": [",".join([str(i) for i in var.shape])],
                        "slice_varnames": [var.name],
                        "remote_varnames": [var.name],
                        "is_sparse": False,
                        "endpoints": var_ctx.split_endpoints(),
                        "file_path": os.path.join(dirname, var.name),
                    },
                )
524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541

        executor.run(prog)
        return local_vars

    def _save_sparse_params(self, executor, dirname, context, main_program):
        prog = Program()
        block = prog.global_block()
        local_vars = []

        for name, var_ctx in context.items():
            if len(var_ctx.origin_varnames()) != 1:
                raise ValueError("Dense can not support split now.")

            varname = var_ctx.origin_varnames()[0]
            local_vars.append(varname)

            optimizer = self._get_optimizer_op(varname)
            reshaped_varnames, origin_varnames = self._get_optimizer_status(
542 543
                optimizer.type, varname
            )
544 545 546 547 548 549 550 551

            var = self.origin_main_program.global_block().vars[varname]
            slice_shapes = []
            dims1 = ",".join([str(i) for i in var.shape[1:]])

            for section in var_ctx.sections():
                slice_shapes.append(str(section) + dims1)

552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567
            block.append_op(
                type='recv_save',
                attrs={
                    "trainer_id": self.role_maker._worker_index(),
                    "shape": var.shape,
                    "slice_shapes": slice_shapes,
                    "slice_varnames": var_ctx.split_varnames(),
                    "remote_varnames": var_ctx.split_varnames(),
                    "is_sparse": True,
                    "endpoints": var_ctx.split_endpoints(),
                    "pserver_num": len(
                        self.role_maker._get_pserver_endpoints()
                    ),
                    "file_path": os.path.join(dirname, var.name),
                },
            )
568 569

            for reshaped_varname in reshaped_varnames:
570 571 572
                var = self.origin_main_program.global_block().vars[
                    reshaped_varname
                ]
573 574 575 576

                slice_varnames = []
                remote_varnames = []
                for i in range(len(var_ctx.split_varnames())):
577 578 579
                    slice_varnames.append(
                        "{}.block{}".format(reshaped_varname, i)
                    )
580 581 582 583 584
                    remote_varnames.append(reshaped_varname)

                block.append_op(
                    type='recv_save',
                    attrs={
585
                        "trainer_id": self.role_maker._worker_index(),
586 587 588 589 590 591
                        "shape": var.shape,
                        "slice_shapes": slice_shapes,
                        "slice_varnames": slice_varnames,
                        "remote_varnames": remote_varnames,
                        "is_sparse": True,
                        "endpoints": var_ctx.split_endpoints(),
592 593 594 595 596 597
                        "pserver_num": len(
                            self.role_maker._get_pserver_endpoints()
                        ),
                        "file_path": os.path.join(dirname, var.name),
                    },
                )
598 599

            for origin_varname in origin_varnames:
600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616
                var = self.origin_main_program.global_block().vars[
                    origin_varname
                ]

                block.append_op(
                    type='recv_save',
                    attrs={
                        "trainer_id": self.role_maker._worker_index(),
                        "shape": var.shape,
                        "slice_shapes": [",".join([str(i) for i in var.shape])],
                        "slice_varnames": [origin_varname],
                        "remote_varnames": [origin_varname],
                        "is_sparse": False,
                        "endpoints": var_ctx.split_endpoints()[:1],
                        "file_path": os.path.join(dirname, var.name),
                    },
                )
617 618 619
        executor.run(prog)
        return context.keys()

620
    def _save_distributed_params(self, executor, dirname, context, mode):
621 622 623 624
        prog = Program()
        block = prog.global_block()

        for name, var_ctx in context.items():
625 626 627 628 629 630 631 632 633 634 635
            block.append_op(
                type='checkpoint_notify',
                attrs={
                    "varname": name,
                    "mode": mode,
                    "slice_varnames": var_ctx.split_varnames(),
                    "remote_varnames": var_ctx.split_varnames(),
                    "endpoints": var_ctx.split_endpoints(),
                    "dirname": dirname,
                },
            )
636 637 638 639

        executor.run(prog)
        return context.keys()

640 641 642
    def _save_distributed_persistables(
        self, executor, dirname, main_program, mode
    ):
643
        dense_ctx = self.compiled_strategy.get_communicator_recv_context(
644 645
            recv_type=1, use_origin_program=True
        )
646 647

        sparse_ctx = self.compiled_strategy.get_communicator_recv_context(
648 649
            recv_type=2, use_origin_program=True
        )
650 651

        distributed_ctx = self.compiled_strategy.get_communicator_recv_context(
652 653
            recv_type=3, use_origin_program=True
        )
654

655 656 657
        recv_dense_varnames = self._save_dense_params(
            executor, dirname, dense_ctx, main_program
        )
658

659 660 661
        recv_sparse_varnames = self._save_sparse_params(
            executor, dirname, sparse_ctx, main_program
        )
662 663

        recv_distributed_varnames = self._save_distributed_params(
664 665
            executor, dirname, distributed_ctx, mode
        )
666

667 668 669 670 671
        saved_varnames = (
            recv_dense_varnames
            + list(recv_sparse_varnames)
            + list(recv_distributed_varnames)
        )
672 673

        remaining_vars = list(
674 675 676 677 678 679
            filter(
                ParameterServerRuntime.__exclude_vars(saved_varnames),
                main_program.list_vars(),
            )
        )

680
        paddle.static.save_vars(
681 682 683 684 685 686 687 688 689
            executor,
            main_program=main_program,
            dirname=dirname,
            vars=remaining_vars,
        )

    def _ps_inference_save_persistables(
        self, executor, dirname, main_program=None, mode=0, **kwargs
    ):
690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711
        """
        This function filters out all variables with `persistable==True` from the
        give `main_program` and then saves these variables to the folder `dirname`
        or file `filename`.

        The `dirname` is used to specify the folder where persistable variables
        are going to be saved. If you would like to save variables in separate
        files, set `filename` None; if you would like to save all variables in a
        single file, use `filename` to specify the file name.
        """

        if isinstance(executor, ParallelExecutor):
            raise TypeError(
                "in fleet.save_persistables() function, executor must be as Executor type, ParallelExecutor is not allowed"
            )

        if not isinstance(executor, Executor):
            raise TypeError(
                "in fleet.save_persistables() function, executor must be as Executor type"
            )

        if main_program is None:
712
            main_program = self.compiled_strategy.get_origin_ps_main_program()
713 714 715 716 717 718

        if isinstance(main_program, CompiledProgram):
            raise TypeError(
                "in fleet.save_persistables() function, main_program must be as Program type, CompiledProgram is not allowed"
            )

719 720 721 722 723 724 725 726 727 728 729 730 731
        self._save_distributed_persistables(
            executor, dirname, main_program, mode
        )

    def _ps_inference_save_inference_model(
        self,
        executor,
        dirname,
        feeded_var_names,
        target_vars,
        main_program=None,
        export_for_deployment=True,
    ):
732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751
        """
        Prune the given `main_program` to build a new program especially for inference,
        and then save it and all related parameters to given `dirname` by the `executor`.
        """

        if isinstance(executor, ParallelExecutor):
            raise TypeError(
                "in fleet.save_inference_model() function, executor must be as Executor type, ParallelExecutor is not allowed"
            )

        if not isinstance(executor, Executor):
            raise TypeError(
                "in fleet.save_inference_model() function, executor must be as Executor type"
            )

        if main_program is not None:
            if isinstance(main_program, CompiledProgram):
                raise TypeError(
                    "in fleet.save_inference_model() function, main_program must be as Program type, CompiledProgram is not allowed"
                )
752
            save_inference_model(
753 754 755 756 757 758 759 760 761
                dirname,
                feeded_var_names,
                target_vars,
                executor,
                main_program,
                None,
                None,
                export_for_deployment,
            )
762
        else:
763
            save_inference_model(
764 765 766 767 768 769 770 771 772 773
                dirname,
                feeded_var_names,
                target_vars,
                executor,
                self.origin_main_program,
                None,
                None,
                export_for_deployment,
                True,
            )
774 775 776 777 778 779 780 781

            model_basename = "__model__"
            model_filename = os.path.join(dirname, model_basename)

            with open(model_filename, "rb") as f:
                program_desc_str = f.read()

            program = Program.parse_from_string(program_desc_str)
782
            program._copy_dist_param_info_from(default_main_program())
783 784 785
            self._ps_inference_save_persistables(
                executor, dirname, program, mode=0
            )
786 787 788 789 790 791

    def _save_inference_model(self, *args, **kwargs):
        self._ps_inference_save_inference_model(*args, **kwargs)

    def _save_persistables(self, *args, **kwargs):
        self._ps_inference_save_persistables(*args, **kwargs)