io.py 17.1 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import os

17 18
from paddle.fluid.evaluator import Evaluator
from paddle.fluid.framework import Program, Parameter, default_main_program, Variable
K
fix bug  
Kexin Zhao 已提交
19
from . import core
20 21

__all__ = [
22 23 24 25 26 27 28 29 30
    'save_vars',
    'save_params',
    'save_persistables',
    'load_vars',
    'load_params',
    'load_persistables',
    'save_inference_model',
    'load_inference_model',
    'get_inference_program',
31 32 33 34
]


def is_parameter(var):
K
Kavya Srinet 已提交
35
    """Check whether the variable is a Parameter.
36 37 38 39 40 41 42

    This function checks whether the input variable is a Parameter.

    Args:
        var : The input variable.

    Returns:
K
Kavya Srinet 已提交
43
        boolean result whether the variable is a Parameter.
44
    """
45 46 47 48
    return isinstance(var, Parameter)


def is_persistable(var):
49
    if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
Y
Refine  
Yu Yang 已提交
50
            var.desc.type() == core.VarDesc.VarType.FETCH_LIST:
51
        return False
52 53 54 55 56 57 58 59
    return var.persistable


def _clone_var_in_block_(block, var):
    assert isinstance(var, Variable)
    return block.create_var(
        name=var.name,
        shape=var.shape,
F
fengjiayi 已提交
60
        dtype=var.dtype,
61 62 63 64 65
        type=var.type,
        lod_level=var.lod_level,
        persistable=True)


66 67 68 69 70
def save_vars(executor,
              dirname,
              main_program=None,
              vars=None,
              predicate=None,
71
              filename=None):
72 73
    """
    Save variables to directory by executor.
74

75 76
    :param executor: executor that save variable
    :param dirname: directory path
X
xuwei06 已提交
77
    :param main_program: program. If vars is None, then filter all variables in this
78
    program which fit `predicate`. Default default_main_program.
79
    :param predicate: The Predicate describes a callable that returns a variable
80 81
    as a bool. If it returns true, the corresponding input variable will be saved.
    :param vars: variables need to be saved. If vars is specified, program & predicate
82
    will be ignored
83 84
    :param filename: The name of a single file that all vars are saved to.
        If it is None, save variables to separate files.
85

86 87 88
    :return: None
    """
    if vars is None:
89
        if main_program is None:
Y
Yu Yang 已提交
90
            main_program = default_main_program()
91
        if not isinstance(main_program, Program):
92 93 94 95 96
            raise TypeError("program should be as Program type or None")

        save_vars(
            executor,
            dirname=dirname,
97
            vars=filter(predicate, main_program.list_vars()),
98
            filename=filename)
99 100 101
    else:
        save_program = Program()
        save_block = save_program.global_block()
102 103

        save_var_map = {}
104
        for each_var in vars:
105 106 107
            # NOTE: don't save the variable which type is RAW
            if each_var.type == core.VarDesc.VarType.RAW:
                continue
108
            new_var = _clone_var_in_block_(save_block, each_var)
109
            if filename is None:
110 111 112 113 114 115 116 117
                save_block.append_op(
                    type='save',
                    inputs={'X': [new_var]},
                    outputs={},
                    attrs={'file_path': os.path.join(dirname, new_var.name)})
            else:
                save_var_map[new_var.name] = new_var

118
        if filename is not None:
119 120 121 122
            save_var_list = []
            for name in sorted(save_var_map.keys()):
                save_var_list.append(save_var_map[name])

123
            save_block.append_op(
124 125
                type='save_combine',
                inputs={'X': save_var_list},
126
                outputs={},
127
                attrs={'file_path': os.path.join(dirname, filename)})
128

129 130 131
        executor.run(save_program)


132
def save_params(executor, dirname, main_program=None, filename=None):
133 134 135 136 137 138
    """
    Save all parameters to directory with executor.
    """
    save_vars(
        executor,
        dirname=dirname,
139
        main_program=main_program,
140
        vars=None,
141
        predicate=is_parameter,
142
        filename=filename)
143 144


145
def save_persistables(executor, dirname, main_program=None, filename=None):
146 147 148 149 150 151
    """
    Save all persistables to directory with executor.
    """
    save_vars(
        executor,
        dirname=dirname,
152
        main_program=main_program,
153
        vars=None,
154
        predicate=is_persistable,
155
        filename=filename)
156 157


158 159 160 161 162
def load_vars(executor,
              dirname,
              main_program=None,
              vars=None,
              predicate=None,
163
              filename=None):
