未验证 提交 ec51485f 编写于 作者: G GGBond8488 提交者: GitHub

Open axis=none test for cumsum prim (#51243)

* fix cumsum prim op maker type error

* add axis=None test for cumsum_grad prim

* fix t
上级 468c17ff
...@@ -181,7 +181,7 @@ class TestSumOp4(OpTest): ...@@ -181,7 +181,7 @@ class TestSumOp4(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out', check_prim=True)
class TestSumOp5(OpTest): class TestSumOp5(OpTest):
...@@ -197,12 +197,15 @@ class TestSumOp5(OpTest): ...@@ -197,12 +197,15 @@ class TestSumOp5(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out', check_prim=True)
class TestSumOp6(OpTest): class TestSumOp6(OpTest):
def setUp(self): def setUp(self):
self.op_type = "cumsum" self.op_type = "cumsum"
self.prim_op_type = "prim"
self.python_api = paddle.cumsum
self.enable_cinn = False
self.attrs = {'axis': -1, 'flatten': True} self.attrs = {'axis': -1, 'flatten': True}
self.inputs = {'X': np.random.random((5, 6, 5)).astype("float64")} self.inputs = {'X': np.random.random((5, 6, 5)).astype("float64")}
self.outputs = {'Out': self.inputs['X'].cumsum()} self.outputs = {'Out': self.inputs['X'].cumsum()}
...@@ -211,7 +214,7 @@ class TestSumOp6(OpTest): ...@@ -211,7 +214,7 @@ class TestSumOp6(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out', check_prim=True)
class TestSumOp7(OpTest): class TestSumOp7(OpTest):
...@@ -227,7 +230,7 @@ class TestSumOp7(OpTest): ...@@ -227,7 +230,7 @@ class TestSumOp7(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out', check_prim=True)
class TestCumsumFP16(unittest.TestCase): class TestCumsumFP16(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册