diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 5abadc73f76b51cc841ffe235c9b99e0d502ea07..8fcc7787091b4d0ec3b6566be7fd826f3f95d7db 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -483,11 +483,11 @@ def save_checkpoint(executor, :param main_program :param max_num_checkpoints """ - if checkpoint_dir is None: - raise ValueError("The values of 'checkpoint_dir' should not be None") + if checkpoint_dir.strip() is None: + raise ValueError("'checkpoint_dir' should not be None") - if trainer_args and not isinstance(trainer_args, dict): - raise TypeError("The type of 'trainer_args' should be dict") + if trainer_args: + assert isinstance(trainer_args, dict) if not os.path.isdir(checkpoint_dir): os.makedirs(checkpoint_dir) @@ -514,11 +514,11 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program): :param main_program """ - if checkpoint_dir is None: - raise ValueError("The values of 'checkpoint_dir' should not be None") + if checkpoint_dir.strip() is None: + raise ValueError("'checkpoint_dir' should not be None") if serial is None or serial < 0: - raise ValueError("The values of 'serial' should not be None or <0 ") + raise ValueError("'serial' should not be None or <0 ") if main_program is None: raise ValueError('main_program should not be None.') @@ -536,8 +536,8 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False): :param delete_dir """ - if checkpoint_dir is None: - raise ValueError("The values of 'checkpoint_dir' should not be None") + if checkpoint_dir.strip() is None: + raise ValueError("'checkpoint_dir' should not be None") _lru_delete(checkpoint_dir, max_num_checkpoints=0) if delete_dir and not os.listdir(checkpoint_dir): @@ -590,8 +590,8 @@ def save_persist_vars_without_grad(executor, dirname, program): def save_trainer_args(dirname, trainer_id, trainer_args): - if not isinstance(trainer_args, dict): - raise TypeError("The type of 'trainer_args' should be dict") + assert isinstance(trainer_args, dict) + cur_dir = _get_trainer_dir(dirname, trainer_id) for name, value in trainer_args.iteritems(): @@ -602,12 +602,11 @@ def save_trainer_args(dirname, trainer_id, trainer_args): def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args): + assert isinstance(trainer_args, list) + cur_dir = _get_serial_dir(checkpoint_dir, serial) cur_dir = _get_trainer_dir(cur_dir, trainer_id) - if not isinstance(trainer_args, list): - raise TypeError("The type of 'trainer_args' should be list") - ret_values = [] for arg in trainer_args: