io.py 21.0 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import os
T
tangwei12 已提交
16 17
import time
import shutil
18

19 20
from paddle.fluid.evaluator import Evaluator
from paddle.fluid.framework import Program, Parameter, default_main_program, Variable
K
fix bug  
Kexin Zhao 已提交
21
from . import core
22 23

__all__ = [
T
tangwei12 已提交
24 25
    'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
    'load_persistables', 'save_inference_model', 'load_inference_model',
T
tangwei12 已提交
26
    'get_inference_program', 'save_checkpoint', 'load_checkpoint',
27 28
    'clean_checkpoint', 'load_persist_vars_without_grad',
    'save_persist_vars_without_grad'
29 30 31 32
]


def is_parameter(var):
K
Kavya Srinet 已提交
33
    """Check whether the variable is a Parameter.
34 35 36 37 38 39 40

    This function checks whether the input variable is a Parameter.

    Args:
        var : The input variable.

    Returns:
K
Kavya Srinet 已提交
41
        boolean result whether the variable is a Parameter.
42
    """
43 44 45 46
    return isinstance(var, Parameter)


def is_persistable(var):
47
    if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
Y
Refine  
Yu Yang 已提交
48
            var.desc.type() == core.VarDesc.VarType.FETCH_LIST:
49
        return False
50 51 52 53 54 55 56 57
    return var.persistable


def _clone_var_in_block_(block, var):
    assert isinstance(var, Variable)
    return block.create_var(
        name=var.name,
        shape=var.shape,
F
fengjiayi 已提交
58
        dtype=var.dtype,
59 60 61 62 63
        type=var.type,
        lod_level=var.lod_level,
        persistable=True)


64 65 66 67 68
def save_vars(executor,
              dirname,
              main_program=None,
              vars=None,
              predicate=None,
69
              filename=None):
70 71
    """
    Save variables to directory by executor.
72

73 74
    :param executor: executor that save variable
    :param dirname: directory path
X
xuwei06 已提交
75
    :param main_program: program. If vars is None, then filter all variables in this
76
    program which fit `predicate`. Default default_main_program.
77
    :param predicate: The Predicate describes a callable that returns a variable
78 79
    as a bool. If it returns true, the corresponding input variable will be saved.
    :param vars: variables need to be saved. If vars is specified, program & predicate
80
    will be ignored
81 82
    :param filename: The name of a single file that all vars are saved to.
        If it is None, save variables to separate files.
83

84 85 86
    :return: None
    """
    if vars is None:
87
        if main_program is None:
Y
Yu Yang 已提交
88
            main_program = default_main_program()
89
        if not isinstance(main_program, Program):
90 91 92 93 94
            raise TypeError("program should be as Program type or None")

        save_vars(
            executor,
            dirname=dirname,
95
            vars=filter(predicate, main_program.list_vars()),
96
            filename=filename)
97 98 99
    else:
        save_program = Program()
        save_block = save_program.global_block()
100 101

        save_var_map = {}
102
        for each_var in vars:
103 104 105
            # NOTE: don't save the variable which type is RAW
            if each_var.type == core.VarDesc.VarType.RAW:
                continue
106
            new_var = _clone_var_in_block_(save_block, each_var)
107
            if filename is None:
108 109 110 111 112 113 114 115
                save_block.append_op(
                    type='save',
                    inputs={'X': [new_var]},
                    outputs={},
                    attrs={'file_path': os.path.join(dirname, new_var.name)})
            else:
                save_var_map[new_var.name] = new_var

116
        if filename is not None:
117 118 119 120
            save_var_list = []
            for name in sorted(save_var_map.keys()):
                save_var_list.append(save_var_map[name])

121
            save_block.append_op(
122 123
                type='save_combine',
                inputs={'X': save_var_list},
124
                outputs={},
125
                attrs={'file_path': os.path.join(dirname, filename)})
126

127 128 129
        executor.run(save_program)


130
def save_params(executor, dirname, main_program=None, filename=None):
131 132 133 134 135 136
    """
    Save all parameters to directory with executor.
    """
    save_vars(
        executor,
        dirname=dirname,
137
        main_program=main_program,
138
        vars=None,
139
        predicate=is_parameter,
140
        filename=filename)
