diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 32f53ebe388b948e7e7a2fc11137ce309c0ad469..8cc25e8623717aa34c4f4c2d4aaf2d880fc66895 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -13,6 +13,7 @@ # limitations under the License. import os +import errno import time import shutil @@ -881,11 +882,9 @@ def save_checkpoint(executor, if trainer_args: assert isinstance(trainer_args, dict) - if not os.path.isdir(checkpoint_dir): - os.makedirs(checkpoint_dir) - is_chief = trainer_id == 0 + _make_chekcpoint_dirs(checkpoint_dir) serial = get_latest_checkpoint_serial(checkpoint_dir) + 1 cur_dir = _get_serial_dir(checkpoint_dir, serial) @@ -1251,6 +1250,20 @@ def _is_checkpoint_var(var): return var.persistable +def _make_chekcpoint_dirs(dirs): + assert dirs is not None + + if os.path.isfile(dirs): + raise OSError(errno.ENOTDIR, "dirs path shoule be a Directory.", dirs) + + if not os.path.isdir(dirs): + try: + os.makedirs(dirs) + except OSError as err: + if err.errno != errno.EEXIST: + raise err + + def _get_dir_serial(dirname): _, serial = dirname.split(CHECKPOINT_SEPARATOR) @@ -1264,38 +1277,27 @@ def _get_dir_serial(dirname): def _get_serial_dir(dirname, serial): serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial) serial_dir = os.path.join(dirname, serial_folder) - - if not os.path.isdir(serial_dir): - os.makedirs(serial_dir) + _make_chekcpoint_dirs(serial_dir) return serial_dir def _get_model_dir(dirname): model_dir = os.path.join(dirname, MODEL_DIR) - - if not os.path.isdir(model_dir): - os.makedirs(model_dir) - + _make_chekcpoint_dirs(model_dir) return model_dir def _get_lookuptable_dir(dirname): lookuptable_dir = os.path.join(dirname, LOOKUP_TABLE_DIR) - - if not os.path.isdir(lookuptable_dir): - os.makedirs(lookuptable_dir) - + _make_chekcpoint_dirs(lookuptable_dir) return lookuptable_dir def _get_trainer_dir(dirname, trainer_id): trainer_folder = TRAINER_PREFIX + CHECKPOINT_SEPARATOR + str(trainer_id) trainer_dir = os.path.join(dirname, trainer_folder) - - if not os.path.isdir(trainer_dir): - os.makedirs(trainer_dir) - + _make_chekcpoint_dirs(trainer_dir) return trainer_dir @@ -1314,7 +1316,11 @@ def _scroll_delete(dirname, max_num_checkpoints=3): serials = serials[max_num_checkpoints:] for serial in serials: cur_dir = _get_serial_dir(dirname, serial) - shutil.rmtree(cur_dir) + try: + shutil.rmtree(cur_dir) + except OSError as err: + if err.errno != errno.ENOENT: + raise err def _write_success(dirname):