decorator.py 11.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

H
Helin Wang 已提交
15 16
__all__ = [
    'map_readers', 'buffered', 'compose', 'chain', 'shuffle',
T
typhoonzero 已提交
17
    'ComposeNotAligned', 'firstn', 'xmap_readers', 'PipeReader'
H
Helin Wang 已提交
18
]
19

T
tangwei12 已提交
20 21 22
from threading import Thread
import subprocess

23
from six.moves.queue import Queue
24
from six.moves import zip_longest
25 26
from six.moves import map
from six.moves import zip
27 28
import itertools
import random
T
tangwei12 已提交
29
import zlib
M
minqiyang 已提交
30
import paddle.fluid.compat as cpt
31 32


H
Helin Wang 已提交
33 34 35 36 37
def map_readers(func, *readers):
    """
    Creates a data reader that outputs return value of function using
    output of each data readers as arguments.

Y
Yu Yang 已提交
38 39 40 41 42
    :param func: function to use. The type of func should be (Sample) => Sample
    :type: callable
    :param readers: readers whose outputs will be used as arguments of func.
    :return: the created data reader.
    :rtype: callable
H
Helin Wang 已提交
43 44 45 46 47 48
    """

    def reader():
        rs = []
        for r in readers:
            rs.append(r())
49
        for e in map(func, *rs):
H
Helin Wang 已提交
50 51 52 53 54
            yield e

    return reader


H
Helin Wang 已提交
55
def shuffle(reader, buf_size):
56
    """
Y
Yu Yang 已提交
57
    Creates a data reader whose data output is shuffled.
58

H
Helin Wang 已提交
59
    Output from the iterator that created by original reader will be
60 61 62
    buffered into shuffle buffer, and then shuffled. The size of shuffle buffer
    is determined by argument buf_size.

63
    :param reader: the original reader whose output will be shuffled.
Y
Yu Yang 已提交
64
    :type reader: callable
65
    :param buf_size: shuffle buffer size.
Y
Yu Yang 已提交
66
    :type buf_size: int
67

Y
Yu Yang 已提交
68 69
    :return: the new reader whose output is shuffled.
    :rtype: callable
70 71
    """

H
Helin Wang 已提交
72
    def data_reader():
73
        buf = []
H
Helin Wang 已提交
74
        for e in reader():
75 76 77 78 79 80 81 82 83 84 85 86
            buf.append(e)
            if len(buf) >= buf_size:
                random.shuffle(buf)
                for b in buf:
                    yield b
                buf = []

        if len(buf) > 0:
            random.shuffle(buf)
            for b in buf:
                yield b

H
Helin Wang 已提交
87
    return data_reader
88 89


H
Helin Wang 已提交
90
def chain(*readers):
91 92 93
    """
    Creates a data reader whose output is the outputs of input data
    readers chained together.
94

H
Helin Wang 已提交
95
    If input readers output following data entries:
96 97 98
    [0, 0, 0]
    [1, 1, 1]
    [2, 2, 2]
H
Helin Wang 已提交
99
    The chained reader will output:
100 101
    [0, 0, 0, 1, 1, 1, 2, 2, 2]

102
    :param readers: input readers.
Y
Yu Yang 已提交
103 104
    :return: the new data reader.
    :rtype: callable
105 106
    """

H
Helin Wang 已提交
107
    def reader():
108
        rs = []
H
Helin Wang 已提交
109
        for r in readers:
110 111 112 113 114
            rs.append(r())

        for e in itertools.chain(*rs):
            yield e

H
Helin Wang 已提交
115
    return reader
116 117


H
Helin Wang 已提交
118
class ComposeNotAligned(ValueError):
119 120 121
    pass


H
Helin Wang 已提交
122
def compose(*readers, **kwargs):
123 124
    """
    Creates a data reader whose output is the combination of input readers.
125

H
Helin Wang 已提交
126
    If input readers output following data entries:
127
    (1, 2)    3    (4, 5)
H
Helin Wang 已提交
128
    The composed reader will output:
129 130
    (1, 2, 3, 4, 5)

Y
Yu Yang 已提交
131 132
    :param readers: readers that will be composed together.
    :param check_alignment: if True, will check if input readers are aligned
133 134
        correctly. If False, will not check alignment and trailing outputs
        will be discarded. Defaults to True.
Y
Yu Yang 已提交
135
    :type check_alignment: bool
136

Y
Yu Yang 已提交
137
    :return: the new data reader.
138

139 140
    :raises ComposeNotAligned: outputs of readers are not aligned.
        Will not raise when check_alignment is set to False.
141 142 143 144 145 146 147 148 149
    """
    check_alignment = kwargs.pop('check_alignment', True)

    def make_tuple(x):
        if isinstance(x, tuple):
            return x
        else:
            return (x, )

