decorator.py 11.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

H
Helin Wang 已提交
15 16
__all__ = [
    'map_readers', 'buffered', 'compose', 'chain', 'shuffle',
T
typhoonzero 已提交
17
    'ComposeNotAligned', 'firstn', 'xmap_readers', 'PipeReader'
H
Helin Wang 已提交
18
]
19

T
tangwei12 已提交
20 21 22
from threading import Thread
import subprocess

23
from six.moves.queue import Queue
24
from six.moves import zip_longest
25 26
import itertools
import random
T
tangwei12 已提交
27
import zlib
28 29


H
Helin Wang 已提交
30 31 32 33 34
def map_readers(func, *readers):
    """
    Creates a data reader that outputs return value of function using
    output of each data readers as arguments.

Y
Yu Yang 已提交
35 36 37 38 39
    :param func: function to use. The type of func should be (Sample) => Sample
    :type: callable
    :param readers: readers whose outputs will be used as arguments of func.
    :return: the created data reader.
    :rtype: callable
H
Helin Wang 已提交
40 41 42 43 44 45
    """

    def reader():
        rs = []
        for r in readers:
            rs.append(r())
46
        for e in map(func, *rs):
H
Helin Wang 已提交
47 48 49 50 51
            yield e

    return reader


H
Helin Wang 已提交
52
def shuffle(reader, buf_size):
53
    """
Y
Yu Yang 已提交
54
    Creates a data reader whose data output is shuffled.
55

H
Helin Wang 已提交
56
    Output from the iterator that created by original reader will be
57 58 59
    buffered into shuffle buffer, and then shuffled. The size of shuffle buffer
    is determined by argument buf_size.

60
    :param reader: the original reader whose output will be shuffled.
Y
Yu Yang 已提交
61
    :type reader: callable
62
    :param buf_size: shuffle buffer size.
Y
Yu Yang 已提交
63
    :type buf_size: int
64

Y
Yu Yang 已提交
65 66
    :return: the new reader whose output is shuffled.
    :rtype: callable
67 68
    """

H
Helin Wang 已提交
69
    def data_reader():
70
        buf = []
H
Helin Wang 已提交
71
        for e in reader():
72 73 74 75 76 77 78 79 80 81 82 83
            buf.append(e)
            if len(buf) >= buf_size:
                random.shuffle(buf)
                for b in buf:
                    yield b
                buf = []

        if len(buf) > 0:
            random.shuffle(buf)
            for b in buf:
                yield b

H
Helin Wang 已提交
84
    return data_reader
85 86


H
Helin Wang 已提交
87
def chain(*readers):
88 89 90
    """
    Creates a data reader whose output is the outputs of input data
    readers chained together.
91

H
Helin Wang 已提交
92
    If input readers output following data entries:
93 94 95
    [0, 0, 0]
    [1, 1, 1]
    [2, 2, 2]
H
Helin Wang 已提交
96
    The chained reader will output:
97 98
    [0, 0, 0, 1, 1, 1, 2, 2, 2]

99
    :param readers: input readers.
Y
Yu Yang 已提交
100 101
    :return: the new data reader.
    :rtype: callable
102 103
    """

H
Helin Wang 已提交
104
    def reader():
105
        rs = []
H
Helin Wang 已提交
106
        for r in readers:
107 108 109 110 111
            rs.append(r())

        for e in itertools.chain(*rs):
            yield e

H
Helin Wang 已提交
112
    return reader
113 114


H
Helin Wang 已提交
115
class ComposeNotAligned(ValueError):
116 117 118
    pass


H
Helin Wang 已提交
119
def compose(*readers, **kwargs):
120 121
    """
    Creates a data reader whose output is the combination of input readers.
122

H
Helin Wang 已提交
123
    If input readers output following data entries:
124
    (1, 2)    3    (4, 5)
H
Helin Wang 已提交
125
    The composed reader will output:
126 127
    (1, 2, 3, 4, 5)

Y
Yu Yang 已提交
128 129
    :param readers: readers that will be composed together.
    :param check_alignment: if True, will check if input readers are aligned
130 131
        correctly. If False, will not check alignment and trailing outputs
        will be discarded. Defaults to True.
Y
Yu Yang 已提交
132
    :type check_alignment: bool
133

Y
Yu Yang 已提交
134
    :return: the new data reader.
135

136 137
    :raises ComposeNotAligned: outputs of readers are not aligned.
        Will not raise when check_alignment is set to False.
138 139 140 141 142 143 144 145 146
    """
    check_alignment = kwargs.pop('check_alignment', True)

    def make_tuple(x):
        if isinstance(x, tuple):
            return x
        else:
            return (x, )

