未验证 提交 d1a98f0b 编写于 作者: T TTerror 提交者: GitHub

fix xpu op test, *test=kunlun (#40409)

上级 dce87e3d
......@@ -123,17 +123,26 @@ class XPUOpTest(OpTest):
return super().check_grad_with_place(
place, inputs_to_check, output_names, no_grad_set,
numeric_grad_delta, in_place, max_relative_error,
user_defined_grads, user_defined_grads, check_dygraph)
user_defined_grads, user_defined_grad_outputs, check_dygraph)
a1 = self.get_grad_with_place(
place, inputs_to_check, output_names, no_grad_set=no_grad_set)
place,
inputs_to_check,
output_names,
no_grad_set=no_grad_set,
user_defined_grad_outputs=user_defined_grad_outputs)
a2 = self.get_grad_with_place(
place, inputs_to_check, output_names, no_grad_set=no_grad_set)
place,
inputs_to_check,
output_names,
no_grad_set=no_grad_set,
user_defined_grad_outputs=user_defined_grad_outputs)
a3 = self.get_grad_with_place(
paddle.CPUPlace(),
inputs_to_check,
output_names,
no_grad_set=no_grad_set)
no_grad_set=no_grad_set,
user_defined_grad_outputs=user_defined_grad_outputs)
self._assert_is_close(a1, a2, inputs_to_check, 0.00000001,
"Gradient Check On two xpu")
self._assert_is_close(a1, a3, inputs_to_check, max_relative_error,
......@@ -147,7 +156,7 @@ class XPUOpTest(OpTest):
numeric_grad_delta=0.005,
in_place=False,
max_relative_error=0.005,
user_defined_grads=None,
user_defined_grad_outputs=None,
check_dygraph=True):
self.scope = core.Scope()
op_inputs = self.inputs if hasattr(self, "inputs") else dict()
......@@ -197,6 +206,10 @@ class XPUOpTest(OpTest):
if not type(output_names) is list:
output_names = [output_names]
analytic_grads = self._get_gradient(inputs_to_check, place,
output_names, no_grad_set)
analytic_grads = self._get_gradient(
inputs_to_check,
place,
output_names,
no_grad_set,
user_defined_grad_outputs=user_defined_grad_outputs)
return analytic_grads
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册