H
Helin Wang 已提交
150
    def reader():
151
        rs = []
H
Helin Wang 已提交
152
        for r in readers:
153 154
            rs.append(r())
        if not check_alignment:
155 156
            for outputs in zip(*rs):
                yield sum(list(map(make_tuple, outputs)), ())
157
        else:
158
            for outputs in zip_longest(*rs):
159 160 161
                for o in outputs:
                    if o is None:
                        # None will be not be present if compose is aligned
H
Helin Wang 已提交
162 163
                        raise ComposeNotAligned(
                            "outputs of readers are not aligned.")
164
                yield sum(list(map(make_tuple, outputs)), ())
165

H
Helin Wang 已提交
166
    return reader
167 168


H
Helin Wang 已提交
169
def buffered(reader, size):
170 171
    """
    Creates a buffered data reader.
172

H
Helin Wang 已提交
173 174
    The buffered data reader will read and save data entries into a
    buffer. Reading from the buffered data reader will proceed as long
175
    as the buffer is not empty.
176

177
    :param reader: the data reader to read from.
Y
Yu Yang 已提交
178
    :type reader: callable
179
    :param size: max buffer size.
Y
Yu Yang 已提交
180
    :type size: int
181

182
    :returns: the buffered data reader.
183 184 185 186 187 188 189 190 191 192 193 194
    """

    class EndSignal():
        pass

    end = EndSignal()

    def read_worker(r, q):
        for d in r:
            q.put(d)
        q.put(end)

H
Helin Wang 已提交
195 196
    def data_reader():
        r = reader()
197
        q = Queue(maxsize=size)
198 199 200 201 202 203 204 205 206 207 208
        t = Thread(
            target=read_worker, args=(
                r,
                q, ))
        t.daemon = True
        t.start()
        e = q.get()
        while e != end:
            yield e
            e = q.get()

H
Helin Wang 已提交
209
    return data_reader
Y
Yu Yang 已提交
210 211


Y
Yu Yang 已提交
212
def firstn(reader, n):
Y
Yu Yang 已提交
213 214
    """
    Limit the max number of samples that reader could return.
Y
Yu Yang 已提交
215 216 217 218 219 220 221

    :param reader: the data reader to read from.
    :type reader: callable
    :param n: the max number of samples that return.
    :type n: int
    :return: the decorated reader.
    :rtype: callable
Y
Yu Yang 已提交
222 223
    """

Y
Yu Yang 已提交
224 225 226 227
    # TODO(yuyang18): Check if just drop the reader, could clean the opened
    # resource or not?

    def firstn_reader():
Y
Yu Yang 已提交
228
        for i, item in enumerate(reader()):
Y
Yu Yang 已提交
229
            if i == n:
Y
Yu Yang 已提交
230 231 232
                break
            yield item

Y
Yu Yang 已提交
233
    return firstn_reader
234 235 236 237 238 239


class XmapEndSignal():
    pass


240
def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
241 242 243 244 245 246 247
    """
    Use multiprocess to map samples from reader by a mapper defined by user.
    And this function contains a buffered decorator.
    :param mapper:  a function to map sample.
    :type mapper: callable
    :param reader: the data reader to read from
    :type reader: callable
248
    :param process_num: process number to handle original sample
249 250 251
    :type process_num: int
    :param buffer_size: max buffer size
    :type buffer_size: int
252 253
    :param order: keep the order of reader
    :type order: bool
254 255 256 257
    :return: the decarated reader
    :rtype: callable
    """
    end = XmapEndSignal()
W
wanghaoshuang 已提交
258

259 260 261 262 263
    # define a worker to read samples from reader to in_queue
    def read_worker(reader, in_queue):
        for i in reader():
            in_queue.put(i)
        in_queue.put(end)
W
wanghaoshuang 已提交
264

265 266 267 268
    # define a worker to read samples from reader to in_queue with order flag
    def order_read_worker(reader, in_queue):
        in_order = 0
        for i in reader():
W
wanghaoshuang 已提交
269 270
            in_queue.put((in_order, i))
            in_order += 1
271
        in_queue.put(end)
272 273 274 275 276 277 278 279 280 281 282

    # define a worker to handle samples from in_queue by mapper
    # and put mapped samples into out_queue
    def handle_worker(in_queue, out_queue, mapper):
        sample = in_queue.get()
        while not isinstance(sample, XmapEndSignal):
            r = mapper(sample)
            out_queue.put(r)
            sample = in_queue.get()
        in_queue.put(end)
        out_queue.put(end)
W
wanghaoshuang 已提交
283

