From 0c6eef3e58b3cdc182d1d8531eb227abc065857f Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Tue, 17 Apr 2018 15:48:05 +0800 Subject: [PATCH] add split by ref test --- python/paddle/fluid/tests/unittests/test_split_op.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_split_op.py b/python/paddle/fluid/tests/unittests/test_split_op.py index 887bdfe8b36..5a7123c36b1 100644 --- a/python/paddle/fluid/tests/unittests/test_split_op.py +++ b/python/paddle/fluid/tests/unittests/test_split_op.py @@ -19,7 +19,6 @@ from op_test import OpTest class TestSplitOp(OpTest): def setUp(self): - self.op_type = "split" axis = 1 x = np.random.random((4, 5, 6)).astype('float32') out = np.split(x, [2, 3], axis) @@ -28,6 +27,9 @@ class TestSplitOp(OpTest): self.outputs = {'Out': [('out%d' % i, out[i]) \ for i in xrange(len(out))]} + def _set_op_type(self): + self.op_type = "split" + def test_check_output(self): self.check_output() @@ -35,5 +37,10 @@ class TestSplitOp(OpTest): self.check_grad(['X'], ['out0', 'out1', 'out2']) +class TestSplitByrefOp(OpTest): + def _set_op_type(self): + self.op_type = "split_byref" + + if __name__ == '__main__': unittest.main() -- GitLab