io.py 11.4 KB
Newer Older
D
dzhwinter 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#  Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
14
import os
15
import cPickle as pickle
16

Y
Yu Yang 已提交
17
from paddle.v2.fluid.framework import Program, Parameter, default_main_program, Variable
K
fix bug  
Kexin Zhao 已提交
18
from . import core
19 20

__all__ = [
21 22 23 24 25 26 27 28 29
    'save_vars',
    'save_params',
    'save_persistables',
    'load_vars',
    'load_params',
    'load_persistables',
    'save_inference_model',
    'load_inference_model',
    'get_inference_program',
30 31 32 33
]


def is_parameter(var):
K
Kavya Srinet 已提交
34
    """Check whether the variable is a Parameter.
35 36 37 38 39 40 41

    This function checks whether the input variable is a Parameter.

    Args:
        var : The input variable.

    Returns:
K
Kavya Srinet 已提交
42
        boolean result whether the variable is a Parameter.
43
    """
44 45 46 47 48 49 50 51 52 53 54 55
    return isinstance(var, Parameter)


def is_persistable(var):
    return var.persistable


def _clone_var_in_block_(block, var):
    assert isinstance(var, Variable)
    return block.create_var(
        name=var.name,
        shape=var.shape,
F
fengjiayi 已提交
56
        dtype=var.dtype,
57 58 59 60 61
        type=var.type,
        lod_level=var.lod_level,
        persistable=True)


62
def save_vars(executor, dirname, main_program=None, vars=None, predicate=None):
63 64
    """
    Save variables to directory by executor.
65

66 67
    :param executor: executor that save variable
    :param dirname: directory path
X
xuwei06 已提交
68
    :param main_program: program. If vars is None, then filter all variables in this
69
    program which fit `predicate`. Default default_main_program.
70 71 72 73 74 75 76
    :param predicate: The Predicate describes a callable that returns a variable
    as a bool. If it returns true, the variables will be saved.
    :param vars: variables need to be saved. If specify vars, program & predicate
    will be ignored
    :return: None
    """
    if vars is None:
77
        if main_program is None:
Y
Yu Yang 已提交
78
            main_program = default_main_program()
79
        if not isinstance(main_program, Program):
80 81 82 83 84
            raise TypeError("program should be as Program type or None")

        save_vars(
            executor,
            dirname=dirname,
85
            vars=filter(predicate, main_program.list_vars()))
86 87 88 89 90 91 92 93 94 95 96 97 98
    else:
        save_program = Program()
        save_block = save_program.global_block()
        for each_var in vars:
            new_var = _clone_var_in_block_(save_block, each_var)
            save_block.append_op(
                type='save',
                inputs={'X': [new_var]},
                outputs={},
                attrs={'file_path': os.path.join(dirname, new_var.name)})
        executor.run(save_program)


99
def save_params(executor, dirname, main_program=None):
100 101 102 103 104 105
    """
    Save all parameters to directory with executor.
    """
    save_vars(
        executor,
        dirname=dirname,
106
        main_program=main_program,
107 108 109 110
        vars=None,
        predicate=is_parameter)


111
def save_persistables(executor, dirname, main_program=None):
112 113 114 115 116 117
    """
    Save all persistables to directory with executor.
    """
    save_vars(
        executor,
        dirname=dirname,
118
        main_program=main_program,
119 120 121 122
        vars=None,
        predicate=is_persistable)


123
def load_vars(executor, dirname, main_program=None, vars=None, predicate=None):
124 125
    """
    Load variables from directory by executor.
126

127 128
    :param executor: executor that save variable
    :param dirname: directory path
X
xuwei06 已提交
129
    :param main_program: program. If vars is None, then filter all variables in this
Y
Yu Yang 已提交
130
    program which fit `predicate`. Default default_main_program().
131 132
    :param predicate: The Predicate describes a callable that returns a variable
    as a bool. If it returns true, the variables will be loaded.
X
xuwei06 已提交
133
    :param vars: variables need to be loaded. If specify vars, program &
134 135 136 137
    predicate will be ignored
    :return: None
    """
    if vars is None:
