io.py 8.2 KB
Newer Older
1
import os
2
import cPickle as pickle
3

Q
Qiao Longfei 已提交
4
from paddle.v2.fluid.framework import Program, Parameter, g_main_program, \
5 6 7 8
    Variable

__all__ = [
    'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
9
    'load_persistables', "save_inference_model", "load_inference_model"
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
]


def is_parameter(var):
    return isinstance(var, Parameter)


def is_persistable(var):
    return var.persistable


def _clone_var_in_block_(block, var):
    assert isinstance(var, Variable)
    return block.create_var(
        name=var.name,
        shape=var.shape,
        dtype=var.data_type,
        type=var.type,
        lod_level=var.lod_level,
        persistable=True)


32
def save_vars(executor, dirname, main_program=None, vars=None, predicate=None):
33 34
    """
    Save variables to directory by executor.
35

36 37
    :param executor: executor that save variable
    :param dirname: directory path
X
xuwei06 已提交
38
    :param main_program: program. If vars is None, then filter all variables in this
39 40 41 42 43 44 45 46
    program which fit `predicate`. Default g_program.
    :param predicate: The Predicate describes a callable that returns a variable
    as a bool. If it returns true, the variables will be saved.
    :param vars: variables need to be saved. If specify vars, program & predicate
    will be ignored
    :return: None
    """
    if vars is None:
47 48 49
        if main_program is None:
            main_program = g_main_program
        if not isinstance(main_program, Program):
50 51 52 53 54
            raise TypeError("program should be as Program type or None")

        save_vars(
            executor,
            dirname=dirname,
55
            vars=filter(predicate, main_program.list_vars()))
56 57 58 59 60 61 62 63 64 65 66 67 68
    else:
        save_program = Program()
        save_block = save_program.global_block()
        for each_var in vars:
            new_var = _clone_var_in_block_(save_block, each_var)
            save_block.append_op(
                type='save',
                inputs={'X': [new_var]},
                outputs={},
                attrs={'file_path': os.path.join(dirname, new_var.name)})
        executor.run(save_program)


69
def save_params(executor, dirname, main_program=None):
70 71 72 73 74 75
    """
    Save all parameters to directory with executor.
    """
    save_vars(
        executor,
        dirname=dirname,
76
        main_program=main_program,
77 78 79 80
        vars=None,
        predicate=is_parameter)


81
def save_persistables(executor, dirname, main_program=None):
82 83 84 85 86 87
    """
    Save all persistables to directory with executor.
    """
    save_vars(
        executor,
        dirname=dirname,
88
        main_program=main_program,
89 90 91 92
        vars=None,
        predicate=is_persistable)


93
def load_vars(executor, dirname, main_program=None, vars=None, predicate=None):
94 95
    """
    Load variables from directory by executor.
96

97 98
    :param executor: executor that save variable
    :param dirname: directory path
X
xuwei06 已提交
99
    :param main_program: program. If vars is None, then filter all variables in this
100 101 102
    program which fit `predicate`. Default g_program.
    :param predicate: The Predicate describes a callable that returns a variable
    as a bool. If it returns true, the variables will be loaded.
X
xuwei06 已提交
103
    :param vars: variables need to be loaded. If specify vars, program &
104 105 106 107
    predicate will be ignored
    :return: None
    """
    if vars is None:
108 109 110
        if main_program is None:
            main_program = g_main_program
        if not isinstance(main_program, Program):
111 112 113 114 115
            raise TypeError("program's type should be Program")

        load_vars(
            executor,
            dirname=dirname,
116
            vars=filter(predicate, main_program.list_vars()))
117 118 119 120 121 122 123 124 125 126 127
    else:
        load_prog = Program()
        load_block = load_prog.global_block()
        for each_var in vars:
            assert isinstance(each_var, Variable)
            new_var = _clone_var_in_block_(load_block, each_var)
            load_block.append_op(
                type='load',
                inputs={},
                outputs={"Out": [new_var]},
                attrs={'file_path': os.path.join(dirname, new_var.name)})
128

129 130 131
        executor.run(load_prog)


