未验证 提交 079f41c8 编写于 作者: K Kang Zhao 提交者: GitHub

feat: add relu composite (#50819)

* feat: add relu composite rule

* feat: add relu composite rule, maximum op

* feat: add relu composite rule, maximum op

* feat: add relu composite rule, polish comments

* feat: add relu composite rule, polish comments

* feat: add relu composite rule, add python api of relu

* feat: add relu composite rule, commit hook

* fix: maximum type error & ban cinn test

* fix: maximum input sequence bugs

* resolve conflicts

* fix: code style bugs

* add: relu fp16 test
上级 47846730
...@@ -1711,8 +1711,11 @@ class TestRound_ZeroDim(TestRound): ...@@ -1711,8 +1711,11 @@ class TestRound_ZeroDim(TestRound):
class TestRelu(TestActivation): class TestRelu(TestActivation):
def setUp(self): def setUp(self):
self.op_type = "relu" self.op_type = "relu"
self.python_api = paddle.nn.functional.relu
self.prim_op_type = "comp"
self.init_dtype() self.init_dtype()
self.init_shape() self.init_shape()
self.skip_cinn()
np.random.seed(1024) np.random.seed(1024)
if self.dtype == np.uint16: if self.dtype == np.uint16:
...@@ -1733,13 +1736,22 @@ class TestRelu(TestActivation): ...@@ -1733,13 +1736,22 @@ class TestRelu(TestActivation):
def test_check_grad(self): def test_check_grad(self):
if self.dtype == np.float16: if self.dtype == np.float16:
return return
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out', check_prim=True)
def test_check_output(self):
self.check_output(check_prim=True)
def skip_cinn(self):
self.enable_cinn = False
class TestRelu_ZeroDim(TestRelu): class TestRelu_ZeroDim(TestRelu):
def init_shape(self): def init_shape(self):
self.shape = [] self.shape = []
def skip_cinn(self):
self.enable_cinn = False
class TestReluAPI(unittest.TestCase): class TestReluAPI(unittest.TestCase):
# test paddle.nn.ReLU, paddle.nn.functional.relu # test paddle.nn.ReLU, paddle.nn.functional.relu
...@@ -3826,7 +3838,7 @@ create_test_act_fp16_class(TestAcosh, grad_atol=0.85) ...@@ -3826,7 +3838,7 @@ create_test_act_fp16_class(TestAcosh, grad_atol=0.85)
create_test_act_fp16_class(TestAsinh, grad_atol=0.85) create_test_act_fp16_class(TestAsinh, grad_atol=0.85)
create_test_act_fp16_class(TestAtanh, grad_atol=0.85) create_test_act_fp16_class(TestAtanh, grad_atol=0.85)
create_test_act_fp16_class(TestRound, grad_check=False) create_test_act_fp16_class(TestRound, grad_check=False)
create_test_act_fp16_class(TestRelu) create_test_act_fp16_class(TestRelu, check_prim=True)
create_test_act_fp16_class(TestGelu) create_test_act_fp16_class(TestGelu)
create_test_act_fp16_class(TestBRelu) create_test_act_fp16_class(TestBRelu)
create_test_act_fp16_class(TestRelu6) create_test_act_fp16_class(TestRelu6)
......
...@@ -330,3 +330,10 @@ def fill_any_like(x, fill_value, dtype, place=None): ...@@ -330,3 +330,10 @@ def fill_any_like(x, fill_value, dtype, place=None):
dtype = dtypes.dtype(dtype) dtype = dtypes.dtype(dtype)
val = full(x.shape, fill_value, dtype) val = full(x.shape, fill_value, dtype)
return val return val
@REGISTER_COMPOSITE('relu')
def relu_composite(x):
"""define composite rule of op relu."""
# relu(x) = max(x, 0)
return maximum(x, zeros_like(x))
...@@ -60,6 +60,7 @@ from paddle.tensor import tanh # noqa: F401 ...@@ -60,6 +60,7 @@ from paddle.tensor import tanh # noqa: F401
from paddle.tensor import uniform # noqa: F401 from paddle.tensor import uniform # noqa: F401
from paddle.tensor import zeros # noqa: F401 from paddle.tensor import zeros # noqa: F401
from paddle.tensor.creation import assign # noqa: F401 from paddle.tensor.creation import assign # noqa: F401
from paddle.tensor.creation import zeros_like # noqa: F401
from paddle.tensor.manipulation import cast # noqa: F401 from paddle.tensor.manipulation import cast # noqa: F401
from paddle.tensor.math import maximum # noqa: F401 from paddle.tensor.math import maximum # noqa: F401
from paddle.tensor.math import minimum # noqa: F401 from paddle.tensor.math import minimum # noqa: F401
...@@ -89,9 +90,9 @@ math_op = [ ...@@ -89,9 +90,9 @@ math_op = [
'logcumsumexp', 'logcumsumexp',
'logit', 'logit',
'max', 'max',
'maximum',
'min', 'min',
'minimum', 'minimum',
'maximum'
] ]
trigonometric_op = [ trigonometric_op = [
...@@ -126,5 +127,6 @@ others = [ ...@@ -126,5 +127,6 @@ others = [
'concat', 'concat',
'uniform', 'uniform',
'greater_equal', 'greater_equal',
'zeros_like',
] ]
""" """
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册