From 099b3d25b08eae5f51e2b6322ec55092acc5870d Mon Sep 17 00:00:00 2001 From: jiangcheng Date: Wed, 7 Jun 2023 15:17:08 +0800 Subject: [PATCH] [CINN] reopen mean/cumsum/instance_norm op's prim+CINN test (#54406) * [CINN] reopen mean/cumsum/instance_norm op's prim+CINN test * remove repeat test_mean_op in cmake --- test/legacy_test/CMakeLists.txt | 4 +++- test/legacy_test/test_cumsum_op.py | 6 ++++-- test/legacy_test/test_instance_norm_op.py | 2 ++ test/legacy_test/test_mean_op.py | 6 ------ 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index d666116b6ad..437f83a900b 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -1195,7 +1195,9 @@ set(TEST_CINN_OPS test_flip test_triangular_solve_op test_scatter_nd_op - test_strided_slice_op) + test_strided_slice_op + test_instance_norm_op + test_cumsum_op) foreach(TEST_CINN_OPS ${TEST_CINN_OPS}) if(WITH_CINN) diff --git a/test/legacy_test/test_cumsum_op.py b/test/legacy_test/test_cumsum_op.py index b2031a792bd..60d5855fb0a 100644 --- a/test/legacy_test/test_cumsum_op.py +++ b/test/legacy_test/test_cumsum_op.py @@ -314,7 +314,8 @@ class TestSumOpExclusiveFP16(OpTest): self.python_api = cumsum_wrapper self.public_python_api = paddle.cumsum self.init_dtype() - self.enable_cinn = True + # TODO(thisjiang): set `True` after reduce+cast at shape=[4, 5, 20, 20], dim=[2]'s fusion bug has fixed + self.enable_cinn = False self.attrs = {'axis': 2, "exclusive": True} self.x = np.random.random((4, 5, 20)).astype(self.dtype) self.out = np.concatenate( @@ -389,7 +390,8 @@ def create_test_fp16_class(parent, max_relative_error=1e-2): self.dtype = self.dtype_ = np.float16 def set_enable_cinn(self): - self.enable_cinn = True + # TODO(thisjiang): set `pass` after reduce+cast at shape=[4, 5, 20, 20], dim=[2]'s fusion bug has fixed + self.enable_cinn = False def test_check_output(self): self.check_output() diff --git a/test/legacy_test/test_instance_norm_op.py b/test/legacy_test/test_instance_norm_op.py index 84905484d07..fe8ed7bf150 100644 --- a/test/legacy_test/test_instance_norm_op.py +++ b/test/legacy_test/test_instance_norm_op.py @@ -103,6 +103,8 @@ class TestInstanceNormOp(OpTest): self.fw_comp_atol = 1e-6 self.rev_comp_rtol = 1e-4 self.rev_comp_atol = 1e-4 + self.cinn_rtol = 1e-4 + self.cinn_atol = 1e-4 self.init_test_case() ref_y_np, ref_mean_np, ref_var_np_tmp = _reference_instance_norm_naive( self.x_np, diff --git a/test/legacy_test/test_mean_op.py b/test/legacy_test/test_mean_op.py index 78bcc7a46bc..9e5e7fc17c9 100644 --- a/test/legacy_test/test_mean_op.py +++ b/test/legacy_test/test_mean_op.py @@ -284,9 +284,6 @@ class TestReduceMeanOpShape6D(TestReduceMeanOp): def set_attrs(self): self.shape = [2, 3, 4, 5, 6, 7] - def if_enable_cinn(self): - self.enable_cinn = False - class TestReduceMeanOpShape6DBF16(TestReduceMeanBF16Op): def set_attrs(self): @@ -298,9 +295,6 @@ class TestReduceMeanOpShape6DFP16(TestReduceMeanOp): self.shape = [2, 3, 4, 5, 6, 7] self.dtype = 'float16' - def if_enable_cinn(self): - self.enable_cinn = False - class TestReduceMeanOpAxisAll(TestReduceMeanOp): def set_attrs(self): -- GitLab