io.py 12.0 KB
Newer Older
D
dzhwinter 已提交
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import os

17
from paddle.v2.fluid.evaluator import Evaluator
Y
Yu Yang 已提交
18
from paddle.v2.fluid.framework import Program, Parameter, default_main_program, Variable
K
fix bug  
Kexin Zhao 已提交
19
from . import core
20 21

__all__ = [
22 23 24 25 26 27 28 29 30
    'save_vars',
    'save_params',
    'save_persistables',
    'load_vars',
    'load_params',
    'load_persistables',
    'save_inference_model',
    'load_inference_model',
    'get_inference_program',
31 32 33 34
]


def is_parameter(var):
K
Kavya Srinet 已提交
35
    """Check whether the variable is a Parameter.
36 37 38 39 40 41 42

    This function checks whether the input variable is a Parameter.

    Args:
        var : The input variable.

    Returns:
K
Kavya Srinet 已提交
43
        boolean result whether the variable is a Parameter.
44
    """
45 46 47 48 49 50 51 52 53 54 55 56
    return isinstance(var, Parameter)


def is_persistable(var):
    return var.persistable


def _clone_var_in_block_(block, var):
    assert isinstance(var, Variable)
    return block.create_var(
        name=var.name,
        shape=var.shape,
F
fengjiayi 已提交
57
        dtype=var.dtype,
58 59 60 61 62
        type=var.type,
        lod_level=var.lod_level,
        persistable=True)


63
def save_vars(executor, dirname, main_program=None, vars=None, predicate=None):
64 65
    """
    Save variables to directory by executor.
66

67 68
    :param executor: executor that save variable
    :param dirname: directory path
X
xuwei06 已提交
69
    :param main_program: program. If vars is None, then filter all variables in this
70
    program which fit `predicate`. Default default_main_program.
71 72 73 74 75 76 77
    :param predicate: The Predicate describes a callable that returns a variable
    as a bool. If it returns true, the variables will be saved.
    :param vars: variables need to be saved. If specify vars, program & predicate
    will be ignored
    :return: None
    """
    if vars is None:
78
        if main_program is None:
Y
Yu Yang 已提交
79
            main_program = default_main_program()
80
        if not isinstance(main_program, Program):
81 82 83 84 85
            raise TypeError("program should be as Program type or None")

        save_vars(
            executor,
            dirname=dirname,
86
            vars=filter(predicate, main_program.list_vars()))
87 88 89 90 91 92 93 94 95 96 97 98 99
    else:
        save_program = Program()
        save_block = save_program.global_block()
        for each_var in vars:
            new_var = _clone_var_in_block_(save_block, each_var)
            save_block.append_op(
                type='save',
                inputs={'X': [new_var]},
                outputs={},
                attrs={'file_path': os.path.join(dirname, new_var.name)})
        executor.run(save_program)


100
def save_params(executor, dirname, main_program=None):
101 102 103 104 105 106
    """
    Save all parameters to directory with executor.
    """
    save_vars(
        executor,
        dirname=dirname,
107
        main_program=main_program,
108 109 110 111
        vars=None,
        predicate=is_parameter)


112
def save_persistables(executor, dirname, main_program=None):
113 114 115 116 117 118
    """
    Save all persistables to directory with executor.
    """
    save_vars(
        executor,
        dirname=dirname,
119
        main_program=main_program,
120 121 122 123
        vars=None,
        predicate=is_persistable)


124
def load_vars(executor, dirname, main_program=None, vars=None, predicate=None):
125 126
    """
    Load variables from directory by executor.
127

128 129
    :param executor: executor that save variable
    :param dirname: directory path
X
xuwei06 已提交
130
    :param main_program: program. If vars is None, then filter all variables in this
Y
Yu Yang 已提交
131
    program which fit `predicate`. Default default_main_program().
132 133
    :param predicate: The Predicate describes a callable that returns a variable
    as a bool. If it returns true, the variables will be loaded.
X
xuwei06 已提交
134
    :param vars: variables need to be loaded. If specify vars, program &
135 136 137 138
    predicate will be ignored
    :return: None
    """
    if vars is None:
139
        if main_program is None:
Y
Yu Yang 已提交
140
            main_program = default_main_program()
141
        if not isinstance(main_program, Program):
142 143 144 145 146
            raise TypeError("program's type should be Program")

        load_vars(
            executor,
            dirname=dirname,
147
            vars=filter(predicate, main_program.list_vars()))
