io.py 25.2 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import os
T
tangwei12 已提交
16 17
import time
import shutil
18

19 20
from paddle.fluid.evaluator import Evaluator
from paddle.fluid.framework import Program, Parameter, default_main_program, Variable
K
fix bug  
Kexin Zhao 已提交
21
from . import core
22 23

__all__ = [
T
tangwei12 已提交
24 25
    'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
    'load_persistables', 'save_inference_model', 'load_inference_model',
T
tangwei12 已提交
26
    'get_inference_program', 'save_checkpoint', 'load_checkpoint',
27
    'clean_checkpoint', 'load_persist_vars_without_grad',
28 29
    'load_lookup_table_vars', 'save_persist_vars_without_grad',
    'get_latest_checkpoint_serial'
30 31 32 33
]


def is_parameter(var):
K
Kavya Srinet 已提交
34
    """Check whether the variable is a Parameter.
35 36 37 38 39 40 41

    This function checks whether the input variable is a Parameter.

    Args:
        var : The input variable.

    Returns:
K
Kavya Srinet 已提交
42
        boolean result whether the variable is a Parameter.
43
    """
44 45 46 47
    return isinstance(var, Parameter)


def is_persistable(var):
48
    if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
Y
Refine  
Yu Yang 已提交
49
            var.desc.type() == core.VarDesc.VarType.FETCH_LIST:
50
        return False
51 52 53 54 55 56 57 58
    return var.persistable


def _clone_var_in_block_(block, var):
    assert isinstance(var, Variable)
    return block.create_var(
        name=var.name,
        shape=var.shape,
F
fengjiayi 已提交
59
        dtype=var.dtype,
60 61 62 63 64
        type=var.type,
        lod_level=var.lod_level,
        persistable=True)


65 66 67 68 69
def save_vars(executor,
              dirname,
              main_program=None,
              vars=None,
              predicate=None,
70
              filename=None):
71 72
    """
    Save variables to directory by executor.
73

74 75
    :param executor: executor that save variable
    :param dirname: directory path
X
xuwei06 已提交
76
    :param main_program: program. If vars is None, then filter all variables in this
77
    program which fit `predicate`. Default default_main_program.
78
    :param predicate: The Predicate describes a callable that returns a variable
79 80
    as a bool. If it returns true, the corresponding input variable will be saved.
    :param vars: variables need to be saved. If vars is specified, program & predicate
81
    will be ignored
82 83
    :param filename: The name of a single file that all vars are saved to.
        If it is None, save variables to separate files.
84

85 86 87
    :return: None
    """
    if vars is None:
88
        if main_program is None:
Y
Yu Yang 已提交
89
            main_program = default_main_program()
90
        if not isinstance(main_program, Program):
91 92 93 94 95
            raise TypeError("program should be as Program type or None")

        save_vars(
            executor,
            dirname=dirname,
96
            vars=filter(predicate, main_program.list_vars()),
97
            filename=filename)
98 99 100
    else:
        save_program = Program()
        save_block = save_program.global_block()
101 102

        save_var_map = {}
103
        for each_var in vars:
104 105 106
            # NOTE: don't save the variable which type is RAW
            if each_var.type == core.VarDesc.VarType.RAW:
                continue
107
            new_var = _clone_var_in_block_(save_block, each_var)
108
            if filename is None:
109 110 111 112 113 114 115 116
                save_block.append_op(
                    type='save',
                    inputs={'X': [new_var]},
                    outputs={},
                    attrs={'file_path': os.path.join(dirname, new_var.name)})
            else:
                save_var_map[new_var.name] = new_var

117
        if filename is not None:
118 119 120 121
            save_var_list = []
            for name in sorted(save_var_map.keys()):
                save_var_list.append(save_var_map[name])

122
            save_block.append_op(
123 124
                type='save_combine',
                inputs={'X': save_var_list},
125
                outputs={},
126
                attrs={'file_path': os.path.join(dirname, filename)})
