io.py 14.9 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import os

17 18
from paddle.fluid.evaluator import Evaluator
from paddle.fluid.framework import Program, Parameter, default_main_program, Variable
K
fix bug  
Kexin Zhao 已提交
19
from . import core
20 21

__all__ = [
22 23 24 25 26 27 28 29 30
    'save_vars',
    'save_params',
    'save_persistables',
    'load_vars',
    'load_params',
    'load_persistables',
    'save_inference_model',
    'load_inference_model',
    'get_inference_program',
31 32 33 34
]


def is_parameter(var):
K
Kavya Srinet 已提交
35
    """Check whether the variable is a Parameter.
36 37 38 39 40 41 42

    This function checks whether the input variable is a Parameter.

    Args:
        var : The input variable.

    Returns:
K
Kavya Srinet 已提交
43
        boolean result whether the variable is a Parameter.
44
    """
45 46 47 48
    return isinstance(var, Parameter)


def is_persistable(var):
49
    if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
Y
Refine  
Yu Yang 已提交
50
            var.desc.type() == core.VarDesc.VarType.FETCH_LIST:
51
        return False
52 53 54 55 56 57 58 59
    return var.persistable


def _clone_var_in_block_(block, var):
    assert isinstance(var, Variable)
    return block.create_var(
        name=var.name,
        shape=var.shape,
F
fengjiayi 已提交
60
        dtype=var.dtype,
61 62 63 64 65
        type=var.type,
        lod_level=var.lod_level,
        persistable=True)


66 67 68 69 70
def save_vars(executor,
              dirname,
              main_program=None,
              vars=None,
              predicate=None,
71
              filename=None):
72 73
    """
    Save variables to directory by executor.
74

75 76
    :param executor: executor that save variable
    :param dirname: directory path
X
xuwei06 已提交
77
    :param main_program: program. If vars is None, then filter all variables in this
78
    program which fit `predicate`. Default default_main_program.
79
    :param predicate: The Predicate describes a callable that returns a variable
80 81
    as a bool. If it returns true, the corresponding input variable will be saved.
    :param vars: variables need to be saved. If vars is specified, program & predicate
82
    will be ignored
83 84
    :param filename: The name of a single file that all vars are saved to.
        If it is None, save variables to separate files.
85

86 87 88
    :return: None
    """
    if vars is None:
89
        if main_program is None:
Y
Yu Yang 已提交
90
            main_program = default_main_program()
91
        if not isinstance(main_program, Program):
92 93 94 95 96
            raise TypeError("program should be as Program type or None")

        save_vars(
            executor,
            dirname=dirname,
97
            vars=filter(predicate, main_program.list_vars()),
98
            filename=filename)
99 100 101
    else:
        save_program = Program()
        save_block = save_program.global_block()
102 103

        save_var_map = {}
104
        for each_var in vars:
105 106 107
            # NOTE: don't save the variable which type is RAW
            if each_var.type == core.VarDesc.VarType.RAW:
                continue
108
            new_var = _clone_var_in_block_(save_block, each_var)
109
            if filename is None:
110 111 112 113 114 115 116 117
                save_block.append_op(
                    type='save',
                    inputs={'X': [new_var]},
                    outputs={},
                    attrs={'file_path': os.path.join(dirname, new_var.name)})
            else:
                save_var_map[new_var.name] = new_var

118
        if filename is not None:
119 120 121 122
            save_var_list = []
            for name in sorted(save_var_map.keys()):
                save_var_list.append(save_var_map[name])

123
            save_block.append_op(
124 125
                type='save_combine',
                inputs={'X': save_var_list},
126
                outputs={},
127
                attrs={'file_path': os.path.join(dirname, filename)})
128

129 130 131
        executor.run(save_program)


