io.py 27.5 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import os
T
tangwei12 已提交
16 17
import time
import shutil
18

19
from paddle.fluid.evaluator import Evaluator
20
from paddle.fluid.framework import Program, Parameter, default_main_program, default_startup_program, Variable
K
fix bug  
Kexin Zhao 已提交
21
from . import core
22 23

__all__ = [
T
tangwei12 已提交
24 25
    'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
    'load_persistables', 'save_inference_model', 'load_inference_model',
T
tangwei12 已提交
26
    'get_inference_program', 'save_checkpoint', 'load_checkpoint',
27
    'clean_checkpoint', 'load_persist_vars_without_grad',
28
    'save_persist_vars_without_grad', 'get_latest_checkpoint_serial'
29 30 31 32
]


def is_parameter(var):
K
Kavya Srinet 已提交
33
    """Check whether the variable is a Parameter.
34 35 36 37 38 39 40

    This function checks whether the input variable is a Parameter.

    Args:
        var : The input variable.

    Returns:
K
Kavya Srinet 已提交
41
        boolean result whether the variable is a Parameter.
42
    """
43 44 45 46
    return isinstance(var, Parameter)


def is_persistable(var):
47
    if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
Y
Refine  
Yu Yang 已提交
48
            var.desc.type() == core.VarDesc.VarType.FETCH_LIST:
49
        return False
50 51 52 53 54 55 56 57
    return var.persistable


def _clone_var_in_block_(block, var):
    assert isinstance(var, Variable)
    return block.create_var(
        name=var.name,
        shape=var.shape,
F
fengjiayi 已提交
58
        dtype=var.dtype,
59 60 61 62 63
        type=var.type,
        lod_level=var.lod_level,
        persistable=True)


64 65 66 67 68
def save_vars(executor,
              dirname,
              main_program=None,
              vars=None,
              predicate=None,
69
              filename=None):
70 71
    """
    Save variables to directory by executor.
72

73 74
    :param executor: executor that save variable
    :param dirname: directory path
X
xuwei06 已提交
75
    :param main_program: program. If vars is None, then filter all variables in this
76
    program which fit `predicate`. Default default_main_program.
77
    :param predicate: The Predicate describes a callable that returns a variable
78 79
    as a bool. If it returns true, the corresponding input variable will be saved.
    :param vars: variables need to be saved. If vars is specified, program & predicate
80
    will be ignored
81 82
    :param filename: The name of a single file that all vars are saved to.
        If it is None, save variables to separate files.
83

84 85 86
    :return: None
    """
    if vars is None:
87
        if main_program is None:
Y
Yu Yang 已提交
88
            main_program = default_main_program()
89
        if not isinstance(main_program, Program):
90 91 92 93 94
            raise TypeError("program should be as Program type or None")

        save_vars(
            executor,
            dirname=dirname,
95
            vars=filter(predicate, main_program.list_vars()),
96
            filename=filename)
97 98 99
    else:
        save_program = Program()
        save_block = save_program.global_block()
100 101

        save_var_map = {}
102
        for each_var in vars:
103 104 105
            # NOTE: don't save the variable which type is RAW
            if each_var.type == core.VarDesc.VarType.RAW:
                continue
106
            new_var = _clone_var_in_block_(save_block, each_var)
107
            if filename is None:
108 109 110 111 112 113 114 115
                save_block.append_op(
                    type='save',
                    inputs={'X': [new_var]},
                    outputs={},
                    attrs={'file_path': os.path.join(dirname, new_var.name)})
            else:
                save_var_map[new_var.name] = new_var

116
        if filename is not None:
117 118 119 120
            save_var_list = []
            for name in sorted(save_var_map.keys()):
                save_var_list.append(save_var_map[name])

121
            save_block.append_op(
122 123
                type='save_combine',
                inputs={'X': save_var_list},
124
                outputs={},
125
                attrs={'file_path': os.path.join(dirname, filename)})
126

127 128 129
        executor.run(save_program)


