未验证 提交 2b7cbd1b 编写于 作者: H HongyuJia 提交者: GitHub

[CINN] Fix TestGelu unittest of CINN (#53859)

* [CINN] Fix TestGelu unittest of CINN

* pass if_enable_cinn
上级 9b0f621c
...@@ -2083,7 +2083,6 @@ class TestGeluApproximate(TestActivation): ...@@ -2083,7 +2083,6 @@ class TestGeluApproximate(TestActivation):
np.random.seed(1024) np.random.seed(1024)
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype) x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
out = gelu(x, approximate) out = gelu(x, approximate)
self.enable_cinn = False
self.inputs = {'X': x} self.inputs = {'X': x}
self.outputs = {'Out': out} self.outputs = {'Out': out}
...@@ -2093,6 +2092,9 @@ class TestGeluApproximate(TestActivation): ...@@ -2093,6 +2092,9 @@ class TestGeluApproximate(TestActivation):
# cpu device, lower threshold to support 1e-8 for pass the unittest # cpu device, lower threshold to support 1e-8 for pass the unittest
self.rev_comp_rtol = 1e-8 self.rev_comp_rtol = 1e-8
self.rev_comp_atol = 1e-8 self.rev_comp_atol = 1e-8
# Cumulative error occurs between comp and cinn, so that we also set cinn_rtol to 1e-8 as rev_comp_rtol = 1e-8
self.cinn_rtol = 1e-8
self.cinn_atol = 1e-8
def test_check_output(self): def test_check_output(self):
self.check_output(check_prim=True) self.check_output(check_prim=True)
...@@ -2125,9 +2127,12 @@ class TestGelu(TestActivation): ...@@ -2125,9 +2127,12 @@ class TestGelu(TestActivation):
# cpu, lower threshold to support 1e-8 for pass the unittest # cpu, lower threshold to support 1e-8 for pass the unittest
self.rev_comp_rtol = 1e-8 self.rev_comp_rtol = 1e-8
self.rev_comp_atol = 1e-8 self.rev_comp_atol = 1e-8
# Cumulative error occurs between comp and cinn, so that we also set cinn_rtol to 1e-8 as rev_comp_rtol = 1e-8
self.cinn_rtol = 1e-8
self.cinn_atol = 1e-8
def if_enable_cinn(self): def if_enable_cinn(self):
self.enable_cinn = False pass
def test_check_output(self): def test_check_output(self):
self.check_output(check_prim=True) self.check_output(check_prim=True)
...@@ -4028,9 +4033,11 @@ create_test_act_fp16_class(TestRelu, check_prim=True) ...@@ -4028,9 +4033,11 @@ create_test_act_fp16_class(TestRelu, check_prim=True)
create_test_act_fp16_class( create_test_act_fp16_class(
TestGelu, TestGelu,
check_prim=True, check_prim=True,
enable_cinn=False, enable_cinn=True,
rev_comp_rtol=1e-3, rev_comp_rtol=1e-3,
rev_comp_atol=1e-3, rev_comp_atol=1e-3,
cinn_rtol=1e-3,
cinn_atol=1e-3,
) )
create_test_act_fp16_class(TestBRelu) create_test_act_fp16_class(TestBRelu)
create_test_act_fp16_class(TestRelu6) create_test_act_fp16_class(TestRelu6)
...@@ -4141,9 +4148,11 @@ create_test_act_bf16_class(TestRelu, check_prim=True) ...@@ -4141,9 +4148,11 @@ create_test_act_bf16_class(TestRelu, check_prim=True)
create_test_act_bf16_class( create_test_act_bf16_class(
TestGelu, TestGelu,
check_prim=True, check_prim=True,
enable_cinn=False, enable_cinn=True,
rev_comp_rtol=1e-2, rev_comp_rtol=1e-2,
rev_comp_atol=1e-2, rev_comp_atol=1e-2,
cinn_rtol=1e-2,
cinn_atol=1e-2,
) )
create_test_act_bf16_class(TestBRelu) create_test_act_bf16_class(TestBRelu)
create_test_act_bf16_class(TestRelu6) create_test_act_bf16_class(TestRelu6)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册