提交 e901de66 编写于 作者: T tangwei12

update var name

上级 2412dee3
...@@ -23,7 +23,7 @@ from . import core ...@@ -23,7 +23,7 @@ from . import core
__all__ = [ __all__ = [
'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params', 'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
'load_persistables', 'save_inference_model', 'load_inference_model', 'load_persistables', 'save_inference_model', 'load_inference_model',
'get_inference_program', 'save_checkpoint', 'restore_checkpoint' 'get_inference_program', 'save_checkpoint', 'load_checkpoint'
] ]
...@@ -466,7 +466,7 @@ def save_checkpoint(executor, ...@@ -466,7 +466,7 @@ def save_checkpoint(executor,
Save Variables to Checkpoint Directory Save Variables to Checkpoint Directory
:param dirname :param dirname
:param keep_max :param max_num_checkpoints
:param save_secs :param save_secs
:param main_program :param main_program
""" """
...@@ -495,7 +495,7 @@ def save_checkpoint(executor, ...@@ -495,7 +495,7 @@ def save_checkpoint(executor,
_lru_delete(dirname, max_num_checkpoints) _lru_delete(dirname, max_num_checkpoints)
def restore_checkpoint(executor, dirname=None, main_program=None): def load_checkpoint(executor, dirname=None, main_program=None):
""" """
Load Variables from Checkpint Dir Load Variables from Checkpint Dir
...@@ -544,9 +544,9 @@ def _interval_secs_exceed(dirname, save_interval_secs): ...@@ -544,9 +544,9 @@ def _interval_secs_exceed(dirname, save_interval_secs):
return True return True
def _lru_delete(dirname, keep_max=3): def _lru_delete(dirname, max_num_checkpoints=3):
""" """
retain checkpoint nums with keep_max retain checkpoint nums with max_num_checkpoints
""" """
dirs = os.listdir(dirname) dirs = os.listdir(dirname)
serials = [] serials = []
...@@ -556,11 +556,11 @@ def _lru_delete(dirname, keep_max=3): ...@@ -556,11 +556,11 @@ def _lru_delete(dirname, keep_max=3):
except ValueError: except ValueError:
continue continue
if len(serials) <= keep_max: if len(serials) <= max_num_checkpoints:
return return
serials.sort(reverse=True) serials.sort(reverse=True)
serials = serials[keep_max:] serials = serials[max_num_checkpoints:]
for serial in serials: for serial in serials:
cur_dir = os.path.join(dirname, str(serial)) cur_dir = os.path.join(dirname, str(serial))
shutil.rmtree(cur_dir) shutil.rmtree(cur_dir)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册