io.py 7.3 KB
Newer Older
1
import os
2
import cPickle as pickle
3

4
from paddle.v2.framework.framework import Program, Parameter, g_main_program, \
5 6 7 8
    Variable

__all__ = [
    'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
9
    'load_persistables', "save_inference_model", "load_inference_model"
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
]


def is_parameter(var):
    return isinstance(var, Parameter)


def is_persistable(var):
    return var.persistable


def _clone_var_in_block_(block, var):
    assert isinstance(var, Variable)
    return block.create_var(
        name=var.name,
        shape=var.shape,
        dtype=var.data_type,
        type=var.type,
        lod_level=var.lod_level,
        persistable=True)


32
def save_vars(executor, dirname, main_program=None, vars=None, predicate=None):
33 34
    """
    Save variables to directory by executor.
35

36 37
    :param executor: executor that save variable
    :param dirname: directory path
38
    :param main_program: program. If vars is None, then filter all variables in this 
39 40 41 42 43 44 45 46
    program which fit `predicate`. Default g_program.
    :param predicate: The Predicate describes a callable that returns a variable
    as a bool. If it returns true, the variables will be saved.
    :param vars: variables need to be saved. If specify vars, program & predicate
    will be ignored
    :return: None
    """
    if vars is None:
47 48 49
        if main_program is None:
            main_program = g_main_program
        if not isinstance(main_program, Program):
50 51 52 53 54
            raise TypeError("program should be as Program type or None")

        save_vars(
            executor,
            dirname=dirname,
55
            vars=filter(predicate, main_program.list_vars()))
56 57 58 59 60 61 62 63 64 65 66 67 68
    else:
        save_program = Program()
        save_block = save_program.global_block()
        for each_var in vars:
            new_var = _clone_var_in_block_(save_block, each_var)
            save_block.append_op(
                type='save',
                inputs={'X': [new_var]},
                outputs={},
                attrs={'file_path': os.path.join(dirname, new_var.name)})
        executor.run(save_program)


69
def save_params(executor, dirname, main_program=None):
70 71 72 73 74 75
    """
    Save all parameters to directory with executor.
    """
    save_vars(
        executor,
        dirname=dirname,
76
        main_program=main_program,
77 78 79 80
        vars=None,
        predicate=is_parameter)


81
def save_persistables(executor, dirname, main_program=None):
82 83 84 85 86 87
    """
    Save all persistables to directory with executor.
    """
    save_vars(
        executor,
        dirname=dirname,
88
        main_program=main_program,
89 90 91 92
        vars=None,
        predicate=is_persistable)


93
def load_vars(executor, dirname, main_program=None, vars=None, predicate=None):
94 95
    """
    Load variables from directory by executor.
96

97 98
    :param executor: executor that save variable
    :param dirname: directory path
99
    :param main_program: program. If vars is None, then filter all variables in this 
100 101 102 103 104 105 106 107
    program which fit `predicate`. Default g_program.
    :param predicate: The Predicate describes a callable that returns a variable
    as a bool. If it returns true, the variables will be loaded.
    :param vars: variables need to be loaded. If specify vars, program & 
    predicate will be ignored
    :return: None
    """
    if vars is None:
108 109 110
        if main_program is None:
            main_program = g_main_program
        if not isinstance(main_program, Program):
111 112 113 114 115
            raise TypeError("program's type should be Program")

        load_vars(
            executor,
            dirname=dirname,
116
            vars=filter(predicate, main_program.list_vars()))
117 118 119 120 121 122 123 124 125 126 127
    else:
        load_prog = Program()
        load_block = load_prog.global_block()
        for each_var in vars:
            assert isinstance(each_var, Variable)
            new_var = _clone_var_in_block_(load_block, each_var)
            load_block.append_op(
                type='load',
                inputs={},
                outputs={"Out": [new_var]},
                attrs={'file_path': os.path.join(dirname, new_var.name)})
128

129 130 131
        executor.run(load_prog)


132
def load_params(executor, dirname, main_program=None):
133 134 135 136
    """
    load all parameters from directory by executor.
    """
    load_vars(
137 138 139 140
        executor,
        dirname=dirname,
        main_program=main_program,
        predicate=is_parameter)
141 142


143
def load_persistables(executor, dirname, main_program=None):
144 145 146 147
    """
    load all persistables from directory by executor.
    """
    load_vars(
148 149 150 151
        executor,
        dirname=dirname,
        main_program=main_program,
        predicate=is_persistable)
152 153 154 155 156 157


def save_inference_model(dirname,
                         feeded_var_names,
                         target_vars,
                         executor,
158
                         main_program=None):
159 160 161 162 163 164 165 166
    """
    Build a model especially for inference, 
    and save it to directory by the executor.

    :param dirname: directory path
    :param feeded_var_names: Names of variables that need to be feeded data during inference
    :param target_vars: Variables from which we can get inference results.
    :param executor: executor that save inference model
167
    :param main_program: original program, which will be pruned to build the inference model. 
168 169 170 171
    Default g_program.

    :return: None
    """
172 173
    if main_program is None:
        main_program = g_main_program
174 175 176 177 178 179
    if not isinstance(target_vars, list):
        target_vars = [target_vars]

    if not os.path.isdir(dirname):
        os.makedirs(dirname)

180
    pruned_program = main_program.prune(target_vars)
181 182 183 184 185 186 187 188 189 190
    fetch_var_names = [v.name for v in target_vars]

    model_file_name = dirname + "/__model__"
    with open(model_file_name, "w") as f:
        pickle.dump({
            "program_desc_str": pruned_program.desc.serialize_to_string(),
            "feed_var_names": feeded_var_names,
            "fetch_var_names": fetch_var_names
        }, f, -1)

191
    save_params(executor, dirname, main_program)
192 193


194
def load_persistables_if_exist(executor, dirname, main_program=None):
195 196 197 198 199 200 201 202 203 204 205 206
    filenames = next(os.walk(dirname))[2]
    filenames = set(filenames)

    def _is_presistable_and_exist_(var):
        if not is_persistable(var):
            return False
        else:
            return var.name in filenames

    load_vars(
        executor,
        dirname,
207
        main_program=main_program,
208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236
        vars=None,
        predicate=_is_presistable_and_exist_)


def load_inference_model(dirname, executor):
    """
    Load inference model from a directory

    :param dirname: directory path
    :param executor: executor that load inference model

    :return: [program, feed_var_names, fetch_var_names]
             program: program especially for inference.
             feeded_var_names: Names of variables that need to feed data
             fetch_vars: Variables from which we can get inference results.
    """
    if not os.path.isdir(dirname):
        raise ValueError("There is no directory named '%s'", dirname)

    model_file_name = dirname + "/__model__"
    model = pickle.load(open(model_file_name, "r"))
    program_desc_str = model["program_desc_str"]
    feed_var_names = model["feed_var_names"]
    fetch_var_names = model["fetch_var_names"]
    program = Program.parse_from_string(program_desc_str)
    load_persistables_if_exist(executor, dirname, program)
    fetch_vars = [program.global_block().var(name) for name in fetch_var_names]

    return [program, feed_var_names, fetch_vars]