parameter_server_runtime.py 27.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import warnings

import paddle.fluid as fluid
from paddle.fluid import core
20 21 22 23
from paddle.fluid.framework import Program
from paddle.fluid.compiler import CompiledProgram
from paddle.fluid.executor import Executor
from paddle.fluid.parallel_executor import ParallelExecutor
24
from paddle.fluid.framework import Variable
25 26

from .runtime_base import RuntimeBase
C
Chengmo 已提交
27
from ..base.private_helper_function import wait_server_ready
28

29 30
__all__ = []

31 32 33

class ParameterServerRuntime(RuntimeBase):
    def __init__(self):
34
        super().__init__()
35 36 37 38 39 40 41 42 43 44 45 46 47
        self._communicator = None

    def _set_basic_info(self, context):
        self.context = context
        self.role_maker = context["role_maker"]
        self.origin_main_program = context["origin_main_program"]
        self.origin_startup_program = context["origin_startup_program"]
        self.async_strategy = self._get_distributed_strategy()
        self.compiled_strategy = self.build_compiled_startegy()

    def _get_distributed_strategy(self):
        strategy = None

48 49 50
        from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import (
            StrategyFactory,
        )
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69

        dist_strategy = self.context["valid_strategy"]
        k_steps = dist_strategy.a_sync_configs["k_steps"]

        if not dist_strategy.a_sync and k_steps == 0:
            strategy = StrategyFactory.create_sync_strategy()

        if dist_strategy.a_sync and k_steps == 0:
            strategy = StrategyFactory.create_async_strategy()

        if dist_strategy.a_sync and k_steps > 0:
            strategy = StrategyFactory.create_geo_strategy(k_steps)

        if not strategy:
            raise ValueError("k_steps must be invalid value, please check")

        return strategy

    def build_compiled_startegy(self):
70 71 72 73 74 75 76 77 78 79
        from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
            CompileTimeStrategy,
        )

        compiled_config = CompileTimeStrategy(
            self.origin_main_program,
            self.origin_main_program,
            self.async_strategy,
            self.role_maker,
        )
80 81
        return compiled_config

82 83 84
    def _load_sparse_params(
        self, executor, dirname, varnames, main_program=None
    ):
85
        assert vars is not None
86 87 88 89 90 91 92 93
        check_vars = []
        load_prog = Program()
        load_block = load_prog.global_block()

        def _in_varnames(var):
            return var.name in varnames

        load_vars = list(
94 95
            filter(_in_varnames, fluid.default_main_program().list_vars())
        )
96 97 98
        if main_program is None:
            main_program = self.origin_main_program

99 100 101 102
        from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
            _get_varname_parts,
        )

103 104 105 106 107 108 109 110
        for each_var in load_vars:
            assert isinstance(each_var, Variable)

            origin_varname, _, _ = _get_varname_parts(each_var.name)

            new_var = fluid.io._clone_var_in_block_(load_block, each_var)
            var_path = os.path.join(dirname, origin_varname)
            if not os.path.exists(var_path):
111 112
                raise ValueError(
                    "SelectedRows var {} can not find at {}".format(
113 114 115
                        new_var.name, var_path
                    )
                )
116 117

            if os.path.isfile(var_path):
118 119 120 121 122 123 124 125 126 127 128
                load_block.append_op(
                    type='sparse_tensor_load',
                    inputs={},
                    outputs={'Out': [new_var]},
                    attrs={
                        'file_path': os.path.join(dirname, origin_varname),
                        'node_index': self.role_maker._server_index(),
                        'node_num': self.role_maker._server_num(),
                        'shape': each_var.shape,
                    },
                )
129 130 131 132 133
            check_vars.append(each_var)

        executor.run(load_prog)

    def _load_distributed_params(self, dirname, varnames):
134
        from paddle.fluid.communicator import LargeScaleKV
135 136 137
        from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
            _get_varname_parts,
        )
138 139 140 141 142 143 144 145 146 147 148 149 150

        scale_kv = LargeScaleKV()
        for varname in varnames:
            origin_varname, _, _ = _get_varname_parts(varname)
            sparse_dir = os.path.join(dirname, origin_varname, varname)
            scale_kv.load(varname, sparse_dir)

    @staticmethod
    def __exclude_vars(exclude_var_names=[]):
        def is_valid(var):
            if var.name in exclude_var_names:
                return False