138
        if main_program is None:
Y
Yu Yang 已提交
139
            main_program = default_main_program()
140
        if not isinstance(main_program, Program):
141 142 143 144 145
            raise TypeError("program's type should be Program")

        load_vars(
            executor,
            dirname=dirname,
146
            vars=filter(predicate, main_program.list_vars()))
147 148 149 150 151 152 153 154 155 156 157
    else:
        load_prog = Program()
        load_block = load_prog.global_block()
        for each_var in vars:
            assert isinstance(each_var, Variable)
            new_var = _clone_var_in_block_(load_block, each_var)
            load_block.append_op(
                type='load',
                inputs={},
                outputs={"Out": [new_var]},
                attrs={'file_path': os.path.join(dirname, new_var.name)})
158

159 160 161
        executor.run(load_prog)


162
def load_params(executor, dirname, main_program=None):
163 164 165 166
    """
    load all parameters from directory by executor.
    """
    load_vars(
167 168 169 170
        executor,
        dirname=dirname,
        main_program=main_program,
        predicate=is_parameter)
171 172


173
def load_persistables(executor, dirname, main_program=None):
174 175 176 177
    """
    load all persistables from directory by executor.
    """
    load_vars(
178 179 180 181
        executor,
        dirname=dirname,
        main_program=main_program,
        predicate=is_persistable)
182 183


184 185
def get_inference_program(target_vars, main_program=None):
    if main_program is None:
Y
Yu Yang 已提交
186
        main_program = default_main_program()
187 188 189 190 191 192 193 194
    if not isinstance(target_vars, list):
        target_vars = [target_vars]

    pruned_program = main_program.prune(targets=target_vars)
    inference_program = pruned_program.inference_optimize()
    return inference_program


K
Kexin Zhao 已提交
195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
def prepend_feed_ops(inference_program, feeded_var_names):
    global_block = inference_program.global_block()
    feed_var = global_block.create_var(
        name='feed', type=core.VarDesc.VarType.FEED_MINIBATCH, persistable=True)

    for i, name in enumerate(feeded_var_names):
        out = global_block.var(name)
        global_block.prepend_op(
            type='feed',
            inputs={'X': [feed_var]},
            outputs={'Out': [out]},
            attrs={'col': i})


def append_fetch_ops(inference_program, fetch_var_names):
    global_block = inference_program.global_block()
    fetch_var = global_block.create_var(
        name='fetch', type=core.VarDesc.VarType.FETCH_LIST, persistable=True)

    for i, name in enumerate(fetch_var_names):
        global_block.append_op(
            type='fetch',
            inputs={'X': [name]},
            outputs={'Out': [fetch_var]},
            attrs={'col': i})


222 223 224 225
def save_inference_model(dirname,
                         feeded_var_names,
                         target_vars,
                         executor,
226
                         main_program=None):
227
    """
X
xuwei06 已提交
228
    Build a model especially for inference,
229 230 231 232 233 234
    and save it to directory by the executor.

    :param dirname: directory path
    :param feeded_var_names: Names of variables that need to be feeded data during inference
    :param target_vars: Variables from which we can get inference results.
    :param executor: executor that save inference model
X
xuwei06 已提交
235
    :param main_program: original program, which will be pruned to build the inference model.
Y
Yu Yang 已提交
236
            Default default_main_program().
237 238 239

    :return: None
    """
F
fengjiayi 已提交
240 241 242 243 244 245 246 247
    if isinstance(feeded_var_names, basestring):
        feeded_var_names = [feeded_var_names]
    else:
        if not (bool(feeded_var_names) and all(
                isinstance(name, basestring) for name in feeded_var_names)):
            raise ValueError("'feed_var_names' should be a list of str.")

    if isinstance(target_vars, Variable):
F
fengjiayi 已提交
248
        target_vars = [target_vars]
