decorator.py 18.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

H
Helin Wang 已提交
15
__all__ = [
S
sneaxiy 已提交
16
    'cache', 'map_readers', 'buffered', 'compose', 'chain', 'shuffle',
Q
Qiao Longfei 已提交
17
    'ComposeNotAligned', 'firstn', 'xmap_readers', 'PipeReader',
18
    'multiprocess_reader', 'Fake'
H
Helin Wang 已提交
19
]
20

T
tangwei12 已提交
21 22
from threading import Thread
import subprocess
Q
Qiao Longfei 已提交
23
import multiprocessing
24
import six
Q
Qiao Longfei 已提交
25
import sys
T
tangwei12 已提交
26

27
from six.moves.queue import Queue
28
from six.moves import zip_longest
29 30
from six.moves import map
from six.moves import zip
31 32
import itertools
import random
T
tangwei12 已提交
33
import zlib
M
minqiyang 已提交
34
import paddle.compat as cpt
35 36


S
sneaxiy 已提交
37 38 39 40 41 42 43 44 45 46 47 48 49
def cache(reader):
    """
    Cache the reader data into memory. 

    Be careful that this method may take long time to process, 
    and consume lots of memory. :code:`reader()` would only 
    call once. 

    Args:
        reader (generator): a reader object which yields 
            data each time.

    Returns:
S
sneaxiy 已提交
50
        generator: a decorated reader object which yields data from cached memory.
S
sneaxiy 已提交
51 52 53 54 55 56 57 58 59 60
    """
    all_data = tuple(reader())

    def __impl__():
        for item in all_data:
            yield item

    return __impl__


H
Helin Wang 已提交
61 62 63
def map_readers(func, *readers):
    """
    Creates a data reader that outputs return value of function using
64
    output of each data reader as arguments.
H
Helin Wang 已提交
65

66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
    If input readers output the following data entries: 2 3,
    and the input func is mul(x, y),
    the output of the resulted reader will be 6.


    Args:
        func: a function to read data and compute result, the output of this function 
              will be set as the output of the resulted data reader.
        readers (Reader|list of Reader): list of readers whose outputs will be used as arguments of func.
 
    Returns:
        the resulted data reader (Reader)

    Examples:

        .. code-block:: python

         import paddle.reader
         d = {"h": 0, "i": 1}
         def func(x):
             return d[x]
         def reader():
             yield "h"
             yield "i"
         map_reader_result = paddle.reader.map_readers(func, reader)
H
Helin Wang 已提交
91 92 93 94 95 96
    """

    def reader():
        rs = []
        for r in readers:
            rs.append(r())
97
        for e in map(func, *rs):
H
Helin Wang 已提交
98 99 100 101 102
            yield e

    return reader


H
Helin Wang 已提交
103
def shuffle(reader, buf_size):
104
    """
Y
Yu Yang 已提交
105
    Creates a data reader whose data output is shuffled.
106

H
Helin Wang 已提交
107
    Output from the iterator that created by original reader will be
108 109 110
    buffered into shuffle buffer, and then shuffled. The size of shuffle buffer
    is determined by argument buf_size.

111
    :param reader: the original reader whose output will be shuffled.
Y
Yu Yang 已提交
112
    :type reader: callable
113
    :param buf_size: shuffle buffer size.
Y
Yu Yang 已提交
114
    :type buf_size: int
115

Y
Yu Yang 已提交
116 117
    :return: the new reader whose output is shuffled.
    :rtype: callable
118 119
    """

H
Helin Wang 已提交
120
    def data_reader():
121
        buf = []
H
Helin Wang 已提交
122
        for e in reader():
123 124 125 126 127 128 129 130 131 132 133 134
            buf.append(e)
            if len(buf) >= buf_size:
                random.shuffle(buf)
                for b in buf:
                    yield b
                buf = []

        if len(buf) > 0:
            random.shuffle(buf)
            for b in buf:
                yield b

H
Helin Wang 已提交
135
    return data_reader
136 137


