diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index b1748f0ad0a39ab551dbc29309150366a0f6fd29..01debaff56a61ed3abf3239147d93ee89d05d99a 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -454,17 +454,16 @@ def get_parameter_value_by_name(name, executor, program=None): return get_parameter_value(var, executor) -SUCCESS = "_SUCCESS" -BEGIN_SECS = None +SUCCESS_MARK_FILENAME = "_SUCCESS" def save_checkpoint(executor, - dirname, - keep_max=3, - save_secs=600, + dirname=None, + max_num_checkpoints=3, + save_interval_secs=600, main_program=None): """ - Save Variables to Checkpint Dir + Save Variables to Checkpoint Directory :param dirname :param keep_max @@ -472,20 +471,19 @@ def save_checkpoint(executor, :param main_program """ if dirname is None: - raise Exception("save checkpoint dir can not be none") + dirname = os.getcwd() if not os.path.isdir(dirname): os.makedirs(dirname) - global BEGIN_SECS - if BEGIN_SECS is not None: - if time.time() - BEGIN_SECS < save_secs: - return - BEGIN_SECS = time.time() + serial = _get_lastest_checkpoint_dir(dirname) + if serial >= 0 and not _interval_secs_exceed( + os.path.join(dirname, str(serial)), save_interval_secs): + return - serial = _get_lastest_checkpoint_dir(dirname) + 1 + serial = serial + 1 cur_dir = os.path.join(dirname, str(serial)) - # save_persistables(executor, cur_dir, main_program) + save_vars( executor, dirname=cur_dir, @@ -494,10 +492,10 @@ def save_checkpoint(executor, predicate=is_checkpoint_var, filename=None) _write_success(cur_dir) - _lru_delete(dirname, keep_max) + _lru_delete(dirname, max_num_checkpoints) -def restore_checkpoint(dirname, executor, main_program=None): +def restore_checkpoint(executor, dirname=None, main_program=None): """ Load Variables from Checkpint Dir @@ -505,15 +503,16 @@ def restore_checkpoint(dirname, executor, main_program=None): :param executor :param main_program """ - if dirname is None and os.path.isdir(dirname): - raise Exception("restore checkpoint can not load variables from %s" % - dirname) + + if dirname is None: + dirname = os.getcwd() + serial = _get_lastest_checkpoint_dir(dirname) if serial < 0: return cur_dir = os.path.join(dirname, str(serial)) - # load_persistables(executor, cur_dir, main_program) + load_vars( executor, dirname=cur_dir, @@ -523,6 +522,10 @@ def restore_checkpoint(dirname, executor, main_program=None): def is_checkpoint_var(var): + """ + VarType will fliter out FEED_MINIBATCH FETCH_LIST RAW + VarName will fliter out Gradient + """ if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \ var.desc.type() == core.VarDesc.VarType.RAW: @@ -534,6 +537,13 @@ def is_checkpoint_var(var): return var.persistable +def _interval_secs_exceed(dirname, save_interval_secs): + dir_time = os.path.getmtime(dirname) + if save_interval_secs > (time.time() - dir_time): + return False + return True + + def _lru_delete(dirname, keep_max=3): """ retain checkpoint nums with keep_max @@ -560,7 +570,7 @@ def _write_success(dirname): """ write _SUCCESS to checkpoint dir """ - success_file = os.path.join(dirname, SUCCESS) + success_file = os.path.join(dirname, SUCCESS_MARK_FILENAME) with open(success_file, 'a'): pass @@ -584,7 +594,8 @@ def _get_lastest_checkpoint_dir(checkpoint_dir): except ValueError: return -1 - success_path = os.path.join(checkpoint_dir, cur_dir, SUCCESS) + success_path = os.path.join(checkpoint_dir, cur_dir, + SUCCESS_MARK_FILENAME) if os.path.isfile(success_path): return int(cur_dir)