未验证 提交 2e4ac019 编写于 作者: J jiangcheng 提交者: GitHub

[CINN] reopen cumsum prim+cinn unittest (#54494)

上级 57564bdf
......@@ -142,7 +142,7 @@ class TestSumOp1(OpTest):
self.dtype = self.dtype_ = np.float64
def set_enable_cinn(self):
self.enable_cinn = True
pass
def set_attrs_input_output(self):
self.attrs = {'axis': 2}
......@@ -241,7 +241,7 @@ class TestSumOpExclusive1(OpTest):
self.dtype = self.dtype_ = np.float64
def set_enable_cinn(self):
self.enable_cinn = True
pass
def set_attrs_input_output(self):
self.attrs = {'axis': 2, 'exclusive': True}
......@@ -314,8 +314,6 @@ class TestSumOpExclusiveFP16(OpTest):
self.python_api = cumsum_wrapper
self.public_python_api = paddle.cumsum
self.init_dtype()
# TODO(thisjiang): set `True` after reduce+cast at shape=[4, 5, 20, 20], dim=[2]'s fusion bug has fixed
self.enable_cinn = False
self.attrs = {'axis': 2, "exclusive": True}
self.x = np.random.random((4, 5, 20)).astype(self.dtype)
self.out = np.concatenate(
......@@ -381,7 +379,7 @@ class TestSumOpReverseExclusive(OpTest):
self.dtype = self.dtype_ = np.float64
def set_enable_cinn(self):
self.enable_cinn = True
pass
def create_test_fp16_class(parent, max_relative_error=1e-2):
......@@ -390,8 +388,7 @@ def create_test_fp16_class(parent, max_relative_error=1e-2):
self.dtype = self.dtype_ = np.float16
def set_enable_cinn(self):
# TODO(thisjiang): set `pass` after reduce+cast at shape=[4, 5, 20, 20], dim=[2]'s fusion bug has fixed
self.enable_cinn = False
pass
def test_check_output(self):
self.check_output()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册