io.py 19.9 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import os
T
tangwei12 已提交
16 17
import time
import shutil
18

19 20
from paddle.fluid.evaluator import Evaluator
from paddle.fluid.framework import Program, Parameter, default_main_program, Variable
K
fix bug  
Kexin Zhao 已提交
21
from . import core
22 23

__all__ = [
T
tangwei12 已提交
24 25
    'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
    'load_persistables', 'save_inference_model', 'load_inference_model',
T
tangwei12 已提交
26
    'get_inference_program', 'save_checkpoint', 'load_checkpoint'
27 28 29 30
]


def is_parameter(var):
K
Kavya Srinet 已提交
31
    """Check whether the variable is a Parameter.
32 33 34 35 36 37 38

    This function checks whether the input variable is a Parameter.

    Args:
        var : The input variable.

    Returns:
K
Kavya Srinet 已提交
39
        boolean result whether the variable is a Parameter.
40
    """
41 42 43 44
    return isinstance(var, Parameter)


def is_persistable(var):
45
    if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
Y
Refine  
Yu Yang 已提交
46
            var.desc.type() == core.VarDesc.VarType.FETCH_LIST:
47
        return False
48 49 50 51 52 53 54 55
    return var.persistable


def _clone_var_in_block_(block, var):
    assert isinstance(var, Variable)
    return block.create_var(
        name=var.name,
        shape=var.shape,
F
fengjiayi 已提交
56
        dtype=var.dtype,
57 58 59 60 61
        type=var.type,
        lod_level=var.lod_level,
        persistable=True)


62 63 64 65 66
def save_vars(executor,
              dirname,
              main_program=None,
              vars=None,
              predicate=None,
67
              filename=None):
68 69
    """
    Save variables to directory by executor.
70

71 72
    :param executor: executor that save variable
    :param dirname: directory path
X
xuwei06 已提交
73
    :param main_program: program. If vars is None, then filter all variables in this
74
    program which fit `predicate`. Default default_main_program.
75
    :param predicate: The Predicate describes a callable that returns a variable
76 77
    as a bool. If it returns true, the corresponding input variable will be saved.
    :param vars: variables need to be saved. If vars is specified, program & predicate
78
    will be ignored
79 80
    :param filename: The name of a single file that all vars are saved to.
        If it is None, save variables to separate files.
81

82 83 84
    :return: None
    """
    if vars is None:
85
        if main_program is None:
Y
Yu Yang 已提交
86
            main_program = default_main_program()
87
        if not isinstance(main_program, Program):
88 89 90 91 92
            raise TypeError("program should be as Program type or None")

        save_vars(
            executor,
            dirname=dirname,
93
            vars=filter(predicate, main_program.list_vars()),
94
            filename=filename)
95 96 97
    else:
        save_program = Program()
        save_block = save_program.global_block()
98 99

        save_var_map = {}
100
        for each_var in vars:
101 102 103
            # NOTE: don't save the variable which type is RAW
            if each_var.type == core.VarDesc.VarType.RAW:
                continue
104
            new_var = _clone_var_in_block_(save_block, each_var)
105
            if filename is None:
106 107 108 109 110 111 112 113
                save_block.append_op(
                    type='save',
                    inputs={'X': [new_var]},
                    outputs={},
                    attrs={'file_path': os.path.join(dirname, new_var.name)})
            else:
                save_var_map[new_var.name] = new_var

114
        if filename is not None:
115 116 117 118
            save_var_list = []
            for name in sorted(save_var_map.keys()):
                save_var_list.append(save_var_map[name])

119
            save_block.append_op(
120 121
                type='save_combine',
                inputs={'X': save_var_list},
122
                outputs={},
123
                attrs={'file_path': os.path.join(dirname, filename)})
124

125 126 127
        executor.run(save_program)