130
def save_params(executor, dirname, main_program=None, filename=None):
131 132 133 134 135 136
    """
    Save all parameters to directory with executor.
    """
    save_vars(
        executor,
        dirname=dirname,
137
        main_program=main_program,
138
        vars=None,
139
        predicate=is_parameter,
140
        filename=filename)
141 142


143
def save_persistables(executor, dirname, main_program=None, filename=None):
144 145 146 147 148 149
    """
    Save all persistables to directory with executor.
    """
    save_vars(
        executor,
        dirname=dirname,
150
        main_program=main_program,
151
        vars=None,
152
        predicate=is_persistable,
153
        filename=filename)
154 155


156 157 158 159 160
def load_vars(executor,
              dirname,
              main_program=None,
              vars=None,
              predicate=None,
161
              filename=None):
162 163
    """
    Load variables from directory by executor.
164

165
    :param executor: executor that load variable
166
    :param dirname: directory path
X
xuwei06 已提交
167
    :param main_program: program. If vars is None, then filter all variables in this
Y
Yu Yang 已提交
168
    program which fit `predicate`. Default default_main_program().
169
    :param predicate: The Predicate describes a callable that returns a variable
170 171
    as a bool. If it returns true, the corresponding input variable will be loaded.
    :param vars: variables need to be loaded. If vars is specified, program &
172
    predicate will be ignored
173 174
    :param filename: The name of the single file that all vars are loaded from.
        If it is None, load variables from separate files.
175

176 177 178
    :return: None
    """
    if vars is None:
179
        if main_program is None:
Y
Yu Yang 已提交
180
            main_program = default_main_program()
181
        if not isinstance(main_program, Program):
182 183 184 185 186
            raise TypeError("program's type should be Program")

        load_vars(
            executor,
            dirname=dirname,
187
            vars=filter(predicate, main_program.list_vars()),
188
            filename=filename)
189 190 191
    else:
        load_prog = Program()
        load_block = load_prog.global_block()
192 193

        load_var_map = {}
194 195
        for each_var in vars:
            assert isinstance(each_var, Variable)
T
tangwei12 已提交
196 197
            if each_var.type == core.VarDesc.VarType.RAW:
                continue
198
            new_var = _clone_var_in_block_(load_block, each_var)
199
            if filename is None:
200 201 202 203 204 205 206 207
                load_block.append_op(
                    type='load',
                    inputs={},
                    outputs={'Out': [new_var]},
                    attrs={'file_path': os.path.join(dirname, new_var.name)})
            else:
                load_var_map[new_var.name] = new_var

208
        if filename is not None:
209 210 211 212
            load_var_list = []
            for name in sorted(load_var_map.keys()):
                load_var_list.append(load_var_map[name])

213
            load_block.append_op(
214
                type='load_combine',
215
                inputs={},
216
                outputs={"Out": load_var_list},
217
                attrs={'file_path': os.path.join(dirname, filename)})
218

219 220 221
        executor.run(load_prog)


222
def load_params(executor, dirname, main_program=None, filename=None):
223 224 225 226
    """
    load all parameters from directory by executor.
    """
    load_vars(
227 228 229
        executor,
        dirname=dirname,
        main_program=main_program,
230
        predicate=is_parameter,
231
        filename=filename)
232 233


234
def load_persistables(executor, dirname, main_program=None, filename=None):
235 236 237 238
    """
    load all persistables from directory by executor.
    """
    load_vars(
239 240 241
        executor,
        dirname=dirname,
        main_program=main_program,
242
        predicate=is_persistable,
243
        filename=filename)
244 245


246 247
def get_inference_program(target_vars, main_program=None):
    if main_program is None:
Y
Yu Yang 已提交
248
        main_program = default_main_program()
249 250
    if not isinstance(target_vars, list):
        target_vars = [target_vars]
W
wanghaoshuang 已提交
251 252 253
    vars = []
    for var in target_vars:
        if isinstance(var, Evaluator):
W
wanghaoshuang 已提交
254 255
            vars.extend(var.states)
            vars.extend(var.metrics)
W
wanghaoshuang 已提交
256 257 258
        else:
            vars.append(var)
    pruned_program = main_program.prune(targets=vars)
259 260 261 262
    inference_program = pruned_program.inference_optimize()
    return inference_program


