提交 2412dee3 编写于 作者: T tangwei12

code optimized

上级 06aa23b0
...@@ -454,17 +454,16 @@ def get_parameter_value_by_name(name, executor, program=None): ...@@ -454,17 +454,16 @@ def get_parameter_value_by_name(name, executor, program=None):
return get_parameter_value(var, executor) return get_parameter_value(var, executor)
SUCCESS = "_SUCCESS" SUCCESS_MARK_FILENAME = "_SUCCESS"
BEGIN_SECS = None
def save_checkpoint(executor, def save_checkpoint(executor,
dirname, dirname=None,
keep_max=3, max_num_checkpoints=3,
save_secs=600, save_interval_secs=600,
main_program=None): main_program=None):
""" """
Save Variables to Checkpint Dir Save Variables to Checkpoint Directory
:param dirname :param dirname
:param keep_max :param keep_max
...@@ -472,20 +471,19 @@ def save_checkpoint(executor, ...@@ -472,20 +471,19 @@ def save_checkpoint(executor,
:param main_program :param main_program
""" """
if dirname is None: if dirname is None:
raise Exception("save checkpoint dir can not be none") dirname = os.getcwd()
if not os.path.isdir(dirname): if not os.path.isdir(dirname):
os.makedirs(dirname) os.makedirs(dirname)
global BEGIN_SECS serial = _get_lastest_checkpoint_dir(dirname)
if BEGIN_SECS is not None: if serial >= 0 and not _interval_secs_exceed(
if time.time() - BEGIN_SECS < save_secs: os.path.join(dirname, str(serial)), save_interval_secs):
return return
BEGIN_SECS = time.time()
serial = _get_lastest_checkpoint_dir(dirname) + 1 serial = serial + 1
cur_dir = os.path.join(dirname, str(serial)) cur_dir = os.path.join(dirname, str(serial))
# save_persistables(executor, cur_dir, main_program)
save_vars( save_vars(
executor, executor,
dirname=cur_dir, dirname=cur_dir,
...@@ -494,10 +492,10 @@ def save_checkpoint(executor, ...@@ -494,10 +492,10 @@ def save_checkpoint(executor,
predicate=is_checkpoint_var, predicate=is_checkpoint_var,
filename=None) filename=None)
_write_success(cur_dir) _write_success(cur_dir)
_lru_delete(dirname, keep_max) _lru_delete(dirname, max_num_checkpoints)
def restore_checkpoint(dirname, executor, main_program=None): def restore_checkpoint(executor, dirname=None, main_program=None):
""" """
Load Variables from Checkpint Dir Load Variables from Checkpint Dir
...@@ -505,15 +503,16 @@ def restore_checkpoint(dirname, executor, main_program=None): ...@@ -505,15 +503,16 @@ def restore_checkpoint(dirname, executor, main_program=None):
:param executor :param executor
:param main_program :param main_program
""" """
if dirname is None and os.path.isdir(dirname):
raise Exception("restore checkpoint can not load variables from %s" % if dirname is None:
dirname) dirname = os.getcwd()
serial = _get_lastest_checkpoint_dir(dirname) serial = _get_lastest_checkpoint_dir(dirname)
if serial < 0: if serial < 0:
return return
cur_dir = os.path.join(dirname, str(serial)) cur_dir = os.path.join(dirname, str(serial))
# load_persistables(executor, cur_dir, main_program)
load_vars( load_vars(
executor, executor,
dirname=cur_dir, dirname=cur_dir,
...@@ -523,6 +522,10 @@ def restore_checkpoint(dirname, executor, main_program=None): ...@@ -523,6 +522,10 @@ def restore_checkpoint(dirname, executor, main_program=None):
def is_checkpoint_var(var): def is_checkpoint_var(var):
"""
VarType will fliter out FEED_MINIBATCH FETCH_LIST RAW
VarName will fliter out Gradient
"""
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \ var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
var.desc.type() == core.VarDesc.VarType.RAW: var.desc.type() == core.VarDesc.VarType.RAW:
...@@ -534,6 +537,13 @@ def is_checkpoint_var(var): ...@@ -534,6 +537,13 @@ def is_checkpoint_var(var):
return var.persistable return var.persistable
def _interval_secs_exceed(dirname, save_interval_secs):
dir_time = os.path.getmtime(dirname)
if save_interval_secs > (time.time() - dir_time):
return False
return True
def _lru_delete(dirname, keep_max=3): def _lru_delete(dirname, keep_max=3):
""" """
retain checkpoint nums with keep_max retain checkpoint nums with keep_max
...@@ -560,7 +570,7 @@ def _write_success(dirname): ...@@ -560,7 +570,7 @@ def _write_success(dirname):
""" """
write _SUCCESS to checkpoint dir write _SUCCESS to checkpoint dir
""" """
success_file = os.path.join(dirname, SUCCESS) success_file = os.path.join(dirname, SUCCESS_MARK_FILENAME)
with open(success_file, 'a'): with open(success_file, 'a'):
pass pass
...@@ -584,7 +594,8 @@ def _get_lastest_checkpoint_dir(checkpoint_dir): ...@@ -584,7 +594,8 @@ def _get_lastest_checkpoint_dir(checkpoint_dir):
except ValueError: except ValueError:
return -1 return -1
success_path = os.path.join(checkpoint_dir, cur_dir, SUCCESS) success_path = os.path.join(checkpoint_dir, cur_dir,
SUCCESS_MARK_FILENAME)
if os.path.isfile(success_path): if os.path.isfile(success_path):
return int(cur_dir) return int(cur_dir)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册