未验证 提交 539fb0d7 编写于 作者: C crystal 提交者: GitHub

Fix unittests for eigh op (#39568)

* fix eigh test

* modify atol and rtol
上级 6c358a7c
......@@ -59,6 +59,10 @@ class TestEighGPUCase(unittest.TestCase):
self.dtype = "float32"
np.random.seed(123)
self.x_np = np.random.random(self.x_shape).astype(self.dtype)
if (paddle.version.cuda() >= "11.6"):
self.rtol = 5e-6
self.atol = 6e-5
else:
self.rtol = 1e-5
self.atol = 1e-5
......@@ -79,23 +83,30 @@ class TestEighGPUCase(unittest.TestCase):
class TestEighAPI(unittest.TestCase):
def setUp(self):
self.init_input_shape()
self.dtype = "float32"
self.init_input_data()
self.UPLO = 'L'
self.rtol = 1e-6
self.atol = 1e-6
if (paddle.version.cuda() >= "11.6"):
self.rtol = 5e-6
self.atol = 6e-5
else:
self.rtol = 1e-5
self.atol = 1e-5
self.place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \
else paddle.CPUPlace()
np.random.seed(123)
def init_input_data(self):
self.x_shape = [5, 5]
self.dtype = "float32"
self.real_data = np.random.random(self.x_shape).astype(self.dtype)
self.complex_data = np.random.random(self.x_shape).astype(
complex_data = np.random.random(self.x_shape).astype(
self.dtype) + 1J * np.random.random(self.x_shape).astype(self.dtype)
self.trans_dims = list(range(len(self.x_shape) - 2)) + [
len(self.x_shape) - 1, len(self.x_shape) - 2
]
def init_input_shape(self):
self.x_shape = [5, 5]
#build a random conjugate matrix
self.complex_symm = np.divide(
complex_data + np.conj(complex_data.transpose(self.trans_dims)), 2)
def compare_result(self, actual_w, actual_v, expected_w, expected_v):
np.testing.assert_allclose(
......@@ -129,9 +140,9 @@ class TestEighAPI(unittest.TestCase):
exe = paddle.static.Executor(self.place)
expected_w, expected_v = exe.run(
main_prog,
feed={"input_x": self.complex_data},
feed={"input_x": self.complex_symm},
fetch_list=[output_w, output_v])
actual_w, actual_v = np.linalg.eigh(self.complex_data)
actual_w, actual_v = np.linalg.eigh(self.complex_symm)
self.compare_result(actual_w, actual_v, expected_w, expected_v)
def test_in_static_mode(self):
......@@ -146,14 +157,14 @@ class TestEighAPI(unittest.TestCase):
actual_w, actual_v = paddle.linalg.eigh(input_real_data)
self.compare_result(actual_w, actual_v.numpy(), expected_w, expected_v)
input_complex_data = paddle.to_tensor(self.complex_data)
expected_w, expected_v = np.linalg.eigh(self.complex_data)
input_complex_data = paddle.to_tensor(self.complex_symm)
expected_w, expected_v = np.linalg.eigh(self.complex_symm)
actual_w, actual_v = paddle.linalg.eigh(input_complex_data)
self.compare_result(actual_w, actual_v.numpy(), expected_w, expected_v)
def test_eigh_grad(self):
paddle.disable_static()
x = paddle.to_tensor(self.complex_data, stop_gradient=False)
x = paddle.to_tensor(self.complex_symm, stop_gradient=False)
w, v = paddle.linalg.eigh(x)
(w.sum() + paddle.abs(v).sum()).backward()
np.testing.assert_allclose(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册