io.py 19.6 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import os
T
bug fix  
tangwei12 已提交
16
import errno
D
dzhwinter 已提交
17
import warnings
18
import logging
Y
Yang Zhang 已提交
19
import pickle
H
hong 已提交
20
import contextlib
21
from functools import reduce
22
import sys
23
from io import BytesIO
24

H
hong 已提交
25
import numpy as np
26
import math
27
import paddle
28
from paddle.fluid import layers
H
hong 已提交
29
from paddle.fluid.executor import Executor, global_scope
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
from paddle.fluid.framework import (
    Program,
    Parameter,
    default_main_program,
    default_startup_program,
    Variable,
    program_guard,
    dygraph_not_support,
    static_only,
)
from paddle.reader import (
    cache,
    map_readers,
    buffered,
    compose,
    chain,
    shuffle,
    ComposeNotAligned,
    firstn,
    xmap_readers,
    multiprocess_reader,
)
52
from .wrapped_decorator import signature_safe_contextmanager
T
tangwei12 已提交
53
from paddle.fluid.compiler import CompiledProgram
54
from paddle.fluid.log_helper import get_logger
S
sneaxiy 已提交
55
from . import reader
56
from . import unique_name
S
sneaxiy 已提交
57
from .reader import *
58 59
from . import dataloader
from .dataloader import *
K
fix bug  
Kexin Zhao 已提交
60
from . import core
61 62
from paddle.utils import deprecated
from paddle.fluid.framework import static_only
63

64 65 66 67
__all__ = [
    'save_inference_model',
    'load_inference_model',
] + reader.__all__
68 69


70 71 72
_logger = get_logger(
    __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)
73 74


75 76 77
def prepend_feed_ops(
    inference_program, feed_target_names, feed_holder_name='feed'
):
Q
Qiao Longfei 已提交
78 79 80
    if len(feed_target_names) == 0:
        return

K
Kexin Zhao 已提交
81
    global_block = inference_program.global_block()
82 83 84 85 86
    feed_var = global_block.create_var(
        name=feed_holder_name,
        type=core.VarDesc.VarType.FEED_MINIBATCH,
        persistable=True,
    )
K
Kexin Zhao 已提交
87

88
    for i, name in enumerate(feed_target_names):
89 90 91 92
        if not global_block.has_var(name):
            raise ValueError(
                "The feeded_var_names[{i}]: '{name}' doesn't exist in pruned inference program. "
                "Please check whether '{name}' is a valid feed_var name, or remove it from feeded_var_names "
93 94 95 96
                "if '{name}' is not involved in the target_vars calculation.".format(
                    i=i, name=name
                )
            )
K
fix bug  
Kexin Zhao 已提交
97
        out = global_block.var(name)
98 99 100 101 102 103
        global_block._prepend_op(
            type='feed',
            inputs={'X': [feed_var]},
            outputs={'Out': [out]},
            attrs={'col': i},
        )
K
Kexin Zhao 已提交
104 105


106 107 108
def append_fetch_ops(
    inference_program, fetch_target_names, fetch_holder_name='fetch'
):
K
Kexin Zhao 已提交
109
    global_block = inference_program.global_block()
110 111 112 113 114
    fetch_var = global_block.create_var(
        name=fetch_holder_name,
        type=core.VarDesc.VarType.FETCH_LIST,
        persistable=True,
    )
K
Kexin Zhao 已提交
115

116
    for i, name in enumerate(fetch_target_names):
117 118 119 120 121 122
        global_block.append_op(
            type='fetch',
            inputs={'X': [name]},
            outputs={'Out': [fetch_var]},
            attrs={'col': i},
        )
K
Kexin Zhao 已提交
123 124