141 142


143
def save_persistables(executor, dirname, main_program=None, filename=None):
144 145 146 147 148 149
    """
    Save all persistables to directory with executor.
    """
    save_vars(
        executor,
        dirname=dirname,
150
        main_program=main_program,
151
        vars=None,
152
        predicate=is_persistable,
153
        filename=filename)
154 155


156 157 158 159 160
def load_vars(executor,
              dirname,
              main_program=None,
              vars=None,
              predicate=None,
161
              filename=None):
162 163
    """
    Load variables from directory by executor.
164

165
    :param executor: executor that load variable
166
    :param dirname: directory path
X
xuwei06 已提交
167
    :param main_program: program. If vars is None, then filter all variables in this
Y
Yu Yang 已提交
168
    program which fit `predicate`. Default default_main_program().
169
    :param predicate: The Predicate describes a callable that returns a variable
170 171
    as a bool. If it returns true, the corresponding input variable will be loaded.
    :param vars: variables need to be loaded. If vars is specified, program &
172
    predicate will be ignored
173 174
    :param filename: The name of the single file that all vars are loaded from.
        If it is None, load variables from separate files.
175

176 177 178
    :return: None
    """
    if vars is None:
179
        if main_program is None:
Y
Yu Yang 已提交
180
            main_program = default_main_program()
181
        if not isinstance(main_program, Program):
182 183 184 185 186
            raise TypeError("program's type should be Program")

        load_vars(
            executor,
            dirname=dirname,
187
            vars=filter(predicate, main_program.list_vars()),
188
            filename=filename)
189 190 191
    else:
        load_prog = Program()
        load_block = load_prog.global_block()
192 193

        load_var_map = {}
194 195
        for each_var in vars:
            assert isinstance(each_var, Variable)
T
tangwei12 已提交
196 197
            if each_var.type == core.VarDesc.VarType.RAW:
                continue
198
            new_var = _clone_var_in_block_(load_block, each_var)
199
            if filename is None:
200 201 202 203 204 205 206 207
                load_block.append_op(
                    type='load',
                    inputs={},
                    outputs={'Out': [new_var]},
                    attrs={'file_path': os.path.join(dirname, new_var.name)})
            else:
                load_var_map[new_var.name] = new_var

208
        if filename is not None:
209 210 211 212
            load_var_list = []
            for name in sorted(load_var_map.keys()):
                load_var_list.append(load_var_map[name])

213
            load_block.append_op(
214
                type='load_combine',
215
                inputs={},
216
                outputs={"Out": load_var_list},
217
                attrs={'file_path': os.path.join(dirname, filename)})
218

219 220 221
        executor.run(load_prog)


222
def load_params(executor, dirname, main_program=None, filename=None):
223 224 225 226
    """
    load all parameters from directory by executor.
    """
    load_vars(
227 228 229
        executor,
        dirname=dirname,
        main_program=main_program,
230
        predicate=is_parameter,
231
        filename=filename)
232 233


234
def load_persistables(executor, dirname, main_program=None, filename=None):
235 236 237 238
    """
    load all persistables from directory by executor.
    """
    load_vars(
239 240 241
        executor,
        dirname=dirname,
        main_program=main_program,
242
        predicate=is_persistable,
243
        filename=filename)
244 245


246 247
def get_inference_program(target_vars, main_program=None):
    if main_program is None:
Y
Yu Yang 已提交
248
        main_program = default_main_program()
249 250
    if not isinstance(target_vars, list):
        target_vars = [target_vars]
W
wanghaoshuang 已提交
251 252 253
    vars = []
    for var in target_vars:
        if isinstance(var, Evaluator):
W
wanghaoshuang 已提交
254 255
            vars.extend(var.states)
            vars.extend(var.metrics)
W
wanghaoshuang 已提交
256 257 258
        else:
            vars.append(var)
    pruned_program = main_program.prune(targets=vars)
