io.py 94.4 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import os
T
bug fix  
tangwei12 已提交
16
import errno
D
dzhwinter 已提交
17
import warnings
18
import logging
Y
Yang Zhang 已提交
19
import pickle
H
hong 已提交
20
import contextlib
21
from functools import reduce
22
import sys
23
from io import BytesIO
24

H
hong 已提交
25
import numpy as np
26
import math
27
import paddle
28
from paddle.fluid import layers
H
hong 已提交
29
from paddle.fluid.executor import Executor, global_scope
30
from paddle.fluid.evaluator import Evaluator
T
tangwei12 已提交
31
from paddle.fluid.framework import Program, Parameter, default_main_program, default_startup_program, Variable, \
32
    program_guard, dygraph_not_support, static_only
33 34
from paddle.reader import cache, map_readers, buffered, compose, chain, shuffle, \
    ComposeNotAligned, firstn, xmap_readers, multiprocess_reader
35
from .wrapped_decorator import signature_safe_contextmanager
T
tangwei12 已提交
36
from paddle.fluid.compiler import CompiledProgram
37
from paddle.fluid.log_helper import get_logger
S
sneaxiy 已提交
38
from . import reader
39
from . import unique_name
S
sneaxiy 已提交
40
from .reader import *
41 42
from . import dataloader
from .dataloader import *
K
fix bug  
Kexin Zhao 已提交
43
from . import core
44 45
from paddle.utils import deprecated
from paddle.fluid.framework import static_only
46

47 48
batch = paddle.batch

49
__all__ = [
50 51 52 53 54 55 56 57 58 59 60 61 62
    'save_vars',
    'save_params',
    'save_persistables',
    'load_vars',
    'load_params',
    'load_persistables',
    'save_inference_model',
    'load_inference_model',
    'batch',
    'save',
    'load',
    'load_program_state',
    'set_program_state',
H
hong 已提交
63 64
    'get_program_parameter',
    'get_program_persistable_vars',
65
] + reader.__all__
66

67 68 69
_logger = get_logger(__name__,
                     logging.INFO,
                     fmt='%(asctime)s-%(levelname)s: %(message)s')
70

71

72
class _open_buffer(object):
73

74 75 76 77 78 79 80 81
    def __init__(self, buffer):
        self.buffer = buffer

    def __enter__(self):
        return self.buffer


class _buffer_reader(_open_buffer):
82

83 84 85 86 87 88 89 90 91 92 93
    def __init__(self, buffer):
        super(_buffer_reader, self).__init__(buffer)
        self.initial_tell = self.buffer.tell()

    def __exit__(self, *args):
        # `args[0]` is type of exception. When the `read` is abnormal, the file pointer returns to the initial position.
        if args[0] is not None:
            self.buffer.seek(self.initial_tell)


class _buffer_writer(_open_buffer):
94

95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
    def __exit__(self, *args):
        self.buffer.flush()


def _is_file_path(path):
    return isinstance(path, str)


def _open_file_buffer(path_or_buffer, mode):

    if _is_file_path(path_or_buffer):
        return open(path_or_buffer, mode)
    else:
        if 'w' in mode:
            return _buffer_writer(path_or_buffer)
        elif 'r' in mode:
            return _buffer_reader(path_or_buffer)
        else:
113 114
            raise ValueError(
                "Expected 'r' or 'w' in mode but got {}".format(mode))
115 116 117 118 119 120


def _is_memory_buffer(buffer):
    return isinstance(buffer, BytesIO)


121
def is_parameter(var):
F
fengjiayi 已提交
122 123
    """
    Check whether the given variable is an instance of Parameter.
124 125

    Args:
F
fengjiayi 已提交
126
        var(Variable): The variable to be checked.
127 128

    Returns:
F
fengjiayi 已提交
129 130 131 132 133 134
        bool: True if the given `var` is an instance of Parameter,
        False if not.

    Examples:
        .. code-block:: python

135
            import paddle
136
            import paddle.fluid as fluid
137 138

            paddle.enable_static()
F
fengjiayi 已提交
139 140
            param = fluid.default_main_program().global_block().var('fc.w')
            res = fluid.io.is_parameter(param)
141
    """
142 143 144 145
    return isinstance(var, Parameter)


def is_persistable(var):
F
fengjiayi 已提交
146 147 148 149 150 151 152 153 154 155 156 157 158
    """
    Check whether the given variable is persistable.

    Args:
        var(Variable): The variable to be checked.

    Returns:
        bool: True if the given `var` is persistable
        False if not.

    Examples:
        .. code-block:: python

159
            import paddle
160
            import paddle.fluid as fluid
161 162

            paddle.enable_static()
163
            param = fluid.default_main_program().global_block().var('fc.b')
F
fengjiayi 已提交
164 165
            res = fluid.io.is_persistable(param)
    """
166
    if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
167 168
                    var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
                    var.desc.type() == core.VarDesc.VarType.READER:
169
        return False
170 171 172
    return var.persistable


H
hong 已提交
173
def is_belong_to_optimizer(var):
174
    if not (isinstance(var, Parameter) or var.desc.need_check_feed()):
175 176 177
        return is_persistable(var)

    return False
H
hong 已提交
178 179


180
@dygraph_not_support
H
hong 已提交
181 182 183 184 185 186 187 188 189 190 191 192 193
def get_program_parameter(program):
    """
    Get all the parameters from Program.

    Args:
        var(Program): The Program to get parameters

    Returns:
        list: The list contains all parameters in the program

    Examples:
        .. code-block:: python

194
            import paddle
H
hong 已提交
195
            import paddle.fluid as fluid
196 197

            paddle.enable_static()
H
hong 已提交
198 199 200 201 202 203 204 205
            data = fluid.data(name="img", shape=[64, 784])
            w = fluid.layers.create_parameter(shape=[784, 200], dtype='float32', name='fc_w')
            b = fluid.layers.create_parameter(shape=[200], dtype='float32', name='fc_b')
            list_para  = fluid.io.get_program_parameter(  fluid.default_main_program() )
    """
    return list(filter(is_parameter, program.list_vars()))


206
@dygraph_not_support
H
hong 已提交
207 208 209 210 211 212 213 214 215 216 217 218 219
def get_program_persistable_vars(program):
    """
    Get all the persistable vars from Program.

    Args:
        var(Program): The Program to get persistable vars

    Returns:
        list: The list contains all persistable vars in the program

    Examples:
        .. code-block:: python

220
            import paddle
H
hong 已提交
221
            import paddle.fluid as fluid
222 223

            paddle.enable_static()
H
hong 已提交
224 225 226 227 228 229 230 231
            data = fluid.data(name="img", shape=[64, 784])
            w = fluid.layers.create_parameter(shape=[784, 200], dtype='float32', name='fc_w')
            b = fluid.layers.create_parameter(shape=[200], dtype='float32', name='fc_b')
            list_para  = fluid.io.get_program_persistable_vars(  fluid.default_main_program() )
    """
    return list(filter(is_persistable, program.list_vars()))


232 233
def _clone_var_in_block_(block, var):
    assert isinstance(var, Variable)
234
    if var.desc.type() == core.VarDesc.VarType.LOD_TENSOR:
235 236 237 238 239 240
        return block.create_var(name=var.name,
                                shape=var.shape,
                                dtype=var.dtype,
                                type=var.type,
                                lod_level=var.lod_level,
                                persistable=True)
241
    else:
242 243 244 245 246
        return block.create_var(name=var.name,
                                shape=var.shape,
                                dtype=var.dtype,
                                type=var.type,
                                persistable=True)
247 248


249
@signature_safe_contextmanager
H
hong 已提交
250 251 252 253 254 255 256
def _load_program_scope(main=None, startup=None, scope=None):
    prog = main if main else paddle.fluid.Program()
    startup_prog = startup if startup else paddle.fluid.Program()
    scope = scope if scope else paddle.fluid.core.Scope()
    with paddle.fluid.scope_guard(scope):
        with paddle.fluid.program_guard(prog, startup_prog):
            with paddle.fluid.unique_name.guard():
257 258
                with paddle.fluid.framework._dygraph_guard(None):
                    yield
H
hong 已提交
259 260


261
def _get_valid_program(main_program=None):
C
chengduo 已提交
262 263 264 265 266
    if main_program is None:
        main_program = default_main_program()
    elif isinstance(main_program, CompiledProgram):
        main_program = main_program._program
        if main_program is None:
267 268 269
            raise TypeError(
                "The type of input main_program is invalid, expected tyep is Program, but received None"
            )
C
chengduo 已提交
270 271 272
        warnings.warn(
            "The input is a CompiledProgram, this is not recommended.")
    if not isinstance(main_program, Program):
273 274 275
        raise TypeError(
            "The type of input main_program is invalid, expected type is fluid.Program, but received %s"
            % type(main_program))
C
chengduo 已提交
276 277 278
    return main_program


279
@dygraph_not_support
280 281 282 283 284
def save_vars(executor,
              dirname,
              main_program=None,
              vars=None,
              predicate=None,
285
              filename=None):
286
    """
287
    Save specific variables in the `Program` to files.
F
fengjiayi 已提交
288

289
    There are two ways to specify the variables to be saved: set variables in
290 291
    a list and assign it to the `vars`, or use the `predicate` function to select
    variables that make `predicate(variable) == True`. The first way has a higher priority.
292

293
    The `dirname` is used to specify the folder where to save variables.
T
tianshuo78520a 已提交
294
    If you prefer to save variables in separate files in the `dirname` folder,
295
    do not set `filename`. If you prefer to save all variables in a single file,
F
fengjiayi 已提交
296
    use `filename` to specify it.
297

F
fengjiayi 已提交
298 299
    Args:
        executor(Executor): The executor to run for saving variables.
300 301
        dirname(str, optional): The folder where to save variables.
                            When you need to save the parameter to the memory, set it to None.
302
        main_program(Program, optional): The program whose variables will be saved.
303
                                    If it is None, the default main program will
F
fengjiayi 已提交
304 305
                                    be used automatically.
                                    Default: None
306 307 308
        vars(list[Variable], optional): The list contains all variables to be saved.
                                        Default: None
        predicate(function, optional): The function selects the variables that make
309
                                       `predicate(variable) == True`.
310 311
                                       Default: None
        filename(str, optional): If you prefer to save all variables in a single file,
312
                                 use `filename` to specify it. Otherwise, let `filename` be None.
313
                                 Default: None
F
fengjiayi 已提交
314 315

    Returns:
316 317
        str: When saving parameters to a file, returns None.
             When saving parameters to memory, returns a binary string containing parameters.
F
fengjiayi 已提交
318 319 320 321 322 323 324

    Raises:
        TypeError: If `main_program` is not an instance of Program nor None.

    Examples:
        .. code-block:: python

325
            import paddle
326
            import paddle.fluid as fluid
327

328
            paddle.enable_static()
329 330 331 332 333 334 335 336 337 338 339
            main_prog = fluid.Program()
            startup_prog = fluid.Program()
            with fluid.program_guard(main_prog, startup_prog):
                data = fluid.layers.data(name="img", shape=[64, 784], append_batch_size=False)
                w = fluid.layers.create_parameter(shape=[784, 200], dtype='float32', name='fc_w')
                b = fluid.layers.create_parameter(shape=[200], dtype='float32', name='fc_b')
                hidden_w = fluid.layers.matmul(x=data, y=w)
                hidden_b = fluid.layers.elementwise_add(hidden_w, b)
            place = fluid.CPUPlace()
            exe = fluid.Executor(place)
            exe.run(startup_prog)
F
fengjiayi 已提交
340

341
            # The first usage: use `vars` to set the saved variables.
342 343
            var_list = [w, b]
            path = "./my_paddle_vars"
344
            fluid.io.save_vars(executor=exe, dirname=path, vars=var_list,
345 346 347 348 349 350 351 352 353 354
                            filename="vars_file")
            # w and b will be save in a file named "var_file".

            # The second usage: use `predicate` to select the saved variable.
            def name_has_fc(var):
                res = "fc" in var.name
                return res
            param_path = "./my_paddle_model"
            fluid.io.save_vars(executor=exe, dirname=param_path, main_program=main_prog, vars=None, predicate = name_has_fc)
            # all variables whose names contain "fc " are saved.
355
    """
