From c9ca7c3525077425096725924be05f1fa2f93f84 Mon Sep 17 00:00:00 2001 From: Kang Zhao Date: Wed, 15 Mar 2023 16:01:24 +0800 Subject: [PATCH] feat: add rsqrt composite rule (#51432) * feat: add relu composite rule * feat: add relu composite rule, maximum op * feat: add relu composite rule, maximum op * feat: add relu composite rule, polish comments * feat: add relu composite rule, polish comments * feat: add relu composite rule, add python api of relu * feat: add relu composite rule, commit hook * fix: maximum type error & ban cinn test * fix: maximum input sequence bugs * resolve conflicts * fix: code style bugs * add: relu fp16 test * feat: add rsqrt composite rule * feat: add rsqrt composite rule * resolve conflicts of composite rule * fix: delete check eager --- .../fluid/tests/unittests/test_activation_op.py | 12 +++++++++++- python/paddle/incubate/autograd/composite_rules.py | 8 ++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index 50101ba94a0..233139c225b 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -1238,6 +1238,7 @@ class TestSqrtBF16(OpTest): class TestRsqrt(TestActivation): def setUp(self): self.op_type = "rsqrt" + self.prim_op_type = "comp" self.python_api = paddle.rsqrt self.init_dtype() self.init_shape() @@ -1248,14 +1249,23 @@ class TestRsqrt(TestActivation): self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + self.enable_cinn = True def init_shape(self): self.shape = [10, 12] + def test_check_output(self): + self.check_output(check_prim=True) + def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out', max_relative_error=0.0005) + self.check_grad( + ['X'], + 'Out', + max_relative_error=0.0005, + check_prim=True, + ) ''' diff --git a/python/paddle/incubate/autograd/composite_rules.py b/python/paddle/incubate/autograd/composite_rules.py index 28c87609ae1..2300cbccfa4 100644 --- a/python/paddle/incubate/autograd/composite_rules.py +++ b/python/paddle/incubate/autograd/composite_rules.py @@ -457,3 +457,11 @@ def unsqueeze_composite(x, axis): ) out = reshape(x, x_shape) return [out, None] + + +@REGISTER_COMPOSITE('rsqrt') +def rsqrt_composite(x): + """define composite rule of op rsqrt.""" + # rsqrt(x) = x^(-0.5) + y = full(x.shape, -0.5, x.dtype) + return pow(x, y) -- GitLab