From 1def9e05656496c15f24dd134c8f669d23923a8e Mon Sep 17 00:00:00 2001 From: WeiXin Date: Thu, 24 Jun 2021 14:46:18 +0800 Subject: [PATCH] TestSaveLoadLargeParameters use cpu place. (#33756) * TestSaveLoadLargeParameters use cpu place. * edit unittest --- .../paddle/fluid/tests/unittests/test_paddle_save_load.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_paddle_save_load.py b/python/paddle/fluid/tests/unittests/test_paddle_save_load.py index 77aa4ae36b..727ac36898 100644 --- a/python/paddle/fluid/tests/unittests/test_paddle_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_paddle_save_load.py @@ -95,6 +95,7 @@ class TestSaveLoadLargeParameters(unittest.TestCase): def test_large_parameters_paddle_save(self): # enable dygraph mode paddle.disable_static() + paddle.set_device("cpu") # create network layer = LayerWithLargeParameters() save_dict = layer.state_dict() @@ -103,11 +104,10 @@ class TestSaveLoadLargeParameters(unittest.TestCase): "layer.pdparams") protocol = 4 paddle.save(save_dict, path, protocol=protocol) - dict_load = paddle.load(path) + dict_load = paddle.load(path, return_numpy=True) # compare results before and after saving for key, value in save_dict.items(): - self.assertTrue( - np.array_equal(dict_load[key].numpy(), value.numpy())) + self.assertTrue(np.array_equal(dict_load[key], value.numpy())) class TestSaveLoadPickle(unittest.TestCase): -- GitLab