提交 bb7db754 编写于 作者: D dangqingqing

add testing for duplicate item

上级 c109e3bf
......@@ -176,7 +176,7 @@ class DataFeederTest(unittest.TestCase):
self.assertEqual(output_sparse.getSparseRowCols(i), data[i][1])
self.assertEqual(output_index[i], data[i][0])
# reader returns 3 featreus, but only use 2 features
# reader returns 3 features, but only use 2 features
data_types = [('fea0', data_type.dense_vector(100)),
('fea2', data_type.integer_value(10))]
feeder = DataFeeder(data_types, {'fea0': 2, 'fea2': 0})
......@@ -187,6 +187,27 @@ class DataFeederTest(unittest.TestCase):
self.assertEqual(output_dense[i].all(), data[i][2].all())
self.assertEqual(output_index[i], data[i][0])
# reader returns 3 featreus, one is duplicate data
data_types = [('fea0', data_type.dense_vector(100)),
('fea1', data_type.sparse_binary_vector(20000)),
('fea2', data_type.integer_value(10)),
('fea3', data_type.dense_vector(100))]
feeder = DataFeeder(data_types,
{'fea0': 2,
'fea1': 1,
'fea2': 0,
'fea3': 2})
arg = feeder(data)
fea0 = arg.getSlotValue(0).copyToNumpyMat()
fea1 = arg.getSlotValue(1)
fea2 = arg.getSlotIds(2).copyToNumpyArray()
fea3 = arg.getSlotValue(3).copyToNumpyMat()
for i in xrange(batch_size):
self.assertEqual(fea0[i].all(), data[i][2].all())
self.assertEqual(fea1.getSparseRowCols(i), data[i][1])
self.assertEqual(fea2[i], data[i][0])
self.assertEqual(fea3[i].all(), data[i][2].all())
def test_multiple_features_tuple(self):
batch_size = 2
data = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册