diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index d666116b6ad27907ffedb5fcddc76519ccbc1683..437f83a900b22bc3f9f78a4b8a5880855bc1877b 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -1195,7 +1195,9 @@ set(TEST_CINN_OPS test_flip test_triangular_solve_op test_scatter_nd_op - test_strided_slice_op) + test_strided_slice_op + test_instance_norm_op + test_cumsum_op) foreach(TEST_CINN_OPS ${TEST_CINN_OPS}) if(WITH_CINN) diff --git a/test/legacy_test/test_cumsum_op.py b/test/legacy_test/test_cumsum_op.py index b2031a792bd0ec2a4b4e4ee3f09245ca6fe39bea..60d5855fb0ad2b63bc3661e958a04abedab773c6 100644 --- a/test/legacy_test/test_cumsum_op.py +++ b/test/legacy_test/test_cumsum_op.py @@ -314,7 +314,8 @@ class TestSumOpExclusiveFP16(OpTest): self.python_api = cumsum_wrapper self.public_python_api = paddle.cumsum self.init_dtype() - self.enable_cinn = True + # TODO(thisjiang): set `True` after reduce+cast at shape=[4, 5, 20, 20], dim=[2]'s fusion bug has fixed + self.enable_cinn = False self.attrs = {'axis': 2, "exclusive": True} self.x = np.random.random((4, 5, 20)).astype(self.dtype) self.out = np.concatenate( @@ -389,7 +390,8 @@ def create_test_fp16_class(parent, max_relative_error=1e-2): self.dtype = self.dtype_ = np.float16 def set_enable_cinn(self): - self.enable_cinn = True + # TODO(thisjiang): set `pass` after reduce+cast at shape=[4, 5, 20, 20], dim=[2]'s fusion bug has fixed + self.enable_cinn = False def test_check_output(self): self.check_output() diff --git a/test/legacy_test/test_instance_norm_op.py b/test/legacy_test/test_instance_norm_op.py index 84905484d07f2fc2020c850bf00526a577ba0be7..fe8ed7bf150c1424dc74f1ef60fc0eb0550586ec 100644 --- a/test/legacy_test/test_instance_norm_op.py +++ b/test/legacy_test/test_instance_norm_op.py @@ -103,6 +103,8 @@ class TestInstanceNormOp(OpTest): self.fw_comp_atol = 1e-6 self.rev_comp_rtol = 1e-4 self.rev_comp_atol = 1e-4 + self.cinn_rtol = 1e-4 + self.cinn_atol = 1e-4 self.init_test_case() ref_y_np, ref_mean_np, ref_var_np_tmp = _reference_instance_norm_naive( self.x_np, diff --git a/test/legacy_test/test_mean_op.py b/test/legacy_test/test_mean_op.py index 78bcc7a46bc9b8e59fea593055a311e1fce4f7f4..9e5e7fc17c949826614fc92b9ea776a28dcd364c 100644 --- a/test/legacy_test/test_mean_op.py +++ b/test/legacy_test/test_mean_op.py @@ -284,9 +284,6 @@ class TestReduceMeanOpShape6D(TestReduceMeanOp): def set_attrs(self): self.shape = [2, 3, 4, 5, 6, 7] - def if_enable_cinn(self): - self.enable_cinn = False - class TestReduceMeanOpShape6DBF16(TestReduceMeanBF16Op): def set_attrs(self): @@ -298,9 +295,6 @@ class TestReduceMeanOpShape6DFP16(TestReduceMeanOp): self.shape = [2, 3, 4, 5, 6, 7] self.dtype = 'float16' - def if_enable_cinn(self): - self.enable_cinn = False - class TestReduceMeanOpAxisAll(TestReduceMeanOp): def set_attrs(self):