提交 08e5f0ae 编写于 作者: T tangwei12

rename need_load_checkpoint to get_latest_checkpoint_serial

上级 c06f43bb
...@@ -25,7 +25,7 @@ __all__ = [ ...@@ -25,7 +25,7 @@ __all__ = [
'load_persistables', 'save_inference_model', 'load_inference_model', 'load_persistables', 'save_inference_model', 'load_inference_model',
'get_inference_program', 'save_checkpoint', 'load_checkpoint', 'get_inference_program', 'save_checkpoint', 'load_checkpoint',
'clean_checkpoint', 'load_persist_vars_without_grad', 'clean_checkpoint', 'load_persist_vars_without_grad',
'save_persist_vars_without_grad' 'save_persist_vars_without_grad', 'get_latest_checkpoint_serial'
] ]
...@@ -503,7 +503,7 @@ def save_checkpoint(executor, ...@@ -503,7 +503,7 @@ def save_checkpoint(executor,
_lru_delete(checkpoint_dir, max_num_checkpoints) _lru_delete(checkpoint_dir, max_num_checkpoints)
def need_load_checkpoint(checkpoint_dir): def get_latest_checkpoint_serial(checkpoint_dir):
""" """
If the directory have checkpoint files, it will return lastest checkpoint directory serial number If the directory have checkpoint files, it will return lastest checkpoint directory serial number
......
...@@ -146,7 +146,7 @@ class Trainer(object): ...@@ -146,7 +146,7 @@ class Trainer(object):
"The checkpoint_config shoule be an instance of CheckpointConfig" "The checkpoint_config shoule be an instance of CheckpointConfig"
) )
else: else:
self.checkpoint.load_serial = io.need_load_checkpoint( self.checkpoint.load_serial = io.get_latest_checkpoint_serial(
self.checkpoint.checkpoint_dir) self.checkpoint.checkpoint_dir)
self.scope = core.Scope() self.scope = core.Scope()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册