未验证 提交 1db9cd46 编写于 作者: T TTerror 提交者: GitHub

fix xpu op test, *test=kunlun (#40862)

上级 d43e8433
......@@ -38,6 +38,7 @@ from paddle.fluid import unique_name
from white_list import op_accuracy_white_list, check_shape_white_list, compile_vs_runtime_white_list, no_check_set_white_list
from white_list import op_threshold_white_list, no_grad_set_white_list
from op_test import OpTest, _set_use_system_allocator, get_numeric_gradient
from xpu.get_test_cover_info import is_empty_grad_op_type
class XPUOpTest(OpTest):
......@@ -108,6 +109,13 @@ class XPUOpTest(OpTest):
check_dygraph=True,
numeric_place=None,
check_eager=False):
if hasattr(self, 'op_type_need_check_grad'):
xpu_version = core.get_xpu_device_version(0)
if is_empty_grad_op_type(xpu_version, self.op_type,
self.in_type_str):
self._check_grad_helper()
return
if place == None:
place = paddle.XPUPlace(0)
......
......@@ -208,7 +208,7 @@ def is_empty_grad_op_type(xpu_version, op, test_type):
if grad_op not in xpu_op_list.keys():
return True
grad_op_types = xpu_op_list[op]
grad_op_types = xpu_op_list[grad_op]
paddle_test_type = type_dict_str_to_paddle[test_type]
if paddle_test_type not in grad_op_types:
return True
......@@ -239,9 +239,11 @@ def create_test_class(func_globals,
continue
class_obj = test_class[1]
cls_name = "{0}_{1}".format(test_class[0], str(test_type))
func_globals[cls_name] = type(
cls_name, (class_obj, ),
{'in_type': type_dict_str_to_numpy[test_type]})
func_globals[cls_name] = type(cls_name, (class_obj, ), {
'in_type': type_dict_str_to_numpy[test_type],
'in_type_str': test_type,
'op_type_need_check_grad': True
})
if hasattr(test_class_obj, 'use_dynamic_create_class'
) and test_class_obj.use_dynamic_create_class:
......@@ -250,6 +252,8 @@ def create_test_class(func_globals,
cls_name = "{0}_{1}".format(dy_class[0], str(test_type))
attr_dict = dy_class[1]
attr_dict['in_type'] = type_dict_str_to_numpy[test_type]
attr_dict['in_type_str'] = test_type
attr_dict['op_type_need_check_grad'] = True
func_globals[cls_name] = type(cls_name, (base_class, ), attr_dict)
record_op_test(op_name, test_type)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册