未验证 提交 c9ca7c35 编写于 作者: K Kang Zhao 提交者: GitHub

feat: add rsqrt composite rule (#51432)

* feat: add relu composite rule

* feat: add relu composite rule, maximum op

* feat: add relu composite rule, maximum op

* feat: add relu composite rule, polish comments

* feat: add relu composite rule, polish comments

* feat: add relu composite rule, add python api of relu

* feat: add relu composite rule, commit hook

* fix: maximum type error & ban cinn test

* fix: maximum input sequence bugs

* resolve conflicts

* fix: code style bugs

* add: relu fp16 test

* feat: add rsqrt composite rule

* feat: add rsqrt composite rule

* resolve conflicts of composite rule

* fix: delete check eager
上级 0e492e43
......@@ -1238,6 +1238,7 @@ class TestSqrtBF16(OpTest):
class TestRsqrt(TestActivation):
def setUp(self):
self.op_type = "rsqrt"
self.prim_op_type = "comp"
self.python_api = paddle.rsqrt
self.init_dtype()
self.init_shape()
......@@ -1248,14 +1249,23 @@ class TestRsqrt(TestActivation):
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
self.enable_cinn = True
def init_shape(self):
self.shape = [10, 12]
def test_check_output(self):
self.check_output(check_prim=True)
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', max_relative_error=0.0005)
self.check_grad(
['X'],
'Out',
max_relative_error=0.0005,
check_prim=True,
)
'''
......
......@@ -457,3 +457,11 @@ def unsqueeze_composite(x, axis):
)
out = reshape(x, x_shape)
return [out, None]
@REGISTER_COMPOSITE('rsqrt')
def rsqrt_composite(x):
"""define composite rule of op rsqrt."""
# rsqrt(x) = x^(-0.5)
y = full(x.shape, -0.5, x.dtype)
return pow(x, y)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册