263 264 265
def prepend_feed_ops(inference_program,
                     feed_target_names,
                     feed_holder_name='feed'):
Q
Qiao Longfei 已提交
266 267 268
    if len(feed_target_names) == 0:
        return

K
Kexin Zhao 已提交
269 270
    global_block = inference_program.global_block()
    feed_var = global_block.create_var(
271 272 273
        name=feed_holder_name,
        type=core.VarDesc.VarType.FEED_MINIBATCH,
        persistable=True)
K
Kexin Zhao 已提交
274

275
    for i, name in enumerate(feed_target_names):
K
fix bug  
Kexin Zhao 已提交
276
        out = global_block.var(name)
K
Kexin Zhao 已提交
277 278 279
        global_block.prepend_op(
            type='feed',
            inputs={'X': [feed_var]},
K
fix bug  
Kexin Zhao 已提交
280
            outputs={'Out': [out]},
K
Kexin Zhao 已提交
281 282 283
            attrs={'col': i})


284 285 286
def append_fetch_ops(inference_program,
                     fetch_target_names,
                     fetch_holder_name='fetch'):
K
Kexin Zhao 已提交
287 288
    global_block = inference_program.global_block()
    fetch_var = global_block.create_var(
289 290 291
        name=fetch_holder_name,
        type=core.VarDesc.VarType.FETCH_LIST,
        persistable=True)
K
Kexin Zhao 已提交
292

293
    for i, name in enumerate(fetch_target_names):
K
Kexin Zhao 已提交
294 295 296 297 298 299 300
        global_block.append_op(
            type='fetch',
            inputs={'X': [name]},
            outputs={'Out': [fetch_var]},
            attrs={'col': i})


301 302 303 304
def save_inference_model(dirname,
                         feeded_var_names,
                         target_vars,
                         executor,
305
                         main_program=None,
306 307
                         model_filename=None,
                         params_filename=None):
308
    """
X
xuwei06 已提交
309
    Build a model especially for inference,
310 311 312 313 314 315
    and save it to directory by the executor.

    :param dirname: directory path
    :param feeded_var_names: Names of variables that need to be feeded data during inference
    :param target_vars: Variables from which we can get inference results.
    :param executor: executor that save inference model
X
xuwei06 已提交
316
    :param main_program: original program, which will be pruned to build the inference model.
Y
Yu Yang 已提交
317
            Default default_main_program().
318 319 320 321 322
    :param model_filename: The name of file to save inference program.
        If not specified, default filename `__model__` will be used.
    :param params_filename: The name of file to save parameters.
        It is used for the case that all parameters are saved in a single binary file.
        If not specified, parameters are considered saved in separate files.
323 324 325

    :return: None
    """
F
fengjiayi 已提交
326 327 328
    if isinstance(feeded_var_names, basestring):
        feeded_var_names = [feeded_var_names]
    else:
Q
Qiao Longfei 已提交
329 330 331 332
        if len(feeded_var_names) > 0:
            if not (bool(feeded_var_names) and all(
                    isinstance(name, basestring) for name in feeded_var_names)):
                raise ValueError("'feed_var_names' should be a list of str.")
F
fengjiayi 已提交
333 334

    if isinstance(target_vars, Variable):
F
fengjiayi 已提交
335
        target_vars = [target_vars]
F
fengjiayi 已提交
336 337 338 339 340
    else:
        if not (bool(target_vars) and all(
                isinstance(var, Variable) for var in target_vars)):
            raise ValueError("'target_vars' should be a list of Variable.")

341
    if main_program is None:
Y
Yu Yang 已提交
342
        main_program = default_main_program()
343
    copy_program = main_program.clone()
344 345 346 347

    if not os.path.isdir(dirname):
        os.makedirs(dirname)

348
    # Clear the is_target information and remove the existed feed and fetch op
349
    global_block = copy_program.global_block()
350 351 352 353
    for i, op in enumerate(global_block.ops):
        op.desc.set_is_target(False)
        if op.type == "feed" or op.type == "fetch":
            global_block.remove_op(i)
354
    copy_program.desc.flush()
355