132
def save_params(executor, dirname, main_program=None, filename=None):
133 134 135 136 137 138
    """
    Save all parameters to directory with executor.
    """
    save_vars(
        executor,
        dirname=dirname,
139
        main_program=main_program,
140
        vars=None,
141
        predicate=is_parameter,
142
        filename=filename)
143 144


145
def save_persistables(executor, dirname, main_program=None, filename=None):
146 147 148 149 150 151
    """
    Save all persistables to directory with executor.
    """
    save_vars(
        executor,
        dirname=dirname,
152
        main_program=main_program,
153
        vars=None,
154
        predicate=is_persistable,
155
        filename=filename)
156 157


158 159 160 161 162
def load_vars(executor,
              dirname,
              main_program=None,
              vars=None,
              predicate=None,
163
              filename=None):
164 165
    """
    Load variables from directory by executor.
166

167
    :param executor: executor that load variable
168
    :param dirname: directory path
X
xuwei06 已提交
169
    :param main_program: program. If vars is None, then filter all variables in this
Y
Yu Yang 已提交
170
    program which fit `predicate`. Default default_main_program().
171
    :param predicate: The Predicate describes a callable that returns a variable
172 173
    as a bool. If it returns true, the corresponding input variable will be loaded.
    :param vars: variables need to be loaded. If vars is specified, program &
174
    predicate will be ignored
175 176
    :param filename: The name of the single file that all vars are loaded from.
        If it is None, load variables from separate files.
177

178 179 180
    :return: None
    """
    if vars is None:
181
        if main_program is None:
Y
Yu Yang 已提交
182
            main_program = default_main_program()
183
        if not isinstance(main_program, Program):
184 185 186 187 188
            raise TypeError("program's type should be Program")

        load_vars(
            executor,
            dirname=dirname,
189
            vars=filter(predicate, main_program.list_vars()),
190
            filename=filename)
191 192 193
    else:
        load_prog = Program()
        load_block = load_prog.global_block()
194 195

        load_var_map = {}
196 197 198
        for each_var in vars:
            assert isinstance(each_var, Variable)
            new_var = _clone_var_in_block_(load_block, each_var)
199
            if filename is None:
200 201 202 203 204 205 206 207
                load_block.append_op(
                    type='load',
                    inputs={},
                    outputs={'Out': [new_var]},
                    attrs={'file_path': os.path.join(dirname, new_var.name)})
            else:
                load_var_map[new_var.name] = new_var

208
        if filename is not None:
209 210 211 212
            load_var_list = []
            for name in sorted(load_var_map.keys()):
                load_var_list.append(load_var_map[name])

213
            load_block.append_op(
214
                type='load_combine',
215
                inputs={},
216
                outputs={"Out": load_var_list},
217
                attrs={'file_path': os.path.join(dirname, filename)})
218

219 220 221
        executor.run(load_prog)


222
def load_params(executor, dirname, main_program=None, filename=None):
223 224 225 226
    """
    load all parameters from directory by executor.
    """
    load_vars(
227 228 229
        executor,
        dirname=dirname,
        main_program=main_program,
230
        predicate=is_parameter,
231
        filename=filename)
232 233


234
def load_persistables(executor, dirname, main_program=None, filename=None):
235 236 237 238
    """
    load all persistables from directory by executor.
    """
    load_vars(
239 240 241
        executor,
        dirname=dirname,
        main_program=main_program,
242
        predicate=is_persistable,
243
        filename=filename)
244 245


246 247
def get_inference_program(target_vars, main_program=None):
    if main_program is None:
Y
Yu Yang 已提交
248
        main_program = default_main_program()
249 250
    if not isinstance(target_vars, list):
        target_vars = [target_vars]
W
wanghaoshuang 已提交
251 252 253
    vars = []
    for var in target_vars:
        if isinstance(var, Evaluator):
W
wanghaoshuang 已提交
254 255
            vars.extend(var.states)
            vars.extend(var.metrics)
W
wanghaoshuang 已提交
256 257 258
        else:
            vars.append(var)
    pruned_program = main_program.prune(targets=vars)
