提交 27b71751 编写于 作者: T tangwei12

update python annotation

上级 e901de66
......@@ -463,8 +463,11 @@ def save_checkpoint(executor,
save_interval_secs=600,
main_program=None):
"""
Save Variables to Checkpoint Directory
Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory,
directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy
to keep numbers of checkpoint directory, the numbers of checkpoint directory are max_num_checkpoints at most,
The interval time between two save_checkpoint must great than or equal to save_interval_secs.
:param dirname
:param max_num_checkpoints
:param save_secs
......@@ -489,7 +492,7 @@ def save_checkpoint(executor,
dirname=cur_dir,
main_program=main_program,
vars=None,
predicate=is_checkpoint_var,
predicate=_is_checkpoint_var,
filename=None)
_write_success(cur_dir)
_lru_delete(dirname, max_num_checkpoints)
......@@ -497,10 +500,11 @@ def save_checkpoint(executor,
def load_checkpoint(executor, dirname=None, main_program=None):
"""
Load Variables from Checkpint Dir
Load checkpoint from directory by executor,
it will find lastest checkpoint file and load it auto.
:param dirname
:param executor
:param dirname
:param main_program
"""
......@@ -517,14 +521,16 @@ def load_checkpoint(executor, dirname=None, main_program=None):
executor,
dirname=cur_dir,
main_program=main_program,
predicate=is_checkpoint_var,
predicate=_is_checkpoint_var,
filename=None)
def is_checkpoint_var(var):
def _is_checkpoint_var(var):
"""
VarType will fliter out FEED_MINIBATCH FETCH_LIST RAW
VarName will fliter out Gradient
checkpoint will not save or load all the variables.
var type is FEED_MINIBATCH/FETCH_LIST/RAW and var name is end with @GRAD are discarded.
:param var
"""
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
......@@ -545,9 +551,6 @@ def _interval_secs_exceed(dirname, save_interval_secs):
def _lru_delete(dirname, max_num_checkpoints=3):
"""
retain checkpoint nums with max_num_checkpoints
"""
dirs = os.listdir(dirname)
serials = []
for serial in dirs:
......@@ -568,7 +571,9 @@ def _lru_delete(dirname, max_num_checkpoints=3):
def _write_success(dirname):
"""
write _SUCCESS to checkpoint dir
write an empty _SUCCESS file to checkpoint dir, indicate this checkpoint is correct.
:param dirname
"""
success_file = os.path.join(dirname, SUCCESS_MARK_FILENAME)
with open(success_file, 'a'):
......@@ -577,7 +582,9 @@ def _write_success(dirname):
def _get_lastest_checkpoint_dir(checkpoint_dir):
"""
get the biggest number in checkpoint_dir, which has _SUCCESS
get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory
:param checkpoint_dir
"""
if not checkpoint_dir.strip():
return -1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册