io.py 20.8 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import os
T
tangwei12 已提交
16 17
import time
import shutil
18

19 20
from paddle.fluid.evaluator import Evaluator
from paddle.fluid.framework import Program, Parameter, default_main_program, Variable
K
fix bug  
Kexin Zhao 已提交
21
from . import core
22 23

__all__ = [
T
tangwei12 已提交
24 25
    'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
    'load_persistables', 'save_inference_model', 'load_inference_model',
T
tangwei12 已提交
26
    'get_inference_program', 'save_checkpoint', 'load_checkpoint',
27 28
    'clean_checkpoint', 'load_persist_vars_without_grad',
    'save_persist_vars_without_grad'
29 30 31 32
]


def is_parameter(var):
K
Kavya Srinet 已提交
33
    """Check whether the variable is a Parameter.
34 35 36 37 38 39 40

    This function checks whether the input variable is a Parameter.

    Args:
        var : The input variable.

    Returns:
K
Kavya Srinet 已提交
41
        boolean result whether the variable is a Parameter.
42
    """
43 44 45 46
    return isinstance(var, Parameter)


def is_persistable(var):
47
    if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
Y
Refine  
Yu Yang 已提交
48
            var.desc.type() == core.VarDesc.VarType.FETCH_LIST:
49
        return False
50 51 52 53 54 55 56 57
    return var.persistable


def _clone_var_in_block_(block, var):
    assert isinstance(var, Variable)
    return block.create_var(
        name=var.name,
        shape=var.shape,
F
fengjiayi 已提交
58
        dtype=var.dtype,
59 60 61 62 63
        type=var.type,
        lod_level=var.lod_level,
        persistable=True)


64 65 66 67 68
def save_vars(executor,
              dirname,
              main_program=None,
              vars=None,
              predicate=None,
69
              filename=None):
70 71
    """
    Save variables to directory by executor.
72

73 74
    :param executor: executor that save variable
    :param dirname: directory path
X
xuwei06 已提交
75
    :param main_program: program. If vars is None, then filter all variables in this
76
    program which fit `predicate`. Default default_main_program.
77
    :param predicate: The Predicate describes a callable that returns a variable
78 79
    as a bool. If it returns true, the corresponding input variable will be saved.
    :param vars: variables need to be saved. If vars is specified, program & predicate
80
    will be ignored
81 82
    :param filename: The name of a single file that all vars are saved to.
        If it is None, save variables to separate files.
83

84 85 86
    :return: None
    """
    if vars is None:
87
        if main_program is None:
Y
Yu Yang 已提交
88
            main_program = default_main_program()
89
        if not isinstance(main_program, Program):
90 91 92 93 94
            raise TypeError("program should be as Program type or None")

        save_vars(
            executor,
            dirname=dirname,
95
            vars=filter(predicate, main_program.list_vars()),
96
            filename=filename)
97 98 99
    else:
        save_program = Program()
        save_block = save_program.global_block()
100 101

        save_var_map = {}
102
        for each_var in vars:
103 104 105
            # NOTE: don't save the variable which type is RAW
            if each_var.type == core.VarDesc.VarType.RAW:
                continue
106
            new_var = _clone_var_in_block_(save_block, each_var)
107
            if filename is None:
108 109 110 111 112 113 114 115
                save_block.append_op(
                    type='save',
                    inputs={'X': [new_var]},
                    outputs={},
                    attrs={'file_path': os.path.join(dirname, new_var.name)})
            else:
                save_var_map[new_var.name] = new_var

116
        if filename is not None:
117 118 119 120
            save_var_list = []
            for name in sorted(save_var_map.keys()):
                save_var_list.append(save_var_map[name])

121
            save_block.append_op(
122 123
                type='save_combine',
                inputs={'X': save_var_list},
124
                outputs={},
125
                attrs={'file_path': os.path.join(dirname, filename)})
126

127 128 129
        executor.run(save_program)


130
def save_params(executor, dirname, main_program=None, filename=None):
131 132 133 134 135 136
    """
    Save all parameters to directory with executor.
    """
    save_vars(
        executor,
        dirname=dirname,
137
        main_program=main_program,
138
        vars=None,
139
        predicate=is_parameter,
140
        filename=filename)