148 149 150 151 152 153 154 155 156 157 158
    else:
        load_prog = Program()
        load_block = load_prog.global_block()
        for each_var in vars:
            assert isinstance(each_var, Variable)
            new_var = _clone_var_in_block_(load_block, each_var)
            load_block.append_op(
                type='load',
                inputs={},
                outputs={"Out": [new_var]},
                attrs={'file_path': os.path.join(dirname, new_var.name)})
159

160 161 162
        executor.run(load_prog)


163
def load_params(executor, dirname, main_program=None):
164 165 166 167
    """
    load all parameters from directory by executor.
    """
    load_vars(
168 169 170 171
        executor,
        dirname=dirname,
        main_program=main_program,
        predicate=is_parameter)
172 173


174
def load_persistables(executor, dirname, main_program=None):
175 176 177 178
    """
    load all persistables from directory by executor.
    """
    load_vars(
179 180 181 182
        executor,
        dirname=dirname,
        main_program=main_program,
        predicate=is_persistable)
183 184


185 186
def get_inference_program(target_vars, main_program=None):
    if main_program is None:
Y
Yu Yang 已提交
187
        main_program = default_main_program()
188 189
    if not isinstance(target_vars, list):
        target_vars = [target_vars]
W
wanghaoshuang 已提交
190 191 192
    vars = []
    for var in target_vars:
        if isinstance(var, Evaluator):
W
wanghaoshuang 已提交
193 194
            vars.extend(var.states)
            vars.extend(var.metrics)
W
wanghaoshuang 已提交
195 196 197
        else:
            vars.append(var)
    pruned_program = main_program.prune(targets=vars)
198 199 200 201
    inference_program = pruned_program.inference_optimize()
    return inference_program


202 203 204
def prepend_feed_ops(inference_program,
                     feed_target_names,
                     feed_holder_name='feed'):
K
Kexin Zhao 已提交
205 206
    global_block = inference_program.global_block()
    feed_var = global_block.create_var(
207 208 209
        name=feed_holder_name,
        type=core.VarDesc.VarType.FEED_MINIBATCH,
        persistable=True)
K
Kexin Zhao 已提交
210

211
    for i, name in enumerate(feed_target_names):
K
fix bug  
Kexin Zhao 已提交
212
        out = global_block.var(name)
K
Kexin Zhao 已提交
213 214 215
        global_block.prepend_op(
            type='feed',
            inputs={'X': [feed_var]},
K
fix bug  
Kexin Zhao 已提交
216
            outputs={'Out': [out]},
K
Kexin Zhao 已提交
217 218 219
            attrs={'col': i})


220 221 222
def append_fetch_ops(inference_program,
                     fetch_target_names,
                     fetch_holder_name='fetch'):
K
Kexin Zhao 已提交
223 224
    global_block = inference_program.global_block()
    fetch_var = global_block.create_var(
225 226 227
        name=fetch_holder_name,
        type=core.VarDesc.VarType.FETCH_LIST,
        persistable=True)
K
Kexin Zhao 已提交
228

229
    for i, name in enumerate(fetch_target_names):
K
Kexin Zhao 已提交
230 231 232 233 234 235 236
        global_block.append_op(
            type='fetch',
            inputs={'X': [name]},
            outputs={'Out': [fetch_var]},
            attrs={'col': i})


237 238 239 240
def save_inference_model(dirname,
                         feeded_var_names,
                         target_vars,
                         executor,
241
                         main_program=None):
242
    """
X
xuwei06 已提交
243
    Build a model especially for inference,
244 245 246 247 248 249
    and save it to directory by the executor.

    :param dirname: directory path
    :param feeded_var_names: Names of variables that need to be feeded data during inference
    :param target_vars: Variables from which we can get inference results.
    :param executor: executor that save inference model
X
xuwei06 已提交
250
    :param main_program: original program, which will be pruned to build the inference model.
Y
Yu Yang 已提交
251
            Default default_main_program().
252 253 254

    :return: None
    """
F
fengjiayi 已提交
255 256 257 258 259 260 261 262
    if isinstance(feeded_var_names, basestring):
        feeded_var_names = [feeded_var_names]
    else:
        if not (bool(feeded_var_names) and all(
                isinstance(name, basestring) for name in feeded_var_names)):
            raise ValueError("'feed_var_names' should be a list of str.")

    if isinstance(target_vars, Variable):