151 152 153
            from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
                _get_varname_parts,
            )
154 155 156 157 158 159 160 161

            origin_varname, _, _ = _get_varname_parts(var.name)
            if origin_varname.endswith("@GRAD"):
                return False

            if origin_varname == "learning_rate_0":
                return False

162 163 164 165 166
            if (
                var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH
                or var.desc.type() == core.VarDesc.VarType.FETCH_LIST
                or var.desc.type() == core.VarDesc.VarType.READER
            ):
167 168 169 170 171 172 173 174
                return False
            return var.persistable

        return is_valid

    def _init_worker(self):
        def sync_strategy_envs():
            kwargs = {}
175
            kwargs[
176 177
                "pserver_endpoints"
            ] = self.role_maker._get_pserver_endpoints()
178
            kwargs["trainer_id"] = self.role_maker._worker_index()
179 180 181
            return kwargs

        def geo_strategy_envs():
182 183 184
            from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
                get_sparse_tablenames,
            )
185 186 187 188 189 190

            def get_sparse_attrs():
                opt_init_map = {}
                opt_init_map["gaussian_random"] = ["seed", "mean", "std"]
                opt_init_map["fill_constant"] = ["value"]
                opt_init_map["uniform_random"] = ["seed", "min", "max"]
191
                opt_init_map["truncated_gaussian_random"] = [
192 193 194
                    "seed",
                    "mean",
                    "std",
195
                ]
196

197 198 199
                dist_varnames = get_sparse_tablenames(
                    self.origin_main_program, True
                )
200
                sparse_varnames = get_sparse_tablenames(
201 202
                    self.origin_main_program, False
                )
203 204 205 206 207 208 209 210

                if len(dist_varnames) != 0:
                    raise ValueError(
                        "GeoStrategy can not support large scale embeding now, please use fluid.layers.embedding"
                    )

                init_attrs = []
                for value_name in sparse_varnames:
211 212 213
                    value_var = self.origin_main_program.global_block().vars[
                        value_name
                    ]
214 215
                    value_attr = [
                        value_name,
216
                        ",".join([str(dim) for dim in value_var.shape]),
217 218
                    ]
                    for op in self.origin_startup_program.global_block().ops:
219 220 221 222
                        if (
                            op.type in opt_init_map.keys()
                            and value_name == op.output("Out")[0]
                        ):
223 224 225 226 227 228 229 230 231
                            init_attr = [op.type]
                            for attr in opt_init_map[op.type]:
                                init_attr.append(str(op.attr(attr)))
                            value_attr.append("&".join(init_attr))
                            init_attrs.append(":".join(value_attr))
                            break
                return "#".join(init_attrs)

            kwargs = {}
232
            kwargs["trainers"] = self.role_maker._worker_num()
233 234 235
            kwargs["sparse_attrs"] = get_sparse_attrs()
            return kwargs

236 237 238 239
        from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
            _get_lr_ops,
            _has_global_step,
        )
240

241 242 243 244
        from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import (
            SyncStrategy,
            GeoStrategy,
        )
245 246

        trainer_config = self.async_strategy.get_trainer_runtime_config()
247
        print(trainer_config)
248

C
Chengmo 已提交
249 250 251 252 253 254 255
        dist_strategy = self.context["valid_strategy"]
        launch_barrier = dist_strategy.a_sync_configs["launch_barrier"]
        if launch_barrier:
            # for trainer wait server ready
            wait_server_ready(self.role_maker._get_pserver_endpoints())

            # for ps-heter mode, wait heter worker ready
256 257 258
            if (
                self.role_maker._is_heter_parameter_server_mode
                and self.role_maker._is_worker()
C
Chengmo 已提交
259 260 261
            ):
                wait_server_ready(self.role_maker._get_heter_worker_endpoints())

262 263 264
        lrs = _has_global_step(_get_lr_ops(self.origin_main_program))

        if lrs:
265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281
            kwargs = {"need_global_step": "1"}
        else:
            kwargs = {"need_global_step": "0"}

        if isinstance(self.async_strategy, GeoStrategy):
            geo_kwargs = geo_strategy_envs()
            kwargs.update(geo_kwargs)
        if isinstance(self.async_strategy, SyncStrategy):
            sync_kwargs = sync_strategy_envs()
            kwargs.update(sync_kwargs)

        kwargs = kwargs if kwargs else None

        send_ctx = self.compiled_strategy.get_communicator_send_context()

        if self.compiled_strategy.is_geo_mode():
            recv_ctx = self.compiled_strategy.get_communicator_recv_context(
282 283
                recv_type=4
            )
