io.py 18.0 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import os
T
tangwei12 已提交
16 17
import time
import shutil
18

19 20
from paddle.fluid.evaluator import Evaluator
from paddle.fluid.framework import Program, Parameter, default_main_program, Variable
K
fix bug  
Kexin Zhao 已提交
21
from . import core
22 23

__all__ = [
T
tangwei12 已提交
24 25 26
    'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
    'load_persistables', 'save_inference_model', 'load_inference_model',
    'get_inference_program', 'save_checkpoint', 'restore_checkpoint'
27 28 29 30
]


def is_parameter(var):
K
Kavya Srinet 已提交
31
    """Check whether the variable is a Parameter.
32 33 34 35 36 37 38

    This function checks whether the input variable is a Parameter.

    Args:
        var : The input variable.

    Returns:
K
Kavya Srinet 已提交
39
        boolean result whether the variable is a Parameter.
40
    """
41 42 43 44
    return isinstance(var, Parameter)


def is_persistable(var):
45
    if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
Y
Refine  
Yu Yang 已提交
46
            var.desc.type() == core.VarDesc.VarType.FETCH_LIST:
47
        return False
48 49 50 51 52 53 54 55
    return var.persistable


def _clone_var_in_block_(block, var):
    assert isinstance(var, Variable)
    return block.create_var(
        name=var.name,
        shape=var.shape,
F
fengjiayi 已提交
56
        dtype=var.dtype,
57 58 59 60 61
        type=var.type,
        lod_level=var.lod_level,
        persistable=True)


62 63 64 65 66
def save_vars(executor,
              dirname,
              main_program=None,
              vars=None,
              predicate=None,
67
              filename=None):
68 69
    """
    Save variables to directory by executor.
70

71 72
    :param executor: executor that save variable
    :param dirname: directory path
X
xuwei06 已提交
73
    :param main_program: program. If vars is None, then filter all variables in this
74
    program which fit `predicate`. Default default_main_program.
75
    :param predicate: The Predicate describes a callable that returns a variable
76 77
    as a bool. If it returns true, the corresponding input variable will be saved.
    :param vars: variables need to be saved. If vars is specified, program & predicate
78
    will be ignored
79 80
    :param filename: The name of a single file that all vars are saved to.
        If it is None, save variables to separate files.
81

82 83 84
    :return: None
    """
    if vars is None:
85
        if main_program is None:
Y
Yu Yang 已提交
86
            main_program = default_main_program()
87
        if not isinstance(main_program, Program):
88 89 90 91 92
            raise TypeError("program should be as Program type or None")

        save_vars(
            executor,
            dirname=dirname,
93
            vars=filter(predicate, main_program.list_vars()),
94
            filename=filename)
95 96 97
    else:
        save_program = Program()
        save_block = save_program.global_block()
98 99

        save_var_map = {}
100
        for each_var in vars:
101 102 103
            # NOTE: don't save the variable which type is RAW
            if each_var.type == core.VarDesc.VarType.RAW:
                continue
104
            new_var = _clone_var_in_block_(save_block, each_var)
105
            if filename is None:
106 107 108 109 110 111 112 113
                save_block.append_op(
                    type='save',
                    inputs={'X': [new_var]},
                    outputs={},
                    attrs={'file_path': os.path.join(dirname, new_var.name)})
            else:
                save_var_map[new_var.name] = new_var

114
        if filename is not None:
115 116 117 118
            save_var_list = []
            for name in sorted(save_var_map.keys()):
                save_var_list.append(save_var_map[name])

119
            save_block.append_op(
120 121
                type='save_combine',
                inputs={'X': save_var_list},
122
                outputs={},
123
                attrs={'file_path': os.path.join(dirname, filename)})
124

125 126 127
        executor.run(save_program)


128
def save_params(executor, dirname, main_program=None, filename=None):
129 130 131 132 133 134
    """
    Save all parameters to directory with executor.
    """
    save_vars(
        executor,
        dirname=dirname,
135
        main_program=main_program,
136
        vars=None,
137
        predicate=is_parameter,
138
        filename=filename)
