未验证 提交 ec51485f 编写于 作者: G GGBond8488 提交者: GitHub

Open axis=none test for cumsum prim (#51243)

* fix cumsum prim op maker type error

* add axis=None test for cumsum_grad prim

* fix t
上级 468c17ff
......@@ -181,7 +181,7 @@ class TestSumOp4(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=True)
class TestSumOp5(OpTest):
......@@ -197,12 +197,15 @@ class TestSumOp5(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=True)
class TestSumOp6(OpTest):
def setUp(self):
self.op_type = "cumsum"
self.prim_op_type = "prim"
self.python_api = paddle.cumsum
self.enable_cinn = False
self.attrs = {'axis': -1, 'flatten': True}
self.inputs = {'X': np.random.random((5, 6, 5)).astype("float64")}
self.outputs = {'Out': self.inputs['X'].cumsum()}
......@@ -211,7 +214,7 @@ class TestSumOp6(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=True)
class TestSumOp7(OpTest):
......@@ -227,7 +230,7 @@ class TestSumOp7(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=True)
class TestCumsumFP16(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册