提交 f3c7fbee 编写于 作者: D dangqingqing

make minst to run

上级 c330af31
...@@ -50,11 +50,12 @@ def main(): ...@@ -50,11 +50,12 @@ def main():
parameters=parameters, parameters=parameters,
event_handler=event_handler, event_handler=event_handler,
batch_size=32, # batch size should be refactor in Data reader batch_size=32, # batch size should be refactor in Data reader
data_types={ # data_types will be removed, It should be in data_types=[ # data_types will be removed, It should be in
# network topology # network topology
'pixel': images.type, ('pixel', images.type),
'label': label.type ('label', label.type)],
}) reader_dict={'pixel':0, 'label':1}
)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -18,6 +18,7 @@ import parameters ...@@ -18,6 +18,7 @@ import parameters
import trainer import trainer
import event import event
import data_type import data_type
import data_feeder
import py_paddle.swig_paddle as api import py_paddle.swig_paddle as api
__all__ = [ __all__ = [
......
...@@ -36,7 +36,7 @@ class DataFeederTest(unittest.TestCase): ...@@ -36,7 +36,7 @@ class DataFeederTest(unittest.TestCase):
def compare(input): def compare(input):
feeder = DataFeeder([('image', data_type.dense_vector(784))], feeder = DataFeeder([('image', data_type.dense_vector(784))],
{'image': 0}) {'image': 0})
arg = feeder([input]) arg = feeder(input)
output = arg.getSlotValue(0).copyToNumpyMat() output = arg.getSlotValue(0).copyToNumpyMat()
input = np.array(input, dtype='float32') input = np.array(input, dtype='float32')
self.assertAlmostEqual(input.all(), output.all()) self.assertAlmostEqual(input.all(), output.all())
...@@ -46,13 +46,17 @@ class DataFeederTest(unittest.TestCase): ...@@ -46,13 +46,17 @@ class DataFeederTest(unittest.TestCase):
dim = 784 dim = 784
data = [] data = []
for i in xrange(batch_size): for i in xrange(batch_size):
data.append(self.dense_reader(784)) each_sample = []
each_sample.append(self.dense_reader(dim))
data.append(each_sample)
compare(data) compare(data)
# test list # test list
data = [] data = []
for i in xrange(batch_size): for i in xrange(batch_size):
data.append(self.dense_reader(784).tolist()) each_sample = []
each_sample.append(self.dense_reader(dim).tolist())
data.append(each_sample)
compare(data) compare(data)
def test_sparse_binary(self): def test_sparse_binary(self):
...@@ -60,7 +64,9 @@ class DataFeederTest(unittest.TestCase): ...@@ -60,7 +64,9 @@ class DataFeederTest(unittest.TestCase):
batch_size = 32 batch_size = 32
data = [] data = []
for i in xrange(batch_size): for i in xrange(batch_size):
data.append([self.sparse_binary_reader(dim, 50)]) each_sample = []
each_sample.append(self.sparse_binary_reader(dim, 50))
data.append(each_sample)
feeder = DataFeeder([('input', data_type.sparse_binary_vector(dim))], feeder = DataFeeder([('input', data_type.sparse_binary_vector(dim))],
{'input': 0}) {'input': 0})
arg = feeder(data) arg = feeder(data)
...@@ -76,11 +82,13 @@ class DataFeederTest(unittest.TestCase): ...@@ -76,11 +82,13 @@ class DataFeederTest(unittest.TestCase):
w = [] w = []
data = [] data = []
for dat in xrange(batch_size): for dat in xrange(batch_size):
each_sample = []
a = self.sparse_binary_reader(dim, 40, non_empty=True) a = self.sparse_binary_reader(dim, 40, non_empty=True)
b = self.dense_reader(len(a)).tolist() b = self.dense_reader(len(a)).tolist()
v.append(a) v.append(a)
w.append(b[0]) w.append(b[0])
data.append([zip(a, b)]) each_sample.append(zip(a, b))
data.append(each_sample)
feeder = DataFeeder([('input', data_type.sparse_vector(dim))], feeder = DataFeeder([('input', data_type.sparse_vector(dim))],
{'input': 0}) {'input': 0})
...@@ -95,7 +103,9 @@ class DataFeederTest(unittest.TestCase): ...@@ -95,7 +103,9 @@ class DataFeederTest(unittest.TestCase):
batch_size = 32 batch_size = 32
index = [] index = []
for i in xrange(batch_size): for i in xrange(batch_size):
index.append([np.random.randint(dim)]) each_sample = []
each_sample.append(np.random.randint(dim))
index.append(each_sample)
feeder = DataFeeder([('input', data_type.integer_value(dim))], feeder = DataFeeder([('input', data_type.integer_value(dim))],
{'input': 0}) {'input': 0})
arg = feeder(index) arg = feeder(index)
......
...@@ -69,7 +69,8 @@ class SGD(ITrainer): ...@@ -69,7 +69,8 @@ class SGD(ITrainer):
test_data_reader=None, test_data_reader=None,
event_handler=None, event_handler=None,
batch_size=32, batch_size=32,
data_types=None): data_types=None,
reader_dict=None):
""" """
Training method. Will train num_passes of input data. Training method. Will train num_passes of input data.
...@@ -103,13 +104,7 @@ class SGD(ITrainer): ...@@ -103,13 +104,7 @@ class SGD(ITrainer):
gm.start() gm.start()
out_args = api.Arguments.createArguments(0) out_args = api.Arguments.createArguments(0)
data_types_lists = [] feeder = DataFeeder(data_types, reader_dict)
for each in topology.input_layer_names:
if each not in data_types:
raise ValueError()
data_types_lists.append(data_types[each])
feeder = DataFeeder(input_types=data_types_lists)
for pass_id in xrange(num_passes): for pass_id in xrange(num_passes):
updater.startPass() updater.startPass()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册