284 285
        else:
            recv_ctx = self.compiled_strategy.get_communicator_recv_context(
286 287
                recv_type=1
            )
288 289

        from paddle.fluid.communicator import Communicator
290

291
        self._communicator = Communicator(
292 293
            trainer_config.mode, kwargs, trainer_config.get_communicator_flags()
        )
294 295 296 297 298 299 300
        self._communicator.init_with_ctx(send_ctx, recv_ctx)

        if not self._communicator.is_running():
            self._communicator.start()
        else:
            warnings.warn("communicator has been initialized, skip")

301
    def _get_executor(self):
302 303
        executor = fluid.Executor(fluid.CPUPlace())
        if self.role_maker._is_heter_parameter_server_mode:
304 305 306 307 308
            heter_worker_device_guard = (
                self.context["valid_strategy"]
                .a_sync_configs["heter_worker_device_guard"]
                .upper()
            )
309
            if heter_worker_device_guard not in ["GPU", "XPU", "CPU"]:
310 311 312 313 314
                raise ValueError(
                    "Heter Worker Not Support Device {}".format(
                        heter_worker_device_guard
                    )
                )
315 316 317 318
            if self.role_maker._is_heter_worker():
                if heter_worker_device_guard == "GPU":
                    executor = Executor(
                        fluid.CUDAPlace(
319 320 321
                            int(os.getenv("FLAGS_selected_gpus", "0"))
                        )
                    )
322 323 324
                elif heter_worker_device_guard == "XPU":
                    executor = Executor(
                        fluid.XPUPlace(
325 326 327
                            int(os.getenv("FLAGS_selected_xpus", "0"))
                        )
                    )
328 329
        return executor

330 331 332 333 334 335 336 337
    def _init_server(self, *args, **kwargs):
        if len(args) > 1:
            raise ValueError("init server can only accept 1 args: `dirname`")
        elif len(args) == 1:
            model_dirname = args[0]
        else:
            model_dirname = None

338
        executor = self._get_executor()
339 340 341 342
        if (
            self.role_maker._is_heter_worker()
            and self.context["valid_strategy"].a_sync_configs["launch_barrier"]
        ):
343 344
            # for heter trainer wait server ready
            wait_server_ready(self.role_maker._get_pserver_endpoints())
345 346
        executor.run(fluid.default_startup_program())

T
tangwei12 已提交
347 348
        if self.role_maker._is_heter_worker():
            self._init_worker()
349 350
            return

351 352 353
        sparse_varnames = self.compiled_strategy.get_sparse_varname_on_ps(False)
        sparse_related_optimize_varnames = []
        for var_name in sparse_varnames:
354 355 356
            sparse_related_optimize_varnames += (
                self.compiled_strategy.get_optimize_varname_on_ps(var_name)
            )
357
        sparse_related_optimize_varnames = list(
358 359
            set(sparse_related_optimize_varnames)
        )
360
        distribtued_varnames = self.compiled_strategy.get_sparse_varname_on_ps(
361 362
            True
        )
363 364
        distributed_related_optimize_varnames = []
        for var_name in distribtued_varnames:
365 366 367
            distributed_related_optimize_varnames += (
                self.compiled_strategy.get_optimize_varname_on_ps(var_name)
            )
368
        distributed_related_optimize_varnames = list(
369 370
            set(distributed_related_optimize_varnames)
        )
371 372 373

        remaining_vars = list(
            filter(
374
                ParameterServerRuntime.__exclude_vars(
375 376 377 378 379 380 381 382
                    sparse_varnames
                    + distribtued_varnames
                    + sparse_related_optimize_varnames
                    + distributed_related_optimize_varnames
                ),
                fluid.default_main_program().list_vars(),
            )
        )
383

384 385 386 387 388 389 390
        if not model_dirname:
            return

        if not os.path.isdir(model_dirname):
            raise ValueError("There is no directory named '%s'", model_dirname)

        # load dense
