decorator.py 11.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

H
Helin Wang 已提交
15 16
__all__ = [
    'map_readers', 'buffered', 'compose', 'chain', 'shuffle',
T
typhoonzero 已提交
17
    'ComposeNotAligned', 'firstn', 'xmap_readers', 'PipeReader'
H
Helin Wang 已提交
18
]
19

T
tangwei12 已提交
20 21 22
from threading import Thread
import subprocess

23
from six.moves.queue import Queue
24
from six.moves import zip_longest
25 26
from six.moves import map
from six.moves import zip
27 28
import itertools
import random
T
tangwei12 已提交
29
import zlib
30 31


H
Helin Wang 已提交
32 33 34 35 36
def map_readers(func, *readers):
    """
    Creates a data reader that outputs return value of function using
    output of each data readers as arguments.

Y
Yu Yang 已提交
37 38 39 40 41
    :param func: function to use. The type of func should be (Sample) => Sample
    :type: callable
    :param readers: readers whose outputs will be used as arguments of func.
    :return: the created data reader.
    :rtype: callable
H
Helin Wang 已提交
42 43 44 45 46 47
    """

    def reader():
        rs = []
        for r in readers:
            rs.append(r())
48
        for e in map(func, *rs):
H
Helin Wang 已提交
49 50 51 52 53
            yield e

    return reader


H
Helin Wang 已提交
54
def shuffle(reader, buf_size):
55
    """
Y
Yu Yang 已提交
56
    Creates a data reader whose data output is shuffled.
57

H
Helin Wang 已提交
58
    Output from the iterator that created by original reader will be
59 60 61
    buffered into shuffle buffer, and then shuffled. The size of shuffle buffer
    is determined by argument buf_size.

62
    :param reader: the original reader whose output will be shuffled.
Y
Yu Yang 已提交
63
    :type reader: callable
64
    :param buf_size: shuffle buffer size.
Y
Yu Yang 已提交
65
    :type buf_size: int
66

Y
Yu Yang 已提交
67 68
    :return: the new reader whose output is shuffled.
    :rtype: callable
69 70
    """

H
Helin Wang 已提交
71
    def data_reader():
72
        buf = []
H
Helin Wang 已提交
73
        for e in reader():
74 75 76 77 78 79 80 81 82 83 84 85
            buf.append(e)
            if len(buf) >= buf_size:
                random.shuffle(buf)
                for b in buf:
                    yield b
                buf = []

        if len(buf) > 0:
            random.shuffle(buf)
            for b in buf:
                yield b

H
Helin Wang 已提交
86
    return data_reader
87 88


H
Helin Wang 已提交
89
def chain(*readers):
90 91 92
    """
    Creates a data reader whose output is the outputs of input data
    readers chained together.
93

H
Helin Wang 已提交
94
    If input readers output following data entries:
95 96 97
    [0, 0, 0]
    [1, 1, 1]
    [2, 2, 2]
H
Helin Wang 已提交
98
    The chained reader will output:
99 100
    [0, 0, 0, 1, 1, 1, 2, 2, 2]

101
    :param readers: input readers.
Y
Yu Yang 已提交
102 103
    :return: the new data reader.
    :rtype: callable
104 105
    """

H
Helin Wang 已提交
106
    def reader():
107
        rs = []
H
Helin Wang 已提交
108
        for r in readers:
109 110 111 112 113
            rs.append(r())

        for e in itertools.chain(*rs):
            yield e

H
Helin Wang 已提交
114
    return reader
115 116


H
Helin Wang 已提交
117
class ComposeNotAligned(ValueError):
118 119 120
    pass


H
Helin Wang 已提交
121
def compose(*readers, **kwargs):
122 123
    """
    Creates a data reader whose output is the combination of input readers.
124

H
Helin Wang 已提交
125
    If input readers output following data entries:
126
    (1, 2)    3    (4, 5)
H
Helin Wang 已提交
127
    The composed reader will output:
128 129
    (1, 2, 3, 4, 5)

Y
Yu Yang 已提交
130 131
    :param readers: readers that will be composed together.
    :param check_alignment: if True, will check if input readers are aligned
132 133
        correctly. If False, will not check alignment and trailing outputs
        will be discarded. Defaults to True.
Y
Yu Yang 已提交
134
    :type check_alignment: bool
135

Y
Yu Yang 已提交
136
    :return: the new data reader.
137

138 139
    :raises ComposeNotAligned: outputs of readers are not aligned.
        Will not raise when check_alignment is set to False.
140 141 142 143 144 145 146 147 148
    """
    check_alignment = kwargs.pop('check_alignment', True)

    def make_tuple(x):
        if isinstance(x, tuple):
            return x
        else:
            return (x, )

