提交 f3c7fbee 编写于 作者: D dangqingqing

make minst to run

上级 c330af31
......@@ -50,11 +50,12 @@ def main():
parameters=parameters,
event_handler=event_handler,
batch_size=32, # batch size should be refactor in Data reader
data_types={ # data_types will be removed, It should be in
data_types=[ # data_types will be removed, It should be in
# network topology
'pixel': images.type,
'label': label.type
})
('pixel', images.type),
('label', label.type)],
reader_dict={'pixel':0, 'label':1}
)
if __name__ == '__main__':
......
......@@ -18,6 +18,7 @@ import parameters
import trainer
import event
import data_type
import data_feeder
import py_paddle.swig_paddle as api
__all__ = [
......
......@@ -36,7 +36,7 @@ class DataFeederTest(unittest.TestCase):
def compare(input):
feeder = DataFeeder([('image', data_type.dense_vector(784))],
{'image': 0})
arg = feeder([input])
arg = feeder(input)
output = arg.getSlotValue(0).copyToNumpyMat()
input = np.array(input, dtype='float32')
self.assertAlmostEqual(input.all(), output.all())
......@@ -46,13 +46,17 @@ class DataFeederTest(unittest.TestCase):
dim = 784
data = []
for i in xrange(batch_size):
data.append(self.dense_reader(784))
each_sample = []
each_sample.append(self.dense_reader(dim))
data.append(each_sample)
compare(data)
# test list
data = []
for i in xrange(batch_size):
data.append(self.dense_reader(784).tolist())
each_sample = []
each_sample.append(self.dense_reader(dim).tolist())
data.append(each_sample)
compare(data)
def test_sparse_binary(self):
......@@ -60,7 +64,9 @@ class DataFeederTest(unittest.TestCase):
batch_size = 32
data = []
for i in xrange(batch_size):
data.append([self.sparse_binary_reader(dim, 50)])
each_sample = []
each_sample.append(self.sparse_binary_reader(dim, 50))
data.append(each_sample)
feeder = DataFeeder([('input', data_type.sparse_binary_vector(dim))],
{'input': 0})
arg = feeder(data)
......@@ -76,11 +82,13 @@ class DataFeederTest(unittest.TestCase):
w = []
data = []
for dat in xrange(batch_size):
each_sample = []
a = self.sparse_binary_reader(dim, 40, non_empty=True)
b = self.dense_reader(len(a)).tolist()
v.append(a)
w.append(b[0])
data.append([zip(a, b)])
each_sample.append(zip(a, b))
data.append(each_sample)
feeder = DataFeeder([('input', data_type.sparse_vector(dim))],
{'input': 0})
......@@ -95,7 +103,9 @@ class DataFeederTest(unittest.TestCase):
batch_size = 32
index = []
for i in xrange(batch_size):
index.append([np.random.randint(dim)])
each_sample = []
each_sample.append(np.random.randint(dim))
index.append(each_sample)
feeder = DataFeeder([('input', data_type.integer_value(dim))],
{'input': 0})
arg = feeder(index)
......
......@@ -69,7 +69,8 @@ class SGD(ITrainer):
test_data_reader=None,
event_handler=None,
batch_size=32,
data_types=None):
data_types=None,
reader_dict=None):
"""
Training method. Will train num_passes of input data.
......@@ -103,13 +104,7 @@ class SGD(ITrainer):
gm.start()
out_args = api.Arguments.createArguments(0)
data_types_lists = []
for each in topology.input_layer_names:
if each not in data_types:
raise ValueError()
data_types_lists.append(data_types[each])
feeder = DataFeeder(input_types=data_types_lists)
feeder = DataFeeder(data_types, reader_dict)
for pass_id in xrange(num_passes):
updater.startPass()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册