From cbe8e6e9a8a492ff59f24afce8ded226f99fb53d Mon Sep 17 00:00:00 2001 From: Zhang Zheng <32410583+ZzSean@users.noreply.github.com> Date: Fri, 7 Apr 2023 15:44:20 +0800 Subject: [PATCH] [AMP OP&Test] Fix the logic of calling infer_dtype func in op test (#52581) * [AMP OP&Test] Fix the logic of calling infer_dtype func in op test * add fp16 --- .../fluid/tests/unittests/eager_op_test.py | 28 +++++++++++++++++-- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/eager_op_test.py b/python/paddle/fluid/tests/unittests/eager_op_test.py index ae75ac02f94..b764a1acd0a 100644 --- a/python/paddle/fluid/tests/unittests/eager_op_test.py +++ b/python/paddle/fluid/tests/unittests/eager_op_test.py @@ -626,7 +626,11 @@ class OpTest(unittest.TestCase): op_proto = OpProtoHolder.instance().get_op_proto(self.op_type) "infer datatype from inputs and outputs for this test case" - if self.is_bfloat16_op(): + if self.is_float16_op(): + self.dtype = np.float16 + self.__class__.dtype = self.dtype + self.output_dtype = np.float16 + elif self.is_bfloat16_op(): self.dtype = np.uint16 self.__class__.dtype = self.dtype self.output_dtype = np.uint16 @@ -1902,7 +1906,16 @@ class OpTest(unittest.TestCase): self.__class__.check_prim = True self.__class__.op_type = self.op_type # set some flags by the combination of arguments. - self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs) + if self.is_float16_op(): + self.dtype = np.float16 + self.__class__.dtype = self.dtype + self.output_dtype = np.float16 + elif self.is_bfloat16_op(): + self.dtype = np.uint16 + self.__class__.dtype = self.dtype + self.output_dtype = np.uint16 + else: + self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs) if ( self.dtype == np.float64 and self.op_type @@ -2201,7 +2214,16 @@ class OpTest(unittest.TestCase): self.assertLessEqual(max_diff, max_relative_error, err_msg()) def _check_grad_helper(self): - self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs) + if self.is_float16_op(): + self.dtype = np.float16 + self.__class__.dtype = self.dtype + self.output_dtype = np.float16 + elif self.is_bfloat16_op(): + self.dtype = np.uint16 + self.__class__.dtype = self.dtype + self.output_dtype = np.uint16 + else: + self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs) self.__class__.op_type = self.op_type self.__class__.exist_check_grad = True if self.dtype == np.float64: -- GitLab