parameter_server_runtime.py 27.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import warnings

import paddle.fluid as fluid
from paddle.fluid import core
20 21
from paddle.fluid.compiler import CompiledProgram
from paddle.fluid.executor import Executor
22
from paddle.fluid.framework import Program, Variable
23
from paddle.fluid.parallel_executor import ParallelExecutor
24

25
from ..base.private_helper_function import wait_server_ready
26
from .runtime_base import RuntimeBase
27

28 29
__all__ = []

30 31 32

class ParameterServerRuntime(RuntimeBase):
    def __init__(self):
33
        super().__init__()
34 35 36 37 38 39 40 41 42 43 44 45 46
        self._communicator = None

    def _set_basic_info(self, context):
        self.context = context
        self.role_maker = context["role_maker"]
        self.origin_main_program = context["origin_main_program"]
        self.origin_startup_program = context["origin_startup_program"]
        self.async_strategy = self._get_distributed_strategy()
        self.compiled_strategy = self.build_compiled_startegy()

    def _get_distributed_strategy(self):
        strategy = None

47 48 49
        from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import (
            StrategyFactory,
        )
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68

        dist_strategy = self.context["valid_strategy"]
        k_steps = dist_strategy.a_sync_configs["k_steps"]

        if not dist_strategy.a_sync and k_steps == 0:
            strategy = StrategyFactory.create_sync_strategy()

        if dist_strategy.a_sync and k_steps == 0:
            strategy = StrategyFactory.create_async_strategy()

        if dist_strategy.a_sync and k_steps > 0:
            strategy = StrategyFactory.create_geo_strategy(k_steps)

        if not strategy:
            raise ValueError("k_steps must be invalid value, please check")

        return strategy

    def build_compiled_startegy(self):
69 70 71 72 73 74 75 76 77 78
        from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
            CompileTimeStrategy,
        )

        compiled_config = CompileTimeStrategy(
            self.origin_main_program,
            self.origin_main_program,
            self.async_strategy,
            self.role_maker,
        )
79 80
        return compiled_config

81 82 83
    def _load_sparse_params(
        self, executor, dirname, varnames, main_program=None
    ):
84
        assert vars is not None
85 86 87 88 89 90 91 92
        check_vars = []
        load_prog = Program()
        load_block = load_prog.global_block()

        def _in_varnames(var):
            return var.name in varnames

        load_vars = list(
93 94
            filter(_in_varnames, fluid.default_main_program().list_vars())
        )
95 96 97
        if main_program is None:
            main_program = self.origin_main_program

98 99 100 101
        from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
            _get_varname_parts,
        )

102 103 104 105 106 107 108 109
        for each_var in load_vars:
            assert isinstance(each_var, Variable)

            origin_varname, _, _ = _get_varname_parts(each_var.name)

            new_var = fluid.io._clone_var_in_block_(load_block, each_var)
            var_path = os.path.join(dirname, origin_varname)
            if not os.path.exists(var_path):
110 111
                raise ValueError(
                    "SelectedRows var {} can not find at {}".format(
112 113 114
                        new_var.name, var_path
                    )
                )
115 116

            if os.path.isfile(var_path):
117 118 119 120 121 122 123 124 125 126 127
                load_block.append_op(
                    type='sparse_tensor_load',
                    inputs={},
                    outputs={'Out': [new_var]},
                    attrs={
                        'file_path': os.path.join(dirname, origin_varname),
                        'node_index': self.role_maker._server_index(),
                        'node_num': self.role_maker._server_num(),
                        'shape': each_var.shape,
                    },
                )
128 129 130 131 132
            check_vars.append(each_var)

        executor.run(load_prog)

    def _load_distributed_params(self, dirname, varnames):
133
        from paddle.fluid.communicator import LargeScaleKV
134 135 136
        from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
            _get_varname_parts,
        )
137 138 139 140 141 142 143 144 145 146 147 148 149

        scale_kv = LargeScaleKV()
        for varname in varnames:
            origin_varname, _, _ = _get_varname_parts(varname)
            sparse_dir = os.path.join(dirname, origin_varname, varname)
            scale_kv.load(varname, sparse_dir)

    @staticmethod
    def __exclude_vars(exclude_var_names=[]):
        def is_valid(var):
            if var.name in exclude_var_names:
                return False

