From dcc332a89001a01fb982938cc6c408b50ebd895a Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 27 Feb 2017 15:12:53 +0800 Subject: [PATCH] Follow qq's comments --- demo/mnist/api_train_v2.py | 2 +- python/paddle/v2/trainer.py | 22 +++++++--------------- 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/demo/mnist/api_train_v2.py b/demo/mnist/api_train_v2.py index d5eeb053188..4d95780f06c 100644 --- a/demo/mnist/api_train_v2.py +++ b/demo/mnist/api_train_v2.py @@ -37,7 +37,7 @@ def main(): trainer = paddle.trainer.SGD(update_equation=adam_optimizer) trainer.train( - train_reader=paddle.reader.batched( + reader=paddle.reader.batched( paddle.reader.shuffle(paddle.dataset.mnist.train_creator(), buf_size=8192), batch_size=32), topology=cost, diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index 9b45ee69f39..cbeef0306b7 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -28,19 +28,13 @@ class ITrainer(object): The interface of Trainer. The only exposed method is `train`. """ - def train(self, - train_reader_creator, - topology, - parameters, - test_data_reader=None, - event_handler=None): + def train(self, reader, topology, parameters, event_handler=None): """ train method. - :param train_reader_creator: + :param reader: :param topology: :param parameters: - :param test_data_reader: :param event_handler: :return: """ @@ -62,7 +56,7 @@ class SGD(ITrainer): self.__optimizer__ = update_equation def train(self, - train_reader, + reader, topology, parameters, num_passes=1, @@ -72,7 +66,7 @@ class SGD(ITrainer): """ Training method. Will train num_passes of input data. - :param train_reader: + :param reader: :param topology: Network Topology, use one or more Layers to represent it. :param parameters: The parameter pools. :param num_passes: The total train passes. @@ -104,7 +98,7 @@ class SGD(ITrainer): for pass_id in xrange(num_passes): updater.startPass() - for batch_id, data_batch in enumerate(train_reader()): + for batch_id, data_batch in enumerate(reader()): pass_type = updater.startBatch(len(data_batch)) gm.forwardBackward(feeder(data_batch), out_args, pass_type) for each_param in gm.getParameters(): @@ -122,13 +116,11 @@ class SGD(ITrainer): gm.finish() -def __check_train_args__(train_reader, topology, parameters, event_handler, - **kwargs): +def __check_train_args__(reader, topology, parameters, event_handler, **kwargs): """ Check train function's argument types """ - if not callable(train_reader) or not isinstance(train_reader(), - collections.Iterator): + if not callable(reader) or not isinstance(reader(), collections.Iterator): raise TypeError('train_data_reader should be a function, ' 'which can return a iterator') -- GitLab