391 392 393 394 395 396
        fluid.io.load_vars(
            executor,
            main_program=fluid.default_main_program(),
            dirname=model_dirname,
            vars=remaining_vars,
        )
397

398
        # load sparse
399 400 401 402 403
        self._load_sparse_params(
            executor=executor,
            dirname=model_dirname,
            varnames=sparse_varnames + sparse_related_optimize_varnames,
        )
404

405
        # load large scale
406 407 408 409 410
        self._load_distributed_params(
            dirname=model_dirname,
            varnames=distribtued_varnames
            + distributed_related_optimize_varnames,
        )
411 412

    def _run_server(self):
413
        executor = self._get_executor()
414 415 416 417
        executor.run(fluid.default_main_program())

    def _stop_worker(self):
        self._communicator.stop()
418
        executor = self._get_executor()
419
        executor.close()
420 421 422

    def _get_optimizer_status(self, op, param_name):
        supported_opts = [
423 424 425 426 427 428 429 430 431
            "sgd",
            "adam",
            "adagrad",
            "adamax",
            "momentum",
            "lars_momentum",
            "rmsprop",
            "decayed_adagrad",
            "ftrl",
432 433 434 435 436 437 438 439 440
        ]

        reshaped_val_map = {}
        reshaped_val_map["sgd"] = []
        reshaped_val_map["adam"] = ["moment1_0", "moment2_0"]
        reshaped_val_map["adagrad"] = ["moment_0"]
        reshaped_val_map["adamax"] = ["moment_0", "inf_norm_0"]
        reshaped_val_map["momentum"] = ["velocity_0"]
        reshaped_val_map["lars_momentum"] = ["velocity_0"]
441
        reshaped_val_map["rmsprop"] = [
442 443 444
            "momentum_0",
            "mean_square_0",
            "mean_grad_0",
445
        ]
446 447 448 449 450 451 452 453 454
        reshaped_val_map["decayed_adagrad"] = ["moment_0"]
        reshaped_val_map["ftrl"] = ["squared_0", "linear_0"]

        orishaped_val_map = {}
        orishaped_val_map["adam"] = ["beta1_pow_acc_0", "beta2_pow_acc_0"]
        orishaped_val_map["adamax"] = ["beta1_pow_acc_0"]

        if op not in supported_opts:
            raise ValueError(
455 456 457 458
                "fleet can not support optimizer: {}, only this can be supported: {}".format(
                    op, supported_opts
                )
            )
459 460 461 462 463 464 465 466 467 468 469 470 471 472

        reshaped_names = [
            param_name + "_" + val for val in reshaped_val_map[op]
        ]

        if op not in orishaped_val_map:
            origin_names = []
        else:
            origin_names = [
                param_name + "_" + val for val in orishaped_val_map[op]
            ]
        return reshaped_names, origin_names

    def _get_optimizer_op(self, param_name):
473 474 475
        from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
            _get_optimize_ops,
        )
476 477 478

        opts = _get_optimize_ops(self.origin_main_program)
        for op in opts:
479 480 481 482 483
            if (
                "Param" in op.input_names
                and "LearningRate" in op.input_names
                and op.input("Param")[0] == param_name
            ):
484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501
                return op

    def _save_dense_params(self, executor, dirname, context, main_program):
        self._communicator.recv()

        prog = Program()
        block = prog.global_block()
        local_vars = []

        for name, var_ctx in context.items():
            if len(var_ctx.origin_varnames()) != 1:
                raise ValueError("Dense can not support split now.")

            varname = var_ctx.origin_varnames()[0]
            local_vars.append(varname)

            optimizer = self._get_optimizer_op(varname)
            reshaped_varnames, origin_varnames = self._get_optimizer_status(
502 503
                optimizer.type, varname
            )
504 505 506

            for var_name in [varname] + reshaped_varnames + origin_varnames:
                var = self.origin_main_program.global_block().vars[var_name]
507 508 509 510 511 512 513 514 515 516 517 518 519
                block.append_op(
                    type='recv_save',
                    attrs={
                        "trainer_id": self.role_maker._worker_index(),
                        "shape": var.shape,
                        "slice_shapes": [",".join([str(i) for i in var.shape])],
                        "slice_varnames": [var.name],
                        "remote_varnames": [var.name],
                        "is_sparse": False,
                        "endpoints": var_ctx.split_endpoints(),
                        "file_path": os.path.join(dirname, var.name),
                    },
                )