150 151 152
            from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
                _get_varname_parts,
            )
153 154 155 156 157 158 159 160

            origin_varname, _, _ = _get_varname_parts(var.name)
            if origin_varname.endswith("@GRAD"):
                return False

            if origin_varname == "learning_rate_0":
                return False

161 162 163 164 165
            if (
                var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH
                or var.desc.type() == core.VarDesc.VarType.FETCH_LIST
                or var.desc.type() == core.VarDesc.VarType.READER
            ):
166 167 168 169 170 171 172 173
                return False
            return var.persistable

        return is_valid

    def _init_worker(self):
        def sync_strategy_envs():
            kwargs = {}
174
            kwargs[
175 176
                "pserver_endpoints"
            ] = self.role_maker._get_pserver_endpoints()
177
            kwargs["trainer_id"] = self.role_maker._worker_index()
178 179 180
            return kwargs

        def geo_strategy_envs():
181 182 183
            from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
                get_sparse_tablenames,
            )
184 185 186 187 188 189

            def get_sparse_attrs():
                opt_init_map = {}
                opt_init_map["gaussian_random"] = ["seed", "mean", "std"]
                opt_init_map["fill_constant"] = ["value"]
                opt_init_map["uniform_random"] = ["seed", "min", "max"]
190
                opt_init_map["truncated_gaussian_random"] = [
191 192 193
                    "seed",
                    "mean",
                    "std",
194
                ]
195

196 197 198
                dist_varnames = get_sparse_tablenames(
                    self.origin_main_program, True
                )
199
                sparse_varnames = get_sparse_tablenames(
200 201
                    self.origin_main_program, False
                )
202 203 204 205 206 207 208 209

                if len(dist_varnames) != 0:
                    raise ValueError(
                        "GeoStrategy can not support large scale embeding now, please use fluid.layers.embedding"
                    )

                init_attrs = []
                for value_name in sparse_varnames:
210 211 212
                    value_var = self.origin_main_program.global_block().vars[
                        value_name
                    ]
213 214
                    value_attr = [
                        value_name,
215
                        ",".join([str(dim) for dim in value_var.shape]),
216 217
                    ]
                    for op in self.origin_startup_program.global_block().ops:
218 219 220 221
                        if (
                            op.type in opt_init_map.keys()
                            and value_name == op.output("Out")[0]
                        ):
222 223 224 225 226 227 228 229 230
                            init_attr = [op.type]
                            for attr in opt_init_map[op.type]:
                                init_attr.append(str(op.attr(attr)))
                            value_attr.append("&".join(init_attr))
                            init_attrs.append(":".join(value_attr))
                            break
                return "#".join(init_attrs)

            kwargs = {}
231
            kwargs["trainers"] = self.role_maker._worker_num()
232 233 234
            kwargs["sparse_attrs"] = get_sparse_attrs()
            return kwargs

235 236 237 238
        from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import (
            GeoStrategy,
            SyncStrategy,
        )
239 240 241 242
        from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
            _get_lr_ops,
            _has_global_step,
        )
243 244

        trainer_config = self.async_strategy.get_trainer_runtime_config()
245
        print(trainer_config)
246

247 248 249 250 251 252 253
        dist_strategy = self.context["valid_strategy"]
        launch_barrier = dist_strategy.a_sync_configs["launch_barrier"]
        if launch_barrier:
            # for trainer wait server ready
            wait_server_ready(self.role_maker._get_pserver_endpoints())

            # for ps-heter mode, wait heter worker ready
254 255 256
            if (
                self.role_maker._is_heter_parameter_server_mode
                and self.role_maker._is_worker()
257 258 259
            ):
                wait_server_ready(self.role_maker._get_heter_worker_endpoints())

260 261 262
        lrs = _has_global_step(_get_lr_ops(self.origin_main_program))

        if lrs:
263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279
            kwargs = {"need_global_step": "1"}
        else:
            kwargs = {"need_global_step": "0"}

        if isinstance(self.async_strategy, GeoStrategy):
            geo_kwargs = geo_strategy_envs()
            kwargs.update(geo_kwargs)
        if isinstance(self.async_strategy, SyncStrategy):
            sync_kwargs = sync_strategy_envs()
            kwargs.update(sync_kwargs)

        kwargs = kwargs if kwargs else None

        send_ctx = self.compiled_strategy.get_communicator_send_context()

        if self.compiled_strategy.is_geo_mode():
            recv_ctx = self.compiled_strategy.get_communicator_recv_context(
280 281
                recv_type=4
            )
