未验证 提交 f9130835 编写于 作者: Z zlsh80826 提交者: GitHub

Fix random seed for several unit tests (#44135)

* Fix test_functional_conv2d_transpose random seed

* Fix random seed and use np.testing

* Fix random seed for test_lu_unpack_op

* Fix test_autograd_functional_dynamic random seed
上级 1f7f7193
...@@ -676,4 +676,5 @@ class TestHessianBatchFirst(unittest.TestCase): ...@@ -676,4 +676,5 @@ class TestHessianBatchFirst(unittest.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
np.random.seed(2022)
unittest.main() unittest.main()
...@@ -39,6 +39,7 @@ class TestFunctionalConv2D(TestCase): ...@@ -39,6 +39,7 @@ class TestFunctionalConv2D(TestCase):
self.groups = 1 self.groups = 1
self.no_bias = False self.no_bias = False
self.data_format = "NHWC" self.data_format = "NHWC"
np.random.seed(2022)
def prepare(self): def prepare(self):
if isinstance(self.filter_shape, int): if isinstance(self.filter_shape, int):
...@@ -188,6 +189,7 @@ class TestFunctionalConv2DError(TestCase): ...@@ -188,6 +189,7 @@ class TestFunctionalConv2DError(TestCase):
self.groups = 1 self.groups = 1
self.no_bias = False self.no_bias = False
self.data_format = "NHWC" self.data_format = "NHWC"
np.random.seed(2022)
def test_exception(self): def test_exception(self):
self.prepare() self.prepare()
......
...@@ -190,6 +190,9 @@ class TestLU_UnpackOp3(TestLU_UnpackOp): ...@@ -190,6 +190,9 @@ class TestLU_UnpackOp3(TestLU_UnpackOp):
class TestLU_UnpackAPI(unittest.TestCase): class TestLU_UnpackAPI(unittest.TestCase):
def setUp(self):
np.random.seed(2022)
def test_dygraph(self): def test_dygraph(self):
def run_lu_unpack_dygraph(shape, dtype): def run_lu_unpack_dygraph(shape, dtype):
......
...@@ -30,6 +30,9 @@ paddle.enable_static() ...@@ -30,6 +30,9 @@ paddle.enable_static()
class TestVariable(unittest.TestCase): class TestVariable(unittest.TestCase):
def setUp(self):
np.random.seed(2022)
def test_np_dtype_convert(self): def test_np_dtype_convert(self):
DT = core.VarDesc.VarType DT = core.VarDesc.VarType
convert = convert_np_dtype_to_dtype_ convert = convert_np_dtype_to_dtype_
...@@ -486,6 +489,9 @@ class TestVariable(unittest.TestCase): ...@@ -486,6 +489,9 @@ class TestVariable(unittest.TestCase):
class TestVariableSlice(unittest.TestCase): class TestVariableSlice(unittest.TestCase):
def setUp(self):
np.random.seed(2022)
def _test_item_none(self, place): def _test_item_none(self, place):
data = np.random.rand(2, 3, 4).astype("float32") data = np.random.rand(2, 3, 4).astype("float32")
prog = paddle.static.Program() prog = paddle.static.Program()
...@@ -545,6 +551,9 @@ class TestVariableSlice(unittest.TestCase): ...@@ -545,6 +551,9 @@ class TestVariableSlice(unittest.TestCase):
class TestListIndex(unittest.TestCase): class TestListIndex(unittest.TestCase):
def setUp(self):
np.random.seed(2022)
def numel(self, shape): def numel(self, shape):
return reduce(lambda x, y: x * y, shape) return reduce(lambda x, y: x * y, shape)
...@@ -723,10 +732,10 @@ class TestListIndex(unittest.TestCase): ...@@ -723,10 +732,10 @@ class TestListIndex(unittest.TestCase):
return return
getitem_pp = exe.run(prog, feed={x.name: array}, fetch_list=fetch_list) getitem_pp = exe.run(prog, feed={x.name: array}, fetch_list=fetch_list)
print(getitem_pp) np.testing.assert_allclose(value_np,
self.assertTrue(np.array_equal(value_np, getitem_pp[0]), getitem_pp[0],
msg='\n numpy:{},\n paddle:{}'.format( rtol=1e-5,
value_np, getitem_pp[0])) atol=1e-8)
def test_static_graph_getitem_bool_index(self): def test_static_graph_getitem_bool_index(self):
paddle.enable_static() paddle.enable_static()
...@@ -791,9 +800,7 @@ class TestListIndex(unittest.TestCase): ...@@ -791,9 +800,7 @@ class TestListIndex(unittest.TestCase):
}, },
fetch_list=fetch_list) fetch_list=fetch_list)
self.assertTrue(np.allclose(array2, setitem_pp[0]), np.testing.assert_allclose(array2, setitem_pp[0], rtol=1e-5, atol=1e-8)
msg='\n numpy:{},\n paddle:{}'.format(
array2, setitem_pp[0]))
def test_static_graph_setitem_list_index(self): def test_static_graph_setitem_list_index(self):
paddle.enable_static() paddle.enable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册