141 142


143
def save_persistables(executor, dirname, main_program=None, filename=None):
144 145 146 147 148 149
    """
    Save all persistables to directory with executor.
    """
    save_vars(
        executor,
        dirname=dirname,
150
        main_program=main_program,
151
        vars=None,
152
        predicate=is_persistable,
153
        filename=filename)
154 155


156 157 158 159 160
def load_vars(executor,
              dirname,
              main_program=None,
              vars=None,
              predicate=None,
161
              filename=None):
162 163
    """
    Load variables from directory by executor.
164

165
    :param executor: executor that load variable
166
    :param dirname: directory path
X
xuwei06 已提交
167
    :param main_program: program. If vars is None, then filter all variables in this
Y
Yu Yang 已提交
168
    program which fit `predicate`. Default default_main_program().
169
    :param predicate: The Predicate describes a callable that returns a variable
170 171
    as a bool. If it returns true, the corresponding input variable will be loaded.
    :param vars: variables need to be loaded. If vars is specified, program &
172
    predicate will be ignored
173 174
    :param filename: The name of the single file that all vars are loaded from.
        If it is None, load variables from separate files.
175

176 177 178
    :return: None
    """
    if vars is None:
179
        if main_program is None:
Y
Yu Yang 已提交
180
            main_program = default_main_program()
181
        if not isinstance(main_program, Program):
182 183 184 185 186
            raise TypeError("program's type should be Program")

        load_vars(
            executor,
            dirname=dirname,
187
            vars=filter(predicate, main_program.list_vars()),
188
            filename=filename)
189 190 191
    else:
        load_prog = Program()
        load_block = load_prog.global_block()
192 193

        load_var_map = {}
194 195
        for each_var in vars:
            assert isinstance(each_var, Variable)
T
tangwei12 已提交
196 197
            if each_var.type == core.VarDesc.VarType.RAW:
                continue
198
            new_var = _clone_var_in_block_(load_block, each_var)
199
            if filename is None:
200 201 202 203 204 205 206 207
                load_block.append_op(
                    type='load',
                    inputs={},
                    outputs={'Out': [new_var]},
                    attrs={'file_path': os.path.join(dirname, new_var.name)})
            else:
                load_var_map[new_var.name] = new_var

208
        if filename is not None:
209 210 211 212
            load_var_list = []
            for name in sorted(load_var_map.keys()):
                load_var_list.append(load_var_map[name])

213
            load_block.append_op(
214
                type='load_combine',
215
                inputs={},
216
                outputs={"Out": load_var_list},
217
                attrs={'file_path': os.path.join(dirname, filename)})
218

219 220 221
        executor.run(load_prog)


222
def load_params(executor, dirname, main_program=None, filename=None):
223 224 225 226
    """
    load all parameters from directory by executor.
    """
    load_vars(
227 228 229
        executor,
        dirname=dirname,
        main_program=main_program,
230
        predicate=is_parameter,
231
        filename=filename)
232 233


234
def load_persistables(executor, dirname, main_program=None, filename=None):
235 236 237 238
    """
    load all persistables from directory by executor.
    """
    load_vars(
239 240 241
        executor,
        dirname=dirname,
        main_program=main_program,
242
        predicate=is_persistable,
243
        filename=filename)
244 245


246 247
def get_inference_program(target_vars, main_program=None):
    if main_program is None:
Y
Yu Yang 已提交
248
        main_program = default_main_program()
249 250
    if not isinstance(target_vars, list):
        target_vars = [target_vars]
W
wanghaoshuang 已提交
251 252 253
    vars = []
    for var in target_vars:
        if isinstance(var, Evaluator):
W
wanghaoshuang 已提交
254 255
            vars.extend(var.states)
            vars.extend(var.metrics)
W
wanghaoshuang 已提交
256 257 258
        else:
            vars.append(var)
    pruned_program = main_program.prune(targets=vars)
259 260 261 262
    inference_program = pruned_program.inference_optimize()
    return inference_program