282 283
        else:
            recv_ctx = self.compiled_strategy.get_communicator_recv_context(
284 285
                recv_type=1
            )
286 287

        from paddle.fluid.communicator import Communicator
288

289
        self._communicator = Communicator(
290 291
            trainer_config.mode, kwargs, trainer_config.get_communicator_flags()
        )
292 293 294 295 296 297 298
        self._communicator.init_with_ctx(send_ctx, recv_ctx)

        if not self._communicator.is_running():
            self._communicator.start()
        else:
            warnings.warn("communicator has been initialized, skip")

299
    def _get_executor(self):
300 301
        executor = fluid.Executor(fluid.CPUPlace())
        if self.role_maker._is_heter_parameter_server_mode:
302 303 304 305 306
            heter_worker_device_guard = (
                self.context["valid_strategy"]
                .a_sync_configs["heter_worker_device_guard"]
                .upper()
            )
307
            if heter_worker_device_guard not in ["GPU", "XPU", "CPU"]:
308 309 310 311 312
                raise ValueError(
                    "Heter Worker Not Support Device {}".format(
                        heter_worker_device_guard
                    )
                )
313 314 315 316
            if self.role_maker._is_heter_worker():
                if heter_worker_device_guard == "GPU":
                    executor = Executor(
                        fluid.CUDAPlace(
317 318 319
                            int(os.getenv("FLAGS_selected_gpus", "0"))
                        )
                    )
320 321 322
                elif heter_worker_device_guard == "XPU":
                    executor = Executor(
                        fluid.XPUPlace(
323 324 325
                            int(os.getenv("FLAGS_selected_xpus", "0"))
                        )
                    )
326 327
        return executor

328 329 330 331 332 333 334 335
    def _init_server(self, *args, **kwargs):
        if len(args) > 1:
            raise ValueError("init server can only accept 1 args: `dirname`")
        elif len(args) == 1:
            model_dirname = args[0]
        else:
            model_dirname = None

336
        executor = self._get_executor()
337 338 339 340
        if (
            self.role_maker._is_heter_worker()
            and self.context["valid_strategy"].a_sync_configs["launch_barrier"]
        ):
341 342
            # for heter trainer wait server ready
            wait_server_ready(self.role_maker._get_pserver_endpoints())
343 344
        executor.run(fluid.default_startup_program())

345 346
        if self.role_maker._is_heter_worker():
            self._init_worker()
347 348
            return

349 350 351
        sparse_varnames = self.compiled_strategy.get_sparse_varname_on_ps(False)
        sparse_related_optimize_varnames = []
        for var_name in sparse_varnames:
352 353 354
            sparse_related_optimize_varnames += (
                self.compiled_strategy.get_optimize_varname_on_ps(var_name)
            )
355
        sparse_related_optimize_varnames = list(
356 357
            set(sparse_related_optimize_varnames)
        )
358
        distribtued_varnames = self.compiled_strategy.get_sparse_varname_on_ps(
359 360
            True
        )
361 362
        distributed_related_optimize_varnames = []
        for var_name in distribtued_varnames:
363 364 365
            distributed_related_optimize_varnames += (
                self.compiled_strategy.get_optimize_varname_on_ps(var_name)
            )
366
        distributed_related_optimize_varnames = list(
367 368
            set(distributed_related_optimize_varnames)
        )
369 370 371

        remaining_vars = list(
            filter(
372
                ParameterServerRuntime.__exclude_vars(
373 374 375 376 377 378 379 380
                    sparse_varnames
                    + distribtued_varnames
                    + sparse_related_optimize_varnames
                    + distributed_related_optimize_varnames
                ),
                fluid.default_main_program().list_vars(),
            )
        )
381

382 383 384 385 386 387 388
        if not model_dirname:
            return

        if not os.path.isdir(model_dirname):
            raise ValueError("There is no directory named '%s'", model_dirname)

        # load dense
389 390 391 392 393 394
        fluid.io.load_vars(
            executor,
            main_program=fluid.default_main_program(),
            dirname=model_dirname,
            vars=remaining_vars,
        )