127

128 129 130
        executor.run(save_program)


131
def save_params(executor, dirname, main_program=None, filename=None):
132 133 134 135 136 137
    """
    Save all parameters to directory with executor.
    """
    save_vars(
        executor,
        dirname=dirname,
138
        main_program=main_program,
139
        vars=None,
140
        predicate=is_parameter,
141
        filename=filename)
142 143


144
def save_persistables(executor, dirname, main_program=None, filename=None):
145 146 147 148 149 150
    """
    Save all persistables to directory with executor.
    """
    save_vars(
        executor,
        dirname=dirname,
151
        main_program=main_program,
152
        vars=None,
153
        predicate=is_persistable,
154
        filename=filename)
155 156


157 158 159 160 161
def load_vars(executor,
              dirname,
              main_program=None,
              vars=None,
              predicate=None,
162
              filename=None):
163 164
    """
    Load variables from directory by executor.
165

166
    :param executor: executor that load variable
167
    :param dirname: directory path
X
xuwei06 已提交
168
    :param main_program: program. If vars is None, then filter all variables in this
Y
Yu Yang 已提交
169
    program which fit `predicate`. Default default_main_program().
170
    :param predicate: The Predicate describes a callable that returns a variable
171 172
    as a bool. If it returns true, the corresponding input variable will be loaded.
    :param vars: variables need to be loaded. If vars is specified, program &
173
    predicate will be ignored
174 175
    :param filename: The name of the single file that all vars are loaded from.
        If it is None, load variables from separate files.
176

177 178 179
    :return: None
    """
    if vars is None:
180
        if main_program is None:
Y
Yu Yang 已提交
181
            main_program = default_main_program()
182
        if not isinstance(main_program, Program):
183 184 185 186 187
            raise TypeError("program's type should be Program")

        load_vars(
            executor,
            dirname=dirname,
188
            vars=filter(predicate, main_program.list_vars()),
189
            filename=filename)
190 191 192
    else:
        load_prog = Program()
        load_block = load_prog.global_block()
193 194

        load_var_map = {}
195 196
        for each_var in vars:
            assert isinstance(each_var, Variable)
T
tangwei12 已提交
197 198
            if each_var.type == core.VarDesc.VarType.RAW:
                continue
199
            new_var = _clone_var_in_block_(load_block, each_var)
200
            if filename is None:
201 202 203 204 205 206 207 208
                load_block.append_op(
                    type='load',
                    inputs={},
                    outputs={'Out': [new_var]},
                    attrs={'file_path': os.path.join(dirname, new_var.name)})
            else:
                load_var_map[new_var.name] = new_var

209
        if filename is not None:
210 211 212 213
            load_var_list = []
            for name in sorted(load_var_map.keys()):
                load_var_list.append(load_var_map[name])

214
            load_block.append_op(
215
                type='load_combine',
216
                inputs={},
217
                outputs={"Out": load_var_list},
218
                attrs={'file_path': os.path.join(dirname, filename)})
219

220 221 222
        executor.run(load_prog)


223
def load_params(executor, dirname, main_program=None, filename=None):
224 225 226 227
    """
    load all parameters from directory by executor.
    """
    load_vars(
228 229 230
        executor,
        dirname=dirname,
        main_program=main_program,
231
        predicate=is_parameter,
232
        filename=filename)
233 234


235
def load_persistables(executor, dirname, main_program=None, filename=None):
236 237 238 239
    """
    load all persistables from directory by executor.
    """
    load_vars(
240 241 242
        executor,
        dirname=dirname,
        main_program=main_program,
243
        predicate=is_persistable,
244
        filename=filename)
245 246


247 248
def get_inference_program(target_vars, main_program=None):
    if main_program is None:
Y
Yu Yang 已提交
249
        main_program = default_main_program()
250 251
    if not isinstance(target_vars, list):
        target_vars = [target_vars]
