diff --git a/python/paddle/fluid/dygraph/checkpoint.py b/python/paddle/fluid/dygraph/checkpoint.py index 6bcd94b45ce064b6b1ae6b3e213214a6749b6aa9..d6c99a65851062218daab068304fba07640bff98 100644 --- a/python/paddle/fluid/dygraph/checkpoint.py +++ b/python/paddle/fluid/dygraph/checkpoint.py @@ -174,6 +174,9 @@ def _save_var_to_file(stat_dict, optimizers, file_dir, file_name): def _load_var_from_file(file_dir): + if not os.path.exists(file_dir): + raise IOError("{} not exist".format(file_dir)) + def walk_filename(file_dir): base_path = os.path.join(file_dir) var_name_list = [] diff --git a/python/paddle/fluid/tests/unittests/test_imperative_checkpoint.py b/python/paddle/fluid/tests/unittests/test_imperative_checkpoint.py index 25d490f6797f3ae63308eb3e449d371864d9b28f..609662cf9880795b7f1ff57efb1205ac1eda0e72 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_checkpoint.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_checkpoint.py @@ -150,6 +150,10 @@ class TestDygraphCheckpoint(unittest.TestCase): dy_param_init_value[param.name] = param.numpy() restore, _ = fluid.dygraph.load_persistables("save_dir") + + self.assertRaises(IOError, fluid.dygraph.load_persistables, + "not_exist_dir") + mnist.load_dict(restore) self.assertEqual(len(dy_param_init_value), len(restore))