356
    pruned_program = copy_program.prune(targets=target_vars)
357
    inference_program = pruned_program.inference_optimize()
358 359
    fetch_var_names = [v.name for v in target_vars]

K
Kexin Zhao 已提交
360 361
    prepend_feed_ops(inference_program, feeded_var_names)
    append_fetch_ops(inference_program, fetch_var_names)
362

363 364
    if model_filename is not None:
        model_filename = os.path.basename(model_filename)
365
    else:
366 367
        model_filename = "__model__"
    model_filename = os.path.join(dirname, model_filename)
368

369 370 371 372
    if params_filename is not None:
        params_filename = os.path.basename(params_filename)

    with open(model_filename, "wb") as f:
373
        f.write(inference_program.desc.serialize_to_string())
374

375
    save_persistables(executor, dirname, inference_program, params_filename)
376 377


378 379 380 381
def load_inference_model(dirname,
                         executor,
                         model_filename=None,
                         params_filename=None):
382 383 384 385 386
    """
    Load inference model from a directory

    :param dirname: directory path
    :param executor: executor that load inference model
387 388 389 390 391 392
    :param model_filename: The name of file to load inference program.
        If not specified, default filename `__model__` will be used.
    :param params_filename: The name of file to load parameters.
        It is used for the case that all parameters are saved in a single binary file.
        If not specified, parameters are considered saved in separate files.

393
    :return: [program, feed_target_names, fetch_targets]
394
             program: program especially for inference.
395 396
             feed_target_names: Names of variables that need to feed data
             fetch_targets: Variables from which we can get inference results.
397 398 399 400
    """
    if not os.path.isdir(dirname):
        raise ValueError("There is no directory named '%s'", dirname)

401 402
    if model_filename is not None:
        model_filename = os.path.basename(model_filename)
403
    else:
404 405 406 407 408
        model_filename = "__model__"
    model_filename = os.path.join(dirname, model_filename)

    if params_filename is not None:
        params_filename = os.path.basename(params_filename)
409

410
    with open(model_filename, "rb") as f:
411 412
        program_desc_str = f.read()

413
    program = Program.parse_from_string(program_desc_str)
414
    load_persistables(executor, dirname, program, params_filename)
415

416 417
    feed_target_names = program.desc.get_feed_target_names()
    fetch_target_names = program.desc.get_fetch_target_names()
418 419 420 421 422
    fetch_targets = [
        program.global_block().var(name) for name in fetch_target_names
    ]

    return [program, feed_target_names, fetch_targets]
X
xuwei06 已提交
423 424 425 426 427 428 429 430


def get_parameter_value(para, executor):
    """
    Get the LoDTensor for the parameter

    :param executor: executor for retrieving the value
    :param para: the given parameter
431

X
xuwei06 已提交
432 433
    :return: the LoDTensor for the parameter
    """
X
xuwei06 已提交
434 435
    assert is_parameter(para)

X
xuwei06 已提交
436 437 438 439 440 441 442 443 444 445 446 447 448
    get_program = Program()
    block = get_program.global_block()
    new_var = _clone_var_in_block_(block, para)
    return executor.run(get_program, feed={}, fetch_list=[new_var])[0]


def get_parameter_value_by_name(name, executor, program=None):
    """
    Get the LoDTensor for paramter with the given name

    :param executor: executor for retrieving the value
    :param name: the name of the parameter
    :param program: the program where the variable is found
Y
Yu Yang 已提交
449
            Default default_main_program().
450

X
xuwei06 已提交
451 452 453
    :return: the LoDTensor for the variable
    """
    if program is None:
Y
Yu Yang 已提交
454
        program = default_main_program()
X
xuwei06 已提交
455 456
    var = program.global_block().var(name)
    return get_parameter_value(var, executor)
T
tangwei12 已提交
457 458


T
tangwei12 已提交
459
SUCCESS_MARK_FILENAME = "_SUCCESS"
460
CHECKPOINT_PREFIX = "checkpoint"
T
tangwei12 已提交
461 462
MODEL_DIR = "__model__"
TRAINER_PREFIX = "trainer"
463
CHECKPOINT_SEPARATOR = "_"
T
tangwei12 已提交
464 465 466