263 264 265
def prepend_feed_ops(inference_program,
                     feed_target_names,
                     feed_holder_name='feed'):
Q
Qiao Longfei 已提交
266 267 268
    if len(feed_target_names) == 0:
        return

K
Kexin Zhao 已提交
269 270
    global_block = inference_program.global_block()
    feed_var = global_block.create_var(
271 272 273
        name=feed_holder_name,
        type=core.VarDesc.VarType.FEED_MINIBATCH,
        persistable=True)
K
Kexin Zhao 已提交
274

275
    for i, name in enumerate(feed_target_names):
K
fix bug  
Kexin Zhao 已提交
276
        out = global_block.var(name)
K
Kexin Zhao 已提交
277 278 279
        global_block.prepend_op(
            type='feed',
            inputs={'X': [feed_var]},
K
fix bug  
Kexin Zhao 已提交
280
            outputs={'Out': [out]},
K
Kexin Zhao 已提交
281 282 283
            attrs={'col': i})


284 285 286
def append_fetch_ops(inference_program,
                     fetch_target_names,
                     fetch_holder_name='fetch'):
K
Kexin Zhao 已提交
287 288
    global_block = inference_program.global_block()
    fetch_var = global_block.create_var(
289 290 291
        name=fetch_holder_name,
        type=core.VarDesc.VarType.FETCH_LIST,
        persistable=True)
K
Kexin Zhao 已提交
292

293
    for i, name in enumerate(fetch_target_names):
K
Kexin Zhao 已提交
294 295 296 297 298 299 300
        global_block.append_op(
            type='fetch',
            inputs={'X': [name]},
            outputs={'Out': [fetch_var]},
            attrs={'col': i})


301 302 303 304
def save_inference_model(dirname,
                         feeded_var_names,
                         target_vars,
                         executor,
305
                         main_program=None,
306 307
                         model_filename=None,
                         params_filename=None):
308
    """
X
xuwei06 已提交
309
    Build a model especially for inference,
310 311 312 313 314 315
    and save it to directory by the executor.

    :param dirname: directory path
    :param feeded_var_names: Names of variables that need to be feeded data during inference
    :param target_vars: Variables from which we can get inference results.
    :param executor: executor that save inference model
X
xuwei06 已提交
316
    :param main_program: original program, which will be pruned to build the inference model.
Y
Yu Yang 已提交
317
            Default default_main_program().
318 319 320 321 322
    :param model_filename: The name of file to save inference program.
        If not specified, default filename `__model__` will be used.
    :param params_filename: The name of file to save parameters.
        It is used for the case that all parameters are saved in a single binary file.
        If not specified, parameters are considered saved in separate files.
323 324 325

    :return: None
    """
F
fengjiayi 已提交
326 327 328
    if isinstance(feeded_var_names, basestring):
        feeded_var_names = [feeded_var_names]
    else:
Q
Qiao Longfei 已提交
329 330 331 332
        if len(feeded_var_names) > 0:
            if not (bool(feeded_var_names) and all(
                    isinstance(name, basestring) for name in feeded_var_names)):
                raise ValueError("'feed_var_names' should be a list of str.")
F
fengjiayi 已提交
333 334

    if isinstance(target_vars, Variable):
F
fengjiayi 已提交
335
        target_vars = [target_vars]
F
fengjiayi 已提交
336 337 338 339 340
    else:
        if not (bool(target_vars) and all(
                isinstance(var, Variable) for var in target_vars)):
            raise ValueError("'target_vars' should be a list of Variable.")

341
    if main_program is None:
Y
Yu Yang 已提交
342
        main_program = default_main_program()
343
    copy_program = main_program.clone()
344 345 346 347

    if not os.path.isdir(dirname):
        os.makedirs(dirname)

348
    # Clear the is_target information and remove the existed feed and fetch op
349
    global_block = copy_program.global_block()
350 351 352 353
    for i, op in enumerate(global_block.ops):
        op.desc.set_is_target(False)
        if op.type == "feed" or op.type == "fetch":
            global_block.remove_op(i)
354
    copy_program.desc.flush()
355

