未验证 提交 69a04209 编写于 作者: Z zhangxiaoci 提交者: GitHub

refactor range unittest for kunlun (#39800)

*test=kunlun
上级 b089e7cd
......@@ -20,57 +20,70 @@ import numpy as np
import sys
sys.path.append("..")
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
paddle.enable_static()
class TestRangeOp(XPUOpTest):
def setUp(self):
self.op_type = "range"
self.init_config()
self.inputs = {
'Start': np.array([self.case[0]]).astype(self.dtype),
'End': np.array([self.case[1]]).astype(self.dtype),
'Step': np.array([self.case[2]]).astype(self.dtype)
}
class XPUTestRangeOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = "range"
self.use_dynamic_create_class = False
self.outputs = {
'Out': np.arange(self.case[0], self.case[1],
self.case[2]).astype(self.dtype)
}
class TestRangeOp(XPUOpTest):
def setUp(self):
self.set_xpu()
self.op_type = "range"
self.init_dtype()
self.init_config()
self.inputs = {
'Start': np.array([self.case[0]]).astype(self.dtype),
'End': np.array([self.case[1]]).astype(self.dtype),
'Step': np.array([self.case[2]]).astype(self.dtype)
}
def init_config(self):
self.dtype = np.float32
self.case = (0, 1, 0.2)
self.outputs = {
'Out': np.arange(self.case[0], self.case[1],
self.case[2]).astype(self.dtype)
}
def test_check_output(self):
place = paddle.XPUPlace(0)
self.check_output_with_place(place, check_dygraph=False)
def set_xpu(self):
self.__class__.no_need_check_grad = True
def init_dtype(self):
self.dtype = self.in_type
class TestFloatRangeOpCase0(TestRangeOp):
def init_config(self):
self.dtype = np.float32
self.case = (0, 5, 1)
def init_config(self):
self.case = (0, 1, 0.2) if self.dtype == np.float32 else (0, 5, 1)
def test_check_output(self):
place = paddle.XPUPlace(0)
self.check_output_with_place(place, check_dygraph=False)
class TestInt32RangeOpCase0(TestRangeOp):
def init_config(self):
self.dtype = np.int32
self.case = (0, 5, 2)
class TestRangeOpCase0(TestRangeOp):
def init_config(self):
self.case = (0, 5, 1)
class TestRangeOpCase1(TestRangeOp):
def init_config(self):
self.case = (0, 5, 2)
class TestInt32RangeOpCase1(TestRangeOp):
def init_config(self):
self.dtype = np.int32
self.case = (10, 1, -2)
class TestRangeOpCase2(TestRangeOp):
def init_config(self):
self.case = (10, 1, -2)
class TestRangeOpCase3(TestRangeOp):
def init_config(self):
self.case = (-1, -10, -2)
class TestInt32RangeOpCase2(TestRangeOp):
def init_config(self):
self.dtype = np.int32
self.case = (-1, -10, -2)
class TestRangeOpCase4(TestRangeOp):
def init_config(self):
self.case = (10, -10, -11)
support_types = get_xpu_op_support_types("range")
for stype in support_types:
create_test_class(globals(), XPUTestRangeOp, stype)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册