125 126
@static_only
@deprecated(since="2.0.0", update_to="paddle.static.save_inference_model")
127 128 129 130 131 132 133 134 135 136 137
def save_inference_model(
    dirname,
    feeded_var_names,
    target_vars,
    executor,
    main_program=None,
    model_filename=None,
    params_filename=None,
    export_for_deployment=True,
    program_only=False,
    clip_extra=True,
138
    legacy_format=False,
139
):
140
    """
F
fengjiayi 已提交
141
    Prune the given `main_program` to build a new program especially for inference,
G
guofei 已提交
142
    and then save it and all related parameters to given `dirname` .
143
    If you just want to save parameters of your trained model, please use the
G
guofei 已提交
144 145
    :ref:`api_fluid_io_save_params` . You can refer to :ref:`api_guide_model_save_reader_en`
    for more details.
146

G
guofei 已提交
147
    Note:
148
        The :code:`dirname` is used to specify the folder where inference model
G
guofei 已提交
149
        structure and parameters are going to be saved. If you would like to save params of
150
        Program in separate files, set `params_filename` None; if you would like to save all
G
guofei 已提交
151
        params of Program in a single file, use `params_filename` to specify the file name.
F
fengjiayi 已提交
152 153 154

    Args:
        dirname(str): The directory path to save the inference model.
T
tianshuo78520a 已提交
155
        feeded_var_names(list[str]): list of string. Names of variables that need to be fed
G
guofei 已提交
156
                                     data during inference.
157
        target_vars(list[Variable]): list of Variable. Variables from which we can get
G
guofei 已提交
158
                                     inference results.
159
        executor(Executor): The executor that saves the inference model. You can refer
G
guofei 已提交
160 161
                            to :ref:`api_guide_executor_en` for more details.
        main_program(Program, optional): The original program, which will be pruned to
T
tianshuo78520a 已提交
162
                                         build the inference model. If is set None,
G
guofei 已提交
163 164 165
                                         the global default :code:`_main_program_` will be used.
                                         Default: None.
        model_filename(str, optional): The name of file to save the inference program
T
tianshuo78520a 已提交
166
                                       itself. If is set None, a default filename
G
guofei 已提交
167 168
                                       :code:`__model__` will be used.
        params_filename(str, optional): The name of file to save all related parameters.
T
tianshuo78520a 已提交
169
                                        If it is set None, parameters will be saved
G
guofei 已提交
170
                                        in separate files .
171
        export_for_deployment(bool, optional): If True, programs are modified to only support
X
Xin Pan 已提交
172 173 174 175
                                     direct inference deployment. Otherwise,
                                     more information will be stored for flexible
                                     optimization and re-training. Currently, only
                                     True is supported.
G
guofei 已提交
176
                                     Default: True.
177
        program_only(bool, optional): If True, It will save inference program only, and do not
G
guofei 已提交
178 179
                                      save params of Program.
                                      Default: False.
180 181
        legacy_format(bool, optional): Whether to save program in legacy format.
                                       Default: False.
182

F
fengjiayi 已提交
183
    Returns:
184
        list, The fetch variables' name list.
F
fengjiayi 已提交
185 186 187

    Examples:
        .. code-block:: python
F
fengjiayi 已提交
188

189
            import paddle
190 191
            import paddle.fluid as fluid

192
            paddle.enable_static()
F
fengjiayi 已提交
193 194
            path = "./infer_model"

T
tianshuo78520a 已提交
195
            # User defined network, here a softmax regession example
196 197
            image = paddle.static.data(name='img', shape=[None, 28, 28], dtype='float32')
            label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
198
            feeder = fluid.DataFeeder(feed_list=[image, label], place=fluid.CPUPlace())
C
Charles-hit 已提交
199
            predict = paddle.static.nn.fc(x=image, size=10, activation='softmax')
200

201 202 203 204
            loss = paddle.nn.functional.cross_entropy(
                input=predict, label=label,
                reduction='none', use_softmax=False
            )
205
            avg_loss = paddle.mean(loss)
206 207 208 209 210 211 212 213 214 215 216 217

            exe = fluid.Executor(fluid.CPUPlace())
            exe.run(fluid.default_startup_program())

            # Feed data and train process

            # Save inference model. Note we don't save label and loss in this example
            fluid.io.save_inference_model(dirname=path,
                                          feeded_var_names=['img'],
                                          target_vars=[predict],
                                          executor=exe)

G
guofei 已提交
218
            # In this example, the save_inference_mode inference will prune the default
219
            # main program according to the network's input node (img) and output node(predict).
G
guofei 已提交
220
            # The pruned inference program is going to be saved in the "./infer_model/__model__"
F
fengjiayi 已提交
221
            # and parameters are going to be saved in separate files under folder
222
            # "./infer_model".
223 224

    """
