未验证 提交 1def9e05 编写于 作者: W WeiXin 提交者: GitHub

TestSaveLoadLargeParameters use cpu place. (#33756)

* TestSaveLoadLargeParameters use cpu place.

* edit unittest
上级 68c1fe8c
...@@ -95,6 +95,7 @@ class TestSaveLoadLargeParameters(unittest.TestCase): ...@@ -95,6 +95,7 @@ class TestSaveLoadLargeParameters(unittest.TestCase):
def test_large_parameters_paddle_save(self): def test_large_parameters_paddle_save(self):
# enable dygraph mode # enable dygraph mode
paddle.disable_static() paddle.disable_static()
paddle.set_device("cpu")
# create network # create network
layer = LayerWithLargeParameters() layer = LayerWithLargeParameters()
save_dict = layer.state_dict() save_dict = layer.state_dict()
...@@ -103,11 +104,10 @@ class TestSaveLoadLargeParameters(unittest.TestCase): ...@@ -103,11 +104,10 @@ class TestSaveLoadLargeParameters(unittest.TestCase):
"layer.pdparams") "layer.pdparams")
protocol = 4 protocol = 4
paddle.save(save_dict, path, protocol=protocol) paddle.save(save_dict, path, protocol=protocol)
dict_load = paddle.load(path) dict_load = paddle.load(path, return_numpy=True)
# compare results before and after saving # compare results before and after saving
for key, value in save_dict.items(): for key, value in save_dict.items():
self.assertTrue( self.assertTrue(np.array_equal(dict_load[key], value.numpy()))
np.array_equal(dict_load[key].numpy(), value.numpy()))
class TestSaveLoadPickle(unittest.TestCase): class TestSaveLoadPickle(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册