提交 5451c78d 编写于 作者: T tangwei12

add checkpoint in io

上级 e130bf37
...@@ -454,3 +454,90 @@ def get_parameter_value_by_name(name, executor, program=None): ...@@ -454,3 +454,90 @@ def get_parameter_value_by_name(name, executor, program=None):
program = default_main_program() program = default_main_program()
var = program.global_block().var(name) var = program.global_block().var(name)
return get_parameter_value(var, executor) return get_parameter_value(var, executor)
SUCCESS = "_SUCCESS"
def save_checkpoint(executor,
dirname,
keep_max=10,
save_secs=600,
main_program=None):
"""
Save Variables to Checkpint Dir
:param dirname
:param keep_max
:param save_secs
"""
if dirname is None:
raise Exception("save checkpoint dir can not be none")
if not os.path.isdir(dirname):
os.makedirs(dirname)
serial = _get_lastest_checkpoint_dir(dirname) + 1
cur_dir = os.path.join(dirname, serial)
save_persistables(executor, cur_dir, main_program)
_write_success(cur_dir)
def restore_checkpoint(dirname, executor, main_program=None):
"""
Load Variables from Checkpint Dir
:param dir
"""
if dirname is None and os.path.isdir(dirname):
raise Exception("restore checkpoint can not load variables from %s" %
dirname)
serial = _get_lastest_checkpoint_dir(dirname) + 1
if serial < -1:
return
cur_dir = os.path.join(dirname, serial)
load_persistables(executor, cur_dir, main_program)
def _write_success(dirname):
"""
"""
success_file = os.path.join(dirname, SUCCESS)
with open(success_file, 'a'):
pass
def _get_lastest_checkpoint_dir(checkpoint_dir):
"""
get the biggest number in checkpoint_dir, which has _SUCCESS
"""
if not checkpoint_dir.strip():
return ""
def has_success(checkpoint_dir, cur_dir):
"""
is _SUCCESS in this dir
"""
if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)):
return -1
try:
int(cur_dir)
except ValueError:
return -1
success_path = os.path.join(checkpoint_dir, cur_dir, SUCCESS)
if os.path.isfile(success_path):
return int(cur_dir)
if not os.path.isdir(checkpoint_dir):
return -1
current_dir = -1
dirs = os.listdir(checkpoint_dir)
for cur_dir in dirs:
success_num = has_success(checkpoint_dir, cur_dir)
if success_num > current_dir:
current_dir = success_num
return current_dir
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册