io.py 23.7 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import os
T
tangwei12 已提交
16 17
import time
import shutil
18

19 20
from paddle.fluid.evaluator import Evaluator
from paddle.fluid.framework import Program, Parameter, default_main_program, Variable
K
fix bug  
Kexin Zhao 已提交
21
from . import core
22 23

__all__ = [
T
tangwei12 已提交
24 25
    'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
    'load_persistables', 'save_inference_model', 'load_inference_model',
T
tangwei12 已提交
26
    'get_inference_program', 'save_checkpoint', 'load_checkpoint',
27 28
    'clean_checkpoint', 'load_persist_vars_without_grad',
    'save_persist_vars_without_grad'
29 30 31 32
]


def is_parameter(var):
K
Kavya Srinet 已提交
33
    """Check whether the variable is a Parameter.
34 35 36 37 38 39 40

    This function checks whether the input variable is a Parameter.

    Args:
        var : The input variable.

    Returns:
K
Kavya Srinet 已提交
41
        boolean result whether the variable is a Parameter.
42
    """
43 44 45 46
    return isinstance(var, Parameter)


def is_persistable(var):
47
    if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
Y
Refine  
Yu Yang 已提交
48
            var.desc.type() == core.VarDesc.VarType.FETCH_LIST:
49
        return False
50 51 52 53 54 55 56 57
    return var.persistable


def _clone_var_in_block_(block, var):
    assert isinstance(var, Variable)
    return block.create_var(
        name=var.name,
        shape=var.shape,
F
fengjiayi 已提交
58
        dtype=var.dtype,
59 60 61 62 63
        type=var.type,
        lod_level=var.lod_level,
        persistable=True)


64 65 66 67 68
def save_vars(executor,
              dirname,
              main_program=None,
              vars=None,
              predicate=None,
69
              filename=None):
70 71
    """
    Save variables to directory by executor.
72

73 74
    :param executor: executor that save variable
    :param dirname: directory path
X
xuwei06 已提交
75
    :param main_program: program. If vars is None, then filter all variables in this
76
    program which fit `predicate`. Default default_main_program.
77
    :param predicate: The Predicate describes a callable that returns a variable
78 79
    as a bool. If it returns true, the corresponding input variable will be saved.
    :param vars: variables need to be saved. If vars is specified, program & predicate
80
    will be ignored
81 82
    :param filename: The name of a single file that all vars are saved to.
        If it is None, save variables to separate files.
83

84 85 86
    :return: None
    """
    if vars is None:
87
        if main_program is None:
Y
Yu Yang 已提交
88
            main_program = default_main_program()
89
        if not isinstance(main_program, Program):
90 91 92 93 94
            raise TypeError("program should be as Program type or None")

        save_vars(
            executor,
            dirname=dirname,
95
            vars=filter(predicate, main_program.list_vars()),
96
            filename=filename)
97 98 99
    else:
        save_program = Program()
        save_block = save_program.global_block()
100 101

        save_var_map = {}
102
        for each_var in vars:
103 104 105
            # NOTE: don't save the variable which type is RAW
            if each_var.type == core.VarDesc.VarType.RAW:
                continue
106
            new_var = _clone_var_in_block_(save_block, each_var)
107
            if filename is None:
108 109 110 111 112 113 114 115
                save_block.append_op(
                    type='save',
                    inputs={'X': [new_var]},
                    outputs={},
                    attrs={'file_path': os.path.join(dirname, new_var.name)})
            else:
                save_var_map[new_var.name] = new_var

116
        if filename is not None:
117 118 119 120
            save_var_list = []
            for name in sorted(save_var_map.keys()):
                save_var_list.append(save_var_map[name])

121
            save_block.append_op(
122 123
                type='save_combine',
                inputs={'X': save_var_list},
124
                outputs={},
125
                attrs={'file_path': os.path.join(dirname, filename)})
126

127 128 129
        executor.run(save_program)


