提交 e901de66 编写于 作者: T tangwei12

update var name

上级 2412dee3
......@@ -23,7 +23,7 @@ from . import core
__all__ = [
'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
'load_persistables', 'save_inference_model', 'load_inference_model',
'get_inference_program', 'save_checkpoint', 'restore_checkpoint'
'get_inference_program', 'save_checkpoint', 'load_checkpoint'
]
......@@ -466,7 +466,7 @@ def save_checkpoint(executor,
Save Variables to Checkpoint Directory
:param dirname
:param keep_max
:param max_num_checkpoints
:param save_secs
:param main_program
"""
......@@ -495,7 +495,7 @@ def save_checkpoint(executor,
_lru_delete(dirname, max_num_checkpoints)
def restore_checkpoint(executor, dirname=None, main_program=None):
def load_checkpoint(executor, dirname=None, main_program=None):
"""
Load Variables from Checkpint Dir
......@@ -544,9 +544,9 @@ def _interval_secs_exceed(dirname, save_interval_secs):
return True
def _lru_delete(dirname, keep_max=3):
def _lru_delete(dirname, max_num_checkpoints=3):
"""
retain checkpoint nums with keep_max
retain checkpoint nums with max_num_checkpoints
"""
dirs = os.listdir(dirname)
serials = []
......@@ -556,11 +556,11 @@ def _lru_delete(dirname, keep_max=3):
except ValueError:
continue
if len(serials) <= keep_max:
if len(serials) <= max_num_checkpoints:
return
serials.sort(reverse=True)
serials = serials[keep_max:]
serials = serials[max_num_checkpoints:]
for serial in serials:
cur_dir = os.path.join(dirname, str(serial))
shutil.rmtree(cur_dir)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册