395

396
        # load sparse
397 398 399 400 401
        self._load_sparse_params(
            executor=executor,
            dirname=model_dirname,
            varnames=sparse_varnames + sparse_related_optimize_varnames,
        )
402

403
        # load large scale
404 405 406 407 408
        self._load_distributed_params(
            dirname=model_dirname,
            varnames=distribtued_varnames
            + distributed_related_optimize_varnames,
        )
409 410

    def _run_server(self):
411
        executor = self._get_executor()
412 413 414 415
        executor.run(fluid.default_main_program())

    def _stop_worker(self):
        self._communicator.stop()
416
        executor = self._get_executor()
417
        executor.close()
418 419 420

    def _get_optimizer_status(self, op, param_name):
        supported_opts = [
421 422 423 424 425 426 427 428 429
            "sgd",
            "adam",
            "adagrad",
            "adamax",
            "momentum",
            "lars_momentum",
            "rmsprop",
            "decayed_adagrad",
            "ftrl",
430 431 432 433 434 435 436 437 438
        ]

        reshaped_val_map = {}
        reshaped_val_map["sgd"] = []
        reshaped_val_map["adam"] = ["moment1_0", "moment2_0"]
        reshaped_val_map["adagrad"] = ["moment_0"]
        reshaped_val_map["adamax"] = ["moment_0", "inf_norm_0"]
        reshaped_val_map["momentum"] = ["velocity_0"]
        reshaped_val_map["lars_momentum"] = ["velocity_0"]
439
        reshaped_val_map["rmsprop"] = [
440 441 442
            "momentum_0",
            "mean_square_0",
            "mean_grad_0",
443
        ]
444 445 446 447 448 449 450 451 452
        reshaped_val_map["decayed_adagrad"] = ["moment_0"]
        reshaped_val_map["ftrl"] = ["squared_0", "linear_0"]

        orishaped_val_map = {}
        orishaped_val_map["adam"] = ["beta1_pow_acc_0", "beta2_pow_acc_0"]
        orishaped_val_map["adamax"] = ["beta1_pow_acc_0"]

        if op not in supported_opts:
            raise ValueError(
453 454 455 456
                "fleet can not support optimizer: {}, only this can be supported: {}".format(
                    op, supported_opts
                )
            )
457 458 459 460 461 462 463 464 465 466 467 468 469 470

        reshaped_names = [
            param_name + "_" + val for val in reshaped_val_map[op]
        ]

        if op not in orishaped_val_map:
            origin_names = []
        else:
            origin_names = [
                param_name + "_" + val for val in orishaped_val_map[op]
            ]
        return reshaped_names, origin_names

    def _get_optimizer_op(self, param_name):
471 472 473
        from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
            _get_optimize_ops,
        )
474 475 476

        opts = _get_optimize_ops(self.origin_main_program)
        for op in opts:
477 478 479 480 481
            if (
                "Param" in op.input_names
                and "LearningRate" in op.input_names
                and op.input("Param")[0] == param_name
            ):
482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499
                return op

    def _save_dense_params(self, executor, dirname, context, main_program):
        self._communicator.recv()

        prog = Program()
        block = prog.global_block()
        local_vars = []

        for name, var_ctx in context.items():
            if len(var_ctx.origin_varnames()) != 1:
                raise ValueError("Dense can not support split now.")

            varname = var_ctx.origin_varnames()[0]
            local_vars.append(varname)

            optimizer = self._get_optimizer_op(varname)
            reshaped_varnames, origin_varnames = self._get_optimizer_status(
500 501
                optimizer.type, varname
            )
502 503 504

            for var_name in [varname] + reshaped_varnames + origin_varnames:
                var = self.origin_main_program.global_block().vars[var_name]
505 506 507 508 509 510 511 512 513 514 515 516 517
                block.append_op(
                    type='recv_save',
                    attrs={
                        "trainer_id": self.role_maker._worker_index(),
                        "shape": var.shape,
                        "slice_shapes": [",".join([str(i) for i in var.shape])],
                        "slice_varnames": [var.name],
                        "remote_varnames": [var.name],
                        "is_sparse": False,
                        "endpoints": var_ctx.split_endpoints(),
                        "file_path": os.path.join(dirname, var.name),
                    },
                )
