提交 d96b4427 编写于 作者: T tangwei12

rename checkpoint folder to checkpoint_serial

上级 9d985340
...@@ -455,10 +455,12 @@ def get_parameter_value_by_name(name, executor, program=None): ...@@ -455,10 +455,12 @@ def get_parameter_value_by_name(name, executor, program=None):
SUCCESS_MARK_FILENAME = "_SUCCESS" SUCCESS_MARK_FILENAME = "_SUCCESS"
CHECKPOINT_PREFIX = "checkpoint"
CHECKPOINT_SEPARATOR = "_"
def save_checkpoint(executor, def save_checkpoint(executor,
dirname=None, checkpoint_dir=None,
max_num_checkpoints=3, max_num_checkpoints=3,
save_interval_secs=600, save_interval_secs=600,
main_program=None): main_program=None):
...@@ -466,26 +468,27 @@ def save_checkpoint(executor, ...@@ -466,26 +468,27 @@ def save_checkpoint(executor,
Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory, Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory,
the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy
to keep numbers of checkpoint directory, the numbers of checkpoint directory are max_num_checkpoints at most, to keep numbers of checkpoint directory, the numbers of checkpoint directory are max_num_checkpoints at most,
The interval time between two save_checkpoint must great than or equal to save_interval_secs. The interval between two saved checkpoints must greater than save_interval_secs.
:param dirname :param executor
:param checkpoint_dir
:param max_num_checkpoints :param max_num_checkpoints
:param save_secs :param save_interval_secs
:param main_program :param main_program
""" """
if dirname is None: if checkpoint_dir is None:
dirname = os.getcwd() checkpoint_dir = os.getcwd()
if not os.path.isdir(dirname): if not os.path.isdir(checkpoint_dir):
os.makedirs(dirname) os.makedirs(checkpoint_dir)
serial = _get_lastest_checkpoint_dir(dirname) serial = _get_lastest_checkpoint_dir(checkpoint_dir)
if serial >= 0 and not _interval_secs_exceed( if serial >= 0 and not _interval_secs_exceed(
os.path.join(dirname, str(serial)), save_interval_secs): _get_serial_dir(serial, checkpoint_dir), save_interval_secs):
return return
serial = serial + 1 serial += 1
cur_dir = os.path.join(dirname, str(serial)) cur_dir = _get_serial_dir(serial, checkpoint_dir)
save_vars( save_vars(
executor, executor,
...@@ -495,27 +498,28 @@ def save_checkpoint(executor, ...@@ -495,27 +498,28 @@ def save_checkpoint(executor,
predicate=_is_checkpoint_var, predicate=_is_checkpoint_var,
filename=None) filename=None)
_write_success(cur_dir) _write_success(cur_dir)
_lru_delete(dirname, max_num_checkpoints) _lru_delete(checkpoint_dir, max_num_checkpoints)
def load_checkpoint(executor, dirname=None, main_program=None): def load_checkpoint(executor, checkpoint_dir=None, main_program=None):
""" """
Load checkpoint from a directory by executor, Load checkpoint from a directory by executor,
it will find latest checkpoint file and load it auto. it will find the most recent saved checkpoint file and load it auto.
:param executor :param executor
:param dirname :param checkpoint_dir
:param main_program :param main_program
""" """
if dirname is None: if checkpoint_dir is None:
dirname = os.getcwd() checkpoint_dir = os.getcwd()
serial = _get_lastest_checkpoint_dir(dirname) serial = _get_lastest_checkpoint_dir(checkpoint_dir)
if serial < 0: if serial < 0:
return return
cur_dir = os.path.join(dirname, str(serial))
cur_dir = _get_serial_dir(serial, checkpoint_dir)
load_vars( load_vars(
executor, executor,
...@@ -525,6 +529,11 @@ def load_checkpoint(executor, dirname=None, main_program=None): ...@@ -525,6 +529,11 @@ def load_checkpoint(executor, dirname=None, main_program=None):
filename=None) filename=None)
def _get_serial_dir(serial, checkpoint_dir):
serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial)
return os.path.join(checkpoint_dir, serial_folder)
def _is_checkpoint_var(var): def _is_checkpoint_var(var):
""" """
the checkpoint will not save or load all the variables. the checkpoint will not save or load all the variables.
...@@ -577,7 +586,8 @@ def _write_success(dirname): ...@@ -577,7 +586,8 @@ def _write_success(dirname):
""" """
success_file = os.path.join(dirname, SUCCESS_MARK_FILENAME) success_file = os.path.join(dirname, SUCCESS_MARK_FILENAME)
with open(success_file, 'a'): with open(success_file, 'a'):
pass now = time.ctime()
success_file.write(now)
def _get_lastest_checkpoint_dir(checkpoint_dir): def _get_lastest_checkpoint_dir(checkpoint_dir):
...@@ -593,18 +603,20 @@ def _get_lastest_checkpoint_dir(checkpoint_dir): ...@@ -593,18 +603,20 @@ def _get_lastest_checkpoint_dir(checkpoint_dir):
""" """
is _SUCCESS in this dir is _SUCCESS in this dir
""" """
if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)): _, serial = cur_dir.split(CHECKPOINT_SEPARATOR)
return -1
try: try:
int(cur_dir) int(serial)
except ValueError: except ValueError:
return -1 return -1
success_path = os.path.join(checkpoint_dir, cur_dir, if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)):
SUCCESS_MARK_FILENAME) return -1
success_path = os.path.join(
_get_serial_dir(serial, checkpoint_dir), SUCCESS_MARK_FILENAME)
if os.path.isfile(success_path): if os.path.isfile(success_path):
return int(cur_dir) return int(serial)
if not os.path.isdir(checkpoint_dir): if not os.path.isdir(checkpoint_dir):
return -1 return -1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册