132
def load_params(executor, dirname, main_program=None):
133 134 135 136
    """
    load all parameters from directory by executor.
    """
    load_vars(
137 138 139 140
        executor,
        dirname=dirname,
        main_program=main_program,
        predicate=is_parameter)
141 142


143
def load_persistables(executor, dirname, main_program=None):
144 145 146 147
    """
    load all persistables from directory by executor.
    """
    load_vars(
148 149 150 151
        executor,
        dirname=dirname,
        main_program=main_program,
        predicate=is_persistable)
152 153 154 155 156 157


def save_inference_model(dirname,
                         feeded_var_names,
                         target_vars,
                         executor,
158
                         main_program=None):
159
    """
X
xuwei06 已提交
160
    Build a model especially for inference,
161 162 163 164 165 166
    and save it to directory by the executor.

    :param dirname: directory path
    :param feeded_var_names: Names of variables that need to be feeded data during inference
    :param target_vars: Variables from which we can get inference results.
    :param executor: executor that save inference model
X
xuwei06 已提交
167 168
    :param main_program: original program, which will be pruned to build the inference model.
    Default g_main_program.
169 170 171

    :return: None
    """
172 173
    if main_program is None:
        main_program = g_main_program
174 175 176 177 178 179
    if not isinstance(target_vars, list):
        target_vars = [target_vars]

    if not os.path.isdir(dirname):
        os.makedirs(dirname)

180
    pruned_program = main_program.prune(target_vars)
181 182 183 184 185 186 187 188 189 190
    fetch_var_names = [v.name for v in target_vars]

    model_file_name = dirname + "/__model__"
    with open(model_file_name, "w") as f:
        pickle.dump({
            "program_desc_str": pruned_program.desc.serialize_to_string(),
            "feed_var_names": feeded_var_names,
            "fetch_var_names": fetch_var_names
        }, f, -1)

191
    save_params(executor, dirname, main_program)
192 193


194
def load_persistables_if_exist(executor, dirname, main_program=None):
195 196 197 198 199 200 201 202 203 204 205 206
    filenames = next(os.walk(dirname))[2]
    filenames = set(filenames)

    def _is_presistable_and_exist_(var):
        if not is_persistable(var):
            return False
        else:
            return var.name in filenames

    load_vars(
        executor,
        dirname,
207
        main_program=main_program,
208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236
        vars=None,
        predicate=_is_presistable_and_exist_)


def load_inference_model(dirname, executor):
    """
    Load inference model from a directory

    :param dirname: directory path
    :param executor: executor that load inference model

    :return: [program, feed_var_names, fetch_var_names]
             program: program especially for inference.
             feeded_var_names: Names of variables that need to feed data
             fetch_vars: Variables from which we can get inference results.
    """
    if not os.path.isdir(dirname):
        raise ValueError("There is no directory named '%s'", dirname)

    model_file_name = dirname + "/__model__"
    model = pickle.load(open(model_file_name, "r"))
    program_desc_str = model["program_desc_str"]
    feed_var_names = model["feed_var_names"]
    fetch_var_names = model["fetch_var_names"]
    program = Program.parse_from_string(program_desc_str)
    load_persistables_if_exist(executor, dirname, program)
    fetch_vars = [program.global_block().var(name) for name in fetch_var_names]

    return [program, feed_var_names, fetch_vars]
X
xuwei06 已提交
237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267


def get_parameter_value(para, executor):
    """
    Get the LoDTensor for the parameter

    :param executor: executor for retrieving the value
    :param para: the given parameter
    :return: the LoDTensor for the parameter
    """
    get_program = Program()
    block = get_program.global_block()
    new_var = _clone_var_in_block_(block, para)
    return executor.run(get_program, feed={}, fetch_list=[new_var])[0]


def get_parameter_value_by_name(name, executor, program=None):
    """
    Get the LoDTensor for paramter with the given name

    :param executor: executor for retrieving the value
    :param name: the name of the parameter
    :param program: the program where the variable is found
    Default g_main_program.
    :return: the LoDTensor for the variable
    """
    if program is None:
        program = g_main_program
    var = program.global_block().var(name)
    assert is_parameter(var)
    return get_parameter_value(var, executor)