未验证 提交 cbe8e6e9 编写于 作者: Z Zhang Zheng 提交者: GitHub

[AMP OP&Test] Fix the logic of calling infer_dtype func in op test (#52581)

* [AMP OP&Test] Fix the logic of calling infer_dtype func in op test

* add fp16
上级 f5ae67e8
......@@ -626,7 +626,11 @@ class OpTest(unittest.TestCase):
op_proto = OpProtoHolder.instance().get_op_proto(self.op_type)
"infer datatype from inputs and outputs for this test case"
if self.is_bfloat16_op():
if self.is_float16_op():
self.dtype = np.float16
self.__class__.dtype = self.dtype
self.output_dtype = np.float16
elif self.is_bfloat16_op():
self.dtype = np.uint16
self.__class__.dtype = self.dtype
self.output_dtype = np.uint16
......@@ -1902,6 +1906,15 @@ class OpTest(unittest.TestCase):
self.__class__.check_prim = True
self.__class__.op_type = self.op_type
# set some flags by the combination of arguments.
if self.is_float16_op():
self.dtype = np.float16
self.__class__.dtype = self.dtype
self.output_dtype = np.float16
elif self.is_bfloat16_op():
self.dtype = np.uint16
self.__class__.dtype = self.dtype
self.output_dtype = np.uint16
else:
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
if (
self.dtype == np.float64
......@@ -2201,6 +2214,15 @@ class OpTest(unittest.TestCase):
self.assertLessEqual(max_diff, max_relative_error, err_msg())
def _check_grad_helper(self):
if self.is_float16_op():
self.dtype = np.float16
self.__class__.dtype = self.dtype
self.output_dtype = np.float16
elif self.is_bfloat16_op():
self.dtype = np.uint16
self.__class__.dtype = self.dtype
self.output_dtype = np.uint16
else:
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
self.__class__.op_type = self.op_type
self.__class__.exist_check_grad = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册