def save_checkpoint(executor,
T
tangwei12 已提交
467
                    checkpoint_dir,
T
tangwei12 已提交
468 469
                    trainer_id,
                    trainer_args=None,
T
tangwei12 已提交
470 471
                    main_program=None,
                    max_num_checkpoints=3):
T
tangwei12 已提交
472
    """
T
tangwei12 已提交
473
    Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory,
T
tangwei12 已提交
474
    the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy
T
tangwei12 已提交
475
    to keep numbers of checkpoint directory,  the numbers of checkpoint directory are max_num_checkpoints at most,
476
    The interval between two saved checkpoints must greater than save_interval_secs.
T
tangwei12 已提交
477

T
tangwei12 已提交
478 479
    :param executor executor for save the value
    :param checkpoint_dir the checkpoint directory 
T
tangwei12 已提交
480
    :param trainer_id currect trainer id, if id is equal to 0, the trainer is chief
T
tangwei12 已提交
481 482 483 484
    :param main_program   will save all variables in program 
    :param max_num_checkpoints will keep numbers of checkpoint serials not bigger than max_num_checkpoints
    """
    if checkpoint_dir is None:
T
tangwei12 已提交
485
        raise ValueError("'checkpoint_dir' should not be None")
T
tangwei12 已提交
486

T
tangwei12 已提交
487 488
    if trainer_args:
        assert isinstance(trainer_args, dict)
T
tangwei12 已提交
489

490 491
    if not os.path.isdir(checkpoint_dir):
        os.makedirs(checkpoint_dir)
T
tangwei12 已提交
492

T
tangwei12 已提交
493
    serial = get_latest_checkpoint_serial(checkpoint_dir) + 1
T
tangwei12 已提交
494
    cur_dir = _get_serial_dir(checkpoint_dir, serial)
T
tangwei12 已提交
495

T
tangwei12 已提交
496 497
    save_trainer_args(cur_dir, trainer_id, trainer_args)

T
tangwei12 已提交
498
    if trainer_id == 0:
T
tangwei12 已提交
499
        save_persist_vars_without_grad(executor, cur_dir, main_program)
T
tangwei12 已提交
500

T
tangwei12 已提交
501
    _scroll_delete(checkpoint_dir, max_num_checkpoints)
T
tangwei12 已提交
502 503


T
tangwei12 已提交
504
def load_checkpoint(executor, checkpoint_dir, serial, main_program):
T
tangwei12 已提交
505
    """
T
tangwei12 已提交
506
    Load checkpoint from a directory by executor,
507
    it will find  the most recent saved checkpoint file and load it auto.
T
tangwei12 已提交
508

T
tangwei12 已提交
509 510 511 512
    :param executor executor for load the value
    :param checkpoint_dir  the checkpoint directory 
    :param serial the serial folder in checkpoint directory will be load
    :param main_program  will load all variables in program 
T
tangwei12 已提交
513
    """
T
tangwei12 已提交
514

T
tangwei12 已提交
515
    if checkpoint_dir is None:
T
tangwei12 已提交
516
        raise ValueError("'checkpoint_dir' should not be None")
T
tangwei12 已提交
517

T
tangwei12 已提交
518
    if serial is None or serial < 0:
T
tangwei12 已提交
519
        raise ValueError("'serial' should not be None or <0 ")
T
tangwei12 已提交
520

T
tangwei12 已提交
521
    if main_program is None:
T
tangwei12 已提交
522
        raise ValueError('main_program should not be None.')
523

T
tangwei12 已提交
524
    cur_dir = _get_serial_dir(checkpoint_dir, serial)
T
tangwei12 已提交
525
    load_persist_vars_without_grad(executor, cur_dir, main_program, True)
T
tangwei12 已提交
526 527


T
tangwei12 已提交
528 529 530 531
def clean_checkpoint(checkpoint_dir, delete_dir=False):
    """
    clean the checkpoint dir, when the train exits normally, the trainer will call clean_checkpoint to delete checkpoint directory saved before.
    delete_dir only works when the directory is empty, otherwise, OSError is raised.  
532 533 534

    :param checkpoint_dir
    :param delete_dir
T
tangwei12 已提交
535
    """