356
    pruned_program = copy_program.prune(targets=target_vars)
357
    inference_program = pruned_program.inference_optimize()
358 359
    fetch_var_names = [v.name for v in target_vars]

K
Kexin Zhao 已提交
360 361
    prepend_feed_ops(inference_program, feeded_var_names)
    append_fetch_ops(inference_program, fetch_var_names)
362

363 364
    if model_filename is not None:
        model_filename = os.path.basename(model_filename)
365
    else:
366 367
        model_filename = "__model__"
    model_filename = os.path.join(dirname, model_filename)
368

369 370 371 372
    if params_filename is not None:
        params_filename = os.path.basename(params_filename)

    with open(model_filename, "wb") as f:
373
        f.write(inference_program.desc.serialize_to_string())
374

375
    save_persistables(executor, dirname, inference_program, params_filename)
376 377


378 379 380 381
def load_inference_model(dirname,
                         executor,
                         model_filename=None,
                         params_filename=None):
382 383 384 385 386
    """
    Load inference model from a directory

    :param dirname: directory path
    :param executor: executor that load inference model
387 388 389 390 391 392
    :param model_filename: The name of file to load inference program.
        If not specified, default filename `__model__` will be used.
    :param params_filename: The name of file to load parameters.
        It is used for the case that all parameters are saved in a single binary file.
        If not specified, parameters are considered saved in separate files.

393
    :return: [program, feed_target_names, fetch_targets]
394
             program: program especially for inference.
395 396
             feed_target_names: Names of variables that need to feed data
             fetch_targets: Variables from which we can get inference results.
397 398 399 400
    """
    if not os.path.isdir(dirname):
        raise ValueError("There is no directory named '%s'", dirname)

401 402
    if model_filename is not None:
        model_filename = os.path.basename(model_filename)
403
    else:
404 405 406 407 408
        model_filename = "__model__"
    model_filename = os.path.join(dirname, model_filename)

    if params_filename is not None:
        params_filename = os.path.basename(params_filename)
409

410
    with open(model_filename, "rb") as f:
411 412
        program_desc_str = f.read()

413
    program = Program.parse_from_string(program_desc_str)
414
    load_persistables(executor, dirname, program, params_filename)
415

416 417
    feed_target_names = program.desc.get_feed_target_names()
    fetch_target_names = program.desc.get_fetch_target_names()
418 419 420 421 422
    fetch_targets = [
        program.global_block().var(name) for name in fetch_target_names
    ]

    return [program, feed_target_names, fetch_targets]
X
xuwei06 已提交
423 424 425 426 427 428 429 430


def get_parameter_value(para, executor):
    """
    Get the LoDTensor for the parameter

    :param executor: executor for retrieving the value
    :param para: the given parameter
431

X
xuwei06 已提交
432 433
    :return: the LoDTensor for the parameter
    """
X
xuwei06 已提交
434 435
    assert is_parameter(para)

X
xuwei06 已提交
436 437 438 439 440 441 442 443 444 445 446 447 448
    get_program = Program()
    block = get_program.global_block()
    new_var = _clone_var_in_block_(block, para)
    return executor.run(get_program, feed={}, fetch_list=[new_var])[0]


def get_parameter_value_by_name(name, executor, program=None):
    """
    Get the LoDTensor for paramter with the given name

    :param executor: executor for retrieving the value
    :param name: the name of the parameter
    :param program: the program where the variable is found
Y
Yu Yang 已提交
449
            Default default_main_program().
450

X
xuwei06 已提交
451 452 453
    :return: the LoDTensor for the variable
    """
    if program is None:
Y
Yu Yang 已提交
454
        program = default_main_program()
X
xuwei06 已提交
455 456
    var = program.global_block().var(name)
    return get_parameter_value(var, executor)
T
tangwei12 已提交
457 458


459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485
def load_persist_vars_without_grad(executor, dirname, program):
    """
    load_persist_vars_without_grad will load variables from a directory by an executor,
    the variable named end with "@GRAD" will not be loaded.
    """
    load_vars(
        executor,
        dirname=dirname,
        main_program=program,
        predicate=_is_checkpoint_var,
        filename=None)