W
wanghaoshuang 已提交
252 253 254
    vars = []
    for var in target_vars:
        if isinstance(var, Evaluator):
W
wanghaoshuang 已提交
255 256
            vars.extend(var.states)
            vars.extend(var.metrics)
W
wanghaoshuang 已提交
257 258 259
        else:
            vars.append(var)
    pruned_program = main_program.prune(targets=vars)
260 261 262 263
    inference_program = pruned_program.inference_optimize()
    return inference_program


264 265 266
def prepend_feed_ops(inference_program,
                     feed_target_names,
                     feed_holder_name='feed'):
Q
Qiao Longfei 已提交
267 268 269
    if len(feed_target_names) == 0:
        return

K
Kexin Zhao 已提交
270 271
    global_block = inference_program.global_block()
    feed_var = global_block.create_var(
272 273 274
        name=feed_holder_name,
        type=core.VarDesc.VarType.FEED_MINIBATCH,
        persistable=True)
K
Kexin Zhao 已提交
275

276
    for i, name in enumerate(feed_target_names):
K
fix bug  
Kexin Zhao 已提交
277
        out = global_block.var(name)
K
Kexin Zhao 已提交
278 279 280
        global_block.prepend_op(
            type='feed',
            inputs={'X': [feed_var]},
K
fix bug  
Kexin Zhao 已提交
281
            outputs={'Out': [out]},
K
Kexin Zhao 已提交
282 283 284
            attrs={'col': i})


285 286 287
def append_fetch_ops(inference_program,
                     fetch_target_names,
                     fetch_holder_name='fetch'):
K
Kexin Zhao 已提交
288 289
    global_block = inference_program.global_block()
    fetch_var = global_block.create_var(
290 291 292
        name=fetch_holder_name,
        type=core.VarDesc.VarType.FETCH_LIST,
        persistable=True)
K
Kexin Zhao 已提交
293

294
    for i, name in enumerate(fetch_target_names):
K
Kexin Zhao 已提交
295 296 297 298 299 300 301
        global_block.append_op(
            type='fetch',
            inputs={'X': [name]},
            outputs={'Out': [fetch_var]},
            attrs={'col': i})


302 303 304 305
def save_inference_model(dirname,
                         feeded_var_names,
                         target_vars,
                         executor,
306
                         main_program=None,
307 308
                         model_filename=None,
                         params_filename=None):
309
    """
X
xuwei06 已提交
310
    Build a model especially for inference,
311 312 313 314 315 316
    and save it to directory by the executor.

    :param dirname: directory path
    :param feeded_var_names: Names of variables that need to be feeded data during inference
    :param target_vars: Variables from which we can get inference results.
    :param executor: executor that save inference model
X
xuwei06 已提交
317
    :param main_program: original program, which will be pruned to build the inference model.
Y
Yu Yang 已提交
318
            Default default_main_program().
319 320 321 322 323
    :param model_filename: The name of file to save inference program.
        If not specified, default filename `__model__` will be used.
    :param params_filename: The name of file to save parameters.
        It is used for the case that all parameters are saved in a single binary file.
        If not specified, parameters are considered saved in separate files.
324 325 326

    :return: None
    """
F
fengjiayi 已提交
327 328 329
    if isinstance(feeded_var_names, basestring):
        feeded_var_names = [feeded_var_names]
    else:
Q
Qiao Longfei 已提交
330 331 332 333
        if len(feeded_var_names) > 0:
            if not (bool(feeded_var_names) and all(
                    isinstance(name, basestring) for name in feeded_var_names)):
                raise ValueError("'feed_var_names' should be a list of str.")
F
fengjiayi 已提交
334 335

    if isinstance(target_vars, Variable):
F
fengjiayi 已提交
336
        target_vars = [target_vars]
F
fengjiayi 已提交
337 338 339 340 341
    else:
        if not (bool(target_vars) and all(
                isinstance(var, Variable) for var in target_vars)):
            raise ValueError("'target_vars' should be a list of Variable.")