H
Helin Wang 已提交
149
    def reader():
150
        rs = []
H
Helin Wang 已提交
151
        for r in readers:
152 153
            rs.append(r())
        if not check_alignment:
154 155
            for outputs in zip(*rs):
                yield sum(list(map(make_tuple, outputs)), ())
156
        else:
157
            for outputs in zip_longest(*rs):
158 159 160
                for o in outputs:
                    if o is None:
                        # None will be not be present if compose is aligned
H
Helin Wang 已提交
161 162
                        raise ComposeNotAligned(
                            "outputs of readers are not aligned.")
163
                yield sum(list(map(make_tuple, outputs)), ())
164

H
Helin Wang 已提交
165
    return reader
166 167


H
Helin Wang 已提交
168
def buffered(reader, size):
169 170
    """
    Creates a buffered data reader.
171

H
Helin Wang 已提交
172 173
    The buffered data reader will read and save data entries into a
    buffer. Reading from the buffered data reader will proceed as long
174
    as the buffer is not empty.
175

176
    :param reader: the data reader to read from.
Y
Yu Yang 已提交
177
    :type reader: callable
178
    :param size: max buffer size.
Y
Yu Yang 已提交
179
    :type size: int
180

181
    :returns: the buffered data reader.
182 183 184 185 186 187 188 189 190 191 192 193
    """

    class EndSignal():
        pass

    end = EndSignal()

    def read_worker(r, q):
        for d in r:
            q.put(d)
        q.put(end)

H
Helin Wang 已提交
194 195
    def data_reader():
        r = reader()
196
        q = Queue(maxsize=size)
197 198 199 200 201 202 203 204 205 206 207
        t = Thread(
            target=read_worker, args=(
                r,
                q, ))
        t.daemon = True
        t.start()
        e = q.get()
        while e != end:
            yield e
            e = q.get()

H
Helin Wang 已提交
208
    return data_reader
Y
Yu Yang 已提交
209 210


Y
Yu Yang 已提交
211
def firstn(reader, n):
Y
Yu Yang 已提交
212 213
    """
    Limit the max number of samples that reader could return.
Y
Yu Yang 已提交
214 215 216 217 218 219 220

    :param reader: the data reader to read from.
    :type reader: callable
    :param n: the max number of samples that return.
    :type n: int
    :return: the decorated reader.
    :rtype: callable
Y
Yu Yang 已提交
221 222
    """

Y
Yu Yang 已提交
223 224 225 226
    # TODO(yuyang18): Check if just drop the reader, could clean the opened
    # resource or not?

    def firstn_reader():
Y
Yu Yang 已提交
227
        for i, item in enumerate(reader()):
Y
Yu Yang 已提交
228
            if i == n:
Y
Yu Yang 已提交
229 230 231
                break
            yield item

Y
Yu Yang 已提交
232
    return firstn_reader
233 234 235 236 237 238


class XmapEndSignal():
    pass


239
def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
240 241 242 243 244 245 246
    """
    Use multiprocess to map samples from reader by a mapper defined by user.
    And this function contains a buffered decorator.
    :param mapper:  a function to map sample.
    :type mapper: callable
    :param reader: the data reader to read from
    :type reader: callable
247
    :param process_num: process number to handle original sample
248 249 250
    :type process_num: int
    :param buffer_size: max buffer size
    :type buffer_size: int
251 252
    :param order: keep the order of reader
    :type order: bool
253 254 255 256
    :return: the decarated reader
    :rtype: callable
    """
    end = XmapEndSignal()
W
wanghaoshuang 已提交
257

258 259 260 261 262
    # define a worker to read samples from reader to in_queue
    def read_worker(reader, in_queue):
        for i in reader():
            in_queue.put(i)
        in_queue.put(end)
W
wanghaoshuang 已提交
263

264 265 266 267
    # define a worker to read samples from reader to in_queue with order flag
    def order_read_worker(reader, in_queue):
        in_order = 0
        for i in reader():
W
wanghaoshuang 已提交
268 269
            in_queue.put((in_order, i))
            in_order += 1
270
        in_queue.put(end)
271 272 273 274 275 276 277 278 279 280 281

    # define a worker to handle samples from in_queue by mapper
    # and put mapped samples into out_queue
    def handle_worker(in_queue, out_queue, mapper):
        sample = in_queue.get()
        while not isinstance(sample, XmapEndSignal):
            r = mapper(sample)
            out_queue.put(r)
            sample = in_queue.get()
        in_queue.put(end)
        out_queue.put(end)
W
wanghaoshuang 已提交
282