128
def save_params(executor, dirname, main_program=None, filename=None):
129 130 131 132 133 134
    """
    Save all parameters to directory with executor.
    """
    save_vars(
        executor,
        dirname=dirname,
135
        main_program=main_program,
136
        vars=None,
137
        predicate=is_parameter,
138
        filename=filename)
139 140


141
def save_persistables(executor, dirname, main_program=None, filename=None):
142 143 144 145 146 147
    """
    Save all persistables to directory with executor.
    """
    save_vars(
        executor,
        dirname=dirname,
148
        main_program=main_program,
149
        vars=None,
150
        predicate=is_persistable,
151
        filename=filename)
152 153


154 155 156 157 158
def load_vars(executor,
              dirname,
              main_program=None,
              vars=None,
              predicate=None,
159
              filename=None):
160 161
    """
    Load variables from directory by executor.
162

163
    :param executor: executor that load variable
164
    :param dirname: directory path
X
xuwei06 已提交
165
    :param main_program: program. If vars is None, then filter all variables in this
Y
Yu Yang 已提交
166
    program which fit `predicate`. Default default_main_program().
167
    :param predicate: The Predicate describes a callable that returns a variable
168 169
    as a bool. If it returns true, the corresponding input variable will be loaded.
    :param vars: variables need to be loaded. If vars is specified, program &
170
    predicate will be ignored
171 172
    :param filename: The name of the single file that all vars are loaded from.
        If it is None, load variables from separate files.
173

174 175 176
    :return: None
    """
    if vars is None:
177
        if main_program is None:
Y
Yu Yang 已提交
178
            main_program = default_main_program()
179
        if not isinstance(main_program, Program):
180 181 182 183 184
            raise TypeError("program's type should be Program")

        load_vars(
            executor,
            dirname=dirname,
185
            vars=filter(predicate, main_program.list_vars()),
186
            filename=filename)
187 188 189
    else:
        load_prog = Program()
        load_block = load_prog.global_block()
190 191

        load_var_map = {}
192 193
        for each_var in vars:
            assert isinstance(each_var, Variable)
T
tangwei12 已提交
194 195
            if each_var.type == core.VarDesc.VarType.RAW:
                continue
196
            new_var = _clone_var_in_block_(load_block, each_var)
197
            if filename is None:
198 199 200 201 202 203 204 205
                load_block.append_op(
                    type='load',
                    inputs={},
                    outputs={'Out': [new_var]},
                    attrs={'file_path': os.path.join(dirname, new_var.name)})
            else:
                load_var_map[new_var.name] = new_var

206
        if filename is not None:
207 208 209 210
            load_var_list = []
            for name in sorted(load_var_map.keys()):
                load_var_list.append(load_var_map[name])

211
            load_block.append_op(
212
                type='load_combine',
213
                inputs={},
214
                outputs={"Out": load_var_list},
215
                attrs={'file_path': os.path.join(dirname, filename)})
216

217 218 219
        executor.run(load_prog)


220
def load_params(executor, dirname, main_program=None, filename=None):
221 222 223 224
    """
    load all parameters from directory by executor.
    """
    load_vars(
225 226 227
        executor,
        dirname=dirname,
        main_program=main_program,
228
        predicate=is_parameter,
229
        filename=filename)
230 231


232
def load_persistables(executor, dirname, main_program=None, filename=None):
233 234 235 236
    """
    load all persistables from directory by executor.
    """
    load_vars(
237 238 239
        executor,
        dirname=dirname,
        main_program=main_program,
240
        predicate=is_persistable,
241
        filename=filename)
242 243


244 245
def get_inference_program(target_vars, main_program=None):
    if main_program is None:
Y
Yu Yang 已提交
246
        main_program = default_main_program()
247 248
    if not isinstance(target_vars, list):
        target_vars = [target_vars]
W
wanghaoshuang 已提交
249 250 251
    vars = []
    for var in target_vars:
        if isinstance(var, Evaluator):
W
wanghaoshuang 已提交
252 253
            vars.extend(var.states)
            vars.extend(var.metrics)