284 285 286 287 288 289 290 291 292 293
    # define a worker to handle samples from in_queue by mapper
    # and put mapped samples into out_queue by order
    def order_handle_worker(in_queue, out_queue, mapper, out_order):
        ins = in_queue.get()
        while not isinstance(ins, XmapEndSignal):
            order, sample = ins
            r = mapper(sample)
            while order != out_order[0]:
                pass
            out_queue.put(r)
W
wanghaoshuang 已提交
294
            out_order[0] += 1
295 296 297
            ins = in_queue.get()
        in_queue.put(end)
        out_queue.put(end)
298 299

    def xreader():
300 301
        in_queue = Queue(buffer_size)
        out_queue = Queue(buffer_size)
302 303 304 305 306 307 308 309 310 311 312
        out_order = [0]
        # start a read worker in a thread
        target = order_read_worker if order else read_worker
        t = Thread(target=target, args=(reader, in_queue))
        t.daemon = True
        t.start()
        # start several handle_workers
        target = order_handle_worker if order else handle_worker
        args = (in_queue, out_queue, mapper, out_order) if order else (
            in_queue, out_queue, mapper)
        workers = []
313
        for i in range(process_num):
314 315 316 317 318 319
            worker = Thread(target=target, args=args)
            worker.daemon = True
            workers.append(worker)
        for w in workers:
            w.start()

320 321 322 323 324 325 326 327 328 329 330 331 332
        sample = out_queue.get()
        while not isinstance(sample, XmapEndSignal):
            yield sample
            sample = out_queue.get()
        finish = 1
        while finish < process_num:
            sample = out_queue.get()
            if isinstance(sample, XmapEndSignal):
                finish += 1
            else:
                yield sample

    return xreader
333 334 335 336 337 338 339 340


def _buf2lines(buf, line_break="\n"):
    # FIXME: line_break should be automatically configured.
    lines = buf.split(line_break)
    return lines[:-1], lines[-1]


T
typhoonzero 已提交
341
class PipeReader:
342
    """
343
        PipeReader read data by stream from a command, take it's
T
typhoonzero 已提交
344 345
        stdout into a pipe buffer and redirect it to the parser to
        parse, then yield data as your desired format.
346

T
typhoonzero 已提交
347 348
        You can using standard linux command or call another program
        to read data, from HDFS, Ceph, URL, AWS S3 etc:
349

T
typhoonzero 已提交
350 351 352 353 354
        .. code-block:: python
           cmd = "hadoop fs -cat /path/to/some/file"
           cmd = "cat sample_file.tar.gz"
           cmd = "curl http://someurl"
           cmd = "python print_s3_bucket.py"
355

T
typhoonzero 已提交
356 357 358
        An example:

        .. code-block:: python
359

T
typhoonzero 已提交
360 361 362 363 364 365
           def example_reader():
               for f in myfiles:
                   pr = PipeReader("cat %s"%f)
                   for l in pr.get_line():
                       sample = l.split(" ")
                       yield sample
366 367
    """

T
typhoonzero 已提交
368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388
    def __init__(self, command, bufsize=8192, file_type="plain"):
        if not isinstance(command, str):
            raise TypeError("left_cmd must be a string")
        if file_type == "gzip":
            self.dec = zlib.decompressobj(
                32 + zlib.MAX_WBITS)  # offset 32 to skip the header
        self.file_type = file_type
        self.bufsize = bufsize
        self.process = subprocess.Popen(
            command.split(" "), bufsize=bufsize, stdout=subprocess.PIPE)

    def get_line(self, cut_lines=True, line_break="\n"):
        """
        :param cut_lines: cut buffer to lines
        :type cut_lines: bool
        :param line_break: line break of the file, like \n or \r
        :type line_break: string

        :return: one line or a buffer of bytes
        :rtype: string
        """
389 390
        remained = ""
        while True:
T
typhoonzero 已提交
391
            buff = self.process.stdout.read(self.bufsize)
392
            if buff:
T
typhoonzero 已提交
393
                if self.file_type == "gzip":
M
minqiyang 已提交
394
                    decomp_buff = cpt.to_literal_str(self.dec.decompress(buff))
T
typhoonzero 已提交
395
                elif self.file_type == "plain":
M
minqiyang 已提交
396
                    decomp_buff = cpt.to_literal_str(buff)
397
                else:
T
typhoonzero 已提交
398 399
                    raise TypeError("file_type %s is not allowed" %
                                    self.file_type)
400 401 402 403

                if cut_lines:
                    lines, remained = _buf2lines(''.join(
                        [remained, decomp_buff]), line_break)
T
typhoonzero 已提交
404 405
                    for line in lines:
                        yield line
406
                else:
T
typhoonzero 已提交
407
                    yield decomp_buff
408 409
            else:
                break