diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 48ac320089aad1f5fa5fe3f327cf28e2c90ad1a1..3e693ed7170530c5ca5cf8820e469146c2eb0c02 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -190,6 +190,7 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor, for (auto &var : sparse_vars) { var->GetMutable()->mutable_rows()->clear(); } + rpc_service_->SetCond(1); // FIXME(typhoonzero): use another condition to sync wait clients get. rpc_service_->WaitClientGet(fan_in); diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 54506e97ed5c9a23f5a1e9624391f466c1c498d6..8e58e5eb794e1bb507ab05394a1f7b57a1d2ed42 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -13,21 +13,18 @@ # limitations under the License. import os +import time +import shutil from paddle.fluid.evaluator import Evaluator from paddle.fluid.framework import Program, Parameter, default_main_program, Variable from . import core __all__ = [ - 'save_vars', - 'save_params', - 'save_persistables', - 'load_vars', - 'load_params', - 'load_persistables', - 'save_inference_model', - 'load_inference_model', - 'get_inference_program', + 'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params', + 'load_persistables', 'save_inference_model', 'load_inference_model', + 'get_inference_program', 'save_checkpoint', 'load_checkpoint', + 'clean_checkpoint' ] @@ -195,6 +192,8 @@ def load_vars(executor, load_var_map = {} for each_var in vars: assert isinstance(each_var, Variable) + if each_var.type == core.VarDesc.VarType.RAW: + continue new_var = _clone_var_in_block_(load_block, each_var) if filename is None: load_block.append_op( @@ -454,3 +453,192 @@ def get_parameter_value_by_name(name, executor, program=None): program = default_main_program() var = program.global_block().var(name) return get_parameter_value(var, executor) + + +SUCCESS_MARK_FILENAME = "_SUCCESS" +CHECKPOINT_PREFIX = "checkpoint" +CHECKPOINT_SEPARATOR = "_" + + +def save_checkpoint(executor, + checkpoint_dir=None, + max_num_checkpoints=3, + save_interval_secs=600, + main_program=None): + """ + Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory, + the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy + to keep numbers of checkpoint directory, the numbers of checkpoint directory are max_num_checkpoints at most, + The interval between two saved checkpoints must greater than save_interval_secs. + + :param executor + :param checkpoint_dir + :param max_num_checkpoints + :param save_interval_secs + :param main_program + """ + if checkpoint_dir is None: + checkpoint_dir = os.getcwd() + + if not os.path.isdir(checkpoint_dir): + os.makedirs(checkpoint_dir) + + serial = _get_lastest_checkpoint_dir(checkpoint_dir) + if serial >= 0 and not _interval_secs_exceed( + _get_serial_dir(serial, checkpoint_dir), save_interval_secs): + return + + serial += 1 + cur_dir = _get_serial_dir(serial, checkpoint_dir) + + save_vars( + executor, + dirname=cur_dir, + main_program=main_program, + vars=None, + predicate=_is_checkpoint_var, + filename=None) + _write_success(cur_dir) + _lru_delete(checkpoint_dir, max_num_checkpoints) + + +def load_checkpoint(executor, checkpoint_dir=None, main_program=None): + """ + Load checkpoint from a directory by executor, + it will find the most recent saved checkpoint file and load it auto. + + :param executor + :param checkpoint_dir + :param main_program + """ + + if checkpoint_dir is None: + checkpoint_dir = os.getcwd() + + serial = _get_lastest_checkpoint_dir(checkpoint_dir) + + if serial < 0: + return + + cur_dir = _get_serial_dir(serial, checkpoint_dir) + + load_vars( + executor, + dirname=cur_dir, + main_program=main_program, + predicate=_is_checkpoint_var, + filename=None) + + +def clean_checkpoint(checkpoint_dir, delete_dir=False): + """ + clean the checkpoint dir, when the train exits normally, the trainer will call clean_checkpoint to delete checkpoint directory saved before. + delete_dir only works when the directory is empty, otherwise, OSError is raised. + """ + if checkpoint_dir is None: + checkpoint_dir = os.getcwd() + _lru_delete(checkpoint_dir, max_num_checkpoints=0) + + if delete_dir and not os.listdir(checkpoint_dir): + os.rmdir(checkpoint_dir) + + +def _get_serial_dir(serial, checkpoint_dir): + serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial) + return os.path.join(checkpoint_dir, serial_folder) + + +def _is_checkpoint_var(var): + """ + the checkpoint will not save or load all the variables. + var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded. + + :param var + """ + if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ + var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \ + var.desc.type() == core.VarDesc.VarType.RAW: + return False + + if var.name.endswith("@GRAD"): + return False + + return var.persistable + + +def _interval_secs_exceed(dirname, save_interval_secs): + dir_time = os.path.getmtime(dirname) + if save_interval_secs > (time.time() - dir_time): + return False + return True + + +def _lru_delete(dirname, max_num_checkpoints=3): + dirs = os.listdir(dirname) + serials = [] + for serial in dirs: + try: + serials.append(int(serial)) + except ValueError: + continue + + if len(serials) <= max_num_checkpoints: + return + + serials.sort(reverse=True) + serials = serials[max_num_checkpoints:] + for serial in serials: + cur_dir = os.path.join(dirname, str(serial)) + shutil.rmtree(cur_dir) + + +def _write_success(dirname): + """ + write an empty file named "_SUCCESS" in checkpoint dir, indicate this checkpoint is correct. + + :param dirname + """ + success_file = os.path.join(dirname, SUCCESS_MARK_FILENAME) + with open(success_file, 'a') as f: + now = time.ctime() + f.write(now) + + +def _get_lastest_checkpoint_dir(checkpoint_dir): + """ + get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory + + :param checkpoint_dir + """ + if not checkpoint_dir.strip(): + return -1 + + def has_success(checkpoint_dir, cur_dir): + """ + is _SUCCESS in this dir + """ + _, serial = cur_dir.split(CHECKPOINT_SEPARATOR) + + try: + int(serial) + except ValueError: + return -1 + + if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)): + return -1 + + success_path = os.path.join( + _get_serial_dir(serial, checkpoint_dir), SUCCESS_MARK_FILENAME) + if os.path.isfile(success_path): + return int(serial) + + if not os.path.isdir(checkpoint_dir): + return -1 + + current_dir = -1 + dirs = os.listdir(checkpoint_dir) + for cur_dir in dirs: + success_num = has_success(checkpoint_dir, cur_dir) + if success_num > current_dir: + current_dir = success_num + return current_dir