From 5f5d6a9dc7eaf2e1c5b069454497d11a28701ddb Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 29 May 2018 16:01:26 +0800 Subject: [PATCH] optimized checkpoint and save_model --- python/paddle/fluid/io.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index aa039bdfa..bd3c2e3d9 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -489,9 +489,9 @@ CHECKPOINT_SEPARATOR = "_" def save_checkpoint(executor, - checkpoint_dir=None, - max_num_checkpoints=3, - main_program=None): + checkpoint_dir, + main_program=None, + max_num_checkpoints=3): """ Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory, the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy @@ -500,12 +500,11 @@ def save_checkpoint(executor, :param executor :param checkpoint_dir - :param max_num_checkpoints - :param save_interval_secs :param main_program + :param max_num_checkpoints """ if checkpoint_dir is None: - checkpoint_dir = os.getcwd() + raise ValueError("The values of 'checkpoint_dir' should not be None") if not os.path.isdir(checkpoint_dir): os.makedirs(checkpoint_dir) @@ -518,7 +517,7 @@ def save_checkpoint(executor, _lru_delete(checkpoint_dir, max_num_checkpoints) -def load_checkpoint(executor, checkpoint_dir=None, main_program=None): +def load_checkpoint(executor, checkpoint_dir, main_program=None): """ Load checkpoint from a directory by executor, it will find the most recent saved checkpoint file and load it auto. @@ -529,7 +528,7 @@ def load_checkpoint(executor, checkpoint_dir=None, main_program=None): """ if checkpoint_dir is None: - checkpoint_dir = os.getcwd() + raise ValueError("The values of 'checkpoint_dir' should not be None") serial = _get_lastest_checkpoint_dir(checkpoint_dir) @@ -546,7 +545,7 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False): delete_dir only works when the directory is empty, otherwise, OSError is raised. """ if checkpoint_dir is None: - checkpoint_dir = os.getcwd() + raise ValueError("The values of 'checkpoint_dir' should not be None") _lru_delete(checkpoint_dir, max_num_checkpoints=0) if delete_dir and not os.listdir(checkpoint_dir): -- GitLab