139 140


141
def save_persistables(executor, dirname, main_program=None, filename=None):
142 143 144 145 146 147
    """
    Save all persistables to directory with executor.
    """
    save_vars(
        executor,
        dirname=dirname,
148
        main_program=main_program,
149
        vars=None,
150
        predicate=is_persistable,
151
        filename=filename)
152 153


154 155 156 157 158
def load_vars(executor,
              dirname,
              main_program=None,
              vars=None,
              predicate=None,
159
              filename=None):
160 161
    """
    Load variables from directory by executor.
162

163
    :param executor: executor that load variable
164
    :param dirname: directory path
X
xuwei06 已提交
165
    :param main_program: program. If vars is None, then filter all variables in this
Y
Yu Yang 已提交
166
    program which fit `predicate`. Default default_main_program().
167
    :param predicate: The Predicate describes a callable that returns a variable
168 169
    as a bool. If it returns true, the corresponding input variable will be loaded.
    :param vars: variables need to be loaded. If vars is specified, program &
170
    predicate will be ignored
171 172
    :param filename: The name of the single file that all vars are loaded from.
        If it is None, load variables from separate files.
173

174 175 176
    :return: None
    """
    if vars is None:
177
        if main_program is None:
Y
Yu Yang 已提交
178
            main_program = default_main_program()
179
        if not isinstance(main_program, Program):
180 181 182 183 184
            raise TypeError("program's type should be Program")

        load_vars(
            executor,
            dirname=dirname,
185
            vars=filter(predicate, main_program.list_vars()),
186
            filename=filename)
187 188 189
    else:
        load_prog = Program()
        load_block = load_prog.global_block()
190 191

        load_var_map = {}
192 193
        for each_var in vars:
            assert isinstance(each_var, Variable)
T
tangwei12 已提交
194 195
            if each_var.type == core.VarDesc.VarType.RAW:
                continue
196
            new_var = _clone_var_in_block_(load_block, each_var)
197
            if filename is None:
198 199 200 201 202 203 204 205
                load_block.append_op(
                    type='load',
                    inputs={},
                    outputs={'Out': [new_var]},
                    attrs={'file_path': os.path.join(dirname, new_var.name)})
            else:
                load_var_map[new_var.name] = new_var

206
        if filename is not None:
207 208 209 210
            load_var_list = []
            for name in sorted(load_var_map.keys()):
                load_var_list.append(load_var_map[name])

211
            load_block.append_op(
212
                type='load_combine',
213
                inputs={},
214
                outputs={"Out": load_var_list},
215
                attrs={'file_path': os.path.join(dirname, filename)})
216

217 218 219
        executor.run(load_prog)


220
def load_params(executor, dirname, main_program=None, filename=None):
221 222 223 224
    """
    load all parameters from directory by executor.
    """
    load_vars(
225 226 227
        executor,
        dirname=dirname,
        main_program=main_program,
228
        predicate=is_parameter,
229
        filename=filename)
230 231


232
def load_persistables(executor, dirname, main_program=None, filename=None):
233 234 235 236
    """
    load all persistables from directory by executor.
    """
    load_vars(
237 238 239
        executor,
        dirname=dirname,
        main_program=main_program,
240
        predicate=is_persistable,
241
        filename=filename)
242 243


244 245
def get_inference_program(target_vars, main_program=None):
    if main_program is None:
Y
Yu Yang 已提交
246
        main_program = default_main_program()
247 248
    if not isinstance(target_vars, list):
        target_vars = [target_vars]
W
wanghaoshuang 已提交
249 250 251
    vars = []
    for var in target_vars:
        if isinstance(var, Evaluator):
W
wanghaoshuang 已提交
252 253
            vars.extend(var.states)
            vars.extend(var.metrics)
W
wanghaoshuang 已提交
254 255 256
        else:
            vars.append(var)
    pruned_program = main_program.prune(targets=vars)