F
fengjiayi 已提交
263
        target_vars = [target_vars]
F
fengjiayi 已提交
264 265 266 267 268
    else:
        if not (bool(target_vars) and all(
                isinstance(var, Variable) for var in target_vars)):
            raise ValueError("'target_vars' should be a list of Variable.")

269
    if main_program is None:
Y
Yu Yang 已提交
270
        main_program = default_main_program()
271 272 273 274

    if not os.path.isdir(dirname):
        os.makedirs(dirname)

275 276
    pruned_program = main_program.prune(targets=target_vars)
    inference_program = pruned_program.inference_optimize()
277 278
    fetch_var_names = [v.name for v in target_vars]

K
Kexin Zhao 已提交
279 280
    prepend_feed_ops(inference_program, feeded_var_names)
    append_fetch_ops(inference_program, fetch_var_names)
281

282 283 284
    model_file_name = dirname + "/__model__"
    with open(model_file_name, "wb") as f:
        f.write(inference_program.desc.serialize_to_string())
285

286
    save_params(executor, dirname, main_program)
287 288


289
def load_persistables_if_exist(executor, dirname, main_program=None):
290 291 292 293 294 295 296 297 298 299 300 301
    filenames = next(os.walk(dirname))[2]
    filenames = set(filenames)

    def _is_presistable_and_exist_(var):
        if not is_persistable(var):
            return False
        else:
            return var.name in filenames

    load_vars(
        executor,
        dirname,
302
        main_program=main_program,
303 304 305 306
        vars=None,
        predicate=_is_presistable_and_exist_)


307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324
def get_feed_targets_names(program):
    feed_targets_names = []
    global_block = program.global_block()
    for op in global_block.ops:
        if op.desc.type() == 'feed':
            feed_targets_names.insert(0, op.desc.output('Out')[0])
    return feed_targets_names


def get_fetch_targets_names(program):
    fetch_targets_names = []
    global_block = program.global_block()
    for op in global_block.ops:
        if op.desc.type() == 'fetch':
            fetch_targets_names.append(op.desc.input('X')[0])
    return fetch_targets_names


325 326 327 328 329 330 331
def load_inference_model(dirname, executor):
    """
    Load inference model from a directory

    :param dirname: directory path
    :param executor: executor that load inference model

332
    :return: [program, feed_target_names, fetch_targets]
333
             program: program especially for inference.
334 335
             feed_target_names: Names of variables that need to feed data
             fetch_targets: Variables from which we can get inference results.
336 337 338 339 340
    """
    if not os.path.isdir(dirname):
        raise ValueError("There is no directory named '%s'", dirname)

    model_file_name = dirname + "/__model__"
341 342 343
    with open(model_file_name, "rb") as f:
        program_desc_str = f.read()

344 345 346
    program = Program.parse_from_string(program_desc_str)
    load_persistables_if_exist(executor, dirname, program)

347 348 349 350 351 352 353
    feed_target_names = get_feed_targets_names(program)
    fetch_target_names = get_fetch_targets_names(program)
    fetch_targets = [
        program.global_block().var(name) for name in fetch_target_names
    ]

    return [program, feed_target_names, fetch_targets]
X
xuwei06 已提交
354 355 356 357 358 359 360 361 362 363


def get_parameter_value(para, executor):
    """
    Get the LoDTensor for the parameter

    :param executor: executor for retrieving the value
    :param para: the given parameter
    :return: the LoDTensor for the parameter
    """
X
xuwei06 已提交
364 365
    assert is_parameter(para)

X
xuwei06 已提交
366 367 368 369 370 371 372 373 374 375 376 377 378
    get_program = Program()
    block = get_program.global_block()
    new_var = _clone_var_in_block_(block, para)
    return executor.run(get_program, feed={}, fetch_list=[new_var])[0]


def get_parameter_value_by_name(name, executor, program=None):
    """
    Get the LoDTensor for paramter with the given name

    :param executor: executor for retrieving the value
    :param name: the name of the parameter
    :param program: the program where the variable is found
Y
Yu Yang 已提交
379
            Default default_main_program().
X
xuwei06 已提交
380 381 382
    :return: the LoDTensor for the variable
    """
    if program is None:
Y
Yu Yang 已提交
383
        program = default_main_program()
X
xuwei06 已提交
384 385
    var = program.global_block().var(name)
    return get_parameter_value(var, executor)