提交 01975ec1 编写于 作者: T tangwei12

add checkpoint in io

上级 5451c78d
......@@ -13,21 +13,17 @@
# limitations under the License.
import os
import time
import shutil
from paddle.fluid.evaluator import Evaluator
from paddle.fluid.framework import Program, Parameter, default_main_program, Variable
from . import core
__all__ = [
'save_vars',
'save_params',
'save_persistables',
'load_vars',
'load_params',
'load_persistables',
'save_inference_model',
'load_inference_model',
'get_inference_program',
'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
'load_persistables', 'save_inference_model', 'load_inference_model',
'get_inference_program', 'save_checkpoint', 'restore_checkpoint'
]
......@@ -195,6 +191,8 @@ def load_vars(executor,
load_var_map = {}
for each_var in vars:
assert isinstance(each_var, Variable)
if each_var.type == core.VarDesc.VarType.RAW:
continue
new_var = _clone_var_in_block_(load_block, each_var)
if filename is None:
load_block.append_op(
......@@ -457,11 +455,12 @@ def get_parameter_value_by_name(name, executor, program=None):
SUCCESS = "_SUCCESS"
BEGIN_SECS = time.time()
def save_checkpoint(executor,
dirname,
keep_max=10,
keep_max=3,
save_secs=600,
main_program=None):
"""
......@@ -470,38 +469,70 @@ def save_checkpoint(executor,
:param dirname
:param keep_max
:param save_secs
:param main_program
"""
if dirname is None:
raise Exception("save checkpoint dir can not be none")
if not os.path.isdir(dirname):
os.makedirs(dirname)
serial = _get_lastest_checkpoint_dir(dirname) + 1
cur_dir = os.path.join(dirname, serial)
global BEGIN_SECS
if time.time() - BEGIN_SECS < save_secs:
return
BEGIN_SECS = time.time()
serial = _get_lastest_checkpoint_dir(dirname) + 1
cur_dir = os.path.join(dirname, str(serial))
save_persistables(executor, cur_dir, main_program)
_write_success(cur_dir)
_lru_delete(dirname, keep_max)
def restore_checkpoint(dirname, executor, main_program=None):
"""
Load Variables from Checkpint Dir
:param dir
:param dirname
:param executor
:param main_program
"""
if dirname is None and os.path.isdir(dirname):
raise Exception("restore checkpoint can not load variables from %s" %
dirname)
serial = _get_lastest_checkpoint_dir(dirname) + 1
serial = _get_lastest_checkpoint_dir(dirname)
if serial < -1:
if serial < 0:
return
cur_dir = os.path.join(dirname, serial)
cur_dir = os.path.join(dirname, str(serial))
load_persistables(executor, cur_dir, main_program)
def _lru_delete(dirname, keep_max=3):
"""
retain checkpoint nums with keep_max
"""
dirs = os.listdir(dirname)
serials = []
for serial in dirs:
try:
serials.append(int(serial))
except ValueError:
continue
if len(serials) <= keep_max:
return
serials.sort(reverse=True)
serials = serials[keep_max:]
for serial in serials:
cur_dir = os.path.join(dirname, str(serial))
shutil.rmtree(cur_dir)
def _write_success(dirname):
"""
write _SUCCESS to checkpoint dir
"""
success_file = os.path.join(dirname, SUCCESS)
with open(success_file, 'a'):
......@@ -513,7 +544,7 @@ def _get_lastest_checkpoint_dir(checkpoint_dir):
get the biggest number in checkpoint_dir, which has _SUCCESS
"""
if not checkpoint_dir.strip():
return ""
return -1
def has_success(checkpoint_dir, cur_dir):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册