164 165
    """
    Load variables from directory by executor.
166

167
    :param executor: executor that load variable
168
    :param dirname: directory path
X
xuwei06 已提交
169
    :param main_program: program. If vars is None, then filter all variables in this
Y
Yu Yang 已提交
170
    program which fit `predicate`. Default default_main_program().
171
    :param predicate: The Predicate describes a callable that returns a variable
172 173
    as a bool. If it returns true, the corresponding input variable will be loaded.
    :param vars: variables need to be loaded. If vars is specified, program &
174
    predicate will be ignored
175 176
    :param filename: The name of the single file that all vars are loaded from.
        If it is None, load variables from separate files.
177

178 179 180
    :return: None
    """
    if vars is None:
181
        if main_program is None:
Y
Yu Yang 已提交
182
            main_program = default_main_program()
183
        if not isinstance(main_program, Program):
184 185 186 187 188
            raise TypeError("program's type should be Program")

        load_vars(
            executor,
            dirname=dirname,
189
            vars=filter(predicate, main_program.list_vars()),
190
            filename=filename)
191 192 193
    else:
        load_prog = Program()
        load_block = load_prog.global_block()
194 195

        load_var_map = {}
196 197 198
        for each_var in vars:
            assert isinstance(each_var, Variable)
            new_var = _clone_var_in_block_(load_block, each_var)
199
            if filename is None:
200 201 202 203 204 205 206 207
                load_block.append_op(
                    type='load',
                    inputs={},
                    outputs={'Out': [new_var]},
                    attrs={'file_path': os.path.join(dirname, new_var.name)})
            else:
                load_var_map[new_var.name] = new_var

208
        if filename is not None:
209 210 211 212
            load_var_list = []
            for name in sorted(load_var_map.keys()):
                load_var_list.append(load_var_map[name])

213
            load_block.append_op(
214
                type='load_combine',
215
                inputs={},
216
                outputs={"Out": load_var_list},
217
                attrs={'file_path': os.path.join(dirname, filename)})
218

219 220 221
        executor.run(load_prog)


222
def load_params(executor, dirname, main_program=None, filename=None):
223 224 225 226
    """
    load all parameters from directory by executor.
    """
    load_vars(
227 228 229
        executor,
        dirname=dirname,
        main_program=main_program,
230
        predicate=is_parameter,
231
        filename=filename)
232 233


234
def load_persistables(executor, dirname, main_program=None, filename=None):
235 236 237 238
    """
    load all persistables from directory by executor.
    """
    load_vars(
239 240 241
        executor,
        dirname=dirname,
        main_program=main_program,
242
        predicate=is_persistable,
243
        filename=filename)
244 245


246 247
def get_inference_program(target_vars, main_program=None):
    if main_program is None:
Y
Yu Yang 已提交
248
        main_program = default_main_program()
249 250
    if not isinstance(target_vars, list):
        target_vars = [target_vars]
W
wanghaoshuang 已提交
251 252 253
    vars = []
    for var in target_vars:
        if isinstance(var, Evaluator):
W
wanghaoshuang 已提交
254 255
            vars.extend(var.states)
            vars.extend(var.metrics)
W
wanghaoshuang 已提交
256 257 258
        else:
            vars.append(var)
    pruned_program = main_program.prune(targets=vars)
259 260 261 262
    inference_program = pruned_program.inference_optimize()
    return inference_program


263 264 265
def prepend_feed_ops(inference_program,
                     feed_target_names,
                     feed_holder_name='feed'):
Q
Qiao Longfei 已提交
266 267 268
    if len(feed_target_names) == 0:
        return

K
Kexin Zhao 已提交
269 270
    global_block = inference_program.global_block()
    feed_var = global_block.create_var(
271 272 273
        name=feed_holder_name,
        type=core.VarDesc.VarType.FEED_MINIBATCH,
        persistable=True)
K
Kexin Zhao 已提交
274

275
    for i, name in enumerate(feed_target_names):
K
fix bug  
Kexin Zhao 已提交
276
        out = global_block.var(name)
K
Kexin Zhao 已提交
277 278 279
        global_block.prepend_op(
            type='feed',
            inputs={'X': [feed_var]},
K
fix bug  
Kexin Zhao 已提交
280
            outputs={'Out': [out]},
K
Kexin Zhao 已提交
281 282 283
            attrs={'col': i})


284 285 286
def append_fetch_ops(inference_program,
                     fetch_target_names,
                     fetch_holder_name='fetch'):