518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535

        executor.run(prog)
        return local_vars

    def _save_sparse_params(self, executor, dirname, context, main_program):
        prog = Program()
        block = prog.global_block()
        local_vars = []

        for name, var_ctx in context.items():
            if len(var_ctx.origin_varnames()) != 1:
                raise ValueError("Dense can not support split now.")

            varname = var_ctx.origin_varnames()[0]
            local_vars.append(varname)

            optimizer = self._get_optimizer_op(varname)
            reshaped_varnames, origin_varnames = self._get_optimizer_status(
536 537
                optimizer.type, varname
            )
538 539 540 541 542 543 544 545

            var = self.origin_main_program.global_block().vars[varname]
            slice_shapes = []
            dims1 = ",".join([str(i) for i in var.shape[1:]])

            for section in var_ctx.sections():
                slice_shapes.append(str(section) + dims1)

546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561
            block.append_op(
                type='recv_save',
                attrs={
                    "trainer_id": self.role_maker._worker_index(),
                    "shape": var.shape,
                    "slice_shapes": slice_shapes,
                    "slice_varnames": var_ctx.split_varnames(),
                    "remote_varnames": var_ctx.split_varnames(),
                    "is_sparse": True,
                    "endpoints": var_ctx.split_endpoints(),
                    "pserver_num": len(
                        self.role_maker._get_pserver_endpoints()
                    ),
                    "file_path": os.path.join(dirname, var.name),
                },
            )
562 563

            for reshaped_varname in reshaped_varnames:
564 565 566
                var = self.origin_main_program.global_block().vars[
                    reshaped_varname
                ]
567 568 569 570

                slice_varnames = []
                remote_varnames = []
                for i in range(len(var_ctx.split_varnames())):
571 572 573
                    slice_varnames.append(
                        "{}.block{}".format(reshaped_varname, i)
                    )
574 575 576 577 578
                    remote_varnames.append(reshaped_varname)

                block.append_op(
                    type='recv_save',
                    attrs={
579
                        "trainer_id": self.role_maker._worker_index(),
580 581 582 583 584 585
                        "shape": var.shape,
                        "slice_shapes": slice_shapes,
                        "slice_varnames": slice_varnames,
                        "remote_varnames": remote_varnames,
                        "is_sparse": True,
                        "endpoints": var_ctx.split_endpoints(),
586 587 588 589 590 591
                        "pserver_num": len(
                            self.role_maker._get_pserver_endpoints()
                        ),
                        "file_path": os.path.join(dirname, var.name),
                    },
                )
592 593

            for origin_varname in origin_varnames:
594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610
                var = self.origin_main_program.global_block().vars[
                    origin_varname
                ]

                block.append_op(
                    type='recv_save',
                    attrs={
                        "trainer_id": self.role_maker._worker_index(),
                        "shape": var.shape,
                        "slice_shapes": [",".join([str(i) for i in var.shape])],
                        "slice_varnames": [origin_varname],
                        "remote_varnames": [origin_varname],
                        "is_sparse": False,
                        "endpoints": var_ctx.split_endpoints()[:1],
                        "file_path": os.path.join(dirname, var.name),
                    },
                )
611 612 613
        executor.run(prog)
        return context.keys()

614
    def _save_distributed_params(self, executor, dirname, context, mode):
615 616 617 618
        prog = Program()
        block = prog.global_block()

        for name, var_ctx in context.items():
619 620 621 622 623 624 625 626 627 628 629
            block.append_op(
                type='checkpoint_notify',
                attrs={
                    "varname": name,
                    "mode": mode,
                    "slice_varnames": var_ctx.split_varnames(),
                    "remote_varnames": var_ctx.split_varnames(),
                    "endpoints": var_ctx.split_endpoints(),
                    "dirname": dirname,
                },
            )
630 631 632 633

        executor.run(prog)
        return context.keys()

634 635 636
    def _save_distributed_persistables(
        self, executor, dirname, main_program, mode
    ):
637
        dense_ctx = self.compiled_strategy.get_communicator_recv_context(
638 639
            recv_type=1, use_origin_program=True
        )
640 641

        sparse_ctx = self.compiled_strategy.get_communicator_recv_context(
642 643
            recv_type=2, use_origin_program=True
        )
644 645

        distributed_ctx = self.compiled_strategy.get_communicator_recv_context(
646 647
            recv_type=3, use_origin_program=True
        )
