未验证 提交 f29c0ca1 编写于 作者: Z Zhang Zheng 提交者: GitHub

[AMP OP&Test] Fix the rtol setting rules in bfloat16 forward (#51875)

上级 75fb2ed9
......@@ -1727,11 +1727,9 @@ class OpTest(unittest.TestCase):
judge whether convert current output and expect to uint16.
return True | False
"""
if actual_np.dtype == np.uint16 and expect_np.dtype in [
np.float32,
np.float64,
]:
actual_np = convert_uint16_to_float(actual_np)
if actual_np.dtype == np.uint16:
if expect_np.dtype in [np.float32, np.float64]:
actual_np = convert_uint16_to_float(actual_np)
self.rtol = 1.0e-2
elif actual_np.dtype == np.float16:
self.rtol = 1.0e-3
......@@ -1828,10 +1826,7 @@ class OpTest(unittest.TestCase):
)
def convert_uint16_to_float_ifneed(self, actual_np, expect_np):
if actual_np.dtype == np.uint16 and expect_np.dtype in [
np.float32,
np.float64,
]:
if actual_np.dtype == np.uint16:
self.rtol = 1.0e-2
elif actual_np.dtype == np.float16:
self.rtol = 1.0e-3
......
......@@ -1723,11 +1723,9 @@ class OpTest(unittest.TestCase):
judge whether convert current output and expect to uint16.
return True | False
"""
if actual_np.dtype == np.uint16 and expect_np.dtype in [
np.float32,
np.float64,
]:
actual_np = convert_uint16_to_float(actual_np)
if actual_np.dtype == np.uint16:
if expect_np.dtype in [np.float32, np.float64]:
actual_np = convert_uint16_to_float(actual_np)
self.rtol = 1.0e-2
elif actual_np.dtype == np.float16:
self.rtol = 1.0e-3
......@@ -1787,10 +1785,7 @@ class OpTest(unittest.TestCase):
return imperative_expect, imperative_expect_t
def convert_uint16_to_float_ifneed(self, actual_np, expect_np):
if actual_np.dtype == np.uint16 and expect_np.dtype in [
np.float32,
np.float64,
]:
if actual_np.dtype == np.uint16:
self.rtol = 1.0e-2
elif actual_np.dtype == np.float16:
self.rtol = 1.0e-3
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册