From cefbf8000d440759f8ff374320580308e6f83e3f Mon Sep 17 00:00:00 2001 From: taixiurong Date: Thu, 23 Jun 2022 17:41:20 +0800 Subject: [PATCH] =?UTF-8?q?xpu-paddlepaddle-30=20[=E4=BB=BB=E5=8A=A1]=20dr?= =?UTF-8?q?opout=20paddle=E5=8D=95=E6=B5=8B,=20test=3Dkunlun=20(#43716)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../unittests/xpu/get_test_cover_info.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/xpu/get_test_cover_info.py b/python/paddle/fluid/tests/unittests/xpu/get_test_cover_info.py index 33a8482346..0c3056ca8a 100644 --- a/python/paddle/fluid/tests/unittests/xpu/get_test_cover_info.py +++ b/python/paddle/fluid/tests/unittests/xpu/get_test_cover_info.py @@ -84,7 +84,7 @@ type_dict_str_to_numpy = { xpu_test_op_white_list = [] xpu_test_type_white_list = ['float64'] -xpu_test_op_type_white_list = [] +xpu_test_op_type_white_list = ['dropout_float16', 'dropout_grad_float16'] xpu_test_device_op_white_list = [] xpu_test_device_op_type_white_list = [] @@ -159,8 +159,10 @@ def make_xpu_op_list(xpu_version): for op_type in type_list: if op_type == paddle.bfloat16: op_type = paddle.bfloat16 - if op_type in type_white_list or op_type not in type_dict_paddle_to_str.keys( - ): + + if type_dict_paddle_to_str[ + op_type] in type_white_list or op_type not in type_dict_paddle_to_str.keys( + ): continue device_op_type_name = device_op_name + '_' + type_dict_paddle_to_str[ @@ -187,10 +189,14 @@ def get_xpu_op_support_types(op_name, dev_id=0): type_dict_paddle_to_str[paddle.bfloat16]) else: support_type_str_list.append(type_dict_paddle_to_str[stype]) - type_white_list = get_type_white_list() - return [ - stype for stype in support_type_str_list if stype not in type_white_list - ] + ops = make_xpu_op_list(xpu_version) + support_types = [] + for stype in support_type_str_list: + op_name_type = op_name + "_" + stype + if op_name_type in ops: + support_types.append(stype) + + return support_types def record_op_test(op_name, test_type): -- GitLab