342
    if main_program is None:
Y
Yu Yang 已提交
343
        main_program = default_main_program()
344
    copy_program = main_program.clone()
345 346 347 348

    if not os.path.isdir(dirname):
        os.makedirs(dirname)

349
    # Clear the is_target information and remove the existed feed and fetch op
350
    global_block = copy_program.global_block()
351 352 353 354
    for i, op in enumerate(global_block.ops):
        op.desc.set_is_target(False)
        if op.type == "feed" or op.type == "fetch":
            global_block.remove_op(i)
355
    copy_program.desc.flush()
356

357
    pruned_program = copy_program.prune(targets=target_vars)
358
    inference_program = pruned_program.inference_optimize()
359 360
    fetch_var_names = [v.name for v in target_vars]

K
Kexin Zhao 已提交
361 362
    prepend_feed_ops(inference_program, feeded_var_names)
    append_fetch_ops(inference_program, fetch_var_names)
363

364 365
    if model_filename is not None:
        model_filename = os.path.basename(model_filename)
366
    else:
367 368
        model_filename = "__model__"
    model_filename = os.path.join(dirname, model_filename)
369

370 371 372 373
    if params_filename is not None:
        params_filename = os.path.basename(params_filename)

    with open(model_filename, "wb") as f:
374
        f.write(inference_program.desc.serialize_to_string())
375

376
    save_persistables(executor, dirname, inference_program, params_filename)
377 378


379 380 381 382
def load_inference_model(dirname,
                         executor,
                         model_filename=None,
                         params_filename=None):
383 384 385 386 387
    """
    Load inference model from a directory

    :param dirname: directory path
    :param executor: executor that load inference model
388 389 390 391 392 393
    :param model_filename: The name of file to load inference program.
        If not specified, default filename `__model__` will be used.
    :param params_filename: The name of file to load parameters.
        It is used for the case that all parameters are saved in a single binary file.
        If not specified, parameters are considered saved in separate files.

394
    :return: [program, feed_target_names, fetch_targets]
395
             program: program especially for inference.
396 397
             feed_target_names: Names of variables that need to feed data
             fetch_targets: Variables from which we can get inference results.
398 399 400 401
    """
    if not os.path.isdir(dirname):
        raise ValueError("There is no directory named '%s'", dirname)

402 403
    if model_filename is not None:
        model_filename = os.path.basename(model_filename)
404
    else:
405 406 407 408 409
        model_filename = "__model__"
    model_filename = os.path.join(dirname, model_filename)

    if params_filename is not None:
        params_filename = os.path.basename(params_filename)
410

411
    with open(model_filename, "rb") as f:
412 413
        program_desc_str = f.read()

414
    program = Program.parse_from_string(program_desc_str)
415
    load_persistables(executor, dirname, program, params_filename)
416

417 418
    feed_target_names = program.desc.get_feed_target_names()
    fetch_target_names = program.desc.get_fetch_target_names()
419 420 421 422 423
    fetch_targets = [
        program.global_block().var(name) for name in fetch_target_names
    ]

    return [program, feed_target_names, fetch_targets]
X
xuwei06 已提交
424 425 426 427 428 429 430 431


def get_parameter_value(para, executor):
    """
    Get the LoDTensor for the parameter

    :param executor: executor for retrieving the value
    :param para: the given parameter
432

X
xuwei06 已提交
433 434
    :return: the LoDTensor for the parameter
    """
X
xuwei06 已提交
435 436
    assert is_parameter(para)

X
xuwei06 已提交
437 438 439 440 441 442 443 444 445 446 447 448 449
    get_program = Program()
    block = get_program.global_block()
    new_var = _clone_var_in_block_(block, para)
    return executor.run(get_program, feed={}, fetch_list=[new_var])[0]