257 258 259 260
    inference_program = pruned_program.inference_optimize()
    return inference_program


261 262 263
def prepend_feed_ops(inference_program,
                     feed_target_names,
                     feed_holder_name='feed'):
Q
Qiao Longfei 已提交
264 265 266
    if len(feed_target_names) == 0:
        return

K
Kexin Zhao 已提交
267 268
    global_block = inference_program.global_block()
    feed_var = global_block.create_var(
269 270 271
        name=feed_holder_name,
        type=core.VarDesc.VarType.FEED_MINIBATCH,
        persistable=True)
K
Kexin Zhao 已提交
272

273
    for i, name in enumerate(feed_target_names):
K
fix bug  
Kexin Zhao 已提交
274
        out = global_block.var(name)
K
Kexin Zhao 已提交
275 276 277
        global_block.prepend_op(
            type='feed',
            inputs={'X': [feed_var]},
K
fix bug  
Kexin Zhao 已提交
278
            outputs={'Out': [out]},
K
Kexin Zhao 已提交
279 280 281
            attrs={'col': i})


282 283 284
def append_fetch_ops(inference_program,
                     fetch_target_names,
                     fetch_holder_name='fetch'):
K
Kexin Zhao 已提交
285 286
    global_block = inference_program.global_block()
    fetch_var = global_block.create_var(
287 288 289
        name=fetch_holder_name,
        type=core.VarDesc.VarType.FETCH_LIST,
        persistable=True)
K
Kexin Zhao 已提交
290

291
    for i, name in enumerate(fetch_target_names):
K
Kexin Zhao 已提交
292 293 294 295 296 297 298
        global_block.append_op(
            type='fetch',
            inputs={'X': [name]},
            outputs={'Out': [fetch_var]},
            attrs={'col': i})


299 300 301 302
def save_inference_model(dirname,
                         feeded_var_names,
                         target_vars,
                         executor,
303
                         main_program=None,
304 305
                         model_filename=None,
                         params_filename=None):
306
    """
X
xuwei06 已提交
307
    Build a model especially for inference,
308 309 310 311 312 313
    and save it to directory by the executor.

    :param dirname: directory path
    :param feeded_var_names: Names of variables that need to be feeded data during inference
    :param target_vars: Variables from which we can get inference results.
    :param executor: executor that save inference model
X
xuwei06 已提交
314
    :param main_program: original program, which will be pruned to build the inference model.
Y
Yu Yang 已提交
315
            Default default_main_program().
316 317 318 319 320
    :param model_filename: The name of file to save inference program.
        If not specified, default filename `__model__` will be used.
    :param params_filename: The name of file to save parameters.
        It is used for the case that all parameters are saved in a single binary file.
        If not specified, parameters are considered saved in separate files.
321 322 323

    :return: None
    """
F
fengjiayi 已提交
324 325 326
    if isinstance(feeded_var_names, basestring):
        feeded_var_names = [feeded_var_names]
    else:
Q
Qiao Longfei 已提交
327 328 329 330
        if len(feeded_var_names) > 0:
            if not (bool(feeded_var_names) and all(
                    isinstance(name, basestring) for name in feeded_var_names)):
                raise ValueError("'feed_var_names' should be a list of str.")
F
fengjiayi 已提交
331 332

    if isinstance(target_vars, Variable):
F
fengjiayi 已提交
333
        target_vars = [target_vars]
F
fengjiayi 已提交
334 335 336 337 338
    else:
        if not (bool(target_vars) and all(
                isinstance(var, Variable) for var in target_vars)):
            raise ValueError("'target_vars' should be a list of Variable.")

339
    if main_program is None:
Y
Yu Yang 已提交
340
        main_program = default_main_program()
341
    copy_program = main_program.clone()
342 343 344 345

    if not os.path.isdir(dirname):
        os.makedirs(dirname)

346
    # Clear the is_target information and remove the existed feed and fetch op
347
    global_block = copy_program.global_block()