130
def save_params(executor, dirname, main_program=None, filename=None):
131 132 133 134 135 136
    """
    Save all parameters to directory with executor.
    """
    save_vars(
        executor,
        dirname=dirname,
137
        main_program=main_program,
138
        vars=None,
139
        predicate=is_parameter,
140
        filename=filename)
141 142


143
def save_persistables(executor, dirname, main_program=None, filename=None):
144 145 146 147 148 149
    """
    Save all persistables to directory with executor.
    """
    save_vars(
        executor,
        dirname=dirname,
150
        main_program=main_program,
151
        vars=None,
152
        predicate=is_persistable,
153
        filename=filename)
154 155


156 157 158 159 160
def load_vars(executor,
              dirname,
              main_program=None,
              vars=None,
              predicate=None,
161
              filename=None):
162 163
    """
    Load variables from directory by executor.
164

165
    :param executor: executor that load variable
166
    :param dirname: directory path
X
xuwei06 已提交
167
    :param main_program: program. If vars is None, then filter all variables in this
Y
Yu Yang 已提交
168
    program which fit `predicate`. Default default_main_program().
169
    :param predicate: The Predicate describes a callable that returns a variable
170 171
    as a bool. If it returns true, the corresponding input variable will be loaded.
    :param vars: variables need to be loaded. If vars is specified, program &
172
    predicate will be ignored
173 174
    :param filename: The name of the single file that all vars are loaded from.
        If it is None, load variables from separate files.
175

176 177 178
    :return: None
    """
    if vars is None:
179
        if main_program is None:
Y
Yu Yang 已提交
180
            main_program = default_main_program()
181
        if not isinstance(main_program, Program):
182 183 184 185 186
            raise TypeError("program's type should be Program")

        load_vars(
            executor,
            dirname=dirname,
187
            vars=filter(predicate, main_program.list_vars()),
188
            filename=filename)
189 190 191
    else:
        load_prog = Program()
        load_block = load_prog.global_block()
192 193

        load_var_map = {}
194 195
        for each_var in vars:
            assert isinstance(each_var, Variable)
T
tangwei12 已提交
196 197
            if each_var.type == core.VarDesc.VarType.RAW:
                continue
198
            new_var = _clone_var_in_block_(load_block, each_var)
199
            if filename is None:
200 201 202 203 204 205 206 207
                load_block.append_op(
                    type='load',
                    inputs={},
                    outputs={'Out': [new_var]},
                    attrs={'file_path': os.path.join(dirname, new_var.name)})
            else:
                load_var_map[new_var.name] = new_var

208
        if filename is not None:
209 210 211 212
            load_var_list = []
            for name in sorted(load_var_map.keys()):
                load_var_list.append(load_var_map[name])

213
            load_block.append_op(
214
                type='load_combine',
215
                inputs={},
216
                outputs={"Out": load_var_list},
217
                attrs={'file_path': os.path.join(dirname, filename)})
218

219 220 221
        executor.run(load_prog)


222
def load_params(executor, dirname, main_program=None, filename=None):
223 224 225 226
    """
    load all parameters from directory by executor.
    """
    load_vars(
227 228 229
        executor,
        dirname=dirname,
        main_program=main_program,
230
        predicate=is_parameter,
231
        filename=filename)
232 233


234
def load_persistables(executor, dirname, main_program=None, filename=None):
235 236 237 238
    """
    load all persistables from directory by executor.
    """
    load_vars(
239 240 241
        executor,
        dirname=dirname,
        main_program=main_program,
242
        predicate=is_persistable,
243
        filename=filename)
244 245


246 247
def get_inference_program(target_vars, main_program=None):
    if main_program is None:
Y
Yu Yang 已提交
248
        main_program = default_main_program()
249 250
    if not isinstance(target_vars, list):
        target_vars = [target_vars]
W
wanghaoshuang 已提交
251 252 253
    vars = []
    for var in target_vars:
        if isinstance(var, Evaluator):
W
wanghaoshuang 已提交
254 255
            vars.extend(var.states)
            vars.extend(var.metrics)
W
wanghaoshuang 已提交
256 257 258
        else:
            vars.append(var)
    pruned_program = main_program.prune(targets=vars)
