未验证 提交 92547a0e 编写于 作者: X Xiaoxu Chen 提交者: GitHub

【Prim】remove gelu precision threshold (#52350)

* [prim] promote gelu precision threshold
上级 1f3b9ef5
...@@ -626,13 +626,12 @@ class PrimForwardChecker: ...@@ -626,13 +626,12 @@ class PrimForwardChecker:
) )
raise RuntimeError(msg) raise RuntimeError(msg)
for i in range(len(ret)): for i in range(len(ret)):
if not np.allclose( np.testing.assert_allclose(
ret[i], ret[i],
self.eager_desire[i], self.eager_desire[i],
rtol=self.fw_comp_rtol, rtol=self.fw_comp_rtol,
atol=self.fw_comp_atol, atol=self.fw_comp_atol,
): err_msg=(
msg = (
'Check static comp forward api out failed. Mismatch between static comp ' 'Check static comp forward api out failed. Mismatch between static comp '
'and eager on %s, when enable_fw_comp is %s,the forward api out tensor\'s index is : %d \n' 'and eager on %s, when enable_fw_comp is %s,the forward api out tensor\'s index is : %d \n'
'static comp forward api out tensor:\n%s\n eager forward api out tensor:\n%s\n' 'static comp forward api out tensor:\n%s\n eager forward api out tensor:\n%s\n'
...@@ -643,8 +642,8 @@ class PrimForwardChecker: ...@@ -643,8 +642,8 @@ class PrimForwardChecker:
ret[i], ret[i],
self.eager_desire[i], self.eager_desire[i],
) )
) ),
raise RuntimeError(msg) )
paddle.disable_static() paddle.disable_static()
core._set_prim_forward_enabled(False) core._set_prim_forward_enabled(False)
...@@ -696,10 +695,12 @@ class PrimForwardChecker: ...@@ -696,10 +695,12 @@ class PrimForwardChecker:
) )
raise RuntimeError(msg) raise RuntimeError(msg)
for i in range(len(ret)): for i in range(len(ret)):
if not np.allclose( np.testing.assert_allclose(
ret[i], self.eager_desire[i], rtol=rtol, atol=atol ret[i],
): self.eager_desire[i],
msg = ( rtol=rtol,
atol=atol,
err_msg=(
'Check jit comp forward api out failed. Mismatch between jit comp ' 'Check jit comp forward api out failed. Mismatch between jit comp '
'and eager on %s, when enable_fw_comp is %s,the forward api out tensor\'s index is : %d \n' 'and eager on %s, when enable_fw_comp is %s,the forward api out tensor\'s index is : %d \n'
'jit comp forward api out tensor:\n%s\n eager forward api out tensor:\n%s\n' 'jit comp forward api out tensor:\n%s\n eager forward api out tensor:\n%s\n'
...@@ -710,8 +711,8 @@ class PrimForwardChecker: ...@@ -710,8 +711,8 @@ class PrimForwardChecker:
ret[i], ret[i],
self.eager_desire[i], self.eager_desire[i],
) )
) ),
raise RuntimeError(msg) )
core._set_prim_forward_enabled(False) core._set_prim_forward_enabled(False)
net.forward.program_cache.clear() net.forward.program_cache.clear()
...@@ -781,10 +782,12 @@ class PrimForwardChecker: ...@@ -781,10 +782,12 @@ class PrimForwardChecker:
) )
raise RuntimeError(msg) raise RuntimeError(msg)
for i in range(len(ret)): for i in range(len(ret)):
if not np.allclose( np.testing.assert_allclose(
ret[i], self.eager_desire[i], rtol=rtol, atol=atol ret[i],
): self.eager_desire[i],
msg = ( rtol=rtol,
atol=atol,
err_msg=(
'Check jit comp with cinn forward api out failed. Mismatch between jit comp and eager on %s, ' 'Check jit comp with cinn forward api out failed. Mismatch between jit comp and eager on %s, '
'when enable_fw_comp is %s, enable_cinn is %s, the forward api out tensor\'s index is : %d \n' 'when enable_fw_comp is %s, enable_cinn is %s, the forward api out tensor\'s index is : %d \n'
'jit comp forward api out tensor:\n%s\n eager forward api out tensor:\n%s\n' 'jit comp forward api out tensor:\n%s\n eager forward api out tensor:\n%s\n'
...@@ -796,8 +799,8 @@ class PrimForwardChecker: ...@@ -796,8 +799,8 @@ class PrimForwardChecker:
ret[i], ret[i],
self.eager_desire[i], self.eager_desire[i],
) )
) ),
raise RuntimeError(msg) )
core._set_prim_forward_enabled(False) core._set_prim_forward_enabled(False)
net.forward.program_cache.clear() net.forward.program_cache.clear()
...@@ -975,13 +978,12 @@ class PrimGradChecker(PrimForwardChecker): ...@@ -975,13 +978,12 @@ class PrimGradChecker(PrimForwardChecker):
) )
raise RuntimeError(msg) raise RuntimeError(msg)
for i in range(len(actual_ret)): for i in range(len(actual_ret)):
if not np.allclose( np.testing.assert_allclose(
actual_ret[i], actual_ret[i],
self.eager_desire[i], self.eager_desire[i],
rtol=atol, rtol=atol,
atol=rtol, atol=rtol,
): err_msg=(
msg = (
'Check eager comp grad out failed. Mismatch between eager comp ' 'Check eager comp grad out failed. Mismatch between eager comp '
'and eager on %s, when enable_rev_comp is %s,the eager comp grad out tensor\'s index is : %d \n' 'and eager on %s, when enable_rev_comp is %s,the eager comp grad out tensor\'s index is : %d \n'
'eager comp grad out tensor:\n%s\n eager grad out tensor:\n%s\n' 'eager comp grad out tensor:\n%s\n eager grad out tensor:\n%s\n'
...@@ -992,8 +994,8 @@ class PrimGradChecker(PrimForwardChecker): ...@@ -992,8 +994,8 @@ class PrimGradChecker(PrimForwardChecker):
actual_ret[i], actual_ret[i],
self.eager_desire[i], self.eager_desire[i],
) )
) ),
raise RuntimeError(msg) )
core.set_prim_eager_enabled(False) core.set_prim_eager_enabled(False)
def check_static_comp(self): def check_static_comp(self):
...@@ -1075,10 +1077,12 @@ class PrimGradChecker(PrimForwardChecker): ...@@ -1075,10 +1077,12 @@ class PrimGradChecker(PrimForwardChecker):
) )
raise RuntimeError(msg) raise RuntimeError(msg)
for i in range(len(actual_ret)): for i in range(len(actual_ret)):
if not np.allclose( np.testing.assert_allclose(
actual_ret[i], self.eager_desire[i], rtol=rtol, atol=atol actual_ret[i],
): self.eager_desire[i],
msg = ( rtol=rtol,
atol=atol,
err_msg=(
'Check static comp grad out failed. Mismatch between static comp ' 'Check static comp grad out failed. Mismatch between static comp '
'and eager on %s, when enable_fw_comp is %s,enable_rev_comp is %s,the forward api out tensor\'s index is : %d \n' 'and eager on %s, when enable_fw_comp is %s,enable_rev_comp is %s,the forward api out tensor\'s index is : %d \n'
'static comp grad out tensor:\n%s\n eager grad out tensor:\n%s\n' 'static comp grad out tensor:\n%s\n eager grad out tensor:\n%s\n'
...@@ -1090,8 +1094,8 @@ class PrimGradChecker(PrimForwardChecker): ...@@ -1090,8 +1094,8 @@ class PrimGradChecker(PrimForwardChecker):
actual_ret[i], actual_ret[i],
self.eager_desire[i], self.eager_desire[i],
) )
) ),
raise RuntimeError(msg) )
core._set_prim_forward_enabled(False) core._set_prim_forward_enabled(False)
core._set_prim_backward_enabled(False) core._set_prim_backward_enabled(False)
paddle.disable_static() paddle.disable_static()
...@@ -1179,10 +1183,12 @@ class PrimGradChecker(PrimForwardChecker): ...@@ -1179,10 +1183,12 @@ class PrimGradChecker(PrimForwardChecker):
) )
raise RuntimeError(msg) raise RuntimeError(msg)
for i in range(len(ret)): for i in range(len(ret)):
if not np.allclose( np.testing.assert_allclose(
ret[i], self.eager_desire[i], rtol=rtol, atol=atol ret[i],
): self.eager_desire[i],
msg = ( rtol=rtol,
atol=atol,
err_msg=(
'Check jit comp grad out failed. Mismatch between jit comp ' 'Check jit comp grad out failed. Mismatch between jit comp '
'and eager on %s, when enable_fw_comp is %s, enable_rev_comp is %s,the grad out tensor\'s index is : %d \n' 'and eager on %s, when enable_fw_comp is %s, enable_rev_comp is %s,the grad out tensor\'s index is : %d \n'
'jit comp grad out tensor:\n%s\n eager grad out out tensor:\n%s\n' 'jit comp grad out tensor:\n%s\n eager grad out out tensor:\n%s\n'
...@@ -1194,8 +1200,8 @@ class PrimGradChecker(PrimForwardChecker): ...@@ -1194,8 +1200,8 @@ class PrimGradChecker(PrimForwardChecker):
ret[i], ret[i],
self.eager_desire[i], self.eager_desire[i],
) )
) ),
raise RuntimeError(msg) )
core._set_prim_forward_enabled(False) core._set_prim_forward_enabled(False)
core._set_prim_backward_enabled(False) core._set_prim_backward_enabled(False)
net.forward.program_cache.clear() net.forward.program_cache.clear()
...@@ -1297,10 +1303,12 @@ class PrimGradChecker(PrimForwardChecker): ...@@ -1297,10 +1303,12 @@ class PrimGradChecker(PrimForwardChecker):
) )
raise RuntimeError(msg) raise RuntimeError(msg)
for i in range(len(ret)): for i in range(len(ret)):
if not np.allclose( np.testing.assert_allclose(
ret[i], self.eager_desire[i], rtol=rtol, atol=atol ret[i],
): self.eager_desire[i],
msg = ( rtol=rtol,
atol=atol,
err_msg=(
'Check jit comp with cinn grad out failed. Mismatch between jit comp with cinn ' 'Check jit comp with cinn grad out failed. Mismatch between jit comp with cinn '
'and eager on %s, when enable_fw_comp is %s, enable_rev_comp is %s, enable_cinn is %s,' 'and eager on %s, when enable_fw_comp is %s, enable_rev_comp is %s, enable_cinn is %s,'
'the grad out tensor\'s index is : %d ,jit comp with cinn grad out tensor:\n%s\n eager grad out out tensor:\n%s\n' 'the grad out tensor\'s index is : %d ,jit comp with cinn grad out tensor:\n%s\n eager grad out out tensor:\n%s\n'
...@@ -1313,8 +1321,9 @@ class PrimGradChecker(PrimForwardChecker): ...@@ -1313,8 +1321,9 @@ class PrimGradChecker(PrimForwardChecker):
ret[i], ret[i],
self.eager_desire[i], self.eager_desire[i],
) )
) ),
raise RuntimeError(msg) )
core._set_prim_forward_enabled(False) core._set_prim_forward_enabled(False)
core._set_prim_backward_enabled(False) core._set_prim_backward_enabled(False)
net.forward.program_cache.clear() net.forward.program_cache.clear()
...@@ -2038,9 +2038,10 @@ class TestGeluApproximate(TestActivation): ...@@ -2038,9 +2038,10 @@ class TestGeluApproximate(TestActivation):
self.outputs = {'Out': out} self.outputs = {'Out': out}
self.attrs = {"approximate": approximate} self.attrs = {"approximate": approximate}
# The backward decomposite of gelu is inconsistent with raw kernel, # The backward decomposite of gelu is inconsistent with raw kernel on
# lower threshold to support 1e-5 for pass the unittest # cpu device, lower threshold to support 1e-8 for pass the unittest
self.rev_comp_rtol = 1e-5 self.rev_comp_rtol = 1e-8
self.rev_comp_atol = 1e-8
def test_check_output(self): def test_check_output(self):
self.check_output(check_prim=True) self.check_output(check_prim=True)
...@@ -2068,9 +2069,10 @@ class TestGelu(TestActivation): ...@@ -2068,9 +2069,10 @@ class TestGelu(TestActivation):
self.inputs = {'X': x} self.inputs = {'X': x}
self.outputs = {'Out': out} self.outputs = {'Out': out}
self.attrs = {"approximate": approximate} self.attrs = {"approximate": approximate}
# The backward decomposite of gelu is inconsistent with raw kernel, # The backward decomposite of gelu is inconsistent with raw kernel on
# lower threshold to support 1e-5 for pass the unittest # cpu, lower threshold to support 1e-8 for pass the unittest
self.rev_comp_rtol = 1e-5 self.rev_comp_rtol = 1e-8
self.rev_comp_atol = 1e-8
def if_enable_cinn(self): def if_enable_cinn(self):
self.enable_cinn = False self.enable_cinn = False
...@@ -2104,9 +2106,10 @@ class TestGELUAPI(unittest.TestCase): ...@@ -2104,9 +2106,10 @@ class TestGELUAPI(unittest.TestCase):
) )
self.enable_cinn = False self.enable_cinn = False
# The backward decomposite of gelu is inconsistent with raw kernel, # The backward decomposite of gelu is inconsistent with raw kernel on
# lower threshold to support 1e-5 for pass the unittest # cpu, lower threshold to support 1e-8 for pass the unittest
self.rev_comp_rtol = 1e-5 self.rev_comp_rtol = 1e-8
self.rev_comp_atol = 1e-8
def test_static_api(self): def test_static_api(self):
with paddle_static_guard(): with paddle_static_guard():
...@@ -3870,11 +3873,17 @@ def create_test_act_fp16_class( ...@@ -3870,11 +3873,17 @@ def create_test_act_fp16_class(
check_prim=False, check_prim=False,
enable_cinn=True, enable_cinn=True,
grad_atol=0.80, grad_atol=0.80,
**kwargs
): ):
@unittest.skipIf( @unittest.skipIf(
not paddle.is_compiled_with_cuda(), "core is not compiled with CUDA" not paddle.is_compiled_with_cuda(), "core is not compiled with CUDA"
) )
class TestActFp16(parent): class TestActFp16(parent):
def setUp(self):
super().setUp()
for k, v in kwargs.items():
setattr(self, k, v)
def init_dtype(self): def init_dtype(self):
self.dtype = np.float16 self.dtype = np.float16
...@@ -3937,7 +3946,13 @@ create_test_act_fp16_class(TestAsinh, grad_atol=0.85) ...@@ -3937,7 +3946,13 @@ create_test_act_fp16_class(TestAsinh, grad_atol=0.85)
create_test_act_fp16_class(TestAtanh, grad_atol=0.85) create_test_act_fp16_class(TestAtanh, grad_atol=0.85)
create_test_act_fp16_class(TestRound, grad_check=False) create_test_act_fp16_class(TestRound, grad_check=False)
create_test_act_fp16_class(TestRelu, check_prim=True) create_test_act_fp16_class(TestRelu, check_prim=True)
create_test_act_fp16_class(TestGelu, check_prim=True, enable_cinn=False) create_test_act_fp16_class(
TestGelu,
check_prim=True,
enable_cinn=False,
rev_comp_rtol=1e-3,
rev_comp_atol=1e-3,
)
create_test_act_fp16_class(TestBRelu) create_test_act_fp16_class(TestBRelu)
create_test_act_fp16_class(TestRelu6) create_test_act_fp16_class(TestRelu6)
create_test_act_fp16_class(TestSoftRelu, check_dygraph=False, grad_atol=0.85) create_test_act_fp16_class(TestSoftRelu, check_dygraph=False, grad_atol=0.85)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册