未验证 提交 153351e1 编写于 作者: Z Zhang Zheng 提交者: GitHub

[AMP OP&Test] Fix fp16 check_grad when user_defined_grads is not None (#51959)

* [AMP OP&Test] Fix fp16 check_grad when user_defined_grads are not None

* fix cond
上级 b74e00e1
...@@ -2407,6 +2407,17 @@ class OpTest(unittest.TestCase): ...@@ -2407,6 +2407,17 @@ class OpTest(unittest.TestCase):
if numeric_place is None: if numeric_place is None:
numeric_place = place numeric_place = place
if user_defined_grads is None and self.is_fp16_compared_with_fp32():
self.enable_cal_ref_output()
numeric_grads = self._get_gradient(
inputs_to_check,
place,
output_names,
no_grad_set,
user_defined_grad_outputs,
)
self.disable_cal_ref_output()
else:
numeric_grads = user_defined_grads or [ numeric_grads = user_defined_grads or [
get_numeric_gradient( get_numeric_gradient(
numeric_place, numeric_place,
...@@ -2421,17 +2432,6 @@ class OpTest(unittest.TestCase): ...@@ -2421,17 +2432,6 @@ class OpTest(unittest.TestCase):
for input_to_check in inputs_to_check for input_to_check in inputs_to_check
] ]
if self.is_fp16_compared_with_fp32():
self.enable_cal_ref_output()
numeric_grads = self._get_gradient(
inputs_to_check,
place,
output_names,
no_grad_set,
user_defined_grad_outputs,
)
self.disable_cal_ref_output()
analytic_grads = self._get_gradient( analytic_grads = self._get_gradient(
inputs_to_check, inputs_to_check,
place, place,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册