356 357 358 359
    save_to_memory = False
    if dirname is None and filename is None:
        save_to_memory = True

C
chengduo 已提交
360
    main_program = _get_valid_program(main_program)
T
tangwei12 已提交
361

362
    if vars is None:
363 364 365 366 367
        return save_vars(executor,
                         main_program=main_program,
                         dirname=dirname,
                         vars=list(filter(predicate, main_program.list_vars())),
                         filename=filename)
368
    else:
石晓伟 已提交
369
        params_var_name = "saved_params"
370 371 372 373 374 375 376
        # give warning when there is no var in model
        if len(list(vars)) == 0:
            warnings.warn(
                "no variable in your model, please ensure there are any variables in your model to save"
            )
            return None

377 378
        save_program = Program()
        save_block = save_program.global_block()
379 380

        save_var_map = {}
381
        for each_var in vars:
382 383 384
            # NOTE: don't save the variable which type is RAW
            if each_var.type == core.VarDesc.VarType.RAW:
                continue
385
            new_var = _clone_var_in_block_(save_block, each_var)
386
            if filename is None and save_to_memory is False:
387 388
                save_file_path = os.path.join(os.path.normpath(dirname),
                                              new_var.name)
389 390 391 392
                save_block.append_op(
                    type='save',
                    inputs={'X': [new_var]},
                    outputs={},
393
                    attrs={'file_path': os.path.normpath(save_file_path)})
394 395 396
            else:
                save_var_map[new_var.name] = new_var

397
        if filename is not None or save_to_memory:
398 399 400 401
            save_var_list = []
            for name in sorted(save_var_map.keys()):
                save_var_list.append(save_var_map[name])

402 403 404 405
            save_path = str()
            if save_to_memory is False:
                save_path = os.path.join(os.path.normpath(dirname), filename)

406 407
            saved_params = save_block.create_var(type=core.VarDesc.VarType.RAW,
                                                 name=params_var_name)
408
            saved_params.desc.set_persistable(True)
409 410 411 412 413 414 415
            save_block.append_op(type='save_combine',
                                 inputs={'X': save_var_list},
                                 outputs={'Y': saved_params},
                                 attrs={
                                     'file_path': save_path,
                                     'save_to_memory': save_to_memory
                                 })
416

417
        # NOTE(zhiqiu): save op will add variable kLookupTablePath in save_program.desc,
418 419 420
        # which leads to diff on save_program and its desc. Call _sync_with_cpp
        # to keep consistency.
        save_program._sync_with_cpp()
421
        executor.run(save_program)
422 423
        if save_to_memory:
            return global_scope().find_var(params_var_name).get_bytes()
424 425


426
@dygraph_not_support
427
def save_params(executor, dirname, main_program=None, filename=None):
428
    """
429
    Save all parameters from the :code:`main_program` to
430
    the folder :code:`dirname` or file :code:`filename`. You can refer to
G
guofei 已提交
431
    :ref:`api_guide_model_save_reader_en` for more details.
F
fengjiayi 已提交
432

G
guofei 已提交
433 434 435
    Use the :code:`dirname` to specify the saving folder. If you would like to
    save parameters in separate files, set :code:`filename` None; if you would
    like to save all parameters in a single file, use :code:`filename` to specify
F
fengjiayi 已提交
436 437
    the file name.

438
    Note:
G
guofei 已提交
439
        Some variables are not Parameter while they are necessary for
440
        training, such as learning rate, global step, etc. So you can NOT save
G
guofei 已提交
441 442
        and continue your training just by :ref:`api_fluid_io_save_params`
        and :ref:`api_fluid_io_load_params`. Please use :ref:`api_fluid_io_save_persistables`
443 444 445
        and :ref:`api_fluid_io_load_persistables` instead.

        If you want to save your model for the inference, please use the
G
guofei 已提交
446 447
        :ref:`api_fluid_io_save_inference_model`. You can refer to
        :ref:`api_guide_model_save_reader_en` for more details.
F
fengjiayi 已提交
448 449

    Args:
450
        executor(Executor): The executor to run for saving parameters, You can
G
guofei 已提交
451
                            refer to :ref:`api_guide_executor_en`.
452 453
        dirname(str, optional): The saving directory path.
                            When you need to save the parameter to the memory, set it to None.
G
guofei 已提交
454
        main_program(Program, optional): The program whose parameters will be
455 456
                                         saved. You can refer to
                                         :ref:`api_guide_Program_en` for more
G
guofei 已提交
457 458 459 460 461 462 463
                                         details. If it is None, the default main
                                         program will be used.
                                         Default: None
        filename(str, optional): The file to save all parameters. If you prefer
                                 to save parameters in different files, set it
                                 to None.
                                 Default: None
F
fengjiayi 已提交
464 465

    Returns:
466 467
        str: When saving parameters to a file, returns None.
             When saving parameters to memory, returns a binary string containing parameters.
F
fengjiayi 已提交
468 469 470 471

    Examples:
        .. code-block:: python

472
            import paddle
H
Huihuang Zheng 已提交
473
            import paddle.fluid as fluid
474

475 476

            paddle.enable_static()
G
guofei 已提交
477 478 479 480 481
            params_path = "./my_paddle_model"
            image = fluid.data(name='img', shape=[None, 28, 28], dtype='float32')
            label = fluid.data(name='label', shape=[None, 1], dtype='int64')
            feeder = fluid.DataFeeder(feed_list=[image, label], place=fluid.CPUPlace())
            predict = fluid.layers.fc(input=image, size=10, act='softmax')
482

G
guofei 已提交
483
            loss = fluid.layers.cross_entropy(input=predict, label=label)
484
            avg_loss = paddle.mean(loss)
485

F
fengjiayi 已提交
486
            exe = fluid.Executor(fluid.CPUPlace())
G
guofei 已提交
487 488
            exe.run(fluid.default_startup_program())
            fluid.io.save_params(executor=exe, dirname=params_path)
489 490
            # The parameters weights and bias of the fc layer in the network are going to
            # be saved in different files in the path "./my_paddle_model"
491
    """
492 493 494 495 496 497
    return save_vars(executor,
                     dirname=dirname,
                     main_program=main_program,
                     vars=None,
                     predicate=is_parameter,
                     filename=filename)
498 499


500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521
def _save_distributed_persistables(executor, dirname, main_program):
    """
    save_persistables for distributed training.
    the method will do things listed below:
    1.save part of persistable variables on trainer.
    2.receive "remote prefetch variables" from parameter servers and merge them.
    3.save "distributed lookup table" on parameter servers.
    4.receive "optimizer variables" from parameter servers and merge them.

    Args:
        executor(Executor): The executor to run for saving parameters.
        dirname(str): The saving directory path.
        main_program(Program): The program whose parameters will be
                            saved. the main_program must be the trainer_program
                            get after transpiler.

    Returns:
        None

    Examples:
        .. code-block:: python

522
            import paddle
523
            import paddle.fluid as fluid
524 525

            paddle.enable_static()
526 527 528 529 530 531 532 533 534 535
            exe = fluid.Executor(fluid.CPUPlace())
            param_path = "./my_paddle_model"
            t = distribute_transpiler.DistributeTranspiler()
            t.transpile(...)
            train_program = t.get_trainer_program()
            _save_distributed_persistables(executor=exe, dirname=param_path, main_program=train_program)
    """

    def __save_remote_params(executor, dirname, remote_params_map):
        """
T
tianshuo78520a 已提交
536
        receive params on pserver through rpc.
537 538 539 540 541 542 543 544 545 546
        if the params are be sliced, will concat them to one, then save it.
        """
        if not remote_params_map:
            return

        prog = Program()
        block = prog.global_block()

        # recv optimize vars from pserver
        for name, remote_params in remote_params_map.items():
T
tangwei12 已提交
547 548 549 550 551 552 553
            origin = remote_params[0].origin
            is_slice = remote_params[0].is_slice

            slices = [None] * len(remote_params)
            slice_varnames = [None] * len(remote_params)
            remote_varnames = [None] * len(remote_params)
            endpoints = [None] * len(remote_params)
554 555 556

            for idx, optimizer in enumerate(remote_params):
                block_id = optimizer.block_id
T
tangwei12 已提交
557
                slice = optimizer.slice
558 559 560
                endpoint = optimizer.endpoint

                index = block_id if is_slice else idx
T
tangwei12 已提交
561 562 563
                slices[index] = slice
                slice_varnames[index] = "{}.slice.{}".format(slice.name, idx)
                remote_varnames[index] = slice.name
564 565
                endpoints[index] = endpoint

T
tangwei12 已提交
566 567 568 569 570
            slice_shapes = []
            for slice in slices:
                tmp = [str(dim) for dim in slice.shape]
                slice_shapes.append(",".join(tmp))

571 572 573 574 575 576 577 578 579 580
            block.append_op(type='recv_save',
                            attrs={
                                "trainer_id": 0,
                                "shape": origin.shape,
                                "slice_shapes": slice_shapes,
                                "slice_varnames": slice_varnames,
                                "remote_varnames": remote_varnames,
                                "endpoints": endpoints,
                                "file_path": os.path.join(dirname, origin.name)
                            })
T
tangwei12 已提交
581

582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601
        executor.run(prog)

    def __save_distributed_lookup_tables(executor, dirname,
                                         distributed_lookup_table, endpoints):
        """
        because the distributed lookup table may too huge to merge and save at one place,
        it will be saved at parameter server independent respectively.

        the save directory is dirname/"__lookup_table__".

        """
        prog = Program()
        block = prog.global_block()

        # if there is lookup table, the trainer 0 will notify all pserver to save.
        lookup_table_filename = os.path.join(dirname, "__lookup_table__")
        attrs = {}
        attrs['epmap'] = endpoints
        attrs['dir'] = lookup_table_filename
        attrs['lookup_table'] = distributed_lookup_table
602 603 604 605
        block.append_op(type='checkpoint_notify',
                        inputs={},
                        outputs={},
                        attrs=attrs)
606 607 608
        executor.run(prog)

    def __exclude_vars(exclude_var_names=[]):
609

610 611 612 613
        def is_valid(var):
            if var.name in exclude_var_names:
                return False
            if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
614 615
                            var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
                            var.desc.type() == core.VarDesc.VarType.READER:
616 617 618 619 620 621
                return False
            return var.persistable

        return is_valid

    if not isinstance(main_program, Program):
T
tangwei12 已提交
622
        raise TypeError("'main_program' should be an instance of Program.")
623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643

    if not main_program._is_distributed:
        raise ValueError(
            "'_save_distributed_persistables' just be designed for distributed training."
        )

    remote_params_map = main_program._parameters_on_pservers.get_distributed_vars_by_vtypes(
        ["Optimizer", "RemotePrefetch"], groupby=True)

    exclude_var_names = []
    if remote_params_map:
        exclude_var_names.extend(remote_params_map.keys())

    if main_program._distributed_lookup_table:
        if isinstance(main_program._distributed_lookup_table, list):
            exclude_var_names.extend(main_program._distributed_lookup_table)
        else:
            exclude_var_names.append(main_program._distributed_lookup_table)

    local_vars = list(
        filter(__exclude_vars(exclude_var_names), main_program.list_vars()))
644 645 646 647
    save_vars(executor,
              main_program=main_program,
              dirname=dirname,
              vars=local_vars)
648 649 650 651 652 653 654 655 656 657

    if main_program._is_chief:
        if remote_params_map:
            __save_remote_params(executor, dirname, remote_params_map)
        if main_program._distributed_lookup_table:
            __save_distributed_lookup_tables(
                executor, dirname, main_program._distributed_lookup_table,
                main_program._endpoints)


