未验证 提交 cefbf800 编写于 作者: T taixiurong 提交者: GitHub

xpu-paddlepaddle-30 [任务] dropout paddle单测, test=kunlun (#43716)

上级 d4b44015
......@@ -84,7 +84,7 @@ type_dict_str_to_numpy = {
xpu_test_op_white_list = []
xpu_test_type_white_list = ['float64']
xpu_test_op_type_white_list = []
xpu_test_op_type_white_list = ['dropout_float16', 'dropout_grad_float16']
xpu_test_device_op_white_list = []
xpu_test_device_op_type_white_list = []
......@@ -159,8 +159,10 @@ def make_xpu_op_list(xpu_version):
for op_type in type_list:
if op_type == paddle.bfloat16:
op_type = paddle.bfloat16
if op_type in type_white_list or op_type not in type_dict_paddle_to_str.keys(
):
if type_dict_paddle_to_str[
op_type] in type_white_list or op_type not in type_dict_paddle_to_str.keys(
):
continue
device_op_type_name = device_op_name + '_' + type_dict_paddle_to_str[
......@@ -187,10 +189,14 @@ def get_xpu_op_support_types(op_name, dev_id=0):
type_dict_paddle_to_str[paddle.bfloat16])
else:
support_type_str_list.append(type_dict_paddle_to_str[stype])
type_white_list = get_type_white_list()
return [
stype for stype in support_type_str_list if stype not in type_white_list
]
ops = make_xpu_op_list(xpu_version)
support_types = []
for stype in support_type_str_list:
op_name_type = op_name + "_" + stype
if op_name_type in ops:
support_types.append(stype)
return support_types
def record_op_test(op_name, test_type):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册