提交 27b71751 编写于 作者: T tangwei12

update python annotation

上级 e901de66
...@@ -463,8 +463,11 @@ def save_checkpoint(executor, ...@@ -463,8 +463,11 @@ def save_checkpoint(executor,
save_interval_secs=600, save_interval_secs=600,
main_program=None): main_program=None):
""" """
Save Variables to Checkpoint Directory Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory,
directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy
to keep numbers of checkpoint directory, the numbers of checkpoint directory are max_num_checkpoints at most,
The interval time between two save_checkpoint must great than or equal to save_interval_secs.
:param dirname :param dirname
:param max_num_checkpoints :param max_num_checkpoints
:param save_secs :param save_secs
...@@ -489,7 +492,7 @@ def save_checkpoint(executor, ...@@ -489,7 +492,7 @@ def save_checkpoint(executor,
dirname=cur_dir, dirname=cur_dir,
main_program=main_program, main_program=main_program,
vars=None, vars=None,
predicate=is_checkpoint_var, predicate=_is_checkpoint_var,
filename=None) filename=None)
_write_success(cur_dir) _write_success(cur_dir)
_lru_delete(dirname, max_num_checkpoints) _lru_delete(dirname, max_num_checkpoints)
...@@ -497,10 +500,11 @@ def save_checkpoint(executor, ...@@ -497,10 +500,11 @@ def save_checkpoint(executor,
def load_checkpoint(executor, dirname=None, main_program=None): def load_checkpoint(executor, dirname=None, main_program=None):
""" """
Load Variables from Checkpint Dir Load checkpoint from directory by executor,
it will find lastest checkpoint file and load it auto.
:param dirname
:param executor :param executor
:param dirname
:param main_program :param main_program
""" """
...@@ -517,14 +521,16 @@ def load_checkpoint(executor, dirname=None, main_program=None): ...@@ -517,14 +521,16 @@ def load_checkpoint(executor, dirname=None, main_program=None):
executor, executor,
dirname=cur_dir, dirname=cur_dir,
main_program=main_program, main_program=main_program,
predicate=is_checkpoint_var, predicate=_is_checkpoint_var,
filename=None) filename=None)
def is_checkpoint_var(var): def _is_checkpoint_var(var):
""" """
VarType will fliter out FEED_MINIBATCH FETCH_LIST RAW checkpoint will not save or load all the variables.
VarName will fliter out Gradient var type is FEED_MINIBATCH/FETCH_LIST/RAW and var name is end with @GRAD are discarded.
:param var
""" """
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \ var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
...@@ -545,9 +551,6 @@ def _interval_secs_exceed(dirname, save_interval_secs): ...@@ -545,9 +551,6 @@ def _interval_secs_exceed(dirname, save_interval_secs):
def _lru_delete(dirname, max_num_checkpoints=3): def _lru_delete(dirname, max_num_checkpoints=3):
"""
retain checkpoint nums with max_num_checkpoints
"""
dirs = os.listdir(dirname) dirs = os.listdir(dirname)
serials = [] serials = []
for serial in dirs: for serial in dirs:
...@@ -568,7 +571,9 @@ def _lru_delete(dirname, max_num_checkpoints=3): ...@@ -568,7 +571,9 @@ def _lru_delete(dirname, max_num_checkpoints=3):
def _write_success(dirname): def _write_success(dirname):
""" """
write _SUCCESS to checkpoint dir write an empty _SUCCESS file to checkpoint dir, indicate this checkpoint is correct.
:param dirname
""" """
success_file = os.path.join(dirname, SUCCESS_MARK_FILENAME) success_file = os.path.join(dirname, SUCCESS_MARK_FILENAME)
with open(success_file, 'a'): with open(success_file, 'a'):
...@@ -577,7 +582,9 @@ def _write_success(dirname): ...@@ -577,7 +582,9 @@ def _write_success(dirname):
def _get_lastest_checkpoint_dir(checkpoint_dir): def _get_lastest_checkpoint_dir(checkpoint_dir):
""" """
get the biggest number in checkpoint_dir, which has _SUCCESS get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory
:param checkpoint_dir
""" """
if not checkpoint_dir.strip(): if not checkpoint_dir.strip():
return -1 return -1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册