H
Helin Wang 已提交
147
    def reader():
148
        rs = []
H
Helin Wang 已提交
149
        for r in readers:
150 151
            rs.append(r())
        if not check_alignment:
152 153
            for outputs in zip(*rs):
                yield sum(list(map(make_tuple, outputs)), ())
154
        else:
155
            for outputs in zip_longest(*rs):
156 157 158
                for o in outputs:
                    if o is None:
                        # None will be not be present if compose is aligned
H
Helin Wang 已提交
159 160
                        raise ComposeNotAligned(
                            "outputs of readers are not aligned.")
161
                yield sum(list(map(make_tuple, outputs)), ())
162

H
Helin Wang 已提交
163
    return reader
164 165


H
Helin Wang 已提交
166
def buffered(reader, size):
167 168
    """
    Creates a buffered data reader.
169

H
Helin Wang 已提交
170 171
    The buffered data reader will read and save data entries into a
    buffer. Reading from the buffered data reader will proceed as long
172
    as the buffer is not empty.
173

174
    :param reader: the data reader to read from.
Y
Yu Yang 已提交
175
    :type reader: callable
176
    :param size: max buffer size.
Y
Yu Yang 已提交
177
    :type size: int
178

179
    :returns: the buffered data reader.
180 181 182 183 184 185 186 187 188 189 190 191
    """

    class EndSignal():
        pass

    end = EndSignal()

    def read_worker(r, q):
        for d in r:
            q.put(d)
        q.put(end)

H
Helin Wang 已提交
192 193
    def data_reader():
        r = reader()
194
        q = Queue(maxsize=size)
195 196 197 198 199 200 201 202 203 204 205
        t = Thread(
            target=read_worker, args=(
                r,
                q, ))
        t.daemon = True
        t.start()
        e = q.get()
        while e != end:
            yield e
            e = q.get()

H
Helin Wang 已提交
206
    return data_reader
Y
Yu Yang 已提交
207 208


Y
Yu Yang 已提交
209
def firstn(reader, n):
Y
Yu Yang 已提交
210 211
    """
    Limit the max number of samples that reader could return.
Y
Yu Yang 已提交
212 213 214 215 216 217 218

    :param reader: the data reader to read from.
    :type reader: callable
    :param n: the max number of samples that return.
    :type n: int
    :return: the decorated reader.
    :rtype: callable
Y
Yu Yang 已提交
219 220
    """

Y
Yu Yang 已提交
221 222 223 224
    # TODO(yuyang18): Check if just drop the reader, could clean the opened
    # resource or not?

    def firstn_reader():
Y
Yu Yang 已提交
225
        for i, item in enumerate(reader()):
Y
Yu Yang 已提交
226
            if i == n:
Y
Yu Yang 已提交
227 228 229
                break
            yield item

Y
Yu Yang 已提交
230
    return firstn_reader
231 232 233 234 235 236


class XmapEndSignal():
    pass


237
def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
238 239 240 241 242 243 244
    """
    Use multiprocess to map samples from reader by a mapper defined by user.
    And this function contains a buffered decorator.
    :param mapper:  a function to map sample.
    :type mapper: callable
    :param reader: the data reader to read from
    :type reader: callable
245
    :param process_num: process number to handle original sample
246 247 248
    :type process_num: int
    :param buffer_size: max buffer size
    :type buffer_size: int
249 250
    :param order: keep the order of reader
    :type order: bool
251 252 253 254
    :return: the decarated reader
    :rtype: callable
    """
    end = XmapEndSignal()
W
wanghaoshuang 已提交
255

256 257 258 259 260
    # define a worker to read samples from reader to in_queue
    def read_worker(reader, in_queue):
        for i in reader():
            in_queue.put(i)
        in_queue.put(end)
W
wanghaoshuang 已提交
261

262 263 264 265
    # define a worker to read samples from reader to in_queue with order flag
    def order_read_worker(reader, in_queue):
        in_order = 0
        for i in reader():
W
wanghaoshuang 已提交
266 267
            in_queue.put((in_order, i))
            in_order += 1
268
        in_queue.put(end)
269 270 271 272 273 274 275 276 277 278 279

    # define a worker to handle samples from in_queue by mapper
    # and put mapped samples into out_queue
    def handle_worker(in_queue, out_queue, mapper):
        sample = in_queue.get()
        while not isinstance(sample, XmapEndSignal):
            r = mapper(sample)
            out_queue.put(r)
            sample = in_queue.get()
        in_queue.put(end)
        out_queue.put(end)
W
wanghaoshuang 已提交
280

