diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 54506e97ed5c9a23f5a1e9624391f466c1c498d6..502386016cfad731ce21b65be9d0975fe9b7d72e 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -454,3 +454,90 @@ def get_parameter_value_by_name(name, executor, program=None): program = default_main_program() var = program.global_block().var(name) return get_parameter_value(var, executor) + + +SUCCESS = "_SUCCESS" + + +def save_checkpoint(executor, + dirname, + keep_max=10, + save_secs=600, + main_program=None): + """ + Save Variables to Checkpint Dir + + :param dirname + :param keep_max + :param save_secs + """ + if dirname is None: + raise Exception("save checkpoint dir can not be none") + + if not os.path.isdir(dirname): + os.makedirs(dirname) + serial = _get_lastest_checkpoint_dir(dirname) + 1 + + cur_dir = os.path.join(dirname, serial) + save_persistables(executor, cur_dir, main_program) + _write_success(cur_dir) + + +def restore_checkpoint(dirname, executor, main_program=None): + """ + Load Variables from Checkpint Dir + + :param dir + """ + if dirname is None and os.path.isdir(dirname): + raise Exception("restore checkpoint can not load variables from %s" % + dirname) + serial = _get_lastest_checkpoint_dir(dirname) + 1 + + if serial < -1: + return + cur_dir = os.path.join(dirname, serial) + load_persistables(executor, cur_dir, main_program) + + +def _write_success(dirname): + """ + """ + success_file = os.path.join(dirname, SUCCESS) + with open(success_file, 'a'): + pass + + +def _get_lastest_checkpoint_dir(checkpoint_dir): + """ + get the biggest number in checkpoint_dir, which has _SUCCESS + """ + if not checkpoint_dir.strip(): + return "" + + def has_success(checkpoint_dir, cur_dir): + """ + is _SUCCESS in this dir + """ + if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)): + return -1 + + try: + int(cur_dir) + except ValueError: + return -1 + + success_path = os.path.join(checkpoint_dir, cur_dir, SUCCESS) + if os.path.isfile(success_path): + return int(cur_dir) + + if not os.path.isdir(checkpoint_dir): + return -1 + + current_dir = -1 + dirs = os.listdir(checkpoint_dir) + for cur_dir in dirs: + success_num = has_success(checkpoint_dir, cur_dir) + if success_num > current_dir: + current_dir = success_num + return current_dir