From 01975ec1c749c9576a1124a7f029234caa86e0ed Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Mon, 21 May 2018 16:53:59 +0800 Subject: [PATCH] add checkpoint in io --- python/paddle/fluid/io.py | 65 +++++++++++++++++++++++++++++---------- 1 file changed, 48 insertions(+), 17 deletions(-) diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 502386016cf..83c32fe9d6e 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -13,21 +13,17 @@ # limitations under the License. import os +import time +import shutil from paddle.fluid.evaluator import Evaluator from paddle.fluid.framework import Program, Parameter, default_main_program, Variable from . import core __all__ = [ - 'save_vars', - 'save_params', - 'save_persistables', - 'load_vars', - 'load_params', - 'load_persistables', - 'save_inference_model', - 'load_inference_model', - 'get_inference_program', + 'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params', + 'load_persistables', 'save_inference_model', 'load_inference_model', + 'get_inference_program', 'save_checkpoint', 'restore_checkpoint' ] @@ -195,6 +191,8 @@ def load_vars(executor, load_var_map = {} for each_var in vars: assert isinstance(each_var, Variable) + if each_var.type == core.VarDesc.VarType.RAW: + continue new_var = _clone_var_in_block_(load_block, each_var) if filename is None: load_block.append_op( @@ -457,11 +455,12 @@ def get_parameter_value_by_name(name, executor, program=None): SUCCESS = "_SUCCESS" +BEGIN_SECS = time.time() def save_checkpoint(executor, dirname, - keep_max=10, + keep_max=3, save_secs=600, main_program=None): """ @@ -470,38 +469,70 @@ def save_checkpoint(executor, :param dirname :param keep_max :param save_secs + :param main_program """ if dirname is None: raise Exception("save checkpoint dir can not be none") if not os.path.isdir(dirname): os.makedirs(dirname) - serial = _get_lastest_checkpoint_dir(dirname) + 1 - cur_dir = os.path.join(dirname, serial) + global BEGIN_SECS + if time.time() - BEGIN_SECS < save_secs: + return + BEGIN_SECS = time.time() + + serial = _get_lastest_checkpoint_dir(dirname) + 1 + cur_dir = os.path.join(dirname, str(serial)) save_persistables(executor, cur_dir, main_program) _write_success(cur_dir) + _lru_delete(dirname, keep_max) def restore_checkpoint(dirname, executor, main_program=None): """ Load Variables from Checkpint Dir - :param dir + :param dirname + :param executor + :param main_program """ if dirname is None and os.path.isdir(dirname): raise Exception("restore checkpoint can not load variables from %s" % dirname) - serial = _get_lastest_checkpoint_dir(dirname) + 1 + serial = _get_lastest_checkpoint_dir(dirname) - if serial < -1: + if serial < 0: return - cur_dir = os.path.join(dirname, serial) + cur_dir = os.path.join(dirname, str(serial)) load_persistables(executor, cur_dir, main_program) +def _lru_delete(dirname, keep_max=3): + """ + retain checkpoint nums with keep_max + """ + dirs = os.listdir(dirname) + serials = [] + for serial in dirs: + try: + serials.append(int(serial)) + except ValueError: + continue + + if len(serials) <= keep_max: + return + + serials.sort(reverse=True) + serials = serials[keep_max:] + for serial in serials: + cur_dir = os.path.join(dirname, str(serial)) + shutil.rmtree(cur_dir) + + def _write_success(dirname): """ + write _SUCCESS to checkpoint dir """ success_file = os.path.join(dirname, SUCCESS) with open(success_file, 'a'): @@ -513,7 +544,7 @@ def _get_lastest_checkpoint_dir(checkpoint_dir): get the biggest number in checkpoint_dir, which has _SUCCESS """ if not checkpoint_dir.strip(): - return "" + return -1 def has_success(checkpoint_dir, cur_dir): """ -- GitLab