259 260 261 262
    inference_program = pruned_program.inference_optimize()
    return inference_program


263 264 265
def prepend_feed_ops(inference_program,
                     feed_target_names,
                     feed_holder_name='feed'):
Q
Qiao Longfei 已提交
266 267 268
    if len(feed_target_names) == 0:
        return

K
Kexin Zhao 已提交
269 270
    global_block = inference_program.global_block()
    feed_var = global_block.create_var(
271 272 273
        name=feed_holder_name,
        type=core.VarDesc.VarType.FEED_MINIBATCH,
        persistable=True)
K
Kexin Zhao 已提交
274

275
    for i, name in enumerate(feed_target_names):
K
fix bug  
Kexin Zhao 已提交
276
        out = global_block.var(name)
K
Kexin Zhao 已提交
277 278 279
        global_block.prepend_op(
            type='feed',
            inputs={'X': [feed_var]},
K
fix bug  
Kexin Zhao 已提交
280
            outputs={'Out': [out]},
K
Kexin Zhao 已提交
281 282 283
            attrs={'col': i})


284 285 286
def append_fetch_ops(inference_program,
                     fetch_target_names,
                     fetch_holder_name='fetch'):
K
Kexin Zhao 已提交
287 288
    global_block = inference_program.global_block()
    fetch_var = global_block.create_var(
289 290 291
        name=fetch_holder_name,
        type=core.VarDesc.VarType.FETCH_LIST,
        persistable=True)
K
Kexin Zhao 已提交
292

293
    for i, name in enumerate(fetch_target_names):
K
Kexin Zhao 已提交
294 295 296 297 298 299 300
        global_block.append_op(
            type='fetch',
            inputs={'X': [name]},
            outputs={'Out': [fetch_var]},
            attrs={'col': i})


301 302 303 304
def save_inference_model(dirname,
                         feeded_var_names,
                         target_vars,
                         executor,
305
                         main_program=None,
306 307
                         model_filename=None,
                         params_filename=None):
308
    """
X
xuwei06 已提交
309
    Build a model especially for inference,
310 311 312 313 314 315
    and save it to directory by the executor.

    :param dirname: directory path
    :param feeded_var_names: Names of variables that need to be feeded data during inference
    :param target_vars: Variables from which we can get inference results.
    :param executor: executor that save inference model
X
xuwei06 已提交
316
    :param main_program: original program, which will be pruned to build the inference model.
Y
Yu Yang 已提交
317
            Default default_main_program().
318 319 320 321 322
    :param model_filename: The name of file to save inference program.
        If not specified, default filename `__model__` will be used.
    :param params_filename: The name of file to save parameters.
        It is used for the case that all parameters are saved in a single binary file.
        If not specified, parameters are considered saved in separate files.
323 324 325

    :return: None
    """
F
fengjiayi 已提交
326 327 328
    if isinstance(feeded_var_names, basestring):
        feeded_var_names = [feeded_var_names]
    else:
Q
Qiao Longfei 已提交
329 330 331 332
        if len(feeded_var_names) > 0:
            if not (bool(feeded_var_names) and all(
                    isinstance(name, basestring) for name in feeded_var_names)):
                raise ValueError("'feed_var_names' should be a list of str.")
F
fengjiayi 已提交
333 334

    if isinstance(target_vars, Variable):
F
fengjiayi 已提交
335
        target_vars = [target_vars]
F
fengjiayi 已提交
336 337 338 339 340
    else:
        if not (bool(target_vars) and all(
                isinstance(var, Variable) for var in target_vars)):
            raise ValueError("'target_vars' should be a list of Variable.")

341
    if main_program is None:
Y
Yu Yang 已提交
342
        main_program = default_main_program()
343
    copy_program = main_program.clone()
344 345 346 347

    if not os.path.isdir(dirname):
        os.makedirs(dirname)

348
    # Clear the is_target information and remove the existed feed and fetch op
349
    global_block = copy_program.global_block()
350 351 352 353
    for i, op in enumerate(global_block.ops):
        op.desc.set_is_target(False)
        if op.type == "feed" or op.type == "fetch":
            global_block.remove_op(i)
