提交 d96b4427 编写于 作者: T tangwei12

rename checkpoint folder to checkpoint_serial

上级 9d985340
......@@ -455,10 +455,12 @@ def get_parameter_value_by_name(name, executor, program=None):
SUCCESS_MARK_FILENAME = "_SUCCESS"
CHECKPOINT_PREFIX = "checkpoint"
CHECKPOINT_SEPARATOR = "_"
def save_checkpoint(executor,
dirname=None,
checkpoint_dir=None,
max_num_checkpoints=3,
save_interval_secs=600,
main_program=None):
......@@ -466,26 +468,27 @@ def save_checkpoint(executor,
Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory,
the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy
to keep numbers of checkpoint directory, the numbers of checkpoint directory are max_num_checkpoints at most,
The interval time between two save_checkpoint must great than or equal to save_interval_secs.
The interval between two saved checkpoints must greater than save_interval_secs.
:param dirname
:param executor
:param checkpoint_dir
:param max_num_checkpoints
:param save_secs
:param save_interval_secs
:param main_program
"""
if dirname is None:
dirname = os.getcwd()
if checkpoint_dir is None:
checkpoint_dir = os.getcwd()
if not os.path.isdir(dirname):
os.makedirs(dirname)
if not os.path.isdir(checkpoint_dir):
os.makedirs(checkpoint_dir)
serial = _get_lastest_checkpoint_dir(dirname)
serial = _get_lastest_checkpoint_dir(checkpoint_dir)
if serial >= 0 and not _interval_secs_exceed(
os.path.join(dirname, str(serial)), save_interval_secs):
_get_serial_dir(serial, checkpoint_dir), save_interval_secs):
return
serial = serial + 1
cur_dir = os.path.join(dirname, str(serial))
serial += 1
cur_dir = _get_serial_dir(serial, checkpoint_dir)
save_vars(
executor,
......@@ -495,27 +498,28 @@ def save_checkpoint(executor,
predicate=_is_checkpoint_var,
filename=None)
_write_success(cur_dir)
_lru_delete(dirname, max_num_checkpoints)
_lru_delete(checkpoint_dir, max_num_checkpoints)
def load_checkpoint(executor, dirname=None, main_program=None):
def load_checkpoint(executor, checkpoint_dir=None, main_program=None):
"""
Load checkpoint from a directory by executor,
it will find latest checkpoint file and load it auto.
it will find the most recent saved checkpoint file and load it auto.
:param executor
:param dirname
:param checkpoint_dir
:param main_program
"""
if dirname is None:
dirname = os.getcwd()
if checkpoint_dir is None:
checkpoint_dir = os.getcwd()
serial = _get_lastest_checkpoint_dir(dirname)
serial = _get_lastest_checkpoint_dir(checkpoint_dir)
if serial < 0:
return
cur_dir = os.path.join(dirname, str(serial))
cur_dir = _get_serial_dir(serial, checkpoint_dir)
load_vars(
executor,
......@@ -525,6 +529,11 @@ def load_checkpoint(executor, dirname=None, main_program=None):
filename=None)
def _get_serial_dir(serial, checkpoint_dir):
serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial)
return os.path.join(checkpoint_dir, serial_folder)
def _is_checkpoint_var(var):
"""
the checkpoint will not save or load all the variables.
......@@ -577,7 +586,8 @@ def _write_success(dirname):
"""
success_file = os.path.join(dirname, SUCCESS_MARK_FILENAME)
with open(success_file, 'a'):
pass
now = time.ctime()
success_file.write(now)
def _get_lastest_checkpoint_dir(checkpoint_dir):
......@@ -593,18 +603,20 @@ def _get_lastest_checkpoint_dir(checkpoint_dir):
"""
is _SUCCESS in this dir
"""
if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)):
return -1
_, serial = cur_dir.split(CHECKPOINT_SEPARATOR)
try:
int(cur_dir)
int(serial)
except ValueError:
return -1
success_path = os.path.join(checkpoint_dir, cur_dir,
SUCCESS_MARK_FILENAME)
if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)):
return -1
success_path = os.path.join(
_get_serial_dir(serial, checkpoint_dir), SUCCESS_MARK_FILENAME)
if os.path.isfile(success_path):
return int(cur_dir)
return int(serial)
if not os.path.isdir(checkpoint_dir):
return -1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册