io.py 9.7 KB
Newer Older
1
import os
2
import cPickle as pickle
3

Y
Yu Yang 已提交
4
from paddle.v2.fluid.framework import Program, Parameter, default_main_program, Variable
5 6

__all__ = [
7 8 9 10 11 12 13 14 15
    'save_vars',
    'save_params',
    'save_persistables',
    'load_vars',
    'load_params',
    'load_persistables',
    'save_inference_model',
    'load_inference_model',
    'get_inference_program',
16 17 18 19
]


def is_parameter(var):
K
Kavya Srinet 已提交
20
    """Check whether the variable is a Parameter.
21 22 23 24 25 26 27

    This function checks whether the input variable is a Parameter.

    Args:
        var : The input variable.

    Returns:
K
Kavya Srinet 已提交
28
        boolean result whether the variable is a Parameter.
29
    """
30 31 32 33 34 35 36 37 38 39 40 41
    return isinstance(var, Parameter)


def is_persistable(var):
    return var.persistable


def _clone_var_in_block_(block, var):
    assert isinstance(var, Variable)
    return block.create_var(
        name=var.name,
        shape=var.shape,
F
fengjiayi 已提交
42
        dtype=var.dtype,
43 44 45 46 47
        type=var.type,
        lod_level=var.lod_level,
        persistable=True)


48
def save_vars(executor, dirname, main_program=None, vars=None, predicate=None):
49 50
    """
    Save variables to directory by executor.
51

52 53
    :param executor: executor that save variable
    :param dirname: directory path
X
xuwei06 已提交
54
    :param main_program: program. If vars is None, then filter all variables in this
55
    program which fit `predicate`. Default default_main_program.
56 57 58 59 60 61 62
    :param predicate: The Predicate describes a callable that returns a variable
    as a bool. If it returns true, the variables will be saved.
    :param vars: variables need to be saved. If specify vars, program & predicate
    will be ignored
    :return: None
    """
    if vars is None:
63
        if main_program is None:
Y
Yu Yang 已提交
64
            main_program = default_main_program()
65
        if not isinstance(main_program, Program):
66 67 68 69 70
            raise TypeError("program should be as Program type or None")

        save_vars(
            executor,
            dirname=dirname,
71
            vars=filter(predicate, main_program.list_vars()))
72 73 74 75 76 77 78 79 80 81 82 83 84
    else:
        save_program = Program()
        save_block = save_program.global_block()
        for each_var in vars:
            new_var = _clone_var_in_block_(save_block, each_var)
            save_block.append_op(
                type='save',
                inputs={'X': [new_var]},
                outputs={},
                attrs={'file_path': os.path.join(dirname, new_var.name)})
        executor.run(save_program)


85
def save_params(executor, dirname, main_program=None):
86 87 88 89 90 91
    """
    Save all parameters to directory with executor.
    """
    save_vars(
        executor,
        dirname=dirname,
92
        main_program=main_program,
93 94 95 96
        vars=None,
        predicate=is_parameter)


97
def save_persistables(executor, dirname, main_program=None):
98 99 100 101 102 103
    """
    Save all persistables to directory with executor.
    """
    save_vars(
        executor,
        dirname=dirname,
104
        main_program=main_program,
105 106 107 108
        vars=None,
        predicate=is_persistable)


109
def load_vars(executor, dirname, main_program=None, vars=None, predicate=None):
110 111
    """
    Load variables from directory by executor.
112

113 114
    :param executor: executor that save variable
    :param dirname: directory path
X
xuwei06 已提交
115
    :param main_program: program. If vars is None, then filter all variables in this
Y
Yu Yang 已提交
116
    program which fit `predicate`. Default default_main_program().
117 118
    :param predicate: The Predicate describes a callable that returns a variable
    as a bool. If it returns true, the variables will be loaded.
X
xuwei06 已提交
119
    :param vars: variables need to be loaded. If specify vars, program &
120 121 122 123
    predicate will be ignored
    :return: None
    """
    if vars is None:
124
        if main_program is None:
Y
Yu Yang 已提交
125
            main_program = default_main_program()
126
        if not isinstance(main_program, Program):
127 128 129 130 131
            raise TypeError("program's type should be Program")

        load_vars(
            executor,
            dirname=dirname,
132
            vars=filter(predicate, main_program.list_vars()))
133 134 135 136 137 138 139 140 141 142 143
    else:
        load_prog = Program()
        load_block = load_prog.global_block()
        for each_var in vars:
            assert isinstance(each_var, Variable)
            new_var = _clone_var_in_block_(load_block, each_var)
            load_block.append_op(
                type='load',
                inputs={},
                outputs={"Out": [new_var]},
                attrs={'file_path': os.path.join(dirname, new_var.name)})
144

145 146 147
        executor.run(load_prog)


148
def load_params(executor, dirname, main_program=None):
149 150 151 152
    """
    load all parameters from directory by executor.
    """
    load_vars(
153 154 155 156
        executor,
        dirname=dirname,
        main_program=main_program,
        predicate=is_parameter)
157 158


159
def load_persistables(executor, dirname, main_program=None):
160 161 162 163
    """
    load all persistables from directory by executor.
    """
    load_vars(
164 165 166 167
        executor,
        dirname=dirname,
        main_program=main_program,
        predicate=is_persistable)
168 169


