diff --git a/python/paddle/fluid/tests/unittests/eager_op_test.py b/python/paddle/fluid/tests/unittests/eager_op_test.py index ae75ac02f9450ef1dd50edb6062ee10d1a8fe4a9..b764a1acd0a96633eb623c67ac5402f8ac685027 100644 --- a/python/paddle/fluid/tests/unittests/eager_op_test.py +++ b/python/paddle/fluid/tests/unittests/eager_op_test.py @@ -626,7 +626,11 @@ class OpTest(unittest.TestCase): op_proto = OpProtoHolder.instance().get_op_proto(self.op_type) "infer datatype from inputs and outputs for this test case" - if self.is_bfloat16_op(): + if self.is_float16_op(): + self.dtype = np.float16 + self.__class__.dtype = self.dtype + self.output_dtype = np.float16 + elif self.is_bfloat16_op(): self.dtype = np.uint16 self.__class__.dtype = self.dtype self.output_dtype = np.uint16 @@ -1902,7 +1906,16 @@ class OpTest(unittest.TestCase): self.__class__.check_prim = True self.__class__.op_type = self.op_type # set some flags by the combination of arguments. - self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs) + if self.is_float16_op(): + self.dtype = np.float16 + self.__class__.dtype = self.dtype + self.output_dtype = np.float16 + elif self.is_bfloat16_op(): + self.dtype = np.uint16 + self.__class__.dtype = self.dtype + self.output_dtype = np.uint16 + else: + self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs) if ( self.dtype == np.float64 and self.op_type @@ -2201,7 +2214,16 @@ class OpTest(unittest.TestCase): self.assertLessEqual(max_diff, max_relative_error, err_msg()) def _check_grad_helper(self): - self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs) + if self.is_float16_op(): + self.dtype = np.float16 + self.__class__.dtype = self.dtype + self.output_dtype = np.float16 + elif self.is_bfloat16_op(): + self.dtype = np.uint16 + self.__class__.dtype = self.dtype + self.output_dtype = np.uint16 + else: + self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs) self.__class__.op_type = self.op_type self.__class__.exist_check_grad = True if self.dtype == np.float64: