未验证 提交 5d70ba6d 编写于 作者: Z zxcd 提交者: GitHub

add silu composite rule (#50838)

* add silu composite rule

* fix code style.

* add silu fp16 unit test.
上级 bbf2bc2b
...@@ -324,6 +324,9 @@ class TestSigmoidBF16_ZeroDim(TestSigmoidBF16): ...@@ -324,6 +324,9 @@ class TestSigmoidBF16_ZeroDim(TestSigmoidBF16):
class TestSilu(TestActivation): class TestSilu(TestActivation):
def setUp(self): def setUp(self):
self.op_type = "silu" self.op_type = "silu"
self.prim_op_type = "comp"
self.enable_cinn = False
self.python_api = paddle.nn.functional.silu
self.init_dtype() self.init_dtype()
self.init_shape() self.init_shape()
...@@ -340,7 +343,7 @@ class TestSilu(TestActivation): ...@@ -340,7 +343,7 @@ class TestSilu(TestActivation):
def test_check_grad(self): def test_check_grad(self):
if self.dtype == np.float16: if self.dtype == np.float16:
return return
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out', check_prim=True)
class TestSilu_ZeroDim(TestSilu): class TestSilu_ZeroDim(TestSilu):
...@@ -348,6 +351,36 @@ class TestSilu_ZeroDim(TestSilu): ...@@ -348,6 +351,36 @@ class TestSilu_ZeroDim(TestSilu):
self.shape = [] self.shape = []
class TestSiluFP16(TestActivation):
def setUp(self):
self.op_type = "silu"
self.prim_op_type = "comp"
self.enable_cinn = False
self.only_prim = True
self.python_api = paddle.nn.functional.silu
self.init_dtype()
self.init_shape()
np.random.seed(1024)
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
out = x / (np.exp(-x) + 1)
self.inputs = {'X': x}
self.outputs = {'Out': out}
def init_dtype(self):
self.dtype = np.float16
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_prim=True)
def test_check_output(self):
check_eager = False
if hasattr(self, 'check_eager'):
check_eager = self.check_eager
self.check_output(check_eager=check_eager, check_prim=True)
class TestSiluAPI(unittest.TestCase): class TestSiluAPI(unittest.TestCase):
# test paddle.nn.Silu, paddle.nn.functional.silu # test paddle.nn.Silu, paddle.nn.functional.silu
def setUp(self): def setUp(self):
...@@ -3644,6 +3677,7 @@ create_test_act_fp16_class(TestActivation) ...@@ -3644,6 +3677,7 @@ create_test_act_fp16_class(TestActivation)
create_test_act_fp16_class(TestExpm1) create_test_act_fp16_class(TestExpm1)
create_test_act_fp16_class(TestSigmoid) create_test_act_fp16_class(TestSigmoid)
create_test_act_fp16_class(TestSilu) create_test_act_fp16_class(TestSilu)
create_test_act_fp16_class(TestSiluFP16)
create_test_act_fp16_class(TestLogSigmoid) create_test_act_fp16_class(TestLogSigmoid)
create_test_act_fp16_class(TestTanh) create_test_act_fp16_class(TestTanh)
create_test_act_fp16_class(TestTanhshrink) create_test_act_fp16_class(TestTanhshrink)
......
...@@ -258,3 +258,14 @@ def bernoulli(shape, dtype, p, seed=0): ...@@ -258,3 +258,14 @@ def bernoulli(shape, dtype, p, seed=0):
), ),
dtype, dtype,
) )
@REGISTER_COMPOSITE('silu')
def silu_composite(x):
"""
define composite rule of op silu
res = x / (1 + exp(-x))
"""
sum_temp = 1 + exp(-x)
res = x / sum_temp
return res
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册