diff --git a/demo/mnist/api_train_v2.py b/demo/mnist/api_train_v2.py index 650bf392bbc73415d9033f8c8134d90fd05f0cc2..67e36e6b201f3858edb0a4047b45a9e60fb3a9a5 100644 --- a/demo/mnist/api_train_v2.py +++ b/demo/mnist/api_train_v2.py @@ -1,15 +1,6 @@ import numpy import paddle.v2 as paddle -import mnist_util - - -def train_reader(): - train_file = './data/raw_data/train' - generator = mnist_util.read_from_mnist(train_file) - for item in generator: - yield item - def main(): paddle.init(use_gpu=False, trainer_count=1) @@ -45,17 +36,24 @@ def main(): trainer = paddle.trainer.SGD(update_equation=adam_optimizer) - trainer.train(train_data_reader=train_reader, - topology=cost, - parameters=parameters, - event_handler=event_handler, - batch_size=32, # batch size should be refactor in Data reader - data_types=[ # data_types will be removed, It should be in - # network topology - ('pixel', images.type), - ('label', label.type)], - reader_dict={'pixel':0, 'label':1} - ) + reader = paddle.reader.batched( + paddle.reader.shuffle( + paddle.dataset.mnist.train_creator(), buf_size=8192), + batch_size=32) + + trainer.train( + train_reader=paddle.reader.batched( + paddle.reader.shuffle(paddle.dataset.mnist.train_creator(), + buf_size=8192), batch_size=32), + topology=cost, + parameters=parameters, + event_handler=event_handler, + data_types=[ # data_types will be removed, It should be in + # network topology + ('pixel', images.type), + ('label', label.type)], + reader_dict={'pixel': 0, 'label': 1} + ) if __name__ == '__main__': diff --git a/python/paddle/v2/__init__.py b/python/paddle/v2/__init__.py index 1122bcb5e45727cf78031259817a84f9e36a3163..cc130caa15f64576307a3a2e6562af0231b9f8dd 100644 --- a/python/paddle/v2/__init__.py +++ b/python/paddle/v2/__init__.py @@ -20,12 +20,13 @@ import event import data_type import data_feeder from . import dataset +from . import reader import attr import py_paddle.swig_paddle as api __all__ = [ 'optimizer', 'layer', 'activation', 'parameters', 'init', 'trainer', - 'event', 'data_type', 'attr', 'data_feeder', 'dataset' + 'event', 'data_type', 'attr', 'data_feeder', 'dataset', 'reader' ] diff --git a/python/paddle/v2/reader/decorator.py b/python/paddle/v2/reader/decorator.py index 9f4234358f469b9741dfef96aab3853d5972b5f4..d83e2f4577f20bcd2220a1a29b394d925126922e 100644 --- a/python/paddle/v2/reader/decorator.py +++ b/python/paddle/v2/reader/decorator.py @@ -14,7 +14,7 @@ __all__ = [ 'map_readers', 'buffered', 'compose', 'chain', 'shuffle', - 'ComposeNotAligned' + 'ComposeNotAligned', 'batched' ] from Queue import Queue @@ -191,3 +191,25 @@ def buffered(reader, size): e = q.get() return data_reader + + +def batched(reader, batch_size): + """ + Create a batched reader. + :param reader: the data reader to read from. + :param batch_size: batch_size + :return: the batched reader. + """ + + def __impl__(): + r = reader() + batch = [] + for instance in r: + batch.append(instance) + if len(batch) == batch_size: + yield batch + batch = [] + if batch: + yield batch + + return __impl__ diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index 89415787eb2e36f29cfed8bb5144558d82337fb0..9b45ee69f39f3fc1e75683f31cb336d82554698c 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -29,7 +29,7 @@ class ITrainer(object): """ def train(self, - train_data_reader, + train_reader_creator, topology, parameters, test_data_reader=None, @@ -37,7 +37,7 @@ class ITrainer(object): """ train method. - :param train_data_reader: + :param train_reader_creator: :param topology: :param parameters: :param test_data_reader: @@ -62,27 +62,23 @@ class SGD(ITrainer): self.__optimizer__ = update_equation def train(self, - train_reader_creator, + train_reader, topology, parameters, num_passes=1, - test_data_reader=None, event_handler=None, - batch_size=32, data_types=None, reader_dict=None): """ Training method. Will train num_passes of input data. - :param train_reader_creator: + :param train_reader: :param topology: Network Topology, use one or more Layers to represent it. :param parameters: The parameter pools. :param num_passes: The total train passes. - :param test_data_reader: :param event_handler: Event handler. A method will be invoked when event occurred. :type event_handler: (BaseEvent) => None - :param batch_size: Not important, will be removed after data refactor. :param data_types: Not important, will be removed after data refactor. :return: """ @@ -108,9 +104,7 @@ class SGD(ITrainer): for pass_id in xrange(num_passes): updater.startPass() - for batch_id, data_batch in enumerate( - __data_reader_to_batch__(train_reader_creator, batch_size, - topology)): + for batch_id, data_batch in enumerate(train_reader()): pass_type = updater.startBatch(len(data_batch)) gm.forwardBackward(feeder(data_batch), out_args, pass_type) for each_param in gm.getParameters(): @@ -128,51 +122,16 @@ class SGD(ITrainer): gm.finish() -def __data_reader_to_batch__(reader, batch_size, topology): - """ - This function is not important, and will be removed when data refactored. - """ - - def input_reorder(func): - for item in func(): - retv = [] - for __layer_name__ in topology.input_layer_names: - retv.append(item[__layer_name__]) - yield retv - - return __generator_to_batch__(input_reorder(reader), batch_size=batch_size) - - -def __generator_to_batch__(generator, batch_size): - """ - This function is not important, and will be removed when data refactored. - """ - ret_val = list() - for each_item in generator: - ret_val.append(each_item) - if len(ret_val) == batch_size: - yield ret_val - ret_val = list() - if len(ret_val) != 0: - yield ret_val - - -def __check_train_args__(train_data_reader, topology, parameters, - test_data_reader, event_handler, **kwargs): +def __check_train_args__(train_reader, topology, parameters, event_handler, + **kwargs): """ Check train function's argument types """ - if not callable(train_data_reader) or not isinstance(train_data_reader(), - collections.Iterator): + if not callable(train_reader) or not isinstance(train_reader(), + collections.Iterator): raise TypeError('train_data_reader should be a function, ' 'which can return a iterator') - if test_data_reader is not None: - if not callable(test_data_reader) or not isinstance( - test_data_reader(), collections.Iterator): - raise TypeError('test_data_reader should be a function, which can ' - 'return a iterator') - if not isinstance(topology, ModelConfig): raise TypeError('topology should be a model config')