未验证 提交 92547a0e 编写于 作者: X Xiaoxu Chen 提交者: GitHub

【Prim】remove gelu precision threshold (#52350)

* [prim] promote gelu precision threshold
上级 1f3b9ef5
......@@ -626,13 +626,12 @@ class PrimForwardChecker:
)
raise RuntimeError(msg)
for i in range(len(ret)):
if not np.allclose(
np.testing.assert_allclose(
ret[i],
self.eager_desire[i],
rtol=self.fw_comp_rtol,
atol=self.fw_comp_atol,
):
msg = (
err_msg=(
'Check static comp forward api out failed. Mismatch between static comp '
'and eager on %s, when enable_fw_comp is %s,the forward api out tensor\'s index is : %d \n'
'static comp forward api out tensor:\n%s\n eager forward api out tensor:\n%s\n'
......@@ -643,8 +642,8 @@ class PrimForwardChecker:
ret[i],
self.eager_desire[i],
)
)
raise RuntimeError(msg)
),
)
paddle.disable_static()
core._set_prim_forward_enabled(False)
......@@ -696,10 +695,12 @@ class PrimForwardChecker:
)
raise RuntimeError(msg)
for i in range(len(ret)):
if not np.allclose(
ret[i], self.eager_desire[i], rtol=rtol, atol=atol
):
msg = (
np.testing.assert_allclose(
ret[i],
self.eager_desire[i],
rtol=rtol,
atol=atol,
err_msg=(
'Check jit comp forward api out failed. Mismatch between jit comp '
'and eager on %s, when enable_fw_comp is %s,the forward api out tensor\'s index is : %d \n'
'jit comp forward api out tensor:\n%s\n eager forward api out tensor:\n%s\n'
......@@ -710,8 +711,8 @@ class PrimForwardChecker:
ret[i],
self.eager_desire[i],
)
)
raise RuntimeError(msg)
),
)
core._set_prim_forward_enabled(False)
net.forward.program_cache.clear()
......@@ -781,10 +782,12 @@ class PrimForwardChecker:
)
raise RuntimeError(msg)
for i in range(len(ret)):
if not np.allclose(
ret[i], self.eager_desire[i], rtol=rtol, atol=atol
):
msg = (
np.testing.assert_allclose(
ret[i],
self.eager_desire[i],
rtol=rtol,
atol=atol,
err_msg=(
'Check jit comp with cinn forward api out failed. Mismatch between jit comp and eager on %s, '
'when enable_fw_comp is %s, enable_cinn is %s, the forward api out tensor\'s index is : %d \n'
'jit comp forward api out tensor:\n%s\n eager forward api out tensor:\n%s\n'
......@@ -796,8 +799,8 @@ class PrimForwardChecker:
ret[i],
self.eager_desire[i],
)
)
raise RuntimeError(msg)
),
)
core._set_prim_forward_enabled(False)
net.forward.program_cache.clear()
......@@ -975,13 +978,12 @@ class PrimGradChecker(PrimForwardChecker):
)
raise RuntimeError(msg)
for i in range(len(actual_ret)):
if not np.allclose(
np.testing.assert_allclose(
actual_ret[i],
self.eager_desire[i],
rtol=atol,
atol=rtol,
):
msg = (
err_msg=(
'Check eager comp grad out failed. Mismatch between eager comp '
'and eager on %s, when enable_rev_comp is %s,the eager comp grad out tensor\'s index is : %d \n'
'eager comp grad out tensor:\n%s\n eager grad out tensor:\n%s\n'
......@@ -992,8 +994,8 @@ class PrimGradChecker(PrimForwardChecker):
actual_ret[i],
self.eager_desire[i],
)
)
raise RuntimeError(msg)
),
)
core.set_prim_eager_enabled(False)
def check_static_comp(self):
......@@ -1075,10 +1077,12 @@ class PrimGradChecker(PrimForwardChecker):
)
raise RuntimeError(msg)
for i in range(len(actual_ret)):
if not np.allclose(
actual_ret[i], self.eager_desire[i], rtol=rtol, atol=atol
):
msg = (
np.testing.assert_allclose(
actual_ret[i],
self.eager_desire[i],
rtol=rtol,
atol=atol,
err_msg=(
'Check static comp grad out failed. Mismatch between static comp '
'and eager on %s, when enable_fw_comp is %s,enable_rev_comp is %s,the forward api out tensor\'s index is : %d \n'
'static comp grad out tensor:\n%s\n eager grad out tensor:\n%s\n'
......@@ -1090,8 +1094,8 @@ class PrimGradChecker(PrimForwardChecker):
actual_ret[i],
self.eager_desire[i],
)
)
raise RuntimeError(msg)
),
)
core._set_prim_forward_enabled(False)
core._set_prim_backward_enabled(False)
paddle.disable_static()
......@@ -1179,10 +1183,12 @@ class PrimGradChecker(PrimForwardChecker):
)
raise RuntimeError(msg)
for i in range(len(ret)):
if not np.allclose(
ret[i], self.eager_desire[i], rtol=rtol, atol=atol
):
msg = (
np.testing.assert_allclose(
ret[i],
self.eager_desire[i],
rtol=rtol,
atol=atol,
err_msg=(
'Check jit comp grad out failed. Mismatch between jit comp '
'and eager on %s, when enable_fw_comp is %s, enable_rev_comp is %s,the grad out tensor\'s index is : %d \n'
'jit comp grad out tensor:\n%s\n eager grad out out tensor:\n%s\n'
......@@ -1194,8 +1200,8 @@ class PrimGradChecker(PrimForwardChecker):
ret[i],
self.eager_desire[i],
)
)
raise RuntimeError(msg)
),
)
core._set_prim_forward_enabled(False)
core._set_prim_backward_enabled(False)
net.forward.program_cache.clear()
......@@ -1297,10 +1303,12 @@ class PrimGradChecker(PrimForwardChecker):
)
raise RuntimeError(msg)
for i in range(len(ret)):
if not np.allclose(
ret[i], self.eager_desire[i], rtol=rtol, atol=atol
):
msg = (
np.testing.assert_allclose(
ret[i],
self.eager_desire[i],
rtol=rtol,
atol=atol,
err_msg=(
'Check jit comp with cinn grad out failed. Mismatch between jit comp with cinn '
'and eager on %s, when enable_fw_comp is %s, enable_rev_comp is %s, enable_cinn is %s,'
'the grad out tensor\'s index is : %d ,jit comp with cinn grad out tensor:\n%s\n eager grad out out tensor:\n%s\n'
......@@ -1313,8 +1321,9 @@ class PrimGradChecker(PrimForwardChecker):
ret[i],
self.eager_desire[i],
)
)
raise RuntimeError(msg)
),
)
core._set_prim_forward_enabled(False)
core._set_prim_backward_enabled(False)
net.forward.program_cache.clear()
......@@ -2038,9 +2038,10 @@ class TestGeluApproximate(TestActivation):
self.outputs = {'Out': out}
self.attrs = {"approximate": approximate}
# The backward decomposite of gelu is inconsistent with raw kernel,
# lower threshold to support 1e-5 for pass the unittest
self.rev_comp_rtol = 1e-5
# The backward decomposite of gelu is inconsistent with raw kernel on
# cpu device, lower threshold to support 1e-8 for pass the unittest
self.rev_comp_rtol = 1e-8
self.rev_comp_atol = 1e-8
def test_check_output(self):
self.check_output(check_prim=True)
......@@ -2068,9 +2069,10 @@ class TestGelu(TestActivation):
self.inputs = {'X': x}
self.outputs = {'Out': out}
self.attrs = {"approximate": approximate}
# The backward decomposite of gelu is inconsistent with raw kernel,
# lower threshold to support 1e-5 for pass the unittest
self.rev_comp_rtol = 1e-5
# The backward decomposite of gelu is inconsistent with raw kernel on
# cpu, lower threshold to support 1e-8 for pass the unittest
self.rev_comp_rtol = 1e-8
self.rev_comp_atol = 1e-8
def if_enable_cinn(self):
self.enable_cinn = False
......@@ -2104,9 +2106,10 @@ class TestGELUAPI(unittest.TestCase):
)
self.enable_cinn = False
# The backward decomposite of gelu is inconsistent with raw kernel,
# lower threshold to support 1e-5 for pass the unittest
self.rev_comp_rtol = 1e-5
# The backward decomposite of gelu is inconsistent with raw kernel on
# cpu, lower threshold to support 1e-8 for pass the unittest
self.rev_comp_rtol = 1e-8
self.rev_comp_atol = 1e-8
def test_static_api(self):
with paddle_static_guard():
......@@ -3870,11 +3873,17 @@ def create_test_act_fp16_class(
check_prim=False,
enable_cinn=True,
grad_atol=0.80,
**kwargs
):
@unittest.skipIf(
not paddle.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestActFp16(parent):
def setUp(self):
super().setUp()
for k, v in kwargs.items():
setattr(self, k, v)
def init_dtype(self):
self.dtype = np.float16
......@@ -3937,7 +3946,13 @@ create_test_act_fp16_class(TestAsinh, grad_atol=0.85)
create_test_act_fp16_class(TestAtanh, grad_atol=0.85)
create_test_act_fp16_class(TestRound, grad_check=False)
create_test_act_fp16_class(TestRelu, check_prim=True)
create_test_act_fp16_class(TestGelu, check_prim=True, enable_cinn=False)
create_test_act_fp16_class(
TestGelu,
check_prim=True,
enable_cinn=False,
rev_comp_rtol=1e-3,
rev_comp_atol=1e-3,
)
create_test_act_fp16_class(TestBRelu)
create_test_act_fp16_class(TestRelu6)
create_test_act_fp16_class(TestSoftRelu, check_dygraph=False, grad_atol=0.85)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册