未验证 提交 acb13825 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] Support allclose and linalg_cond to eager mode (#41545) (#41691)

上级 8663376f
......@@ -16,6 +16,7 @@ import paddle
import paddle.fluid as fluid
import unittest
import numpy as np
from paddle.fluid.framework import _test_eager_guard
class TestAllcloseLayer(unittest.TestCase):
......@@ -95,7 +96,7 @@ class TestAllcloseLayer(unittest.TestCase):
with fluid.program_guard(main, startup):
self.allclose_check(use_cuda=True, dtype='float64')
def test_dygraph_mode(self):
def func_dygraph_mode(self):
x_1 = np.array([10000., 1e-07]).astype("float32")
y_1 = np.array([10000.1, 1e-08]).astype("float32")
x_2 = np.array([10000., 1e-08]).astype("float32")
......@@ -171,9 +172,14 @@ class TestAllcloseLayer(unittest.TestCase):
x_v_5 = paddle.to_tensor(x_5)
y_v_5 = paddle.to_tensor(y_5)
ret_5 = paddle.allclose(
x_v_5, y_v_5, rtol=0.01, atol=0.0, name='test_8')
x_v_5, y_v_5, rtol=0.015, atol=0.0, name='test_8')
self.assertEqual(ret_5.numpy()[0], True)
def test_dygraph_mode(self):
with _test_eager_guard():
self.func_dygraph_mode()
self.func_dygraph_mode()
if __name__ == "__main__":
unittest.main()
......@@ -18,6 +18,7 @@ import unittest
import numpy as np
import paddle
import paddle.static as static
from paddle.fluid.framework import _test_eager_guard
p_list_n_n = ("fro", "nuc", 1, -1, np.inf, -np.inf)
p_list_m_n = (None, 2, -2)
......@@ -89,16 +90,21 @@ class API_TestStaticCond(unittest.TestCase):
class API_TestDygraphCond(unittest.TestCase):
def test_out(self):
def func_out(self):
paddle.disable_static()
# test calling results of 'cond' in dynamic mode
x_list_n_n, x_list_m_n = gen_input()
test_dygraph_assert_true(self, x_list_n_n, p_list_n_n + p_list_m_n)
test_dygraph_assert_true(self, x_list_m_n, p_list_m_n)
def test_out(self):
with _test_eager_guard():
self.func_out()
self.func_out()
class TestCondAPIError(unittest.TestCase):
def test_dygraph_api_error(self):
def func_dygraph_api_error(self):
paddle.disable_static()
# test raising errors when 'cond' is called in dygraph mode
p_list_error = ('fro_', '_nuc', -0.7, 0, 1.5, 3)
......@@ -113,6 +119,11 @@ class TestCondAPIError(unittest.TestCase):
x_tensor = paddle.to_tensor(x)
self.assertRaises(ValueError, paddle.linalg.cond, x_tensor, p)
def test_dygraph_api_error(self):
with _test_eager_guard():
self.func_dygraph_api_error()
self.func_dygraph_api_error()
def test_static_api_error(self):
paddle.enable_static()
# test raising errors when 'cond' is called in static mode
......@@ -149,13 +160,18 @@ class TestCondAPIError(unittest.TestCase):
class TestCondEmptyTensorInput(unittest.TestCase):
def test_dygraph_empty_tensor_input(self):
def func_dygraph_empty_tensor_input(self):
paddle.disable_static()
# test calling results of 'cond' when input is an empty tensor in dynamic mode
x_list_n_n, x_list_m_n = gen_empty_input()
test_dygraph_assert_true(self, x_list_n_n, p_list_n_n + p_list_m_n)
test_dygraph_assert_true(self, x_list_m_n, p_list_m_n)
def test_dygraph_empty_tensor_input(self):
with _test_eager_guard():
self.func_dygraph_empty_tensor_input()
self.func_dygraph_empty_tensor_input()
if __name__ == "__main__":
paddle.enable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册