658
@dygraph_not_support
659
def save_persistables(executor, dirname, main_program=None, filename=None):
660
    """
661 662
    Save all persistable variables from :code:`main_program` to
    the folder :code:`dirname` or file :code:`filename`. You can refer to
G
guofei 已提交
663
    :ref:`api_guide_model_save_reader_en` for more details. And then
664 665
    saves these persistables variables to the folder :code:`dirname` or file
    :code:`filename`.
F
fengjiayi 已提交
666

G
guofei 已提交
667
    The :code:`dirname` is used to specify the folder where persistable variables
668
    are going to be saved. If you would like to save variables in separate
G
guofei 已提交
669 670
    files, set :code:`filename` None; if you would like to save all variables in a
    single file, use :code:`filename` to specify the file name.
F
fengjiayi 已提交
671 672 673

    Args:
        executor(Executor): The executor to run for saving persistable variables.
674
                            You can refer to :ref:`api_guide_executor_en` for
G
guofei 已提交
675
                            more details.
676

677 678 679
        dirname(str, optional): The saving directory path.
                            When you need to save the parameter to the memory, set it to None.
        main_program(Program, optional): The program whose persistbale variables will
680
                                         be saved. You can refer to
G
guofei 已提交
681
                                         :ref:`api_guide_Program_en` for more details.
682
                                         If it is None, the default main program will
G
guofei 已提交
683 684 685 686 687
                                         be used.
                                         Default: None.
        filename(str, optional): The file to save all variables. If you prefer to
                                 save variables in different files, set it to None.
                                 Default: None.
F
fengjiayi 已提交
688 689

    Returns:
690 691
        str: When saving parameters to a file, returns None.
             When saving parameters to memory, returns a binary string containing parameters.
F
fengjiayi 已提交
692 693 694 695

    Examples:
        .. code-block:: python

696
            import paddle
H
Huihuang Zheng 已提交
697
            import paddle.fluid as fluid
698

699
            paddle.enable_static()
G
guofei 已提交
700 701 702 703 704
            dir_path = "./my_paddle_model"
            file_name = "persistables"
            image = fluid.data(name='img', shape=[None, 28, 28], dtype='float32')
            label = fluid.data(name='label', shape=[None, 1], dtype='int64')
            feeder = fluid.DataFeeder(feed_list=[image, label], place=fluid.CPUPlace())
705

G
guofei 已提交
706 707
            predict = fluid.layers.fc(input=image, size=10, act='softmax')
            loss = fluid.layers.cross_entropy(input=predict, label=label)
708
            avg_loss = paddle.mean(loss)
F
fengjiayi 已提交
709
            exe = fluid.Executor(fluid.CPUPlace())
G
guofei 已提交
710 711
            exe.run(fluid.default_startup_program())
            fluid.io.save_persistables(executor=exe, dirname=dir_path, filename=file_name)
712
            # The persistables variables weights and bias in the fc layer of the network
G
guofei 已提交
713 714
            # are going to be saved in the same file named "persistables" in the path
            # "./my_paddle_model"
715
    """
716
    if main_program and main_program._is_distributed:
717 718 719
        return _save_distributed_persistables(executor,
                                              dirname=dirname,
                                              main_program=main_program)
720
    else:
721 722 723 724 725 726
        return save_vars(executor,
                         dirname=dirname,
                         main_program=main_program,
                         vars=None,
                         predicate=is_persistable,
                         filename=filename)
727 728


729 730 731 732 733
def load_vars(executor,
              dirname,
              main_program=None,
              vars=None,
              predicate=None,
734
              filename=None):
735
    """
736 737
    :api_attr: Static Graph

738
    This API loads variables from files by executor.
F
fengjiayi 已提交
739

740
    There are two ways to specify the variables to be loaded: the first way, set
741 742
    variables in a list and assign it to the `vars`; the second way, use the
    `predicate` function to select variables that make `predicate(variable) == True`.
743
    The first way has a higher priority.
F
fengjiayi 已提交
744

745
    The `dirname` is used to specify the folder where to load variables.
746
    If variables were saved in separate files in the folder `dirname`,
747
    set `filename` None. If all variables were saved in a single file,
F
fengjiayi 已提交
748
    use `filename` to specify it.
749

F
fengjiayi 已提交
750 751
    Args:
        executor(Executor): The executor to run for loading variables.
752 753
        dirname(str): The folder where to load the variables.
        main_program(Program, optional): The program whose variables will be loaded.
754
                                    If it is None, the default main program will
F
fengjiayi 已提交
755 756
                                    be used automatically.
                                    Default: None
757
        vars(list[Variable], optional): The list that contains all variables to be loaded.
F
fengjiayi 已提交
758
                                   Default: None
759
        predicate(function, optional): The function selects variables that make
760 761 762 763 764
                                        `predicate(variable) == True`.
                                        Default: None
        filename(str, optional): The file which saved all required variables. If variables
                                were saved in separate files, set it to be None.
                                Default: None
F
fengjiayi 已提交
765 766 767 768 769 770 771

    Returns:
        None

    Examples:
        .. code-block:: python

772
            import paddle
773
            import paddle.fluid as fluid
774

775
            paddle.enable_static()
776 777 778 779 780 781 782 783 784 785 786
            main_prog = fluid.Program()
            startup_prog = fluid.Program()
            with fluid.program_guard(main_prog, startup_prog):
                data = fluid.layers.data(name="img", shape=[64, 784], append_batch_size=False)
                w = fluid.layers.create_parameter(shape=[784, 200], dtype='float32', name='fc_w')
                b = fluid.layers.create_parameter(shape=[200], dtype='float32', name='fc_b')
                hidden_w = fluid.layers.matmul(x=data, y=w)
                hidden_b = fluid.layers.elementwise_add(hidden_w, b)
            place = fluid.CPUPlace()
            exe = fluid.Executor(place)
            exe.run(startup_prog)
F
fengjiayi 已提交
787

788 789 790 791 792 793 794 795 796 797 798
            # The first usage: using `vars` to specify the variables.
            path = "./my_paddle_vars"
            var_list = [w, b]
            fluid.io.save_vars(executor=exe, dirname=path, vars=var_list,
                               filename="vars_file")
            fluid.io.load_vars(executor=exe, dirname=path, vars=var_list,
                               filename="vars_file")
            # w and b will be loaded, and they are supposed to
            # be saved in the same file named 'var_file' in the path "./my_paddle_vars".

            # The second usage: using the `predicate` function to select variables
799
            param_path = "./my_paddle_model"
F
fengjiayi 已提交
800 801 802
            def name_has_fc(var):
                res = "fc" in var.name
                return res
803 804 805
            fluid.io.save_vars(executor=exe, dirname=param_path, main_program=main_prog,
                              vars=None, predicate=name_has_fc)
            fluid.io.load_vars(executor=exe, dirname=param_path, main_program=main_prog,
C
chengduo 已提交
806
                               vars=None, predicate=name_has_fc)
807 808
            # Load All variables in the `main_program` whose name includes "fc".
            # And all the variables are supposed to be saved in separate files.
F
fengjiayi 已提交
809

810
    """
811 812 813 814 815
    vars_from_memory = False
    if dirname is not None:
        dirname = os.path.normpath(dirname)
    else:
        vars_from_memory = True
T
tangwei12 已提交
816

817
    if vars is None:
818
        if main_program is None:
Y
Yu Yang 已提交
819
            main_program = default_main_program()
820
        if not isinstance(main_program, Program):
821 822 823
            raise TypeError(
                "The type of input main_program is invalid, expected type is fluid.Program, but received %s"
                % type(main_program))
824

825 826 827 828 829
        load_vars(executor,
                  dirname=dirname,
                  main_program=main_program,
                  vars=list(filter(predicate, main_program.list_vars())),
                  filename=filename)
830 831 832
    else:
        load_prog = Program()
        load_block = load_prog.global_block()
833

834 835
        if main_program is None:
            main_program = default_main_program()
T
tangwei12 已提交
836

837
        if not isinstance(main_program, Program):
838 839 840
            raise TypeError(
                "The type of input main_program is invalid, expected type is fluid.Program, but received %s"
                % type(main_program))
841

T
tangwei12 已提交
842
        # save origin param shape
H
hong 已提交
843
        orig_para_shape = {}
844
        load_var_map = {}
845 846 847 848

        check_vars = []
        sparse_vars = []

849 850
        for each_var in vars:
            assert isinstance(each_var, Variable)
851

T
tangwei12 已提交
852 853
            if each_var.type == core.VarDesc.VarType.RAW:
                continue
H
hong 已提交
854 855

            if isinstance(each_var, Parameter):
856 857
                orig_para_shape[each_var.name] = tuple(
                    each_var.desc.get_shape())
858 859 860 861 862

            if each_var.type == core.VarDesc.VarType.SELECTED_ROWS:
                sparse_vars.append(each_var)
                continue

863
            new_var = _clone_var_in_block_(load_block, each_var)
864 865
            check_vars.append(each_var)

866
            if filename is None:
867 868 869 870
                if dirname is None:
                    raise ValueError(
                        "The directory path and params cannot be None at the same time."
                    )
871 872 873 874
                load_block.append_op(
                    type='load',
                    inputs={},
                    outputs={'Out': [new_var]},
875
                    attrs={'file_path': os.path.join(dirname, new_var.name)})
876 877 878
            else:
                load_var_map[new_var.name] = new_var

879 880 881 882 883 884 885 886 887 888 889
        for each_var in sparse_vars:
            assert isinstance(each_var, Variable)

            if filename is not None:
                raise ValueError(
                    "SelectedRows can not be load with load_combine")

            new_var = _clone_var_in_block_(load_block, each_var)

            var_path = os.path.join(dirname, new_var.name)
            if not os.path.exists(var_path):
890 891 892
                raise ValueError(
                    "SelectedRows var {} can not find at {}".format(
                        new_var.name, var_path))
893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909

            if os.path.isfile(var_path):
                load_block.append_op(
                    type='load',
                    inputs={},
                    outputs={'Out': [new_var]},
                    attrs={'file_path': os.path.join(dirname, new_var.name)})
            else:
                blocks = []
                block_paths = os.listdir(var_path)

                for block in block_paths:
                    if block.startswith(new_var.name):
                        blocks.append(block)

                slices = []
                for block in blocks:
910 911 912 913 914
                    slice = load_block.create_var(name=block,
                                                  type=new_var.type,
                                                  shape=new_var.shape,
                                                  dtype=new_var.dtype,
                                                  persistable=False)
915 916 917
                    slices.append(slice)

                    file_path = os.path.join(var_path, block, "Param")
918 919 920 921
                    load_block.append_op(type='load',
                                         inputs={},
                                         outputs={'Out': [slice]},
                                         attrs={'file_path': file_path})
922

923 924 925 926
                load_block.append_op(type='lookup_sparse_table_merge',
                                     inputs={'X': slices},
                                     outputs={'Out': new_var},
                                     attrs={})
927

928
        if filename is not None:
929 930 931 932
            load_var_list = []
            for name in sorted(load_var_map.keys()):
                load_var_list.append(load_var_map[name])

933 934 935
            if vars_from_memory is False:
                filename = os.path.join(dirname, filename)

936 937 938 939 940 941 942
            load_block.append_op(type='load_combine',
                                 inputs={},
                                 outputs={"Out": load_var_list},
                                 attrs={
                                     'file_path': filename,
                                     'model_from_memory': vars_from_memory
                                 })
943 944
        executor.run(load_prog)

T
tangwei12 已提交
945
        # check var shape
946
        for each_var in check_vars:
H
hong 已提交
947 948 949 950 951
            if not isinstance(each_var, Parameter):
                continue
            var_temp = paddle.fluid.global_scope().find_var(each_var.name)
            assert var_temp != None, "can't not find var: " + each_var.name
            new_shape = (np.array(var_temp.get_tensor())).shape
952
            assert each_var.name in orig_para_shape, each_var.name + "MUST in var list"
H
hong 已提交
953 954 955
            orig_shape = orig_para_shape.get(each_var.name)
            if new_shape != orig_shape:
                raise RuntimeError(
956
                    "Variable's shape does not match, the Program requires a parameter with the shape of ({}), "
957 958
                    "while the loaded parameter (namely [ {} ]) has a shape of  ({})."
                    .format(orig_shape, each_var.name, new_shape))
H
hong 已提交
959

960