283 284 285 286 287 288 289 290 291 292
    # define a worker to handle samples from in_queue by mapper
    # and put mapped samples into out_queue by order
    def order_handle_worker(in_queue, out_queue, mapper, out_order):
        ins = in_queue.get()
        while not isinstance(ins, XmapEndSignal):
            order, sample = ins
            r = mapper(sample)
            while order != out_order[0]:
                pass
            out_queue.put(r)
W
wanghaoshuang 已提交
293
            out_order[0] += 1
294 295 296
            ins = in_queue.get()
        in_queue.put(end)
        out_queue.put(end)
297 298

    def xreader():
299 300
        in_queue = Queue(buffer_size)
        out_queue = Queue(buffer_size)
301 302 303 304 305 306 307 308 309 310 311
        out_order = [0]
        # start a read worker in a thread
        target = order_read_worker if order else read_worker
        t = Thread(target=target, args=(reader, in_queue))
        t.daemon = True
        t.start()
        # start several handle_workers
        target = order_handle_worker if order else handle_worker
        args = (in_queue, out_queue, mapper, out_order) if order else (
            in_queue, out_queue, mapper)
        workers = []
312
        for i in range(process_num):
313 314 315 316 317 318
            worker = Thread(target=target, args=args)
            worker.daemon = True
            workers.append(worker)
        for w in workers:
            w.start()

319 320 321 322 323 324 325 326 327 328 329 330 331
        sample = out_queue.get()
        while not isinstance(sample, XmapEndSignal):
            yield sample
            sample = out_queue.get()
        finish = 1
        while finish < process_num:
            sample = out_queue.get()
            if isinstance(sample, XmapEndSignal):
                finish += 1
            else:
                yield sample

    return xreader
332 333 334 335 336 337 338 339


def _buf2lines(buf, line_break="\n"):
    # FIXME: line_break should be automatically configured.
    lines = buf.split(line_break)
    return lines[:-1], lines[-1]


T
typhoonzero 已提交
340
class PipeReader:
341
    """
342
        PipeReader read data by stream from a command, take it's
T
typhoonzero 已提交
343 344
        stdout into a pipe buffer and redirect it to the parser to
        parse, then yield data as your desired format.
345

T
typhoonzero 已提交
346 347
        You can using standard linux command or call another program
        to read data, from HDFS, Ceph, URL, AWS S3 etc:
348

T
typhoonzero 已提交
349 350 351 352 353
        .. code-block:: python
           cmd = "hadoop fs -cat /path/to/some/file"
           cmd = "cat sample_file.tar.gz"
           cmd = "curl http://someurl"
           cmd = "python print_s3_bucket.py"
354

T
typhoonzero 已提交
355 356 357
        An example:

        .. code-block:: python
358

T
typhoonzero 已提交
359 360 361 362 363 364
           def example_reader():
               for f in myfiles:
                   pr = PipeReader("cat %s"%f)
                   for l in pr.get_line():
                       sample = l.split(" ")
                       yield sample
365 366
    """

T
typhoonzero 已提交
367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387
    def __init__(self, command, bufsize=8192, file_type="plain"):
        if not isinstance(command, str):
            raise TypeError("left_cmd must be a string")
        if file_type == "gzip":
            self.dec = zlib.decompressobj(
                32 + zlib.MAX_WBITS)  # offset 32 to skip the header
        self.file_type = file_type
        self.bufsize = bufsize
        self.process = subprocess.Popen(
            command.split(" "), bufsize=bufsize, stdout=subprocess.PIPE)

    def get_line(self, cut_lines=True, line_break="\n"):
        """
        :param cut_lines: cut buffer to lines
        :type cut_lines: bool
        :param line_break: line break of the file, like \n or \r
        :type line_break: string

        :return: one line or a buffer of bytes
        :rtype: string
        """
388 389
        remained = ""
        while True:
T
typhoonzero 已提交
390
            buff = self.process.stdout.read(self.bufsize)
391
            if buff:
T
typhoonzero 已提交
392 393 394
                if self.file_type == "gzip":
                    decomp_buff = self.dec.decompress(buff)
                elif self.file_type == "plain":
395 396
                    decomp_buff = buff
                else:
T
typhoonzero 已提交
397 398
                    raise TypeError("file_type %s is not allowed" %
                                    self.file_type)
399 400 401 402

                if cut_lines:
                    lines, remained = _buf2lines(''.join(
                        [remained, decomp_buff]), line_break)
T
typhoonzero 已提交
403 404
                    for line in lines:
                        yield line
405
                else:
T
typhoonzero 已提交
406
                    yield decomp_buff
407 408
            else:
                break