未验证 提交 f29c0ca1 编写于 作者: Z Zhang Zheng 提交者: GitHub

[AMP OP&Test] Fix the rtol setting rules in bfloat16 forward (#51875)

上级 75fb2ed9
...@@ -1727,11 +1727,9 @@ class OpTest(unittest.TestCase): ...@@ -1727,11 +1727,9 @@ class OpTest(unittest.TestCase):
judge whether convert current output and expect to uint16. judge whether convert current output and expect to uint16.
return True | False return True | False
""" """
if actual_np.dtype == np.uint16 and expect_np.dtype in [ if actual_np.dtype == np.uint16:
np.float32, if expect_np.dtype in [np.float32, np.float64]:
np.float64, actual_np = convert_uint16_to_float(actual_np)
]:
actual_np = convert_uint16_to_float(actual_np)
self.rtol = 1.0e-2 self.rtol = 1.0e-2
elif actual_np.dtype == np.float16: elif actual_np.dtype == np.float16:
self.rtol = 1.0e-3 self.rtol = 1.0e-3
...@@ -1828,10 +1826,7 @@ class OpTest(unittest.TestCase): ...@@ -1828,10 +1826,7 @@ class OpTest(unittest.TestCase):
) )
def convert_uint16_to_float_ifneed(self, actual_np, expect_np): def convert_uint16_to_float_ifneed(self, actual_np, expect_np):
if actual_np.dtype == np.uint16 and expect_np.dtype in [ if actual_np.dtype == np.uint16:
np.float32,
np.float64,
]:
self.rtol = 1.0e-2 self.rtol = 1.0e-2
elif actual_np.dtype == np.float16: elif actual_np.dtype == np.float16:
self.rtol = 1.0e-3 self.rtol = 1.0e-3
......
...@@ -1723,11 +1723,9 @@ class OpTest(unittest.TestCase): ...@@ -1723,11 +1723,9 @@ class OpTest(unittest.TestCase):
judge whether convert current output and expect to uint16. judge whether convert current output and expect to uint16.
return True | False return True | False
""" """
if actual_np.dtype == np.uint16 and expect_np.dtype in [ if actual_np.dtype == np.uint16:
np.float32, if expect_np.dtype in [np.float32, np.float64]:
np.float64, actual_np = convert_uint16_to_float(actual_np)
]:
actual_np = convert_uint16_to_float(actual_np)
self.rtol = 1.0e-2 self.rtol = 1.0e-2
elif actual_np.dtype == np.float16: elif actual_np.dtype == np.float16:
self.rtol = 1.0e-3 self.rtol = 1.0e-3
...@@ -1787,10 +1785,7 @@ class OpTest(unittest.TestCase): ...@@ -1787,10 +1785,7 @@ class OpTest(unittest.TestCase):
return imperative_expect, imperative_expect_t return imperative_expect, imperative_expect_t
def convert_uint16_to_float_ifneed(self, actual_np, expect_np): def convert_uint16_to_float_ifneed(self, actual_np, expect_np):
if actual_np.dtype == np.uint16 and expect_np.dtype in [ if actual_np.dtype == np.uint16:
np.float32,
np.float64,
]:
self.rtol = 1.0e-2 self.rtol = 1.0e-2
elif actual_np.dtype == np.float16: elif actual_np.dtype == np.float16:
self.rtol = 1.0e-3 self.rtol = 1.0e-3
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册