From f91308356a7ec8a85c7a9946437b021973588a9a Mon Sep 17 00:00:00 2001 From: zlsh80826 Date: Fri, 15 Jul 2022 13:38:13 +0800 Subject: [PATCH] Fix random seed for several unit tests (#44135) * Fix test_functional_conv2d_transpose random seed * Fix random seed and use np.testing * Fix random seed for test_lu_unpack_op * Fix test_autograd_functional_dynamic random seed --- .../test_autograd_functional_dynamic.py | 1 + .../test_functional_conv2d_transpose.py | 2 ++ .../tests/unittests/test_lu_unpack_op.py | 3 +++ .../fluid/tests/unittests/test_variable.py | 21 ++++++++++++------- 4 files changed, 20 insertions(+), 7 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/autograd/test_autograd_functional_dynamic.py b/python/paddle/fluid/tests/unittests/autograd/test_autograd_functional_dynamic.py index 6c67b78d6a5..4b615804525 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_autograd_functional_dynamic.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_autograd_functional_dynamic.py @@ -676,4 +676,5 @@ class TestHessianBatchFirst(unittest.TestCase): if __name__ == "__main__": + np.random.seed(2022) unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_functional_conv2d_transpose.py b/python/paddle/fluid/tests/unittests/test_functional_conv2d_transpose.py index d1b9c689257..dce6a37c6bb 100644 --- a/python/paddle/fluid/tests/unittests/test_functional_conv2d_transpose.py +++ b/python/paddle/fluid/tests/unittests/test_functional_conv2d_transpose.py @@ -39,6 +39,7 @@ class TestFunctionalConv2D(TestCase): self.groups = 1 self.no_bias = False self.data_format = "NHWC" + np.random.seed(2022) def prepare(self): if isinstance(self.filter_shape, int): @@ -188,6 +189,7 @@ class TestFunctionalConv2DError(TestCase): self.groups = 1 self.no_bias = False self.data_format = "NHWC" + np.random.seed(2022) def test_exception(self): self.prepare() diff --git a/python/paddle/fluid/tests/unittests/test_lu_unpack_op.py b/python/paddle/fluid/tests/unittests/test_lu_unpack_op.py index 1757adef8e3..97773c70e17 100644 --- a/python/paddle/fluid/tests/unittests/test_lu_unpack_op.py +++ b/python/paddle/fluid/tests/unittests/test_lu_unpack_op.py @@ -190,6 +190,9 @@ class TestLU_UnpackOp3(TestLU_UnpackOp): class TestLU_UnpackAPI(unittest.TestCase): + def setUp(self): + np.random.seed(2022) + def test_dygraph(self): def run_lu_unpack_dygraph(shape, dtype): diff --git a/python/paddle/fluid/tests/unittests/test_variable.py b/python/paddle/fluid/tests/unittests/test_variable.py index 87802b83415..5fb220da609 100644 --- a/python/paddle/fluid/tests/unittests/test_variable.py +++ b/python/paddle/fluid/tests/unittests/test_variable.py @@ -30,6 +30,9 @@ paddle.enable_static() class TestVariable(unittest.TestCase): + def setUp(self): + np.random.seed(2022) + def test_np_dtype_convert(self): DT = core.VarDesc.VarType convert = convert_np_dtype_to_dtype_ @@ -486,6 +489,9 @@ class TestVariable(unittest.TestCase): class TestVariableSlice(unittest.TestCase): + def setUp(self): + np.random.seed(2022) + def _test_item_none(self, place): data = np.random.rand(2, 3, 4).astype("float32") prog = paddle.static.Program() @@ -545,6 +551,9 @@ class TestVariableSlice(unittest.TestCase): class TestListIndex(unittest.TestCase): + def setUp(self): + np.random.seed(2022) + def numel(self, shape): return reduce(lambda x, y: x * y, shape) @@ -723,10 +732,10 @@ class TestListIndex(unittest.TestCase): return getitem_pp = exe.run(prog, feed={x.name: array}, fetch_list=fetch_list) - print(getitem_pp) - self.assertTrue(np.array_equal(value_np, getitem_pp[0]), - msg='\n numpy:{},\n paddle:{}'.format( - value_np, getitem_pp[0])) + np.testing.assert_allclose(value_np, + getitem_pp[0], + rtol=1e-5, + atol=1e-8) def test_static_graph_getitem_bool_index(self): paddle.enable_static() @@ -791,9 +800,7 @@ class TestListIndex(unittest.TestCase): }, fetch_list=fetch_list) - self.assertTrue(np.allclose(array2, setitem_pp[0]), - msg='\n numpy:{},\n paddle:{}'.format( - array2, setitem_pp[0])) + np.testing.assert_allclose(array2, setitem_pp[0], rtol=1e-5, atol=1e-8) def test_static_graph_setitem_list_index(self): paddle.enable_static() -- GitLab