961
@dygraph_not_support
962
def load_params(executor, dirname, main_program=None, filename=None):
963
    """
964 965
    :api_attr: Static Graph

966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984
    This API filters out all parameters from the give ``main_program``
    and then tries to load these parameters from the directory ``dirname`` or
    the file ``filename``.

    Use the ``dirname`` to specify the directory where parameters were saved. If
    parameters were saved in separate files under the directory `dirname`, set
    ``filename`` as None; if all parameters were saved in a single file, use
    ``filename`` to specify the file name.

    **Note**:
        Some variables are not Parameter while they are necessary for
        training, such as learning rate, global step, etc. So you cannot save and
        continue your training just by using :ref:`api_fluid_io_save_params` and
        :ref:`api_fluid_io_load_params`. Please use :ref:`api_fluid_io_save_persistables`
        and :ref:`api_fluid_io_load_persistables` instead.

        If you want to load the pre-trained model structure and parameters
        for the inference, please use the :ref:`api_fluid_io_load_inference_model` API. You can
        refer to :ref:`api_guide_model_save_reader_en` for more details.
F
fengjiayi 已提交
985 986

    Args:
987 988
        executor(Executor): The executor used for loading parameters.
                            See :ref:`api_guide_executor_en` for more details about it.
F
fengjiayi 已提交
989
        dirname(str): The directory path.
990 991 992 993 994 995 996 997
        main_program(Program, optional): The program whose parameters will be
                                    loaded. If it is None, the ``default_main_program``
                                    will be used automatically. See :ref:`api_guide_Program_en`
                                    for more about ``Program``.
                                    Default: None.
        filename(str, optional): The file which saved all parameters. If parameters
                            were saved in separated files, set it to None.
                            Default: None.
F
fengjiayi 已提交
998 999 1000 1001 1002 1003 1004

    Returns:
        None

    Examples:
        .. code-block:: python

1005
            import paddle
1006
            import paddle.fluid as fluid
1007

1008
            paddle.enable_static()
F
fengjiayi 已提交
1009 1010 1011
            exe = fluid.Executor(fluid.CPUPlace())
            param_path = "./my_paddle_model"
            prog = fluid.default_main_program()
1012
            fluid.io.load_params(executor=exe, dirname=param_path,
F
fengjiayi 已提交
1013
                                main_program=None)
1014
    """
1015 1016 1017 1018 1019
    load_vars(executor,
              dirname=dirname,
              main_program=main_program,
              predicate=is_parameter,
              filename=filename)
1020 1021


1022
@dygraph_not_support
1023
def load_persistables(executor, dirname, main_program=None, filename=None):
1024
    """
1025
    :api_attr: Static Graph
1026

1027 1028
    This API filters out all variables with ``persistable==True`` from the
    given ``main_program`` and then tries to load these variables from the
T
tianshuo78520a 已提交
1029
    directory ``dirname`` or the file ``filename``.
F
fengjiayi 已提交
1030

1031 1032 1033 1034
    Use the ``dirname`` to specify the directory where persistable variables
    (refer to :ref:`api_guide_model_save_reader_en`) were saved. If variables
    were saved in separate files, set ``filename`` as None; if all variables
    were saved in a single file, use ``filename`` to specify the file name.
F
fengjiayi 已提交
1035 1036

    Args:
1037 1038
        executor(Executor): The executor used for loading persistable variables.
                            See :ref:`api_guide_executor_en` for more details about it.
F
fengjiayi 已提交
1039
        dirname(str): The directory path.
T
tianshuo78520a 已提交
1040
        main_program(Program, optional): The program whose persistable variables will
1041 1042 1043 1044 1045 1046 1047
                                    be loaded. If it is None, the ``default_main_program``
                                    will be used automatically. See :ref:`api_guide_Program_en`
                                    for more about ``Program``.
                                    Default: None.
        filename(str, optional): The file which saved all persistable variables. If variables
                                 were saved in separated files, set it to None.
                                 Default: None.
F
fengjiayi 已提交
1048 1049 1050 1051 1052 1053 1054

    Returns:
        None

    Examples:
        .. code-block:: python

1055
            import paddle
1056
            import paddle.fluid as fluid
1057

1058
            paddle.enable_static()
F
fengjiayi 已提交
1059 1060 1061
            exe = fluid.Executor(fluid.CPUPlace())
            param_path = "./my_paddle_model"
            prog = fluid.default_main_program()
1062
            fluid.io.load_persistables(executor=exe, dirname=param_path,
F
fengjiayi 已提交
1063
                                       main_program=None)
1064
    """
1065 1066

    if main_program and main_program._is_distributed:
1067 1068 1069
        _load_distributed_persistables(executor,
                                       dirname=dirname,
                                       main_program=main_program)
1070
    else:
1071 1072 1073 1074 1075
        load_vars(executor,
                  dirname=dirname,
                  main_program=main_program,
                  predicate=is_persistable,
                  filename=filename)
1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095


def _load_distributed_persistables(executor, dirname, main_program=None):
    """
    customized load_persistables for distributed training.
    it should be used on parameter server,

    Args:
        executor(Executor): The executor to run for saving parameters.
        dirname(str): The load directory path.
        main_program(Program): The program whose parameters will be
                            loaded. the main_program must be the pserver_program
                            get after transpiler.

    Returns:
        None

    Examples:
        .. code-block:: python

1096
            import paddle
1097
            import paddle.fluid as fluid
1098 1099

            paddle.enable_static()
1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124
            exe = fluid.Executor(fluid.CPUPlace())
            param_path = "./my_paddle_model"
            t = distribute_transpiler.DistributeTranspiler()
            t.transpile(...)
            pserver_prog = t.get_pserver_program(...)
            _load_distributed_persistables(executor=exe, dirname=param_path, main_program=pserver_prog)
    """

    def __is_distributed_part_var(varname):
        trainer_idx = varname.find(".trainer_")
        block_idx = varname.find(".block")
        return trainer_idx or block_idx

    def __load_persistable_vars(executor, dirname, need_load_vars):
        load_prog = Program()
        load_block = load_prog.global_block()
        need_delete_vars = []

        for param in need_load_vars:
            origin_var = param.origin
            slice_var = param.slice
            is_slice = param.is_slice
            offset = param.offset

            if is_slice:
1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141
                slice = load_block.create_var(name=slice_var.name,
                                              type=slice_var.type,
                                              shape=slice_var.shape,
                                              dtype=slice_var.dtype,
                                              persistable=True)

                load_block.append_op(type='load',
                                     inputs={},
                                     outputs={'Out': [slice]},
                                     attrs={
                                         'file_path':
                                         os.path.join(dirname, origin_var.name),
                                         'seek':
                                         offset,
                                         'shape':
                                         slice.shape
                                     })
1142
            else:
1143 1144 1145 1146 1147 1148
                origin = load_block.create_var(name="{}".format(
                    origin_var.name),
                                               type=origin_var.type,
                                               shape=origin_var.shape,
                                               dtype=origin_var.dtype,
                                               persistable=True)
1149 1150 1151 1152
                load_block.append_op(
                    type='load',
                    inputs={},
                    outputs={'Out': [origin]},
1153
                    attrs={'file_path': os.path.join(dirname, origin_var.name)})
1154 1155 1156

        load_block.append_op(
            type='delete_var',
1157 1158
            inputs={'X': need_delete_vars},
        )
1159 1160 1161 1162

        executor.run(load_prog)

    if not isinstance(main_program, Program):
T
tangwei12 已提交
1163
        raise TypeError("'main_program' should be an instance of Program.")
1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177

    if not main_program._is_distributed:
        raise ValueError(
            "'_load_distributed_persistables' just be designed for distributed training."
        )

    if not main_program._ps_endpoint:
        raise ValueError(
            "'_load_distributed_persistables' need current_endpoint set in DistributeTranspiler.transpile"
        )

    need_load_vars = main_program._parameters_on_pservers.get_distributed_vars_by_ep(
        main_program._ps_endpoint)
    __load_persistable_vars(executor, dirname, need_load_vars)
1178 1179


1180 1181 1182
def prepend_feed_ops(inference_program,
                     feed_target_names,
                     feed_holder_name='feed'):
Q
Qiao Longfei 已提交
1183 1184 1185
    if len(feed_target_names) == 0:
        return

K
Kexin Zhao 已提交
1186
    global_block = inference_program.global_block()
1187 1188 1189
    feed_var = global_block.create_var(name=feed_holder_name,
                                       type=core.VarDesc.VarType.FEED_MINIBATCH,
                                       persistable=True)
K
Kexin Zhao 已提交
1190

1191
    for i, name in enumerate(feed_target_names):
1192 1193 1194 1195 1196
        if not global_block.has_var(name):
            raise ValueError(
                "The feeded_var_names[{i}]: '{name}' doesn't exist in pruned inference program. "
                "Please check whether '{name}' is a valid feed_var name, or remove it from feeded_var_names "
                "if '{name}' is not involved in the target_vars calculation.".
1197
                format(i=i, name=name))
K
fix bug  
Kexin Zhao 已提交
1198
        out = global_block.var(name)
1199 1200 1201 1202
        global_block._prepend_op(type='feed',
                                 inputs={'X': [feed_var]},
                                 outputs={'Out': [out]},
                                 attrs={'col': i})
K
Kexin Zhao 已提交
1203 1204


1205 1206 1207
def append_fetch_ops(inference_program,
                     fetch_target_names,
                     fetch_holder_name='fetch'):
K
Kexin Zhao 已提交
1208
    global_block = inference_program.global_block()
1209 1210 1211
    fetch_var = global_block.create_var(name=fetch_holder_name,
                                        type=core.VarDesc.VarType.FETCH_LIST,
                                        persistable=True)
K
Kexin Zhao 已提交
1212

1213
    for i, name in enumerate(fetch_target_names):
1214 1215 1216 1217
        global_block.append_op(type='fetch',
                               inputs={'X': [name]},
                               outputs={'Out': [fetch_var]},
                               attrs={'col': i})
K
Kexin Zhao 已提交
1218 1219


1220 1221
@static_only
@deprecated(since="2.0.0", update_to="paddle.static.save_inference_model")
1222 1223 1224 1225
def save_inference_model(dirname,
                         feeded_var_names,
                         target_vars,
                         executor,
1226
                         main_program=None,
1227
                         model_filename=None,
1228
                         params_filename=None,
T
tangwei12 已提交
1229
                         export_for_deployment=True,
1230
                         program_only=False,
1231
                         clip_extra=True):
1232
    """
F
fengjiayi 已提交
1233
    Prune the given `main_program` to build a new program especially for inference,
G
guofei 已提交
1234
    and then save it and all related parameters to given `dirname` .
1235
    If you just want to save parameters of your trained model, please use the
G
guofei 已提交
1236 1237
    :ref:`api_fluid_io_save_params` . You can refer to :ref:`api_guide_model_save_reader_en`
    for more details.
1238

G
guofei 已提交
1239
    Note:
1240
        The :code:`dirname` is used to specify the folder where inference model
G
guofei 已提交
1241
        structure and parameters are going to be saved. If you would like to save params of
1242
        Program in separate files, set `params_filename` None; if you would like to save all
G
guofei 已提交
1243
        params of Program in a single file, use `params_filename` to specify the file name.
F
fengjiayi 已提交
1244 1245 1246

    Args:
        dirname(str): The directory path to save the inference model.
T
tianshuo78520a 已提交
1247
        feeded_var_names(list[str]): list of string. Names of variables that need to be fed
G
guofei 已提交
1248
                                     data during inference.
1249
        target_vars(list[Variable]): list of Variable. Variables from which we can get
G
guofei 已提交
1250
                                     inference results.
1251
        executor(Executor): The executor that saves the inference model. You can refer
G
guofei 已提交
1252 1253
                            to :ref:`api_guide_executor_en` for more details.
        main_program(Program, optional): The original program, which will be pruned to
T
tianshuo78520a 已提交
1254
                                         build the inference model. If is set None,
G
guofei 已提交
1255 1256 1257
                                         the global default :code:`_main_program_` will be used.
                                         Default: None.
        model_filename(str, optional): The name of file to save the inference program
T
tianshuo78520a 已提交
1258
                                       itself. If is set None, a default filename
G
guofei 已提交
1259 1260
                                       :code:`__model__` will be used.
        params_filename(str, optional): The name of file to save all related parameters.
T
tianshuo78520a 已提交
1261
                                        If it is set None, parameters will be saved
G
guofei 已提交
1262
                                        in separate files .
1263
        export_for_deployment(bool, optional): If True, programs are modified to only support
X
Xin Pan 已提交
1264 1265 1266 1267
                                     direct inference deployment. Otherwise,
                                     more information will be stored for flexible
                                     optimization and re-training. Currently, only
                                     True is supported.
G
guofei 已提交
1268
                                     Default: True.
1269
        program_only(bool, optional): If True, It will save inference program only, and do not
G
guofei 已提交
1270 1271
                                      save params of Program.
                                      Default: False.
1272

F
fengjiayi 已提交
1273
    Returns:
1274
        list, The fetch variables' name list.
F
fengjiayi 已提交
1275 1276 1277

    Examples:
        .. code-block:: python
F
fengjiayi 已提交
1278

1279
            import paddle
1280 1281
            import paddle.fluid as fluid

1282
            paddle.enable_static()
F
fengjiayi 已提交
1283 1284
            path = "./infer_model"

T
tianshuo78520a 已提交
1285
            # User defined network, here a softmax regession example
G
guofei 已提交
1286 1287
            image = fluid.data(name='img', shape=[None, 28, 28], dtype='float32')
            label = fluid.data(name='label', shape=[None, 1], dtype='int64')
1288 1289 1290 1291
            feeder = fluid.DataFeeder(feed_list=[image, label], place=fluid.CPUPlace())
            predict = fluid.layers.fc(input=image, size=10, act='softmax')

            loss = fluid.layers.cross_entropy(input=predict, label=label)
1292
            avg_loss = paddle.mean(loss)
1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304

            exe = fluid.Executor(fluid.CPUPlace())
            exe.run(fluid.default_startup_program())

            # Feed data and train process

            # Save inference model. Note we don't save label and loss in this example
            fluid.io.save_inference_model(dirname=path,
                                          feeded_var_names=['img'],
                                          target_vars=[predict],
                                          executor=exe)

G
guofei 已提交
1305
            # In this example, the save_inference_mode inference will prune the default
1306
            # main program according to the network's input node (img) and output node(predict).
G
guofei 已提交
1307
            # The pruned inference program is going to be saved in the "./infer_model/__model__"
F
fengjiayi 已提交
1308
            # and parameters are going to be saved in separate files under folder
1309
            # "./infer_model".
1310 1311

    """
