未验证 提交 099b3d25 编写于 作者: J jiangcheng 提交者: GitHub

[CINN] reopen mean/cumsum/instance_norm op's prim+CINN test (#54406)

* [CINN] reopen mean/cumsum/instance_norm op's prim+CINN test

* remove repeat test_mean_op in cmake
上级 791963ab
...@@ -1195,7 +1195,9 @@ set(TEST_CINN_OPS ...@@ -1195,7 +1195,9 @@ set(TEST_CINN_OPS
test_flip test_flip
test_triangular_solve_op test_triangular_solve_op
test_scatter_nd_op test_scatter_nd_op
test_strided_slice_op) test_strided_slice_op
test_instance_norm_op
test_cumsum_op)
foreach(TEST_CINN_OPS ${TEST_CINN_OPS}) foreach(TEST_CINN_OPS ${TEST_CINN_OPS})
if(WITH_CINN) if(WITH_CINN)
......
...@@ -314,7 +314,8 @@ class TestSumOpExclusiveFP16(OpTest): ...@@ -314,7 +314,8 @@ class TestSumOpExclusiveFP16(OpTest):
self.python_api = cumsum_wrapper self.python_api = cumsum_wrapper
self.public_python_api = paddle.cumsum self.public_python_api = paddle.cumsum
self.init_dtype() self.init_dtype()
self.enable_cinn = True # TODO(thisjiang): set `True` after reduce+cast at shape=[4, 5, 20, 20], dim=[2]'s fusion bug has fixed
self.enable_cinn = False
self.attrs = {'axis': 2, "exclusive": True} self.attrs = {'axis': 2, "exclusive": True}
self.x = np.random.random((4, 5, 20)).astype(self.dtype) self.x = np.random.random((4, 5, 20)).astype(self.dtype)
self.out = np.concatenate( self.out = np.concatenate(
...@@ -389,7 +390,8 @@ def create_test_fp16_class(parent, max_relative_error=1e-2): ...@@ -389,7 +390,8 @@ def create_test_fp16_class(parent, max_relative_error=1e-2):
self.dtype = self.dtype_ = np.float16 self.dtype = self.dtype_ = np.float16
def set_enable_cinn(self): def set_enable_cinn(self):
self.enable_cinn = True # TODO(thisjiang): set `pass` after reduce+cast at shape=[4, 5, 20, 20], dim=[2]'s fusion bug has fixed
self.enable_cinn = False
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
......
...@@ -103,6 +103,8 @@ class TestInstanceNormOp(OpTest): ...@@ -103,6 +103,8 @@ class TestInstanceNormOp(OpTest):
self.fw_comp_atol = 1e-6 self.fw_comp_atol = 1e-6
self.rev_comp_rtol = 1e-4 self.rev_comp_rtol = 1e-4
self.rev_comp_atol = 1e-4 self.rev_comp_atol = 1e-4
self.cinn_rtol = 1e-4
self.cinn_atol = 1e-4
self.init_test_case() self.init_test_case()
ref_y_np, ref_mean_np, ref_var_np_tmp = _reference_instance_norm_naive( ref_y_np, ref_mean_np, ref_var_np_tmp = _reference_instance_norm_naive(
self.x_np, self.x_np,
......
...@@ -284,9 +284,6 @@ class TestReduceMeanOpShape6D(TestReduceMeanOp): ...@@ -284,9 +284,6 @@ class TestReduceMeanOpShape6D(TestReduceMeanOp):
def set_attrs(self): def set_attrs(self):
self.shape = [2, 3, 4, 5, 6, 7] self.shape = [2, 3, 4, 5, 6, 7]
def if_enable_cinn(self):
self.enable_cinn = False
class TestReduceMeanOpShape6DBF16(TestReduceMeanBF16Op): class TestReduceMeanOpShape6DBF16(TestReduceMeanBF16Op):
def set_attrs(self): def set_attrs(self):
...@@ -298,9 +295,6 @@ class TestReduceMeanOpShape6DFP16(TestReduceMeanOp): ...@@ -298,9 +295,6 @@ class TestReduceMeanOpShape6DFP16(TestReduceMeanOp):
self.shape = [2, 3, 4, 5, 6, 7] self.shape = [2, 3, 4, 5, 6, 7]
self.dtype = 'float16' self.dtype = 'float16'
def if_enable_cinn(self):
self.enable_cinn = False
class TestReduceMeanOpAxisAll(TestReduceMeanOp): class TestReduceMeanOpAxisAll(TestReduceMeanOp):
def set_attrs(self): def set_attrs(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册