提交 99b9eafe 编写于 作者: H helinwang 提交者: GitHub

Merge pull request #1515 from helinwang/int_seq

change argument name of data_type.integer_value/integer_value_sequenc…
...@@ -65,14 +65,18 @@ def sparse_value_slot(dim, seq_type=SequenceType.NO_SEQUENCE): ...@@ -65,14 +65,18 @@ def sparse_value_slot(dim, seq_type=SequenceType.NO_SEQUENCE):
return InputType(dim, seq_type, DataType.SparseValue) return InputType(dim, seq_type, DataType.SparseValue)
def index_slot(dim, seq_type=SequenceType.NO_SEQUENCE): def index_slot(value_range, seq_type=SequenceType.NO_SEQUENCE):
return InputType(dim, seq_type, DataType.Index) """Data type of integer.
:param value_range: range of this integer.
"""
return InputType(value_range, seq_type, DataType.Index)
dense_vector = dense_slot dense_vector = dense_slot
sparse_binary_vector = sparse_non_value_slot sparse_binary_vector = sparse_non_value_slot
sparse_vector = sparse_value_slot sparse_vector = sparse_value_slot
integer_value = index_slot integer_value = index_slot
integer_value.__doc__ = index_slot.__doc__
def dense_vector_sequence(dim): def dense_vector_sequence(dim):
...@@ -99,8 +103,11 @@ def sparse_vector_sub_sequence(dim): ...@@ -99,8 +103,11 @@ def sparse_vector_sub_sequence(dim):
return sparse_vector(dim, seq_type=SequenceType.SUB_SEQUENCE) return sparse_vector(dim, seq_type=SequenceType.SUB_SEQUENCE)
def integer_value_sequence(dim): def integer_value_sequence(value_range):
return integer_value(dim, seq_type=SequenceType.SEQUENCE) """Data type of a sequence of integer.
:param value_range: range of each element.
"""
return integer_value(value_range, seq_type=SequenceType.SEQUENCE)
def integer_value_sub_sequence(dim): def integer_value_sub_sequence(dim):
...@@ -108,6 +115,7 @@ def integer_value_sub_sequence(dim): ...@@ -108,6 +115,7 @@ def integer_value_sub_sequence(dim):
integer_sequence = integer_value_sequence integer_sequence = integer_value_sequence
integer_sequence.__doc__ = integer_value_sequence.__doc__
class SingleSlotWrapper(object): class SingleSlotWrapper(object):
......
...@@ -110,14 +110,14 @@ class DataFeederTest(unittest.TestCase): ...@@ -110,14 +110,14 @@ class DataFeederTest(unittest.TestCase):
self.assertAlmostEqual(value.all(), w[i].all()) self.assertAlmostEqual(value.all(), w[i].all())
def test_integer(self): def test_integer(self):
dim = 100 value_range = 100
batch_size = 32 batch_size = 32
index = [] index = []
for i in xrange(batch_size): for i in xrange(batch_size):
each_sample = [] each_sample = []
each_sample.append(np.random.randint(dim)) each_sample.append(np.random.randint(value_range))
index.append(each_sample) index.append(each_sample)
feeder = DataFeeder([('input', data_type.integer_value(dim))], feeder = DataFeeder([('input', data_type.integer_value(value_range))],
{'input': 0}) {'input': 0})
arg = feeder(index) arg = feeder(index)
output = arg.getSlotIds(0).copyToNumpyArray() output = arg.getSlotIds(0).copyToNumpyArray()
...@@ -125,7 +125,7 @@ class DataFeederTest(unittest.TestCase): ...@@ -125,7 +125,7 @@ class DataFeederTest(unittest.TestCase):
self.assertEqual(output.all(), index.flatten().all()) self.assertEqual(output.all(), index.flatten().all())
def test_integer_sequence(self): def test_integer_sequence(self):
dim = 10000 value_range = 10000
batch_size = 32 batch_size = 32
start = [0] start = [0]
data = [] data = []
...@@ -133,10 +133,11 @@ class DataFeederTest(unittest.TestCase): ...@@ -133,10 +133,11 @@ class DataFeederTest(unittest.TestCase):
each_sample = [] each_sample = []
each_sample.append( each_sample.append(
self.sparse_binary_reader( self.sparse_binary_reader(
dim, 30, non_empty=True)) value_range, 30, non_empty=True))
data.append(each_sample) data.append(each_sample)
start.append(len(each_sample[0]) + start[-1]) start.append(len(each_sample[0]) + start[-1])
feeder = DataFeeder([('input', data_type.integer_value_sequence(dim))], feeder = DataFeeder(
[('input', data_type.integer_value_sequence(value_range))],
{'input': 0}) {'input': 0})
arg = feeder(data) arg = feeder(data)
output_data = arg.getSlotIds(0).copyToNumpyArray() output_data = arg.getSlotIds(0).copyToNumpyArray()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册