diff --git a/python/paddle/fluid/tests/unittests/op_test_xpu.py b/python/paddle/fluid/tests/unittests/op_test_xpu.py index 50ea065209422d2c972e480fbbd9a9442b5e5c25..6c964a828eed7eb01bce68b81baab61c66c5cf43 100644 --- a/python/paddle/fluid/tests/unittests/op_test_xpu.py +++ b/python/paddle/fluid/tests/unittests/op_test_xpu.py @@ -123,17 +123,26 @@ class XPUOpTest(OpTest): return super().check_grad_with_place( place, inputs_to_check, output_names, no_grad_set, numeric_grad_delta, in_place, max_relative_error, - user_defined_grads, user_defined_grads, check_dygraph) + user_defined_grads, user_defined_grad_outputs, check_dygraph) a1 = self.get_grad_with_place( - place, inputs_to_check, output_names, no_grad_set=no_grad_set) + place, + inputs_to_check, + output_names, + no_grad_set=no_grad_set, + user_defined_grad_outputs=user_defined_grad_outputs) a2 = self.get_grad_with_place( - place, inputs_to_check, output_names, no_grad_set=no_grad_set) + place, + inputs_to_check, + output_names, + no_grad_set=no_grad_set, + user_defined_grad_outputs=user_defined_grad_outputs) a3 = self.get_grad_with_place( paddle.CPUPlace(), inputs_to_check, output_names, - no_grad_set=no_grad_set) + no_grad_set=no_grad_set, + user_defined_grad_outputs=user_defined_grad_outputs) self._assert_is_close(a1, a2, inputs_to_check, 0.00000001, "Gradient Check On two xpu") self._assert_is_close(a1, a3, inputs_to_check, max_relative_error, @@ -147,7 +156,7 @@ class XPUOpTest(OpTest): numeric_grad_delta=0.005, in_place=False, max_relative_error=0.005, - user_defined_grads=None, + user_defined_grad_outputs=None, check_dygraph=True): self.scope = core.Scope() op_inputs = self.inputs if hasattr(self, "inputs") else dict() @@ -197,6 +206,10 @@ class XPUOpTest(OpTest): if not type(output_names) is list: output_names = [output_names] - analytic_grads = self._get_gradient(inputs_to_check, place, - output_names, no_grad_set) + analytic_grads = self._get_gradient( + inputs_to_check, + place, + output_names, + no_grad_set, + user_defined_grad_outputs=user_defined_grad_outputs) return analytic_grads