1312
    if isinstance(feeded_var_names, str):
F
fengjiayi 已提交
1313
        feeded_var_names = [feeded_var_names]
X
Xin Pan 已提交
1314
    elif export_for_deployment:
Q
Qiao Longfei 已提交
1315
        if len(feeded_var_names) > 0:
1316
            # TODO(paddle-dev): polish these code blocks
1317 1318 1319
            if not (bool(feeded_var_names)
                    and all(isinstance(name, str)
                            for name in feeded_var_names)):
M
minqiyang 已提交
1320
                raise ValueError("'feed_var_names' should be a list of str.")
F
fengjiayi 已提交
1321 1322

    if isinstance(target_vars, Variable):
F
fengjiayi 已提交
1323
        target_vars = [target_vars]
X
Xin Pan 已提交
1324
    elif export_for_deployment:
1325 1326
        if not (bool(target_vars)
                and all(isinstance(var, Variable) for var in target_vars)):
F
fengjiayi 已提交
1327 1328
            raise ValueError("'target_vars' should be a list of Variable.")

C
chengduo 已提交
1329
    main_program = _get_valid_program(main_program)
T
tangwei12 已提交
1330

1331
    # remind user to set auc_states to zeros if the program contains auc op
1332 1333
    all_ops = main_program.global_block().ops
    for op in all_ops:
1334 1335 1336
        # clear device of Op
        device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
        op._set_attr(device_attr_name, "")
1337 1338 1339 1340 1341 1342
        if op.type == 'auc':
            warnings.warn(
                "please ensure that you have set the auc states to zeros before saving inference model"
            )
            break

1343 1344
    with program_guard(main_program):
        uniq_target_vars = []
F
flame 已提交
1345 1346
        for i, var in enumerate(target_vars):
            uniq_target_vars.append(var)
1347
        target_vars = uniq_target_vars
F
flame 已提交
1348
    target_var_name_list = [var.name for var in target_vars]
1349

1350
    # when a pserver and a trainer running on the same machine, mkdir may conflict
L
lujun 已提交
1351
    save_dirname = dirname
1352
    try:
L
lujun 已提交
1353 1354
        save_dirname = os.path.normpath(dirname)
        os.makedirs(save_dirname)
1355 1356 1357 1358
    except OSError as e:
        if e.errno != errno.EEXIST:
            raise

X
Xin Pan 已提交
1359 1360 1361 1362
    if model_filename is not None:
        model_basename = os.path.basename(model_filename)
    else:
        model_basename = "__model__"
L
lujun 已提交
1363
    model_basename = os.path.join(save_dirname, model_basename)
1364

X
Xin Pan 已提交
1365 1366 1367 1368
    # When export_for_deployment is true, we modify the program online so that
    # it can only be loaded for inference directly. If it's false, the whole
    # original program and related meta are saved so that future usage can be
    # more flexible.
1369 1370 1371

    origin_program = main_program.clone()

X
Xin Pan 已提交
1372
    if export_for_deployment:
X
Xin Pan 已提交
1373 1374
        main_program = main_program.clone()
        global_block = main_program.global_block()
1375
        need_to_remove_op_index = []
X
Xin Pan 已提交
1376 1377 1378
        for i, op in enumerate(global_block.ops):
            op.desc.set_is_target(False)
            if op.type == "feed" or op.type == "fetch":
1379 1380 1381 1382 1383
                need_to_remove_op_index.append(i)

        for index in need_to_remove_op_index[::-1]:
            global_block._remove_op(index)

X
Xin Pan 已提交
1384
        main_program.desc.flush()
X
Xin Pan 已提交
1385

1386 1387
        main_program = main_program._prune_with_input(
            feeded_var_names=feeded_var_names, targets=target_vars)
X
Xin Pan 已提交
1388
        main_program = main_program._inference_optimize(prune_read_op=True)
X
Xin Pan 已提交
1389 1390
        fetch_var_names = [v.name for v in target_vars]

1391 1392 1393 1394 1395
        for target_v in target_vars:
            if not main_program.global_block().has_var(target_v.name):
                main_program.global_block().create_var(
                    name=target_v.name,
                    shape=target_v.shape,
1396 1397
                    dtype=target_v.dtype,
                    persistable=target_v.persistable)
1398

X
Xin Pan 已提交
1399 1400 1401
        prepend_feed_ops(main_program, feeded_var_names)
        append_fetch_ops(main_program, fetch_var_names)

1402
        main_program.desc._set_version()
1403
        paddle.fluid.core.save_op_version_info(main_program.desc)
X
Xin Pan 已提交
1404
        with open(model_basename, "wb") as f:
1405
            f.write(
1406 1407
                main_program._remove_training_info(
                    clip_extra=clip_extra).desc.serialize_to_string())
X
Xin Pan 已提交
1408 1409 1410
    else:
        # TODO(panyx0718): Save more information so that it can also be used
        # for training and more flexible post-processing.
X
Xin Pan 已提交
1411
        with open(model_basename + ".main_program", "wb") as f:
1412
            f.write(
1413 1414
                main_program._remove_training_info(
                    clip_extra=clip_extra).desc.serialize_to_string())
T
tangwei12 已提交
1415

T
tangwei12 已提交
1416 1417 1418 1419 1420 1421
    if program_only:
        warnings.warn(
            "save_inference_model specified the param `program_only` to True, It will not save params of Program."
        )
        return target_var_name_list

1422 1423
    main_program._copy_dist_param_info_from(origin_program)

X
fix  
Xin Pan 已提交
1424 1425
    if params_filename is not None:
        params_filename = os.path.basename(params_filename)
1426

L
lujun 已提交
1427
    save_persistables(executor, save_dirname, main_program, params_filename)
F
flame 已提交
1428
    return target_var_name_list
X
fix  
Xin Pan 已提交
1429

1430

1431 1432
@static_only
@deprecated(since="2.0.0", update_to="paddle.static.load_inference_model")
1433 1434 1435
def load_inference_model(dirname,
                         executor,
                         model_filename=None,
T
tangwei12 已提交
1436 1437
                         params_filename=None,
                         pserver_endpoints=None):
1438
    """
1439 1440 1441
    Load the inference model from a given directory. By this API, you can get the model
    structure(Inference Program) and model parameters. If you just want to load
    parameters of the pre-trained model, please use the :ref:`api_fluid_io_load_params` API.
1442
    You can refer to :ref:`api_guide_model_save_reader_en` for more details.
1443

F
fengjiayi 已提交
1444
    Args:
1445 1446 1447
        dirname(str): One of the following:
          - The given directory path.
          - Set to None when reading the model from memory.
F
fengjiayi 已提交
1448
        executor(Executor): The executor to run for loading inference model.
1449
                            See :ref:`api_guide_executor_en` for more details about it.
1450 1451 1452 1453 1454 1455 1456
        model_filename(str, optional): One of the following:
          - The name of file to load the inference program.
          - If it is None, the default filename ``__model__`` will be used.
          - When ``dirname`` is ``None``, it must be set to a string containing model.
          Default: ``None``.
        params_filename(str, optional): It is only used for the case that all
            parameters were saved in a single binary file. One of the following:
1457
          - The name of file to load all parameters.
1458 1459 1460
          - When ``dirname`` is ``None``, it must be set to a string containing all the parameters.
          - If parameters were saved in separate files, set it as ``None``.
            Default: ``None``.
1461 1462 1463 1464

        pserver_endpoints(list, optional): It is only needed by the distributed inference.
                                    If using a distributed look up table during the training,
                                    this table is also needed by the inference process. Its value is
1465
                                    a list of pserver endpoints.
F
fengjiayi 已提交
1466 1467

    Returns:
1468
        list: The return of this API is a list with three elements:
1469
        (program, feed_target_names, fetch_targets). The `program` is a
1470 1471 1472 1473 1474
        ``Program`` (refer to :ref:`api_guide_Program_en`), which is used for inference.
        The `feed_target_names` is a list of ``str``, which contains names of variables
        that need to feed data in the inference program. The `fetch_targets` is a list of
        ``Variable`` (refer to :ref:`api_guide_Program_en`). It contains variables from which
        we can get inference results.
F
fengjiayi 已提交
1475 1476 1477 1478 1479


    Examples:
        .. code-block:: python

1480
            import paddle
1481 1482
            import paddle.fluid as fluid
            import numpy as np
1483

1484
            paddle.enable_static()
1485
            # Build the model
1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496
            main_prog = fluid.Program()
            startup_prog = fluid.Program()
            with fluid.program_guard(main_prog, startup_prog):
                data = fluid.layers.data(name="img", shape=[64, 784], append_batch_size=False)
                w = fluid.layers.create_parameter(shape=[784, 200], dtype='float32')
                b = fluid.layers.create_parameter(shape=[200], dtype='float32')
                hidden_w = fluid.layers.matmul(x=data, y=w)
                hidden_b = fluid.layers.elementwise_add(hidden_w, b)
            place = fluid.CPUPlace()
            exe = fluid.Executor(place)
            exe.run(startup_prog)
1497 1498

            # Save the inference model
F
fengjiayi 已提交
1499
            path = "./infer_model"
1500 1501
            fluid.io.save_inference_model(dirname=path, feeded_var_names=['img'],
                         target_vars=[hidden_b], executor=exe, main_program=main_prog)
1502 1503 1504

            # Demo one. Not need to set the distributed look up table, because the
            # training doesn't use a distributed look up table.
1505 1506
            [inference_program, feed_target_names, fetch_targets] = (
                fluid.io.load_inference_model(dirname=path, executor=exe))
1507
            tensor_img = np.array(np.random.random((1, 64, 784)), dtype=np.float32)
F
fengjiayi 已提交
1508 1509 1510 1511
            results = exe.run(inference_program,
                          feed={feed_target_names[0]: tensor_img},
                          fetch_list=fetch_targets)

1512 1513 1514
            # Demo two. If the training uses a distributed look up table, the pserver
            # endpoints list should be supported when loading the inference model.
            # The below is just an example.
1515
            endpoints = ["127.0.0.1:2023","127.0.0.1:2024"]
1516
            [dist_inference_program, dist_feed_target_names, dist_fetch_targets] = (
1517 1518
                fluid.io.load_inference_model(dirname=path,
                                              executor=exe,
1519
                                              pserver_endpoints=endpoints))
1520

1521
            # In this example, the inference program was saved in the file
1522
            # "./infer_model/__model__" and parameters were saved in
1523 1524 1525 1526
            # separate files under the directory "./infer_model".
            # By the inference program, feed_target_names and
            # fetch_targets, we can use an executor to run the inference
            # program for getting the inference result.
1527
    """
