diff --git a/python/paddle/trainer/PyDataProvider2.py b/python/paddle/trainer/PyDataProvider2.py index bd24c68b6fe88eab03c814f8cac70db3880316f4..afde7453a15fc0cc0196efac807dcc407869ec54 100644 --- a/python/paddle/trainer/PyDataProvider2.py +++ b/python/paddle/trainer/PyDataProvider2.py @@ -65,14 +65,18 @@ def sparse_value_slot(dim, seq_type=SequenceType.NO_SEQUENCE): return InputType(dim, seq_type, DataType.SparseValue) -def index_slot(dim, seq_type=SequenceType.NO_SEQUENCE): - return InputType(dim, seq_type, DataType.Index) +def index_slot(ele_range, seq_type=SequenceType.NO_SEQUENCE): + """Data type of integer. + :param ele_range: range of this integer. + """ + return InputType(ele_range, seq_type, DataType.Index) dense_vector = dense_slot sparse_binary_vector = sparse_non_value_slot sparse_vector = sparse_value_slot integer_value = index_slot +integer_value.__doc__ = index_slot.__doc__ def dense_vector_sequence(dim): @@ -99,8 +103,11 @@ def sparse_vector_sub_sequence(dim): return sparse_vector(dim, seq_type=SequenceType.SUB_SEQUENCE) -def integer_value_sequence(dim): - return integer_value(dim, seq_type=SequenceType.SEQUENCE) +def integer_value_sequence(ele_range): + """Data type of a sequence of integer. + :param ele_range: range of each element. + """ + return integer_value(ele_range, seq_type=SequenceType.SEQUENCE) def integer_value_sub_sequence(dim): @@ -108,6 +115,7 @@ def integer_value_sub_sequence(dim): integer_sequence = integer_value_sequence +integer_sequence.__doc__ = integer_value_sequence.__doc__ class SingleSlotWrapper(object): diff --git a/python/paddle/v2/tests/test_data_feeder.py b/python/paddle/v2/tests/test_data_feeder.py index ab2bc5df76cd839b5b0184e9559f0c2e03baf38b..1b1f5aef8b811b320d382d558afb3532932b3ebf 100644 --- a/python/paddle/v2/tests/test_data_feeder.py +++ b/python/paddle/v2/tests/test_data_feeder.py @@ -110,14 +110,14 @@ class DataFeederTest(unittest.TestCase): self.assertAlmostEqual(value.all(), w[i].all()) def test_integer(self): - dim = 100 + ele_range = 100 batch_size = 32 index = [] for i in xrange(batch_size): each_sample = [] - each_sample.append(np.random.randint(dim)) + each_sample.append(np.random.randint(ele_range)) index.append(each_sample) - feeder = DataFeeder([('input', data_type.integer_value(dim))], + feeder = DataFeeder([('input', data_type.integer_value(ele_range))], {'input': 0}) arg = feeder(index) output = arg.getSlotIds(0).copyToNumpyArray() @@ -125,7 +125,7 @@ class DataFeederTest(unittest.TestCase): self.assertEqual(output.all(), index.flatten().all()) def test_integer_sequence(self): - dim = 10000 + ele_range = 10000 batch_size = 32 start = [0] data = [] @@ -133,11 +133,12 @@ class DataFeederTest(unittest.TestCase): each_sample = [] each_sample.append( self.sparse_binary_reader( - dim, 30, non_empty=True)) + ele_range, 30, non_empty=True)) data.append(each_sample) start.append(len(each_sample[0]) + start[-1]) - feeder = DataFeeder([('input', data_type.integer_value_sequence(dim))], - {'input': 0}) + feeder = DataFeeder( + [('input', data_type.integer_value_sequence(ele_range))], + {'input': 0}) arg = feeder(data) output_data = arg.getSlotIds(0).copyToNumpyArray() output_start = arg.getSlotSequenceStartPositions(0).copyToNumpyArray()