未验证 提交 6f30b14f 编写于 作者: Z Zhang Zheng 提交者: GitHub

[AMP OP&Test] Modify the logic of comparing grad in bfloat16 (#51345)

* [AMP OP&Test] Modify the logic of comparing grad in bfloat16
上级 04f56338
......@@ -446,6 +446,26 @@ class OpTest(unittest.TestCase):
)
)
def is_float16_op(self):
# self.dtype is the dtype of inputs, and is set in infer_dtype_from_inputs_outputs.
# Make sure this function is called after calling infer_dtype_from_inputs_outputs.
return (
self.dtype == np.float16
or (
hasattr(self, 'output_dtype')
and self.output_dtype == np.float16
)
or (
hasattr(self, 'mkldnn_data_type')
and getattr(self, 'mkldnn_data_type') == "float16"
)
or (
hasattr(self, 'attrs')
and 'mkldnn_data_type' in self.attrs
and self.attrs['mkldnn_data_type'] == 'float16'
)
)
def is_mkldnn_op(self):
return (hasattr(self, "use_mkldnn") and self.use_mkldnn) or (
hasattr(self, "attrs")
......@@ -1868,8 +1888,31 @@ class OpTest(unittest.TestCase):
names,
max_relative_error,
msg_prefix,
atol=1e-5,
):
for a, b, name in zip(numeric_grads, analytic_grads, names):
# Used by bfloat16 for now to solve precision problem
if self.is_bfloat16_op():
if a.size == 0:
self.assertTrue(b.size == 0)
np.testing.assert_allclose(
b,
a,
rtol=max_relative_error,
atol=atol,
equal_nan=False,
err_msg=(
"Operator %s error, %s variable %s (shape: %s, dtype: %s) max gradient diff over limit"
)
% (
self.op_type,
msg_prefix,
name,
str(a.shape),
self.dtype,
),
)
else:
# It asserts np.abs(a - b) / np.abs(a) < max_relative_error, in which
# max_relative_error is 1e-7. According to the value of np.abs(a), we
# change np.abs(a) to achieve dynamic threshold. For example, if
......@@ -1884,8 +1927,12 @@ class OpTest(unittest.TestCase):
not in op_threshold_white_list.NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST
):
abs_a[abs_a < 1e-10] = 1e-3
abs_a[np.logical_and(abs_a > 1e-10, abs_a <= 1e-8)] *= 1e4
abs_a[np.logical_and(abs_a > 1e-8, abs_a <= 1e-6)] *= 1e2
abs_a[
np.logical_and(abs_a > 1e-10, abs_a <= 1e-8)
] *= 1e4
abs_a[
np.logical_and(abs_a > 1e-8, abs_a <= 1e-6)
] *= 1e2
elif self.is_bfloat16_op():
abs_a[abs_a < 1e-2] = 1
else:
......@@ -1950,6 +1997,7 @@ class OpTest(unittest.TestCase):
check_dygraph=True,
check_prim=False,
only_check_prim=False,
atol=1e-5,
):
self._check_grad_helper()
places = self._get_places()
......@@ -1967,6 +2015,7 @@ class OpTest(unittest.TestCase):
check_dygraph=check_dygraph,
check_prim=check_prim,
only_check_prim=only_check_prim,
atol=atol,
)
def check_grad_with_place(
......@@ -1984,6 +2033,7 @@ class OpTest(unittest.TestCase):
check_prim=False,
only_check_prim=False,
numeric_place=None,
atol=1e-5,
):
core._set_prim_all_enabled(False)
core.set_prim_eager_enabled(False)
......@@ -2008,8 +2058,15 @@ class OpTest(unittest.TestCase):
op_attrs = self.attrs if hasattr(self, "attrs") else dict()
self._check_grad_helper()
if self.is_bfloat16_op() and self.is_mkldnn_op():
if self.is_bfloat16_op():
if self.is_mkldnn_op():
check_dygraph = False
atol = 1e-2 if atol < 1e-2 else atol
else:
atol = 1e-1 if atol < 1e-1 else atol
if self.is_float16_op():
atol = 1e-3 if atol < 1e-3 else atol
if (
self.dtype == np.float64
......@@ -2122,6 +2179,7 @@ class OpTest(unittest.TestCase):
inputs_to_check,
max_relative_error,
"Gradient Check On %s" % str(place),
atol=atol,
)
if check_dygraph:
......@@ -2151,6 +2209,7 @@ class OpTest(unittest.TestCase):
inputs_to_check,
max_relative_error,
"Gradient Check On %s" % str(place),
atol=atol,
)
def _find_var_in_dygraph(self, output_vars, name):
......
......@@ -2115,8 +2115,31 @@ class OpTest(unittest.TestCase):
names,
max_relative_error,
msg_prefix,
atol=1e-5,
):
for a, b, name in zip(numeric_grads, analytic_grads, names):
# Used by bfloat16 for now to solve precision problem
if self.is_bfloat16_op():
if a.size == 0:
self.assertTrue(b.size == 0)
np.testing.assert_allclose(
b,
a,
rtol=max_relative_error,
atol=atol,
equal_nan=False,
err_msg=(
"Operator %s error, %s variable %s (shape: %s, dtype: %s) max gradient diff over limit"
)
% (
self.op_type,
msg_prefix,
name,
str(a.shape),
self.dtype,
),
)
else:
# It asserts np.abs(a - b) / np.abs(a) < max_relative_error, in which
# max_relative_error is 1e-7. According to the value of np.abs(a), we
# change np.abs(a) to achieve dynamic threshold. For example, if
......@@ -2132,8 +2155,12 @@ class OpTest(unittest.TestCase):
not in op_threshold_white_list.NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST
):
abs_a[abs_a < 1e-10] = 1e-3
abs_a[np.logical_and(abs_a > 1e-10, abs_a <= 1e-8)] *= 1e4
abs_a[np.logical_and(abs_a > 1e-8, abs_a <= 1e-6)] *= 1e2
abs_a[
np.logical_and(abs_a > 1e-10, abs_a <= 1e-8)
] *= 1e4
abs_a[
np.logical_and(abs_a > 1e-8, abs_a <= 1e-6)
] *= 1e2
elif self.is_bfloat16_op():
abs_a[abs_a < 1e-2] = 1
else:
......@@ -2202,6 +2229,7 @@ class OpTest(unittest.TestCase):
check_eager=False,
check_prim=False,
only_check_prim=False,
atol=1e-5,
):
# disable legacy dygraph check when check_eager is True
if check_eager:
......@@ -2224,6 +2252,7 @@ class OpTest(unittest.TestCase):
check_eager=check_eager,
check_prim=check_prim,
only_check_prim=only_check_prim,
atol=atol,
)
def check_grad_with_place(
......@@ -2242,6 +2271,7 @@ class OpTest(unittest.TestCase):
check_eager=False,
check_prim=False,
only_check_prim=False,
atol=1e-5,
):
core._set_prim_all_enabled(False)
if check_prim:
......@@ -2269,9 +2299,16 @@ class OpTest(unittest.TestCase):
op_attrs = self.attrs if hasattr(self, "attrs") else dict()
self._check_grad_helper()
if self.is_bfloat16_op() and self.is_mkldnn_op():
if self.is_bfloat16_op():
if self.is_mkldnn_op():
check_dygraph = False
check_eager = False
atol = 1e-2 if atol < 1e-2 else atol
else:
atol = 1e-1 if atol < 1e-1 else atol
if self.is_float16_op():
atol = 1e-3 if atol < 1e-3 else atol
if (
self.dtype == np.float64
......@@ -2396,6 +2433,7 @@ class OpTest(unittest.TestCase):
inputs_to_check,
max_relative_error,
"Gradient Check On %s" % str(place),
atol=atol,
)
if check_dygraph:
......@@ -2427,6 +2465,7 @@ class OpTest(unittest.TestCase):
inputs_to_check,
max_relative_error,
"Gradient Check On %s" % str(place),
atol=atol,
)
# ensure switch back eager dygraph
g_disable_legacy_dygraph()
......@@ -2459,6 +2498,7 @@ class OpTest(unittest.TestCase):
inputs_to_check,
max_relative_error,
"Gradient Check On %s" % str(place),
atol=atol,
)
def _find_var_in_dygraph(self, output_vars, name):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册