W
wanghaoshuang 已提交
254 255 256
        else:
            vars.append(var)
    pruned_program = main_program.prune(targets=vars)
257 258 259 260
    inference_program = pruned_program.inference_optimize()
    return inference_program


261 262 263
def prepend_feed_ops(inference_program,
                     feed_target_names,
                     feed_holder_name='feed'):
Q
Qiao Longfei 已提交
264 265 266
    if len(feed_target_names) == 0:
        return

K
Kexin Zhao 已提交
267 268
    global_block = inference_program.global_block()
    feed_var = global_block.create_var(
269 270 271
        name=feed_holder_name,
        type=core.VarDesc.VarType.FEED_MINIBATCH,
        persistable=True)
K
Kexin Zhao 已提交
272

273
    for i, name in enumerate(feed_target_names):
K
fix bug  
Kexin Zhao 已提交
274
        out = global_block.var(name)
K
Kexin Zhao 已提交
275 276 277
        global_block.prepend_op(
            type='feed',
            inputs={'X': [feed_var]},
K
fix bug  
Kexin Zhao 已提交
278
            outputs={'Out': [out]},
K
Kexin Zhao 已提交
279 280 281
            attrs={'col': i})


282 283 284
def append_fetch_ops(inference_program,
                     fetch_target_names,
                     fetch_holder_name='fetch'):
K
Kexin Zhao 已提交
285 286
    global_block = inference_program.global_block()
    fetch_var = global_block.create_var(
287 288 289
        name=fetch_holder_name,
        type=core.VarDesc.VarType.FETCH_LIST,
        persistable=True)
K
Kexin Zhao 已提交
290

291
    for i, name in enumerate(fetch_target_names):
K
Kexin Zhao 已提交
292 293 294 295 296 297 298
        global_block.append_op(
            type='fetch',
            inputs={'X': [name]},
            outputs={'Out': [fetch_var]},
            attrs={'col': i})


299 300 301 302
def save_inference_model(dirname,
                         feeded_var_names,
                         target_vars,
                         executor,
303
                         main_program=None,
304 305
                         model_filename=None,
                         params_filename=None):
306
    """
X
xuwei06 已提交
307
    Build a model especially for inference,
308 309 310 311 312 313
    and save it to directory by the executor.

    :param dirname: directory path
    :param feeded_var_names: Names of variables that need to be feeded data during inference
    :param target_vars: Variables from which we can get inference results.
    :param executor: executor that save inference model
X
xuwei06 已提交
314
    :param main_program: original program, which will be pruned to build the inference model.
Y
Yu Yang 已提交
315
            Default default_main_program().
316 317 318 319 320
    :param model_filename: The name of file to save inference program.
        If not specified, default filename `__model__` will be used.
    :param params_filename: The name of file to save parameters.
        It is used for the case that all parameters are saved in a single binary file.
        If not specified, parameters are considered saved in separate files.
321 322 323

    :return: None
    """
F
fengjiayi 已提交
324 325 326
    if isinstance(feeded_var_names, basestring):
        feeded_var_names = [feeded_var_names]
    else:
Q
Qiao Longfei 已提交
327 328 329 330
        if len(feeded_var_names) > 0:
            if not (bool(feeded_var_names) and all(
                    isinstance(name, basestring) for name in feeded_var_names)):
                raise ValueError("'feed_var_names' should be a list of str.")
F
fengjiayi 已提交
331 332

    if isinstance(target_vars, Variable):
F
fengjiayi 已提交
333
        target_vars = [target_vars]
F
fengjiayi 已提交
334 335 336 337 338
    else:
        if not (bool(target_vars) and all(
                isinstance(var, Variable) for var in target_vars)):
            raise ValueError("'target_vars' should be a list of Variable.")

339
    if main_program is None:
Y
Yu Yang 已提交
340
        main_program = default_main_program()
