提交 aa2bcf51 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #1537 from helinwang/batch

move paddle.reader.batch to paddle.batch
......@@ -66,7 +66,7 @@ def main():
sys.stdout.flush()
if isinstance(event, paddle.event.EndPass):
result = trainer.test(
reader=paddle.reader.batched(
reader=paddle.batch(
paddle.dataset.cifar.test10(), batch_size=128),
reader_dict={'image': 0,
'label': 1})
......@@ -77,7 +77,7 @@ def main():
parameters=parameters,
update_equation=momentum_optimizer)
trainer.train(
reader=paddle.reader.batched(
reader=paddle.batch(
paddle.reader.shuffle(
paddle.dataset.cifar.train10(), buf_size=50000),
batch_size=128),
......
......@@ -98,7 +98,7 @@ def main():
result.metrics['classification_error_evaluator']))
trainer.train(
reader=paddle.reader.batched(
reader=paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=8192),
batch_size=128),
......@@ -115,7 +115,7 @@ def main():
probs = paddle.infer(
output=predict,
parameters=parameters,
reader=paddle.reader.batched(
reader=paddle.batch(
paddle.reader.firstn(
paddle.reader.map_readers(lambda item: (item[0], ),
paddle.dataset.mnist.test()),
......
......@@ -28,6 +28,7 @@ import pooling
import inference
import networks
import py_paddle.swig_paddle as api
import minibatch
__all__ = [
'optimizer', 'layer', 'activation', 'parameters', 'init', 'trainer',
......@@ -45,3 +46,4 @@ def init(**kwargs):
infer = inference.infer
batch = minibatch.batch
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def batch(reader, batch_size):
"""
Create a batch reader.
:param reader: the data reader to read from.
:param batch_size: batch_size
:return: the batch reader.
"""
def batch_reader():
r = reader()
batch = []
for instance in r:
batch.append(instance)
if len(batch) == batch_size:
yield batch
batch = []
if batch:
yield batch
return batch_reader
......@@ -14,7 +14,7 @@
__all__ = [
'map_readers', 'buffered', 'compose', 'chain', 'shuffle',
'ComposeNotAligned', 'batched', 'firstn'
'ComposeNotAligned', 'firstn'
]
import itertools
......@@ -193,28 +193,6 @@ def buffered(reader, size):
return data_reader
def batched(reader, batch_size):
"""
Create a batched reader.
:param reader: the data reader to read from.
:param batch_size: batch_size
:return: the batched reader.
"""
def batched_reader():
r = reader()
batch = []
for instance in r:
batch.append(instance)
if len(batch) == batch_size:
yield batch
batch = []
if batch:
yield batch
return batched_reader
def firstn(reader, n):
"""
Limit the max number of samples that reader could return.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册