From d498174c9aa6439207c9789a140f8d610352be74 Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Fri, 3 Mar 2017 00:18:00 +0000 Subject: [PATCH] move paddle.reader.batch to paddle.batch --- demo/image_classification/api_v2_train.py | 4 +-- demo/mnist/api_train_v2.py | 4 +-- python/paddle/v2/__init__.py | 2 ++ python/paddle/v2/batch.py | 35 +++++++++++++++++++++++ python/paddle/v2/reader/decorator.py | 24 +--------------- 5 files changed, 42 insertions(+), 27 deletions(-) create mode 100644 python/paddle/v2/batch.py diff --git a/demo/image_classification/api_v2_train.py b/demo/image_classification/api_v2_train.py index 585f61c6fa..e0fc0e04bb 100644 --- a/demo/image_classification/api_v2_train.py +++ b/demo/image_classification/api_v2_train.py @@ -66,7 +66,7 @@ def main(): sys.stdout.flush() if isinstance(event, paddle.event.EndPass): result = trainer.test( - reader=paddle.reader.batched( + reader=paddle.batch( paddle.dataset.cifar.test10(), batch_size=128), reader_dict={'image': 0, 'label': 1}) @@ -77,7 +77,7 @@ def main(): parameters=parameters, update_equation=momentum_optimizer) trainer.train( - reader=paddle.reader.batched( + reader=paddle.batch( paddle.reader.shuffle( paddle.dataset.cifar.train10(), buf_size=50000), batch_size=128), diff --git a/demo/mnist/api_train_v2.py b/demo/mnist/api_train_v2.py index 9b7ebde500..4fb1808ca1 100644 --- a/demo/mnist/api_train_v2.py +++ b/demo/mnist/api_train_v2.py @@ -98,7 +98,7 @@ def main(): result.metrics['classification_error_evaluator'])) trainer.train( - reader=paddle.reader.batched( + reader=paddle.batch( paddle.reader.shuffle( paddle.dataset.mnist.train(), buf_size=8192), batch_size=128), @@ -115,7 +115,7 @@ def main(): probs = paddle.infer( output=predict, parameters=parameters, - reader=paddle.reader.batched( + reader=paddle.batch( paddle.reader.firstn( paddle.reader.map_readers(lambda item: (item[0], ), paddle.dataset.mnist.test()), diff --git a/python/paddle/v2/__init__.py b/python/paddle/v2/__init__.py index 8ab8cd2f85..a023e3ea06 100644 --- a/python/paddle/v2/__init__.py +++ b/python/paddle/v2/__init__.py @@ -28,6 +28,7 @@ import pooling import inferencer import networks import py_paddle.swig_paddle as api +import batch __all__ = [ 'optimizer', 'layer', 'activation', 'parameters', 'init', 'trainer', @@ -45,3 +46,4 @@ def init(**kwargs): infer = inferencer.infer +batch = batch.batch diff --git a/python/paddle/v2/batch.py b/python/paddle/v2/batch.py new file mode 100644 index 0000000000..f01815a0ce --- /dev/null +++ b/python/paddle/v2/batch.py @@ -0,0 +1,35 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def batch(reader, batch_size): + """ + Create a batch reader. + :param reader: the data reader to read from. + :param batch_size: batch_size + :return: the batch reader. + """ + + def batch_reader(): + r = reader() + batch = [] + for instance in r: + batch.append(instance) + if len(batch) == batch_size: + yield batch + batch = [] + if batch: + yield batch + + return batch_reader diff --git a/python/paddle/v2/reader/decorator.py b/python/paddle/v2/reader/decorator.py index b7657e2776..c4ba110205 100644 --- a/python/paddle/v2/reader/decorator.py +++ b/python/paddle/v2/reader/decorator.py @@ -14,7 +14,7 @@ __all__ = [ 'map_readers', 'buffered', 'compose', 'chain', 'shuffle', - 'ComposeNotAligned', 'batched', 'firstn' + 'ComposeNotAligned', 'firstn' ] import itertools @@ -193,28 +193,6 @@ def buffered(reader, size): return data_reader -def batched(reader, batch_size): - """ - Create a batched reader. - :param reader: the data reader to read from. - :param batch_size: batch_size - :return: the batched reader. - """ - - def batched_reader(): - r = reader() - batch = [] - for instance in r: - batch.append(instance) - if len(batch) == batch_size: - yield batch - batch = [] - if batch: - yield batch - - return batched_reader - - def firstn(reader, n): """ Limit the max number of samples that reader could return. -- GitLab