未验证 提交 145a6cbb 编写于 作者: V Vvsmile 提交者: GitHub

Adjust tolerance without modify grad (#51459)

上级 117df481
......@@ -1717,11 +1717,14 @@ class OpTest(unittest.TestCase):
if hasattr(self, 'force_fp32_output') and getattr(
self, 'force_fp32_output'
):
atol = 1e-2
atol = 1e-2 if atol < 1e-2 else atol
else:
atol = 2
atol = 2 if atol < 2 else atol
else:
atol = 1e-1
atol = 1e-1 if atol < 1e-1 else atol
if self.is_float16_op():
atol = 1e-3 if atol < 1e-3 else atol
if no_check_set is not None:
if (
......@@ -2171,7 +2174,7 @@ class OpTest(unittest.TestCase):
if grad.dtype == np.uint16:
grad = convert_uint16_to_float(grad)
max_relative_error = (
0.04 if max_relative_error < 0.04 else max_relative_error
0.01 if max_relative_error < 0.01 else max_relative_error
)
fp32_analytic_grads.append(grad)
analytic_grads = fp32_analytic_grads
......@@ -2181,11 +2184,16 @@ class OpTest(unittest.TestCase):
if grad.dtype == np.uint16:
grad = convert_uint16_to_float(grad)
max_relative_error = (
0.04 if max_relative_error < 0.04 else max_relative_error
0.01 if max_relative_error < 0.01 else max_relative_error
)
fp32_numeric_grads.append(grad)
numeric_grads = fp32_numeric_grads
if self.is_float16_op():
max_relative_error = (
0.001 if max_relative_error < 0.001 else max_relative_error
)
self._assert_is_close(
numeric_grads,
analytic_grads,
......
......@@ -1910,14 +1910,14 @@ class OpTest(unittest.TestCase):
if hasattr(self, 'force_fp32_output') and getattr(
self, 'force_fp32_output'
):
atol = 1e-2
atol = 1e-2 if atol < 1e-2 else atol
else:
atol = 2
atol = 2 if atol < 2 else atol
else:
atol = 1e-1
atol = 1e-2 if atol < 1e-2 else atol
if self.is_float16_op():
atol = 1e-3
atol = 1e-3 if atol < 1e-3 else atol
if no_check_set is not None:
if (
......@@ -2415,7 +2415,7 @@ class OpTest(unittest.TestCase):
if grad.dtype == np.uint16:
grad = convert_uint16_to_float(grad)
max_relative_error = (
0.04 if max_relative_error < 0.04 else max_relative_error
0.01 if max_relative_error < 0.01 else max_relative_error
)
fp32_analytic_grads.append(grad)
analytic_grads = fp32_analytic_grads
......@@ -2425,11 +2425,16 @@ class OpTest(unittest.TestCase):
if grad.dtype == np.uint16:
grad = convert_uint16_to_float(grad)
max_relative_error = (
0.04 if max_relative_error < 0.04 else max_relative_error
0.01 if max_relative_error < 0.01 else max_relative_error
)
fp32_numeric_grads.append(grad)
numeric_grads = fp32_numeric_grads
if self.is_float16_op():
max_relative_error = (
0.001 if max_relative_error < 0.001 else max_relative_error
)
self._assert_is_close(
numeric_grads,
analytic_grads,
......
......@@ -201,6 +201,35 @@ class TestElementwiseDivOpBF16(ElementwiseDivOp):
self.x_shape = [12, 13]
self.y_shape = [12, 13]
def test_check_gradient(self):
check_list = []
check_list.append(
{
'grad': ['X', 'Y'],
'no_grad': None,
'val_grad': [self.grad_x, self.grad_y],
}
)
check_list.append(
{'grad': ['Y'], 'no_grad': set('X'), 'val_grad': [self.grad_y]}
)
check_list.append(
{'grad': ['X'], 'no_grad': set('Y'), 'val_grad': [self.grad_x]}
)
for check_option in check_list:
check_args = [check_option['grad'], 'Out']
check_kwargs = {
'no_grad_set': check_option['no_grad'],
'user_defined_grads': check_option['val_grad'],
'user_defined_grad_outputs': [self.grad_out],
'check_dygraph': self.check_dygraph,
}
if self.place is None:
self.check_grad(*check_args, **check_kwargs)
else:
check_args.insert(0, self.place)
self.check_grad_with_place(*check_args, **check_kwargs)
# elementwise_pow does't support bfloat16
def if_check_prim(self):
self.check_prim = False
......
......@@ -115,7 +115,7 @@ class TestSumOp_bf16(OpTest):
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, check_eager=True)
self.check_output_with_place(place, check_eager=True, atol=0.1)
def test_check_grad(self):
place = core.CUDAPlace(0)
......
......@@ -171,7 +171,12 @@ class TestScaleBF16Op(OpTest):
self.check_output(check_eager=True)
def test_check_grad(self):
self.check_grad(['X'], 'Out', numeric_grad_delta=0.8, check_eager=True)
self.check_grad(
['X'],
'Out',
numeric_grad_delta=0.8,
check_eager=True,
)
@unittest.skipIf(
......
......@@ -356,9 +356,7 @@ class TestSumBF16Op(OpTest):
def test_check_grad(self):
# new dynamic graph mode does not support unit16 type
self.check_grad(
['x0'], 'Out', numeric_grad_delta=0.5, check_dygraph=False
)
self.check_grad(['x0'], 'Out', check_dygraph=False)
class API_Test_Add_n(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册