提交 01975ec1 编写于 作者: T tangwei12

add checkpoint in io

上级 5451c78d
...@@ -13,21 +13,17 @@ ...@@ -13,21 +13,17 @@
# limitations under the License. # limitations under the License.
import os import os
import time
import shutil
from paddle.fluid.evaluator import Evaluator from paddle.fluid.evaluator import Evaluator
from paddle.fluid.framework import Program, Parameter, default_main_program, Variable from paddle.fluid.framework import Program, Parameter, default_main_program, Variable
from . import core from . import core
__all__ = [ __all__ = [
'save_vars', 'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
'save_params', 'load_persistables', 'save_inference_model', 'load_inference_model',
'save_persistables', 'get_inference_program', 'save_checkpoint', 'restore_checkpoint'
'load_vars',
'load_params',
'load_persistables',
'save_inference_model',
'load_inference_model',
'get_inference_program',
] ]
...@@ -195,6 +191,8 @@ def load_vars(executor, ...@@ -195,6 +191,8 @@ def load_vars(executor,
load_var_map = {} load_var_map = {}
for each_var in vars: for each_var in vars:
assert isinstance(each_var, Variable) assert isinstance(each_var, Variable)
if each_var.type == core.VarDesc.VarType.RAW:
continue
new_var = _clone_var_in_block_(load_block, each_var) new_var = _clone_var_in_block_(load_block, each_var)
if filename is None: if filename is None:
load_block.append_op( load_block.append_op(
...@@ -457,11 +455,12 @@ def get_parameter_value_by_name(name, executor, program=None): ...@@ -457,11 +455,12 @@ def get_parameter_value_by_name(name, executor, program=None):
SUCCESS = "_SUCCESS" SUCCESS = "_SUCCESS"
BEGIN_SECS = time.time()
def save_checkpoint(executor, def save_checkpoint(executor,
dirname, dirname,
keep_max=10, keep_max=3,
save_secs=600, save_secs=600,
main_program=None): main_program=None):
""" """
...@@ -470,38 +469,70 @@ def save_checkpoint(executor, ...@@ -470,38 +469,70 @@ def save_checkpoint(executor,
:param dirname :param dirname
:param keep_max :param keep_max
:param save_secs :param save_secs
:param main_program
""" """
if dirname is None: if dirname is None:
raise Exception("save checkpoint dir can not be none") raise Exception("save checkpoint dir can not be none")
if not os.path.isdir(dirname): if not os.path.isdir(dirname):
os.makedirs(dirname) os.makedirs(dirname)
serial = _get_lastest_checkpoint_dir(dirname) + 1
cur_dir = os.path.join(dirname, serial) global BEGIN_SECS
if time.time() - BEGIN_SECS < save_secs:
return
BEGIN_SECS = time.time()
serial = _get_lastest_checkpoint_dir(dirname) + 1
cur_dir = os.path.join(dirname, str(serial))
save_persistables(executor, cur_dir, main_program) save_persistables(executor, cur_dir, main_program)
_write_success(cur_dir) _write_success(cur_dir)
_lru_delete(dirname, keep_max)
def restore_checkpoint(dirname, executor, main_program=None): def restore_checkpoint(dirname, executor, main_program=None):
""" """
Load Variables from Checkpint Dir Load Variables from Checkpint Dir
:param dir :param dirname
:param executor
:param main_program
""" """
if dirname is None and os.path.isdir(dirname): if dirname is None and os.path.isdir(dirname):
raise Exception("restore checkpoint can not load variables from %s" % raise Exception("restore checkpoint can not load variables from %s" %
dirname) dirname)
serial = _get_lastest_checkpoint_dir(dirname) + 1 serial = _get_lastest_checkpoint_dir(dirname)
if serial < -1: if serial < 0:
return return
cur_dir = os.path.join(dirname, serial) cur_dir = os.path.join(dirname, str(serial))
load_persistables(executor, cur_dir, main_program) load_persistables(executor, cur_dir, main_program)
def _lru_delete(dirname, keep_max=3):
"""
retain checkpoint nums with keep_max
"""
dirs = os.listdir(dirname)
serials = []
for serial in dirs:
try:
serials.append(int(serial))
except ValueError:
continue
if len(serials) <= keep_max:
return
serials.sort(reverse=True)
serials = serials[keep_max:]
for serial in serials:
cur_dir = os.path.join(dirname, str(serial))
shutil.rmtree(cur_dir)
def _write_success(dirname): def _write_success(dirname):
""" """
write _SUCCESS to checkpoint dir
""" """
success_file = os.path.join(dirname, SUCCESS) success_file = os.path.join(dirname, SUCCESS)
with open(success_file, 'a'): with open(success_file, 'a'):
...@@ -513,7 +544,7 @@ def _get_lastest_checkpoint_dir(checkpoint_dir): ...@@ -513,7 +544,7 @@ def _get_lastest_checkpoint_dir(checkpoint_dir):
get the biggest number in checkpoint_dir, which has _SUCCESS get the biggest number in checkpoint_dir, which has _SUCCESS
""" """
if not checkpoint_dir.strip(): if not checkpoint_dir.strip():
return "" return -1
def has_success(checkpoint_dir, cur_dir): def has_success(checkpoint_dir, cur_dir):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册