未验证 提交 099b3d25 编写于 作者: J jiangcheng 提交者: GitHub

[CINN] reopen mean/cumsum/instance_norm op's prim+CINN test (#54406)

* [CINN] reopen mean/cumsum/instance_norm op's prim+CINN test

* remove repeat test_mean_op in cmake
上级 791963ab
......@@ -1195,7 +1195,9 @@ set(TEST_CINN_OPS
test_flip
test_triangular_solve_op
test_scatter_nd_op
test_strided_slice_op)
test_strided_slice_op
test_instance_norm_op
test_cumsum_op)
foreach(TEST_CINN_OPS ${TEST_CINN_OPS})
if(WITH_CINN)
......
......@@ -314,7 +314,8 @@ class TestSumOpExclusiveFP16(OpTest):
self.python_api = cumsum_wrapper
self.public_python_api = paddle.cumsum
self.init_dtype()
self.enable_cinn = True
# TODO(thisjiang): set `True` after reduce+cast at shape=[4, 5, 20, 20], dim=[2]'s fusion bug has fixed
self.enable_cinn = False
self.attrs = {'axis': 2, "exclusive": True}
self.x = np.random.random((4, 5, 20)).astype(self.dtype)
self.out = np.concatenate(
......@@ -389,7 +390,8 @@ def create_test_fp16_class(parent, max_relative_error=1e-2):
self.dtype = self.dtype_ = np.float16
def set_enable_cinn(self):
self.enable_cinn = True
# TODO(thisjiang): set `pass` after reduce+cast at shape=[4, 5, 20, 20], dim=[2]'s fusion bug has fixed
self.enable_cinn = False
def test_check_output(self):
self.check_output()
......
......@@ -103,6 +103,8 @@ class TestInstanceNormOp(OpTest):
self.fw_comp_atol = 1e-6
self.rev_comp_rtol = 1e-4
self.rev_comp_atol = 1e-4
self.cinn_rtol = 1e-4
self.cinn_atol = 1e-4
self.init_test_case()
ref_y_np, ref_mean_np, ref_var_np_tmp = _reference_instance_norm_naive(
self.x_np,
......
......@@ -284,9 +284,6 @@ class TestReduceMeanOpShape6D(TestReduceMeanOp):
def set_attrs(self):
self.shape = [2, 3, 4, 5, 6, 7]
def if_enable_cinn(self):
self.enable_cinn = False
class TestReduceMeanOpShape6DBF16(TestReduceMeanBF16Op):
def set_attrs(self):
......@@ -298,9 +295,6 @@ class TestReduceMeanOpShape6DFP16(TestReduceMeanOp):
self.shape = [2, 3, 4, 5, 6, 7]
self.dtype = 'float16'
def if_enable_cinn(self):
self.enable_cinn = False
class TestReduceMeanOpAxisAll(TestReduceMeanOp):
def set_attrs(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册