520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537

        executor.run(prog)
        return local_vars

    def _save_sparse_params(self, executor, dirname, context, main_program):
        prog = Program()
        block = prog.global_block()
        local_vars = []

        for name, var_ctx in context.items():
            if len(var_ctx.origin_varnames()) != 1:
                raise ValueError("Dense can not support split now.")

            varname = var_ctx.origin_varnames()[0]
            local_vars.append(varname)

            optimizer = self._get_optimizer_op(varname)
            reshaped_varnames, origin_varnames = self._get_optimizer_status(
538 539
                optimizer.type, varname
            )
540 541 542 543 544 545 546 547

            var = self.origin_main_program.global_block().vars[varname]
            slice_shapes = []
            dims1 = ",".join([str(i) for i in var.shape[1:]])

            for section in var_ctx.sections():
                slice_shapes.append(str(section) + dims1)

548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563
            block.append_op(
                type='recv_save',
                attrs={
                    "trainer_id": self.role_maker._worker_index(),
                    "shape": var.shape,
                    "slice_shapes": slice_shapes,
                    "slice_varnames": var_ctx.split_varnames(),
                    "remote_varnames": var_ctx.split_varnames(),
                    "is_sparse": True,
                    "endpoints": var_ctx.split_endpoints(),
                    "pserver_num": len(
                        self.role_maker._get_pserver_endpoints()
                    ),
                    "file_path": os.path.join(dirname, var.name),
                },
            )
564 565

            for reshaped_varname in reshaped_varnames:
566 567 568
                var = self.origin_main_program.global_block().vars[
                    reshaped_varname
                ]
569 570 571 572

                slice_varnames = []
                remote_varnames = []
                for i in range(len(var_ctx.split_varnames())):
573 574 575
                    slice_varnames.append(
                        "{}.block{}".format(reshaped_varname, i)
                    )
576 577 578 579 580
                    remote_varnames.append(reshaped_varname)

                block.append_op(
                    type='recv_save',
                    attrs={
581
                        "trainer_id": self.role_maker._worker_index(),
582 583 584 585 586 587
                        "shape": var.shape,
                        "slice_shapes": slice_shapes,
                        "slice_varnames": slice_varnames,
                        "remote_varnames": remote_varnames,
                        "is_sparse": True,
                        "endpoints": var_ctx.split_endpoints(),
588 589 590 591 592 593
                        "pserver_num": len(
                            self.role_maker._get_pserver_endpoints()
                        ),
                        "file_path": os.path.join(dirname, var.name),
                    },
                )
594 595

            for origin_varname in origin_varnames:
596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612
                var = self.origin_main_program.global_block().vars[
                    origin_varname
                ]

                block.append_op(
                    type='recv_save',
                    attrs={
                        "trainer_id": self.role_maker._worker_index(),
                        "shape": var.shape,
                        "slice_shapes": [",".join([str(i) for i in var.shape])],
                        "slice_varnames": [origin_varname],
                        "remote_varnames": [origin_varname],
                        "is_sparse": False,
                        "endpoints": var_ctx.split_endpoints()[:1],
                        "file_path": os.path.join(dirname, var.name),
                    },
                )
613 614 615
        executor.run(prog)
        return context.keys()

616
    def _save_distributed_params(self, executor, dirname, context, mode):
617 618 619 620
        prog = Program()
        block = prog.global_block()

        for name, var_ctx in context.items():
621 622 623 624 625 626 627 628 629 630 631
            block.append_op(
                type='checkpoint_notify',
                attrs={
                    "varname": name,
                    "mode": mode,
                    "slice_varnames": var_ctx.split_varnames(),
                    "remote_varnames": var_ctx.split_varnames(),
                    "endpoints": var_ctx.split_endpoints(),
                    "dirname": dirname,
                },
            )
632 633 634 635

        executor.run(prog)
        return context.keys()

636 637 638
    def _save_distributed_persistables(
        self, executor, dirname, main_program, mode
    ):
639
        dense_ctx = self.compiled_strategy.get_communicator_recv_context(
640 641
            recv_type=1, use_origin_program=True
        )
642 643

        sparse_ctx = self.compiled_strategy.get_communicator_recv_context(
644 645
            recv_type=2, use_origin_program=True
        )