F
fengjiayi 已提交
249 250 251 252 253
    else:
        if not (bool(target_vars) and all(
                isinstance(var, Variable) for var in target_vars)):
            raise ValueError("'target_vars' should be a list of Variable.")

254
    if main_program is None:
Y
Yu Yang 已提交
255
        main_program = default_main_program()
256 257 258 259

    if not os.path.isdir(dirname):
        os.makedirs(dirname)

260 261
    pruned_program = main_program.prune(targets=target_vars)
    inference_program = pruned_program.inference_optimize()
262 263 264 265 266
    fetch_var_names = [v.name for v in target_vars]

    model_file_name = dirname + "/__model__"
    with open(model_file_name, "w") as f:
        pickle.dump({
267
            "program_desc_str": inference_program.desc.serialize_to_string(),
268 269 270 271
            "feed_var_names": feeded_var_names,
            "fetch_var_names": fetch_var_names
        }, f, -1)

272 273
    # Save only programDesc of inference_program in binary format
    # in another file: __model__.dat
K
Kexin Zhao 已提交
274 275
    prepend_feed_ops(inference_program, feeded_var_names)
    append_fetch_ops(inference_program, fetch_var_names)
276

277 278 279
    with open(model_file_name + ".dat", "wb") as fp:
        fp.write(inference_program.desc.serialize_to_string())

280
    save_params(executor, dirname, main_program)
281 282


283
def load_persistables_if_exist(executor, dirname, main_program=None):
284 285 286 287 288 289 290 291 292 293 294 295
    filenames = next(os.walk(dirname))[2]
    filenames = set(filenames)

    def _is_presistable_and_exist_(var):
        if not is_persistable(var):
            return False
        else:
            return var.name in filenames

    load_vars(
        executor,
        dirname,
296
        main_program=main_program,
297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325
        vars=None,
        predicate=_is_presistable_and_exist_)


def load_inference_model(dirname, executor):
    """
    Load inference model from a directory

    :param dirname: directory path
    :param executor: executor that load inference model

    :return: [program, feed_var_names, fetch_var_names]
             program: program especially for inference.
             feeded_var_names: Names of variables that need to feed data
             fetch_vars: Variables from which we can get inference results.
    """
    if not os.path.isdir(dirname):
        raise ValueError("There is no directory named '%s'", dirname)

    model_file_name = dirname + "/__model__"
    model = pickle.load(open(model_file_name, "r"))
    program_desc_str = model["program_desc_str"]
    feed_var_names = model["feed_var_names"]
    fetch_var_names = model["fetch_var_names"]
    program = Program.parse_from_string(program_desc_str)
    load_persistables_if_exist(executor, dirname, program)
    fetch_vars = [program.global_block().var(name) for name in fetch_var_names]

    return [program, feed_var_names, fetch_vars]
X
xuwei06 已提交
326 327 328 329 330 331 332 333 334 335


def get_parameter_value(para, executor):
    """
    Get the LoDTensor for the parameter

    :param executor: executor for retrieving the value
    :param para: the given parameter
    :return: the LoDTensor for the parameter
    """
X
xuwei06 已提交
336 337
    assert is_parameter(para)

X
xuwei06 已提交
338 339 340 341 342 343 344 345 346 347 348 349 350
    get_program = Program()
    block = get_program.global_block()
    new_var = _clone_var_in_block_(block, para)
    return executor.run(get_program, feed={}, fetch_list=[new_var])[0]


def get_parameter_value_by_name(name, executor, program=None):
    """
    Get the LoDTensor for paramter with the given name

    :param executor: executor for retrieving the value
    :param name: the name of the parameter
    :param program: the program where the variable is found
Y
Yu Yang 已提交
351
            Default default_main_program().
X
xuwei06 已提交
352 353 354
    :return: the LoDTensor for the variable
    """
    if program is None:
Y
Yu Yang 已提交
355
        program = default_main_program()
X
xuwei06 已提交
356 357
    var = program.global_block().var(name)
    return get_parameter_value(var, executor)