281 282 283 284 285 286 287 288 289 290
    # define a worker to handle samples from in_queue by mapper
    # and put mapped samples into out_queue by order
    def order_handle_worker(in_queue, out_queue, mapper, out_order):
        ins = in_queue.get()
        while not isinstance(ins, XmapEndSignal):
            order, sample = ins
            r = mapper(sample)
            while order != out_order[0]:
                pass
            out_queue.put(r)
W
wanghaoshuang 已提交
291
            out_order[0] += 1
292 293 294
            ins = in_queue.get()
        in_queue.put(end)
        out_queue.put(end)
295 296

    def xreader():
297 298
        in_queue = Queue(buffer_size)
        out_queue = Queue(buffer_size)
299 300 301 302 303 304 305 306 307 308 309
        out_order = [0]
        # start a read worker in a thread
        target = order_read_worker if order else read_worker
        t = Thread(target=target, args=(reader, in_queue))
        t.daemon = True
        t.start()
        # start several handle_workers
        target = order_handle_worker if order else handle_worker
        args = (in_queue, out_queue, mapper, out_order) if order else (
            in_queue, out_queue, mapper)
        workers = []
310
        for i in range(process_num):
311 312 313 314 315 316
            worker = Thread(target=target, args=args)
            worker.daemon = True
            workers.append(worker)
        for w in workers:
            w.start()

317 318 319 320 321 322 323 324 325 326 327 328 329
        sample = out_queue.get()
        while not isinstance(sample, XmapEndSignal):
            yield sample
            sample = out_queue.get()
        finish = 1
        while finish < process_num:
            sample = out_queue.get()
            if isinstance(sample, XmapEndSignal):
                finish += 1
            else:
                yield sample

    return xreader
330 331 332 333 334 335 336 337


def _buf2lines(buf, line_break="\n"):
    # FIXME: line_break should be automatically configured.
    lines = buf.split(line_break)
    return lines[:-1], lines[-1]


T
typhoonzero 已提交
338
class PipeReader:
339
    """
340
        PipeReader read data by stream from a command, take it's
T
typhoonzero 已提交
341 342
        stdout into a pipe buffer and redirect it to the parser to
        parse, then yield data as your desired format.
343

T
typhoonzero 已提交
344 345
        You can using standard linux command or call another program
        to read data, from HDFS, Ceph, URL, AWS S3 etc:
346

T
typhoonzero 已提交
347 348 349 350 351
        .. code-block:: python
           cmd = "hadoop fs -cat /path/to/some/file"
           cmd = "cat sample_file.tar.gz"
           cmd = "curl http://someurl"
           cmd = "python print_s3_bucket.py"
352

T
typhoonzero 已提交
353 354 355
        An example:

        .. code-block:: python
356

T
typhoonzero 已提交
357 358 359 360 361 362
           def example_reader():
               for f in myfiles:
                   pr = PipeReader("cat %s"%f)
                   for l in pr.get_line():
                       sample = l.split(" ")
                       yield sample
363 364
    """

T
typhoonzero 已提交
365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385
    def __init__(self, command, bufsize=8192, file_type="plain"):
        if not isinstance(command, str):
            raise TypeError("left_cmd must be a string")
        if file_type == "gzip":
            self.dec = zlib.decompressobj(
                32 + zlib.MAX_WBITS)  # offset 32 to skip the header
        self.file_type = file_type
        self.bufsize = bufsize
        self.process = subprocess.Popen(
            command.split(" "), bufsize=bufsize, stdout=subprocess.PIPE)

    def get_line(self, cut_lines=True, line_break="\n"):
        """
        :param cut_lines: cut buffer to lines
        :type cut_lines: bool
        :param line_break: line break of the file, like \n or \r
        :type line_break: string

        :return: one line or a buffer of bytes
        :rtype: string
        """
386 387
        remained = ""
        while True:
T
typhoonzero 已提交
388
            buff = self.process.stdout.read(self.bufsize)
389
            if buff:
T
typhoonzero 已提交
390 391 392
                if self.file_type == "gzip":
                    decomp_buff = self.dec.decompress(buff)
                elif self.file_type == "plain":
393 394
                    decomp_buff = buff
                else:
T
typhoonzero 已提交
395 396
                    raise TypeError("file_type %s is not allowed" %
                                    self.file_type)
397 398 399 400

                if cut_lines:
                    lines, remained = _buf2lines(''.join(
                        [remained, decomp_buff]), line_break)
T
typhoonzero 已提交
401 402
                    for line in lines:
                        yield line
403
                else:
T
typhoonzero 已提交
404
                    yield decomp_buff
405 406
            else:
                break