From f4c4c6179c9628055e13fc2237851a0fba801702 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 30 Jul 2018 14:18:48 +0800 Subject: [PATCH] optimize unit test of test_split_ids_op --- .../tests/unittests/test_split_ids_op.py | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_split_ids_op.py b/python/paddle/fluid/tests/unittests/test_split_ids_op.py index 4001877290..ca78613098 100644 --- a/python/paddle/fluid/tests/unittests/test_split_ids_op.py +++ b/python/paddle/fluid/tests/unittests/test_split_ids_op.py @@ -69,17 +69,18 @@ class TestSpliteIds(unittest.TestCase): op = Operator("split_ids", Ids="X", Out=outs_name) - op.run(scope, place) - - for i in range(len(outs)): - expected_rows = expected_out_rows[i] - self.assertEqual(outs[i].rows(), expected_rows) - for j in range(len(expected_rows)): - row = expected_rows[j] - self.assertAlmostEqual( - float(row), np.array(outs[i].get_tensor())[j, 0]) - self.assertAlmostEqual( - float(row + 1), np.array(outs[i].get_tensor())[j, 1]) + for _ in range(3): + op.run(scope, place) + + for i in range(len(outs)): + expected_rows = expected_out_rows[i] + self.assertEqual(outs[i].rows(), expected_rows) + for j in range(len(expected_rows)): + row = expected_rows[j] + self.assertAlmostEqual( + float(row), np.array(outs[i].get_tensor())[j, 0]) + self.assertAlmostEqual( + float(row + 1), np.array(outs[i].get_tensor())[j, 1]) if __name__ == '__main__': -- GitLab