259 260 261 262
    inference_program = pruned_program.inference_optimize()
    return inference_program


263 264 265
def prepend_feed_ops(inference_program,
                     feed_target_names,
                     feed_holder_name='feed'):
K
Kexin Zhao 已提交
266 267
    global_block = inference_program.global_block()
    feed_var = global_block.create_var(
268 269 270
        name=feed_holder_name,
        type=core.VarDesc.VarType.FEED_MINIBATCH,
        persistable=True)
K
Kexin Zhao 已提交
271

272
    for i, name in enumerate(feed_target_names):
K
fix bug  
Kexin Zhao 已提交
273
        out = global_block.var(name)
K
Kexin Zhao 已提交
274 275 276
        global_block.prepend_op(
            type='feed',
            inputs={'X': [feed_var]},
K
fix bug  
Kexin Zhao 已提交
277
            outputs={'Out': [out]},
K
Kexin Zhao 已提交
278 279 280
            attrs={'col': i})


281 282 283
def append_fetch_ops(inference_program,
                     fetch_target_names,
                     fetch_holder_name='fetch'):
K
Kexin Zhao 已提交
284 285
    global_block = inference_program.global_block()
    fetch_var = global_block.create_var(
286 287 288
        name=fetch_holder_name,
        type=core.VarDesc.VarType.FETCH_LIST,
        persistable=True)
K
Kexin Zhao 已提交
289

290
    for i, name in enumerate(fetch_target_names):
K
Kexin Zhao 已提交
291 292 293 294 295 296 297
        global_block.append_op(
            type='fetch',
            inputs={'X': [name]},
            outputs={'Out': [fetch_var]},
            attrs={'col': i})


298 299 300 301
def save_inference_model(dirname,
                         feeded_var_names,
                         target_vars,
                         executor,
302
                         main_program=None,
303 304
                         model_filename=None,
                         params_filename=None):
305
    """
X
xuwei06 已提交
306
    Build a model especially for inference,
307 308 309 310 311 312
    and save it to directory by the executor.

    :param dirname: directory path
    :param feeded_var_names: Names of variables that need to be feeded data during inference
    :param target_vars: Variables from which we can get inference results.
    :param executor: executor that save inference model
X
xuwei06 已提交
313
    :param main_program: original program, which will be pruned to build the inference model.
Y
Yu Yang 已提交
314
            Default default_main_program().
315 316 317 318 319
    :param model_filename: The name of file to save inference program.
        If not specified, default filename `__model__` will be used.
    :param params_filename: The name of file to save parameters.
        It is used for the case that all parameters are saved in a single binary file.
        If not specified, parameters are considered saved in separate files.
320 321 322

    :return: None
    """
F
fengjiayi 已提交
323 324 325 326 327 328 329 330
    if isinstance(feeded_var_names, basestring):
        feeded_var_names = [feeded_var_names]
    else:
        if not (bool(feeded_var_names) and all(
                isinstance(name, basestring) for name in feeded_var_names)):
            raise ValueError("'feed_var_names' should be a list of str.")

    if isinstance(target_vars, Variable):
F
fengjiayi 已提交
331
        target_vars = [target_vars]
F
fengjiayi 已提交
332 333 334 335 336
    else:
        if not (bool(target_vars) and all(
                isinstance(var, Variable) for var in target_vars)):
            raise ValueError("'target_vars' should be a list of Variable.")

337
    if main_program is None:
Y
Yu Yang 已提交
338
        main_program = default_main_program()
339
    copy_program = main_program
340 341 342 343

    if not os.path.isdir(dirname):
        os.makedirs(dirname)

344
    # Clear the is_target information and remove the existed feed and fetch op
345
    global_block = copy_program.global_block()
346 347 348 349
    for i, op in enumerate(global_block.ops):
        op.desc.set_is_target(False)
        if op.type == "feed" or op.type == "fetch":
            global_block.remove_op(i)
