未验证 提交 397a69d9 编写于 作者: T tangwei12 提交者: GitHub

Merge pull request #10532 from seiriosPlus/checkpoint

add checkpoint util class and implement
......@@ -190,6 +190,7 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
for (auto &var : sparse_vars) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
}
rpc_service_->SetCond(1);
// FIXME(typhoonzero): use another condition to sync wait clients get.
rpc_service_->WaitClientGet(fan_in);
......
......@@ -13,21 +13,18 @@
# limitations under the License.
import os
import time
import shutil
from paddle.fluid.evaluator import Evaluator
from paddle.fluid.framework import Program, Parameter, default_main_program, Variable
from . import core
__all__ = [
'save_vars',
'save_params',
'save_persistables',
'load_vars',
'load_params',
'load_persistables',
'save_inference_model',
'load_inference_model',
'get_inference_program',
'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
'load_persistables', 'save_inference_model', 'load_inference_model',
'get_inference_program', 'save_checkpoint', 'load_checkpoint',
'clean_checkpoint'
]
......@@ -195,6 +192,8 @@ def load_vars(executor,
load_var_map = {}
for each_var in vars:
assert isinstance(each_var, Variable)
if each_var.type == core.VarDesc.VarType.RAW:
continue
new_var = _clone_var_in_block_(load_block, each_var)
if filename is None:
load_block.append_op(
......@@ -454,3 +453,192 @@ def get_parameter_value_by_name(name, executor, program=None):
program = default_main_program()
var = program.global_block().var(name)
return get_parameter_value(var, executor)
SUCCESS_MARK_FILENAME = "_SUCCESS"
CHECKPOINT_PREFIX = "checkpoint"
CHECKPOINT_SEPARATOR = "_"
def save_checkpoint(executor,
checkpoint_dir=None,
max_num_checkpoints=3,
save_interval_secs=600,
main_program=None):
"""
Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory,
the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy
to keep numbers of checkpoint directory, the numbers of checkpoint directory are max_num_checkpoints at most,
The interval between two saved checkpoints must greater than save_interval_secs.
:param executor
:param checkpoint_dir
:param max_num_checkpoints
:param save_interval_secs
:param main_program
"""
if checkpoint_dir is None:
checkpoint_dir = os.getcwd()
if not os.path.isdir(checkpoint_dir):
os.makedirs(checkpoint_dir)
serial = _get_lastest_checkpoint_dir(checkpoint_dir)
if serial >= 0 and not _interval_secs_exceed(
_get_serial_dir(serial, checkpoint_dir), save_interval_secs):
return
serial += 1
cur_dir = _get_serial_dir(serial, checkpoint_dir)
save_vars(
executor,
dirname=cur_dir,
main_program=main_program,
vars=None,
predicate=_is_checkpoint_var,
filename=None)
_write_success(cur_dir)
_lru_delete(checkpoint_dir, max_num_checkpoints)
def load_checkpoint(executor, checkpoint_dir=None, main_program=None):
"""
Load checkpoint from a directory by executor,
it will find the most recent saved checkpoint file and load it auto.
:param executor
:param checkpoint_dir
:param main_program
"""
if checkpoint_dir is None:
checkpoint_dir = os.getcwd()
serial = _get_lastest_checkpoint_dir(checkpoint_dir)
if serial < 0:
return
cur_dir = _get_serial_dir(serial, checkpoint_dir)
load_vars(
executor,
dirname=cur_dir,
main_program=main_program,
predicate=_is_checkpoint_var,
filename=None)
def clean_checkpoint(checkpoint_dir, delete_dir=False):
"""
clean the checkpoint dir, when the train exits normally, the trainer will call clean_checkpoint to delete checkpoint directory saved before.
delete_dir only works when the directory is empty, otherwise, OSError is raised.
"""
if checkpoint_dir is None:
checkpoint_dir = os.getcwd()
_lru_delete(checkpoint_dir, max_num_checkpoints=0)
if delete_dir and not os.listdir(checkpoint_dir):
os.rmdir(checkpoint_dir)
def _get_serial_dir(serial, checkpoint_dir):
serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial)
return os.path.join(checkpoint_dir, serial_folder)
def _is_checkpoint_var(var):
"""
the checkpoint will not save or load all the variables.
var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded.
:param var
"""
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
var.desc.type() == core.VarDesc.VarType.RAW:
return False
if var.name.endswith("@GRAD"):
return False
return var.persistable
def _interval_secs_exceed(dirname, save_interval_secs):
dir_time = os.path.getmtime(dirname)
if save_interval_secs > (time.time() - dir_time):
return False
return True
def _lru_delete(dirname, max_num_checkpoints=3):
dirs = os.listdir(dirname)
serials = []
for serial in dirs:
try:
serials.append(int(serial))
except ValueError:
continue
if len(serials) <= max_num_checkpoints:
return
serials.sort(reverse=True)
serials = serials[max_num_checkpoints:]
for serial in serials:
cur_dir = os.path.join(dirname, str(serial))
shutil.rmtree(cur_dir)
def _write_success(dirname):
"""
write an empty file named "_SUCCESS" in checkpoint dir, indicate this checkpoint is correct.
:param dirname
"""
success_file = os.path.join(dirname, SUCCESS_MARK_FILENAME)
with open(success_file, 'a') as f:
now = time.ctime()
f.write(now)
def _get_lastest_checkpoint_dir(checkpoint_dir):
"""
get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory
:param checkpoint_dir
"""
if not checkpoint_dir.strip():
return -1
def has_success(checkpoint_dir, cur_dir):
"""
is _SUCCESS in this dir
"""
_, serial = cur_dir.split(CHECKPOINT_SEPARATOR)
try:
int(serial)
except ValueError:
return -1
if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)):
return -1
success_path = os.path.join(
_get_serial_dir(serial, checkpoint_dir), SUCCESS_MARK_FILENAME)
if os.path.isfile(success_path):
return int(serial)
if not os.path.isdir(checkpoint_dir):
return -1
current_dir = -1
dirs = os.listdir(checkpoint_dir)
for cur_dir in dirs:
success_num = has_success(checkpoint_dir, cur_dir)
if success_num > current_dir:
current_dir = success_num
return current_dir
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册