未验证 提交 c7848aca 编写于 作者: W WeiXin 提交者: GitHub

[Cherry-Pick]fix test_paddle_save_load and test_paddle_save_load_binary (#32949) (#33008)

    test_paddle_save_load 单测随机挂:使用np.ndarray生成随机数组,可能生成nan,造成做对比时结果不匹配(nan != nan)。改为np.random.randn生成随机数组。

    test_paddle_save_load_binary随机挂: 如果一个字符串不能解析为Program,windows上会有超时风险。解决方法:不在windows平台不加载'不能解析为Program的字符串'。

原始PR:#32949
上级 bdce8a1d
...@@ -412,11 +412,10 @@ class TestSaveLoadAny(unittest.TestCase): ...@@ -412,11 +412,10 @@ class TestSaveLoadAny(unittest.TestCase):
] ]
obj2 = {'k1': obj1, 'k2': state_dict, 'epoch': 123} obj2 = {'k1': obj1, 'k2': state_dict, 'epoch': 123}
obj3 = (paddle.randn( obj3 = (paddle.randn(
[5, 4], dtype='float32'), np.ndarray( [5, 4], dtype='float32'), np.random.randn(3, 4).astype("float32"), {
[3, 4], dtype="float32"), { "state_dict": state_dict,
"state_dict": state_dict, "opt": state_dict
"opt": state_dict })
})
obj4 = (np.random.randn(5, 6), (123, )) obj4 = (np.random.randn(5, 6), (123, ))
path1 = "test_save_load_any_complex_object_dygraph/obj1" path1 = "test_save_load_any_complex_object_dygraph/obj1"
......
...@@ -19,6 +19,7 @@ import numpy as np ...@@ -19,6 +19,7 @@ import numpy as np
import os import os
import sys import sys
import six import six
import platform
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
...@@ -162,12 +163,13 @@ class TestSaveLoadBinaryFormat(unittest.TestCase): ...@@ -162,12 +163,13 @@ class TestSaveLoadBinaryFormat(unittest.TestCase):
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
path = 'test_save_load_error/temp' path = 'test_save_load_error/temp'
paddle.save({}, path, use_binary_format=True) paddle.save({}, path, use_binary_format=True)
# On the Windows platform, when parsing a string that can't be parsed as a `Program`, `desc_.ParseFromString` has a timeout risk.
with self.assertRaises(ValueError): if 'Windows' != platform.system():
path = 'test_save_load_error/temp' with self.assertRaises(ValueError):
with open(path, "w") as f: path = 'test_save_load_error/temp'
f.write('\0') with open(path, "w") as f:
paddle.load(path) f.write('\0')
paddle.load(path)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
temp_lod = fluid.core.LoDTensor() temp_lod = fluid.core.LoDTensor()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册