未验证 提交 e1674e8b 编写于 作者: C Charles-hit 提交者: GitHub

add prim test for some ops (#51749)

* add tanh and cast prim test

* fix tanh test

* fix 0-d test

* add sqrt fp16 prim test

* add public_python_api in prim test

* fix test_squeeze2_op

* add tanh prim test

* add dropout prim test

* [Dy2St]Fix clone for test state problem

* clean code

* modify test_cumsum_op

* modify test_cumsum_op

* fix dropout test

* add dropout in cmake

* fix dropout test

---------
Co-authored-by: NAurelius84 <zhangliujie@baidu.com>
上级 20befdef
...@@ -1212,7 +1212,9 @@ set(TEST_CINN_OPS ...@@ -1212,7 +1212,9 @@ set(TEST_CINN_OPS
test_mean_op test_mean_op
test_unsqueeze2_op test_unsqueeze2_op
test_meshgrid_op test_meshgrid_op
test_gather_op) test_gather_op
test_cast_op
test_dropout_op)
foreach(TEST_CINN_OPS ${TEST_CINN_OPS}) foreach(TEST_CINN_OPS ${TEST_CINN_OPS})
if(WITH_CINN) if(WITH_CINN)
......
...@@ -469,9 +469,12 @@ class TestLogSigmoidAPI(unittest.TestCase): ...@@ -469,9 +469,12 @@ class TestLogSigmoidAPI(unittest.TestCase):
class TestTanh(TestActivation, TestParameter): class TestTanh(TestActivation, TestParameter):
def setUp(self): def setUp(self):
self.op_type = "tanh" self.op_type = "tanh"
self.prim_op_type = "prim"
self.python_api = paddle.tanh self.python_api = paddle.tanh
self.public_python_api = paddle.tanh
self.init_dtype() self.init_dtype()
self.init_shape() self.init_shape()
self.if_enable_cinn()
np.random.seed(1024) np.random.seed(1024)
x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype) x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype)
...@@ -483,7 +486,7 @@ class TestTanh(TestActivation, TestParameter): ...@@ -483,7 +486,7 @@ class TestTanh(TestActivation, TestParameter):
def test_check_grad(self): def test_check_grad(self):
if self.dtype == np.float16: if self.dtype == np.float16:
return return
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out', check_prim=True)
def init_dtype(self): def init_dtype(self):
# TODO If dtype is float64, the output (Out) has diff at CPUPlace # TODO If dtype is float64, the output (Out) has diff at CPUPlace
...@@ -491,11 +494,17 @@ class TestTanh(TestActivation, TestParameter): ...@@ -491,11 +494,17 @@ class TestTanh(TestActivation, TestParameter):
# for now. # for now.
self.dtype = np.float32 self.dtype = np.float32
def if_enable_cinn(self):
pass
class TestTanh_ZeroDim(TestTanh): class TestTanh_ZeroDim(TestTanh):
def init_shape(self): def init_shape(self):
self.shape = [] self.shape = []
def if_enable_cinn(self):
self.enable_cinn = False
class TestTanhAPI(unittest.TestCase): class TestTanhAPI(unittest.TestCase):
# test paddle.tanh, paddle.nn.tanh, paddle.nn.functional.tanh # test paddle.tanh, paddle.nn.tanh, paddle.nn.functional.tanh
...@@ -601,7 +610,7 @@ class TestAtan(TestActivation, TestParameter): ...@@ -601,7 +610,7 @@ class TestAtan(TestActivation, TestParameter):
self.assertEqual(z, z_expected) self.assertEqual(z, z_expected)
class TestAtan_ZeroDim(TestTanh): class TestAtan_ZeroDim(TestAtan):
def init_shape(self): def init_shape(self):
self.shape = [] self.shape = []
...@@ -3910,7 +3919,7 @@ create_test_act_fp16_class(TestTanh) ...@@ -3910,7 +3919,7 @@ create_test_act_fp16_class(TestTanh)
create_test_act_fp16_class(TestTanhshrink) create_test_act_fp16_class(TestTanhshrink)
create_test_act_fp16_class(TestHardShrink) create_test_act_fp16_class(TestHardShrink)
create_test_act_fp16_class(TestSoftshrink) create_test_act_fp16_class(TestSoftshrink)
create_test_act_fp16_class(TestSqrt) create_test_act_fp16_class(TestSqrt, check_prim=True)
create_test_act_fp16_class(TestSqrtComp, check_prim=True) create_test_act_fp16_class(TestSqrtComp, check_prim=True)
create_test_act_fp16_class(TestAbs, check_prim=True) create_test_act_fp16_class(TestAbs, check_prim=True)
create_test_act_fp16_class(TestCeil, grad_check=False) create_test_act_fp16_class(TestCeil, grad_check=False)
......
...@@ -28,33 +28,8 @@ from paddle import fluid ...@@ -28,33 +28,8 @@ from paddle import fluid
from paddle.fluid import Program, core, program_guard from paddle.fluid import Program, core, program_guard
def convert_to_dtype_(dtype):
if dtype == 5:
return core.VarDesc.VarType.FP32
elif dtype == 6:
return core.VarDesc.VarType.FP64
elif dtype == 4:
return core.VarDesc.VarType.FP16
elif dtype == 2:
return core.VarDesc.VarType.INT32
elif dtype == 1:
return core.VarDesc.VarType.INT16
elif dtype == 3:
return core.VarDesc.VarType.INT64
elif dtype == 0:
return core.VarDesc.VarType.BOOL
elif dtype == 22:
return core.VarDesc.VarType.BF16
elif dtype == 20:
return core.VarDesc.VarType.UINT8
elif dtype == 21:
return core.VarDesc.VarType.INT8
elif dtype == np.complex64:
raise ValueError("Not supported dtype %s" % dtype)
def cast_wrapper(x, out_dtype=None): def cast_wrapper(x, out_dtype=None):
return paddle.tensor.cast(x, convert_to_dtype_(out_dtype)) return paddle.cast(x, paddle.dtype(out_dtype))
class TestCastOpFp32ToFp64(OpTest): class TestCastOpFp32ToFp64(OpTest):
...@@ -67,13 +42,15 @@ class TestCastOpFp32ToFp64(OpTest): ...@@ -67,13 +42,15 @@ class TestCastOpFp32ToFp64(OpTest):
'out_dtype': int(core.VarDesc.VarType.FP64), 'out_dtype': int(core.VarDesc.VarType.FP64),
} }
self.op_type = 'cast' self.op_type = 'cast'
self.prim_op_type = "prim"
self.python_api = cast_wrapper self.python_api = cast_wrapper
self.public_python_api = cast_wrapper
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_grad(self): def test_grad(self):
self.check_grad(['X'], ['Out']) self.check_grad(['X'], ['Out'], check_prim=True)
class TestCastOpFp16ToFp32(OpTest): class TestCastOpFp16ToFp32(OpTest):
...@@ -86,12 +63,16 @@ class TestCastOpFp16ToFp32(OpTest): ...@@ -86,12 +63,16 @@ class TestCastOpFp16ToFp32(OpTest):
'out_dtype': int(core.VarDesc.VarType.FP32), 'out_dtype': int(core.VarDesc.VarType.FP32),
} }
self.op_type = 'cast' self.op_type = 'cast'
self.__class__.no_need_check_grad = True self.prim_op_type = "prim"
self.python_api = cast_wrapper self.python_api = cast_wrapper
self.public_python_api = cast_wrapper
def test_check_output(self): def test_check_output(self):
self.check_output(atol=1e-3) self.check_output(atol=1e-3)
def test_grad(self):
self.check_grad(['X'], ['Out'], check_prim=True, only_check_prim=True)
class TestCastOpFp32ToFp16(OpTest): class TestCastOpFp32ToFp16(OpTest):
def setUp(self): def setUp(self):
...@@ -103,12 +84,16 @@ class TestCastOpFp32ToFp16(OpTest): ...@@ -103,12 +84,16 @@ class TestCastOpFp32ToFp16(OpTest):
'out_dtype': int(core.VarDesc.VarType.FP16), 'out_dtype': int(core.VarDesc.VarType.FP16),
} }
self.op_type = 'cast' self.op_type = 'cast'
self.__class__.no_need_check_grad = True self.prim_op_type = "prim"
self.python_api = cast_wrapper self.python_api = cast_wrapper
self.public_python_api = cast_wrapper
def test_check_output(self): def test_check_output(self):
self.check_output(atol=1e-3) self.check_output(atol=1e-3)
def test_grad(self):
self.check_grad(['X'], ['Out'], check_prim=True, only_check_prim=True)
class TestCastOpBf16ToFp32(OpTest): class TestCastOpBf16ToFp32(OpTest):
def setUp(self): def setUp(self):
...@@ -120,12 +105,17 @@ class TestCastOpBf16ToFp32(OpTest): ...@@ -120,12 +105,17 @@ class TestCastOpBf16ToFp32(OpTest):
'out_dtype': int(core.VarDesc.VarType.FP32), 'out_dtype': int(core.VarDesc.VarType.FP32),
} }
self.op_type = 'cast' self.op_type = 'cast'
self.__class__.no_need_check_grad = True self.prim_op_type = "prim"
self.python_api = cast_wrapper self.python_api = cast_wrapper
self.public_python_api = cast_wrapper
self.enable_cinn = False
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_grad(self):
self.check_grad(['X'], ['Out'], check_prim=True, only_check_prim=True)
class TestCastOpFp32ToBf16(OpTest): class TestCastOpFp32ToBf16(OpTest):
def setUp(self): def setUp(self):
...@@ -137,12 +127,17 @@ class TestCastOpFp32ToBf16(OpTest): ...@@ -137,12 +127,17 @@ class TestCastOpFp32ToBf16(OpTest):
'out_dtype': int(core.VarDesc.VarType.BF16), 'out_dtype': int(core.VarDesc.VarType.BF16),
} }
self.op_type = 'cast' self.op_type = 'cast'
self.__class__.no_need_check_grad = True self.prim_op_type = "prim"
self.python_api = cast_wrapper self.python_api = cast_wrapper
self.public_python_api = cast_wrapper
self.enable_cinn = False
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_grad(self):
self.check_grad(['X'], ['Out'], check_prim=True, only_check_prim=True)
class TestCastOpError(unittest.TestCase): class TestCastOpError(unittest.TestCase):
def test_errors(self): def test_errors(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册