提交 dcc332a8 编写于 作者: Y Yu Yang

Follow qq's comments

上级 f3f24604
...@@ -37,7 +37,7 @@ def main(): ...@@ -37,7 +37,7 @@ def main():
trainer = paddle.trainer.SGD(update_equation=adam_optimizer) trainer = paddle.trainer.SGD(update_equation=adam_optimizer)
trainer.train( trainer.train(
train_reader=paddle.reader.batched( reader=paddle.reader.batched(
paddle.reader.shuffle(paddle.dataset.mnist.train_creator(), paddle.reader.shuffle(paddle.dataset.mnist.train_creator(),
buf_size=8192), batch_size=32), buf_size=8192), batch_size=32),
topology=cost, topology=cost,
......
...@@ -28,19 +28,13 @@ class ITrainer(object): ...@@ -28,19 +28,13 @@ class ITrainer(object):
The interface of Trainer. The only exposed method is `train`. The interface of Trainer. The only exposed method is `train`.
""" """
def train(self, def train(self, reader, topology, parameters, event_handler=None):
train_reader_creator,
topology,
parameters,
test_data_reader=None,
event_handler=None):
""" """
train method. train method.
:param train_reader_creator: :param reader:
:param topology: :param topology:
:param parameters: :param parameters:
:param test_data_reader:
:param event_handler: :param event_handler:
:return: :return:
""" """
...@@ -62,7 +56,7 @@ class SGD(ITrainer): ...@@ -62,7 +56,7 @@ class SGD(ITrainer):
self.__optimizer__ = update_equation self.__optimizer__ = update_equation
def train(self, def train(self,
train_reader, reader,
topology, topology,
parameters, parameters,
num_passes=1, num_passes=1,
...@@ -72,7 +66,7 @@ class SGD(ITrainer): ...@@ -72,7 +66,7 @@ class SGD(ITrainer):
""" """
Training method. Will train num_passes of input data. Training method. Will train num_passes of input data.
:param train_reader: :param reader:
:param topology: Network Topology, use one or more Layers to represent it. :param topology: Network Topology, use one or more Layers to represent it.
:param parameters: The parameter pools. :param parameters: The parameter pools.
:param num_passes: The total train passes. :param num_passes: The total train passes.
...@@ -104,7 +98,7 @@ class SGD(ITrainer): ...@@ -104,7 +98,7 @@ class SGD(ITrainer):
for pass_id in xrange(num_passes): for pass_id in xrange(num_passes):
updater.startPass() updater.startPass()
for batch_id, data_batch in enumerate(train_reader()): for batch_id, data_batch in enumerate(reader()):
pass_type = updater.startBatch(len(data_batch)) pass_type = updater.startBatch(len(data_batch))
gm.forwardBackward(feeder(data_batch), out_args, pass_type) gm.forwardBackward(feeder(data_batch), out_args, pass_type)
for each_param in gm.getParameters(): for each_param in gm.getParameters():
...@@ -122,13 +116,11 @@ class SGD(ITrainer): ...@@ -122,13 +116,11 @@ class SGD(ITrainer):
gm.finish() gm.finish()
def __check_train_args__(train_reader, topology, parameters, event_handler, def __check_train_args__(reader, topology, parameters, event_handler, **kwargs):
**kwargs):
""" """
Check train function's argument types Check train function's argument types
""" """
if not callable(train_reader) or not isinstance(train_reader(), if not callable(reader) or not isinstance(reader(), collections.Iterator):
collections.Iterator):
raise TypeError('train_data_reader should be a function, ' raise TypeError('train_data_reader should be a function, '
'which can return a iterator') 'which can return a iterator')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册