diff --git a/demo/mnist/api_train_v2.py b/demo/mnist/api_train_v2.py index 6fc01ce58be57c77144c6558d039430b22d3a746..650bf392bbc73415d9033f8c8134d90fd05f0cc2 100644 --- a/demo/mnist/api_train_v2.py +++ b/demo/mnist/api_train_v2.py @@ -50,11 +50,12 @@ def main(): parameters=parameters, event_handler=event_handler, batch_size=32, # batch size should be refactor in Data reader - data_types={ # data_types will be removed, It should be in + data_types=[ # data_types will be removed, It should be in # network topology - 'pixel': images.type, - 'label': label.type - }) + ('pixel', images.type), + ('label', label.type)], + reader_dict={'pixel':0, 'label':1} + ) if __name__ == '__main__': diff --git a/python/paddle/v2/__init__.py b/python/paddle/v2/__init__.py index be752731ba232a43933b117085c5da3ee363b035..bf06b5a7e360e9c46fdfc7f819dff46b244dc7f5 100644 --- a/python/paddle/v2/__init__.py +++ b/python/paddle/v2/__init__.py @@ -18,6 +18,7 @@ import parameters import trainer import event import data_type +import data_feeder import py_paddle.swig_paddle as api __all__ = [ diff --git a/python/paddle/v2/tests/test_data_feeder.py b/python/paddle/v2/tests/test_data_feeder.py index dcf433d7d8fdce0a473225c36410ec029b355f37..95a59a5d97fb672ba1e4dc7bead95d6556f37277 100644 --- a/python/paddle/v2/tests/test_data_feeder.py +++ b/python/paddle/v2/tests/test_data_feeder.py @@ -36,7 +36,7 @@ class DataFeederTest(unittest.TestCase): def compare(input): feeder = DataFeeder([('image', data_type.dense_vector(784))], {'image': 0}) - arg = feeder([input]) + arg = feeder(input) output = arg.getSlotValue(0).copyToNumpyMat() input = np.array(input, dtype='float32') self.assertAlmostEqual(input.all(), output.all()) @@ -46,13 +46,17 @@ class DataFeederTest(unittest.TestCase): dim = 784 data = [] for i in xrange(batch_size): - data.append(self.dense_reader(784)) + each_sample = [] + each_sample.append(self.dense_reader(dim)) + data.append(each_sample) compare(data) # test list data = [] for i in xrange(batch_size): - data.append(self.dense_reader(784).tolist()) + each_sample = [] + each_sample.append(self.dense_reader(dim).tolist()) + data.append(each_sample) compare(data) def test_sparse_binary(self): @@ -60,7 +64,9 @@ class DataFeederTest(unittest.TestCase): batch_size = 32 data = [] for i in xrange(batch_size): - data.append([self.sparse_binary_reader(dim, 50)]) + each_sample = [] + each_sample.append(self.sparse_binary_reader(dim, 50)) + data.append(each_sample) feeder = DataFeeder([('input', data_type.sparse_binary_vector(dim))], {'input': 0}) arg = feeder(data) @@ -76,11 +82,13 @@ class DataFeederTest(unittest.TestCase): w = [] data = [] for dat in xrange(batch_size): + each_sample = [] a = self.sparse_binary_reader(dim, 40, non_empty=True) b = self.dense_reader(len(a)).tolist() v.append(a) w.append(b[0]) - data.append([zip(a, b)]) + each_sample.append(zip(a, b)) + data.append(each_sample) feeder = DataFeeder([('input', data_type.sparse_vector(dim))], {'input': 0}) @@ -95,7 +103,9 @@ class DataFeederTest(unittest.TestCase): batch_size = 32 index = [] for i in xrange(batch_size): - index.append([np.random.randint(dim)]) + each_sample = [] + each_sample.append(np.random.randint(dim)) + index.append(each_sample) feeder = DataFeeder([('input', data_type.integer_value(dim))], {'input': 0}) arg = feeder(index) diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index 5709c7e886c323b9159f18a52133770ea675fa5b..023ab5e42d25b9f70827b1e2efba985a5442db1f 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -69,7 +69,8 @@ class SGD(ITrainer): test_data_reader=None, event_handler=None, batch_size=32, - data_types=None): + data_types=None, + reader_dict=None): """ Training method. Will train num_passes of input data. @@ -103,13 +104,7 @@ class SGD(ITrainer): gm.start() out_args = api.Arguments.createArguments(0) - data_types_lists = [] - for each in topology.input_layer_names: - if each not in data_types: - raise ValueError() - data_types_lists.append(data_types[each]) - - feeder = DataFeeder(input_types=data_types_lists) + feeder = DataFeeder(data_types, reader_dict) for pass_id in xrange(num_passes): updater.startPass()