341
    copy_program = main_program.clone()
342 343 344 345

    if not os.path.isdir(dirname):
        os.makedirs(dirname)

346
    # Clear the is_target information and remove the existed feed and fetch op
347
    global_block = copy_program.global_block()
348 349 350 351
    for i, op in enumerate(global_block.ops):
        op.desc.set_is_target(False)
        if op.type == "feed" or op.type == "fetch":
            global_block.remove_op(i)
352
    copy_program.desc.flush()
353

354
    pruned_program = copy_program.prune(targets=target_vars)
355
    inference_program = pruned_program.inference_optimize()
356 357
    fetch_var_names = [v.name for v in target_vars]

K
Kexin Zhao 已提交
358 359
    prepend_feed_ops(inference_program, feeded_var_names)
    append_fetch_ops(inference_program, fetch_var_names)
360

361 362
    if model_filename is not None:
        model_filename = os.path.basename(model_filename)
363
    else:
364 365
        model_filename = "__model__"
    model_filename = os.path.join(dirname, model_filename)
366

367 368 369 370
    if params_filename is not None:
        params_filename = os.path.basename(params_filename)

    with open(model_filename, "wb") as f:
371
        f.write(inference_program.desc.serialize_to_string())
372

373
    save_persistables(executor, dirname, inference_program, params_filename)
374 375


376 377 378 379
def load_inference_model(dirname,
                         executor,
                         model_filename=None,
                         params_filename=None):
380 381 382 383 384
    """
    Load inference model from a directory

    :param dirname: directory path
    :param executor: executor that load inference model
385 386 387 388 389 390
    :param model_filename: The name of file to load inference program.
        If not specified, default filename `__model__` will be used.
    :param params_filename: The name of file to load parameters.
        It is used for the case that all parameters are saved in a single binary file.
        If not specified, parameters are considered saved in separate files.

391
    :return: [program, feed_target_names, fetch_targets]
392
             program: program especially for inference.
393 394
             feed_target_names: Names of variables that need to feed data
             fetch_targets: Variables from which we can get inference results.
395 396 397 398
    """
    if not os.path.isdir(dirname):
        raise ValueError("There is no directory named '%s'", dirname)

399 400
    if model_filename is not None:
        model_filename = os.path.basename(model_filename)
401
    else:
402 403 404 405 406
        model_filename = "__model__"
    model_filename = os.path.join(dirname, model_filename)

    if params_filename is not None:
        params_filename = os.path.basename(params_filename)
407

408
    with open(model_filename, "rb") as f:
409 410
        program_desc_str = f.read()

411
    program = Program.parse_from_string(program_desc_str)
412
    load_persistables(executor, dirname, program, params_filename)
413

414 415
    feed_target_names = program.desc.get_feed_target_names()
    fetch_target_names = program.desc.get_fetch_target_names()
416 417 418 419 420
    fetch_targets = [
        program.global_block().var(name) for name in fetch_target_names
    ]

    return [program, feed_target_names, fetch_targets]
X
xuwei06 已提交
421 422 423 424 425 426 427 428


def get_parameter_value(para, executor):
    """
    Get the LoDTensor for the parameter

    :param executor: executor for retrieving the value
    :param para: the given parameter
429

X
xuwei06 已提交
430 431
    :return: the LoDTensor for the parameter
    """
X
xuwei06 已提交
432 433
    assert is_parameter(para)

X
xuwei06 已提交
434 435 436 437 438 439 440 441 442 443 444 445 446
    get_program = Program()
    block = get_program.global_block()
    new_var = _clone_var_in_block_(block, para)
    return executor.run(get_program, feed={}, fetch_list=[new_var])[0]


def get_parameter_value_by_name(name, executor, program=None):
    """
    Get the LoDTensor for paramter with the given name

    :param executor: executor for retrieving the value
    :param name: the name of the parameter
    :param program: the program where the variable is found
Y
Yu Yang 已提交
447
            Default default_main_program().
448

X
xuwei06 已提交
449 450 451
    :return: the LoDTensor for the variable
    """
    if program is None:
