decorator.py 4.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
__all__ = ['buffered', 'compose', 'chain', 'shuffle', 'ComposeNotAligned']
16 17 18

from Queue import Queue
from threading import Thread
19 20
import itertools
import random
21 22


H
Helin Wang 已提交
23
def shuffle(reader, buf_size):
24 25
    """
    Creates a data reader whose data output is suffled.
26

H
Helin Wang 已提交
27
    Output from the iterator that created by original reader will be
28 29 30
    buffered into shuffle buffer, and then shuffled. The size of shuffle buffer
    is determined by argument buf_size.

31 32
    :param reader: the original reader whose output will be shuffled.
    :param buf_size: shuffle buffer size.
33

34
    :returns:the new reader whose output is shuffled.
35 36
    """

H
Helin Wang 已提交
37
    def data_reader():
38
        buf = []
H
Helin Wang 已提交
39
        for e in reader():
40 41 42 43 44 45 46 47 48 49 50 51
            buf.append(e)
            if len(buf) >= buf_size:
                random.shuffle(buf)
                for b in buf:
                    yield b
                buf = []

        if len(buf) > 0:
            random.shuffle(buf)
            for b in buf:
                yield b

H
Helin Wang 已提交
52
    return data_reader
53 54


H
Helin Wang 已提交
55
def chain(*readers):
56 57 58
    """
    Creates a data reader whose output is the outputs of input data
    readers chained together.
59

H
Helin Wang 已提交
60
    If input readers output following data entries:
61 62 63
    [0, 0, 0]
    [1, 1, 1]
    [2, 2, 2]
H
Helin Wang 已提交
64
    The chained reader will output:
65 66
    [0, 0, 0, 1, 1, 1, 2, 2, 2]

67 68
    :param readers: input readers.
    :returns: the new data reader.
69 70
    """

H
Helin Wang 已提交
71
    def reader():
72
        rs = []
H
Helin Wang 已提交
73
        for r in readers:
74 75 76 77 78
            rs.append(r())

        for e in itertools.chain(*rs):
            yield e

H
Helin Wang 已提交
79
    return reader
80 81


H
Helin Wang 已提交
82
class ComposeNotAligned(ValueError):
83 84 85
    pass


H
Helin Wang 已提交
86
def compose(*readers, **kwargs):
87 88
    """
    Creates a data reader whose output is the combination of input readers.
89

H
Helin Wang 已提交
90
    If input readers output following data entries:
91
    (1, 2)    3    (4, 5)
H
Helin Wang 已提交
92
    The composed reader will output:
93 94
    (1, 2, 3, 4, 5)

95 96 97 98
    :*readers: readers that will be composed together.
    :check_alignment: if True, will check if input readers are aligned
        correctly. If False, will not check alignment and trailing outputs
        will be discarded. Defaults to True.
99

100
    :returns: the new data reader.
101

102 103
    :raises ComposeNotAligned: outputs of readers are not aligned.
        Will not raise when check_alignment is set to False.
104 105 106 107 108 109 110 111 112
    """
    check_alignment = kwargs.pop('check_alignment', True)

    def make_tuple(x):
        if isinstance(x, tuple):
            return x
        else:
            return (x, )

H
Helin Wang 已提交
113
    def reader():
114
        rs = []
H
Helin Wang 已提交
115
        for r in readers:
116 117 118 119 120 121 122 123 124
            rs.append(r())
        if not check_alignment:
            for outputs in itertools.izip(*rs):
                yield sum(map(make_tuple, outputs), ())
        else:
            for outputs in itertools.izip_longest(*rs):
                for o in outputs:
                    if o is None:
                        # None will be not be present if compose is aligned
H
Helin Wang 已提交
125 126
                        raise ComposeNotAligned(
                            "outputs of readers are not aligned.")
127 128
                yield sum(map(make_tuple, outputs), ())

H
Helin Wang 已提交
129
    return reader
130 131


H
Helin Wang 已提交
132
def buffered(reader, size):
133 134
    """
    Creates a buffered data reader.
135

H
Helin Wang 已提交
136 137
    The buffered data reader will read and save data entries into a
    buffer. Reading from the buffered data reader will proceed as long
138
    as the buffer is not empty.
139
    
140 141
    :param reader: the data reader to read from.
    :param size: max buffer size.
142
    
143
    :returns: the buffered data reader.
144 145 146 147 148 149 150 151 152 153 154 155
    """

    class EndSignal():
        pass

    end = EndSignal()

    def read_worker(r, q):
        for d in r:
            q.put(d)
        q.put(end)

H
Helin Wang 已提交
156 157
    def data_reader():
        r = reader()
158 159 160 161 162 163 164 165 166 167 168 169
        q = Queue(maxsize=size)
        t = Thread(
            target=read_worker, args=(
                r,
                q, ))
        t.daemon = True
        t.start()
        e = q.get()
        while e != end:
            yield e
            e = q.get()

H
Helin Wang 已提交
170
    return data_reader