提交 46d2ca57 编写于 作者: Y Yu Yang

Combine Reader=>Feeder together.

上级 c26431ba
import numpy import numpy
import paddle.v2 as paddle import paddle.v2 as paddle
import mnist_util
def train_reader():
train_file = './data/raw_data/train'
generator = mnist_util.read_from_mnist(train_file)
for item in generator:
yield item
def main(): def main():
paddle.init(use_gpu=False, trainer_count=1) paddle.init(use_gpu=False, trainer_count=1)
...@@ -45,17 +36,24 @@ def main(): ...@@ -45,17 +36,24 @@ def main():
trainer = paddle.trainer.SGD(update_equation=adam_optimizer) trainer = paddle.trainer.SGD(update_equation=adam_optimizer)
trainer.train(train_data_reader=train_reader, reader = paddle.reader.batched(
topology=cost, paddle.reader.shuffle(
parameters=parameters, paddle.dataset.mnist.train_creator(), buf_size=8192),
event_handler=event_handler, batch_size=32)
batch_size=32, # batch size should be refactor in Data reader
data_types=[ # data_types will be removed, It should be in trainer.train(
# network topology train_reader=paddle.reader.batched(
('pixel', images.type), paddle.reader.shuffle(paddle.dataset.mnist.train_creator(),
('label', label.type)], buf_size=8192), batch_size=32),
reader_dict={'pixel':0, 'label':1} topology=cost,
) parameters=parameters,
event_handler=event_handler,
data_types=[ # data_types will be removed, It should be in
# network topology
('pixel', images.type),
('label', label.type)],
reader_dict={'pixel': 0, 'label': 1}
)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -20,12 +20,13 @@ import event ...@@ -20,12 +20,13 @@ import event
import data_type import data_type
import data_feeder import data_feeder
from . import dataset from . import dataset
from . import reader
import attr import attr
import py_paddle.swig_paddle as api import py_paddle.swig_paddle as api
__all__ = [ __all__ = [
'optimizer', 'layer', 'activation', 'parameters', 'init', 'trainer', 'optimizer', 'layer', 'activation', 'parameters', 'init', 'trainer',
'event', 'data_type', 'attr', 'data_feeder', 'dataset' 'event', 'data_type', 'attr', 'data_feeder', 'dataset', 'reader'
] ]
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
__all__ = [ __all__ = [
'map_readers', 'buffered', 'compose', 'chain', 'shuffle', 'map_readers', 'buffered', 'compose', 'chain', 'shuffle',
'ComposeNotAligned' 'ComposeNotAligned', 'batched'
] ]
from Queue import Queue from Queue import Queue
...@@ -191,3 +191,25 @@ def buffered(reader, size): ...@@ -191,3 +191,25 @@ def buffered(reader, size):
e = q.get() e = q.get()
return data_reader return data_reader
def batched(reader, batch_size):
"""
Create a batched reader.
:param reader: the data reader to read from.
:param batch_size: batch_size
:return: the batched reader.
"""
def __impl__():
r = reader()
batch = []
for instance in r:
batch.append(instance)
if len(batch) == batch_size:
yield batch
batch = []
if batch:
yield batch
return __impl__
...@@ -29,7 +29,7 @@ class ITrainer(object): ...@@ -29,7 +29,7 @@ class ITrainer(object):
""" """
def train(self, def train(self,
train_data_reader, train_reader_creator,
topology, topology,
parameters, parameters,
test_data_reader=None, test_data_reader=None,
...@@ -37,7 +37,7 @@ class ITrainer(object): ...@@ -37,7 +37,7 @@ class ITrainer(object):
""" """
train method. train method.
:param train_data_reader: :param train_reader_creator:
:param topology: :param topology:
:param parameters: :param parameters:
:param test_data_reader: :param test_data_reader:
...@@ -62,27 +62,23 @@ class SGD(ITrainer): ...@@ -62,27 +62,23 @@ class SGD(ITrainer):
self.__optimizer__ = update_equation self.__optimizer__ = update_equation
def train(self, def train(self,
train_reader_creator, train_reader,
topology, topology,
parameters, parameters,
num_passes=1, num_passes=1,
test_data_reader=None,
event_handler=None, event_handler=None,
batch_size=32,
data_types=None, data_types=None,
reader_dict=None): reader_dict=None):
""" """
Training method. Will train num_passes of input data. Training method. Will train num_passes of input data.
:param train_reader_creator: :param train_reader:
:param topology: Network Topology, use one or more Layers to represent it. :param topology: Network Topology, use one or more Layers to represent it.
:param parameters: The parameter pools. :param parameters: The parameter pools.
:param num_passes: The total train passes. :param num_passes: The total train passes.
:param test_data_reader:
:param event_handler: Event handler. A method will be invoked when event :param event_handler: Event handler. A method will be invoked when event
occurred. occurred.
:type event_handler: (BaseEvent) => None :type event_handler: (BaseEvent) => None
:param batch_size: Not important, will be removed after data refactor.
:param data_types: Not important, will be removed after data refactor. :param data_types: Not important, will be removed after data refactor.
:return: :return:
""" """
...@@ -108,9 +104,7 @@ class SGD(ITrainer): ...@@ -108,9 +104,7 @@ class SGD(ITrainer):
for pass_id in xrange(num_passes): for pass_id in xrange(num_passes):
updater.startPass() updater.startPass()
for batch_id, data_batch in enumerate( for batch_id, data_batch in enumerate(train_reader()):
__data_reader_to_batch__(train_reader_creator, batch_size,
topology)):
pass_type = updater.startBatch(len(data_batch)) pass_type = updater.startBatch(len(data_batch))
gm.forwardBackward(feeder(data_batch), out_args, pass_type) gm.forwardBackward(feeder(data_batch), out_args, pass_type)
for each_param in gm.getParameters(): for each_param in gm.getParameters():
...@@ -128,51 +122,16 @@ class SGD(ITrainer): ...@@ -128,51 +122,16 @@ class SGD(ITrainer):
gm.finish() gm.finish()
def __data_reader_to_batch__(reader, batch_size, topology): def __check_train_args__(train_reader, topology, parameters, event_handler,
""" **kwargs):
This function is not important, and will be removed when data refactored.
"""
def input_reorder(func):
for item in func():
retv = []
for __layer_name__ in topology.input_layer_names:
retv.append(item[__layer_name__])
yield retv
return __generator_to_batch__(input_reorder(reader), batch_size=batch_size)
def __generator_to_batch__(generator, batch_size):
"""
This function is not important, and will be removed when data refactored.
"""
ret_val = list()
for each_item in generator:
ret_val.append(each_item)
if len(ret_val) == batch_size:
yield ret_val
ret_val = list()
if len(ret_val) != 0:
yield ret_val
def __check_train_args__(train_data_reader, topology, parameters,
test_data_reader, event_handler, **kwargs):
""" """
Check train function's argument types Check train function's argument types
""" """
if not callable(train_data_reader) or not isinstance(train_data_reader(), if not callable(train_reader) or not isinstance(train_reader(),
collections.Iterator): collections.Iterator):
raise TypeError('train_data_reader should be a function, ' raise TypeError('train_data_reader should be a function, '
'which can return a iterator') 'which can return a iterator')
if test_data_reader is not None:
if not callable(test_data_reader) or not isinstance(
test_data_reader(), collections.Iterator):
raise TypeError('test_data_reader should be a function, which can '
'return a iterator')
if not isinstance(topology, ModelConfig): if not isinstance(topology, ModelConfig):
raise TypeError('topology should be a model config') raise TypeError('topology should be a model config')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册