1528 1529 1530 1531
    load_from_memory = False
    if dirname is not None:
        load_dirname = os.path.normpath(dirname)
        if not os.path.isdir(load_dirname):
1532
            raise ValueError("There is no directory named '%s'" % dirname)
1533

1534 1535
        if model_filename is None:
            model_filename = '__model__'
1536

1537 1538
        model_filename = os.path.join(load_dirname,
                                      os.path.basename(model_filename))
1539

1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553
        if params_filename is not None:
            params_filename = os.path.basename(params_filename)

        with open(model_filename, "rb") as f:
            program_desc_str = f.read()
    else:
        load_from_memory = True
        if params_filename is None:
            raise ValueError(
                "The path of params cannot be None when the directory path is None."
            )
        load_dirname = dirname
        program_desc_str = model_filename
        params_filename = params_filename
1554

1555
    program = Program.parse_from_string(program_desc_str)
X
Xin Pan 已提交
1556
    if not core._is_program_version_supported(program._version()):
X
version  
Xin Pan 已提交
1557 1558 1559
        raise ValueError("Unsupported program version: %d\n" %
                         program._version())
    # Binary data also need versioning.
L
lujun 已提交
1560
    load_persistables(executor, load_dirname, program, params_filename)
1561

T
tangwei12 已提交
1562
    if pserver_endpoints:
T
tangwei12 已提交
1563
        program = _endpoints_replacement(program, pserver_endpoints)
T
tangwei12 已提交
1564

1565 1566
    feed_target_names = program.desc.get_feed_target_names()
    fetch_target_names = program.desc.get_fetch_target_names()
1567 1568 1569 1570 1571
    fetch_targets = [
        program.global_block().var(name) for name in fetch_target_names
    ]

    return [program, feed_target_names, fetch_targets]
X
xuwei06 已提交
1572 1573


T
tangwei12 已提交
1574 1575 1576
def _endpoints_replacement(program, endpoints):
    ENDPOINT_MAP = "epmap"
    for op in program.global_block().ops:
T
tangwei12 已提交
1577 1578
        if op.has_attr(ENDPOINT_MAP):
            op.set_attr(ENDPOINT_MAP, endpoints)
T
fix  
tangwei12 已提交
1579
    program._sync_with_cpp()
T
tangwei12 已提交
1580
    return program
T
tangwei12 已提交
1581 1582


X
xuwei06 已提交
1583 1584
def get_parameter_value(para, executor):
    """
F
fengjiayi 已提交
1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595
    Get the LoDTensor value of the given parameter.

    Args:
        para(Parameter): The parameter to get value from.
        executor(Executor): The executor to run for retrieving the value.

    Returns:
        numpy.array: The given parameter's values.

    Raises:
        AssertionError: If the `para` is not an instance of Parameter.
X
xuwei06 已提交
1596

F
fengjiayi 已提交
1597 1598
    Examples:
        .. code-block:: python
X
xuwei06 已提交
1599

1600
            import paddle
1601
            import paddle.fluid as fluid
1602 1603

            paddle.enable_static()
F
fengjiayi 已提交
1604 1605 1606
            exe = fluid.Executor(fluid.CPUPlace())
            param = fluid.default_main_program().global_block().var('fc.w')
            p = fluid.io.get_parameter_value(param, exe)
1607

X
xuwei06 已提交
1608
    """
1609
    assert is_parameter(para), "The input variable is not parameter."
X
xuwei06 已提交
1610

X
xuwei06 已提交
1611 1612 1613 1614 1615 1616 1617 1618
    get_program = Program()
    block = get_program.global_block()
    new_var = _clone_var_in_block_(block, para)
    return executor.run(get_program, feed={}, fetch_list=[new_var])[0]


def get_parameter_value_by_name(name, executor, program=None):
    """
F
fengjiayi 已提交
1619
    Get the LoDTensor value of a certain parameter by its name.
X
xuwei06 已提交
1620

F
fengjiayi 已提交
1621 1622 1623 1624 1625 1626 1627
    Args:
        name(str): The parameter's name.
        executor(Executor): The executor to run for retrieving the value.
        program(Program | None): The program where to find the parameter.
                               If it's set to be None, the function will
                               try to find the parameter in the default
                               main program.
X
xuwei06 已提交
1628

F
fengjiayi 已提交
1629 1630
    Returns:
        numpy.array: The parameter's values.
1631

F
fengjiayi 已提交
1632 1633 1634
    Examples:
        .. code-block:: python

1635
            import paddle
1636
            import paddle.fluid as fluid
1637 1638

            paddle.enable_static()
F
fengjiayi 已提交
1639 1640
            exe = fluid.Executor(fluid.CPUPlace())
            p = fluid.io.get_parameter_value('fc.w', exe)
X
xuwei06 已提交
1641 1642
    """
    if program is None:
Y
Yu Yang 已提交
1643
        program = default_main_program()
X
xuwei06 已提交
1644 1645
    var = program.global_block().var(name)
    return get_parameter_value(var, executor)
1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660


def _save_persistable_nodes(executor, dirname, graph):
    """
    Save persistable nodes to the given directory by the executor.

    Args:
        executor(Executor): The executor to run for saving node values.
        dirname(str): The directory path.
        graph(IrGraph): All the required persistable nodes in the graph will be saved.
    """
    persistable_node_names = set()
    persistable_nodes = []
    all_persistable_nodes = graph.all_persistable_nodes()
    for node in all_persistable_nodes:
1661
        name = node.name()
1662 1663 1664 1665 1666 1667 1668 1669
        if name not in persistable_node_names:
            persistable_node_names.add(name)
            persistable_nodes.append(node)
    program = Program()
    var_list = []
    for node in persistable_nodes:
        var_desc = node.var()
        if var_desc.type() == core.VarDesc.VarType.RAW or \
1670
                        var_desc.type() == core.VarDesc.VarType.READER:
1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695
            continue
        var = program.global_block().create_var(
            name=var_desc.name(),
            shape=var_desc.shape(),
            dtype=var_desc.dtype(),
            type=var_desc.type(),
            lod_level=var_desc.lod_level(),
            persistable=var_desc.persistable())
        var_list.append(var)
    save_vars(executor=executor, dirname=dirname, vars=var_list)


def _load_persistable_nodes(executor, dirname, graph):
    """
    Load persistable node values from the given directory by the executor.

    Args:
        executor(Executor): The executor to run for loading node values.
        dirname(str): The directory path.
        graph(IrGraph): All the required persistable nodes in the graph will be loaded.
    """
    persistable_node_names = set()
    persistable_nodes = []
    all_persistable_nodes = graph.all_persistable_nodes()
    for node in all_persistable_nodes:
1696
        name = node.name()
1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708
        if name not in persistable_node_names:
            persistable_node_names.add(name)
            persistable_nodes.append(node)
    program = Program()
    var_list = []

    def _exist(var):
        return os.path.exists(os.path.join(dirname, var.name))

    for node in persistable_nodes:
        var_desc = node.var()
        if var_desc.type() == core.VarDesc.VarType.RAW or \
1709
                        var_desc.type() == core.VarDesc.VarType.READER:
1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722
            continue
        var = program.global_block().create_var(
            name=var_desc.name(),
            shape=var_desc.shape(),
            dtype=var_desc.dtype(),
            type=var_desc.type(),
            lod_level=var_desc.lod_level(),
            persistable=var_desc.persistable())
        if _exist(var):
            var_list.append(var)
        else:
            _logger.warn("Cannot find the var %s!!!" % (node.name()))
    load_vars(executor=executor, dirname=dirname, vars=var_list)
H
hong 已提交
1723 1724


W
WeiXin 已提交
1725
def _unpack_saved_dict(saved_obj, protocol):
1726 1727
    temp_saved_obj = {}
    unpack_infor = {}
W
WeiXin 已提交
1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747
    # When pickle protocol=2 or protocol=3 the serialized object cannot be larger than 4G.
    if 1 < protocol < 4:
        if isinstance(saved_obj, dict):
            for key, value in saved_obj.items():
                if isinstance(value, np.ndarray):
                    MAX_NUMBER_OF_ELEMENT = int(
                        (2**30 - 1) / value.dtype.itemsize)
                    num_element = np.prod(value.shape)
                    if num_element > MAX_NUMBER_OF_ELEMENT:
                        unpack_infor[key] = {}
                        unpack_infor[key]["OriginShape"] = value.shape
                        unpack_infor[key]["slices"] = []
                        value = value.flatten()
                        for i in range(
                                int(
                                    math.ceil(num_element * 1.0 /
                                              MAX_NUMBER_OF_ELEMENT))):
                            part_name = key + "@@." + str(i)
                            unpack_infor[key]["slices"].append(part_name)
                            temp_saved_obj[part_name] = value[
1748 1749 1750
                                i *
                                MAX_NUMBER_OF_ELEMENT:MAX_NUMBER_OF_ELEMENT *
                                (i + 1)]
1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762

    if unpack_infor:
        for key, value in unpack_infor.items():
            if key in saved_obj:
                saved_obj.pop(key)
                for part in value['slices']:
                    saved_obj[part] = temp_saved_obj[part]
        saved_obj['UnpackBigParamInfor@@'] = unpack_infor
    return saved_obj


def _pack_loaded_dict(load_obj):
W
WeiXin 已提交
1763 1764 1765 1766 1767 1768
    if isinstance(load_obj, dict):
        unpack_info = 'UnpackBigParamInfor@@'
        if unpack_info in load_obj:
            removes = []
            for key, value in load_obj[unpack_info].items():
                slices = [load_obj[part] for part in value["slices"]]
1769 1770
                load_obj[key] = np.concatenate(slices).reshape(
                    value["OriginShape"])
W
WeiXin 已提交
1771 1772 1773 1774 1775
                removes += value["slices"]
            for key in removes:
                load_obj.pop(key)
            load_obj.pop(unpack_info)

1776 1777 1778
    return load_obj


1779
@static_only
1780
def _legacy_save(param_dict, model_path, protocol=2):
1781

1782
    def get_tensor(var):
J
Jiabin Yang 已提交
1783
        if isinstance(var, (core.VarBase, core.eager.Tensor)):
1784 1785 1786 1787 1788 1789 1790 1791
            return var.numpy()
        elif isinstance(var, core.LoDTensor):
            return np.array(var)
        return var

    param_dict = {name: get_tensor(param_dict[name]) for name in param_dict}

    # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3'
1792 1793 1794
    if _is_file_path(
            model_path
    ) and sys.platform == 'darwin' and sys.version_info.major == 3:
1795 1796 1797 1798 1799 1800
        pickle_bytes = pickle.dumps(param_dict, protocol=protocol)
        with open(model_path, 'wb') as f:
            max_bytes = 2**30
            for i in range(0, len(pickle_bytes), max_bytes):
                f.write(pickle_bytes[i:i + max_bytes])
    else:
1801
        with _open_file_buffer(model_path, 'wb') as f:
1802 1803 1804 1805
            pickle.dump(param_dict, f, protocol=protocol)