259 260 261 262
    inference_program = pruned_program.inference_optimize()
    return inference_program


263 264 265
def prepend_feed_ops(inference_program,
                     feed_target_names,
                     feed_holder_name='feed'):
Q
Qiao Longfei 已提交
266 267 268
    if len(feed_target_names) == 0:
        return

K
Kexin Zhao 已提交
269 270
    global_block = inference_program.global_block()
    feed_var = global_block.create_var(
271 272 273
        name=feed_holder_name,
        type=core.VarDesc.VarType.FEED_MINIBATCH,
        persistable=True)
K
Kexin Zhao 已提交
274

275
    for i, name in enumerate(feed_target_names):
K
fix bug  
Kexin Zhao 已提交
276
        out = global_block.var(name)
K
Kexin Zhao 已提交
277 278 279
        global_block.prepend_op(
            type='feed',
            inputs={'X': [feed_var]},
K
fix bug  
Kexin Zhao 已提交
280
            outputs={'Out': [out]},
K
Kexin Zhao 已提交
281 282 283
            attrs={'col': i})


284 285 286
def append_fetch_ops(inference_program,
                     fetch_target_names,
                     fetch_holder_name='fetch'):
K
Kexin Zhao 已提交
287 288
    global_block = inference_program.global_block()
    fetch_var = global_block.create_var(
289 290 291
        name=fetch_holder_name,
        type=core.VarDesc.VarType.FETCH_LIST,
        persistable=True)
K
Kexin Zhao 已提交
292

293
    for i, name in enumerate(fetch_target_names):
K
Kexin Zhao 已提交
294 295 296 297 298 299 300
        global_block.append_op(
            type='fetch',
            inputs={'X': [name]},
            outputs={'Out': [fetch_var]},
            attrs={'col': i})


301 302 303 304
def save_inference_model(dirname,
                         feeded_var_names,
                         target_vars,
                         executor,
305
                         main_program=None,
306 307
                         model_filename=None,
                         params_filename=None):
308
    """
X
xuwei06 已提交
309
    Build a model especially for inference,
310 311 312 313 314 315
    and save it to directory by the executor.

    :param dirname: directory path
    :param feeded_var_names: Names of variables that need to be feeded data during inference
    :param target_vars: Variables from which we can get inference results.
    :param executor: executor that save inference model
X
xuwei06 已提交
316
    :param main_program: original program, which will be pruned to build the inference model.
Y
Yu Yang 已提交
317
            Default default_main_program().
318 319 320 321 322
    :param model_filename: The name of file to save inference program.
        If not specified, default filename `__model__` will be used.
    :param params_filename: The name of file to save parameters.
        It is used for the case that all parameters are saved in a single binary file.
        If not specified, parameters are considered saved in separate files.
323 324 325

    :return: None
    """
F
fengjiayi 已提交
326 327 328
    if isinstance(feeded_var_names, basestring):
        feeded_var_names = [feeded_var_names]
    else:
Q
Qiao Longfei 已提交
329 330 331 332
        if len(feeded_var_names) > 0:
            if not (bool(feeded_var_names) and all(
                    isinstance(name, basestring) for name in feeded_var_names)):
                raise ValueError("'feed_var_names' should be a list of str.")
F
fengjiayi 已提交
333 334

    if isinstance(target_vars, Variable):
F
fengjiayi 已提交
335
        target_vars = [target_vars]
F
fengjiayi 已提交
336 337 338 339 340
    else:
        if not (bool(target_vars) and all(
                isinstance(var, Variable) for var in target_vars)):
            raise ValueError("'target_vars' should be a list of Variable.")

341
    if main_program is None:
Y
Yu Yang 已提交
342
        main_program = default_main_program()
343
    copy_program = main_program.clone()
344 345 346 347

    if not os.path.isdir(dirname):
        os.makedirs(dirname)

348
    # Clear the is_target information and remove the existed feed and fetch op
349
    global_block = copy_program.global_block()
