diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index c638da67c825d4b6d6aec306830509755af144e1..9e0bc425f0e34d37f5002b2ebc7e81247fd3fd51 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -529,6 +529,19 @@ def load_checkpoint(executor, checkpoint_dir=None, main_program=None): filename=None) +def clean_checkpoint(checkpoint_dir, delete_dir=False): + """ + clean the checkpoint dir, when the train exits normally, the trainer will call clean_checkpoint to delete checkpoint directory saved before. + delete_dir only works when the directory is empty, otherwise, OSError is raised. + """ + if checkpoint_dir is None: + checkpoint_dir = os.getcwd() + _lru_delete(checkpoint_dir, max_num_checkpoints=0) + + if delete_dir and not os.listdir(checkpoint_dir): + os.rmdir(checkpoint_dir) + + def _get_serial_dir(serial, checkpoint_dir): serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial) return os.path.join(checkpoint_dir, serial_folder)