648

649 650 651
        recv_dense_varnames = self._save_dense_params(
            executor, dirname, dense_ctx, main_program
        )
652

653 654 655
        recv_sparse_varnames = self._save_sparse_params(
            executor, dirname, sparse_ctx, main_program
        )
656 657

        recv_distributed_varnames = self._save_distributed_params(
658 659
            executor, dirname, distributed_ctx, mode
        )
660

661 662 663 664 665
        saved_varnames = (
            recv_dense_varnames
            + list(recv_sparse_varnames)
            + list(recv_distributed_varnames)
        )
666 667

        remaining_vars = list(
668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683
            filter(
                ParameterServerRuntime.__exclude_vars(saved_varnames),
                main_program.list_vars(),
            )
        )

        fluid.io.save_vars(
            executor,
            main_program=main_program,
            dirname=dirname,
            vars=remaining_vars,
        )

    def _ps_inference_save_persistables(
        self, executor, dirname, main_program=None, mode=0, **kwargs
    ):
684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705
        """
        This function filters out all variables with `persistable==True` from the
        give `main_program` and then saves these variables to the folder `dirname`
        or file `filename`.

        The `dirname` is used to specify the folder where persistable variables
        are going to be saved. If you would like to save variables in separate
        files, set `filename` None; if you would like to save all variables in a
        single file, use `filename` to specify the file name.
        """

        if isinstance(executor, ParallelExecutor):
            raise TypeError(
                "in fleet.save_persistables() function, executor must be as Executor type, ParallelExecutor is not allowed"
            )

        if not isinstance(executor, Executor):
            raise TypeError(
                "in fleet.save_persistables() function, executor must be as Executor type"
            )

        if main_program is None:
706
            main_program = self.compiled_strategy.get_origin_ps_main_program()
707 708 709 710 711 712

        if isinstance(main_program, CompiledProgram):
            raise TypeError(
                "in fleet.save_persistables() function, main_program must be as Program type, CompiledProgram is not allowed"
            )

713 714 715 716 717 718 719 720 721 722 723 724 725
        self._save_distributed_persistables(
            executor, dirname, main_program, mode
        )

    def _ps_inference_save_inference_model(
        self,
        executor,
        dirname,
        feeded_var_names,
        target_vars,
        main_program=None,
        export_for_deployment=True,
    ):
726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745
        """
        Prune the given `main_program` to build a new program especially for inference,
        and then save it and all related parameters to given `dirname` by the `executor`.
        """

        if isinstance(executor, ParallelExecutor):
            raise TypeError(
                "in fleet.save_inference_model() function, executor must be as Executor type, ParallelExecutor is not allowed"
            )

        if not isinstance(executor, Executor):
            raise TypeError(
                "in fleet.save_inference_model() function, executor must be as Executor type"
            )

        if main_program is not None:
            if isinstance(main_program, CompiledProgram):
                raise TypeError(
                    "in fleet.save_inference_model() function, main_program must be as Program type, CompiledProgram is not allowed"
                )
746 747 748 749 750 751 752 753 754 755
            fluid.io.save_inference_model(
                dirname,
                feeded_var_names,
                target_vars,
                executor,
                main_program,
                None,
                None,
                export_for_deployment,
            )
756
        else:
757 758 759 760 761 762 763 764 765 766 767
            fluid.io.save_inference_model(
                dirname,
                feeded_var_names,
                target_vars,
                executor,
                self.origin_main_program,
                None,
                None,
                export_for_deployment,
                True,
            )
768 769 770 771 772 773 774 775 776

            model_basename = "__model__"
            model_filename = os.path.join(dirname, model_basename)

            with open(model_filename, "rb") as f:
                program_desc_str = f.read()

            program = Program.parse_from_string(program_desc_str)
            program._copy_dist_param_info_from(fluid.default_main_program())
777 778 779
            self._ps_inference_save_persistables(
                executor, dirname, program, mode=0
            )
780 781 782 783 784 785

    def _save_inference_model(self, *args, **kwargs):
        self._ps_inference_save_inference_model(*args, **kwargs)

    def _save_persistables(self, *args, **kwargs):
        self._ps_inference_save_persistables(*args, **kwargs)
反馈
建议
客服 返回
顶部