未验证 提交 aa47297a 编写于 作者: L lkylkylky 提交者: GitHub

fix unittests for eignvalsh (#39841)

上级 fb635089
...@@ -60,6 +60,10 @@ class TestEigvalshGPUCase(unittest.TestCase): ...@@ -60,6 +60,10 @@ class TestEigvalshGPUCase(unittest.TestCase):
self.dtype = "float32" self.dtype = "float32"
np.random.seed(123) np.random.seed(123)
self.x_np = np.random.random(self.x_shape).astype(self.dtype) self.x_np = np.random.random(self.x_shape).astype(self.dtype)
if (paddle.version.cuda() >= "11.6"):
self.rtol = 5e-6
self.atol = 6e-5
else:
self.rtol = 1e-5 self.rtol = 1e-5
self.atol = 1e-5 self.atol = 1e-5
...@@ -75,23 +79,29 @@ class TestEigvalshGPUCase(unittest.TestCase): ...@@ -75,23 +79,29 @@ class TestEigvalshGPUCase(unittest.TestCase):
class TestEigvalshAPI(unittest.TestCase): class TestEigvalshAPI(unittest.TestCase):
def setUp(self): def setUp(self):
self.init_input_shape() self.x_shape = [5, 5]
self.dtype = "float32" self.dtype = "float32"
self.UPLO = 'L' self.UPLO = 'L'
self.rtol = 1e-6 if (paddle.version.cuda() >= "11.6"):
self.atol = 1e-6 self.rtol = 5e-6
self.atol = 6e-5
else:
self.rtol = 1e-5
self.atol = 1e-5
self.place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \ self.place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \
else paddle.CPUPlace() else paddle.CPUPlace()
np.random.seed(123) np.random.seed(123)
self.init_input_data()
def init_input_data(self):
self.real_data = np.random.random(self.x_shape).astype(self.dtype) self.real_data = np.random.random(self.x_shape).astype(self.dtype)
self.complex_data = np.random.random(self.x_shape).astype( complex_data = np.random.random(self.x_shape).astype(
self.dtype) + 1J * np.random.random(self.x_shape).astype(self.dtype) self.dtype) + 1J * np.random.random(self.x_shape).astype(self.dtype)
self.trans_dims = list(range(len(self.x_shape) - 2)) + [ self.trans_dims = list(range(len(self.x_shape) - 2)) + [
len(self.x_shape) - 1, len(self.x_shape) - 2 len(self.x_shape) - 1, len(self.x_shape) - 2
] ]
self.complex_symm = np.divide(
def init_input_shape(self): complex_data + np.conj(complex_data.transpose(self.trans_dims)), 2)
self.x_shape = [5, 5]
def compare_result(self, actual_w, expected_w): def compare_result(self, actual_w, expected_w):
np.testing.assert_allclose( np.testing.assert_allclose(
...@@ -122,9 +132,9 @@ class TestEigvalshAPI(unittest.TestCase): ...@@ -122,9 +132,9 @@ class TestEigvalshAPI(unittest.TestCase):
output_w = paddle.linalg.eigvalsh(input_x) output_w = paddle.linalg.eigvalsh(input_x)
exe = paddle.static.Executor(self.place) exe = paddle.static.Executor(self.place)
expected_w = exe.run(main_prog, expected_w = exe.run(main_prog,
feed={"input_x": self.complex_data}, feed={"input_x": self.complex_symm},
fetch_list=[output_w]) fetch_list=[output_w])
actual_w = np.linalg.eigvalsh(self.complex_data) actual_w = np.linalg.eigvalsh(self.complex_symm)
self.compare_result(actual_w, expected_w[0]) self.compare_result(actual_w, expected_w[0])
def test_in_static_mode(self): def test_in_static_mode(self):
...@@ -139,14 +149,14 @@ class TestEigvalshAPI(unittest.TestCase): ...@@ -139,14 +149,14 @@ class TestEigvalshAPI(unittest.TestCase):
actual_w = paddle.linalg.eigvalsh(input_real_data) actual_w = paddle.linalg.eigvalsh(input_real_data)
self.compare_result(actual_w, expected_w) self.compare_result(actual_w, expected_w)
input_complex_data = paddle.to_tensor(self.complex_data) input_complex_symm = paddle.to_tensor(self.complex_symm)
expected_w = np.linalg.eigvalsh(self.complex_data) expected_w = np.linalg.eigvalsh(self.complex_symm)
actual_w = paddle.linalg.eigvalsh(input_complex_data) actual_w = paddle.linalg.eigvalsh(input_complex_symm)
self.compare_result(actual_w, expected_w) self.compare_result(actual_w, expected_w)
def test_eigvalsh_grad(self): def test_eigvalsh_grad(self):
paddle.disable_static(self.place) paddle.disable_static(self.place)
x = paddle.to_tensor(self.complex_data, stop_gradient=False) x = paddle.to_tensor(self.complex_symm, stop_gradient=False)
w = paddle.linalg.eigvalsh(x) w = paddle.linalg.eigvalsh(x)
(w.sum()).backward() (w.sum()).backward()
np.testing.assert_allclose( np.testing.assert_allclose(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册