未验证 提交 d1a98f0b 编写于 作者: T TTerror 提交者: GitHub

fix xpu op test, *test=kunlun (#40409)

上级 dce87e3d
...@@ -123,17 +123,26 @@ class XPUOpTest(OpTest): ...@@ -123,17 +123,26 @@ class XPUOpTest(OpTest):
return super().check_grad_with_place( return super().check_grad_with_place(
place, inputs_to_check, output_names, no_grad_set, place, inputs_to_check, output_names, no_grad_set,
numeric_grad_delta, in_place, max_relative_error, numeric_grad_delta, in_place, max_relative_error,
user_defined_grads, user_defined_grads, check_dygraph) user_defined_grads, user_defined_grad_outputs, check_dygraph)
a1 = self.get_grad_with_place( a1 = self.get_grad_with_place(
place, inputs_to_check, output_names, no_grad_set=no_grad_set) place,
inputs_to_check,
output_names,
no_grad_set=no_grad_set,
user_defined_grad_outputs=user_defined_grad_outputs)
a2 = self.get_grad_with_place( a2 = self.get_grad_with_place(
place, inputs_to_check, output_names, no_grad_set=no_grad_set) place,
inputs_to_check,
output_names,
no_grad_set=no_grad_set,
user_defined_grad_outputs=user_defined_grad_outputs)
a3 = self.get_grad_with_place( a3 = self.get_grad_with_place(
paddle.CPUPlace(), paddle.CPUPlace(),
inputs_to_check, inputs_to_check,
output_names, output_names,
no_grad_set=no_grad_set) no_grad_set=no_grad_set,
user_defined_grad_outputs=user_defined_grad_outputs)
self._assert_is_close(a1, a2, inputs_to_check, 0.00000001, self._assert_is_close(a1, a2, inputs_to_check, 0.00000001,
"Gradient Check On two xpu") "Gradient Check On two xpu")
self._assert_is_close(a1, a3, inputs_to_check, max_relative_error, self._assert_is_close(a1, a3, inputs_to_check, max_relative_error,
...@@ -147,7 +156,7 @@ class XPUOpTest(OpTest): ...@@ -147,7 +156,7 @@ class XPUOpTest(OpTest):
numeric_grad_delta=0.005, numeric_grad_delta=0.005,
in_place=False, in_place=False,
max_relative_error=0.005, max_relative_error=0.005,
user_defined_grads=None, user_defined_grad_outputs=None,
check_dygraph=True): check_dygraph=True):
self.scope = core.Scope() self.scope = core.Scope()
op_inputs = self.inputs if hasattr(self, "inputs") else dict() op_inputs = self.inputs if hasattr(self, "inputs") else dict()
...@@ -197,6 +206,10 @@ class XPUOpTest(OpTest): ...@@ -197,6 +206,10 @@ class XPUOpTest(OpTest):
if not type(output_names) is list: if not type(output_names) is list:
output_names = [output_names] output_names = [output_names]
analytic_grads = self._get_gradient(inputs_to_check, place, analytic_grads = self._get_gradient(
output_names, no_grad_set) inputs_to_check,
place,
output_names,
no_grad_set,
user_defined_grad_outputs=user_defined_grad_outputs)
return analytic_grads return analytic_grads
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册