350 351 352 353
    for i, op in enumerate(global_block.ops):
        op.desc.set_is_target(False)
        if op.type == "feed" or op.type == "fetch":
            global_block.remove_op(i)
354
    copy_program.desc.flush()
355

356
    pruned_program = copy_program.prune(targets=target_vars)
357
    inference_program = pruned_program.inference_optimize()
358 359
    fetch_var_names = [v.name for v in target_vars]

K
Kexin Zhao 已提交
360 361
    prepend_feed_ops(inference_program, feeded_var_names)
    append_fetch_ops(inference_program, fetch_var_names)
362

363 364
    if model_filename is not None:
        model_filename = os.path.basename(model_filename)
365
    else:
366 367
        model_filename = "__model__"
    model_filename = os.path.join(dirname, model_filename)
368

369 370 371 372
    if params_filename is not None:
        params_filename = os.path.basename(params_filename)

    with open(model_filename, "wb") as f:
373
        f.write(inference_program.desc.serialize_to_string())
374

375
    save_persistables(executor, dirname, inference_program, params_filename)
376 377


378 379 380 381
def load_inference_model(dirname,
                         executor,
                         model_filename=None,
                         params_filename=None):
382 383 384 385 386
    """
    Load inference model from a directory

    :param dirname: directory path
    :param executor: executor that load inference model
387 388 389 390 391 392
    :param model_filename: The name of file to load inference program.
        If not specified, default filename `__model__` will be used.
    :param params_filename: The name of file to load parameters.
        It is used for the case that all parameters are saved in a single binary file.
        If not specified, parameters are considered saved in separate files.

393
    :return: [program, feed_target_names, fetch_targets]
394
             program: program especially for inference.
395 396
             feed_target_names: Names of variables that need to feed data
             fetch_targets: Variables from which we can get inference results.
397 398 399 400
    """
    if not os.path.isdir(dirname):
        raise ValueError("There is no directory named '%s'", dirname)

401 402
    if model_filename is not None:
        model_filename = os.path.basename(model_filename)
403
    else:
404 405 406 407 408
        model_filename = "__model__"
    model_filename = os.path.join(dirname, model_filename)

    if params_filename is not None:
        params_filename = os.path.basename(params_filename)
409

410
    with open(model_filename, "rb") as f:
411 412
        program_desc_str = f.read()

413
    program = Program.parse_from_string(program_desc_str)
414
    load_persistables(executor, dirname, program, params_filename)
415

416 417
    feed_target_names = program.desc.get_feed_target_names()
    fetch_target_names = program.desc.get_fetch_target_names()
418 419 420 421 422
    fetch_targets = [
        program.global_block().var(name) for name in fetch_target_names
    ]

    return [program, feed_target_names, fetch_targets]
X
xuwei06 已提交
423 424 425 426 427 428 429 430


def get_parameter_value(para, executor):
    """
    Get the LoDTensor for the parameter

    :param executor: executor for retrieving the value
    :param para: the given parameter
431

X
xuwei06 已提交
432 433
    :return: the LoDTensor for the parameter
    """
X
xuwei06 已提交
434 435
    assert is_parameter(para)

X
xuwei06 已提交
436 437 438 439 440 441 442 443 444 445 446 447 448
    get_program = Program()
    block = get_program.global_block()
    new_var = _clone_var_in_block_(block, para)
    return executor.run(get_program, feed={}, fetch_list=[new_var])[0]


def get_parameter_value_by_name(name, executor, program=None):
    """
    Get the LoDTensor for paramter with the given name

    :param executor: executor for retrieving the value
    :param name: the name of the parameter
    :param program: the program where the variable is found
Y
Yu Yang 已提交
449
            Default default_main_program().
450

X
xuwei06 已提交
451 452 453
    :return: the LoDTensor for the variable
    """
    if program is None:
Y
Yu Yang 已提交
454
        program = default_main_program()
X
xuwei06 已提交
455 456
    var = program.global_block().var(name)
    return get_parameter_value(var, executor)
T
tangwei12 已提交
457 458


