未验证 提交 4f1fef60 编写于 作者: T TTerror 提交者: GitHub

refactor unittest for kunlun (#38772)

* refactor unittests for kunlun

* refactor unittests for kunlun, test=kunlun
上级 4e23ba32
......@@ -74,6 +74,32 @@ bool is_in_xpu_black_list(const std::string& op_name) {
return false;
}
std::vector<vartype::Type> get_xpu_op_support_type(const std::string& op_name,
XPUVersion version) {
std::vector<vartype::Type> res;
auto& ops = version == XPU1 ? get_kl1_ops() : get_kl2_ops();
if (ops.find(op_name) != ops.end()) {
XPUKernelSet& type_set = ops[op_name];
for (auto& item : type_set) {
res.push_back(item.data_type_);
}
}
return res;
}
XPUOpListMap get_xpu_op_list(XPUVersion version) {
XPUOpListMap res;
auto& ops = version == XPU1 ? get_kl1_ops() : get_kl2_ops();
for (auto& op : ops) {
std::vector<vartype::Type> op_vartypes;
for (auto& item : op.second) {
op_vartypes.push_back(item.data_type_);
}
res[op.first] = std::move(op_vartypes);
}
return res;
}
} // namespace platform
} // namespace paddle
#endif
......@@ -12,6 +12,7 @@ limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/op_kernel_type.h"
......@@ -19,10 +20,17 @@ namespace paddle {
namespace platform {
using pOpKernelType = paddle::framework::OpKernelType;
using vartype = paddle::framework::proto::VarType;
using XPUOpListMap =
std::unordered_map<std::string, std::vector<vartype::Type>>;
bool is_xpu_support_op(const std::string& op_name, const pOpKernelType& type);
bool is_in_xpu_black_list(const std::string& op_name);
std::vector<vartype::Type> get_xpu_op_support_type(const std::string& op_name,
XPUVersion version);
XPUOpListMap get_xpu_op_list(XPUVersion version);
} // namespace platform
} // namespace paddle
#endif
......@@ -129,6 +129,7 @@ limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/platform/device/xpu/xpu_info.h"
#include "paddle/fluid/platform/device/xpu/xpu_op_list.h"
#endif
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"
......@@ -1762,6 +1763,13 @@ All parameter, weight, gradient are variables in Paddle.
m.def("get_xpu_device_count", platform::GetXPUDeviceCount);
m.def("get_xpu_device_version",
[](int device_id) { return platform::get_xpu_version(device_id); });
m.def("get_xpu_device_op_support_types",
[](const std::string &op_name, platform::XPUVersion version) {
return platform::get_xpu_op_support_type(op_name, version);
});
m.def("get_xpu_device_op_list", [](platform::XPUVersion version) {
return platform::get_xpu_op_list(version);
});
m.def("is_float16_supported", [](const platform::XPUPlace &place) -> bool {
// XPUs with Compute Capability > xpu2 support float16 and bfloat16
return platform::get_xpu_version(place.device) > platform::XPUVersion::XPU1;
......
......@@ -1725,6 +1725,7 @@ function parallel_test_base_xpu() {
EOF
set +x
export XPU_OP_LIST_DIR=$tmp_dir
ut_startTime_s=`date +%s`
test_cases=$(ctest -N -V | grep "_xpu" ) # cases list which would be run exclusively
get_quickly_disable_ut||disable_ut_quickly='disable_ut' # indicate whether the case was in quickly disable list
......@@ -1747,6 +1748,8 @@ set -x
if [[ "$EXIT_CODE" != "0" ]]; then
exit 8;
fi
python ${PADDLE_ROOT}/build/python/paddle/fluid/tests/unittests/xpu/get_test_cover_info.py
unset XPU_OP_LIST_DIR
fi
}
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import inspect
import os
import fcntl
import paddle
import paddle.fluid.core as core
type_dict_paddle_to_str = {
paddle.bool: 'bool',
paddle.uint8: 'uint8',
paddle.int8: 'int8',
paddle.int16: 'int16',
paddle.int32: 'int32',
paddle.int64: 'int64',
paddle.float16: 'float16',
paddle.float32: 'float32',
paddle.float64: 'float64',
paddle.complex128: 'complex128',
paddle.complex64: 'complex64',
}
type_dict_str_to_paddle = {
'int32': paddle.int32,
'int64': paddle.int64,
'float32': paddle.float32,
'float16': paddle.float16,
'bool': paddle.bool,
'uint8': paddle.uint8,
'int8': paddle.int8,
'complex128': paddle.complex128,
'complex64': paddle.complex64,
'int16': paddle.int16,
}
xpu_test_op_white_list = []
xpu_test_type_white_list = []
xpu_test_op_type_white_list = []
xpu_test_device_op_white_list = []
xpu_test_device_op_type_white_list = []
class XPUOpTestWrapper(object):
def create_classes(self):
base_class = None
classes = []
return base_class, classes
def get_op_white_list():
op_white_list = xpu_test_op_white_list
if os.getenv('XPU_TEST_OP_WHITE_LIST') is not None:
op_white_list.extend(
os.getenv('XPU_TEST_OP_WHITE_LIST').strip().split(','))
return list(set(op_white_list))
def get_type_white_list():
type_white_list = xpu_test_type_white_list
if os.getenv('XPU_TEST_TYPE_WHITE_LIST') is not None:
type_white_list.extend(
os.getenv('XPU_TEST_TYPE_WHITE_LIST').strip().split(','))
return list(set(type_white_list))
def get_op_type_white_list():
op_type_white_list = xpu_test_op_type_white_list
if os.getenv('XPU_TEST_OP_TYPE_WHITE_LIST') is not None:
op_type_white_list.extend(
os.getenv('XPU_TEST_OP_TYPE_WHITE_LIST').strip().split(','))
return list(set(op_type_white_list))
def get_device_op_white_list():
device_op_white_list = xpu_test_device_op_white_list
if os.getenv('XPU_TEST_DEVICE_OP_WHITE_LIST') is not None:
device_op_white_list.extend(
os.getenv('XPU_TEST_DEVICE_OP_WHITE_LIST').strip().split(','))
return list(set(device_op_white_list))
def get_device_op_type_white_list():
device_op_type_white_list = xpu_test_device_op_type_white_list
if os.getenv('XPU_TEST_DEVICE_OP_TYPE_WHITE_LIST') is not None:
device_op_type_white_list.extend(
os.getenv('XPU_TEST_DEVICE_OP_TYPE_WHITE_LIST').strip().split(','))
return list(set(device_op_type_white_list))
def make_xpu_op_list(xpu_version):
ops = []
raw_op_list = core.get_xpu_device_op_list(xpu_version)
version_str = "xpu2" if xpu_version == core.XPUVersion.XPU2 else "xpu1"
op_white_list = get_op_white_list()
type_white_list = get_type_white_list()
op_type_white_list = get_op_type_white_list()
device_op_white_list = get_device_op_white_list()
device_op_type_white_list = get_device_op_type_white_list()
print('op_white_list:', op_white_list)
print('type_white_list:', type_white_list)
print('op_type_white_list:', op_type_white_list)
print('device_op_white_list:', device_op_white_list)
print('device_op_type_white_list:', device_op_type_white_list)
for op_name, type_list in raw_op_list.items():
device_op_name = version_str + '_' + op_name
if op_name in op_white_list or device_op_name in device_op_white_list:
continue
for op_type in type_list:
if op_type in type_white_list or op_type not in type_dict_paddle_to_str.keys(
):
continue
device_op_type_name = device_op_name + '_' + type_dict_paddle_to_str[
op_type]
if device_op_type_name in device_op_type_white_list:
continue
op_type_name = op_name + '_' + type_dict_paddle_to_str[op_type]
if op_type_name in op_type_white_list:
continue
ops.append(op_type_name)
return ops
def get_xpu_op_support_types(op_name, dev_id=0):
xpu_version = core.get_xpu_device_version(dev_id)
support_type_list = core.get_xpu_device_op_support_types(op_name,
xpu_version)
support_type_str_list = [
type_dict_paddle_to_str[x] for x in support_type_list
]
return support_type_str_list
def record_op_test(op_name, test_type):
dirname = os.getenv('XPU_OP_LIST_DIR')
filename = 'xpu_op_test'
if dirname is not None:
filename = os.path.join(dirname, filename)
with open(filename, 'a') as f:
fcntl.flock(f, fcntl.LOCK_EX)
f.write(op_name + '_' + test_type + '\n')
def is_empty_grad_op_type(xpu_version, op, test_type):
xpu_op_list = core.get_xpu_device_op_list(xpu_version)
grad_op = op + '_grad'
if grad_op not in xpu_op_list.keys():
return True
grad_op_types = xpu_op_list[op]
paddle_test_type = type_dict_str_to_paddle[test_type]
if paddle_test_type not in grad_op_types:
return True
return False
def create_test_class(func_globals,
test_class,
test_type,
test_grad=True,
ignore_deivce_version=[],
test_deivce_version=[]):
xpu_version = core.get_xpu_device_version(0)
if xpu_version in ignore_deivce_version:
return
if len(test_deivce_version) != 0 and xpu_version not in test_deivce_version:
return
test_class_obj = test_class()
register_classes = inspect.getmembers(test_class_obj, inspect.isclass)
op_name = test_class_obj.op_name
no_grad = is_empty_grad_op_type(xpu_version, op_name, test_type)
for test_class in register_classes:
if test_class[0] == '__class__':
continue
class_obj = test_class[1]
cls_name = "{0}_{1}".format(test_class[0], str(test_type))
func_globals[cls_name] = type(cls_name, (class_obj, ),
{'in_type': test_type})
if hasattr(test_class_obj, 'use_dynamic_create_class'
) and test_class_obj.use_dynamic_create_class:
base_class, dynamic_classes = test_class_obj.dynamic_create_class()
for dy_class in dynamic_classes:
cls_name = "{0}_{1}".format(dy_class[0], str(test_type))
attr_dict = dy_class[1]
attr_dict['in_type'] = test_type
func_globals[cls_name] = type(cls_name, (base_class, ), attr_dict)
record_op_test(op_name, test_type)
if not no_grad:
record_op_test(op_name + '_grad', test_type)
def get_test_cover_info():
xpu_version = core.get_xpu_device_version(0)
version_str = "xpu2" if xpu_version == core.XPUVersion.XPU2 else "xpu1"
xpu_op_list = make_xpu_op_list(xpu_version)
xpu_op_covered = []
dirname = os.getenv('XPU_OP_LIST_DIR')
filename = 'xpu_op_test'
if dirname is not None:
filename = os.path.join(dirname, filename)
if os.path.exists(filename) and os.path.isfile(filename):
with open(filename) as f:
for line in f:
test_op_name = line.strip()
if test_op_name in xpu_op_list:
xpu_op_covered.append(test_op_name)
diff_list = list(set(xpu_op_list).difference(set(xpu_op_covered)))
total_len = len(set(xpu_op_list))
covered_len = len(set(xpu_op_covered))
print('{} test: {}/{}'.format(version_str, covered_len, total_len))
if (len(diff_list) != 0):
print("These ops need to be tested on {0}! ops:{1}".format(
version_str, ','.join(diff_list)))
if __name__ == '__main__':
get_test_cover_info()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import sys
sys.path.append("..")
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid import compiler, Program, program_guard
import op_test
from op_test import OpTest, skip_check_grad_ci
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
paddle.enable_static()
def huber_loss_forward(val, delta):
abs_val = abs(val)
if abs_val <= delta:
return 0.5 * val * val
else:
return delta * (abs_val - 0.5 * delta)
# 1.动态生成不同参数的测试case,wrapper类中必须实现dynamic_create_class方法
# self.use_dynamic_create_class置为True
class XPUTestArgsortOp1(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'argsort'
self.use_dynamic_create_class = True
def dynamic_create_class(self):
base_class = self.TestArgsortOp
classes = []
for descending in [True, False]:
for axis in [0, 1, 2, -1, -2]:
class_name = 'XPUTestArgsortOp_axis_' + str(axis)
attr_dict = {'init_axis': axis, 'descending': descending}
classes.append([class_name, attr_dict])
return base_class, classes
class TestArgsortOp(XPUOpTest):
def setUp(self):
self.set_xpu()
self.op_type = "argsort"
self.place = paddle.XPUPlace(0)
self.dtype = self.in_type
self.input_shape = (2, 2, 2, 3, 3)
self.axis = -1
self.descending = False
if self.in_type == 'float32':
self.x = np.random.random(self.input_shape).astype(self.dtype)
else:
self.x = np.random.randint(
low=-1000, high=1000,
size=self.input_shape).astype(self.dtype)
self.inputs = {"X": self.x}
self.attrs = {"axis": self.axis, "descending": self.descending}
self.get_output()
self.outputs = {"Out": self.sorted_x, "Indices": self.indices}
def get_output(self):
if self.descending:
self.indices = np.flip(
np.argsort(
self.x, kind='heapsort', axis=self.axis),
self.axis)
self.sorted_x = np.flip(
np.sort(
self.x, kind='heapsort', axis=self.axis), self.axis)
else:
self.indices = np.argsort(
self.x, kind='heapsort', axis=self.axis)
self.sorted_x = np.sort(self.x, kind='heapsort', axis=self.axis)
def set_xpu(self):
self.__class__.use_xpu = True
self.__class__.no_need_check_grad = True
def test_check_output(self):
self.check_output_with_place(self.place)
# 2. 为不同参数的测试case定义一个测试类,self.use_dynamic_create_class需要置为False
class XPUTestArgsortOp2(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'argsort'
self.use_dynamic_create_class = False
class TestArgsortOp(XPUOpTest):
def setUp(self):
self.set_xpu()
self.op_type = "argsort"
self.place = paddle.XPUPlace(0)
self.init_dtype()
self.init_inputshape()
self.init_axis()
self.init_direction()
if self.in_type == 'float32':
self.x = np.random.random(self.input_shape).astype(self.dtype)
else:
self.x = np.random.randint(
low=-1000, high=1000,
size=self.input_shape).astype(self.dtype)
self.inputs = {"X": self.x}
self.attrs = {"axis": self.axis, "descending": self.descending}
self.get_output()
self.outputs = {"Out": self.sorted_x, "Indices": self.indices}
def get_output(self):
if self.descending:
self.indices = np.flip(
np.argsort(
self.x, kind='heapsort', axis=self.axis),
self.axis)
self.sorted_x = np.flip(
np.sort(
self.x, kind='heapsort', axis=self.axis), self.axis)
else:
self.indices = np.argsort(
self.x, kind='heapsort', axis=self.axis)
self.sorted_x = np.sort(self.x, kind='heapsort', axis=self.axis)
def set_xpu(self):
self.__class__.use_xpu = True
self.__class__.no_need_check_grad = True
def init_inputshape(self):
self.input_shape = (2, 2, 2, 3, 3)
def init_dtype(self):
self.dtype = self.in_type
def init_axis(self):
self.axis = -1
def test_check_output(self):
self.check_output_with_place(self.place)
def init_direction(self):
self.descending = False
class TestArgsortOpAxis0XPU(TestArgsortOp):
def init_axis(self):
self.axis = 0
class TestArgsortOpAxis1XPU(TestArgsortOp):
def init_axis(self):
self.axis = 1
class TestArgsortOpAxis2XPU(TestArgsortOp):
def init_axis(self):
self.axis = 2
class TestArgsortOpAxisNeg1XPU(TestArgsortOp):
def init_axis(self):
self.axis = -1
class TestArgsortOpAxisNeg2XPU(TestArgsortOp):
def init_axis(self):
self.axis = -2
class TestArgsortOpDescendingAxisXPU(TestArgsortOp):
def init_direction(self):
self.descending = True
class TestArgsortOpDescendingAxis0XPU(TestArgsortOpAxis0XPU):
def init_direction(self):
self.descending = True
class TestArgsortOpDescendingAxis1XPU(TestArgsortOpAxis1XPU):
def init_direction(self):
self.descending = True
class TestArgsortOpDescendingAxis2XPU(TestArgsortOpAxis2XPU):
def init_direction(self):
self.descending = True
class TestArgsortOpDescendingAxisNeg1XPU(TestArgsortOpAxisNeg1XPU):
def init_direction(self):
self.descending = True
class TestArgsortOpDescendingAxisNeg2XPU(TestArgsortOpAxisNeg2XPU):
def init_direction(self):
self.descending = True
support_types = get_xpu_op_support_types('argsort')
for stype in support_types:
create_test_class(globals(), XPUTestArgsortOp1, stype)
create_test_class(globals(), XPUTestArgsortOp2, stype)
class XPUTestHuberLossOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'huber_loss'
self.use_dynamic_create_class = False
class TestHuberLossOp(XPUOpTest):
def setUp(self):
self.set_xpu()
self.op_type = 'huber_loss'
self.place = paddle.XPUPlace(0)
self.init_dtype()
self.set_inputs()
self.set_attrs()
self.set_outputs()
def set_inputs(self):
shape = self.set_shape()
x = np.random.uniform(0, 1., shape).astype(self.dtype)
y = np.random.uniform(0, 1., shape).astype(self.dtype)
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(x),
'Y': OpTest.np_dtype_to_fluid_dtype(y)
}
def set_attrs(self):
self.attrs = {'delta': 0.5}
def set_outputs(self):
delta = self.attrs['delta']
shape = self.set_shape()
residual = self.inputs['Y'] - self.inputs['X']
loss = np.vectorize(huber_loss_forward)(residual,
delta).astype(self.dtype)
self.outputs = {'Residual': residual, 'Out': loss.reshape(shape)}
def set_shape(self):
return (100, 1)
def set_xpu(self):
self.__class__.use_xpu = True
def init_dtype(self):
self.dtype = self.in_type
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad_normal(self):
self.check_grad_with_place(self.place, ['X', 'Y'], 'Out')
def test_check_grad_ingore_x(self):
self.check_grad_with_place(
self.place, ['Y'], 'Out', no_grad_set=set("residual"))
def test_check_grad_ingore_y(self):
self.check_grad_with_place(
self.place, ['X'], 'Out', no_grad_set=set('residual'))
class TestHuberLossOp1(TestHuberLossOp):
def set_shape(self):
return (640)
class TestHuberLossOp2(TestHuberLossOp):
def set_shape(self):
return (10, 10)
class TestHuberLossOp3(TestHuberLossOp):
def set_shape(self):
return (10, 10, 1)
support_types = get_xpu_op_support_types('huber_loss')
for stype in support_types:
create_test_class(globals(), XPUTestHuberLossOp, stype)
create_test_class(
globals(),
XPUTestHuberLossOp,
stype,
ignore_deivce_version=[core.XPUVersion.XPU1])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册