H
Helin Wang 已提交
138
def chain(*readers):
139
    """
140 141
    Use the input data readers to create a chained data reader. The new created reader
    chains the outputs of input readers together as its output.
142

143 144 145 146 147 148 149 150
    **Note**:
        ``paddle.reader.chain`` is the alias of ``paddle.fluid.io.chain``, and
        ``paddle.fluid.io.chain`` is recommended to use.

    For example, if three input readers' outputs are as follows:
    [0, 0, 0],
    [10, 10, 10],
    [20, 20, 20].
H
Helin Wang 已提交
151
    The chained reader will output:
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183
    [[0, 0, 0], [10, 10, 10], [20, 20, 20]].

    Args:
        readers(list): input data readers.

    Returns:
        callable: the new chained data reader.

    Examples:
        ..  code-block:: python

            import paddle

            def reader_creator_3(start):
                def reader():
                    for i in range(start, start + 3):
                        yield [i, i, i]
                return reader

            c = paddle.reader.chain(reader_creator_3(0), reader_creator_3(10), reader_creator_3(20))
            for e in c():
                print(e)
            # Output:
            # [0, 0, 0]
            # [1, 1, 1]
            # [2, 2, 2]
            # [10, 10, 10]
            # [11, 11, 11]
            # [12, 12, 12]
            # [20, 20, 20]
            # [21, 21, 21]
            # [22, 22, 22]
184 185 186

    """

H
Helin Wang 已提交
187
    def reader():
188
        rs = []
H
Helin Wang 已提交
189
        for r in readers:
190 191 192 193 194
            rs.append(r())

        for e in itertools.chain(*rs):
            yield e

H
Helin Wang 已提交
195
    return reader
196 197


H
Helin Wang 已提交
198
class ComposeNotAligned(ValueError):
199 200 201
    pass


H
Helin Wang 已提交
202
def compose(*readers, **kwargs):
203 204
    """
    Creates a data reader whose output is the combination of input readers.
205

H
Helin Wang 已提交
206
    If input readers output following data entries:
207
    (1, 2)    3    (4, 5)
H
Helin Wang 已提交
208
    The composed reader will output:
209 210
    (1, 2, 3, 4, 5)

Y
Yu Yang 已提交
211 212
    :param readers: readers that will be composed together.
    :param check_alignment: if True, will check if input readers are aligned
213 214
        correctly. If False, will not check alignment and trailing outputs
        will be discarded. Defaults to True.
Y
Yu Yang 已提交
215
    :type check_alignment: bool
216

Y
Yu Yang 已提交
217
    :return: the new data reader.
218

219 220
    :raises ComposeNotAligned: outputs of readers are not aligned.
        Will not raise when check_alignment is set to False.
221 222 223 224 225 226 227 228 229
    """
    check_alignment = kwargs.pop('check_alignment', True)

    def make_tuple(x):
        if isinstance(x, tuple):
            return x
        else:
            return (x, )

H
Helin Wang 已提交
230
    def reader():
231
        rs = []
H
Helin Wang 已提交
232
        for r in readers:
233 234
            rs.append(r())
        if not check_alignment:
235 236
            for outputs in zip(*rs):
                yield sum(list(map(make_tuple, outputs)), ())
237
        else:
238
            for outputs in zip_longest(*rs):
239 240 241
                for o in outputs:
                    if o is None:
                        # None will be not be present if compose is aligned
H
Helin Wang 已提交
242 243
                        raise ComposeNotAligned(
                            "outputs of readers are not aligned.")
244
                yield sum(list(map(make_tuple, outputs)), ())
245

H
Helin Wang 已提交
246
    return reader
247 248


H
Helin Wang 已提交
249
def buffered(reader, size):
250 251
    """
    Creates a buffered data reader.
252

H
Helin Wang 已提交
253 254
    The buffered data reader will read and save data entries into a
    buffer. Reading from the buffered data reader will proceed as long
255
    as the buffer is not empty.
256

257
    :param reader: the data reader to read from.
Y
Yu Yang 已提交
258
    :type reader: callable
259
    :param size: max buffer size.
Y
Yu Yang 已提交
260
    :type size: int
261

262
    :returns: the buffered data reader.
263 264 265 266 267 268 269 270 271 272 273 274
    """

    class EndSignal():
        pass

    end = EndSignal()

    def read_worker(r, q):
        for d in r:
            q.put(d)
        q.put(end)

