diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index ac26991d41dd6615a898e15f1b3fd7dfc77f1f76..3a7b68a682d04ef5050eeb1c536f83a3f2c71bf6 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -463,8 +463,11 @@ def save_checkpoint(executor, save_interval_secs=600, main_program=None): """ - Save Variables to Checkpoint Directory - + Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory, + directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy + to keep numbers of checkpoint directory, the numbers of checkpoint directory are max_num_checkpoints at most, + The interval time between two save_checkpoint must great than or equal to save_interval_secs. + :param dirname :param max_num_checkpoints :param save_secs @@ -489,7 +492,7 @@ def save_checkpoint(executor, dirname=cur_dir, main_program=main_program, vars=None, - predicate=is_checkpoint_var, + predicate=_is_checkpoint_var, filename=None) _write_success(cur_dir) _lru_delete(dirname, max_num_checkpoints) @@ -497,10 +500,11 @@ def save_checkpoint(executor, def load_checkpoint(executor, dirname=None, main_program=None): """ - Load Variables from Checkpint Dir + Load checkpoint from directory by executor, + it will find lastest checkpoint file and load it auto. - :param dirname :param executor + :param dirname :param main_program """ @@ -517,14 +521,16 @@ def load_checkpoint(executor, dirname=None, main_program=None): executor, dirname=cur_dir, main_program=main_program, - predicate=is_checkpoint_var, + predicate=_is_checkpoint_var, filename=None) -def is_checkpoint_var(var): +def _is_checkpoint_var(var): """ - VarType will fliter out FEED_MINIBATCH FETCH_LIST RAW - VarName will fliter out Gradient + checkpoint will not save or load all the variables. + var type is FEED_MINIBATCH/FETCH_LIST/RAW and var name is end with @GRAD are discarded. + + :param var """ if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \ @@ -545,9 +551,6 @@ def _interval_secs_exceed(dirname, save_interval_secs): def _lru_delete(dirname, max_num_checkpoints=3): - """ - retain checkpoint nums with max_num_checkpoints - """ dirs = os.listdir(dirname) serials = [] for serial in dirs: @@ -568,7 +571,9 @@ def _lru_delete(dirname, max_num_checkpoints=3): def _write_success(dirname): """ - write _SUCCESS to checkpoint dir + write an empty _SUCCESS file to checkpoint dir, indicate this checkpoint is correct. + + :param dirname """ success_file = os.path.join(dirname, SUCCESS_MARK_FILENAME) with open(success_file, 'a'): @@ -577,7 +582,9 @@ def _write_success(dirname): def _get_lastest_checkpoint_dir(checkpoint_dir): """ - get the biggest number in checkpoint_dir, which has _SUCCESS + get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory + + :param checkpoint_dir """ if not checkpoint_dir.strip(): return -1