提交 cf3fb248 编写于 作者: T tangwei12

add clean checkpoint

上级 192f9a5a
...@@ -529,6 +529,19 @@ def load_checkpoint(executor, checkpoint_dir=None, main_program=None): ...@@ -529,6 +529,19 @@ def load_checkpoint(executor, checkpoint_dir=None, main_program=None):
filename=None) filename=None)
def clean_checkpoint(checkpoint_dir, delete_dir=False):
"""
clean the checkpoint dir, when the train exits normally, the trainer will call clean_checkpoint to delete checkpoint directory saved before.
delete_dir only works when the directory is empty, otherwise, OSError is raised.
"""
if checkpoint_dir is None:
checkpoint_dir = os.getcwd()
_lru_delete(checkpoint_dir, max_num_checkpoints=0)
if delete_dir and not os.listdir(checkpoint_dir):
os.rmdir(checkpoint_dir)
def _get_serial_dir(serial, checkpoint_dir): def _get_serial_dir(serial, checkpoint_dir):
serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial) serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial)
return os.path.join(checkpoint_dir, serial_folder) return os.path.join(checkpoint_dir, serial_folder)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册