未验证 提交 2b7cbd1b 编写于 作者: H HongyuJia 提交者: GitHub

[CINN] Fix TestGelu unittest of CINN (#53859)

* [CINN] Fix TestGelu unittest of CINN

* pass if_enable_cinn
上级 9b0f621c
......@@ -2083,7 +2083,6 @@ class TestGeluApproximate(TestActivation):
np.random.seed(1024)
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
out = gelu(x, approximate)
self.enable_cinn = False
self.inputs = {'X': x}
self.outputs = {'Out': out}
......@@ -2093,6 +2092,9 @@ class TestGeluApproximate(TestActivation):
# cpu device, lower threshold to support 1e-8 for pass the unittest
self.rev_comp_rtol = 1e-8
self.rev_comp_atol = 1e-8
# Cumulative error occurs between comp and cinn, so that we also set cinn_rtol to 1e-8 as rev_comp_rtol = 1e-8
self.cinn_rtol = 1e-8
self.cinn_atol = 1e-8
def test_check_output(self):
self.check_output(check_prim=True)
......@@ -2125,9 +2127,12 @@ class TestGelu(TestActivation):
# cpu, lower threshold to support 1e-8 for pass the unittest
self.rev_comp_rtol = 1e-8
self.rev_comp_atol = 1e-8
# Cumulative error occurs between comp and cinn, so that we also set cinn_rtol to 1e-8 as rev_comp_rtol = 1e-8
self.cinn_rtol = 1e-8
self.cinn_atol = 1e-8
def if_enable_cinn(self):
self.enable_cinn = False
pass
def test_check_output(self):
self.check_output(check_prim=True)
......@@ -4028,9 +4033,11 @@ create_test_act_fp16_class(TestRelu, check_prim=True)
create_test_act_fp16_class(
TestGelu,
check_prim=True,
enable_cinn=False,
enable_cinn=True,
rev_comp_rtol=1e-3,
rev_comp_atol=1e-3,
cinn_rtol=1e-3,
cinn_atol=1e-3,
)
create_test_act_fp16_class(TestBRelu)
create_test_act_fp16_class(TestRelu6)
......@@ -4141,9 +4148,11 @@ create_test_act_bf16_class(TestRelu, check_prim=True)
create_test_act_bf16_class(
TestGelu,
check_prim=True,
enable_cinn=False,
enable_cinn=True,
rev_comp_rtol=1e-2,
rev_comp_atol=1e-2,
cinn_rtol=1e-2,
cinn_atol=1e-2,
)
create_test_act_bf16_class(TestBRelu)
create_test_act_bf16_class(TestRelu6)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册