提交 5f5d6a9d 编写于 作者: T tangwei12

optimized checkpoint and save_model

上级 5eea5db9
...@@ -489,9 +489,9 @@ CHECKPOINT_SEPARATOR = "_" ...@@ -489,9 +489,9 @@ CHECKPOINT_SEPARATOR = "_"
def save_checkpoint(executor, def save_checkpoint(executor,
checkpoint_dir=None, checkpoint_dir,
max_num_checkpoints=3, main_program=None,
main_program=None): max_num_checkpoints=3):
""" """
Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory, Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory,
the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy
...@@ -500,12 +500,11 @@ def save_checkpoint(executor, ...@@ -500,12 +500,11 @@ def save_checkpoint(executor,
:param executor :param executor
:param checkpoint_dir :param checkpoint_dir
:param max_num_checkpoints
:param save_interval_secs
:param main_program :param main_program
:param max_num_checkpoints
""" """
if checkpoint_dir is None: if checkpoint_dir is None:
checkpoint_dir = os.getcwd() raise ValueError("The values of 'checkpoint_dir' should not be None")
if not os.path.isdir(checkpoint_dir): if not os.path.isdir(checkpoint_dir):
os.makedirs(checkpoint_dir) os.makedirs(checkpoint_dir)
...@@ -518,7 +517,7 @@ def save_checkpoint(executor, ...@@ -518,7 +517,7 @@ def save_checkpoint(executor,
_lru_delete(checkpoint_dir, max_num_checkpoints) _lru_delete(checkpoint_dir, max_num_checkpoints)
def load_checkpoint(executor, checkpoint_dir=None, main_program=None): def load_checkpoint(executor, checkpoint_dir, main_program=None):
""" """
Load checkpoint from a directory by executor, Load checkpoint from a directory by executor,
it will find the most recent saved checkpoint file and load it auto. it will find the most recent saved checkpoint file and load it auto.
...@@ -529,7 +528,7 @@ def load_checkpoint(executor, checkpoint_dir=None, main_program=None): ...@@ -529,7 +528,7 @@ def load_checkpoint(executor, checkpoint_dir=None, main_program=None):
""" """
if checkpoint_dir is None: if checkpoint_dir is None:
checkpoint_dir = os.getcwd() raise ValueError("The values of 'checkpoint_dir' should not be None")
serial = _get_lastest_checkpoint_dir(checkpoint_dir) serial = _get_lastest_checkpoint_dir(checkpoint_dir)
...@@ -546,7 +545,7 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False): ...@@ -546,7 +545,7 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False):
delete_dir only works when the directory is empty, otherwise, OSError is raised. delete_dir only works when the directory is empty, otherwise, OSError is raised.
""" """
if checkpoint_dir is None: if checkpoint_dir is None:
checkpoint_dir = os.getcwd() raise ValueError("The values of 'checkpoint_dir' should not be None")
_lru_delete(checkpoint_dir, max_num_checkpoints=0) _lru_delete(checkpoint_dir, max_num_checkpoints=0)
if delete_dir and not os.listdir(checkpoint_dir): if delete_dir and not os.listdir(checkpoint_dir):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册