Y
Yu Yang 已提交
452
        program = default_main_program()
X
xuwei06 已提交
453 454
    var = program.global_block().var(name)
    return get_parameter_value(var, executor)
T
tangwei12 已提交
455 456


T
tangwei12 已提交
457
SUCCESS_MARK_FILENAME = "_SUCCESS"
458 459
CHECKPOINT_PREFIX = "checkpoint"
CHECKPOINT_SEPARATOR = "_"
T
tangwei12 已提交
460 461 462


def save_checkpoint(executor,
463
                    checkpoint_dir=None,
T
tangwei12 已提交
464 465
                    max_num_checkpoints=3,
                    save_interval_secs=600,
T
tangwei12 已提交
466 467
                    main_program=None):
    """
T
tangwei12 已提交
468
    Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory,
T
tangwei12 已提交
469
    the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy
T
tangwei12 已提交
470
    to keep numbers of checkpoint directory,  the numbers of checkpoint directory are max_num_checkpoints at most,
471
    The interval between two saved checkpoints must greater than save_interval_secs.
T
tangwei12 已提交
472

473 474
    :param executor
    :param checkpoint_dir
T
tangwei12 已提交
475
    :param max_num_checkpoints
476
    :param save_interval_secs
T
tangwei12 已提交
477
    :param main_program
T
tangwei12 已提交
478
    """
479 480
    if checkpoint_dir is None:
        checkpoint_dir = os.getcwd()
T
tangwei12 已提交
481

482 483
    if not os.path.isdir(checkpoint_dir):
        os.makedirs(checkpoint_dir)
T
tangwei12 已提交
484

485
    serial = _get_lastest_checkpoint_dir(checkpoint_dir)
T
tangwei12 已提交
486
    if serial >= 0 and not _interval_secs_exceed(
487
            _get_serial_dir(serial, checkpoint_dir), save_interval_secs):
T
tangwei12 已提交
488
        return
T
tangwei12 已提交
489

490 491
    serial += 1
    cur_dir = _get_serial_dir(serial, checkpoint_dir)
T
tangwei12 已提交
492

T
tangwei12 已提交
493 494 495 496 497
    save_vars(
        executor,
        dirname=cur_dir,
        main_program=main_program,
        vars=None,
T
tangwei12 已提交
498
        predicate=_is_checkpoint_var,
T
tangwei12 已提交
499
        filename=None)
T
tangwei12 已提交
500
    _write_success(cur_dir)
501
    _lru_delete(checkpoint_dir, max_num_checkpoints)
T
tangwei12 已提交
502 503


504
def load_checkpoint(executor, checkpoint_dir=None, main_program=None):
T
tangwei12 已提交
505
    """
T
tangwei12 已提交
506
    Load checkpoint from a directory by executor,
507
    it will find  the most recent saved checkpoint file and load it auto.
T
tangwei12 已提交
508

T
tangwei12 已提交
509
    :param executor
510
    :param checkpoint_dir
T
tangwei12 已提交
511
    :param main_program
T
tangwei12 已提交
512
    """
T
tangwei12 已提交
513

514 515
    if checkpoint_dir is None:
        checkpoint_dir = os.getcwd()
T
tangwei12 已提交
516

517
    serial = _get_lastest_checkpoint_dir(checkpoint_dir)
T
tangwei12 已提交
518

T
tangwei12 已提交
519
    if serial < 0:
T
tangwei12 已提交
520
        return
521 522

    cur_dir = _get_serial_dir(serial, checkpoint_dir)
T
tangwei12 已提交
523

T
tangwei12 已提交
524 525 526 527
    load_vars(
        executor,
        dirname=cur_dir,
        main_program=main_program,
T
tangwei12 已提交
528
        predicate=_is_checkpoint_var,
T
tangwei12 已提交
529 530 531
        filename=None)


