未验证 提交 ec008a71 编写于 作者: R Roc 提交者: GitHub

[AMP OP & Test] Tril & Triu (#52411)

上级 648f58aa
......@@ -14,10 +14,11 @@
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
from paddle import fluid, tensor
from paddle.fluid import core
from paddle.fluid.framework import Program, program_guard
......@@ -49,20 +50,58 @@ class TrilTriuOpDefaultTest(OpTest):
def test_check_grad_normal(self):
self.check_grad(['X'], 'Out')
def init_dtype(self):
self.dtype = np.float64
def initTestCase(self):
self.init_dtype()
self.real_op_type = np.random.choice(['triu', 'tril'])
self.diagonal = None
self.X = np.arange(1, 101, dtype=self.dtype).reshape([10, -1])
class TrilTriuOpDefaultTestFP16(TrilTriuOpDefaultTest):
def init_dtype(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
'not supported bf16',
)
class TrilTriuOpDefaultTestBF16(TrilTriuOpDefaultTest):
def init_dtype(self):
self.dtype = np.uint16
def setUp(self):
super().setUp()
self.outputs["Out"] = convert_float_to_uint16(self.outputs["Out"])
self.inputs['X'] = convert_float_to_uint16(self.inputs['X'])
def initTestCase(self):
self.init_dtype()
self.real_op_type = np.random.choice(['triu', 'tril'])
self.diagonal = None
self.X = np.arange(1, 101, dtype="float64").reshape([10, -1])
self.X = np.arange(1, 101, dtype="float32").reshape([10, -1])
def test_check_output(self):
self.check_output_with_place(core.CUDAPlace(0))
def test_check_grad_normal(self):
self.check_grad_with_place(
core.CUDAPlace(0), ['X'], 'Out', numeric_grad_delta=0.05
)
def case_generator(op_type, Xshape, diagonal, expected):
def case_generator(op_type, Xshape, diagonal, expected, dtype):
"""
Generate testcases with the params shape of X, diagonal and op_type.
If arg`expercted` is 'success', it will register an Optest case and expect to pass.
Otherwise, it will register an API case and check the expect failure.
"""
cls_name = "{}_{}_shape_{}_diag_{}".format(
expected, op_type, Xshape, diagonal
cls_name = "{}_{}_shape_{}_diag_{}_dtype_{}".format(
expected, op_type, Xshape, diagonal, dtype
)
errmsg = {
"diagonal: TypeError": "diagonal in {} must be a python Int".format(
......@@ -93,7 +132,34 @@ def case_generator(op_type, Xshape, diagonal, expected):
self.diagonal = diagonal
self.X = np.random.random(Xshape).astype("float64")
CLASS = locals()['SuccessCase' if expected == "success" else 'FailureCase']
class SuccessCaseFP16(TrilTriuOpDefaultTestFP16):
def initTestCase(self):
self.init_dtype()
self.real_op_type = op_type
self.diagonal = diagonal
self.X = np.random.random(Xshape).astype("float16")
class SuccessCaseBF16(TrilTriuOpDefaultTestBF16):
def initTestCase(self):
self.init_dtype()
self.real_op_type = op_type
self.diagonal = diagonal
self.X = np.random.random(Xshape).astype("float32")
if dtype == "float64":
CLASS = locals()[
'SuccessCase' if expected == "success" else 'FailureCase'
]
elif dtype == "float16":
CLASS = locals()[
'SuccessCaseFP16' if expected == "success" else 'FailureCase'
]
elif dtype == "bfloat16":
CLASS = locals()[
'SuccessCaseBF16' if expected == "success" else 'FailureCase'
]
else:
raise ValueError(f"Not supported dtype {dtype}")
CLASS.__name__ = cls_name
globals()[cls_name] = CLASS
......@@ -119,11 +185,14 @@ cases = {
(2020,): [None],
},
}
for _op_type in ['tril', 'triu']:
for dtype in ["float64", "float16", "bfloat16"]:
for _op_type in ['tril', 'triu']:
for _expected, _params in cases.items():
for _Xshape, _diaglist in _params.items():
[
case_generator(_op_type, _Xshape, _diagonal, _expected)
case_generator(
_op_type, _Xshape, _diagonal, _expected, dtype
)
for _diagonal in _diaglist
]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册