536

T
tangwei12 已提交
537
    if checkpoint_dir is None:
T
tangwei12 已提交
538
        raise ValueError("'checkpoint_dir' should not be None")
T
tangwei12 已提交
539
    _scroll_delete(checkpoint_dir, max_num_checkpoints=0)
T
tangwei12 已提交
540 541 542 543 544

    if delete_dir and not os.listdir(checkpoint_dir):
        os.rmdir(checkpoint_dir)


T
tangwei12 已提交
545 546 547 548
def load_persist_vars_without_grad(executor,
                                   dirname,
                                   program,
                                   has_model_dir=False):
T
tangwei12 已提交
549 550 551
    """
    load_persist_vars_without_grad will load variables from a directory by an executor,
    the variable named end with "@GRAD" will not be loaded.
552

T
tangwei12 已提交
553 554 555 556
    :param executor  executor for load the value
    :param dirname the checkpoint directory 
    :param program   will load all variables in program 
    :param has_model_dir if has_model_dir is True, will load variables from  sub directory named __model__
T
tangwei12 已提交
557 558
    """

T
tangwei12 已提交
559
    if has_model_dir:
T
tangwei12 已提交
560 561 562 563 564 565 566 567 568 569 570 571 572 573
        dirname = _get_model_dir(dirname)

    load_vars(
        executor,
        dirname=dirname,
        main_program=program,
        predicate=_is_checkpoint_var,
        filename=None)


def save_persist_vars_without_grad(executor, dirname, program):
    """
    save_persist_vars_without_grad  will save variables to a directory by an executor,
    the variable named end with "@GRAD" will not be saved.
574

T
tangwei12 已提交
575 576 577
    :param executor  executor for load the value
    :param dirname the checkpoint directory 
    :param program   will load all variables in program
T
tangwei12 已提交
578 579 580 581 582 583 584 585 586 587 588 589 590
    """
    cur_dir = _get_model_dir(dirname)
    save_vars(
        executor,
        dirname=cur_dir,
        main_program=program,
        vars=None,
        predicate=_is_checkpoint_var,
        filename=None)
    _write_success(cur_dir)


def save_trainer_args(dirname, trainer_id, trainer_args):
T
tangwei12 已提交
591 592
    assert isinstance(trainer_args, dict)

T
tangwei12 已提交
593 594 595 596 597 598 599 600 601 602
    cur_dir = _get_trainer_dir(dirname, trainer_id)

    for name, value in trainer_args.iteritems():
        args_file = os.path.join(cur_dir, name)
        with open(args_file, 'w') as f:
            f.write(str(value))
    _write_success(cur_dir)


def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args):
T
tangwei12 已提交
603 604
    assert isinstance(trainer_args, list)

T
tangwei12 已提交
605 606 607 608 609 610 611 612 613 614 615 616 617
    cur_dir = _get_serial_dir(checkpoint_dir, serial)
    cur_dir = _get_trainer_dir(cur_dir, trainer_id)

    ret_values = []

    for arg in trainer_args:
        cur_file = os.path.join(cur_dir, arg)
        with open(cur_file, 'r') as f:
            contents = f.read()
            ret_values.append(contents.strip())
    return ret_values


T
tangwei12 已提交
618
def _is_checkpoint_var(var):
T
tangwei12 已提交
619
    """
T
tangwei12 已提交
620 621 622
    the checkpoint will not save or load all the variables.
    var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded.

T
tangwei12 已提交
623
    :param var
T
tangwei12 已提交
624
    """
T
tangwei12 已提交
625 626 627 628
    if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
            var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
            var.desc.type() == core.VarDesc.VarType.RAW:
        return False
T
tangwei12 已提交
629
    # @GRAD are named for gradient variables, checkpoint will not save it.
T
tangwei12 已提交
630 631
    if "@GRAD" in var.name:
        return False
T
tangwei12 已提交
632
    # .trainer_ are named for distribute train variables, checkpoint will not save it.
