未验证 提交 f9130835 编写于 作者: Z zlsh80826 提交者: GitHub

Fix random seed for several unit tests (#44135)

* Fix test_functional_conv2d_transpose random seed

* Fix random seed and use np.testing

* Fix random seed for test_lu_unpack_op

* Fix test_autograd_functional_dynamic random seed
上级 1f7f7193
......@@ -676,4 +676,5 @@ class TestHessianBatchFirst(unittest.TestCase):
if __name__ == "__main__":
np.random.seed(2022)
unittest.main()
......@@ -39,6 +39,7 @@ class TestFunctionalConv2D(TestCase):
self.groups = 1
self.no_bias = False
self.data_format = "NHWC"
np.random.seed(2022)
def prepare(self):
if isinstance(self.filter_shape, int):
......@@ -188,6 +189,7 @@ class TestFunctionalConv2DError(TestCase):
self.groups = 1
self.no_bias = False
self.data_format = "NHWC"
np.random.seed(2022)
def test_exception(self):
self.prepare()
......
......@@ -190,6 +190,9 @@ class TestLU_UnpackOp3(TestLU_UnpackOp):
class TestLU_UnpackAPI(unittest.TestCase):
def setUp(self):
np.random.seed(2022)
def test_dygraph(self):
def run_lu_unpack_dygraph(shape, dtype):
......
......@@ -30,6 +30,9 @@ paddle.enable_static()
class TestVariable(unittest.TestCase):
def setUp(self):
np.random.seed(2022)
def test_np_dtype_convert(self):
DT = core.VarDesc.VarType
convert = convert_np_dtype_to_dtype_
......@@ -486,6 +489,9 @@ class TestVariable(unittest.TestCase):
class TestVariableSlice(unittest.TestCase):
def setUp(self):
np.random.seed(2022)
def _test_item_none(self, place):
data = np.random.rand(2, 3, 4).astype("float32")
prog = paddle.static.Program()
......@@ -545,6 +551,9 @@ class TestVariableSlice(unittest.TestCase):
class TestListIndex(unittest.TestCase):
def setUp(self):
np.random.seed(2022)
def numel(self, shape):
return reduce(lambda x, y: x * y, shape)
......@@ -723,10 +732,10 @@ class TestListIndex(unittest.TestCase):
return
getitem_pp = exe.run(prog, feed={x.name: array}, fetch_list=fetch_list)
print(getitem_pp)
self.assertTrue(np.array_equal(value_np, getitem_pp[0]),
msg='\n numpy:{},\n paddle:{}'.format(
value_np, getitem_pp[0]))
np.testing.assert_allclose(value_np,
getitem_pp[0],
rtol=1e-5,
atol=1e-8)
def test_static_graph_getitem_bool_index(self):
paddle.enable_static()
......@@ -791,9 +800,7 @@ class TestListIndex(unittest.TestCase):
},
fetch_list=fetch_list)
self.assertTrue(np.allclose(array2, setitem_pp[0]),
msg='\n numpy:{},\n paddle:{}'.format(
array2, setitem_pp[0]))
np.testing.assert_allclose(array2, setitem_pp[0], rtol=1e-5, atol=1e-8)
def test_static_graph_setitem_list_index(self):
paddle.enable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册