未验证 提交 b8352611 编写于 作者: Z zzk0 提交者: GitHub

[CINN] Enable check_cinn on some tests (#54261)

* [CINN] Enable check_cinn

* add CMakeLists.txt
上级 0a0fbb1a
...@@ -1189,7 +1189,13 @@ set(TEST_CINN_OPS ...@@ -1189,7 +1189,13 @@ set(TEST_CINN_OPS
test_tile_op test_tile_op
test_roll_op test_roll_op
test_sum_op test_sum_op
test_elementwise_min_op) test_elementwise_min_op
test_arg_min_max_op
test_reverse_op
test_flip
test_triangular_solve_op
test_scatter_nd_op
test_strided_slice_op)
foreach(TEST_CINN_OPS ${TEST_CINN_OPS}) foreach(TEST_CINN_OPS ${TEST_CINN_OPS})
if(WITH_CINN) if(WITH_CINN)
......
...@@ -42,7 +42,7 @@ class BaseTestCase(OpTest): ...@@ -42,7 +42,7 @@ class BaseTestCase(OpTest):
self.outputs = {'Out': np.argmax(self.x, axis=self.axis)} self.outputs = {'Out': np.argmax(self.x, axis=self.axis)}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_cinn=True)
class TestCase0(BaseTestCase): class TestCase0(BaseTestCase):
......
...@@ -100,10 +100,10 @@ class TestFlipOp(OpTest): ...@@ -100,10 +100,10 @@ class TestFlipOp(OpTest):
self.attrs = {"axis": self.axis} self.attrs = {"axis": self.axis}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_cinn=True)
def test_check_grad(self): def test_check_grad(self):
self.check_grad(["X"], "Out") self.check_grad(["X"], "Out", check_cinn=True)
def init_test_case(self): def init_test_case(self):
self.in_shape = (6, 4, 2, 3) self.in_shape = (6, 4, 2, 3)
...@@ -167,12 +167,12 @@ def create_test_fp16_class(parent): ...@@ -167,12 +167,12 @@ def create_test_fp16_class(parent):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
if core.is_float16_supported(place): if core.is_float16_supported(place):
self.check_output_with_place(place) self.check_output_with_place(place, check_cinn=True)
def test_check_grad(self): def test_check_grad(self):
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
if core.is_float16_supported(place): if core.is_float16_supported(place):
self.check_grad_with_place(place, ["X"], "Out") self.check_grad_with_place(place, ["X"], "Out", check_cinn=True)
cls_name = "{}_{}".format(parent.__name__, "FP16OP") cls_name = "{}_{}".format(parent.__name__, "FP16OP")
TestFlipFP16.__name__ = cls_name TestFlipFP16.__name__ = cls_name
......
...@@ -37,10 +37,10 @@ class TestReverseOp(OpTest): ...@@ -37,10 +37,10 @@ class TestReverseOp(OpTest):
self.outputs = {'Out': out} self.outputs = {'Out': out}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_cinn=True)
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out', check_cinn=True)
class TestCase0(TestReverseOp): class TestCase0(TestReverseOp):
......
...@@ -93,7 +93,7 @@ class TestScatterNdAddSimpleOp(OpTest): ...@@ -93,7 +93,7 @@ class TestScatterNdAddSimpleOp(OpTest):
self.dtype = np.float64 self.dtype = np.float64
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_cinn=True)
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X', 'Updates'], 'Out', check_prim=True) self.check_grad(['X', 'Updates'], 'Out', check_prim=True)
...@@ -169,7 +169,7 @@ class TestScatterNdAddWithEmptyIndex(OpTest): ...@@ -169,7 +169,7 @@ class TestScatterNdAddWithEmptyIndex(OpTest):
self.dtype = np.float64 self.dtype = np.float64
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_cinn=True)
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X', 'Updates'], 'Out', check_prim=True) self.check_grad(['X', 'Updates'], 'Out', check_prim=True)
...@@ -248,7 +248,7 @@ class TestScatterNdAddWithHighRankSame(OpTest): ...@@ -248,7 +248,7 @@ class TestScatterNdAddWithHighRankSame(OpTest):
self.dtype = np.float64 self.dtype = np.float64
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_cinn=True)
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X', 'Updates'], 'Out', check_prim=True) self.check_grad(['X', 'Updates'], 'Out', check_prim=True)
...@@ -311,7 +311,7 @@ class TestScatterNdAddWithHighRankDiff(OpTest): ...@@ -311,7 +311,7 @@ class TestScatterNdAddWithHighRankDiff(OpTest):
self.outputs = {'Out': expect_np} self.outputs = {'Out': expect_np}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_cinn=True)
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X', 'Updates'], 'Out', check_prim=True) self.check_grad(['X', 'Updates'], 'Out', check_prim=True)
......
...@@ -96,10 +96,10 @@ class TestStrideSliceOp(OpTest): ...@@ -96,10 +96,10 @@ class TestStrideSliceOp(OpTest):
} }
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_cinn=True)
def test_check_grad(self): def test_check_grad(self):
self.check_grad({'Input'}, 'Out') self.check_grad({'Input'}, 'Out', check_cinn=True)
def initTestCase(self): def initTestCase(self):
self.input = np.random.rand(100) self.input = np.random.rand(100)
...@@ -1032,10 +1032,10 @@ class TestStrideSliceFP16Op(OpTest): ...@@ -1032,10 +1032,10 @@ class TestStrideSliceFP16Op(OpTest):
} }
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_cinn=True)
def test_check_grad(self): def test_check_grad(self):
self.check_grad({'Input'}, 'Out') self.check_grad({'Input'}, 'Out', check_cinn=True)
def initTestCase(self): def initTestCase(self):
self.input = np.random.rand(100) self.input = np.random.rand(100)
......
...@@ -64,10 +64,10 @@ class TestTriangularSolveOp(OpTest): ...@@ -64,10 +64,10 @@ class TestTriangularSolveOp(OpTest):
self.outputs = {'Out': self.output} self.outputs = {'Out': self.output}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_cinn=True)
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out') self.check_grad(['X', 'Y'], 'Out', check_cinn=True)
# 2D(broadcast) + 3D, test 'transpose' # 2D(broadcast) + 3D, test 'transpose'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册