225
    if isinstance(feeded_var_names, str):
F
fengjiayi 已提交
226
        feeded_var_names = [feeded_var_names]
X
Xin Pan 已提交
227
    elif export_for_deployment:
Q
Qiao Longfei 已提交
228
        if len(feeded_var_names) > 0:
229
            # TODO(paddle-dev): polish these code blocks
230 231 232 233
            if not (
                bool(feeded_var_names)
                and all(isinstance(name, str) for name in feeded_var_names)
            ):
M
minqiyang 已提交
234
                raise ValueError("'feed_var_names' should be a list of str.")
F
fengjiayi 已提交
235 236

    if isinstance(target_vars, Variable):
F
fengjiayi 已提交
237
        target_vars = [target_vars]
X
Xin Pan 已提交
238
    elif export_for_deployment:
239 240 241 242
        if not (
            bool(target_vars)
            and all(isinstance(var, Variable) for var in target_vars)
        ):
F
fengjiayi 已提交
243 244
            raise ValueError("'target_vars' should be a list of Variable.")

245
    main_program = paddle.static.io._get_valid_program(main_program)
T
tangwei12 已提交
246

247
    # remind user to set auc_states to zeros if the program contains auc op
248 249
    all_ops = main_program.global_block().ops
    for op in all_ops:
250 251 252
        # clear device of Op
        device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
        op._set_attr(device_attr_name, "")
253 254 255 256 257 258
        if op.type == 'auc':
            warnings.warn(
                "please ensure that you have set the auc states to zeros before saving inference model"
            )
            break

259 260
    with program_guard(main_program):
        uniq_target_vars = []
F
flame 已提交
261 262
        for i, var in enumerate(target_vars):
            uniq_target_vars.append(var)
263
        target_vars = uniq_target_vars
F
flame 已提交
264
    target_var_name_list = [var.name for var in target_vars]
265

266
    # when a pserver and a trainer running on the same machine, mkdir may conflict
L
lujun 已提交
267
    save_dirname = dirname
268
    try:
L
lujun 已提交
269 270
        save_dirname = os.path.normpath(dirname)
        os.makedirs(save_dirname)
271 272 273 274
    except OSError as e:
        if e.errno != errno.EEXIST:
            raise

X
Xin Pan 已提交
275 276 277 278
    if model_filename is not None:
        model_basename = os.path.basename(model_filename)
    else:
        model_basename = "__model__"
L
lujun 已提交
279
    model_basename = os.path.join(save_dirname, model_basename)
280

X
Xin Pan 已提交
281 282 283 284
    # When export_for_deployment is true, we modify the program online so that
    # it can only be loaded for inference directly. If it's false, the whole
    # original program and related meta are saved so that future usage can be
    # more flexible.
285 286 287

    origin_program = main_program.clone()

X
Xin Pan 已提交
288
    if export_for_deployment:
X
Xin Pan 已提交
289 290
        main_program = main_program.clone()
        global_block = main_program.global_block()
291
        need_to_remove_op_index = []
X
Xin Pan 已提交
292 293 294
        for i, op in enumerate(global_block.ops):
            op.desc.set_is_target(False)
            if op.type == "feed" or op.type == "fetch":
295 296 297 298 299
                need_to_remove_op_index.append(i)

        for index in need_to_remove_op_index[::-1]:
            global_block._remove_op(index)

X
Xin Pan 已提交
300
        main_program.desc.flush()
X
Xin Pan 已提交
301

302
        main_program = main_program._prune_with_input(
303 304
            feeded_var_names=feeded_var_names, targets=target_vars
        )
X
Xin Pan 已提交
305
        main_program = main_program._inference_optimize(prune_read_op=True)
X
Xin Pan 已提交
306 307
        fetch_var_names = [v.name for v in target_vars]

308 309 310 311 312
        for target_v in target_vars:
            if not main_program.global_block().has_var(target_v.name):
                main_program.global_block().create_var(
                    name=target_v.name,
                    shape=target_v.shape,
313
                    dtype=target_v.dtype,
314 315
                    persistable=target_v.persistable,
                )
316