348 349 350 351
    for i, op in enumerate(global_block.ops):
        op.desc.set_is_target(False)
        if op.type == "feed" or op.type == "fetch":
            global_block.remove_op(i)
352
    copy_program.desc.flush()
353

354
    pruned_program = copy_program.prune(targets=target_vars)
355
    inference_program = pruned_program.inference_optimize()
356 357
    fetch_var_names = [v.name for v in target_vars]

K
Kexin Zhao 已提交
358 359
    prepend_feed_ops(inference_program, feeded_var_names)
    append_fetch_ops(inference_program, fetch_var_names)
360

361 362
    if model_filename is not None:
        model_filename = os.path.basename(model_filename)
363
    else:
364 365
        model_filename = "__model__"
    model_filename = os.path.join(dirname, model_filename)
366

367 368 369 370
    if params_filename is not None:
        params_filename = os.path.basename(params_filename)

    with open(model_filename, "wb") as f:
371
        f.write(inference_program.desc.serialize_to_string())
372

373
    save_persistables(executor, dirname, inference_program, params_filename)
374 375


376 377 378 379
def load_inference_model(dirname,
                         executor,
                         model_filename=None,
                         params_filename=None):
380 381 382 383 384
    """
    Load inference model from a directory

    :param dirname: directory path
    :param executor: executor that load inference model
385 386 387 388 389 390
    :param model_filename: The name of file to load inference program.
        If not specified, default filename `__model__` will be used.
    :param params_filename: The name of file to load parameters.
        It is used for the case that all parameters are saved in a single binary file.
        If not specified, parameters are considered saved in separate files.

391
    :return: [program, feed_target_names, fetch_targets]
392
             program: program especially for inference.
393 394
             feed_target_names: Names of variables that need to feed data
             fetch_targets: Variables from which we can get inference results.
395 396 397 398
    """
    if not os.path.isdir(dirname):
        raise ValueError("There is no directory named '%s'", dirname)

399 400
    if model_filename is not None:
        model_filename = os.path.basename(model_filename)
401
    else:
402 403 404 405 406
        model_filename = "__model__"
    model_filename = os.path.join(dirname, model_filename)

    if params_filename is not None:
        params_filename = os.path.basename(params_filename)
407

408
    with open(model_filename, "rb") as f:
409 410
        program_desc_str = f.read()

411
    program = Program.parse_from_string(program_desc_str)
412
    load_persistables(executor, dirname, program, params_filename)
413

414 415
    feed_target_names = program.desc.get_feed_target_names()
    fetch_target_names = program.desc.get_fetch_target_names()
416 417 418 419 420
    fetch_targets = [
        program.global_block().var(name) for name in fetch_target_names
    ]

    return [program, feed_target_names, fetch_targets]
X
xuwei06 已提交
421 422 423 424 425 426 427 428


def get_parameter_value(para, executor):
    """
    Get the LoDTensor for the parameter

    :param executor: executor for retrieving the value
    :param para: the given parameter
429

X
xuwei06 已提交
430 431
    :return: the LoDTensor for the parameter
    """
X
xuwei06 已提交
432 433
    assert is_parameter(para)

X
xuwei06 已提交
434 435 436 437 438 439 440 441 442 443 444 445 446
    get_program = Program()
    block = get_program.global_block()
    new_var = _clone_var_in_block_(block, para)
    return executor.run(get_program, feed={}, fetch_list=[new_var])[0]


def get_parameter_value_by_name(name, executor, program=None):
    """
    Get the LoDTensor for paramter with the given name

    :param executor: executor for retrieving the value
    :param name: the name of the parameter
    :param program: the program where the variable is found
Y
Yu Yang 已提交
447
            Default default_main_program().
448

X
xuwei06 已提交
449 450 451
    :return: the LoDTensor for the variable
    """
    if program is None:
Y
Yu Yang 已提交
452
        program = default_main_program()
X
xuwei06 已提交
453 454
    var = program.global_block().var(name)
    return get_parameter_value(var, executor)
T
tangwei12 已提交
455 456 457


