未验证 提交 cbe8e6e9 编写于 作者: Z Zhang Zheng 提交者: GitHub

[AMP OP&Test] Fix the logic of calling infer_dtype func in op test (#52581)

* [AMP OP&Test] Fix the logic of calling infer_dtype func in op test

* add fp16
上级 f5ae67e8
...@@ -626,7 +626,11 @@ class OpTest(unittest.TestCase): ...@@ -626,7 +626,11 @@ class OpTest(unittest.TestCase):
op_proto = OpProtoHolder.instance().get_op_proto(self.op_type) op_proto = OpProtoHolder.instance().get_op_proto(self.op_type)
"infer datatype from inputs and outputs for this test case" "infer datatype from inputs and outputs for this test case"
if self.is_bfloat16_op(): if self.is_float16_op():
self.dtype = np.float16
self.__class__.dtype = self.dtype
self.output_dtype = np.float16
elif self.is_bfloat16_op():
self.dtype = np.uint16 self.dtype = np.uint16
self.__class__.dtype = self.dtype self.__class__.dtype = self.dtype
self.output_dtype = np.uint16 self.output_dtype = np.uint16
...@@ -1902,7 +1906,16 @@ class OpTest(unittest.TestCase): ...@@ -1902,7 +1906,16 @@ class OpTest(unittest.TestCase):
self.__class__.check_prim = True self.__class__.check_prim = True
self.__class__.op_type = self.op_type self.__class__.op_type = self.op_type
# set some flags by the combination of arguments. # set some flags by the combination of arguments.
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs) if self.is_float16_op():
self.dtype = np.float16
self.__class__.dtype = self.dtype
self.output_dtype = np.float16
elif self.is_bfloat16_op():
self.dtype = np.uint16
self.__class__.dtype = self.dtype
self.output_dtype = np.uint16
else:
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
if ( if (
self.dtype == np.float64 self.dtype == np.float64
and self.op_type and self.op_type
...@@ -2201,7 +2214,16 @@ class OpTest(unittest.TestCase): ...@@ -2201,7 +2214,16 @@ class OpTest(unittest.TestCase):
self.assertLessEqual(max_diff, max_relative_error, err_msg()) self.assertLessEqual(max_diff, max_relative_error, err_msg())
def _check_grad_helper(self): def _check_grad_helper(self):
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs) if self.is_float16_op():
self.dtype = np.float16
self.__class__.dtype = self.dtype
self.output_dtype = np.float16
elif self.is_bfloat16_op():
self.dtype = np.uint16
self.__class__.dtype = self.dtype
self.output_dtype = np.uint16
else:
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
self.__class__.op_type = self.op_type self.__class__.op_type = self.op_type
self.__class__.exist_check_grad = True self.__class__.exist_check_grad = True
if self.dtype == np.float64: if self.dtype == np.float64:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册