X
Xin Pan 已提交
317 318 319 320
        prepend_feed_ops(main_program, feeded_var_names)
        append_fetch_ops(main_program, fetch_var_names)

        with open(model_basename, "wb") as f:
321
            f.write(
322
                main_program._remove_training_info(
323 324 325
                    clip_extra=clip_extra
                ).desc.serialize_to_string()
            )
X
Xin Pan 已提交
326 327 328
    else:
        # TODO(panyx0718): Save more information so that it can also be used
        # for training and more flexible post-processing.
X
Xin Pan 已提交
329
        with open(model_basename + ".main_program", "wb") as f:
330
            f.write(
331
                main_program._remove_training_info(
332 333 334
                    clip_extra=clip_extra
                ).desc.serialize_to_string()
            )
T
tangwei12 已提交
335

T
tangwei12 已提交
336 337 338 339 340 341
    if program_only:
        warnings.warn(
            "save_inference_model specified the param `program_only` to True, It will not save params of Program."
        )
        return target_var_name_list

342 343
    main_program._copy_dist_param_info_from(origin_program)

X
fix  
Xin Pan 已提交
344 345
    if params_filename is not None:
        params_filename = os.path.basename(params_filename)
346

347 348 349
    paddle.distributed.io.save_persistables(
        executor, save_dirname, main_program, params_filename
    )
F
flame 已提交
350
    return target_var_name_list
X
fix  
Xin Pan 已提交
351

352

353 354
@static_only
@deprecated(since="2.0.0", update_to="paddle.static.load_inference_model")
355 356 357 358 359 360 361
def load_inference_model(
    dirname,
    executor,
    model_filename=None,
    params_filename=None,
    pserver_endpoints=None,
):
362
    """
363 364 365
    Load the inference model from a given directory. By this API, you can get the model
    structure(Inference Program) and model parameters. If you just want to load
    parameters of the pre-trained model, please use the :ref:`api_fluid_io_load_params` API.
366
    You can refer to :ref:`api_guide_model_save_reader_en` for more details.
367

F
fengjiayi 已提交
368
    Args:
369 370 371
        dirname(str): One of the following:
          - The given directory path.
          - Set to None when reading the model from memory.
F
fengjiayi 已提交
372
        executor(Executor): The executor to run for loading inference model.
373
                            See :ref:`api_guide_executor_en` for more details about it.
374 375 376 377 378 379 380
        model_filename(str, optional): One of the following:
          - The name of file to load the inference program.
          - If it is None, the default filename ``__model__`` will be used.
          - When ``dirname`` is ``None``, it must be set to a string containing model.
          Default: ``None``.
        params_filename(str, optional): It is only used for the case that all
            parameters were saved in a single binary file. One of the following:
381
          - The name of file to load all parameters.
382 383 384
          - When ``dirname`` is ``None``, it must be set to a string containing all the parameters.
          - If parameters were saved in separate files, set it as ``None``.
            Default: ``None``.
385 386 387 388

        pserver_endpoints(list, optional): It is only needed by the distributed inference.
                                    If using a distributed look up table during the training,
                                    this table is also needed by the inference process. Its value is
389
                                    a list of pserver endpoints.
F
fengjiayi 已提交
390 391

    Returns:
392
        list: The return of this API is a list with three elements:
393
        (program, feed_target_names, fetch_targets). The `program` is a
394 395 396 397 398
        ``Program`` (refer to :ref:`api_guide_Program_en`), which is used for inference.
        The `feed_target_names` is a list of ``str``, which contains names of variables
        that need to feed data in the inference program. The `fetch_targets` is a list of
        ``Variable`` (refer to :ref:`api_guide_Program_en`). It contains variables from which
        we can get inference results.
F
fengjiayi 已提交
399 400 401 402 403


    Examples:
        .. code-block:: python

404
            import paddle
405 406
            import paddle.fluid as fluid
            import numpy as np
407

408
            paddle.enable_static()
409
            # Build the model
410 411 412
            main_prog = fluid.Program()
            startup_prog = fluid.Program()
            with fluid.program_guard(main_prog, startup_prog):
G
GGBond8488 已提交
413
                data = paddle.static.data(name="img", shape=[-1, 64, 784])
414 415
                w = paddle.create_parameter(shape=[784, 200], dtype='float32')
                b = paddle.create_parameter(shape=[200], dtype='float32')
416
                hidden_w = paddle.matmul(x=data, y=w)
H
HongyuJia 已提交
417
                hidden_b = paddle.add(hidden_w, b)
418 419 420
            place = fluid.CPUPlace()
            exe = fluid.Executor(place)
            exe.run(startup_prog)
421 422

            # Save the inference model
F
fengjiayi 已提交
423
            path = "./infer_model"
424 425
            fluid.io.save_inference_model(dirname=path, feeded_var_names=['img'],
                         target_vars=[hidden_b], executor=exe, main_program=main_prog)
426 427 428

            # Demo one. Not need to set the distributed look up table, because the
            # training doesn't use a distributed look up table.
429 430
            [inference_program, feed_target_names, fetch_targets] = (
                fluid.io.load_inference_model(dirname=path, executor=exe))
431
            tensor_img = np.array(np.random.random((1, 64, 784)), dtype=np.float32)
F
fengjiayi 已提交
432 433 434 435
            results = exe.run(inference_program,
                          feed={feed_target_names[0]: tensor_img},
                          fetch_list=fetch_targets)

436 437 438
            # Demo two. If the training uses a distributed look up table, the pserver
            # endpoints list should be supported when loading the inference model.
            # The below is just an example.
439
            endpoints = ["127.0.0.1:2023","127.0.0.1:2024"]
440
            [dist_inference_program, dist_feed_target_names, dist_fetch_targets] = (
441 442
                fluid.io.load_inference_model(dirname=path,
                                              executor=exe,
443
                                              pserver_endpoints=endpoints))
444

445
            # In this example, the inference program was saved in the file
446
            # "./infer_model/__model__" and parameters were saved in
447 448 449 450
            # separate files under the directory "./infer_model".
            # By the inference program, feed_target_names and
            # fetch_targets, we can use an executor to run the inference
            # program for getting the inference result.
451
    """