K
Kexin Zhao 已提交
287 288
    global_block = inference_program.global_block()
    fetch_var = global_block.create_var(
289 290 291
        name=fetch_holder_name,
        type=core.VarDesc.VarType.FETCH_LIST,
        persistable=True)
K
Kexin Zhao 已提交
292

293
    for i, name in enumerate(fetch_target_names):
K
Kexin Zhao 已提交
294 295 296 297 298 299 300
        global_block.append_op(
            type='fetch',
            inputs={'X': [name]},
            outputs={'Out': [fetch_var]},
            attrs={'col': i})


301 302 303 304
def save_inference_model(dirname,
                         feeded_var_names,
                         target_vars,
                         executor,
305
                         main_program=None,
306 307
                         model_filename=None,
                         params_filename=None):
308
    """
X
xuwei06 已提交
309
    Build a model especially for inference,
310 311 312 313 314 315
    and save it to directory by the executor.

    :param dirname: directory path
    :param feeded_var_names: Names of variables that need to be feeded data during inference
    :param target_vars: Variables from which we can get inference results.
    :param executor: executor that save inference model
X
xuwei06 已提交
316
    :param main_program: original program, which will be pruned to build the inference model.
Y
Yu Yang 已提交
317
            Default default_main_program().
318 319 320 321 322
    :param model_filename: The name of file to save inference program.
        If not specified, default filename `__model__` will be used.
    :param params_filename: The name of file to save parameters.
        It is used for the case that all parameters are saved in a single binary file.
        If not specified, parameters are considered saved in separate files.
323 324 325

    :return: None
    """
F
fengjiayi 已提交
326 327 328
    if isinstance(feeded_var_names, basestring):
        feeded_var_names = [feeded_var_names]
    else:
Q
Qiao Longfei 已提交
329 330 331 332
        if len(feeded_var_names) > 0:
            if not (bool(feeded_var_names) and all(
                    isinstance(name, basestring) for name in feeded_var_names)):
                raise ValueError("'feed_var_names' should be a list of str.")
F
fengjiayi 已提交
333 334

    if isinstance(target_vars, Variable):
F
fengjiayi 已提交
335
        target_vars = [target_vars]
F
fengjiayi 已提交
336 337 338 339 340
    else:
        if not (bool(target_vars) and all(
                isinstance(var, Variable) for var in target_vars)):
            raise ValueError("'target_vars' should be a list of Variable.")

341
    if main_program is None:
Y
Yu Yang 已提交
342
        main_program = default_main_program()
343
    copy_program = main_program.clone()
344 345 346 347

    if not os.path.isdir(dirname):
        os.makedirs(dirname)

348
    # Clear the is_target information and remove the existed feed and fetch op
349
    global_block = copy_program.global_block()
350 351 352 353
    for i, op in enumerate(global_block.ops):
        op.desc.set_is_target(False)
        if op.type == "feed" or op.type == "fetch":
            global_block.remove_op(i)
354
    copy_program.desc.flush()
355

356
    pruned_program = copy_program.prune(targets=target_vars)
357
    inference_program = pruned_program.inference_optimize()
358 359
    fetch_var_names = [v.name for v in target_vars]

K
Kexin Zhao 已提交
360 361
    prepend_feed_ops(inference_program, feeded_var_names)
    append_fetch_ops(inference_program, fetch_var_names)
362

363 364
    if model_filename is not None:
        model_filename = os.path.basename(model_filename)
365
    else:
366 367
        model_filename = "__model__"
    model_filename = os.path.join(dirname, model_filename)
368

369 370 371 372
    if params_filename is not None:
        params_filename = os.path.basename(params_filename)

    with open(model_filename, "wb") as f:
373
        f.write(inference_program.desc.serialize_to_string())
374

375
    save_persistables(executor, dirname, inference_program, params_filename)
376 377


378 379 380 381
def load_inference_model(dirname,
                         executor,
                         model_filename=None,
                         params_filename=None):
382 383 384 385 386
    """
    Load inference model from a directory

    :param dirname: directory path
    :param executor: executor that load inference model
387 388 389 390 391 392
    :param model_filename: The name of file to load inference program.
        If not specified, default filename `__model__` will be used.
    :param params_filename: The name of file to load parameters.
        It is used for the case that all parameters are saved in a single binary file.
        If not specified, parameters are considered saved in separate files.

393
    :return: [program, feed_target_names, fetch_targets]
394
             program: program especially for inference.
395 396
             feed_target_names: Names of variables that need to feed data
             fetch_targets: Variables from which we can get inference results.
397 398 399 400
    """
    if not os.path.isdir(dirname):
        raise ValueError("There is no directory named '%s'", dirname)

401 402
    if model_filename is not None:
        model_filename = os.path.basename(model_filename)
403
    else:
404 405 406 407 408
        model_filename = "__model__"
    model_filename = os.path.join(dirname, model_filename)

    if params_filename is not None:
        params_filename = os.path.basename(params_filename)
409

410
    with open(model_filename, "rb") as f:
411 412
        program_desc_str = f.read()

413
    program = Program.parse_from_string(program_desc_str)
414
    load_persistables(executor, dirname, program, params_filename)
415

416 417
    feed_target_names = program.desc.get_feed_target_names()
    fetch_target_names = program.desc.get_fetch_target_names()
418 419 420 421 422
    fetch_targets = [
        program.global_block().var(name) for name in fetch_target_names
    ]

    return [program, feed_target_names, fetch_targets]
X
xuwei06 已提交
423 424 425 426 427 428 429 430


def get_parameter_value(para, executor):
    """
    Get the LoDTensor for the parameter

    :param executor: executor for retrieving the value
    :param para: the given parameter
431

X
xuwei06 已提交
432 433
    :return: the LoDTensor for the parameter
    """
X
xuwei06 已提交
434 435
    assert is_parameter(para)

X
xuwei06 已提交
436 437 438 439 440 441 442 443 444 445 446 447 448
    get_program = Program()
    block = get_program.global_block()
    new_var = _clone_var_in_block_(block, para)
    return executor.run(get_program, feed={}, fetch_list=[new_var])[0]


def get_parameter_value_by_name(name, executor, program=None):
    """
    Get the LoDTensor for paramter with the given name

    :param executor: executor for retrieving the value
    :param name: the name of the parameter
    :param program: the program where the variable is found
Y
Yu Yang 已提交
449
            Default default_main_program().
450

X
xuwei06 已提交
451 452 453
    :return: the LoDTensor for the variable
    """
    if program is None:
Y
Yu Yang 已提交
454
        program = default_main_program()
X
xuwei06 已提交
455 456
    var = program.global_block().var(name)
    return get_parameter_value(var, executor)
T
tangwei12 已提交
457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543


SUCCESS = "_SUCCESS"


def save_checkpoint(executor,
                    dirname,
                    keep_max=10,
                    save_secs=600,
                    main_program=None):
    """
    Save Variables to Checkpint Dir

    :param dirname
    :param keep_max
    :param save_secs
    """
    if dirname is None:
        raise Exception("save checkpoint dir can not be none")

    if not os.path.isdir(dirname):
        os.makedirs(dirname)
    serial = _get_lastest_checkpoint_dir(dirname) + 1

    cur_dir = os.path.join(dirname, serial)
    save_persistables(executor, cur_dir, main_program)
    _write_success(cur_dir)


def restore_checkpoint(dirname, executor, main_program=None):
    """
    Load Variables from Checkpint Dir

    :param dir
    """
    if dirname is None and os.path.isdir(dirname):
        raise Exception("restore checkpoint can not load variables from %s" %
                        dirname)
    serial = _get_lastest_checkpoint_dir(dirname) + 1

    if serial < -1:
        return
    cur_dir = os.path.join(dirname, serial)
    load_persistables(executor, cur_dir, main_program)


def _write_success(dirname):
    """
    """
    success_file = os.path.join(dirname, SUCCESS)
    with open(success_file, 'a'):
        pass


def _get_lastest_checkpoint_dir(checkpoint_dir):
    """
    get the biggest number in checkpoint_dir, which has _SUCCESS
    """
    if not checkpoint_dir.strip():
        return ""

    def has_success(checkpoint_dir, cur_dir):
        """
        is _SUCCESS in this dir
        """
        if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)):
            return -1

        try:
            int(cur_dir)
        except ValueError:
            return -1

        success_path = os.path.join(checkpoint_dir, cur_dir, SUCCESS)
        if os.path.isfile(success_path):
            return int(cur_dir)

    if not os.path.isdir(checkpoint_dir):
        return -1

    current_dir = -1
    dirs = os.listdir(checkpoint_dir)
    for cur_dir in dirs:
        success_num = has_success(checkpoint_dir, cur_dir)
        if success_num > current_dir:
            current_dir = success_num
    return current_dir