diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 845e8c9ca2765e8cb62c9d9289da1764f5a0fcfb..239736aad08540c7aa18059da038f6078e97e1db 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -455,10 +455,12 @@ def get_parameter_value_by_name(name, executor, program=None): SUCCESS_MARK_FILENAME = "_SUCCESS" +CHECKPOINT_PREFIX = "checkpoint" +CHECKPOINT_SEPARATOR = "_" def save_checkpoint(executor, - dirname=None, + checkpoint_dir=None, max_num_checkpoints=3, save_interval_secs=600, main_program=None): @@ -466,26 +468,27 @@ def save_checkpoint(executor, Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory, the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy to keep numbers of checkpoint directory, the numbers of checkpoint directory are max_num_checkpoints at most, - The interval time between two save_checkpoint must great than or equal to save_interval_secs. + The interval between two saved checkpoints must greater than save_interval_secs. - :param dirname + :param executor + :param checkpoint_dir :param max_num_checkpoints - :param save_secs + :param save_interval_secs :param main_program """ - if dirname is None: - dirname = os.getcwd() + if checkpoint_dir is None: + checkpoint_dir = os.getcwd() - if not os.path.isdir(dirname): - os.makedirs(dirname) + if not os.path.isdir(checkpoint_dir): + os.makedirs(checkpoint_dir) - serial = _get_lastest_checkpoint_dir(dirname) + serial = _get_lastest_checkpoint_dir(checkpoint_dir) if serial >= 0 and not _interval_secs_exceed( - os.path.join(dirname, str(serial)), save_interval_secs): + _get_serial_dir(serial, checkpoint_dir), save_interval_secs): return - serial = serial + 1 - cur_dir = os.path.join(dirname, str(serial)) + serial += 1 + cur_dir = _get_serial_dir(serial, checkpoint_dir) save_vars( executor, @@ -495,27 +498,28 @@ def save_checkpoint(executor, predicate=_is_checkpoint_var, filename=None) _write_success(cur_dir) - _lru_delete(dirname, max_num_checkpoints) + _lru_delete(checkpoint_dir, max_num_checkpoints) -def load_checkpoint(executor, dirname=None, main_program=None): +def load_checkpoint(executor, checkpoint_dir=None, main_program=None): """ Load checkpoint from a directory by executor, - it will find latest checkpoint file and load it auto. + it will find the most recent saved checkpoint file and load it auto. :param executor - :param dirname + :param checkpoint_dir :param main_program """ - if dirname is None: - dirname = os.getcwd() + if checkpoint_dir is None: + checkpoint_dir = os.getcwd() - serial = _get_lastest_checkpoint_dir(dirname) + serial = _get_lastest_checkpoint_dir(checkpoint_dir) if serial < 0: return - cur_dir = os.path.join(dirname, str(serial)) + + cur_dir = _get_serial_dir(serial, checkpoint_dir) load_vars( executor, @@ -525,6 +529,11 @@ def load_checkpoint(executor, dirname=None, main_program=None): filename=None) +def _get_serial_dir(serial, checkpoint_dir): + serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial) + return os.path.join(checkpoint_dir, serial_folder) + + def _is_checkpoint_var(var): """ the checkpoint will not save or load all the variables. @@ -577,7 +586,8 @@ def _write_success(dirname): """ success_file = os.path.join(dirname, SUCCESS_MARK_FILENAME) with open(success_file, 'a'): - pass + now = time.ctime() + success_file.write(now) def _get_lastest_checkpoint_dir(checkpoint_dir): @@ -593,18 +603,20 @@ def _get_lastest_checkpoint_dir(checkpoint_dir): """ is _SUCCESS in this dir """ - if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)): - return -1 + _, serial = cur_dir.split(CHECKPOINT_SEPARATOR) try: - int(cur_dir) + int(serial) except ValueError: return -1 - success_path = os.path.join(checkpoint_dir, cur_dir, - SUCCESS_MARK_FILENAME) + if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)): + return -1 + + success_path = os.path.join( + _get_serial_dir(serial, checkpoint_dir), SUCCESS_MARK_FILENAME) if os.path.isfile(success_path): - return int(cur_dir) + return int(serial) if not os.path.isdir(checkpoint_dir): return -1