未验证 提交 bb5dd203 编写于 作者: J jiangcheng 提交者: GitHub

[CINN] reopen some prim with cinn single test (#51081)

上级 ff7ce2ff
...@@ -120,7 +120,7 @@ class TestExpPrimFp32(OpTest): ...@@ -120,7 +120,7 @@ class TestExpPrimFp32(OpTest):
self.shape = [12, 17] self.shape = [12, 17]
def skip_cinn(self): def skip_cinn(self):
self.enable_cinn = False self.enable_cinn = True
def set_only_prim(self): def set_only_prim(self):
pass pass
...@@ -145,7 +145,7 @@ class TestExpPrimFp16(TestExpPrimFp32): ...@@ -145,7 +145,7 @@ class TestExpPrimFp16(TestExpPrimFp32):
self.check_grad(['X'], 'Out', check_prim=True) self.check_grad(['X'], 'Out', check_prim=True)
def skip_cinn(self): def skip_cinn(self):
self.enable_cinn = False self.enable_cinn = True
class TestExpPrim_ZeroDim(TestExpPrimFp32): class TestExpPrim_ZeroDim(TestExpPrimFp32):
...@@ -325,7 +325,7 @@ class TestSilu(TestActivation): ...@@ -325,7 +325,7 @@ class TestSilu(TestActivation):
def setUp(self): def setUp(self):
self.op_type = "silu" self.op_type = "silu"
self.prim_op_type = "comp" self.prim_op_type = "comp"
self.enable_cinn = False self.enable_cinn = True
self.python_api = paddle.nn.functional.silu self.python_api = paddle.nn.functional.silu
self.init_dtype() self.init_dtype()
self.init_shape() self.init_shape()
...@@ -349,13 +349,14 @@ class TestSilu(TestActivation): ...@@ -349,13 +349,14 @@ class TestSilu(TestActivation):
class TestSilu_ZeroDim(TestSilu): class TestSilu_ZeroDim(TestSilu):
def init_shape(self): def init_shape(self):
self.shape = [] self.shape = []
self.enable_cinn = False
class TestSiluFP16(TestActivation): class TestSiluFP16(TestActivation):
def setUp(self): def setUp(self):
self.op_type = "silu" self.op_type = "silu"
self.prim_op_type = "comp" self.prim_op_type = "comp"
self.enable_cinn = False self.enable_cinn = True
self.only_prim = True self.only_prim = True
self.python_api = paddle.nn.functional.silu self.python_api = paddle.nn.functional.silu
self.init_dtype() self.init_dtype()
...@@ -1199,7 +1200,7 @@ class TestSqrtPrimFp32(TestActivation): ...@@ -1199,7 +1200,7 @@ class TestSqrtPrimFp32(TestActivation):
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out} self.outputs = {'Out': out}
self.enable_cinn = False self.enable_cinn = True
def test_check_grad(self): def test_check_grad(self):
if self.dtype == np.float16: if self.dtype == np.float16:
...@@ -1216,11 +1217,13 @@ class TestSqrtPrimFp32(TestActivation): ...@@ -1216,11 +1217,13 @@ class TestSqrtPrimFp32(TestActivation):
class TestSqrt_ZeroDim(TestSqrt): class TestSqrt_ZeroDim(TestSqrt):
def init_shape(self): def init_shape(self):
self.shape = [] self.shape = []
self.enable_cinn = False
class TestSqrtPrim_ZeroDim(TestSqrt): class TestSqrtPrim_ZeroDim(TestSqrt):
def init_shape(self): def init_shape(self):
self.shape = [] self.shape = []
self.enable_cinn = False
def init_dtype(self): def init_dtype(self):
self.dtype = np.float32 self.dtype = np.float32
...@@ -1527,6 +1530,8 @@ class TestSin(TestActivation, TestParameter): ...@@ -1527,6 +1530,8 @@ class TestSin(TestActivation, TestParameter):
self.op_type = "sin" self.op_type = "sin"
self.init_dtype() self.init_dtype()
self.init_shape() self.init_shape()
# prim not support now
self.enable_cinn = False
np.random.seed(1024) np.random.seed(1024)
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype) x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
......
...@@ -35,7 +35,7 @@ class TestExpandV2OpRank1(OpTest): ...@@ -35,7 +35,7 @@ class TestExpandV2OpRank1(OpTest):
self.attrs = {'shape': self.shape} self.attrs = {'shape': self.shape}
output = np.tile(self.inputs['X'], self.expand_times) output = np.tile(self.inputs['X'], self.expand_times)
self.outputs = {'Out': output} self.outputs = {'Out': output}
self.enable_cinn = False self.enable_cinn = True
def init_data(self): def init_data(self):
self.ori_shape = [100] self.ori_shape = [100]
......
...@@ -61,7 +61,7 @@ class TestFillAnyLikeOpFloat32(TestFillAnyLikeOp): ...@@ -61,7 +61,7 @@ class TestFillAnyLikeOpFloat32(TestFillAnyLikeOp):
self.value = 0.0 self.value = 0.0
def skip_cinn(self): def skip_cinn(self):
self.enable_cinn = False self.enable_cinn = True
@unittest.skipIf( @unittest.skipIf(
...@@ -96,7 +96,7 @@ class TestFillAnyLikeOpValue1(TestFillAnyLikeOp): ...@@ -96,7 +96,7 @@ class TestFillAnyLikeOpValue1(TestFillAnyLikeOp):
self.value = 1.0 self.value = 1.0
def skip_cinn(self): def skip_cinn(self):
self.enable_cinn = False self.enable_cinn = True
class TestFillAnyLikeOpValue2(TestFillAnyLikeOp): class TestFillAnyLikeOpValue2(TestFillAnyLikeOp):
...@@ -104,7 +104,7 @@ class TestFillAnyLikeOpValue2(TestFillAnyLikeOp): ...@@ -104,7 +104,7 @@ class TestFillAnyLikeOpValue2(TestFillAnyLikeOp):
self.value = 1e-10 self.value = 1e-10
def skip_cinn(self): def skip_cinn(self):
self.enable_cinn = False self.enable_cinn = True
class TestFillAnyLikeOpValue3(TestFillAnyLikeOp): class TestFillAnyLikeOpValue3(TestFillAnyLikeOp):
...@@ -112,7 +112,7 @@ class TestFillAnyLikeOpValue3(TestFillAnyLikeOp): ...@@ -112,7 +112,7 @@ class TestFillAnyLikeOpValue3(TestFillAnyLikeOp):
self.value = 1e-100 self.value = 1e-100
def skip_cinn(self): def skip_cinn(self):
self.enable_cinn = False self.enable_cinn = True
class TestFillAnyLikeOpType(TestFillAnyLikeOp): class TestFillAnyLikeOpType(TestFillAnyLikeOp):
...@@ -136,7 +136,7 @@ class TestFillAnyLikeOpType(TestFillAnyLikeOp): ...@@ -136,7 +136,7 @@ class TestFillAnyLikeOpType(TestFillAnyLikeOp):
self.skip_cinn() self.skip_cinn()
def skip_cinn(self): def skip_cinn(self):
self.enable_cinn = False self.enable_cinn = True
class TestFillAnyLikeOpFloat16(TestFillAnyLikeOp): class TestFillAnyLikeOpFloat16(TestFillAnyLikeOp):
...@@ -144,7 +144,7 @@ class TestFillAnyLikeOpFloat16(TestFillAnyLikeOp): ...@@ -144,7 +144,7 @@ class TestFillAnyLikeOpFloat16(TestFillAnyLikeOp):
self.dtype = np.float16 self.dtype = np.float16
def skip_cinn(self): def skip_cinn(self):
self.enable_cinn = False self.enable_cinn = True
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -28,15 +28,18 @@ class TestFlattenOp(OpTest): ...@@ -28,15 +28,18 @@ class TestFlattenOp(OpTest):
self.prim_op_type = "comp" self.prim_op_type = "comp"
self.start_axis = 0 self.start_axis = 0
self.stop_axis = -1 self.stop_axis = -1
self.skip_cinn()
self.init_test_case() self.init_test_case()
self.inputs = {"X": np.random.random(self.in_shape).astype("float64")} self.inputs = {"X": np.random.random(self.in_shape).astype("float64")}
self.init_attrs() self.init_attrs()
self.enable_cinn = False
self.outputs = { self.outputs = {
"Out": self.inputs["X"].reshape(self.new_shape), "Out": self.inputs["X"].reshape(self.new_shape),
"XShape": np.random.random(self.in_shape).astype("float32"), "XShape": np.random.random(self.in_shape).astype("float32"),
} }
def skip_cinn(self):
self.enable_cinn = True
def test_check_output(self): def test_check_output(self):
self.check_output( self.check_output(
no_check_set=["XShape"], check_eager=True, check_prim=True no_check_set=["XShape"], check_eager=True, check_prim=True
...@@ -135,6 +138,9 @@ class TestFlattenOp_6(TestFlattenOp): ...@@ -135,6 +138,9 @@ class TestFlattenOp_6(TestFlattenOp):
self.stop_axis = -1 self.stop_axis = -1
self.new_shape = (1,) self.new_shape = (1,)
def skip_cinn(self):
self.enable_cinn = False
def init_attrs(self): def init_attrs(self):
self.attrs = { self.attrs = {
"start_axis": self.start_axis, "start_axis": self.start_axis,
......
...@@ -145,7 +145,7 @@ class TestFullLikeOp2(TestFullLikeOp1): ...@@ -145,7 +145,7 @@ class TestFullLikeOp2(TestFullLikeOp1):
self.dtype = np.float64 self.dtype = np.float64
def skip_cinn(self): def skip_cinn(self):
self.enable_cinn = False self.enable_cinn = True
class TestFullLikeOp3(TestFullLikeOp1): class TestFullLikeOp3(TestFullLikeOp1):
...@@ -155,7 +155,7 @@ class TestFullLikeOp3(TestFullLikeOp1): ...@@ -155,7 +155,7 @@ class TestFullLikeOp3(TestFullLikeOp1):
self.dtype = np.int64 self.dtype = np.int64
def skip_cinn(self): def skip_cinn(self):
self.enable_cinn = False self.enable_cinn = True
@unittest.skipIf( @unittest.skipIf(
......
...@@ -32,8 +32,7 @@ class TestSumOp(OpTest): ...@@ -32,8 +32,7 @@ class TestSumOp(OpTest):
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")} self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")}
self.outputs = {'Out': self.inputs['X'].sum(axis=0)} self.outputs = {'Out': self.inputs['X'].sum(axis=0)}
self.attrs = {'dim': [0]} self.attrs = {'dim': [0]}
# reduce doesn't support float64 in cinn self.enable_cinn = True
self.enable_cinn = False
def test_check_output(self): def test_check_output(self):
self.check_output(check_eager=True) self.check_output(check_eager=True)
...@@ -55,8 +54,7 @@ class TestSumOpFp32(OpTest): ...@@ -55,8 +54,7 @@ class TestSumOpFp32(OpTest):
'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim'])) 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))
} }
self.gradient = self.calc_gradient() self.gradient = self.calc_gradient()
# error occurred in cinn self.enable_cinn = True
self.enable_cinn = False
def test_check_output(self): def test_check_output(self):
self.check_output(check_eager=True) self.check_output(check_eager=True)
...@@ -151,7 +149,7 @@ class TestSumOp_fp16_withInt(OpTest): ...@@ -151,7 +149,7 @@ class TestSumOp_fp16_withInt(OpTest):
'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim'])) 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))
} }
self.gradient = self.calc_gradient() self.gradient = self.calc_gradient()
self.enable_cinn = False self.enable_cinn = True
def test_check_output(self): def test_check_output(self):
self.check_output(check_eager=True) self.check_output(check_eager=True)
...@@ -182,7 +180,7 @@ class TestSumOp5D(OpTest): ...@@ -182,7 +180,7 @@ class TestSumOp5D(OpTest):
self.attrs = {'dim': [0]} self.attrs = {'dim': [0]}
self.outputs = {'Out': self.inputs['X'].sum(axis=0)} self.outputs = {'Out': self.inputs['X'].sum(axis=0)}
# error occurred in cinn # error occurred in cinn
self.enable_cinn = False self.enable_cinn = True
def test_check_output(self): def test_check_output(self):
self.check_output(check_eager=True) self.check_output(check_eager=True)
...@@ -202,7 +200,7 @@ class TestSumOp6D(OpTest): ...@@ -202,7 +200,7 @@ class TestSumOp6D(OpTest):
self.attrs = {'dim': [0]} self.attrs = {'dim': [0]}
self.outputs = {'Out': self.inputs['X'].sum(axis=0)} self.outputs = {'Out': self.inputs['X'].sum(axis=0)}
# error occurred in cinn # error occurred in cinn
self.enable_cinn = False self.enable_cinn = True
def test_check_output(self): def test_check_output(self):
self.check_output(check_eager=True) self.check_output(check_eager=True)
...@@ -678,8 +676,7 @@ class Test1DReduce(OpTest): ...@@ -678,8 +676,7 @@ class Test1DReduce(OpTest):
self.prim_op_type = "prim" self.prim_op_type = "prim"
self.inputs = {'X': np.random.random(120).astype("float64")} self.inputs = {'X': np.random.random(120).astype("float64")}
self.outputs = {'Out': self.inputs['X'].sum(axis=0)} self.outputs = {'Out': self.inputs['X'].sum(axis=0)}
# reduce doesn't support float64 in cinn. self.enable_cinn = True
self.enable_cinn = False
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -696,8 +693,7 @@ class Test2DReduce0(Test1DReduce): ...@@ -696,8 +693,7 @@ class Test2DReduce0(Test1DReduce):
self.attrs = {'dim': [0]} self.attrs = {'dim': [0]}
self.inputs = {'X': np.random.random((20, 10)).astype("float64")} self.inputs = {'X': np.random.random((20, 10)).astype("float64")}
self.outputs = {'Out': self.inputs['X'].sum(axis=0)} self.outputs = {'Out': self.inputs['X'].sum(axis=0)}
# reduce doesn't support float64 in cinn. self.enable_cinn = True
self.enable_cinn = False
class Test2DReduce1(Test1DReduce): class Test2DReduce1(Test1DReduce):
...@@ -710,8 +706,7 @@ class Test2DReduce1(Test1DReduce): ...@@ -710,8 +706,7 @@ class Test2DReduce1(Test1DReduce):
self.outputs = { self.outputs = {
'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim'])) 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))
} }
# reduce doesn't support float64 in cinn. self.enable_cinn = True
self.enable_cinn = False
class Test3DReduce0(Test1DReduce): class Test3DReduce0(Test1DReduce):
...@@ -724,8 +719,7 @@ class Test3DReduce0(Test1DReduce): ...@@ -724,8 +719,7 @@ class Test3DReduce0(Test1DReduce):
self.outputs = { self.outputs = {
'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim'])) 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))
} }
# reduce doesn't support float64 in cinn. self.enable_cinn = True
self.enable_cinn = False
class Test3DReduce1(Test1DReduce): class Test3DReduce1(Test1DReduce):
...@@ -738,8 +732,7 @@ class Test3DReduce1(Test1DReduce): ...@@ -738,8 +732,7 @@ class Test3DReduce1(Test1DReduce):
self.outputs = { self.outputs = {
'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim'])) 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))
} }
# reduce doesn't support float64 in cinn. self.enable_cinn = True
self.enable_cinn = False
class Test3DReduce2(Test1DReduce): class Test3DReduce2(Test1DReduce):
...@@ -752,8 +745,7 @@ class Test3DReduce2(Test1DReduce): ...@@ -752,8 +745,7 @@ class Test3DReduce2(Test1DReduce):
self.outputs = { self.outputs = {
'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim'])) 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))
} }
# reduce doesn't support float64 in cinn. self.enable_cinn = True
self.enable_cinn = False
class Test3DReduce3(Test1DReduce): class Test3DReduce3(Test1DReduce):
...@@ -766,8 +758,7 @@ class Test3DReduce3(Test1DReduce): ...@@ -766,8 +758,7 @@ class Test3DReduce3(Test1DReduce):
self.outputs = { self.outputs = {
'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim'])) 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))
} }
# reduce doesn't support float64 in cinn. self.enable_cinn = True
self.enable_cinn = False
class Test8DReduce0(Test1DReduce): class Test8DReduce0(Test1DReduce):
...@@ -800,8 +791,7 @@ class TestKeepDimReduce(Test1DReduce): ...@@ -800,8 +791,7 @@ class TestKeepDimReduce(Test1DReduce):
axis=tuple(self.attrs['dim']), keepdims=self.attrs['keep_dim'] axis=tuple(self.attrs['dim']), keepdims=self.attrs['keep_dim']
) )
} }
# reduce doesn't support float64 in cinn. self.enable_cinn = True
self.enable_cinn = False
class TestKeepDim8DReduce(Test1DReduce): class TestKeepDim8DReduce(Test1DReduce):
...@@ -897,8 +887,7 @@ class TestReduceSumWithDimOne(OpTest): ...@@ -897,8 +887,7 @@ class TestReduceSumWithDimOne(OpTest):
axis=tuple(self.attrs['dim']), keepdims=True axis=tuple(self.attrs['dim']), keepdims=True
) )
} }
# reduce doesn't support float64 in cinn self.enable_cinn = True
self.enable_cinn = False
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -919,8 +908,7 @@ class TestReduceSumWithNumelOne(OpTest): ...@@ -919,8 +908,7 @@ class TestReduceSumWithNumelOne(OpTest):
axis=tuple(self.attrs['dim']), keepdims=False axis=tuple(self.attrs['dim']), keepdims=False
) )
} }
# reduce doesn't support float64 in cinn self.enable_cinn = True
self.enable_cinn = False
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -937,8 +925,7 @@ class TestReduceAll(OpTest): ...@@ -937,8 +925,7 @@ class TestReduceAll(OpTest):
self.inputs = {'X': np.random.random((100, 1, 1)).astype("float64")} self.inputs = {'X': np.random.random((100, 1, 1)).astype("float64")}
self.attrs = {'reduce_all': True, 'keep_dim': False} self.attrs = {'reduce_all': True, 'keep_dim': False}
self.outputs = {'Out': self.inputs['X'].sum()} self.outputs = {'Out': self.inputs['X'].sum()}
# reduce doesn't support float64 in cinn self.enable_cinn = True
self.enable_cinn = False
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -955,8 +942,7 @@ class TestReduceAllFp32(OpTest): ...@@ -955,8 +942,7 @@ class TestReduceAllFp32(OpTest):
self.inputs = {'X': np.random.random((100, 1, 1)).astype("float32")} self.inputs = {'X': np.random.random((100, 1, 1)).astype("float32")}
self.attrs = {'reduce_all': True, 'keep_dim': False} self.attrs = {'reduce_all': True, 'keep_dim': False}
self.outputs = {'Out': self.inputs['X'].sum()} self.outputs = {'Out': self.inputs['X'].sum()}
# reduce doesn't support float64 in cinn self.enable_cinn = True
self.enable_cinn = False
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -973,7 +959,7 @@ class Test1DReduceWithAxes1(OpTest): ...@@ -973,7 +959,7 @@ class Test1DReduceWithAxes1(OpTest):
self.inputs = {'X': np.random.random(100).astype("float64")} self.inputs = {'X': np.random.random(100).astype("float64")}
self.attrs = {'dim': [0], 'keep_dim': False} self.attrs = {'dim': [0], 'keep_dim': False}
self.outputs = {'Out': self.inputs['X'].sum(axis=0)} self.outputs = {'Out': self.inputs['X'].sum(axis=0)}
self.enable_cinn = False self.enable_cinn = True
def test_check_output(self): def test_check_output(self):
self.check_output(check_prim=True) self.check_output(check_prim=True)
...@@ -996,6 +982,7 @@ class TestReduceWithDtype(OpTest): ...@@ -996,6 +982,7 @@ class TestReduceWithDtype(OpTest):
'out_dtype': int(convert_np_dtype_to_dtype_(np.float64)), 'out_dtype': int(convert_np_dtype_to_dtype_(np.float64)),
} }
) )
# cinn op_mapper not support in_dtype/out_dtype attr
self.enable_cinn = False self.enable_cinn = False
def test_check_output(self): def test_check_output(self):
...@@ -1025,6 +1012,7 @@ class TestReduceWithDtype1(TestReduceWithDtype): ...@@ -1025,6 +1012,7 @@ class TestReduceWithDtype1(TestReduceWithDtype):
'out_dtype': int(convert_np_dtype_to_dtype_(np.float64)), 'out_dtype': int(convert_np_dtype_to_dtype_(np.float64)),
} }
) )
# cinn op_mapper not support in_dtype/out_dtype attr
self.enable_cinn = False self.enable_cinn = False
def test_check_output(self): def test_check_output(self):
...@@ -1048,6 +1036,7 @@ class TestReduceWithDtype2(TestReduceWithDtype): ...@@ -1048,6 +1036,7 @@ class TestReduceWithDtype2(TestReduceWithDtype):
'out_dtype': int(convert_np_dtype_to_dtype_(np.float64)), 'out_dtype': int(convert_np_dtype_to_dtype_(np.float64)),
} }
) )
# cinn op_mapper not support in_dtype/out_dtype attr
self.enable_cinn = False self.enable_cinn = False
def test_check_output(self): def test_check_output(self):
......
...@@ -34,7 +34,7 @@ class TestSliceOp(OpTest): ...@@ -34,7 +34,7 @@ class TestSliceOp(OpTest):
self.op_type = "slice" self.op_type = "slice"
self.prim_op_type = "prim" self.prim_op_type = "prim"
self.python_api = paddle.slice self.python_api = paddle.slice
self.enable_cinn = False self.enable_cinn = True
self.config() self.config()
self.inputs = {'Input': self.input} self.inputs = {'Input': self.input}
self.outputs = {'Out': self.out} self.outputs = {'Out': self.out}
...@@ -74,7 +74,7 @@ class TestCase1(TestSliceOp): ...@@ -74,7 +74,7 @@ class TestCase1(TestSliceOp):
class TestCase2(TestSliceOp): class TestCase2(TestSliceOp):
def config(self): def config(self):
self.enable_cinn = False self.enable_cinn = True
self.input = np.random.random([3, 4, 5, 6]).astype("float64") self.input = np.random.random([3, 4, 5, 6]).astype("float64")
self.starts = [-3, 0, 2] self.starts = [-3, 0, 2]
self.ends = [3, 100, -1] self.ends = [3, 100, -1]
...@@ -114,7 +114,7 @@ class TestSliceZerosShapeTensor(OpTest): ...@@ -114,7 +114,7 @@ class TestSliceZerosShapeTensor(OpTest):
# 1.2 with attr(decrease) # 1.2 with attr(decrease)
class TestSliceOp_decs_dim(OpTest): class TestSliceOp_decs_dim(OpTest):
def setUp(self): def setUp(self):
self.enable_cinn = False self.enable_cinn = True
self.op_type = "slice" self.op_type = "slice"
self.prim_op_type = "prim" self.prim_op_type = "prim"
self.python_api = paddle.slice self.python_api = paddle.slice
...@@ -149,7 +149,7 @@ class TestSliceOp_decs_dim(OpTest): ...@@ -149,7 +149,7 @@ class TestSliceOp_decs_dim(OpTest):
class TestSliceOp_decs_dim_2(TestSliceOp_decs_dim): class TestSliceOp_decs_dim_2(TestSliceOp_decs_dim):
def config(self): def config(self):
self.enable_cinn = False self.enable_cinn = True
self.input = np.random.random([3, 4, 5, 6]).astype("float64") self.input = np.random.random([3, 4, 5, 6]).astype("float64")
self.starts = [1, 0, 2] self.starts = [1, 0, 2]
self.ends = [2, 1, 4] self.ends = [2, 1, 4]
...@@ -161,7 +161,7 @@ class TestSliceOp_decs_dim_2(TestSliceOp_decs_dim): ...@@ -161,7 +161,7 @@ class TestSliceOp_decs_dim_2(TestSliceOp_decs_dim):
class TestSliceOp_decs_dim_3(TestSliceOp_decs_dim): class TestSliceOp_decs_dim_3(TestSliceOp_decs_dim):
def config(self): def config(self):
self.enable_cinn = False self.enable_cinn = True
self.input = np.random.random([3, 4, 5, 6]).astype("float64") self.input = np.random.random([3, 4, 5, 6]).astype("float64")
self.starts = [-1, 0, 2] self.starts = [-1, 0, 2]
self.ends = [1000000, 1, 4] self.ends = [1000000, 1, 4]
...@@ -185,7 +185,7 @@ class TestSliceOp_decs_dim_4(TestSliceOp_decs_dim): ...@@ -185,7 +185,7 @@ class TestSliceOp_decs_dim_4(TestSliceOp_decs_dim):
class TestSliceOp_decs_dim_5(TestSliceOp_decs_dim): class TestSliceOp_decs_dim_5(TestSliceOp_decs_dim):
def config(self): def config(self):
self.enable_cinn = False self.enable_cinn = True
self.input = np.random.random([3, 4, 5, 6]).astype("float64") self.input = np.random.random([3, 4, 5, 6]).astype("float64")
self.starts = [-1] self.starts = [-1]
self.ends = [1000000] self.ends = [1000000]
...@@ -198,7 +198,7 @@ class TestSliceOp_decs_dim_5(TestSliceOp_decs_dim): ...@@ -198,7 +198,7 @@ class TestSliceOp_decs_dim_5(TestSliceOp_decs_dim):
# test_6 with test_2 with test_3 # test_6 with test_2 with test_3
class TestSliceOp_decs_dim_6(TestSliceOp_decs_dim): class TestSliceOp_decs_dim_6(TestSliceOp_decs_dim):
def config(self): def config(self):
self.enable_cinn = False self.enable_cinn = True
self.input = np.random.random([3, 4, 5, 6]).astype("float64") self.input = np.random.random([3, 4, 5, 6]).astype("float64")
self.starts = [0, 1, 2, 3] self.starts = [0, 1, 2, 3]
self.ends = [1, 2, 3, 4] self.ends = [1, 2, 3, 4]
...@@ -484,7 +484,7 @@ class TestSliceOp_starts_OneTensor_ends_ListTensor(OpTest): ...@@ -484,7 +484,7 @@ class TestSliceOp_starts_OneTensor_ends_ListTensor(OpTest):
) )
class TestFP16(OpTest): class TestFP16(OpTest):
def setUp(self): def setUp(self):
self.enable_cinn = False self.enable_cinn = True
self.op_type = "slice" self.op_type = "slice"
self.prim_op_type = "prim" self.prim_op_type = "prim"
self.python_api = paddle.slice self.python_api = paddle.slice
......
...@@ -73,7 +73,7 @@ class TestSoftmaxOp(OpTest): ...@@ -73,7 +73,7 @@ class TestSoftmaxOp(OpTest):
'use_cudnn': self.use_cudnn, 'use_cudnn': self.use_cudnn,
'use_mkldnn': self.use_mkldnn, 'use_mkldnn': self.use_mkldnn,
} }
self.enable_cinn = False self.enable_cinn = True
def init_kernel_type(self): def init_kernel_type(self):
pass pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册