提交 f28f41db 编写于 作者: T tangwei12

update io.py annotations and codes

上级 6db240d7
......@@ -483,11 +483,11 @@ def save_checkpoint(executor,
:param main_program
:param max_num_checkpoints
"""
if checkpoint_dir is None:
raise ValueError("The values of 'checkpoint_dir' should not be None")
if checkpoint_dir.strip() is None:
raise ValueError("'checkpoint_dir' should not be None")
if trainer_args and not isinstance(trainer_args, dict):
raise TypeError("The type of 'trainer_args' should be dict")
if trainer_args:
assert isinstance(trainer_args, dict)
if not os.path.isdir(checkpoint_dir):
os.makedirs(checkpoint_dir)
......@@ -514,11 +514,11 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program):
:param main_program
"""
if checkpoint_dir is None:
raise ValueError("The values of 'checkpoint_dir' should not be None")
if checkpoint_dir.strip() is None:
raise ValueError("'checkpoint_dir' should not be None")
if serial is None or serial < 0:
raise ValueError("The values of 'serial' should not be None or <0 ")
raise ValueError("'serial' should not be None or <0 ")
if main_program is None:
raise ValueError('main_program should not be None.')
......@@ -536,8 +536,8 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False):
:param delete_dir
"""
if checkpoint_dir is None:
raise ValueError("The values of 'checkpoint_dir' should not be None")
if checkpoint_dir.strip() is None:
raise ValueError("'checkpoint_dir' should not be None")
_lru_delete(checkpoint_dir, max_num_checkpoints=0)
if delete_dir and not os.listdir(checkpoint_dir):
......@@ -590,8 +590,8 @@ def save_persist_vars_without_grad(executor, dirname, program):
def save_trainer_args(dirname, trainer_id, trainer_args):
if not isinstance(trainer_args, dict):
raise TypeError("The type of 'trainer_args' should be dict")
assert isinstance(trainer_args, dict)
cur_dir = _get_trainer_dir(dirname, trainer_id)
for name, value in trainer_args.iteritems():
......@@ -602,12 +602,11 @@ def save_trainer_args(dirname, trainer_id, trainer_args):
def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args):
assert isinstance(trainer_args, list)
cur_dir = _get_serial_dir(checkpoint_dir, serial)
cur_dir = _get_trainer_dir(cur_dir, trainer_id)
if not isinstance(trainer_args, list):
raise TypeError("The type of 'trainer_args' should be list")
ret_values = []
for arg in trainer_args:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册