T
tangwei12 已提交
633 634 635
    if ".trainer_" in var.name:
        return False

T
tangwei12 已提交
636
    # .block is named for distribute train variables, checkpoint will not save it.
T
tangwei12 已提交
637
    if ".block" in var.name:
T
tangwei12 已提交
638 639 640
        return False

    return var.persistable
T
tangwei12 已提交
641 642


T
tangwei12 已提交
643 644 645 646 647 648 649 650 651 652 653 654
def _get_dir_serial(dirname):
    _, serial = dirname.split(CHECKPOINT_SEPARATOR)

    try:
        serial_num = int(serial)
    except ValueError:
        serial_num = -1
    return serial_num


def _get_serial_dir(dirname, serial):
    serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial)
T
tangwei12 已提交
655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679
    serial_dir = os.path.join(dirname, serial_folder)

    if not os.path.isdir(serial_dir):
        os.makedirs(serial_dir)

    return serial_dir


def _get_model_dir(dirname):
    model_dir = os.path.join(dirname, MODEL_DIR)

    if not os.path.isdir(model_dir):
        os.makedirs(model_dir)

    return model_dir


def _get_trainer_dir(dirname, trainer_id):
    trainer_folder = TRAINER_PREFIX + CHECKPOINT_SEPARATOR + str(trainer_id)
    trainer_dir = os.path.join(dirname, trainer_folder)

    if not os.path.isdir(trainer_dir):
        os.makedirs(trainer_dir)

    return trainer_dir
T
tangwei12 已提交
680 681


T
tangwei12 已提交
682
def _scroll_delete(dirname, max_num_checkpoints=3):
T
tangwei12 已提交
683
    dirs = os.listdir(dirname)
T
tangwei12 已提交
684
    serial_map = {}
T
tangwei12 已提交
685
    for serial in dirs:
T
tangwei12 已提交
686 687
        serial_num = _get_dir_serial(serial)
        serial_map[serial_num] = serial
T
tangwei12 已提交
688

T
tangwei12 已提交
689
    if len(serial_map.keys()) <= max_num_checkpoints:
T
tangwei12 已提交
690 691
        return

T
tangwei12 已提交
692
    serials = serial_map.keys()
T
tangwei12 已提交
693
    serials.sort(reverse=True)
T
tangwei12 已提交
694
    serials = serials[max_num_checkpoints:]
T
tangwei12 已提交
695
    for serial in serials:
T
tangwei12 已提交
696
        cur_dir = _get_serial_dir(dirname, serial)
T
tangwei12 已提交
697 698 699
        shutil.rmtree(cur_dir)


T
tangwei12 已提交
700 701
def _write_success(dirname):
    """
T
tangwei12 已提交
702
    write an empty file named "_SUCCESS" in checkpoint dir, indicate this checkpoint is correct.
T
tangwei12 已提交
703 704

    :param dirname
T
tangwei12 已提交
705
    """
T
tangwei12 已提交
706
    success_file = os.path.join(dirname, SUCCESS_MARK_FILENAME)
T
bug fix  
tangwei12 已提交
707
    with open(success_file, 'a') as f:
708
        now = time.ctime()
T
bug fix  
tangwei12 已提交
709
        f.write(now)
T
tangwei12 已提交
710 711


T
tangwei12 已提交
712
def get_latest_checkpoint_serial(checkpoint_dir):
T
tangwei12 已提交
713
    """
T
tangwei12 已提交
714 715 716
    get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory

    :param checkpoint_dir
T
tangwei12 已提交
717
    """
T
tangwei12 已提交
718
    if not checkpoint_dir:
T
tangwei12 已提交
719
        return -1
T
tangwei12 已提交
720 721 722 723 724 725

    def has_success(checkpoint_dir, cur_dir):
        """
        is _SUCCESS in this dir
        """

T
tangwei12 已提交
726
        serial = _get_dir_serial(cur_dir)
T
tangwei12 已提交
727 728
        if serial == -1 or not os.path.isdir(
                os.path.join(checkpoint_dir, cur_dir)):