354
    copy_program.desc.flush()
355

356
    pruned_program = copy_program.prune(targets=target_vars)
357
    inference_program = pruned_program.inference_optimize()
358 359
    fetch_var_names = [v.name for v in target_vars]

K
Kexin Zhao 已提交
360 361
    prepend_feed_ops(inference_program, feeded_var_names)
    append_fetch_ops(inference_program, fetch_var_names)
362

363 364
    if model_filename is not None:
        model_filename = os.path.basename(model_filename)
365
    else:
366 367
        model_filename = "__model__"
    model_filename = os.path.join(dirname, model_filename)
368

369 370 371 372
    if params_filename is not None:
        params_filename = os.path.basename(params_filename)

    with open(model_filename, "wb") as f:
373
        f.write(inference_program.desc.serialize_to_string())
374

375
    save_persistables(executor, dirname, inference_program, params_filename)
376 377


378 379 380 381
def load_inference_model(dirname,
                         executor,
                         model_filename=None,
                         params_filename=None):
382 383 384 385 386
    """
    Load inference model from a directory

    :param dirname: directory path
    :param executor: executor that load inference model
387 388 389 390 391 392
    :param model_filename: The name of file to load inference program.
        If not specified, default filename `__model__` will be used.
    :param params_filename: The name of file to load parameters.
        It is used for the case that all parameters are saved in a single binary file.
        If not specified, parameters are considered saved in separate files.

393
    :return: [program, feed_target_names, fetch_targets]
394
             program: program especially for inference.
395 396
             feed_target_names: Names of variables that need to feed data
             fetch_targets: Variables from which we can get inference results.
397 398 399 400
    """
    if not os.path.isdir(dirname):
        raise ValueError("There is no directory named '%s'", dirname)

401 402
    if model_filename is not None:
        model_filename = os.path.basename(model_filename)
403
    else:
404 405 406 407 408
        model_filename = "__model__"
    model_filename = os.path.join(dirname, model_filename)

    if params_filename is not None:
        params_filename = os.path.basename(params_filename)
409

410
    with open(model_filename, "rb") as f:
411 412
        program_desc_str = f.read()

413
    program = Program.parse_from_string(program_desc_str)
414
    load_persistables(executor, dirname, program, params_filename)
415

416 417
    feed_target_names = program.desc.get_feed_target_names()
    fetch_target_names = program.desc.get_fetch_target_names()
418 419 420 421 422
    fetch_targets = [
        program.global_block().var(name) for name in fetch_target_names
    ]

    return [program, feed_target_names, fetch_targets]
X
xuwei06 已提交
423 424 425 426 427 428 429 430


def get_parameter_value(para, executor):
    """
    Get the LoDTensor for the parameter

    :param executor: executor for retrieving the value
    :param para: the given parameter
431

X
xuwei06 已提交
432 433
    :return: the LoDTensor for the parameter
    """
X
xuwei06 已提交
434 435
    assert is_parameter(para)

X
xuwei06 已提交
436 437 438 439 440 441 442 443 444 445 446 447 448
    get_program = Program()
    block = get_program.global_block()
    new_var = _clone_var_in_block_(block, para)
    return executor.run(get_program, feed={}, fetch_list=[new_var])[0]


def get_parameter_value_by_name(name, executor, program=None):
    """
    Get the LoDTensor for paramter with the given name

    :param executor: executor for retrieving the value
    :param name: the name of the parameter
    :param program: the program where the variable is found
Y
Yu Yang 已提交
449
            Default default_main_program().
450

X
xuwei06 已提交
451 452 453
    :return: the LoDTensor for the variable
    """
    if program is None:
Y
Yu Yang 已提交
454
        program = default_main_program()
X
xuwei06 已提交
455 456
    var = program.global_block().var(name)
    return get_parameter_value(var, executor)
T
tangwei12 已提交
457 458


459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485
def load_persist_vars_without_grad(executor, dirname, program):
    """
    load_persist_vars_without_grad will load variables from a directory by an executor,
    the variable named end with "@GRAD" will not be loaded.
    """
    load_vars(
        executor,
        dirname=dirname,
        main_program=program,
        predicate=_is_checkpoint_var,
        filename=None)