def save_persist_vars_without_grad(executor, dirname, program):
    """
    save_persist_vars_without_grad  will save variables to a directory by an executor,
    the variable named end with "@GRAD" will not be saved.
    """
    save_vars(
        executor,
        dirname=dirname,
        main_program=program,
        vars=None,
        predicate=_is_checkpoint_var,
        filename=None)


T
tangwei12 已提交
486
SUCCESS_MARK_FILENAME = "_SUCCESS"
487 488
CHECKPOINT_PREFIX = "checkpoint"
CHECKPOINT_SEPARATOR = "_"
T
tangwei12 已提交
489 490 491


def save_checkpoint(executor,
492
                    checkpoint_dir=None,
T
tangwei12 已提交
493
                    max_num_checkpoints=3,
T
tangwei12 已提交
494 495
                    main_program=None):
    """
T
tangwei12 已提交
496
    Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory,
T
tangwei12 已提交
497
    the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy
T
tangwei12 已提交
498
    to keep numbers of checkpoint directory,  the numbers of checkpoint directory are max_num_checkpoints at most,
499
    The interval between two saved checkpoints must greater than save_interval_secs.
T
tangwei12 已提交
500

501 502
    :param executor
    :param checkpoint_dir
T
tangwei12 已提交
503
    :param max_num_checkpoints
504
    :param save_interval_secs
T
tangwei12 已提交
505
    :param main_program
T
tangwei12 已提交
506
    """
507 508
    if checkpoint_dir is None:
        checkpoint_dir = os.getcwd()
T
tangwei12 已提交
509

510 511
    if not os.path.isdir(checkpoint_dir):
        os.makedirs(checkpoint_dir)
T
tangwei12 已提交
512

T
tangwei12 已提交
513 514
    serial = _get_lastest_checkpoint_dir(checkpoint_dir) + 1
    cur_dir = _get_serial_dir(checkpoint_dir, serial)
T
tangwei12 已提交
515

T
tangwei12 已提交
516
    save_persist_vars_without_grad(executor, cur_dir, main_program)
T
tangwei12 已提交
517
    _write_success(cur_dir)
518
    _lru_delete(checkpoint_dir, max_num_checkpoints)
T
tangwei12 已提交
519 520


521
def load_checkpoint(executor, checkpoint_dir=None, main_program=None):
T
tangwei12 已提交
522
    """
T
tangwei12 已提交
523
    Load checkpoint from a directory by executor,
524
    it will find  the most recent saved checkpoint file and load it auto.
T
tangwei12 已提交
525

T
tangwei12 已提交
526
    :param executor
527
    :param checkpoint_dir
T
tangwei12 已提交
528
    :param main_program
T
tangwei12 已提交
529
    """
T
tangwei12 已提交
530

531 532
    if checkpoint_dir is None:
        checkpoint_dir = os.getcwd()
T
tangwei12 已提交
533

534
    serial = _get_lastest_checkpoint_dir(checkpoint_dir)
T
tangwei12 已提交
535

T
tangwei12 已提交
536
    if serial < 0:
T
tangwei12 已提交
537
        return
538

T
tangwei12 已提交
539
    cur_dir = _get_serial_dir(checkpoint_dir, serial)
540
    load_persist_vars_without_grad(executor, cur_dir, main_program)
T
tangwei12 已提交
541 542


T
tangwei12 已提交
543 544 545 546 547 548 549 550 551 552 553 554 555
def clean_checkpoint(checkpoint_dir, delete_dir=False):
    """
    clean the checkpoint dir, when the train exits normally, the trainer will call clean_checkpoint to delete checkpoint directory saved before.
    delete_dir only works when the directory is empty, otherwise, OSError is raised.  
    """
    if checkpoint_dir is None:
        checkpoint_dir = os.getcwd()
    _lru_delete(checkpoint_dir, max_num_checkpoints=0)

    if delete_dir and not os.listdir(checkpoint_dir):
        os.rmdir(checkpoint_dir)