H
Helin Wang 已提交
275 276
    def data_reader():
        r = reader()
277
        q = Queue(maxsize=size)
278 279 280 281 282 283 284 285 286 287 288
        t = Thread(
            target=read_worker, args=(
                r,
                q, ))
        t.daemon = True
        t.start()
        e = q.get()
        while e != end:
            yield e
            e = q.get()

H
Helin Wang 已提交
289
    return data_reader
Y
Yu Yang 已提交
290 291


Y
Yu Yang 已提交
292
def firstn(reader, n):
Y
Yu Yang 已提交
293 294
    """
    Limit the max number of samples that reader could return.
Y
Yu Yang 已提交
295 296 297 298 299 300 301

    :param reader: the data reader to read from.
    :type reader: callable
    :param n: the max number of samples that return.
    :type n: int
    :return: the decorated reader.
    :rtype: callable
Y
Yu Yang 已提交
302 303
    """

Y
Yu Yang 已提交
304 305 306 307
    # TODO(yuyang18): Check if just drop the reader, could clean the opened
    # resource or not?

    def firstn_reader():
Y
Yu Yang 已提交
308
        for i, item in enumerate(reader()):
Y
Yu Yang 已提交
309
            if i == n:
Y
Yu Yang 已提交
310 311 312
                break
            yield item

Y
Yu Yang 已提交
313
    return firstn_reader
314 315 316 317 318 319


class XmapEndSignal():
    pass


320
def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
321
    """
Z
Zeng Jinle 已提交
322 323 324 325 326 327 328 329 330 331 332 333
    Use multi-threads to map samples from reader by a mapper defined by user.

    Args:
        mapper (callable): a function to map the data from reader.
        reader (callable): a data reader which yields the data. 
        process_num (int): thread number to handle original sample.
        buffer_size (int): size of the queue to read data in. 
        order (bool): whether to keep the data order from original reader. 
            Default False.

    Returns:
        callable: a decorated reader with data mapping. 
334 335
    """
    end = XmapEndSignal()
W
wanghaoshuang 已提交
336

337 338 339 340 341
    # define a worker to read samples from reader to in_queue
    def read_worker(reader, in_queue):
        for i in reader():
            in_queue.put(i)
        in_queue.put(end)
W
wanghaoshuang 已提交
342

343 344 345 346
    # define a worker to read samples from reader to in_queue with order flag
    def order_read_worker(reader, in_queue):
        in_order = 0
        for i in reader():
W
wanghaoshuang 已提交
347 348
            in_queue.put((in_order, i))
            in_order += 1
349
        in_queue.put(end)
350 351 352 353 354 355 356 357 358 359 360

    # define a worker to handle samples from in_queue by mapper
    # and put mapped samples into out_queue
    def handle_worker(in_queue, out_queue, mapper):
        sample = in_queue.get()
        while not isinstance(sample, XmapEndSignal):
            r = mapper(sample)
            out_queue.put(r)
            sample = in_queue.get()
        in_queue.put(end)
        out_queue.put(end)
W
wanghaoshuang 已提交
361

362 363 364 365 366 367 368 369 370 371
    # define a worker to handle samples from in_queue by mapper
    # and put mapped samples into out_queue by order
    def order_handle_worker(in_queue, out_queue, mapper, out_order):
        ins = in_queue.get()
        while not isinstance(ins, XmapEndSignal):
            order, sample = ins
            r = mapper(sample)
            while order != out_order[0]:
                pass
            out_queue.put(r)
W
wanghaoshuang 已提交
372
            out_order[0] += 1
373 374 375
            ins = in_queue.get()
        in_queue.put(end)
        out_queue.put(end)
376 377

    def xreader():
378 379
        in_queue = Queue(buffer_size)
        out_queue = Queue(buffer_size)
380 381 382 383 384 385 386 387 388 389 390
        out_order = [0]
        # start a read worker in a thread
        target = order_read_worker if order else read_worker
        t = Thread(target=target, args=(reader, in_queue))
        t.daemon = True
        t.start()
        # start several handle_workers
        target = order_handle_worker if order else handle_worker
        args = (in_queue, out_queue, mapper, out_order) if order else (
            in_queue, out_queue, mapper)
        workers = []
