From 2b7cbd1b6810c7d6a5623789da1e0c41ee1ea6f5 Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Thu, 18 May 2023 16:26:41 +0800 Subject: [PATCH] [CINN] Fix TestGelu unittest of CINN (#53859) * [CINN] Fix TestGelu unittest of CINN * pass if_enable_cinn --- .../fluid/tests/unittests/test_activation_op.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index c273cd49549..4e691ae3af9 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -2083,7 +2083,6 @@ class TestGeluApproximate(TestActivation): np.random.seed(1024) x = np.random.uniform(-1, 1, self.shape).astype(self.dtype) out = gelu(x, approximate) - self.enable_cinn = False self.inputs = {'X': x} self.outputs = {'Out': out} @@ -2093,6 +2092,9 @@ class TestGeluApproximate(TestActivation): # cpu device, lower threshold to support 1e-8 for pass the unittest self.rev_comp_rtol = 1e-8 self.rev_comp_atol = 1e-8 + # Cumulative error occurs between comp and cinn, so that we also set cinn_rtol to 1e-8 as rev_comp_rtol = 1e-8 + self.cinn_rtol = 1e-8 + self.cinn_atol = 1e-8 def test_check_output(self): self.check_output(check_prim=True) @@ -2125,9 +2127,12 @@ class TestGelu(TestActivation): # cpu, lower threshold to support 1e-8 for pass the unittest self.rev_comp_rtol = 1e-8 self.rev_comp_atol = 1e-8 + # Cumulative error occurs between comp and cinn, so that we also set cinn_rtol to 1e-8 as rev_comp_rtol = 1e-8 + self.cinn_rtol = 1e-8 + self.cinn_atol = 1e-8 def if_enable_cinn(self): - self.enable_cinn = False + pass def test_check_output(self): self.check_output(check_prim=True) @@ -4028,9 +4033,11 @@ create_test_act_fp16_class(TestRelu, check_prim=True) create_test_act_fp16_class( TestGelu, check_prim=True, - enable_cinn=False, + enable_cinn=True, rev_comp_rtol=1e-3, rev_comp_atol=1e-3, + cinn_rtol=1e-3, + cinn_atol=1e-3, ) create_test_act_fp16_class(TestBRelu) create_test_act_fp16_class(TestRelu6) @@ -4141,9 +4148,11 @@ create_test_act_bf16_class(TestRelu, check_prim=True) create_test_act_bf16_class( TestGelu, check_prim=True, - enable_cinn=False, + enable_cinn=True, rev_comp_rtol=1e-2, rev_comp_atol=1e-2, + cinn_rtol=1e-2, + cinn_atol=1e-2, ) create_test_act_bf16_class(TestBRelu) create_test_act_bf16_class(TestRelu6) -- GitLab