T
tangwei12 已提交
556
def _is_checkpoint_var(var):
T
tangwei12 已提交
557
    """
T
tangwei12 已提交
558 559 560
    the checkpoint will not save or load all the variables.
    var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded.

T
tangwei12 已提交
561
    :param var
T
tangwei12 已提交
562
    """
T
tangwei12 已提交
563 564 565 566 567 568 569 570 571
    if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
            var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
            var.desc.type() == core.VarDesc.VarType.RAW:
        return False

    if var.name.endswith("@GRAD"):
        return False

    return var.persistable
T
tangwei12 已提交
572 573


T
tangwei12 已提交
574 575 576 577 578 579 580 581 582 583 584 585 586 587
def _get_dir_serial(dirname):
    _, serial = dirname.split(CHECKPOINT_SEPARATOR)

    serial_num = -1
    try:
        serial_num = int(serial)
    except ValueError:
        serial_num = -1
    return serial_num


def _get_serial_dir(dirname, serial):
    serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial)
    return os.path.join(dirname, serial_folder)
T
tangwei12 已提交
588 589


T
tangwei12 已提交
590
def _lru_delete(dirname, max_num_checkpoints=3):
T
tangwei12 已提交
591
    dirs = os.listdir(dirname)
T
tangwei12 已提交
592
    serial_map = {}
T
tangwei12 已提交
593
    for serial in dirs:
T
tangwei12 已提交
594 595
        serial_num = _get_dir_serial(serial)
        serial_map[serial_num] = serial
T
tangwei12 已提交
596

T
tangwei12 已提交
597
    if len(serial_map.keys()) <= max_num_checkpoints:
T
tangwei12 已提交
598 599
        return

T
tangwei12 已提交
600
    serials = serial_map.keys()
T
tangwei12 已提交
601
    serials.sort(reverse=True)
T
tangwei12 已提交
602
    serials = serials[max_num_checkpoints:]
T
tangwei12 已提交
603
    for serial in serials:
T
tangwei12 已提交
604
        cur_dir = _get_serial_dir(dirname, serial)
T
tangwei12 已提交
605 606 607
        shutil.rmtree(cur_dir)


T
tangwei12 已提交
608 609
def _write_success(dirname):
    """
T
tangwei12 已提交
610
    write an empty file named "_SUCCESS" in checkpoint dir, indicate this checkpoint is correct.
T
tangwei12 已提交
611 612

    :param dirname
T
tangwei12 已提交
613
    """
T
tangwei12 已提交
614
    success_file = os.path.join(dirname, SUCCESS_MARK_FILENAME)
T
bug fix  
tangwei12 已提交
615
    with open(success_file, 'a') as f:
616
        now = time.ctime()
T
bug fix  
tangwei12 已提交
617
        f.write(now)
T
tangwei12 已提交
618 619 620 621


def _get_lastest_checkpoint_dir(checkpoint_dir):
    """
T
tangwei12 已提交
622 623 624
    get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory

    :param checkpoint_dir
T
tangwei12 已提交
625 626
    """
    if not checkpoint_dir.strip():
T
tangwei12 已提交
627
        return -1
T
tangwei12 已提交
628 629 630 631 632 633

    def has_success(checkpoint_dir, cur_dir):
        """
        is _SUCCESS in this dir
        """

T
tangwei12 已提交
634 635
        serial = _get_dir_serial(cur_dir)
        if serial == -1:
T
tangwei12 已提交
636 637
            return -1

638 639 640 641
        if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)):
            return -1

        success_path = os.path.join(
T
tangwei12 已提交
642
            _get_serial_dir(checkpoint_dir, serial), SUCCESS_MARK_FILENAME)
T
tangwei12 已提交
643
        if os.path.isfile(success_path):
T
tangwei12 已提交
644
            return serial
T
tangwei12 已提交
645 646 647 648 649 650 651 652 653 654 655

    if not os.path.isdir(checkpoint_dir):
        return -1

    current_dir = -1
    dirs = os.listdir(checkpoint_dir)
    for cur_dir in dirs:
        success_num = has_success(checkpoint_dir, cur_dir)
        if success_num > current_dir:
            current_dir = success_num
    return current_dir