391
        for i in range(process_num):
392 393 394 395 396 397
            worker = Thread(target=target, args=args)
            worker.daemon = True
            workers.append(worker)
        for w in workers:
            w.start()

398 399 400 401 402 403 404 405 406 407 408 409 410
        sample = out_queue.get()
        while not isinstance(sample, XmapEndSignal):
            yield sample
            sample = out_queue.get()
        finish = 1
        while finish < process_num:
            sample = out_queue.get()
            if isinstance(sample, XmapEndSignal):
                finish += 1
            else:
                yield sample

    return xreader
411 412


Q
Qiao Longfei 已提交
413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445
def multiprocess_reader(readers, use_pipe=True, queue_size=1000):
    """
    multiprocess_reader use python multi process to read data from readers
    and then use multiprocess.Queue or multiprocess.Pipe to merge all
    data. The process number is equal to the number of input readers, each
    process call one reader.

    Multiprocess.Queue require the rw access right to /dev/shm, some
    platform does not support.

    you need to create multiple readers first, these readers should be independent
    to each other so that each process can work independently.

    An example:

    .. code-block:: python

        reader0 = reader(["file01", "file02"])
        reader1 = reader(["file11", "file12"])
        reader1 = reader(["file21", "file22"])
        reader = multiprocess_reader([reader0, reader1, reader2],
            queue_size=100, use_pipe=False)
    """

    try:
        import ujson as json
    except Exception as e:
        sys.stderr.write("import ujson error: " + str(e) + " use json\n")
        import json

    assert type(readers) is list and len(readers) > 0

    def _read_into_queue(reader, queue):
446 447 448 449 450 451 452 453 454
        try:
            for sample in reader():
                if sample is None:
                    raise ValueError("sample has None")
                queue.put(sample)
            queue.put(None)
        except:
            queue.put("")
            six.reraise(*sys.exc_info())
Q
Qiao Longfei 已提交
455 456 457 458 459 460 461 462 463 464 465 466 467 468

    def queue_reader():
        queue = multiprocessing.Queue(queue_size)
        for reader in readers:
            p = multiprocessing.Process(
                target=_read_into_queue, args=(reader, queue))
            p.start()

        reader_num = len(readers)
        finish_num = 0
        while finish_num < reader_num:
            sample = queue.get()
            if sample is None:
                finish_num += 1
469 470
            elif sample == "":
                raise ValueError("multiprocess reader raises an exception")
Q
Qiao Longfei 已提交
471 472 473 474
            else:
                yield sample

    def _read_into_pipe(reader, conn):
475 476 477 478 479 480 481 482 483 484 485
        try:
            for sample in reader():
                if sample is None:
                    raise ValueError("sample has None!")
                conn.send(json.dumps(sample))
            conn.send(json.dumps(None))
            conn.close()
        except:
            conn.send(json.dumps(""))
            conn.close()
            six.reraise(*sys.exc_info())
Q
Qiao Longfei 已提交
486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508

    def pipe_reader():
        conns = []
        for reader in readers:
            parent_conn, child_conn = multiprocessing.Pipe()
            conns.append(parent_conn)
            p = multiprocessing.Process(
                target=_read_into_pipe, args=(reader, child_conn))
            p.start()

        reader_num = len(readers)
        finish_num = 0
        conn_to_remove = []
        while finish_num < reader_num:
            for conn in conn_to_remove:
                conns.remove(conn)
            conn_to_remove = []
            for conn in conns:
                sample = json.loads(conn.recv())
                if sample is None:
                    finish_num += 1
                    conn.close()
                    conn_to_remove.append(conn)
509 510 511 512
                elif sample == "":
                    conn.close()
                    conn_to_remove.append(conn)
                    raise ValueError("multiprocess reader raises an exception")
Q
Qiao Longfei 已提交
513 514 515 516 517 518 519 520 521
                else:
                    yield sample

    if use_pipe:
        return pipe_reader
    else:
        return queue_reader


522 523 524 525 526 527
def _buf2lines(buf, line_break="\n"):
    # FIXME: line_break should be automatically configured.
    lines = buf.split(line_break)
    return lines[:-1], lines[-1]


