From 10643b4ea69f9880e06593889e96c60c799feaa1 Mon Sep 17 00:00:00 2001 From: Ghost Under Moon Date: Mon, 26 Aug 2019 12:10:54 +0800 Subject: [PATCH] fix- raise io error when user load from non-existed dir test=develop (#19384) This PR fix problem with issue #18096 , which raise an error for user to specify the error about load dir is wrong --- python/paddle/fluid/dygraph/checkpoint.py | 3 +++ .../fluid/tests/unittests/test_imperative_checkpoint.py | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/python/paddle/fluid/dygraph/checkpoint.py b/python/paddle/fluid/dygraph/checkpoint.py index 6bcd94b45c..d6c99a6585 100644 --- a/python/paddle/fluid/dygraph/checkpoint.py +++ b/python/paddle/fluid/dygraph/checkpoint.py @@ -174,6 +174,9 @@ def _save_var_to_file(stat_dict, optimizers, file_dir, file_name): def _load_var_from_file(file_dir): + if not os.path.exists(file_dir): + raise IOError("{} not exist".format(file_dir)) + def walk_filename(file_dir): base_path = os.path.join(file_dir) var_name_list = [] diff --git a/python/paddle/fluid/tests/unittests/test_imperative_checkpoint.py b/python/paddle/fluid/tests/unittests/test_imperative_checkpoint.py index 25d490f679..609662cf98 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_checkpoint.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_checkpoint.py @@ -150,6 +150,10 @@ class TestDygraphCheckpoint(unittest.TestCase): dy_param_init_value[param.name] = param.numpy() restore, _ = fluid.dygraph.load_persistables("save_dir") + + self.assertRaises(IOError, fluid.dygraph.load_persistables, + "not_exist_dir") + mnist.load_dict(restore) self.assertEqual(len(dy_param_init_value), len(restore)) -- GitLab