452 453 454 455
    load_from_memory = False
    if dirname is not None:
        load_dirname = os.path.normpath(dirname)
        if not os.path.isdir(load_dirname):
456
            raise ValueError("There is no directory named '%s'" % dirname)
457

458 459
        if model_filename is None:
            model_filename = '__model__'
460

461 462 463
        model_filename = os.path.join(
            load_dirname, os.path.basename(model_filename)
        )
464

465 466 467 468 469 470 471 472 473 474 475 476 477 478
        if params_filename is not None:
            params_filename = os.path.basename(params_filename)

        with open(model_filename, "rb") as f:
            program_desc_str = f.read()
    else:
        load_from_memory = True
        if params_filename is None:
            raise ValueError(
                "The path of params cannot be None when the directory path is None."
            )
        load_dirname = dirname
        program_desc_str = model_filename
        params_filename = params_filename
479

480
    program = Program.parse_from_string(program_desc_str)
X
Xin Pan 已提交
481
    if not core._is_program_version_supported(program._version()):
482 483 484
        raise ValueError(
            "Unsupported program version: %d\n" % program._version()
        )
X
version  
Xin Pan 已提交
485
    # Binary data also need versioning.
486 487 488
    paddle.distributed.io.load_persistables(
        executor, load_dirname, program, params_filename
    )
489

T
tangwei12 已提交
490
    if pserver_endpoints:
T
tangwei12 已提交
491
        program = _endpoints_replacement(program, pserver_endpoints)
T
tangwei12 已提交
492

493 494
    feed_target_names = program.desc.get_feed_target_names()
    fetch_target_names = program.desc.get_fetch_target_names()
495 496 497 498 499
    fetch_targets = [
        program.global_block().var(name) for name in fetch_target_names
    ]

    return [program, feed_target_names, fetch_targets]
X
xuwei06 已提交
500 501


T
tangwei12 已提交
502 503 504
def _endpoints_replacement(program, endpoints):
    ENDPOINT_MAP = "epmap"
    for op in program.global_block().ops:
T
tangwei12 已提交
505 506
        if op.has_attr(ENDPOINT_MAP):
            op.set_attr(ENDPOINT_MAP, endpoints)
T
fix  
tangwei12 已提交
507
    program._sync_with_cpp()
T
tangwei12 已提交
508
    return program