729 730 731
            return -1

        success_path = os.path.join(
T
tangwei12 已提交
732 733
            _get_serial_dir(checkpoint_dir, serial), MODEL_DIR,
            SUCCESS_MARK_FILENAME)
T
tangwei12 已提交
734
        if os.path.isfile(success_path):
T
tangwei12 已提交
735
            return serial
T
tangwei12 已提交
736 737 738 739 740 741 742 743 744 745 746

    if not os.path.isdir(checkpoint_dir):
        return -1

    current_dir = -1
    dirs = os.listdir(checkpoint_dir)
    for cur_dir in dirs:
        success_num = has_success(checkpoint_dir, cur_dir)
        if success_num > current_dir:
            current_dir = success_num
    return current_dir
747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764


def get_test_program(filelist, program=None, startup_program=None):
    """
    Transpile current train program to a program to read test dataset
    if the program is using reader ops like "open_files_op".
    """

    def _copy_reader_var_(block, var, new_name=None):
        if new_name == None:
            new_name = var.name
        new_var = block.create_var(
            name=str(new_name), type=core.VarDesc.VarType.READER)
        new_var.desc.set_shapes(var.desc.shapes())
        new_var.desc.set_dtypes(var.desc.dtypes())
        new_var.persistable = True
        return new_var

F
fengjiayi 已提交
765
    def _get_test_reader_name(train_reader_name):
766 767
        return train_reader_name + "_test"

F
fengjiayi 已提交
768
    def _is_reader_op(op):
769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785
        block = op.block
        if "Out" in op.output_names:
            reader_out = block.vars[op.output("Out")[0]]
            if reader_out.type == core.VarDesc.VarType.READER:
                return True
        return False

    if program == None:
        program = default_main_program()
    if startup_program == None:
        startup_program = default_startup_program()
    startup_block = startup_program.global_block()

    # 1. find out the orignal reader var name
    startup_reader_op_list = []

    for op in startup_block.ops:
F
fengjiayi 已提交
786
        if _is_reader_op(op):
787 788 789 790 791 792 793 794 795 796 797 798 799 800 801
            startup_reader_op_list.append(op)

    if len(startup_reader_op_list) == 0:
        return program

    root_reader_op = startup_reader_op_list[0]
    train_test_reader_map = {}
    # 2. add operators to startup to read open and read test data files
    for op in startup_reader_op_list:
        assert (len(op.output("Out")) == 1)
        train_reader_name = op.output("Out")[0]
        train_reader = startup_block.vars[train_reader_name]
        test_reader = _copy_reader_var_(
            startup_block,
            train_reader,
F
fengjiayi 已提交
802
            new_name=_get_test_reader_name(train_reader_name))
803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832
        train_test_reader_map[train_reader.name] = test_reader

        test_op_inputs = {}
        for name in op.input_names:
            train_arg_names = op.input(name)
            test_arg_vars = []
            for arg_name in train_arg_names:
                arg_var = train_test_reader_map[
                    arg_name] if name == "UnderlyingReader" else startup_block.vars[
                        arg_name]
                test_arg_vars.append(arg_var)
            test_op_inputs[name] = test_arg_vars

        test_op = startup_block.append_op(
            type=op.type,
            inputs=test_op_inputs,
            outputs={'Out': [test_reader]},
            attrs=op.attrs)
        # root reader op's filelist attr for read test files
        if op.type == root_reader_op.type:
            test_op.set_attr("file_names", filelist)
        if op.type == "create_multi_pass_reader":
            test_op.set_attr("pass_num", 1)

    # 3. rename reader vars in inference program to different name
    #    to avoid read from train data.
    main_block = program.global_block()
    for var in main_block.vars.values():
        if var.type == core.VarDesc.VarType.READER:
            main_block.rename_var(
F
fengjiayi 已提交
833
                str(var.name), str(_get_test_reader_name(var.name)))
834 835 836 837 838 839 840 841 842 843 844

    for op in main_block.ops:
        if op.type == root_reader_op.type:
            test_op.set_attr("file_names", filelist)
        if op.type == "create_multi_pass_reader":
            test_op.set_attr("pass_num", 1)

    startup_program.sync_with_cpp()
    program.sync_with_cpp()

    return program