提交 8e01f3b9 编写于 作者: T tangwei12

bug fix

上级 620999c9
......@@ -13,6 +13,7 @@
# limitations under the License.
import os
import errno
import time
import shutil
......@@ -881,11 +882,9 @@ def save_checkpoint(executor,
if trainer_args:
assert isinstance(trainer_args, dict)
if not os.path.isdir(checkpoint_dir):
os.makedirs(checkpoint_dir)
is_chief = trainer_id == 0
_make_chekcpoint_dirs(checkpoint_dir)
serial = get_latest_checkpoint_serial(checkpoint_dir) + 1
cur_dir = _get_serial_dir(checkpoint_dir, serial)
......@@ -1251,6 +1250,20 @@ def _is_checkpoint_var(var):
return var.persistable
def _make_chekcpoint_dirs(dirs):
assert dirs is not None
if os.path.isfile(dirs):
raise OSError(errno.ENOTDIR, "dirs path shoule be a Directory.", dirs)
if not os.path.isdir(dirs):
try:
os.makedirs(dirs)
except OSError as err:
if err.errno != errno.EEXIST:
raise err
def _get_dir_serial(dirname):
_, serial = dirname.split(CHECKPOINT_SEPARATOR)
......@@ -1264,38 +1277,27 @@ def _get_dir_serial(dirname):
def _get_serial_dir(dirname, serial):
serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial)
serial_dir = os.path.join(dirname, serial_folder)
if not os.path.isdir(serial_dir):
os.makedirs(serial_dir)
_make_chekcpoint_dirs(serial_dir)
return serial_dir
def _get_model_dir(dirname):
model_dir = os.path.join(dirname, MODEL_DIR)
if not os.path.isdir(model_dir):
os.makedirs(model_dir)
_make_chekcpoint_dirs(model_dir)
return model_dir
def _get_lookuptable_dir(dirname):
lookuptable_dir = os.path.join(dirname, LOOKUP_TABLE_DIR)
if not os.path.isdir(lookuptable_dir):
os.makedirs(lookuptable_dir)
_make_chekcpoint_dirs(lookuptable_dir)
return lookuptable_dir
def _get_trainer_dir(dirname, trainer_id):
trainer_folder = TRAINER_PREFIX + CHECKPOINT_SEPARATOR + str(trainer_id)
trainer_dir = os.path.join(dirname, trainer_folder)
if not os.path.isdir(trainer_dir):
os.makedirs(trainer_dir)
_make_chekcpoint_dirs(trainer_dir)
return trainer_dir
......@@ -1314,7 +1316,11 @@ def _scroll_delete(dirname, max_num_checkpoints=3):
serials = serials[max_num_checkpoints:]
for serial in serials:
cur_dir = _get_serial_dir(dirname, serial)
shutil.rmtree(cur_dir)
try:
shutil.rmtree(cur_dir)
except OSError as err:
if err.errno != errno.ENOENT:
raise err
def _write_success(dirname):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册