def save_persist_vars_without_grad(executor, dirname, program):
    """
    save_persist_vars_without_grad  will save variables to a directory by an executor,
    the variable named end with "@GRAD" will not be saved.
    """
    save_vars(
        executor,
        dirname=dirname,
        main_program=program,
        vars=None,
        predicate=_is_checkpoint_var,
        filename=None)


T
tangwei12 已提交
486
SUCCESS_MARK_FILENAME = "_SUCCESS"
487 488
CHECKPOINT_PREFIX = "checkpoint"
CHECKPOINT_SEPARATOR = "_"
T
tangwei12 已提交
489 490 491


def save_checkpoint(executor,
492
                    checkpoint_dir=None,
T
tangwei12 已提交
493 494
                    max_num_checkpoints=3,
                    save_interval_secs=600,
T
tangwei12 已提交
495 496
                    main_program=None):
    """
T
tangwei12 已提交
497
    Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory,
T
tangwei12 已提交
498
    the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy
T
tangwei12 已提交
499
    to keep numbers of checkpoint directory,  the numbers of checkpoint directory are max_num_checkpoints at most,
500
    The interval between two saved checkpoints must greater than save_interval_secs.
T
tangwei12 已提交
501

502 503
    :param executor
    :param checkpoint_dir
T
tangwei12 已提交
504
    :param max_num_checkpoints
505
    :param save_interval_secs
T
tangwei12 已提交
506
    :param main_program
T
tangwei12 已提交
507
    """
508 509
    if checkpoint_dir is None:
        checkpoint_dir = os.getcwd()
T
tangwei12 已提交
510

511 512
    if not os.path.isdir(checkpoint_dir):
        os.makedirs(checkpoint_dir)
T
tangwei12 已提交
513

514
    serial = _get_lastest_checkpoint_dir(checkpoint_dir)
T
tangwei12 已提交
515
    if serial >= 0 and not _interval_secs_exceed(
516
            _get_serial_dir(serial, checkpoint_dir), save_interval_secs):
T
tangwei12 已提交
517
        return
T
tangwei12 已提交
518

519 520
    serial += 1
    cur_dir = _get_serial_dir(serial, checkpoint_dir)
T
tangwei12 已提交
521

522
    load_persist_vars_without_grad(executor, cur_dir, main_program)
T
tangwei12 已提交
523
    _write_success(cur_dir)
524
    _lru_delete(checkpoint_dir, max_num_checkpoints)
T
tangwei12 已提交
525 526


527
def load_checkpoint(executor, checkpoint_dir=None, main_program=None):
T
tangwei12 已提交
528
    """
T
tangwei12 已提交
529
    Load checkpoint from a directory by executor,
530
    it will find  the most recent saved checkpoint file and load it auto.
T
tangwei12 已提交
531

T
tangwei12 已提交
532
    :param executor
533
    :param checkpoint_dir
T
tangwei12 已提交
534
    :param main_program
T
tangwei12 已提交
535
    """
T
tangwei12 已提交
536

537 538
    if checkpoint_dir is None:
        checkpoint_dir = os.getcwd()
T
tangwei12 已提交
539

540
    serial = _get_lastest_checkpoint_dir(checkpoint_dir)
T
tangwei12 已提交
541

T
tangwei12 已提交
542
    if serial < 0:
T
tangwei12 已提交
543
        return
544 545

    cur_dir = _get_serial_dir(serial, checkpoint_dir)
546
    load_persist_vars_without_grad(executor, cur_dir, main_program)
T
tangwei12 已提交
547 548


T
tangwei12 已提交
549 550 551 552 553 554 555 556 557 558 559 560 561
def clean_checkpoint(checkpoint_dir, delete_dir=False):
    """
    clean the checkpoint dir, when the train exits normally, the trainer will call clean_checkpoint to delete checkpoint directory saved before.
    delete_dir only works when the directory is empty, otherwise, OSError is raised.  
    """
    if checkpoint_dir is None:
        checkpoint_dir = os.getcwd()
    _lru_delete(checkpoint_dir, max_num_checkpoints=0)

    if delete_dir and not os.listdir(checkpoint_dir):
        os.rmdir(checkpoint_dir)


