未验证 提交 6f30b14f 编写于 作者: Z Zhang Zheng 提交者: GitHub

[AMP OP&Test] Modify the logic of comparing grad in bfloat16 (#51345)

* [AMP OP&Test] Modify the logic of comparing grad in bfloat16
上级 04f56338
...@@ -446,6 +446,26 @@ class OpTest(unittest.TestCase): ...@@ -446,6 +446,26 @@ class OpTest(unittest.TestCase):
) )
) )
def is_float16_op(self):
# self.dtype is the dtype of inputs, and is set in infer_dtype_from_inputs_outputs.
# Make sure this function is called after calling infer_dtype_from_inputs_outputs.
return (
self.dtype == np.float16
or (
hasattr(self, 'output_dtype')
and self.output_dtype == np.float16
)
or (
hasattr(self, 'mkldnn_data_type')
and getattr(self, 'mkldnn_data_type') == "float16"
)
or (
hasattr(self, 'attrs')
and 'mkldnn_data_type' in self.attrs
and self.attrs['mkldnn_data_type'] == 'float16'
)
)
def is_mkldnn_op(self): def is_mkldnn_op(self):
return (hasattr(self, "use_mkldnn") and self.use_mkldnn) or ( return (hasattr(self, "use_mkldnn") and self.use_mkldnn) or (
hasattr(self, "attrs") hasattr(self, "attrs")
...@@ -1868,67 +1888,94 @@ class OpTest(unittest.TestCase): ...@@ -1868,67 +1888,94 @@ class OpTest(unittest.TestCase):
names, names,
max_relative_error, max_relative_error,
msg_prefix, msg_prefix,
atol=1e-5,
): ):
for a, b, name in zip(numeric_grads, analytic_grads, names): for a, b, name in zip(numeric_grads, analytic_grads, names):
# It asserts np.abs(a - b) / np.abs(a) < max_relative_error, in which # Used by bfloat16 for now to solve precision problem
# max_relative_error is 1e-7. According to the value of np.abs(a), we if self.is_bfloat16_op():
# change np.abs(a) to achieve dynamic threshold. For example, if if a.size == 0:
# the value of np.abs(a) is between 1e-10 and 1e-8, we set np.abs(a)*=1e4. self.assertTrue(b.size == 0)
# Therefore, it asserts np.abs(a - b) / (np.abs(a)*1e4) < max_relative_error, np.testing.assert_allclose(
# which is the same as np.abs(a - b) / np.abs(a) < max_relative_error*1e4. b,
abs_a = np.abs(a) a,
if abs_a.ndim > 0: rtol=max_relative_error,
if ( atol=atol,
self.dtype == np.float64 equal_nan=False,
and self.op_type err_msg=(
not in op_threshold_white_list.NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST "Operator %s error, %s variable %s (shape: %s, dtype: %s) max gradient diff over limit"
): )
abs_a[abs_a < 1e-10] = 1e-3 % (
abs_a[np.logical_and(abs_a > 1e-10, abs_a <= 1e-8)] *= 1e4 self.op_type,
abs_a[np.logical_and(abs_a > 1e-8, abs_a <= 1e-6)] *= 1e2 msg_prefix,
elif self.is_bfloat16_op(): name,
abs_a[abs_a < 1e-2] = 1 str(a.shape),
else: self.dtype,
abs_a[abs_a < 1e-3] = 1 ),
elif abs_a.ndim == 0:
if (
self.dtype == np.float64
and self.op_type
not in op_threshold_white_list.NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST
):
if abs_a < 1e-10:
abs_a = 1e-3
elif abs_a > 1e-10 and abs_a <= 1e-8:
abs_a = abs_a * 1e4
elif abs_a > 1e-8 and abs_a <= 1e-6:
abs_a = abs_a * 1e2
elif self.is_bfloat16_op():
abs_a = 1 if abs_a < 1e-2 else abs_a
else:
abs_a = 1 if abs_a < 1e-3 else abs_a
diff_mat = np.abs(a - b) / abs_a
max_diff = np.max(diff_mat)
def err_msg():
offset = np.argmax(diff_mat > max_relative_error)
return (
"Operator %s error, %s variable %s (shape: %s, dtype: %s) max gradient diff %e over limit %e, "
"the first error element is %d, expected %e, but got %e."
) % (
self.op_type,
msg_prefix,
name,
str(a.shape),
self.dtype,
max_diff,
max_relative_error,
offset,
a.flatten()[offset],
b.flatten()[offset],
) )
else:
# It asserts np.abs(a - b) / np.abs(a) < max_relative_error, in which
# max_relative_error is 1e-7. According to the value of np.abs(a), we
# change np.abs(a) to achieve dynamic threshold. For example, if
# the value of np.abs(a) is between 1e-10 and 1e-8, we set np.abs(a)*=1e4.
# Therefore, it asserts np.abs(a - b) / (np.abs(a)*1e4) < max_relative_error,
# which is the same as np.abs(a - b) / np.abs(a) < max_relative_error*1e4.
abs_a = np.abs(a)
if abs_a.ndim > 0:
if (
self.dtype == np.float64
and self.op_type
not in op_threshold_white_list.NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST
):
abs_a[abs_a < 1e-10] = 1e-3
abs_a[
np.logical_and(abs_a > 1e-10, abs_a <= 1e-8)
] *= 1e4
abs_a[
np.logical_and(abs_a > 1e-8, abs_a <= 1e-6)
] *= 1e2
elif self.is_bfloat16_op():
abs_a[abs_a < 1e-2] = 1
else:
abs_a[abs_a < 1e-3] = 1
elif abs_a.ndim == 0:
if (
self.dtype == np.float64
and self.op_type
not in op_threshold_white_list.NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST
):
if abs_a < 1e-10:
abs_a = 1e-3
elif abs_a > 1e-10 and abs_a <= 1e-8:
abs_a = abs_a * 1e4
elif abs_a > 1e-8 and abs_a <= 1e-6:
abs_a = abs_a * 1e2
elif self.is_bfloat16_op():
abs_a = 1 if abs_a < 1e-2 else abs_a
else:
abs_a = 1 if abs_a < 1e-3 else abs_a
diff_mat = np.abs(a - b) / abs_a
max_diff = np.max(diff_mat)
def err_msg():
offset = np.argmax(diff_mat > max_relative_error)
return (
"Operator %s error, %s variable %s (shape: %s, dtype: %s) max gradient diff %e over limit %e, "
"the first error element is %d, expected %e, but got %e."
) % (
self.op_type,
msg_prefix,
name,
str(a.shape),
self.dtype,
max_diff,
max_relative_error,
offset,
a.flatten()[offset],
b.flatten()[offset],
)
self.assertLessEqual(max_diff, max_relative_error, err_msg()) self.assertLessEqual(max_diff, max_relative_error, err_msg())
def _check_grad_helper(self): def _check_grad_helper(self):
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs) self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
...@@ -1950,6 +1997,7 @@ class OpTest(unittest.TestCase): ...@@ -1950,6 +1997,7 @@ class OpTest(unittest.TestCase):
check_dygraph=True, check_dygraph=True,
check_prim=False, check_prim=False,
only_check_prim=False, only_check_prim=False,
atol=1e-5,
): ):
self._check_grad_helper() self._check_grad_helper()
places = self._get_places() places = self._get_places()
...@@ -1967,6 +2015,7 @@ class OpTest(unittest.TestCase): ...@@ -1967,6 +2015,7 @@ class OpTest(unittest.TestCase):
check_dygraph=check_dygraph, check_dygraph=check_dygraph,
check_prim=check_prim, check_prim=check_prim,
only_check_prim=only_check_prim, only_check_prim=only_check_prim,
atol=atol,
) )
def check_grad_with_place( def check_grad_with_place(
...@@ -1984,6 +2033,7 @@ class OpTest(unittest.TestCase): ...@@ -1984,6 +2033,7 @@ class OpTest(unittest.TestCase):
check_prim=False, check_prim=False,
only_check_prim=False, only_check_prim=False,
numeric_place=None, numeric_place=None,
atol=1e-5,
): ):
core._set_prim_all_enabled(False) core._set_prim_all_enabled(False)
core.set_prim_eager_enabled(False) core.set_prim_eager_enabled(False)
...@@ -2008,8 +2058,15 @@ class OpTest(unittest.TestCase): ...@@ -2008,8 +2058,15 @@ class OpTest(unittest.TestCase):
op_attrs = self.attrs if hasattr(self, "attrs") else dict() op_attrs = self.attrs if hasattr(self, "attrs") else dict()
self._check_grad_helper() self._check_grad_helper()
if self.is_bfloat16_op() and self.is_mkldnn_op(): if self.is_bfloat16_op():
check_dygraph = False if self.is_mkldnn_op():
check_dygraph = False
atol = 1e-2 if atol < 1e-2 else atol
else:
atol = 1e-1 if atol < 1e-1 else atol
if self.is_float16_op():
atol = 1e-3 if atol < 1e-3 else atol
if ( if (
self.dtype == np.float64 self.dtype == np.float64
...@@ -2122,6 +2179,7 @@ class OpTest(unittest.TestCase): ...@@ -2122,6 +2179,7 @@ class OpTest(unittest.TestCase):
inputs_to_check, inputs_to_check,
max_relative_error, max_relative_error,
"Gradient Check On %s" % str(place), "Gradient Check On %s" % str(place),
atol=atol,
) )
if check_dygraph: if check_dygraph:
...@@ -2151,6 +2209,7 @@ class OpTest(unittest.TestCase): ...@@ -2151,6 +2209,7 @@ class OpTest(unittest.TestCase):
inputs_to_check, inputs_to_check,
max_relative_error, max_relative_error,
"Gradient Check On %s" % str(place), "Gradient Check On %s" % str(place),
atol=atol,
) )
def _find_var_in_dygraph(self, output_vars, name): def _find_var_in_dygraph(self, output_vars, name):
......
...@@ -2115,71 +2115,98 @@ class OpTest(unittest.TestCase): ...@@ -2115,71 +2115,98 @@ class OpTest(unittest.TestCase):
names, names,
max_relative_error, max_relative_error,
msg_prefix, msg_prefix,
atol=1e-5,
): ):
for a, b, name in zip(numeric_grads, analytic_grads, names): for a, b, name in zip(numeric_grads, analytic_grads, names):
# It asserts np.abs(a - b) / np.abs(a) < max_relative_error, in which # Used by bfloat16 for now to solve precision problem
# max_relative_error is 1e-7. According to the value of np.abs(a), we if self.is_bfloat16_op():
# change np.abs(a) to achieve dynamic threshold. For example, if if a.size == 0:
# the value of np.abs(a) is between 1e-10 and 1e-8, we set np.abs(a)*=1e4. self.assertTrue(b.size == 0)
# Therefore, it asserts np.abs(a - b) / (np.abs(a)*1e4) < max_relative_error, np.testing.assert_allclose(
# which is the same as np.abs(a - b) / np.abs(a) < max_relative_error*1e4. b,
a,
abs_a = np.abs(a) rtol=max_relative_error,
if abs_a.ndim > 0: atol=atol,
if ( equal_nan=False,
self.dtype == np.float64 err_msg=(
and self.op_type "Operator %s error, %s variable %s (shape: %s, dtype: %s) max gradient diff over limit"
not in op_threshold_white_list.NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST )
): % (
abs_a[abs_a < 1e-10] = 1e-3 self.op_type,
abs_a[np.logical_and(abs_a > 1e-10, abs_a <= 1e-8)] *= 1e4 msg_prefix,
abs_a[np.logical_and(abs_a > 1e-8, abs_a <= 1e-6)] *= 1e2 name,
elif self.is_bfloat16_op(): str(a.shape),
abs_a[abs_a < 1e-2] = 1 self.dtype,
else: ),
abs_a[abs_a < 1e-3] = 1
elif abs_a.ndim == 0:
if (
self.dtype == np.float64
and self.op_type
not in op_threshold_white_list.NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST
):
if abs_a < 1e-10:
abs_a = 1e-3
elif abs_a > 1e-10 and abs_a <= 1e-8:
abs_a = abs_a * 1e4
elif abs_a > 1e-8 and abs_a <= 1e-6:
abs_a = abs_a * 1e2
elif self.is_bfloat16_op():
abs_a = 1 if abs_a < 1e-2 else abs_a
else:
abs_a = 1 if abs_a < 1e-3 else abs_a
if self.dtype == np.bool:
diff_mat = np.abs(a ^ b) / abs_a
else:
diff_mat = np.abs(a - b) / abs_a
max_diff = np.max(diff_mat)
def err_msg():
offset = np.argmax(diff_mat > max_relative_error)
return (
"Operator %s error, %s variable %s (shape: %s, dtype: %s) max gradient diff %e over limit %e, "
"the first error element is %d, expected %e, but got %e."
) % (
self.op_type,
msg_prefix,
name,
str(a.shape),
self.dtype,
max_diff,
max_relative_error,
offset,
a.flatten()[offset],
b.flatten()[offset],
) )
else:
# It asserts np.abs(a - b) / np.abs(a) < max_relative_error, in which
# max_relative_error is 1e-7. According to the value of np.abs(a), we
# change np.abs(a) to achieve dynamic threshold. For example, if
# the value of np.abs(a) is between 1e-10 and 1e-8, we set np.abs(a)*=1e4.
# Therefore, it asserts np.abs(a - b) / (np.abs(a)*1e4) < max_relative_error,
# which is the same as np.abs(a - b) / np.abs(a) < max_relative_error*1e4.
abs_a = np.abs(a)
if abs_a.ndim > 0:
if (
self.dtype == np.float64
and self.op_type
not in op_threshold_white_list.NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST
):
abs_a[abs_a < 1e-10] = 1e-3
abs_a[
np.logical_and(abs_a > 1e-10, abs_a <= 1e-8)
] *= 1e4
abs_a[
np.logical_and(abs_a > 1e-8, abs_a <= 1e-6)
] *= 1e2
elif self.is_bfloat16_op():
abs_a[abs_a < 1e-2] = 1
else:
abs_a[abs_a < 1e-3] = 1
elif abs_a.ndim == 0:
if (
self.dtype == np.float64
and self.op_type
not in op_threshold_white_list.NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST
):
if abs_a < 1e-10:
abs_a = 1e-3
elif abs_a > 1e-10 and abs_a <= 1e-8:
abs_a = abs_a * 1e4
elif abs_a > 1e-8 and abs_a <= 1e-6:
abs_a = abs_a * 1e2
elif self.is_bfloat16_op():
abs_a = 1 if abs_a < 1e-2 else abs_a
else:
abs_a = 1 if abs_a < 1e-3 else abs_a
self.assertLessEqual(max_diff, max_relative_error, err_msg()) if self.dtype == np.bool:
diff_mat = np.abs(a ^ b) / abs_a
else:
diff_mat = np.abs(a - b) / abs_a
max_diff = np.max(diff_mat)
def err_msg():
offset = np.argmax(diff_mat > max_relative_error)
return (
"Operator %s error, %s variable %s (shape: %s, dtype: %s) max gradient diff %e over limit %e, "
"the first error element is %d, expected %e, but got %e."
) % (
self.op_type,
msg_prefix,
name,
str(a.shape),
self.dtype,
max_diff,
max_relative_error,
offset,
a.flatten()[offset],
b.flatten()[offset],
)
self.assertLessEqual(max_diff, max_relative_error, err_msg())
def _check_grad_helper(self): def _check_grad_helper(self):
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs) self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
...@@ -2202,6 +2229,7 @@ class OpTest(unittest.TestCase): ...@@ -2202,6 +2229,7 @@ class OpTest(unittest.TestCase):
check_eager=False, check_eager=False,
check_prim=False, check_prim=False,
only_check_prim=False, only_check_prim=False,
atol=1e-5,
): ):
# disable legacy dygraph check when check_eager is True # disable legacy dygraph check when check_eager is True
if check_eager: if check_eager:
...@@ -2224,6 +2252,7 @@ class OpTest(unittest.TestCase): ...@@ -2224,6 +2252,7 @@ class OpTest(unittest.TestCase):
check_eager=check_eager, check_eager=check_eager,
check_prim=check_prim, check_prim=check_prim,
only_check_prim=only_check_prim, only_check_prim=only_check_prim,
atol=atol,
) )
def check_grad_with_place( def check_grad_with_place(
...@@ -2242,6 +2271,7 @@ class OpTest(unittest.TestCase): ...@@ -2242,6 +2271,7 @@ class OpTest(unittest.TestCase):
check_eager=False, check_eager=False,
check_prim=False, check_prim=False,
only_check_prim=False, only_check_prim=False,
atol=1e-5,
): ):
core._set_prim_all_enabled(False) core._set_prim_all_enabled(False)
if check_prim: if check_prim:
...@@ -2269,9 +2299,16 @@ class OpTest(unittest.TestCase): ...@@ -2269,9 +2299,16 @@ class OpTest(unittest.TestCase):
op_attrs = self.attrs if hasattr(self, "attrs") else dict() op_attrs = self.attrs if hasattr(self, "attrs") else dict()
self._check_grad_helper() self._check_grad_helper()
if self.is_bfloat16_op() and self.is_mkldnn_op(): if self.is_bfloat16_op():
check_dygraph = False if self.is_mkldnn_op():
check_eager = False check_dygraph = False
check_eager = False
atol = 1e-2 if atol < 1e-2 else atol
else:
atol = 1e-1 if atol < 1e-1 else atol
if self.is_float16_op():
atol = 1e-3 if atol < 1e-3 else atol
if ( if (
self.dtype == np.float64 self.dtype == np.float64
...@@ -2396,6 +2433,7 @@ class OpTest(unittest.TestCase): ...@@ -2396,6 +2433,7 @@ class OpTest(unittest.TestCase):
inputs_to_check, inputs_to_check,
max_relative_error, max_relative_error,
"Gradient Check On %s" % str(place), "Gradient Check On %s" % str(place),
atol=atol,
) )
if check_dygraph: if check_dygraph:
...@@ -2427,6 +2465,7 @@ class OpTest(unittest.TestCase): ...@@ -2427,6 +2465,7 @@ class OpTest(unittest.TestCase):
inputs_to_check, inputs_to_check,
max_relative_error, max_relative_error,
"Gradient Check On %s" % str(place), "Gradient Check On %s" % str(place),
atol=atol,
) )
# ensure switch back eager dygraph # ensure switch back eager dygraph
g_disable_legacy_dygraph() g_disable_legacy_dygraph()
...@@ -2459,6 +2498,7 @@ class OpTest(unittest.TestCase): ...@@ -2459,6 +2498,7 @@ class OpTest(unittest.TestCase):
inputs_to_check, inputs_to_check,
max_relative_error, max_relative_error,
"Gradient Check On %s" % str(place), "Gradient Check On %s" % str(place),
atol=atol,
) )
def _find_var_in_dygraph(self, output_vars, name): def _find_var_in_dygraph(self, output_vars, name):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册