def get_parameter_value_by_name(name, executor, program=None):
    """
    Get the LoDTensor for paramter with the given name

    :param executor: executor for retrieving the value
    :param name: the name of the parameter
    :param program: the program where the variable is found
Y
Yu Yang 已提交
450
            Default default_main_program().
451

X
xuwei06 已提交
452 453 454
    :return: the LoDTensor for the variable
    """
    if program is None:
Y
Yu Yang 已提交
455
        program = default_main_program()
X
xuwei06 已提交
456 457
    var = program.global_block().var(name)
    return get_parameter_value(var, executor)
T
tangwei12 已提交
458 459


T
tangwei12 已提交
460
SUCCESS_MARK_FILENAME = "_SUCCESS"
461
CHECKPOINT_PREFIX = "checkpoint"
T
tangwei12 已提交
462
MODEL_DIR = "__model__"
463
LOOKUP_TABLE_DIR = "__lookup_table__"
T
tangwei12 已提交
464
TRAINER_PREFIX = "trainer"
465
PSERVER_PREFIX = "pserver"
466
CHECKPOINT_SEPARATOR = "_"
T
tangwei12 已提交
467 468 469


def save_checkpoint(executor,
T
tangwei12 已提交
470
                    checkpoint_dir,
T
tangwei12 已提交
471 472
                    trainer_id,
                    trainer_args=None,
T
tangwei12 已提交
473 474
                    main_program=None,
                    max_num_checkpoints=3):
T
tangwei12 已提交
475
    """
T
tangwei12 已提交
476
    Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory,
T
tangwei12 已提交
477
    the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy
T
tangwei12 已提交
478
    to keep numbers of checkpoint directory,  the numbers of checkpoint directory are max_num_checkpoints at most,
479
    The interval between two saved checkpoints must greater than save_interval_secs.
T
tangwei12 已提交
480

T
tangwei12 已提交
481 482
    :param executor executor for save the value
    :param checkpoint_dir the checkpoint directory 
T
tangwei12 已提交
483
    :param trainer_id currect trainer id, if id is equal to 0, the trainer is chief
T
tangwei12 已提交
484 485 486 487
    :param main_program   will save all variables in program 
    :param max_num_checkpoints will keep numbers of checkpoint serials not bigger than max_num_checkpoints
    """
    if checkpoint_dir is None:
T
tangwei12 已提交
488
        raise ValueError("'checkpoint_dir' should not be None")
T
tangwei12 已提交
489

T
tangwei12 已提交
490 491
    if trainer_args:
        assert isinstance(trainer_args, dict)
T
tangwei12 已提交
492

493 494
    if not os.path.isdir(checkpoint_dir):
        os.makedirs(checkpoint_dir)
T
tangwei12 已提交
495

T
tangwei12 已提交
496
    serial = get_latest_checkpoint_serial(checkpoint_dir) + 1
T
tangwei12 已提交
497
    cur_dir = _get_serial_dir(checkpoint_dir, serial)
T
tangwei12 已提交
498

T
tangwei12 已提交
499 500
    save_trainer_args(cur_dir, trainer_id, trainer_args)

T
tangwei12 已提交
501
    if trainer_id == 0:
T
tangwei12 已提交
502
        save_persist_vars_without_grad(executor, cur_dir, main_program)
T
tangwei12 已提交
503
        save_pserver_vars_by_notify(executor, cur_dir, "")
T
tangwei12 已提交
504

T
tangwei12 已提交
505
    _scroll_delete(checkpoint_dir, max_num_checkpoints)
T
tangwei12 已提交
506 507


T
tangwei12 已提交
508
def load_checkpoint(executor, checkpoint_dir, serial, main_program):
T
tangwei12 已提交
509
    """
T
tangwei12 已提交
510
    Load checkpoint from a directory by executor,
511
    it will find  the most recent saved checkpoint file and load it auto.
T
tangwei12 已提交
512

T
tangwei12 已提交
513 514 515 516
    :param executor executor for load the value
    :param checkpoint_dir  the checkpoint directory 
    :param serial the serial folder in checkpoint directory will be load
    :param main_program  will load all variables in program 
T
tangwei12 已提交
517
    """
