From f29c0ca1d9a316411ccff0fcd7dfcec5970b483c Mon Sep 17 00:00:00 2001 From: Zhang Zheng <32410583+ZzSean@users.noreply.github.com> Date: Wed, 22 Mar 2023 10:52:31 +0800 Subject: [PATCH] [AMP OP&Test] Fix the rtol setting rules in bfloat16 forward (#51875) --- .../paddle/fluid/tests/unittests/eager_op_test.py | 13 ++++--------- python/paddle/fluid/tests/unittests/op_test.py | 13 ++++--------- 2 files changed, 8 insertions(+), 18 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/eager_op_test.py b/python/paddle/fluid/tests/unittests/eager_op_test.py index e42a83f0653..fa2eca18138 100644 --- a/python/paddle/fluid/tests/unittests/eager_op_test.py +++ b/python/paddle/fluid/tests/unittests/eager_op_test.py @@ -1727,11 +1727,9 @@ class OpTest(unittest.TestCase): judge whether convert current output and expect to uint16. return True | False """ - if actual_np.dtype == np.uint16 and expect_np.dtype in [ - np.float32, - np.float64, - ]: - actual_np = convert_uint16_to_float(actual_np) + if actual_np.dtype == np.uint16: + if expect_np.dtype in [np.float32, np.float64]: + actual_np = convert_uint16_to_float(actual_np) self.rtol = 1.0e-2 elif actual_np.dtype == np.float16: self.rtol = 1.0e-3 @@ -1828,10 +1826,7 @@ class OpTest(unittest.TestCase): ) def convert_uint16_to_float_ifneed(self, actual_np, expect_np): - if actual_np.dtype == np.uint16 and expect_np.dtype in [ - np.float32, - np.float64, - ]: + if actual_np.dtype == np.uint16: self.rtol = 1.0e-2 elif actual_np.dtype == np.float16: self.rtol = 1.0e-3 diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 0c2f88490e6..01e574ef173 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -1723,11 +1723,9 @@ class OpTest(unittest.TestCase): judge whether convert current output and expect to uint16. return True | False """ - if actual_np.dtype == np.uint16 and expect_np.dtype in [ - np.float32, - np.float64, - ]: - actual_np = convert_uint16_to_float(actual_np) + if actual_np.dtype == np.uint16: + if expect_np.dtype in [np.float32, np.float64]: + actual_np = convert_uint16_to_float(actual_np) self.rtol = 1.0e-2 elif actual_np.dtype == np.float16: self.rtol = 1.0e-3 @@ -1787,10 +1785,7 @@ class OpTest(unittest.TestCase): return imperative_expect, imperative_expect_t def convert_uint16_to_float_ifneed(self, actual_np, expect_np): - if actual_np.dtype == np.uint16 and expect_np.dtype in [ - np.float32, - np.float64, - ]: + if actual_np.dtype == np.uint16: self.rtol = 1.0e-2 elif actual_np.dtype == np.float16: self.rtol = 1.0e-3 -- GitLab