未验证 提交 079f41c8 编写于 作者: K Kang Zhao 提交者: GitHub

feat: add relu composite (#50819)

* feat: add relu composite rule

* feat: add relu composite rule, maximum op

* feat: add relu composite rule, maximum op

* feat: add relu composite rule, polish comments

* feat: add relu composite rule, polish comments

* feat: add relu composite rule, add python api of relu

* feat: add relu composite rule, commit hook

* fix: maximum type error & ban cinn test

* fix: maximum input sequence bugs

* resolve conflicts

* fix: code style bugs

* add: relu fp16 test
上级 47846730
......@@ -1711,8 +1711,11 @@ class TestRound_ZeroDim(TestRound):
class TestRelu(TestActivation):
def setUp(self):
self.op_type = "relu"
self.python_api = paddle.nn.functional.relu
self.prim_op_type = "comp"
self.init_dtype()
self.init_shape()
self.skip_cinn()
np.random.seed(1024)
if self.dtype == np.uint16:
......@@ -1733,13 +1736,22 @@ class TestRelu(TestActivation):
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=True)
def test_check_output(self):
self.check_output(check_prim=True)
def skip_cinn(self):
self.enable_cinn = False
class TestRelu_ZeroDim(TestRelu):
def init_shape(self):
self.shape = []
def skip_cinn(self):
self.enable_cinn = False
class TestReluAPI(unittest.TestCase):
# test paddle.nn.ReLU, paddle.nn.functional.relu
......@@ -3826,7 +3838,7 @@ create_test_act_fp16_class(TestAcosh, grad_atol=0.85)
create_test_act_fp16_class(TestAsinh, grad_atol=0.85)
create_test_act_fp16_class(TestAtanh, grad_atol=0.85)
create_test_act_fp16_class(TestRound, grad_check=False)
create_test_act_fp16_class(TestRelu)
create_test_act_fp16_class(TestRelu, check_prim=True)
create_test_act_fp16_class(TestGelu)
create_test_act_fp16_class(TestBRelu)
create_test_act_fp16_class(TestRelu6)
......
......@@ -330,3 +330,10 @@ def fill_any_like(x, fill_value, dtype, place=None):
dtype = dtypes.dtype(dtype)
val = full(x.shape, fill_value, dtype)
return val
@REGISTER_COMPOSITE('relu')
def relu_composite(x):
"""define composite rule of op relu."""
# relu(x) = max(x, 0)
return maximum(x, zeros_like(x))
......@@ -60,6 +60,7 @@ from paddle.tensor import tanh # noqa: F401
from paddle.tensor import uniform # noqa: F401
from paddle.tensor import zeros # noqa: F401
from paddle.tensor.creation import assign # noqa: F401
from paddle.tensor.creation import zeros_like # noqa: F401
from paddle.tensor.manipulation import cast # noqa: F401
from paddle.tensor.math import maximum # noqa: F401
from paddle.tensor.math import minimum # noqa: F401
......@@ -89,9 +90,9 @@ math_op = [
'logcumsumexp',
'logit',
'max',
'maximum',
'min',
'minimum',
'maximum'
]
trigonometric_op = [
......@@ -126,5 +127,6 @@ others = [
'concat',
'uniform',
'greater_equal',
'zeros_like',
]
"""
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册