T
typhoonzero 已提交
528
class PipeReader:
529
    """
530
        PipeReader read data by stream from a command, take it's
T
typhoonzero 已提交
531 532
        stdout into a pipe buffer and redirect it to the parser to
        parse, then yield data as your desired format.
533

T
typhoonzero 已提交
534 535
        You can using standard linux command or call another program
        to read data, from HDFS, Ceph, URL, AWS S3 etc:
536

T
typhoonzero 已提交
537 538 539 540 541
        .. code-block:: python
           cmd = "hadoop fs -cat /path/to/some/file"
           cmd = "cat sample_file.tar.gz"
           cmd = "curl http://someurl"
           cmd = "python print_s3_bucket.py"
542

T
typhoonzero 已提交
543 544 545
        An example:

        .. code-block:: python
546

T
typhoonzero 已提交
547 548 549 550 551 552
           def example_reader():
               for f in myfiles:
                   pr = PipeReader("cat %s"%f)
                   for l in pr.get_line():
                       sample = l.split(" ")
                       yield sample
553 554
    """

T
typhoonzero 已提交
555 556 557 558 559 560 561 562 563 564 565 566 567 568 569
    def __init__(self, command, bufsize=8192, file_type="plain"):
        if not isinstance(command, str):
            raise TypeError("left_cmd must be a string")
        if file_type == "gzip":
            self.dec = zlib.decompressobj(
                32 + zlib.MAX_WBITS)  # offset 32 to skip the header
        self.file_type = file_type
        self.bufsize = bufsize
        self.process = subprocess.Popen(
            command.split(" "), bufsize=bufsize, stdout=subprocess.PIPE)

    def get_line(self, cut_lines=True, line_break="\n"):
        """
        :param cut_lines: cut buffer to lines
        :type cut_lines: bool
Z
Zeng Jinle 已提交
570
        :param line_break: line break of the file, like '\\\\n' or '\\\\r'
T
typhoonzero 已提交
571 572 573 574 575
        :type line_break: string

        :return: one line or a buffer of bytes
        :rtype: string
        """
576 577
        remained = ""
        while True:
T
typhoonzero 已提交
578
            buff = self.process.stdout.read(self.bufsize)
579
            if buff:
T
typhoonzero 已提交
580
                if self.file_type == "gzip":
M
minqiyang 已提交
581
                    decomp_buff = cpt.to_text(self.dec.decompress(buff))
T
typhoonzero 已提交
582
                elif self.file_type == "plain":
M
minqiyang 已提交
583
                    decomp_buff = cpt.to_text(buff)
584
                else:
T
typhoonzero 已提交
585 586
                    raise TypeError("file_type %s is not allowed" %
                                    self.file_type)
587 588 589 590

                if cut_lines:
                    lines, remained = _buf2lines(''.join(
                        [remained, decomp_buff]), line_break)
T
typhoonzero 已提交
591 592
                    for line in lines:
                        yield line
593
                else:
T
typhoonzero 已提交
594
                    yield decomp_buff
595 596
            else:
                break
Q
qiaolongfei 已提交
597 598


599
class Fake(object):
Q
qiaolongfei 已提交
600 601 602 603 604 605
    """
    fake reader will cache the first data it read and yield it out for data_num times.
    It is used to cache a data from real reader and use it for speed testing.

    :param reader: the origin reader
    :param data_num: times that this reader will yield data.
606

Q
qiaolongfei 已提交
607
    :return: a fake reader.
608 609 610 611 612 613 614 615 616

    Examples:
        .. code-block:: python

            def reader():
                for i in range(10):
                    yield i

            fake_reader = Fake()(reader, 100)
Q
qiaolongfei 已提交
617 618
    """

619 620 621
    def __init__(self):
        self.data = None
        self.yield_num = 0
Q
qiaolongfei 已提交
622

623 624 625
    def __call__(self, reader, data_num):
        def fake_reader():
            if self.data is None:
626
                self.data = next(reader())
627 628 629 630
            while self.yield_num < data_num:
                yield self.data
                self.yield_num += 1
            self.yield_num = 0
Q
qiaolongfei 已提交
631

632
        return fake_reader