562 563 564 565 566
def _get_serial_dir(serial, checkpoint_dir):
    serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial)
    return os.path.join(checkpoint_dir, serial_folder)


T
tangwei12 已提交
567
def _is_checkpoint_var(var):
T
tangwei12 已提交
568
    """
T
tangwei12 已提交
569 570 571
    the checkpoint will not save or load all the variables.
    var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded.

T
tangwei12 已提交
572
    :param var
T
tangwei12 已提交
573
    """
T
tangwei12 已提交
574 575 576 577 578 579 580 581 582
    if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
            var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
            var.desc.type() == core.VarDesc.VarType.RAW:
        return False

    if var.name.endswith("@GRAD"):
        return False

    return var.persistable
T
tangwei12 已提交
583 584


T
tangwei12 已提交
585 586 587 588 589 590 591
def _interval_secs_exceed(dirname, save_interval_secs):
    dir_time = os.path.getmtime(dirname)
    if save_interval_secs > (time.time() - dir_time):
        return False
    return True


T
tangwei12 已提交
592
def _lru_delete(dirname, max_num_checkpoints=3):
T
tangwei12 已提交
593 594 595 596 597 598 599 600
    dirs = os.listdir(dirname)
    serials = []
    for serial in dirs:
        try:
            serials.append(int(serial))
        except ValueError:
            continue

T
tangwei12 已提交
601
    if len(serials) <= max_num_checkpoints:
T
tangwei12 已提交
602 603 604
        return

    serials.sort(reverse=True)
T
tangwei12 已提交
605
    serials = serials[max_num_checkpoints:]
T
tangwei12 已提交
606 607 608 609 610
    for serial in serials:
        cur_dir = os.path.join(dirname, str(serial))
        shutil.rmtree(cur_dir)


T
tangwei12 已提交
611 612
def _write_success(dirname):
    """
T
tangwei12 已提交
613
    write an empty file named "_SUCCESS" in checkpoint dir, indicate this checkpoint is correct.
T
tangwei12 已提交
614 615

    :param dirname
T
tangwei12 已提交
616
    """
T
tangwei12 已提交
617
    success_file = os.path.join(dirname, SUCCESS_MARK_FILENAME)
T
bug fix  
tangwei12 已提交
618
    with open(success_file, 'a') as f:
619
        now = time.ctime()
T
bug fix  
tangwei12 已提交
620
        f.write(now)
T
tangwei12 已提交
621 622 623 624


def _get_lastest_checkpoint_dir(checkpoint_dir):
    """
T
tangwei12 已提交
625 626 627
    get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory

    :param checkpoint_dir
T
tangwei12 已提交
628 629
    """
    if not checkpoint_dir.strip():
T
tangwei12 已提交
630
        return -1
T
tangwei12 已提交
631 632 633 634 635

    def has_success(checkpoint_dir, cur_dir):
        """
        is _SUCCESS in this dir
        """
636
        _, serial = cur_dir.split(CHECKPOINT_SEPARATOR)
T
tangwei12 已提交
637 638

        try:
639
            int(serial)
T
tangwei12 已提交
640 641 642
        except ValueError:
            return -1

643 644 645 646 647
        if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)):
            return -1

        success_path = os.path.join(
            _get_serial_dir(serial, checkpoint_dir), SUCCESS_MARK_FILENAME)
T
tangwei12 已提交
648
        if os.path.isfile(success_path):
649
            return int(serial)
T
tangwei12 已提交
650 651 652 653 654 655 656 657 658 659 660

    if not os.path.isdir(checkpoint_dir):
        return -1

    current_dir = -1
    dirs = os.listdir(checkpoint_dir)
    for cur_dir in dirs:
        success_num = has_success(checkpoint_dir, cur_dir)
        if success_num > current_dir:
            current_dir = success_num
    return current_dir