未验证 提交 153351e1 编写于 作者: Z Zhang Zheng 提交者: GitHub

[AMP OP&Test] Fix fp16 check_grad when user_defined_grads is not None (#51959)

* [AMP OP&Test] Fix fp16 check_grad when user_defined_grads are not None

* fix cond
上级 b74e00e1
...@@ -2407,21 +2407,7 @@ class OpTest(unittest.TestCase): ...@@ -2407,21 +2407,7 @@ class OpTest(unittest.TestCase):
if numeric_place is None: if numeric_place is None:
numeric_place = place numeric_place = place
numeric_grads = user_defined_grads or [ if user_defined_grads is None and self.is_fp16_compared_with_fp32():
get_numeric_gradient(
numeric_place,
self.scope,
self.op,
self.inputs,
input_to_check,
output_names,
delta=numeric_grad_delta,
in_place=in_place,
)
for input_to_check in inputs_to_check
]
if self.is_fp16_compared_with_fp32():
self.enable_cal_ref_output() self.enable_cal_ref_output()
numeric_grads = self._get_gradient( numeric_grads = self._get_gradient(
inputs_to_check, inputs_to_check,
...@@ -2431,6 +2417,20 @@ class OpTest(unittest.TestCase): ...@@ -2431,6 +2417,20 @@ class OpTest(unittest.TestCase):
user_defined_grad_outputs, user_defined_grad_outputs,
) )
self.disable_cal_ref_output() self.disable_cal_ref_output()
else:
numeric_grads = user_defined_grads or [
get_numeric_gradient(
numeric_place,
self.scope,
self.op,
self.inputs,
input_to_check,
output_names,
delta=numeric_grad_delta,
in_place=in_place,
)
for input_to_check in inputs_to_check
]
analytic_grads = self._get_gradient( analytic_grads = self._get_gradient(
inputs_to_check, inputs_to_check,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册