T
tangwei12 已提交
459
SUCCESS_MARK_FILENAME = "_SUCCESS"
460
CHECKPOINT_PREFIX = "checkpoint"
T
tangwei12 已提交
461 462
MODEL_DIR = "__model__"
TRAINER_PREFIX = "trainer"
463
CHECKPOINT_SEPARATOR = "_"
T
tangwei12 已提交
464 465 466


def save_checkpoint(executor,
T
tangwei12 已提交
467
                    checkpoint_dir,
T
tangwei12 已提交
468 469 470
                    trainer_id,
                    is_chief=False,
                    trainer_args=None,
T
tangwei12 已提交
471 472
                    main_program=None,
                    max_num_checkpoints=3):
T
tangwei12 已提交
473
    """
T
tangwei12 已提交
474
    Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory,
T
tangwei12 已提交
475
    the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy
T
tangwei12 已提交
476
    to keep numbers of checkpoint directory,  the numbers of checkpoint directory are max_num_checkpoints at most,
477
    The interval between two saved checkpoints must greater than save_interval_secs.
T
tangwei12 已提交
478

479 480
    :param executor
    :param checkpoint_dir
481 482
    :param trainer_id
    :param is_chief
T
tangwei12 已提交
483
    :param main_program
T
tangwei12 已提交
484
    :param max_num_checkpoints
T
tangwei12 已提交
485
    """
486
    if checkpoint_dir is None:
T
tangwei12 已提交
487
        raise ValueError("The values of 'checkpoint_dir' should not be None")
T
tangwei12 已提交
488

T
tangwei12 已提交
489 490 491
    if trainer_args and not isinstance(trainer_args, dict):
        raise TypeError("The type of 'trainer_args' should be dict")

492 493
    if not os.path.isdir(checkpoint_dir):
        os.makedirs(checkpoint_dir)
T
tangwei12 已提交
494

T
tangwei12 已提交
495 496
    serial = _get_lastest_checkpoint_dir(checkpoint_dir) + 1
    cur_dir = _get_serial_dir(checkpoint_dir, serial)
T
tangwei12 已提交
497

T
tangwei12 已提交
498 499
    save_trainer_args(cur_dir, trainer_id, trainer_args)

T
tangwei12 已提交
500 501
    if is_chief:
        save_persist_vars_without_grad(executor, cur_dir, main_program)
T
tangwei12 已提交
502 503

    _lru_delete(checkpoint_dir, max_num_checkpoints)
T
tangwei12 已提交
504 505


T
tangwei12 已提交
506
def need_load_checkpoint(checkpoint_dir):
507 508 509 510 511
    """
    If the directory have checkpoint files, it will return lastest checkpoint directory serial number

    :param checkpoint_dir
    """
T
tangwei12 已提交
512 513 514 515 516 517 518
    serial = _get_lastest_checkpoint_dir(checkpoint_dir)
    if serial < 0:
        return None
    return serial


def load_checkpoint(executor, checkpoint_dir, serial, main_program):
T
tangwei12 已提交
519
    """
T
tangwei12 已提交
520
    Load checkpoint from a directory by executor,
521
    it will find  the most recent saved checkpoint file and load it auto.
T
tangwei12 已提交
522

T
tangwei12 已提交
523
    :param executor
524
    :param checkpoint_dir
525
    :param serial
T
tangwei12 已提交
526
    :param main_program
T
tangwei12 已提交
527
    """
T
tangwei12 已提交
528

529
    if checkpoint_dir is None:
T
tangwei12 已提交
530 531
        raise ValueError(
            "The values of 'checkpoint_dir' or 'serial' should not be None")
T
tangwei12 已提交
532

T
tangwei12 已提交
533 534
    if serial is None or serial < 0:
        raise ValueError("The values of 'serial' should not be None or <0 ")
T
tangwei12 已提交
535

T
tangwei12 已提交
536 537
    if main_program is None:
        raise ValueError("The values of 'main_program'should not be None")
538

T
tangwei12 已提交
539
    cur_dir = _get_serial_dir(checkpoint_dir, serial)
540
    load_persist_vars_without_grad(executor, cur_dir, main_program)
