提交 f28f41db 编写于 作者: T tangwei12

update io.py annotations and codes

上级 6db240d7
...@@ -483,11 +483,11 @@ def save_checkpoint(executor, ...@@ -483,11 +483,11 @@ def save_checkpoint(executor,
:param main_program :param main_program
:param max_num_checkpoints :param max_num_checkpoints
""" """
if checkpoint_dir is None: if checkpoint_dir.strip() is None:
raise ValueError("The values of 'checkpoint_dir' should not be None") raise ValueError("'checkpoint_dir' should not be None")
if trainer_args and not isinstance(trainer_args, dict): if trainer_args:
raise TypeError("The type of 'trainer_args' should be dict") assert isinstance(trainer_args, dict)
if not os.path.isdir(checkpoint_dir): if not os.path.isdir(checkpoint_dir):
os.makedirs(checkpoint_dir) os.makedirs(checkpoint_dir)
...@@ -514,11 +514,11 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program): ...@@ -514,11 +514,11 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program):
:param main_program :param main_program
""" """
if checkpoint_dir is None: if checkpoint_dir.strip() is None:
raise ValueError("The values of 'checkpoint_dir' should not be None") raise ValueError("'checkpoint_dir' should not be None")
if serial is None or serial < 0: if serial is None or serial < 0:
raise ValueError("The values of 'serial' should not be None or <0 ") raise ValueError("'serial' should not be None or <0 ")
if main_program is None: if main_program is None:
raise ValueError('main_program should not be None.') raise ValueError('main_program should not be None.')
...@@ -536,8 +536,8 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False): ...@@ -536,8 +536,8 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False):
:param delete_dir :param delete_dir
""" """
if checkpoint_dir is None: if checkpoint_dir.strip() is None:
raise ValueError("The values of 'checkpoint_dir' should not be None") raise ValueError("'checkpoint_dir' should not be None")
_lru_delete(checkpoint_dir, max_num_checkpoints=0) _lru_delete(checkpoint_dir, max_num_checkpoints=0)
if delete_dir and not os.listdir(checkpoint_dir): if delete_dir and not os.listdir(checkpoint_dir):
...@@ -590,8 +590,8 @@ def save_persist_vars_without_grad(executor, dirname, program): ...@@ -590,8 +590,8 @@ def save_persist_vars_without_grad(executor, dirname, program):
def save_trainer_args(dirname, trainer_id, trainer_args): def save_trainer_args(dirname, trainer_id, trainer_args):
if not isinstance(trainer_args, dict): assert isinstance(trainer_args, dict)
raise TypeError("The type of 'trainer_args' should be dict")
cur_dir = _get_trainer_dir(dirname, trainer_id) cur_dir = _get_trainer_dir(dirname, trainer_id)
for name, value in trainer_args.iteritems(): for name, value in trainer_args.iteritems():
...@@ -602,12 +602,11 @@ def save_trainer_args(dirname, trainer_id, trainer_args): ...@@ -602,12 +602,11 @@ def save_trainer_args(dirname, trainer_id, trainer_args):
def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args): def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args):
assert isinstance(trainer_args, list)
cur_dir = _get_serial_dir(checkpoint_dir, serial) cur_dir = _get_serial_dir(checkpoint_dir, serial)
cur_dir = _get_trainer_dir(cur_dir, trainer_id) cur_dir = _get_trainer_dir(cur_dir, trainer_id)
if not isinstance(trainer_args, list):
raise TypeError("The type of 'trainer_args' should be list")
ret_values = [] ret_values = []
for arg in trainer_args: for arg in trainer_args:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册