diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index c273cd4954941876b55895a21bb4b45073b63454..4e691ae3af90d8e2347bb76b6b2e0f0787076523 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -2083,7 +2083,6 @@ class TestGeluApproximate(TestActivation): np.random.seed(1024) x = np.random.uniform(-1, 1, self.shape).astype(self.dtype) out = gelu(x, approximate) - self.enable_cinn = False self.inputs = {'X': x} self.outputs = {'Out': out} @@ -2093,6 +2092,9 @@ class TestGeluApproximate(TestActivation): # cpu device, lower threshold to support 1e-8 for pass the unittest self.rev_comp_rtol = 1e-8 self.rev_comp_atol = 1e-8 + # Cumulative error occurs between comp and cinn, so that we also set cinn_rtol to 1e-8 as rev_comp_rtol = 1e-8 + self.cinn_rtol = 1e-8 + self.cinn_atol = 1e-8 def test_check_output(self): self.check_output(check_prim=True) @@ -2125,9 +2127,12 @@ class TestGelu(TestActivation): # cpu, lower threshold to support 1e-8 for pass the unittest self.rev_comp_rtol = 1e-8 self.rev_comp_atol = 1e-8 + # Cumulative error occurs between comp and cinn, so that we also set cinn_rtol to 1e-8 as rev_comp_rtol = 1e-8 + self.cinn_rtol = 1e-8 + self.cinn_atol = 1e-8 def if_enable_cinn(self): - self.enable_cinn = False + pass def test_check_output(self): self.check_output(check_prim=True) @@ -4028,9 +4033,11 @@ create_test_act_fp16_class(TestRelu, check_prim=True) create_test_act_fp16_class( TestGelu, check_prim=True, - enable_cinn=False, + enable_cinn=True, rev_comp_rtol=1e-3, rev_comp_atol=1e-3, + cinn_rtol=1e-3, + cinn_atol=1e-3, ) create_test_act_fp16_class(TestBRelu) create_test_act_fp16_class(TestRelu6) @@ -4141,9 +4148,11 @@ create_test_act_bf16_class(TestRelu, check_prim=True) create_test_act_bf16_class( TestGelu, check_prim=True, - enable_cinn=False, + enable_cinn=True, rev_comp_rtol=1e-2, rev_comp_atol=1e-2, + cinn_rtol=1e-2, + cinn_atol=1e-2, ) create_test_act_bf16_class(TestBRelu) create_test_act_bf16_class(TestRelu6)