T
tangwei12 已提交
541 542


T
tangwei12 已提交
543 544 545 546
def clean_checkpoint(checkpoint_dir, delete_dir=False):
    """
    clean the checkpoint dir, when the train exits normally, the trainer will call clean_checkpoint to delete checkpoint directory saved before.
    delete_dir only works when the directory is empty, otherwise, OSError is raised.  
547 548 549

    :param checkpoint_dir
    :param delete_dir
T
tangwei12 已提交
550
    """
551

T
tangwei12 已提交
552
    if checkpoint_dir is None:
T
tangwei12 已提交
553
        raise ValueError("The values of 'checkpoint_dir' should not be None")
T
tangwei12 已提交
554 555 556 557 558 559
    _lru_delete(checkpoint_dir, max_num_checkpoints=0)

    if delete_dir and not os.listdir(checkpoint_dir):
        os.rmdir(checkpoint_dir)


T
tangwei12 已提交
560 561 562 563
def load_persist_vars_without_grad(executor, dirname, program, nest=True):
    """
    load_persist_vars_without_grad will load variables from a directory by an executor,
    the variable named end with "@GRAD" will not be loaded.
564 565 566 567 568

    :param executor
    :param dirname
    :param program
    :param nest
T
tangwei12 已提交
569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585
    """

    if nest:
        dirname = _get_model_dir(dirname)

    load_vars(
        executor,
        dirname=dirname,
        main_program=program,
        predicate=_is_checkpoint_var,
        filename=None)


def save_persist_vars_without_grad(executor, dirname, program):
    """
    save_persist_vars_without_grad  will save variables to a directory by an executor,
    the variable named end with "@GRAD" will not be saved.
586 587 588 589

    :param executor
    :param dirname
    :param program
T
tangwei12 已提交
590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630
    """
    cur_dir = _get_model_dir(dirname)
    save_vars(
        executor,
        dirname=cur_dir,
        main_program=program,
        vars=None,
        predicate=_is_checkpoint_var,
        filename=None)
    _write_success(cur_dir)


def save_trainer_args(dirname, trainer_id, trainer_args):
    if not isinstance(trainer_args, dict):
        raise TypeError("The type of 'trainer_args' should be dict")
    cur_dir = _get_trainer_dir(dirname, trainer_id)

    for name, value in trainer_args.iteritems():
        args_file = os.path.join(cur_dir, name)
        with open(args_file, 'w') as f:
            f.write(str(value))
    _write_success(cur_dir)


def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args):
    cur_dir = _get_serial_dir(checkpoint_dir, serial)
    cur_dir = _get_trainer_dir(cur_dir, trainer_id)

    if not isinstance(trainer_args, list):
        raise TypeError("The type of 'trainer_args' should be list")

    ret_values = []

    for arg in trainer_args:
        cur_file = os.path.join(cur_dir, arg)
        with open(cur_file, 'r') as f:
            contents = f.read()
            ret_values.append(contents.strip())
    return ret_values


T
tangwei12 已提交
631
def _is_checkpoint_var(var):
T
tangwei12 已提交
632
    """
T
tangwei12 已提交
633 634 635
    the checkpoint will not save or load all the variables.
    var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded.

T
tangwei12 已提交
636
    :param var
T
tangwei12 已提交
637
    """
T
tangwei12 已提交
638 639 640 641 642
    if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
            var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
            var.desc.type() == core.VarDesc.VarType.RAW:
        return False

T
tangwei12 已提交
643 644 645 646 647 648 649
    if "@GRAD" in var.name:
        return False

    if ".trainer_" in var.name:
        return False

    if ".block" in var.name:
T
tangwei12 已提交
650 651 652
        return False

    return var.persistable
T
tangwei12 已提交
653 654


T
tangwei12 已提交
655 656 657 658 659 660 661 662 663 664 665 666 667
def _get_dir_serial(dirname):
    _, serial = dirname.split(CHECKPOINT_SEPARATOR)

    serial_num = -1
    try:
        serial_num = int(serial)
    except ValueError:
        serial_num = -1
    return serial_num