532 533 534 535 536
def _get_serial_dir(serial, checkpoint_dir):
    serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial)
    return os.path.join(checkpoint_dir, serial_folder)


T
tangwei12 已提交
537
def _is_checkpoint_var(var):
T
tangwei12 已提交
538
    """
T
tangwei12 已提交
539 540 541
    the checkpoint will not save or load all the variables.
    var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded.

T
tangwei12 已提交
542
    :param var
T
tangwei12 已提交
543
    """
T
tangwei12 已提交
544 545 546 547 548 549 550 551 552
    if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
            var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
            var.desc.type() == core.VarDesc.VarType.RAW:
        return False

    if var.name.endswith("@GRAD"):
        return False

    return var.persistable
T
tangwei12 已提交
553 554


T
tangwei12 已提交
555 556 557 558 559 560 561
def _interval_secs_exceed(dirname, save_interval_secs):
    dir_time = os.path.getmtime(dirname)
    if save_interval_secs > (time.time() - dir_time):
        return False
    return True


T
tangwei12 已提交
562
def _lru_delete(dirname, max_num_checkpoints=3):
T
tangwei12 已提交
563 564 565 566 567 568 569 570
    dirs = os.listdir(dirname)
    serials = []
    for serial in dirs:
        try:
            serials.append(int(serial))
        except ValueError:
            continue

T
tangwei12 已提交
571
    if len(serials) <= max_num_checkpoints:
T
tangwei12 已提交
572 573 574
        return

    serials.sort(reverse=True)
T
tangwei12 已提交
575
    serials = serials[max_num_checkpoints:]
T
tangwei12 已提交
576 577 578 579 580
    for serial in serials:
        cur_dir = os.path.join(dirname, str(serial))
        shutil.rmtree(cur_dir)


T
tangwei12 已提交
581 582
def _write_success(dirname):
    """
T
tangwei12 已提交
583
    write an empty file named "_SUCCESS" in checkpoint dir, indicate this checkpoint is correct.
T
tangwei12 已提交
584 585

    :param dirname
T
tangwei12 已提交
586
    """
T
tangwei12 已提交
587
    success_file = os.path.join(dirname, SUCCESS_MARK_FILENAME)
T
bug fix  
tangwei12 已提交
588
    with open(success_file, 'a') as f:
589
        now = time.ctime()
T
bug fix  
tangwei12 已提交
590
        f.write(now)
T
tangwei12 已提交
591 592 593 594


def _get_lastest_checkpoint_dir(checkpoint_dir):
    """
T
tangwei12 已提交
595 596 597
    get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory

    :param checkpoint_dir
T
tangwei12 已提交
598 599
    """
    if not checkpoint_dir.strip():
T
tangwei12 已提交
600
        return -1
T
tangwei12 已提交
601 602 603 604 605

    def has_success(checkpoint_dir, cur_dir):
        """
        is _SUCCESS in this dir
        """
606
        _, serial = cur_dir.split(CHECKPOINT_SEPARATOR)
T
tangwei12 已提交
607 608

        try:
609
            int(serial)
T
tangwei12 已提交
610 611 612
        except ValueError:
            return -1

613 614 615 616 617
        if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)):
            return -1

        success_path = os.path.join(
            _get_serial_dir(serial, checkpoint_dir), SUCCESS_MARK_FILENAME)
T
tangwei12 已提交
618
        if os.path.isfile(success_path):
619
            return int(serial)
T
tangwei12 已提交
620 621 622 623 624 625 626 627 628 629 630

    if not os.path.isdir(checkpoint_dir):
        return -1

    current_dir = -1
    dirs = os.listdir(checkpoint_dir)
    for cur_dir in dirs:
        success_num = has_success(checkpoint_dir, cur_dir)
        if success_num > current_dir:
            current_dir = success_num
    return current_dir