未验证 提交 ec008a71 编写于 作者: R Roc 提交者: GitHub

[AMP OP & Test] Tril & Triu (#52411)

上级 648f58aa
...@@ -14,10 +14,11 @@ ...@@ -14,10 +14,11 @@
import unittest import unittest
import numpy as np import numpy as np
from eager_op_test import OpTest from eager_op_test import OpTest, convert_float_to_uint16
import paddle import paddle
from paddle import fluid, tensor from paddle import fluid, tensor
from paddle.fluid import core
from paddle.fluid.framework import Program, program_guard from paddle.fluid.framework import Program, program_guard
...@@ -49,20 +50,58 @@ class TrilTriuOpDefaultTest(OpTest): ...@@ -49,20 +50,58 @@ class TrilTriuOpDefaultTest(OpTest):
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
def init_dtype(self):
self.dtype = np.float64
def initTestCase(self):
self.init_dtype()
self.real_op_type = np.random.choice(['triu', 'tril'])
self.diagonal = None
self.X = np.arange(1, 101, dtype=self.dtype).reshape([10, -1])
class TrilTriuOpDefaultTestFP16(TrilTriuOpDefaultTest):
def init_dtype(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
'not supported bf16',
)
class TrilTriuOpDefaultTestBF16(TrilTriuOpDefaultTest):
def init_dtype(self):
self.dtype = np.uint16
def setUp(self):
super().setUp()
self.outputs["Out"] = convert_float_to_uint16(self.outputs["Out"])
self.inputs['X'] = convert_float_to_uint16(self.inputs['X'])
def initTestCase(self): def initTestCase(self):
self.init_dtype()
self.real_op_type = np.random.choice(['triu', 'tril']) self.real_op_type = np.random.choice(['triu', 'tril'])
self.diagonal = None self.diagonal = None
self.X = np.arange(1, 101, dtype="float64").reshape([10, -1]) self.X = np.arange(1, 101, dtype="float32").reshape([10, -1])
def test_check_output(self):
self.check_output_with_place(core.CUDAPlace(0))
def test_check_grad_normal(self):
self.check_grad_with_place(
core.CUDAPlace(0), ['X'], 'Out', numeric_grad_delta=0.05
)
def case_generator(op_type, Xshape, diagonal, expected): def case_generator(op_type, Xshape, diagonal, expected, dtype):
""" """
Generate testcases with the params shape of X, diagonal and op_type. Generate testcases with the params shape of X, diagonal and op_type.
If arg`expercted` is 'success', it will register an Optest case and expect to pass. If arg`expercted` is 'success', it will register an Optest case and expect to pass.
Otherwise, it will register an API case and check the expect failure. Otherwise, it will register an API case and check the expect failure.
""" """
cls_name = "{}_{}_shape_{}_diag_{}".format( cls_name = "{}_{}_shape_{}_diag_{}_dtype_{}".format(
expected, op_type, Xshape, diagonal expected, op_type, Xshape, diagonal, dtype
) )
errmsg = { errmsg = {
"diagonal: TypeError": "diagonal in {} must be a python Int".format( "diagonal: TypeError": "diagonal in {} must be a python Int".format(
...@@ -93,7 +132,34 @@ def case_generator(op_type, Xshape, diagonal, expected): ...@@ -93,7 +132,34 @@ def case_generator(op_type, Xshape, diagonal, expected):
self.diagonal = diagonal self.diagonal = diagonal
self.X = np.random.random(Xshape).astype("float64") self.X = np.random.random(Xshape).astype("float64")
CLASS = locals()['SuccessCase' if expected == "success" else 'FailureCase'] class SuccessCaseFP16(TrilTriuOpDefaultTestFP16):
def initTestCase(self):
self.init_dtype()
self.real_op_type = op_type
self.diagonal = diagonal
self.X = np.random.random(Xshape).astype("float16")
class SuccessCaseBF16(TrilTriuOpDefaultTestBF16):
def initTestCase(self):
self.init_dtype()
self.real_op_type = op_type
self.diagonal = diagonal
self.X = np.random.random(Xshape).astype("float32")
if dtype == "float64":
CLASS = locals()[
'SuccessCase' if expected == "success" else 'FailureCase'
]
elif dtype == "float16":
CLASS = locals()[
'SuccessCaseFP16' if expected == "success" else 'FailureCase'
]
elif dtype == "bfloat16":
CLASS = locals()[
'SuccessCaseBF16' if expected == "success" else 'FailureCase'
]
else:
raise ValueError(f"Not supported dtype {dtype}")
CLASS.__name__ = cls_name CLASS.__name__ = cls_name
globals()[cls_name] = CLASS globals()[cls_name] = CLASS
...@@ -119,11 +185,14 @@ cases = { ...@@ -119,11 +185,14 @@ cases = {
(2020,): [None], (2020,): [None],
}, },
} }
for _op_type in ['tril', 'triu']: for dtype in ["float64", "float16", "bfloat16"]:
for _op_type in ['tril', 'triu']:
for _expected, _params in cases.items(): for _expected, _params in cases.items():
for _Xshape, _diaglist in _params.items(): for _Xshape, _diaglist in _params.items():
[ [
case_generator(_op_type, _Xshape, _diagonal, _expected) case_generator(
_op_type, _Xshape, _diagonal, _expected, dtype
)
for _diagonal in _diaglist for _diagonal in _diaglist
] ]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册