From bb7db754208a7484ced25eb879bd77e7f6fae6c9 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Sat, 25 Feb 2017 10:07:15 +0800 Subject: [PATCH] add testing for duplicate item --- python/paddle/v2/tests/test_data_feeder.py | 23 +++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/python/paddle/v2/tests/test_data_feeder.py b/python/paddle/v2/tests/test_data_feeder.py index 4d5df6e8931..5f67da6a5b3 100644 --- a/python/paddle/v2/tests/test_data_feeder.py +++ b/python/paddle/v2/tests/test_data_feeder.py @@ -176,7 +176,7 @@ class DataFeederTest(unittest.TestCase): self.assertEqual(output_sparse.getSparseRowCols(i), data[i][1]) self.assertEqual(output_index[i], data[i][0]) - # reader returns 3 featreus, but only use 2 features + # reader returns 3 features, but only use 2 features data_types = [('fea0', data_type.dense_vector(100)), ('fea2', data_type.integer_value(10))] feeder = DataFeeder(data_types, {'fea0': 2, 'fea2': 0}) @@ -187,6 +187,27 @@ class DataFeederTest(unittest.TestCase): self.assertEqual(output_dense[i].all(), data[i][2].all()) self.assertEqual(output_index[i], data[i][0]) + # reader returns 3 featreus, one is duplicate data + data_types = [('fea0', data_type.dense_vector(100)), + ('fea1', data_type.sparse_binary_vector(20000)), + ('fea2', data_type.integer_value(10)), + ('fea3', data_type.dense_vector(100))] + feeder = DataFeeder(data_types, + {'fea0': 2, + 'fea1': 1, + 'fea2': 0, + 'fea3': 2}) + arg = feeder(data) + fea0 = arg.getSlotValue(0).copyToNumpyMat() + fea1 = arg.getSlotValue(1) + fea2 = arg.getSlotIds(2).copyToNumpyArray() + fea3 = arg.getSlotValue(3).copyToNumpyMat() + for i in xrange(batch_size): + self.assertEqual(fea0[i].all(), data[i][2].all()) + self.assertEqual(fea1.getSparseRowCols(i), data[i][1]) + self.assertEqual(fea2[i], data[i][0]) + self.assertEqual(fea3[i].all(), data[i][2].all()) + def test_multiple_features_tuple(self): batch_size = 2 data = [] -- GitLab