T
tangwei12 已提交
518

T
tangwei12 已提交
519
    if checkpoint_dir is None:
T
tangwei12 已提交
520
        raise ValueError("'checkpoint_dir' should not be None")
T
tangwei12 已提交
521

T
tangwei12 已提交
522
    if serial is None or serial < 0:
T
tangwei12 已提交
523
        raise ValueError("'serial' should not be None or <0 ")
T
tangwei12 已提交
524

T
tangwei12 已提交
525
    if main_program is None:
T
tangwei12 已提交
526
        raise ValueError('main_program should not be None.')
527

T
tangwei12 已提交
528
    cur_dir = _get_serial_dir(checkpoint_dir, serial)
T
tangwei12 已提交
529
    load_persist_vars_without_grad(executor, cur_dir, main_program, True)
T
tangwei12 已提交
530 531


T
tangwei12 已提交
532 533
def clean_checkpoint(checkpoint_dir, delete_dir=False):
    """
T
tangwei12 已提交
534 535
    clean the checkpoint dir, when the train exits normally, 
    the trainer will call clean_checkpoint to delete checkpoint directory saved before.
T
tangwei12 已提交
536
    delete_dir only works when the directory is empty, otherwise, OSError is raised.  
537 538 539

    :param checkpoint_dir
    :param delete_dir
T
tangwei12 已提交
540
    """
541

T
tangwei12 已提交
542
    if checkpoint_dir is None:
T
tangwei12 已提交
543
        raise ValueError("'checkpoint_dir' should not be None")
T
tangwei12 已提交
544
    _scroll_delete(checkpoint_dir, max_num_checkpoints=0)
T
tangwei12 已提交
545 546 547 548 549

    if delete_dir and not os.listdir(checkpoint_dir):
        os.rmdir(checkpoint_dir)


T
tangwei12 已提交
550 551 552 553
def load_persist_vars_without_grad(executor,
                                   dirname,
                                   program,
                                   has_model_dir=False):
T
tangwei12 已提交
554 555 556
    """
    load_persist_vars_without_grad will load variables from a directory by an executor,
    the variable named end with "@GRAD" will not be loaded.
557

T
tangwei12 已提交
558 559 560 561
    :param executor  executor for load the value
    :param dirname the checkpoint directory 
    :param program   will load all variables in program 
    :param has_model_dir if has_model_dir is True, will load variables from  sub directory named __model__
T
tangwei12 已提交
562 563
    """

T
tangwei12 已提交
564
    if has_model_dir:
T
tangwei12 已提交
565 566 567 568 569 570 571 572 573 574
        dirname = _get_model_dir(dirname)

    load_vars(
        executor,
        dirname=dirname,
        main_program=program,
        predicate=_is_checkpoint_var,
        filename=None)


575 576 577 578 579 580 581 582
def load_lookup_table_vars(executor, dirname, pserver_id, table_name):
    lookup_table_dir = os.path.join(dirname, LOOKUP_TABLE_DIR)
    table_file = table_name + CHECKPOINT_SEPARATOR + PSERVER_PREFIX + CHECKPOINT_SEPARATOR + str(
        pserver_id)

    load_vars(executor, lookup_table_dir, vars=table_name, filename=table_file)


T
tangwei12 已提交
583 584 585 586
def save_persist_vars_without_grad(executor, dirname, program):
    """
    save_persist_vars_without_grad  will save variables to a directory by an executor,
    the variable named end with "@GRAD" will not be saved.
587

T
tangwei12 已提交
588 589 590
    :param executor  executor for load the value
    :param dirname the checkpoint directory 
    :param program   will load all variables in program
T
tangwei12 已提交
591 592 593 594 595 596 597 598 599 600 601 602
    """
    cur_dir = _get_model_dir(dirname)
    save_vars(
        executor,
        dirname=cur_dir,
        main_program=program,
        vars=None,
        predicate=_is_checkpoint_var,
        filename=None)
    _write_success(cur_dir)