@static_only
1806
def save(program, model_path, protocol=4, **configs):
H
hong 已提交
1807
    """
1808

1809
    This function save parameters, optimizer information and network description to model_path.
H
hong 已提交
1810

1811 1812
    The parameters contains all the trainable Tensor, will save to a file with suffix ".pdparams".
    The optimizer information contains all the Tensor used by optimizer. For Adam optimizer, contains beta1, beta2, momentum etc. All the information will save to a file with suffix ".pdopt". (If the optimizer have no Tensor need to save (like SGD), the fill will not generated).
H
hong 已提交
1813
    The network description is the description of the program. It's only used for deployment. The description  will save to a file with a suffix ".pdmodel".
1814

H
hong 已提交
1815 1816 1817
    Args:
        program(Program) : The program to saved.
        model_path(str): the file prefix to save the program. The format is "dirname/file_prefix". If file_prefix is empty str. A exception will be raised
1818
        protocol(int, optional): The protocol version of pickle module must be greater than 1 and less than 5.
1819
                                 Default: 4
1820
        configs(dict, optional) : optional keyword arguments.
H
hong 已提交
1821 1822 1823 1824 1825 1826 1827

    Returns:
        None

    Examples:
        .. code-block:: python

1828
            import paddle
1829
            import paddle.static as static
H
hong 已提交
1830

1831
            paddle.enable_static()
H
hong 已提交
1832

1833 1834 1835 1836 1837 1838 1839 1840 1841 1842
            x = static.data(name="x", shape=[10, 10], dtype='float32')
            y = static.nn.fc(x, 10)
            z = static.nn.fc(y, 10)

            place = paddle.CPUPlace()
            exe = static.Executor(place)
            exe.run(static.default_startup_program())
            prog = static.default_main_program()

            static.save(prog, "./temp")
H
hong 已提交
1843 1844 1845 1846
    """

    base_name = os.path.basename(model_path)
    assert base_name != "", \
1847
        "The input model_path MUST be format of dirname/filename [dirname\\filename in Windows system], but received model_path is empty string."
1848 1849 1850 1851 1852
    if 'pickle_protocol' in configs:
        protocol = configs['pickle_protocol']
        warnings.warn(
            "'pickle_protocol' is a deprecated argument. Please use 'protocol' instead."
        )
H
hong 已提交
1853

1854
    if not isinstance(protocol, int):
W
WeiXin 已提交
1855
        raise ValueError("The 'protocol' MUST be `int`, but received {}".format(
1856
            type(protocol)))
W
WeiXin 已提交
1857

1858
    if protocol < 2 or protocol > 4:
1859 1860 1861
        raise ValueError(
            "Expected 1<'protocol'<5, but received protocol={}".format(
                protocol))
W
WeiXin 已提交
1862

1863 1864 1865 1866
    dir_name = os.path.dirname(model_path)
    if dir_name and not os.path.exists(dir_name):
        os.makedirs(dir_name)

Y
Yang Zhang 已提交
1867 1868 1869 1870
    def get_tensor(var):
        t = global_scope().find_var(var.name).get_tensor()
        return np.array(t)

H
hong 已提交
1871
    parameter_list = list(filter(is_parameter, program.list_vars()))
Y
Yang Zhang 已提交
1872
    param_dict = {p.name: get_tensor(p) for p in parameter_list}
W
WeiXin 已提交
1873

1874
    param_dict = _unpack_saved_dict(param_dict, protocol)
1875

1876 1877 1878
    # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3'
    if sys.platform == 'darwin' and sys.version_info.major == 3:
        pickle_bytes = pickle.dumps(param_dict, protocol=protocol)
1879 1880 1881 1882 1883 1884
        with open(model_path + ".pdparams", 'wb') as f:
            max_bytes = 2**30
            for i in range(0, len(pickle_bytes), max_bytes):
                f.write(pickle_bytes[i:i + max_bytes])
    else:
        with open(model_path + ".pdparams", 'wb') as f:
1885
            pickle.dump(param_dict, f, protocol=protocol)
H
hong 已提交
1886 1887 1888 1889

    optimizer_var_list = list(
        filter(is_belong_to_optimizer, program.list_vars()))

Y
Yang Zhang 已提交
1890 1891
    opt_dict = {p.name: get_tensor(p) for p in optimizer_var_list}
    with open(model_path + ".pdopt", 'wb') as f:
1892
        pickle.dump(opt_dict, f, protocol=protocol)
H
hong 已提交
1893 1894 1895 1896

    main_program = program.clone()
    program.desc.flush()
    main_program.desc._set_version()
1897
    paddle.fluid.core.save_op_version_info(program.desc)
H
hong 已提交
1898 1899 1900 1901 1902

    with open(model_path + ".pdmodel", "wb") as f:
        f.write(program.desc.serialize_to_string())


1903 1904 1905 1906 1907 1908
def _pickle_loads_mac(path, f):
    pickle_bytes = bytearray(0)
    file_size = os.path.getsize(path)
    max_bytes = 2**30
    for _ in range(0, file_size, max_bytes):
        pickle_bytes += f.read(max_bytes)
T
tianshuo78520a 已提交
1909
    load_result = pickle.loads(pickle_bytes, encoding='latin1')
1910 1911 1912
    return load_result


1913
@static_only
H
hong 已提交
1914
def load(program, model_path, executor=None, var_list=None):
H
hong 已提交
1915
    """
1916 1917
    :api_attr: Static Graph

H
hong 已提交
1918
    This function get parameters and optimizer information from program, and then get corresponding value from file.
1919
    An exception will throw if shape or dtype of the parameters is not match.
H
hong 已提交
1920

1921 1922
    This function can also load model file saved with [ save_params, save_persistables, save_vars ].
    var_list can not be None  when load single model file
H
hong 已提交
1923 1924
    ( filename is not None When save_params, save_persistables or save_vars is called ).

1925
    Args:
1926 1927
        program(Program): The program will be loaded
        model_path(str): The file prefix store the program
1928
        executor(Executor, optional): The executor used for initialize the parameter
1929
                                      When startup program is not run.
1930
        var_list(list|tuple, optional): The Tensor list/tuple to load single model file saved with
1931
                                  [ save_params, save_persistables, save_vars ].
H
hong 已提交
1932
                                  Default: None
H
hong 已提交
1933 1934 1935

    Returns:
        None
1936

H
hong 已提交
1937 1938 1939
     Examples:
        .. code-block:: python

1940
            import paddle
1941
            import paddle.static as static
H
hong 已提交
1942

1943
            paddle.enable_static()
H
hong 已提交
1944

1945 1946 1947
            x = static.data(name="x", shape=[10, 10], dtype='float32')
            y = static.nn.fc(x, 10)
            z = static.nn.fc(y, 10)
H
hong 已提交
1948

1949 1950 1951 1952 1953 1954 1955
            place = paddle.CPUPlace()
            exe = static.Executor(place)
            exe.run(static.default_startup_program())
            prog = static.default_main_program()

            static.save(prog, "./temp")
            static.load(prog, "./temp")
H
hong 已提交
1956 1957
    """

1958 1959
    assert executor is None or isinstance(executor, Executor)

H
hong 已提交
1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972
    model_prefix = model_path
    if model_prefix.endswith(".pdparams"):
        model_prefix = model_prefix[:-9]
    elif model_prefix.endswith(".pdopt"):
        model_prefix = model_prefix[:-6]
    elif model_prefix.endswith(".pdmodel"):
        model_prefix = model_prefix[:-8]

    parameter_file_name = model_prefix + ".pdparams"

    if not os.path.exists(parameter_file_name):
        # model file save by fluid.save not found, try to load model file saved with
        # [save_vars, save_params, save_persistables]
1973
        _logger.debug(
1974 1975
            "{} not found, try to load model file saved with [ save_params, save_persistables, save_vars ]"
            .format(parameter_file_name))
H
hong 已提交
1976 1977 1978 1979
        if executor is None:
            raise ValueError(
                "executor is required when loading model file saved with [ save_params, save_persistables, save_vars ]"
            )
1980 1981 1982 1983 1984 1985

        if var_list is not None:
            var_list_names = [var.name for var in var_list]
        else:
            var_list_names = None

H
hong 已提交
1986 1987 1988 1989 1990 1991 1992 1993 1994 1995
        if os.path.isdir(model_path):
            binary_file_set = set()
            for root, dirs, files in os.walk(model_path, topdown=False):
                for f in files:
                    binary_file_set.add(
                        os.path.join(root, f).replace("\\", "/"))
            program_var_list = list(program.list_vars())
            loaded_var_list = []
            for var in program_var_list:
                var_path = os.path.join(model_path, var.name).replace("\\", "/")
1996 1997
                load_condition = var_list_names is None or var.name in var_list_names
                if var_path in binary_file_set and load_condition:
H
hong 已提交
1998 1999 2000 2001 2002 2003 2004
                    loaded_var_list.append(var)
                    binary_file_set.remove(var_path)
            if len(binary_file_set) > 0:
                unused_var_list = " ".join(list(binary_file_set))
                _logger.warning("variable file [ %s ] not used" %
                                (" ".join(list(binary_file_set))))
            try:
2005 2006 2007
                load_vars(executor=executor,
                          dirname=model_path,
                          vars=loaded_var_list)
H
hong 已提交
2008 2009 2010 2011 2012
            except RuntimeError as e:
                _logger.error(e)
                raise e
            except:
                raise RuntimeError(
2013
                    "Failed to load model file, please make sure model file is saved with the "
H
hong 已提交
2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028
                    "following APIs: save_params, save_persistables, save_vars")

            return
        elif os.path.isfile(model_path):
            if var_list == None:
                raise ValueError(
                    "var_list is required when loading model file saved with [ save_params, save_persistables, save_vars ]"
                )
            program_var_list = program.list_vars()
            program_var_name_set = set([var.name for var in program_var_list])

            # check all the variable inlcuded in program
            for var in var_list:
                if var.name not in program_var_name_set:
                    raise LookupError(
2029
                        "loaded var [{}] is not in program variable list")
H
hong 已提交
2030 2031 2032

            dir_name, file_name = os.path.split(model_path)
            try:
2033 2034 2035 2036
                load_vars(executor=executor,
                          dirname=dir_name,
                          vars=var_list,
                          filename=file_name)
H
hong 已提交
2037 2038 2039 2040
            except RuntimeError as e:
                _logger.error(e)
                raise e
            except:
2041 2042 2043
                raise RuntimeError("Failed to load model file , please make sure model file is saved with the " \
                                   "the following APIs: [ save_params, save_persistables, save_vars ]. " \
                                   "When these API called, filename CANNOT be None")
H
hong 已提交
2044 2045

            return
Y
Yang Zhang 已提交
2046 2047 2048 2049 2050 2051 2052 2053

    def set_var(var, ndarray):
        t = global_scope().find_var(var.name).get_tensor()
        p = t._place()
        if p.is_cpu_place():
            place = paddle.fluid.CPUPlace()
        elif p.is_cuda_pinned_place():
            place = paddle.fluid.CUDAPinnedPlace()
2054 2055 2056 2057
        elif p.is_xpu_place():
            p = paddle.fluid.core.Place()
            p.set_place(t._place())
            place = paddle.fluid.XPUPlace(p.xpu_device_id())
2058 2059 2060 2061
        elif p.is_npu_place():
            p = paddle.fluid.core.Place()
            p.set_place(t._place())
            place = paddle.fluid.NPUPlace(p.npu_device_id())
2062 2063 2064 2065
        elif p.is_mlu_place():
            p = paddle.fluid.core.Place()
            p.set_place(t._place())
            place = paddle.fluid.MLUPlace(p.mlu_device_id())
Y
Yang Zhang 已提交
2066 2067 2068 2069 2070 2071
        else:
            p = paddle.fluid.core.Place()
            p.set_place(t._place())
            place = paddle.fluid.CUDAPlace(p.gpu_device_id())

        t.set(ndarray, place)
H
hong 已提交
2072 2073

    parameter_list = list(filter(is_parameter, program.list_vars()))
2074 2075 2076 2077 2078

    if executor:
        paddle.fluid.core._create_loaded_parameter(parameter_list,
                                                   global_scope(),
                                                   executor._default_executor)
Y
Yang Zhang 已提交
2079
    with open(parameter_file_name, 'rb') as f:
2080 2081 2082 2083 2084

        # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3'
        if sys.platform == 'darwin' and sys.version_info.major == 3:
            load_dict = _pickle_loads_mac(parameter_file_name, f)
        else:
T
tianshuo78520a 已提交
2085
            load_dict = pickle.load(f, encoding='latin1')
2086
        load_dict = _pack_loaded_dict(load_dict)
Y
Yang Zhang 已提交
2087 2088 2089 2090 2091
    for v in parameter_list:
        assert v.name in load_dict, \
            "Can not find [{}] in model file [{}]".format(
                v.name, parameter_file_name)
        set_var(v, load_dict[v.name])
H
hong 已提交
2092 2093 2094 2095 2096

    optimizer_var_list = list(
        filter(is_belong_to_optimizer, program.list_vars()))

    if len(optimizer_var_list) > 0:
H
hong 已提交
2097
        opt_file_name = model_prefix + ".pdopt"
H
hong 已提交
2098
        assert os.path.exists(opt_file_name), \
T
tangwei12 已提交
2099
            "Optimizer file [{}] not exits".format(opt_file_name)
2100 2101 2102 2103

        if executor:
            paddle.fluid.core._create_loaded_parameter(
                optimizer_var_list, global_scope(), executor._default_executor)
Y
Yang Zhang 已提交
2104 2105

        with open(opt_file_name, 'rb') as f:
T
tianshuo78520a 已提交
2106
            load_dict = pickle.load(f, encoding='latin1')
Y
Yang Zhang 已提交
2107 2108 2109 2110 2111
        for v in optimizer_var_list:
            assert v.name in load_dict, \
                "Can not find [{}] in model file [{}]".format(
                    v.name, opt_file_name)
            set_var(v, load_dict[v.name])
2112 2113


H
hong 已提交
2114
def load_program_state(model_path, var_list=None):
2115
    """
2116

2117
    Load program state from local file
2118

2119 2120
    Args:
        model_path(str): The file prefix store the program
2121
        var_list(list|tuple, optional): The Tensor list/tuple to load saved with
2122
                                  [ save_params, save_persistables, save_vars ].
H
hong 已提交
2123
                                  Default: None.
2124
                                  The var_list is only used to get name,
H
hong 已提交
2125
                                  will not be modified.
2126 2127 2128 2129
    Returns:
        state_dict(dict): the dict store Parameter and optimizer information

    Examples:
2130

2131 2132
        .. code-block:: python

2133
            import paddle
2134
            import paddle.static as static
2135 2136

            paddle.enable_static()
2137

2138 2139 2140
            x = static.data(name="x", shape=[10, 10], dtype='float32')
            y = static.nn.fc(x, 10)
            z = static.nn.fc(y, 10)
2141

2142 2143 2144 2145
            place = paddle.CPUPlace()
            exe = static.Executor(place)
            exe.run(static.default_startup_program())
            prog = static.default_main_program()
2146

2147 2148
            static.save(prog, "./temp")
            program_state = static.load_program_state("./temp")
2149
    """
H
hong 已提交
2150 2151 2152 2153 2154 2155 2156 2157 2158 2159 2160 2161
    model_prefix = model_path
    if model_prefix.endswith(".pdparams"):
        model_prefix = model_prefix[:-9]
    elif model_prefix.endswith(".pdopt"):
        model_prefix = model_prefix[:-6]
    elif model_prefix.endswith(".pdmodel"):
        model_prefix = model_prefix[:-8]

    parameter_file_name = model_prefix + ".pdparams"
    if not os.path.exists(parameter_file_name):
        # model file saved with fluid.save is not found, try to load model file saved with
        # [save_vars, save_params, save_persistables]
2162
        _logger.debug(
2163 2164
            "{} not found, try to load model file saved with [ save_params, save_persistables, save_vars ]"
            .format(parameter_file_name))
H
hong 已提交
2165 2166 2167 2168 2169 2170 2171 2172 2173 2174 2175 2176 2177 2178 2179 2180 2181 2182 2183 2184 2185 2186 2187 2188 2189

        var_name_list = []
        if var_list is None and os.path.isfile(model_path):
            raise ValueError(
                "var_list can not be None when model_path is a file type")

        for root, dirs, files in os.walk(model_path, topdown=False):
            for f in files:
                file_path = os.path.join(root, f)
                var_temp_name = os.path.relpath(file_path, model_path)
                var_temp_name = var_temp_name.replace("\\", "/")
                var_name_list.append(var_temp_name)

        with _load_program_scope():
            load_prog = Program()
            load_block = load_prog.global_block()

            def clone_var_to_block(block, var):
                if not isinstance(var, Variable):
                    raise TypeError("value in var_list must be variable")
                return block.create_var(
                    name=var.name,
                    shape=var.shape,
                    dtype=var.dtype,
                    type=var.type,
2190 2191
                    lod_level=var.lod_level if var.desc.type()
                    == core.VarDesc.VarType.LOD_TENSOR else None,
H
hong 已提交
2192 2193
                    persistable=True)

2194 2195 2196 2197 2198 2199
            def _load_vars_with_try_catch(exe,
                                          dirname,
                                          vars,
                                          filename,
                                          raise_error=True):
                try:
2200 2201 2202 2203
                    load_vars(executor=exe,
                              dirname=dirname,
                              vars=vars,
                              filename=filename)
2204 2205 2206 2207 2208 2209 2210 2211 2212 2213 2214 2215 2216 2217 2218 2219
                    return True
                except:
                    error_str = "Failed to load model/variables `%s`, please make sure " \
                                "model/variables file is saved with the following APIs: " \
                                "save_params, save_persistables, save_vars."
                    filenames = [var.name for var in vars
                                 ] if filename is None else filename
                    if raise_error:
                        raise RuntimeError(error_str % filenames)
                    else:
                        warnings.warn(error_str % filenames, RuntimeWarning)
                return False

            place = paddle.fluid.CPUPlace()
            exe = paddle.fluid.Executor(place)

H
hong 已提交
2220 2221
            loaded_var_list = []

2222 2223 2224
            if os.path.isfile(model_path):
                # when model_path is file, var_list cannot be None
                dir_name, file_name = os.path.split(model_path)
H
hong 已提交
2225 2226
                for var in var_list:
                    loaded_var_list.append(clone_var_to_block(load_block, var))
2227 2228
                _load_vars_with_try_catch(exe, dir_name, loaded_var_list,
                                          file_name)
H
hong 已提交
2229
            else:
2230 2231 2232 2233 2234 2235 2236
                # var_list can be None or not None
                if var_list is not None:
                    for var in var_list:
                        loaded_var_list.append(
                            clone_var_to_block(load_block, var))
                    _load_vars_with_try_catch(exe, model_path, loaded_var_list,
                                              None)
H
hong 已提交
2237
                else:
2238
                    for var_name in var_name_list:
2239 2240 2241 2242
                        # NOTE(chenweihang): If identify which files the user wants
                        # to load from the disk, we load these variables one by one.
                        # If a file does not exist, we only warn the user that the
                        # file may be an irrelevant file, but does not throw an error
2243
                        # to ensure that other legal variables can be loaded.
2244 2245
                        temp_var = load_block.create_var(name=var_name,
                                                         persistable=True)
2246 2247 2248 2249
                        if _load_vars_with_try_catch(exe, model_path,
                                                     [temp_var], None, False):
                            loaded_var_list.append(temp_var)

H
hong 已提交
2250 2251
            res_dict = {}
            for var in loaded_var_list:
2252 2253
                res_dict[var.name] = np.asarray(
                    paddle.fluid.global_scope().find_var(var.name).get_tensor())
H
hong 已提交
2254 2255 2256

            return res_dict

2257
    assert os.path.exists(parameter_file_name), \
T
tangwei12 已提交
2258
        "Parameter file [{}] not exits".format(parameter_file_name)
2259 2260

    with open(parameter_file_name, 'rb') as f:
2261 2262 2263 2264
        # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3'
        if sys.platform == 'darwin' and sys.version_info.major == 3:
            para_dict = _pickle_loads_mac(parameter_file_name, f)
        else:
T
tianshuo78520a 已提交
2265
            para_dict = pickle.load(f, encoding='latin1')
2266
    para_dict = _pack_loaded_dict(para_dict)
2267

H
hong 已提交
2268
    opt_file_name = model_prefix + ".pdopt"
2269 2270
    if os.path.exists(opt_file_name):
        with open(opt_file_name, 'rb') as f:
T
tianshuo78520a 已提交
2271
            opti_dict = pickle.load(f, encoding='latin1')
2272 2273 2274 2275 2276 2277

        para_dict.update(opti_dict)

    return para_dict


2278
@static_only
2279 2280 2281 2282
def set_program_state(program, state_dict):
    """
    Set program parameter from state_dict

2283
    An exception will throw if shape or dtype of the parameters is not match.
2284 2285 2286 2287 2288 2289

    NOTICE: This function MUST called after run start_up_program

    Args:
        program(Program): The program to be set
        state_dict(dict): the dict store Parameter and optimizer information
2290
    Returns:
2291
        None
2292

2293 2294
    Examples:
        .. code-block:: python
2295

2296
            import paddle
2297
            import paddle.static as static
2298 2299

            paddle.enable_static()
2300

2301 2302 2303
            x = static.data(name="x", shape=[10, 10], dtype='float32')
            y = static.nn.fc(x, 10)
            z = static.nn.fc(y, 10)
2304

2305 2306 2307 2308
            place = paddle.CPUPlace()
            exe = static.Executor(place)
            exe.run(static.default_startup_program())
            prog = static.default_main_program()
2309

2310 2311
            static.save(prog, "./temp")
            program_state = static.load_program_state("./temp")
H
hong 已提交
2312

2313
            static.set_program_state(prog, program_state)
2314
    """
2315
    state_dict = _pack_loaded_dict(state_dict)
2316 2317 2318 2319 2320 2321
    parameter_list = list(filter(is_persistable, program.list_vars()))

    used_para_list = {}
    for para in parameter_list:
        var_temp = paddle.fluid.global_scope().find_var(para.name)
        assert var_temp != None, \
T
tangwei12 已提交
2322
            "Variable [ {} ] Not found, Please make sure run startup program".format(para.name)
2323 2324 2325 2326
        if para.name in state_dict:
            # set value from state dict
            orig_para_np = np.array(var_temp.get_tensor())
            new_para_np = state_dict[para.name]
T
tangwei12 已提交
2327
            assert orig_para_np.shape == new_para_np.shape, \
2328
                "Parameter's shape does not match, the Program requires a parameter with the shape of ({}), " \
T
tangwei12 已提交
2329
                "while the loaded parameter (namely [ {} ]) has a shape of  ({})." \
2330
                    .format(orig_para_np.shape, para.name, new_para_np.shape)
T
tangwei12 已提交
2331
            assert orig_para_np.dtype == new_para_np.dtype, \
2332
                "Parameter's data type does not match, the Program requires a parameter with a dtype of ({}), " \
T
tangwei12 已提交
2333
                "while the loaded parameter (namely [ {} ]) has a dtype of  ({})." \
2334 2335 2336 2337 2338
                    .format(orig_para_np.dtype, para.name, new_para_np.dtype)

            ten = var_temp.get_tensor()
            ten_place = ten._place()

Q
QingshuChen 已提交
2339 2340
            #assert ten_place.is_gpu_place() or ten_place.is_cpu_place(), \
            #    "Place not support, only support CPUPlace and GPUPlace, now is {}".format(str(ten_place))
2341 2342 2343 2344 2345 2346 2347
            py_place = paddle.fluid.CPUPlace()
            if ten_place.is_cuda_pinned_place():
                place = paddle.fluid.CUDAPinnedPlace()
            elif ten_place.is_gpu_place():
                p = paddle.fluid.core.Place()
                p.set_place(ten_place)
                py_place = paddle.fluid.CUDAPlace(p.gpu_device_id())
Q
QingshuChen 已提交
2348 2349 2350 2351
            elif ten_place.is_xpu_place():
                p = paddle.fluid.core.Place()
                p.set_place(ten_place)
                py_place = paddle.fluid.XPUPlace(p.xpu_device_id())
2352 2353 2354 2355
            elif ten_place.is_npu_place():
                p = paddle.fluid.core.Place()
                p.set_place(ten_place)
                py_place = paddle.fluid.NPUPlace(p.npu_device_id())
2356 2357 2358 2359
            elif ten_place.is_mlu_place():
                p = paddle.fluid.core.Place()
                p.set_place(ten_place)
                py_place = paddle.fluid.MLUPlace(p.mlu_device_id())
2360 2361 2362 2363 2364 2365 2366 2367 2368 2369 2370

            ten.set(new_para_np, py_place)

            used_para_list[para.name] = 1

    unused_para_list = []
    for k, v in state_dict.items():
        if k not in used_para_list:
            unused_para_list.append(k)
    if len(unused_para_list) > 0:
        warnings.warn(
2371 2372
            "This list is not set, Because of Paramerter not found in program. There are: {}"
            .format(" ".join(unused_para_list)))