提交 2412dee3 编写于 作者: T tangwei12

code optimized

上级 06aa23b0
......@@ -454,17 +454,16 @@ def get_parameter_value_by_name(name, executor, program=None):
return get_parameter_value(var, executor)
SUCCESS = "_SUCCESS"
BEGIN_SECS = None
SUCCESS_MARK_FILENAME = "_SUCCESS"
def save_checkpoint(executor,
dirname,
keep_max=3,
save_secs=600,
dirname=None,
max_num_checkpoints=3,
save_interval_secs=600,
main_program=None):
"""
Save Variables to Checkpint Dir
Save Variables to Checkpoint Directory
:param dirname
:param keep_max
......@@ -472,20 +471,19 @@ def save_checkpoint(executor,
:param main_program
"""
if dirname is None:
raise Exception("save checkpoint dir can not be none")
dirname = os.getcwd()
if not os.path.isdir(dirname):
os.makedirs(dirname)
global BEGIN_SECS
if BEGIN_SECS is not None:
if time.time() - BEGIN_SECS < save_secs:
serial = _get_lastest_checkpoint_dir(dirname)
if serial >= 0 and not _interval_secs_exceed(
os.path.join(dirname, str(serial)), save_interval_secs):
return
BEGIN_SECS = time.time()
serial = _get_lastest_checkpoint_dir(dirname) + 1
serial = serial + 1
cur_dir = os.path.join(dirname, str(serial))
# save_persistables(executor, cur_dir, main_program)
save_vars(
executor,
dirname=cur_dir,
......@@ -494,10 +492,10 @@ def save_checkpoint(executor,
predicate=is_checkpoint_var,
filename=None)
_write_success(cur_dir)
_lru_delete(dirname, keep_max)
_lru_delete(dirname, max_num_checkpoints)
def restore_checkpoint(dirname, executor, main_program=None):
def restore_checkpoint(executor, dirname=None, main_program=None):
"""
Load Variables from Checkpint Dir
......@@ -505,15 +503,16 @@ def restore_checkpoint(dirname, executor, main_program=None):
:param executor
:param main_program
"""
if dirname is None and os.path.isdir(dirname):
raise Exception("restore checkpoint can not load variables from %s" %
dirname)
if dirname is None:
dirname = os.getcwd()
serial = _get_lastest_checkpoint_dir(dirname)
if serial < 0:
return
cur_dir = os.path.join(dirname, str(serial))
# load_persistables(executor, cur_dir, main_program)
load_vars(
executor,
dirname=cur_dir,
......@@ -523,6 +522,10 @@ def restore_checkpoint(dirname, executor, main_program=None):
def is_checkpoint_var(var):
"""
VarType will fliter out FEED_MINIBATCH FETCH_LIST RAW
VarName will fliter out Gradient
"""
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
var.desc.type() == core.VarDesc.VarType.RAW:
......@@ -534,6 +537,13 @@ def is_checkpoint_var(var):
return var.persistable
def _interval_secs_exceed(dirname, save_interval_secs):
dir_time = os.path.getmtime(dirname)
if save_interval_secs > (time.time() - dir_time):
return False
return True
def _lru_delete(dirname, keep_max=3):
"""
retain checkpoint nums with keep_max
......@@ -560,7 +570,7 @@ def _write_success(dirname):
"""
write _SUCCESS to checkpoint dir
"""
success_file = os.path.join(dirname, SUCCESS)
success_file = os.path.join(dirname, SUCCESS_MARK_FILENAME)
with open(success_file, 'a'):
pass
......@@ -584,7 +594,8 @@ def _get_lastest_checkpoint_dir(checkpoint_dir):
except ValueError:
return -1
success_path = os.path.join(checkpoint_dir, cur_dir, SUCCESS)
success_path = os.path.join(checkpoint_dir, cur_dir,
SUCCESS_MARK_FILENAME)
if os.path.isfile(success_path):
return int(cur_dir)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册