未验证 提交 aa47297a 编写于 作者: L lkylkylky 提交者: GitHub

fix unittests for eignvalsh (#39841)

上级 fb635089
......@@ -60,8 +60,12 @@ class TestEigvalshGPUCase(unittest.TestCase):
self.dtype = "float32"
np.random.seed(123)
self.x_np = np.random.random(self.x_shape).astype(self.dtype)
self.rtol = 1e-5
self.atol = 1e-5
if (paddle.version.cuda() >= "11.6"):
self.rtol = 5e-6
self.atol = 6e-5
else:
self.rtol = 1e-5
self.atol = 1e-5
def test_check_output_gpu(self):
if paddle.is_compiled_with_cuda():
......@@ -75,23 +79,29 @@ class TestEigvalshGPUCase(unittest.TestCase):
class TestEigvalshAPI(unittest.TestCase):
def setUp(self):
self.init_input_shape()
self.x_shape = [5, 5]
self.dtype = "float32"
self.UPLO = 'L'
self.rtol = 1e-6
self.atol = 1e-6
if (paddle.version.cuda() >= "11.6"):
self.rtol = 5e-6
self.atol = 6e-5
else:
self.rtol = 1e-5
self.atol = 1e-5
self.place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \
else paddle.CPUPlace()
np.random.seed(123)
self.init_input_data()
def init_input_data(self):
self.real_data = np.random.random(self.x_shape).astype(self.dtype)
self.complex_data = np.random.random(self.x_shape).astype(
complex_data = np.random.random(self.x_shape).astype(
self.dtype) + 1J * np.random.random(self.x_shape).astype(self.dtype)
self.trans_dims = list(range(len(self.x_shape) - 2)) + [
len(self.x_shape) - 1, len(self.x_shape) - 2
]
def init_input_shape(self):
self.x_shape = [5, 5]
self.complex_symm = np.divide(
complex_data + np.conj(complex_data.transpose(self.trans_dims)), 2)
def compare_result(self, actual_w, expected_w):
np.testing.assert_allclose(
......@@ -122,9 +132,9 @@ class TestEigvalshAPI(unittest.TestCase):
output_w = paddle.linalg.eigvalsh(input_x)
exe = paddle.static.Executor(self.place)
expected_w = exe.run(main_prog,
feed={"input_x": self.complex_data},
feed={"input_x": self.complex_symm},
fetch_list=[output_w])
actual_w = np.linalg.eigvalsh(self.complex_data)
actual_w = np.linalg.eigvalsh(self.complex_symm)
self.compare_result(actual_w, expected_w[0])
def test_in_static_mode(self):
......@@ -139,14 +149,14 @@ class TestEigvalshAPI(unittest.TestCase):
actual_w = paddle.linalg.eigvalsh(input_real_data)
self.compare_result(actual_w, expected_w)
input_complex_data = paddle.to_tensor(self.complex_data)
expected_w = np.linalg.eigvalsh(self.complex_data)
actual_w = paddle.linalg.eigvalsh(input_complex_data)
input_complex_symm = paddle.to_tensor(self.complex_symm)
expected_w = np.linalg.eigvalsh(self.complex_symm)
actual_w = paddle.linalg.eigvalsh(input_complex_symm)
self.compare_result(actual_w, expected_w)
def test_eigvalsh_grad(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.complex_data, stop_gradient=False)
x = paddle.to_tensor(self.complex_symm, stop_gradient=False)
w = paddle.linalg.eigvalsh(x)
(w.sum()).backward()
np.testing.assert_allclose(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册