170 171
def get_inference_program(target_vars, main_program=None):
    if main_program is None:
Y
Yu Yang 已提交
172
        main_program = default_main_program()
173 174 175 176 177 178 179 180
    if not isinstance(target_vars, list):
        target_vars = [target_vars]

    pruned_program = main_program.prune(targets=target_vars)
    inference_program = pruned_program.inference_optimize()
    return inference_program


181 182 183 184
def save_inference_model(dirname,
                         feeded_var_names,
                         target_vars,
                         executor,
185
                         main_program=None):
186
    """
X
xuwei06 已提交
187
    Build a model especially for inference,
188 189 190 191 192 193
    and save it to directory by the executor.

    :param dirname: directory path
    :param feeded_var_names: Names of variables that need to be feeded data during inference
    :param target_vars: Variables from which we can get inference results.
    :param executor: executor that save inference model
X
xuwei06 已提交
194
    :param main_program: original program, which will be pruned to build the inference model.
Y
Yu Yang 已提交
195
            Default default_main_program().
196 197 198

    :return: None
    """
F
fengjiayi 已提交
199 200 201 202 203 204 205 206
    if isinstance(feeded_var_names, basestring):
        feeded_var_names = [feeded_var_names]
    else:
        if not (bool(feeded_var_names) and all(
                isinstance(name, basestring) for name in feeded_var_names)):
            raise ValueError("'feed_var_names' should be a list of str.")

    if isinstance(target_vars, Variable):
F
fengjiayi 已提交
207
        target_vars = [target_vars]
F
fengjiayi 已提交
208 209 210 211 212
    else:
        if not (bool(target_vars) and all(
                isinstance(var, Variable) for var in target_vars)):
            raise ValueError("'target_vars' should be a list of Variable.")

213
    if main_program is None:
Y
Yu Yang 已提交
214
        main_program = default_main_program()
215 216 217 218

    if not os.path.isdir(dirname):
        os.makedirs(dirname)

219 220
    pruned_program = main_program.prune(targets=target_vars)
    inference_program = pruned_program.inference_optimize()
221 222 223 224 225
    fetch_var_names = [v.name for v in target_vars]

    model_file_name = dirname + "/__model__"
    with open(model_file_name, "w") as f:
        pickle.dump({
226
            "program_desc_str": inference_program.desc.serialize_to_string(),
227 228 229 230
            "feed_var_names": feeded_var_names,
            "fetch_var_names": fetch_var_names
        }, f, -1)

231 232 233 234 235
    # Save only programDesc of inference_program in binary format
    # in another file: __model__.dat
    with open(model_file_name + ".dat", "wb") as fp:
        fp.write(inference_program.desc.serialize_to_string())

236
    save_params(executor, dirname, main_program)
237 238


239
def load_persistables_if_exist(executor, dirname, main_program=None):
240 241 242 243 244 245 246 247 248 249 250 251
    filenames = next(os.walk(dirname))[2]
    filenames = set(filenames)

    def _is_presistable_and_exist_(var):
        if not is_persistable(var):
            return False
        else:
            return var.name in filenames

    load_vars(
        executor,
        dirname,
252
        main_program=main_program,
253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281
        vars=None,
        predicate=_is_presistable_and_exist_)


def load_inference_model(dirname, executor):
    """
    Load inference model from a directory

    :param dirname: directory path
    :param executor: executor that load inference model

    :return: [program, feed_var_names, fetch_var_names]
             program: program especially for inference.
             feeded_var_names: Names of variables that need to feed data
             fetch_vars: Variables from which we can get inference results.
    """
    if not os.path.isdir(dirname):
        raise ValueError("There is no directory named '%s'", dirname)

    model_file_name = dirname + "/__model__"
    model = pickle.load(open(model_file_name, "r"))
    program_desc_str = model["program_desc_str"]
    feed_var_names = model["feed_var_names"]
    fetch_var_names = model["fetch_var_names"]
    program = Program.parse_from_string(program_desc_str)
    load_persistables_if_exist(executor, dirname, program)
    fetch_vars = [program.global_block().var(name) for name in fetch_var_names]

    return [program, feed_var_names, fetch_vars]
X
xuwei06 已提交
282 283 284 285 286 287 288 289 290 291


def get_parameter_value(para, executor):
    """
    Get the LoDTensor for the parameter

    :param executor: executor for retrieving the value
    :param para: the given parameter
    :return: the LoDTensor for the parameter
    """
X
xuwei06 已提交
292 293
    assert is_parameter(para)

X
xuwei06 已提交
294 295 296 297 298 299 300 301 302 303 304 305 306
    get_program = Program()
    block = get_program.global_block()
    new_var = _clone_var_in_block_(block, para)
    return executor.run(get_program, feed={}, fetch_list=[new_var])[0]


def get_parameter_value_by_name(name, executor, program=None):
    """
    Get the LoDTensor for paramter with the given name

    :param executor: executor for retrieving the value
    :param name: the name of the parameter
    :param program: the program where the variable is found
Y
Yu Yang 已提交
307
            Default default_main_program().
X
xuwei06 已提交
308 309 310
    :return: the LoDTensor for the variable
    """
    if program is None:
Y
Yu Yang 已提交
311
        program = default_main_program()
X
xuwei06 已提交
312 313
    var = program.global_block().var(name)
    return get_parameter_value(var, executor)