350
    copy_program.desc.flush()
351

352
    pruned_program = copy_program.prune(targets=target_vars)
353
    inference_program = pruned_program.inference_optimize()
354 355
    fetch_var_names = [v.name for v in target_vars]

K
Kexin Zhao 已提交
356 357
    prepend_feed_ops(inference_program, feeded_var_names)
    append_fetch_ops(inference_program, fetch_var_names)
358

359 360
    if model_filename is not None:
        model_filename = os.path.basename(model_filename)
361
    else:
362 363
        model_filename = "__model__"
    model_filename = os.path.join(dirname, model_filename)
364

365 366 367 368
    if params_filename is not None:
        params_filename = os.path.basename(params_filename)

    with open(model_filename, "wb") as f:
369
        f.write(inference_program.desc.serialize_to_string())
370

371
    save_persistables(executor, dirname, inference_program, params_filename)
372 373


374 375 376 377
def load_inference_model(dirname,
                         executor,
                         model_filename=None,
                         params_filename=None):
378 379 380 381 382
    """
    Load inference model from a directory

    :param dirname: directory path
    :param executor: executor that load inference model
383 384 385 386 387 388
    :param model_filename: The name of file to load inference program.
        If not specified, default filename `__model__` will be used.
    :param params_filename: The name of file to load parameters.
        It is used for the case that all parameters are saved in a single binary file.
        If not specified, parameters are considered saved in separate files.

389
    :return: [program, feed_target_names, fetch_targets]
390
             program: program especially for inference.
391 392
             feed_target_names: Names of variables that need to feed data
             fetch_targets: Variables from which we can get inference results.
393 394 395 396
    """
    if not os.path.isdir(dirname):
        raise ValueError("There is no directory named '%s'", dirname)

397 398
    if model_filename is not None:
        model_filename = os.path.basename(model_filename)
399
    else:
400 401 402 403 404
        model_filename = "__model__"
    model_filename = os.path.join(dirname, model_filename)

    if params_filename is not None:
        params_filename = os.path.basename(params_filename)
405

406
    with open(model_filename, "rb") as f:
407 408
        program_desc_str = f.read()

409
    program = Program.parse_from_string(program_desc_str)
410
    load_persistables(executor, dirname, program, params_filename)
411

412 413
    feed_target_names = program.desc.get_feed_target_names()
    fetch_target_names = program.desc.get_fetch_target_names()
414 415 416 417 418
    fetch_targets = [
        program.global_block().var(name) for name in fetch_target_names
    ]

    return [program, feed_target_names, fetch_targets]
X
xuwei06 已提交
419 420 421 422 423 424 425 426


def get_parameter_value(para, executor):
    """
    Get the LoDTensor for the parameter

    :param executor: executor for retrieving the value
    :param para: the given parameter
427

X
xuwei06 已提交
428 429
    :return: the LoDTensor for the parameter
    """
X
xuwei06 已提交
430 431
    assert is_parameter(para)

X
xuwei06 已提交
432 433 434 435 436 437 438 439 440 441 442 443 444
    get_program = Program()
    block = get_program.global_block()
    new_var = _clone_var_in_block_(block, para)
    return executor.run(get_program, feed={}, fetch_list=[new_var])[0]


def get_parameter_value_by_name(name, executor, program=None):
    """
    Get the LoDTensor for paramter with the given name

    :param executor: executor for retrieving the value
    :param name: the name of the parameter
    :param program: the program where the variable is found
Y
Yu Yang 已提交
445
            Default default_main_program().
446

X
xuwei06 已提交
447 448 449
    :return: the LoDTensor for the variable
    """
    if program is None:
Y
Yu Yang 已提交
450
        program = default_main_program()
X
xuwei06 已提交
451 452
    var = program.global_block().var(name)
    return get_parameter_value(var, executor)