SUCCESS = "_SUCCESS"
T
tangwei12 已提交
458
BEGIN_SECS = time.time()
T
tangwei12 已提交
459 460 461 462


def save_checkpoint(executor,
                    dirname,
T
tangwei12 已提交
463
                    keep_max=3,
T
tangwei12 已提交
464 465 466 467 468 469 470 471
                    save_secs=600,
                    main_program=None):
    """
    Save Variables to Checkpint Dir

    :param dirname
    :param keep_max
    :param save_secs
T
tangwei12 已提交
472
    :param main_program
T
tangwei12 已提交
473 474 475 476 477 478 479
    """
    if dirname is None:
        raise Exception("save checkpoint dir can not be none")

    if not os.path.isdir(dirname):
        os.makedirs(dirname)

T
tangwei12 已提交
480 481 482 483 484 485 486
    global BEGIN_SECS
    if time.time() - BEGIN_SECS < save_secs:
        return
    BEGIN_SECS = time.time()

    serial = _get_lastest_checkpoint_dir(dirname) + 1
    cur_dir = os.path.join(dirname, str(serial))
T
tangwei12 已提交
487 488
    save_persistables(executor, cur_dir, main_program)
    _write_success(cur_dir)
T
tangwei12 已提交
489
    _lru_delete(dirname, keep_max)
T
tangwei12 已提交
490 491 492 493 494 495


def restore_checkpoint(dirname, executor, main_program=None):
    """
    Load Variables from Checkpint Dir

T
tangwei12 已提交
496 497 498
    :param dirname
    :param executor
    :param main_program
T
tangwei12 已提交
499 500 501 502
    """
    if dirname is None and os.path.isdir(dirname):
        raise Exception("restore checkpoint can not load variables from %s" %
                        dirname)
T
tangwei12 已提交
503
    serial = _get_lastest_checkpoint_dir(dirname)
T
tangwei12 已提交
504

T
tangwei12 已提交
505
    if serial < 0:
T
tangwei12 已提交
506
        return
T
tangwei12 已提交
507
    cur_dir = os.path.join(dirname, str(serial))
T
tangwei12 已提交
508 509 510
    load_persistables(executor, cur_dir, main_program)


T
tangwei12 已提交
511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532
def _lru_delete(dirname, keep_max=3):
    """
    retain checkpoint nums with keep_max
    """
    dirs = os.listdir(dirname)
    serials = []
    for serial in dirs:
        try:
            serials.append(int(serial))
        except ValueError:
            continue

    if len(serials) <= keep_max:
        return

    serials.sort(reverse=True)
    serials = serials[keep_max:]
    for serial in serials:
        cur_dir = os.path.join(dirname, str(serial))
        shutil.rmtree(cur_dir)


T
tangwei12 已提交
533 534
def _write_success(dirname):
    """
T
tangwei12 已提交
535
    write _SUCCESS to checkpoint dir
T
tangwei12 已提交
536 537 538 539 540 541 542 543 544 545 546
    """
    success_file = os.path.join(dirname, SUCCESS)
    with open(success_file, 'a'):
        pass


def _get_lastest_checkpoint_dir(checkpoint_dir):
    """
    get the biggest number in checkpoint_dir, which has _SUCCESS
    """
    if not checkpoint_dir.strip():
T
tangwei12 已提交
547
        return -1
T
tangwei12 已提交
548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574

    def has_success(checkpoint_dir, cur_dir):
        """
        is _SUCCESS in this dir
        """
        if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)):
            return -1

        try:
            int(cur_dir)
        except ValueError:
            return -1

        success_path = os.path.join(checkpoint_dir, cur_dir, SUCCESS)
        if os.path.isfile(success_path):
            return int(cur_dir)

    if not os.path.isdir(checkpoint_dir):
        return -1

    current_dir = -1
    dirs = os.listdir(checkpoint_dir)
    for cur_dir in dirs:
        success_num = has_success(checkpoint_dir, cur_dir)
        if success_num > current_dir:
            current_dir = success_num
    return current_dir