646 647

        distributed_ctx = self.compiled_strategy.get_communicator_recv_context(
648 649
            recv_type=3, use_origin_program=True
        )
650

651 652 653
        recv_dense_varnames = self._save_dense_params(
            executor, dirname, dense_ctx, main_program
        )
654

655 656 657
        recv_sparse_varnames = self._save_sparse_params(
            executor, dirname, sparse_ctx, main_program
        )
658 659

        recv_distributed_varnames = self._save_distributed_params(
660 661
            executor, dirname, distributed_ctx, mode
        )
662

663 664 665 666 667
        saved_varnames = (
            recv_dense_varnames
            + list(recv_sparse_varnames)
            + list(recv_distributed_varnames)
        )
668 669

        remaining_vars = list(
670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685
            filter(
                ParameterServerRuntime.__exclude_vars(saved_varnames),
                main_program.list_vars(),
            )
        )

        fluid.io.save_vars(
            executor,
            main_program=main_program,
            dirname=dirname,
            vars=remaining_vars,
        )

    def _ps_inference_save_persistables(
        self, executor, dirname, main_program=None, mode=0, **kwargs
    ):
686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707
        """
        This function filters out all variables with `persistable==True` from the
        give `main_program` and then saves these variables to the folder `dirname`
        or file `filename`.

        The `dirname` is used to specify the folder where persistable variables
        are going to be saved. If you would like to save variables in separate
        files, set `filename` None; if you would like to save all variables in a
        single file, use `filename` to specify the file name.
        """

        if isinstance(executor, ParallelExecutor):
            raise TypeError(
                "in fleet.save_persistables() function, executor must be as Executor type, ParallelExecutor is not allowed"
            )

        if not isinstance(executor, Executor):
            raise TypeError(
                "in fleet.save_persistables() function, executor must be as Executor type"
            )

        if main_program is None:
708
            main_program = self.compiled_strategy.get_origin_ps_main_program()
709 710 711 712 713 714

        if isinstance(main_program, CompiledProgram):
            raise TypeError(
                "in fleet.save_persistables() function, main_program must be as Program type, CompiledProgram is not allowed"
            )

715 716 717 718 719 720 721 722 723 724 725 726 727
        self._save_distributed_persistables(
            executor, dirname, main_program, mode
        )

    def _ps_inference_save_inference_model(
        self,
        executor,
        dirname,
        feeded_var_names,
        target_vars,
        main_program=None,
        export_for_deployment=True,
    ):
728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747
        """
        Prune the given `main_program` to build a new program especially for inference,
        and then save it and all related parameters to given `dirname` by the `executor`.
        """

        if isinstance(executor, ParallelExecutor):
            raise TypeError(
                "in fleet.save_inference_model() function, executor must be as Executor type, ParallelExecutor is not allowed"
            )

        if not isinstance(executor, Executor):
            raise TypeError(
                "in fleet.save_inference_model() function, executor must be as Executor type"
            )

        if main_program is not None:
            if isinstance(main_program, CompiledProgram):
                raise TypeError(
                    "in fleet.save_inference_model() function, main_program must be as Program type, CompiledProgram is not allowed"
                )
748 749 750 751 752 753 754 755 756 757
            fluid.io.save_inference_model(
                dirname,
                feeded_var_names,
                target_vars,
                executor,
                main_program,
                None,
                None,
                export_for_deployment,
            )
758
        else:
759 760 761 762 763 764 765 766 767 768 769
            fluid.io.save_inference_model(
                dirname,
                feeded_var_names,
                target_vars,
                executor,
                self.origin_main_program,
                None,
                None,
                export_for_deployment,
                True,
            )
770 771 772 773 774 775 776 777 778

            model_basename = "__model__"
            model_filename = os.path.join(dirname, model_basename)

            with open(model_filename, "rb") as f:
                program_desc_str = f.read()

            program = Program.parse_from_string(program_desc_str)
            program._copy_dist_param_info_from(fluid.default_main_program())
779 780 781
            self._ps_inference_save_persistables(
                executor, dirname, program, mode=0
            )
782 783 784 785 786 787

    def _save_inference_model(self, *args, **kwargs):
        self._ps_inference_save_inference_model(*args, **kwargs)

    def _save_persistables(self, *args, **kwargs):
        self._ps_inference_save_persistables(*args, **kwargs)