batch.py 2.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
__all__ = ['batch']

17

18
def batch(reader, batch_size, drop_last=False):
19
    """
G
guofei 已提交
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
    This operator creates a batched reader which combines the data from the 
    input reader to batched data.
    
    Args:
        reader(generator): the data reader to read from.
        batch_size(int): size of each mini-batch.
        drop_last(bool, optional): If set to True, the last batch is dropped when 
            the size of last batch is not equal to batch_size, if set to False,
            it will not. Default: False.
    Returns:
        The batched reader. 
    
    Return Type:
        generator   

    Examples:
        .. code-block:: python
           
            import paddle.fluid as fluid
            def reader():
                for i in range(10):
                    yield i
            batch_reader = fluid.io.batch(reader, batch_size=2)
            
            for data in batch_reader():
                print(data)

            # Output is
            # [0, 1]
            # [2, 3]
            # [4, 5]
            # [6, 7]
            # [8, 9]
53 54 55 56
    """

    def batch_reader():
        r = reader()
57
        b = []
58
        for instance in r:
59 60 61 62
            b.append(instance)
            if len(b) == batch_size:
                yield b
                b = []
63
        if drop_last == False and len(b) != 0:
64
            yield b
65

66
    # Batch size check
C
chenweihang 已提交
67
    batch_size = int(batch_size)
68 69 70 71
    if batch_size <= 0:
        raise ValueError("batch_size should be a positive integeral value, "
                         "but got batch_size={}".format(batch_size))

72
    return batch_reader