def _get_serial_dir(dirname, serial):
    serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial)
T
tangwei12 已提交
668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692
    serial_dir = os.path.join(dirname, serial_folder)

    if not os.path.isdir(serial_dir):
        os.makedirs(serial_dir)

    return serial_dir


def _get_model_dir(dirname):
    model_dir = os.path.join(dirname, MODEL_DIR)

    if not os.path.isdir(model_dir):
        os.makedirs(model_dir)

    return model_dir


def _get_trainer_dir(dirname, trainer_id):
    trainer_folder = TRAINER_PREFIX + CHECKPOINT_SEPARATOR + str(trainer_id)
    trainer_dir = os.path.join(dirname, trainer_folder)

    if not os.path.isdir(trainer_dir):
        os.makedirs(trainer_dir)

    return trainer_dir
T
tangwei12 已提交
693 694


T
tangwei12 已提交
695
def _lru_delete(dirname, max_num_checkpoints=3):
T
tangwei12 已提交
696
    dirs = os.listdir(dirname)
T
tangwei12 已提交
697
    serial_map = {}
T
tangwei12 已提交
698
    for serial in dirs:
T
tangwei12 已提交
699 700
        serial_num = _get_dir_serial(serial)
        serial_map[serial_num] = serial
T
tangwei12 已提交
701

T
tangwei12 已提交
702
    if len(serial_map.keys()) <= max_num_checkpoints:
T
tangwei12 已提交
703 704
        return

T
tangwei12 已提交
705
    serials = serial_map.keys()
T
tangwei12 已提交
706
    serials.sort(reverse=True)
T
tangwei12 已提交
707
    serials = serials[max_num_checkpoints:]
T
tangwei12 已提交
708
    for serial in serials:
T
tangwei12 已提交
709
        cur_dir = _get_serial_dir(dirname, serial)
T
tangwei12 已提交
710 711 712
        shutil.rmtree(cur_dir)


T
tangwei12 已提交
713 714
def _write_success(dirname):
    """
T
tangwei12 已提交
715
    write an empty file named "_SUCCESS" in checkpoint dir, indicate this checkpoint is correct.
T
tangwei12 已提交
716 717

    :param dirname
T
tangwei12 已提交
718
    """
T
tangwei12 已提交
719
    success_file = os.path.join(dirname, SUCCESS_MARK_FILENAME)
T
bug fix  
tangwei12 已提交
720
    with open(success_file, 'a') as f:
721
        now = time.ctime()
T
bug fix  
tangwei12 已提交
722
        f.write(now)
T
tangwei12 已提交
723 724 725 726


def _get_lastest_checkpoint_dir(checkpoint_dir):
    """
T
tangwei12 已提交
727 728 729
    get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory

    :param checkpoint_dir
T
tangwei12 已提交
730 731
    """
    if not checkpoint_dir.strip():
T
tangwei12 已提交
732
        return -1
T
tangwei12 已提交
733 734 735 736 737 738

    def has_success(checkpoint_dir, cur_dir):
        """
        is _SUCCESS in this dir
        """

T
tangwei12 已提交
739 740
        serial = _get_dir_serial(cur_dir)
        if serial == -1:
T
tangwei12 已提交
741 742
            return -1

743 744 745 746
        if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)):
            return -1

        success_path = os.path.join(
T
tangwei12 已提交
747 748
            _get_serial_dir(checkpoint_dir, serial), MODEL_DIR,
            SUCCESS_MARK_FILENAME)
T
tangwei12 已提交
749
        if os.path.isfile(success_path):
T
tangwei12 已提交
750
            return serial
T
tangwei12 已提交
751 752 753 754 755 756 757 758 759 760 761

    if not os.path.isdir(checkpoint_dir):
        return -1

    current_dir = -1
    dirs = os.listdir(checkpoint_dir)
    for cur_dir in dirs:
        success_num = has_success(checkpoint_dir, cur_dir)
        if success_num > current_dir:
            current_dir = success_num
    return current_dir