T
tangwei12 已提交
603 604 605 606 607 608 609 610 611 612 613 614 615
def save_pserver_vars_by_notify(executor, dirname, epmap):
    """
    """
    cur_dir = _get_lookuptable_dir(dirname)

    checkpoint_notify_program = Program()
    checkpoint_notify_block = checkpoint_notify_program.global_block()

    attrs = {}
    attrs['epmap'] = None
    attrs['dir'] = cur_dir

    checkpoint_notify_block.append_op(
T
tangwei12 已提交
616
        type='checkpoint_notify', inputs={}, output={}, attrs=attrs)
T
tangwei12 已提交
617 618 619
    executor.run(checkpoint_notify_program)


T
tangwei12 已提交
620
def save_trainer_args(dirname, trainer_id, trainer_args):
T
tangwei12 已提交
621 622
    assert isinstance(trainer_args, dict)

T
tangwei12 已提交
623 624 625 626 627 628 629 630 631 632
    cur_dir = _get_trainer_dir(dirname, trainer_id)

    for name, value in trainer_args.iteritems():
        args_file = os.path.join(cur_dir, name)
        with open(args_file, 'w') as f:
            f.write(str(value))
    _write_success(cur_dir)


def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args):
T
tangwei12 已提交
633 634
    assert isinstance(trainer_args, list)

T
tangwei12 已提交
635 636 637 638 639 640 641 642 643 644 645 646 647
    cur_dir = _get_serial_dir(checkpoint_dir, serial)
    cur_dir = _get_trainer_dir(cur_dir, trainer_id)

    ret_values = []

    for arg in trainer_args:
        cur_file = os.path.join(cur_dir, arg)
        with open(cur_file, 'r') as f:
            contents = f.read()
            ret_values.append(contents.strip())
    return ret_values


T
tangwei12 已提交
648
def _is_checkpoint_var(var):
T
tangwei12 已提交
649
    """
T
tangwei12 已提交
650 651 652
    the checkpoint will not save or load all the variables.
    var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded.

T
tangwei12 已提交
653
    :param var
T
tangwei12 已提交
654
    """
T
tangwei12 已提交
655 656 657 658
    if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
            var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
            var.desc.type() == core.VarDesc.VarType.RAW:
        return False
T
tangwei12 已提交
659
    # @GRAD are named for gradient variables, checkpoint will not save it.
T
tangwei12 已提交
660 661
    if "@GRAD" in var.name:
        return False
T
tangwei12 已提交
662
    # .trainer_ are named for distribute train variables, checkpoint will not save it.
T
tangwei12 已提交
663 664 665
    if ".trainer_" in var.name:
        return False

T
tangwei12 已提交
666
    # .block is named for distribute train variables, checkpoint will not save it.
T
tangwei12 已提交
667
    if ".block" in var.name:
T
tangwei12 已提交
668 669 670
        return False

    return var.persistable
T
tangwei12 已提交
671 672


T
tangwei12 已提交
673 674 675 676 677 678 679 680 681 682 683 684
def _get_dir_serial(dirname):
    _, serial = dirname.split(CHECKPOINT_SEPARATOR)

    try:
        serial_num = int(serial)
    except ValueError:
        serial_num = -1
    return serial_num


def _get_serial_dir(dirname, serial):
    serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial)
T
tangwei12 已提交
685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701
    serial_dir = os.path.join(dirname, serial_folder)

    if not os.path.isdir(serial_dir):
        os.makedirs(serial_dir)

    return serial_dir


def _get_model_dir(dirname):
    model_dir = os.path.join(dirname, MODEL_DIR)

    if not os.path.isdir(model_dir):
        os.makedirs(model_dir)

    return model_dir


