未验证 提交 e1674e8b 编写于 作者: C Charles-hit 提交者: GitHub

add prim test for some ops (#51749)

* add tanh and cast prim test

* fix tanh test

* fix 0-d test

* add sqrt fp16 prim test

* add public_python_api in prim test

* fix test_squeeze2_op

* add tanh prim test

* add dropout prim test

* [Dy2St]Fix clone for test state problem

* clean code

* modify test_cumsum_op

* modify test_cumsum_op

* fix dropout test

* add dropout in cmake

* fix dropout test

---------
Co-authored-by: NAurelius84 <zhangliujie@baidu.com>
上级 20befdef
......@@ -1212,7 +1212,9 @@ set(TEST_CINN_OPS
test_mean_op
test_unsqueeze2_op
test_meshgrid_op
test_gather_op)
test_gather_op
test_cast_op
test_dropout_op)
foreach(TEST_CINN_OPS ${TEST_CINN_OPS})
if(WITH_CINN)
......
......@@ -469,9 +469,12 @@ class TestLogSigmoidAPI(unittest.TestCase):
class TestTanh(TestActivation, TestParameter):
def setUp(self):
self.op_type = "tanh"
self.prim_op_type = "prim"
self.python_api = paddle.tanh
self.public_python_api = paddle.tanh
self.init_dtype()
self.init_shape()
self.if_enable_cinn()
np.random.seed(1024)
x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype)
......@@ -483,7 +486,7 @@ class TestTanh(TestActivation, TestParameter):
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=True)
def init_dtype(self):
# TODO If dtype is float64, the output (Out) has diff at CPUPlace
......@@ -491,11 +494,17 @@ class TestTanh(TestActivation, TestParameter):
# for now.
self.dtype = np.float32
def if_enable_cinn(self):
pass
class TestTanh_ZeroDim(TestTanh):
def init_shape(self):
self.shape = []
def if_enable_cinn(self):
self.enable_cinn = False
class TestTanhAPI(unittest.TestCase):
# test paddle.tanh, paddle.nn.tanh, paddle.nn.functional.tanh
......@@ -601,7 +610,7 @@ class TestAtan(TestActivation, TestParameter):
self.assertEqual(z, z_expected)
class TestAtan_ZeroDim(TestTanh):
class TestAtan_ZeroDim(TestAtan):
def init_shape(self):
self.shape = []
......@@ -3910,7 +3919,7 @@ create_test_act_fp16_class(TestTanh)
create_test_act_fp16_class(TestTanhshrink)
create_test_act_fp16_class(TestHardShrink)
create_test_act_fp16_class(TestSoftshrink)
create_test_act_fp16_class(TestSqrt)
create_test_act_fp16_class(TestSqrt, check_prim=True)
create_test_act_fp16_class(TestSqrtComp, check_prim=True)
create_test_act_fp16_class(TestAbs, check_prim=True)
create_test_act_fp16_class(TestCeil, grad_check=False)
......
......@@ -28,33 +28,8 @@ from paddle import fluid
from paddle.fluid import Program, core, program_guard
def convert_to_dtype_(dtype):
if dtype == 5:
return core.VarDesc.VarType.FP32
elif dtype == 6:
return core.VarDesc.VarType.FP64
elif dtype == 4:
return core.VarDesc.VarType.FP16
elif dtype == 2:
return core.VarDesc.VarType.INT32
elif dtype == 1:
return core.VarDesc.VarType.INT16
elif dtype == 3:
return core.VarDesc.VarType.INT64
elif dtype == 0:
return core.VarDesc.VarType.BOOL
elif dtype == 22:
return core.VarDesc.VarType.BF16
elif dtype == 20:
return core.VarDesc.VarType.UINT8
elif dtype == 21:
return core.VarDesc.VarType.INT8
elif dtype == np.complex64:
raise ValueError("Not supported dtype %s" % dtype)
def cast_wrapper(x, out_dtype=None):
return paddle.tensor.cast(x, convert_to_dtype_(out_dtype))
return paddle.cast(x, paddle.dtype(out_dtype))
class TestCastOpFp32ToFp64(OpTest):
......@@ -67,13 +42,15 @@ class TestCastOpFp32ToFp64(OpTest):
'out_dtype': int(core.VarDesc.VarType.FP64),
}
self.op_type = 'cast'
self.prim_op_type = "prim"
self.python_api = cast_wrapper
self.public_python_api = cast_wrapper
def test_check_output(self):
self.check_output()
def test_grad(self):
self.check_grad(['X'], ['Out'])
self.check_grad(['X'], ['Out'], check_prim=True)
class TestCastOpFp16ToFp32(OpTest):
......@@ -86,12 +63,16 @@ class TestCastOpFp16ToFp32(OpTest):
'out_dtype': int(core.VarDesc.VarType.FP32),
}
self.op_type = 'cast'
self.__class__.no_need_check_grad = True
self.prim_op_type = "prim"
self.python_api = cast_wrapper
self.public_python_api = cast_wrapper
def test_check_output(self):
self.check_output(atol=1e-3)
def test_grad(self):
self.check_grad(['X'], ['Out'], check_prim=True, only_check_prim=True)
class TestCastOpFp32ToFp16(OpTest):
def setUp(self):
......@@ -103,12 +84,16 @@ class TestCastOpFp32ToFp16(OpTest):
'out_dtype': int(core.VarDesc.VarType.FP16),
}
self.op_type = 'cast'
self.__class__.no_need_check_grad = True
self.prim_op_type = "prim"
self.python_api = cast_wrapper
self.public_python_api = cast_wrapper
def test_check_output(self):
self.check_output(atol=1e-3)
def test_grad(self):
self.check_grad(['X'], ['Out'], check_prim=True, only_check_prim=True)
class TestCastOpBf16ToFp32(OpTest):
def setUp(self):
......@@ -120,12 +105,17 @@ class TestCastOpBf16ToFp32(OpTest):
'out_dtype': int(core.VarDesc.VarType.FP32),
}
self.op_type = 'cast'
self.__class__.no_need_check_grad = True
self.prim_op_type = "prim"
self.python_api = cast_wrapper
self.public_python_api = cast_wrapper
self.enable_cinn = False
def test_check_output(self):
self.check_output()
def test_grad(self):
self.check_grad(['X'], ['Out'], check_prim=True, only_check_prim=True)
class TestCastOpFp32ToBf16(OpTest):
def setUp(self):
......@@ -137,12 +127,17 @@ class TestCastOpFp32ToBf16(OpTest):
'out_dtype': int(core.VarDesc.VarType.BF16),
}
self.op_type = 'cast'
self.__class__.no_need_check_grad = True
self.prim_op_type = "prim"
self.python_api = cast_wrapper
self.public_python_api = cast_wrapper
self.enable_cinn = False
def test_check_output(self):
self.check_output()
def test_grad(self):
self.check_grad(['X'], ['Out'], check_prim=True, only_check_prim=True)
class TestCastOpError(unittest.TestCase):
def test_errors(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册