未验证 提交 75fb2ed9 编写于 作者: Z Zhang Zheng 提交者: GitHub

Replace OpTest.assertTrue(numpy.allclose) to numpy.testing.assert_allclose (#51690)

上级 e335ae29
...@@ -1604,6 +1604,23 @@ class OpTest(unittest.TestCase): ...@@ -1604,6 +1604,23 @@ class OpTest(unittest.TestCase):
raise NotImplementedError("base class, not implement!") raise NotImplementedError("base class, not implement!")
def _compare_numpy(self, name, actual_np, expect_np): def _compare_numpy(self, name, actual_np, expect_np):
if actual_np.shape == expect_np.shape:
np.testing.assert_allclose(
actual_np,
expect_np,
atol=atol,
rtol=self.rtol if hasattr(self, 'rtol') else 1e-5,
equal_nan=equal_nan,
err_msg=(
"Output ("
+ name
+ ") has diff at "
+ str(place)
+ " in "
+ self.checker_name
),
)
return
self.op_test.assertTrue( self.op_test.assertTrue(
np.allclose( np.allclose(
actual_np, actual_np,
...@@ -1777,6 +1794,23 @@ class OpTest(unittest.TestCase): ...@@ -1777,6 +1794,23 @@ class OpTest(unittest.TestCase):
): ):
pass pass
else: else:
if actual_np.shape == expect_np.shape:
np.testing.assert_allclose(
actual_np,
expect_np,
atol=atol,
rtol=self.rtol if hasattr(self, 'rtol') else 1e-5,
equal_nan=equal_nan,
err_msg=(
"Output ("
+ name
+ ") has diff at "
+ str(place)
+ " in "
+ self.checker_name
),
)
return
self.op_test.assertTrue( self.op_test.assertTrue(
np.allclose( np.allclose(
actual_np, actual_np,
......
...@@ -1599,6 +1599,23 @@ class OpTest(unittest.TestCase): ...@@ -1599,6 +1599,23 @@ class OpTest(unittest.TestCase):
raise NotImplementedError("base class, not implement!") raise NotImplementedError("base class, not implement!")
def _compare_numpy(self, name, actual_np, expect_np): def _compare_numpy(self, name, actual_np, expect_np):
if actual_np.shape == expect_np.shape:
np.testing.assert_allclose(
actual_np,
expect_np,
atol=atol,
rtol=self.rtol if hasattr(self, 'rtol') else 1e-5,
equal_nan=equal_nan,
err_msg=(
"Output ("
+ name
+ ") has diff at "
+ str(place)
+ " in "
+ self.checker_name
),
)
return
self.op_test.assertTrue( self.op_test.assertTrue(
np.allclose( np.allclose(
actual_np, actual_np,
...@@ -1810,6 +1827,23 @@ class OpTest(unittest.TestCase): ...@@ -1810,6 +1827,23 @@ class OpTest(unittest.TestCase):
): ):
pass pass
else: else:
if actual_np.shape == expect_np.shape:
np.testing.assert_allclose(
actual_np,
expect_np,
atol=atol,
rtol=self.rtol if hasattr(self, 'rtol') else 1e-5,
equal_nan=equal_nan,
err_msg=(
"Output ("
+ name
+ ") has diff at "
+ str(place)
+ " in "
+ self.checker_name
),
)
return
self.op_test.assertTrue( self.op_test.assertTrue(
np.allclose( np.allclose(
actual_np, actual_np,
......
...@@ -87,7 +87,7 @@ class TestAssignPosOpInt64(op_test.OpTest): ...@@ -87,7 +87,7 @@ class TestAssignPosOpInt64(op_test.OpTest):
self.cum_count = cum_count self.cum_count = cum_count
def test_forward(self): def test_forward(self):
np.allclose = get_redefined_allclose(self.cum_count) np.testing.assert_allclose = get_redefined_allclose(self.cum_count)
self.check_output_with_place(paddle.CUDAPlace(0)) self.check_output_with_place(paddle.CUDAPlace(0))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册