T
tangwei12 已提交
702 703 704 705 706 707 708 709 710
def _get_lookuptable_dir(dirname):
    lookuptable_dir = os.path.join(dirname, LOOKUP_TABLE_DIR)

    if not os.path.isdir(lookuptable_dir):
        os.makedirs(lookuptable_dir)

    return lookuptable_dir


T
tangwei12 已提交
711 712 713 714 715 716 717 718
def _get_trainer_dir(dirname, trainer_id):
    trainer_folder = TRAINER_PREFIX + CHECKPOINT_SEPARATOR + str(trainer_id)
    trainer_dir = os.path.join(dirname, trainer_folder)

    if not os.path.isdir(trainer_dir):
        os.makedirs(trainer_dir)

    return trainer_dir
T
tangwei12 已提交
719 720


T
tangwei12 已提交
721
def _scroll_delete(dirname, max_num_checkpoints=3):
T
tangwei12 已提交
722
    dirs = os.listdir(dirname)
T
tangwei12 已提交
723
    serial_map = {}
T
tangwei12 已提交
724
    for serial in dirs:
T
tangwei12 已提交
725 726
        serial_num = _get_dir_serial(serial)
        serial_map[serial_num] = serial
T
tangwei12 已提交
727

T
tangwei12 已提交
728
    if len(serial_map.keys()) <= max_num_checkpoints:
T
tangwei12 已提交
729 730
        return

T
tangwei12 已提交
731
    serials = serial_map.keys()
T
tangwei12 已提交
732
    serials.sort(reverse=True)
T
tangwei12 已提交
733
    serials = serials[max_num_checkpoints:]
T
tangwei12 已提交
734
    for serial in serials:
T
tangwei12 已提交
735
        cur_dir = _get_serial_dir(dirname, serial)
T
tangwei12 已提交
736 737 738
        shutil.rmtree(cur_dir)


T
tangwei12 已提交
739 740
def _write_success(dirname):
    """
T
tangwei12 已提交
741
    write an empty file named "_SUCCESS" in checkpoint dir, indicate this checkpoint is correct.
T
tangwei12 已提交
742 743

    :param dirname
T
tangwei12 已提交
744
    """
T
tangwei12 已提交
745
    success_file = os.path.join(dirname, SUCCESS_MARK_FILENAME)
T
bug fix  
tangwei12 已提交
746
    with open(success_file, 'a') as f:
747
        now = time.ctime()
T
bug fix  
tangwei12 已提交
748
        f.write(now)
T
tangwei12 已提交
749 750


T
tangwei12 已提交
751
def get_latest_checkpoint_serial(checkpoint_dir):
T
tangwei12 已提交
752
    """
T
tangwei12 已提交
753 754 755
    get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory

    :param checkpoint_dir
T
tangwei12 已提交
756
    """
T
tangwei12 已提交
757
    if not checkpoint_dir:
T
tangwei12 已提交
758
        return -1
T
tangwei12 已提交
759 760 761 762 763 764

    def has_success(checkpoint_dir, cur_dir):
        """
        is _SUCCESS in this dir
        """

T
tangwei12 已提交
765
        serial = _get_dir_serial(cur_dir)
T
tangwei12 已提交
766 767
        if serial == -1 or not os.path.isdir(
                os.path.join(checkpoint_dir, cur_dir)):
768 769 770
            return -1

        success_path = os.path.join(
T
tangwei12 已提交
771 772
            _get_serial_dir(checkpoint_dir, serial), MODEL_DIR,
            SUCCESS_MARK_FILENAME)
T
tangwei12 已提交
773
        if os.path.isfile(success_path):
T
tangwei12 已提交
774
            return serial
T
tangwei12 已提交
775 776 777 778 779 780 781 782 783 784 785

    if not os.path.isdir(checkpoint_dir):
        return -1

    current_dir = -1
    dirs = os.listdir(checkpoint_dir)
    for cur_dir in dirs:
        success_num = has_success(checkpoint_dir, cur_dir)
        if success_num > current_dir:
            current_dir = success_num
    return current_dir