未验证 提交 1def9e05 编写于 作者: W WeiXin 提交者: GitHub

TestSaveLoadLargeParameters use cpu place. (#33756)

* TestSaveLoadLargeParameters use cpu place.

* edit unittest
上级 68c1fe8c
......@@ -95,6 +95,7 @@ class TestSaveLoadLargeParameters(unittest.TestCase):
def test_large_parameters_paddle_save(self):
# enable dygraph mode
paddle.disable_static()
paddle.set_device("cpu")
# create network
layer = LayerWithLargeParameters()
save_dict = layer.state_dict()
......@@ -103,11 +104,10 @@ class TestSaveLoadLargeParameters(unittest.TestCase):
"layer.pdparams")
protocol = 4
paddle.save(save_dict, path, protocol=protocol)
dict_load = paddle.load(path)
dict_load = paddle.load(path, return_numpy=True)
# compare results before and after saving
for key, value in save_dict.items():
self.assertTrue(
np.array_equal(dict_load[key].numpy(), value.numpy()))
self.assertTrue(np.array_equal(dict_load[key], value.numpy()))
class TestSaveLoadPickle(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册