From 153351e150dd62f40ee7c664f905db6333105585 Mon Sep 17 00:00:00 2001 From: Zhang Zheng <32410583+ZzSean@users.noreply.github.com> Date: Wed, 22 Mar 2023 16:00:30 +0800 Subject: [PATCH] [AMP OP&Test] Fix fp16 check_grad when user_defined_grads is not None (#51959) * [AMP OP&Test] Fix fp16 check_grad when user_defined_grads are not None * fix cond --- .../paddle/fluid/tests/unittests/op_test.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 01e574ef173..0c24ec5406d 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -2407,21 +2407,7 @@ class OpTest(unittest.TestCase): if numeric_place is None: numeric_place = place - numeric_grads = user_defined_grads or [ - get_numeric_gradient( - numeric_place, - self.scope, - self.op, - self.inputs, - input_to_check, - output_names, - delta=numeric_grad_delta, - in_place=in_place, - ) - for input_to_check in inputs_to_check - ] - - if self.is_fp16_compared_with_fp32(): + if user_defined_grads is None and self.is_fp16_compared_with_fp32(): self.enable_cal_ref_output() numeric_grads = self._get_gradient( inputs_to_check, @@ -2431,6 +2417,20 @@ class OpTest(unittest.TestCase): user_defined_grad_outputs, ) self.disable_cal_ref_output() + else: + numeric_grads = user_defined_grads or [ + get_numeric_gradient( + numeric_place, + self.scope, + self.op, + self.inputs, + input_to_check, + output_names, + delta=numeric_grad_delta, + in_place=in_place, + ) + for input_to_check in inputs_to_check + ] analytic_grads = self._get_gradient( inputs_to_check, -- GitLab