decorator.py 12.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

H
Helin Wang 已提交
15 16
__all__ = [
    'map_readers', 'buffered', 'compose', 'chain', 'shuffle',
T
tangwei12 已提交
17
    'ComposeNotAligned', 'firstn', 'xmap_readers', 'pipe_reader'
H
Helin Wang 已提交
18
]
19

T
tangwei12 已提交
20 21 22 23
from threading import Thread
import subprocess

from Queue import Queue
24 25
import itertools
import random
T
tangwei12 已提交
26
import zlib
27 28


H
Helin Wang 已提交
29 30 31 32 33
def map_readers(func, *readers):
    """
    Creates a data reader that outputs return value of function using
    output of each data readers as arguments.

Y
Yu Yang 已提交
34 35 36 37 38
    :param func: function to use. The type of func should be (Sample) => Sample
    :type: callable
    :param readers: readers whose outputs will be used as arguments of func.
    :return: the created data reader.
    :rtype: callable
H
Helin Wang 已提交
39 40 41 42 43 44 45 46 47 48 49 50
    """

    def reader():
        rs = []
        for r in readers:
            rs.append(r())
        for e in itertools.imap(func, *rs):
            yield e

    return reader


H
Helin Wang 已提交
51
def shuffle(reader, buf_size):
52
    """
Y
Yu Yang 已提交
53
    Creates a data reader whose data output is shuffled.
54

H
Helin Wang 已提交
55
    Output from the iterator that created by original reader will be
56 57 58
    buffered into shuffle buffer, and then shuffled. The size of shuffle buffer
    is determined by argument buf_size.

59
    :param reader: the original reader whose output will be shuffled.
Y
Yu Yang 已提交
60
    :type reader: callable
61
    :param buf_size: shuffle buffer size.
Y
Yu Yang 已提交
62
    :type buf_size: int
63

Y
Yu Yang 已提交
64 65
    :return: the new reader whose output is shuffled.
    :rtype: callable
66 67
    """

H
Helin Wang 已提交
68
    def data_reader():
69
        buf = []
H
Helin Wang 已提交
70
        for e in reader():
71 72 73 74 75 76 77 78 79 80 81 82
            buf.append(e)
            if len(buf) >= buf_size:
                random.shuffle(buf)
                for b in buf:
                    yield b
                buf = []

        if len(buf) > 0:
            random.shuffle(buf)
            for b in buf:
                yield b

H
Helin Wang 已提交
83
    return data_reader
84 85


H
Helin Wang 已提交
86
def chain(*readers):
87 88 89
    """
    Creates a data reader whose output is the outputs of input data
    readers chained together.
90

H
Helin Wang 已提交
91
    If input readers output following data entries:
92 93 94
    [0, 0, 0]
    [1, 1, 1]
    [2, 2, 2]
H
Helin Wang 已提交
95
    The chained reader will output:
96 97
    [0, 0, 0, 1, 1, 1, 2, 2, 2]

98
    :param readers: input readers.
Y
Yu Yang 已提交
99 100
    :return: the new data reader.
    :rtype: callable
101 102
    """

H
Helin Wang 已提交
103
    def reader():
104
        rs = []
H
Helin Wang 已提交
105
        for r in readers:
106 107 108 109 110
            rs.append(r())

        for e in itertools.chain(*rs):
            yield e

H
Helin Wang 已提交
111
    return reader
112 113


H
Helin Wang 已提交
114
class ComposeNotAligned(ValueError):
115 116 117
    pass


H
Helin Wang 已提交
118
def compose(*readers, **kwargs):
119 120
    """
    Creates a data reader whose output is the combination of input readers.
121

H
Helin Wang 已提交
122
    If input readers output following data entries:
123
    (1, 2)    3    (4, 5)
H
Helin Wang 已提交
124
    The composed reader will output:
125 126
    (1, 2, 3, 4, 5)

Y
Yu Yang 已提交
127 128
    :param readers: readers that will be composed together.
    :param check_alignment: if True, will check if input readers are aligned
129 130
        correctly. If False, will not check alignment and trailing outputs
        will be discarded. Defaults to True.
Y
Yu Yang 已提交
131
    :type check_alignment: bool
132

Y
Yu Yang 已提交
133
    :return: the new data reader.
134

135 136
    :raises ComposeNotAligned: outputs of readers are not aligned.
        Will not raise when check_alignment is set to False.
137 138 139 140 141 142 143 144 145
    """
    check_alignment = kwargs.pop('check_alignment', True)

    def make_tuple(x):
        if isinstance(x, tuple):
            return x
        else:
            return (x, )

H
Helin Wang 已提交
146
    def reader():
147
        rs = []
H
Helin Wang 已提交
148
        for r in readers:
149 150 151 152 153 154 155 156 157
            rs.append(r())
        if not check_alignment:
            for outputs in itertools.izip(*rs):
                yield sum(map(make_tuple, outputs), ())
        else:
            for outputs in itertools.izip_longest(*rs):
                for o in outputs:
                    if o is None:
                        # None will be not be present if compose is aligned
H
Helin Wang 已提交
158 159
                        raise ComposeNotAligned(
                            "outputs of readers are not aligned.")
160 161
                yield sum(map(make_tuple, outputs), ())

H
Helin Wang 已提交
162
    return reader
163 164


H
Helin Wang 已提交
165
def buffered(reader, size):
166 167
    """
    Creates a buffered data reader.
168

H
Helin Wang 已提交
169 170
    The buffered data reader will read and save data entries into a
    buffer. Reading from the buffered data reader will proceed as long
171
    as the buffer is not empty.
172

173
    :param reader: the data reader to read from.
Y
Yu Yang 已提交
174
    :type reader: callable
175
    :param size: max buffer size.
Y
Yu Yang 已提交
176
    :type size: int
177

178
    :returns: the buffered data reader.
179 180 181 182 183 184 185 186 187 188 189 190
    """

    class EndSignal():
        pass

    end = EndSignal()

    def read_worker(r, q):
        for d in r:
            q.put(d)
        q.put(end)

H
Helin Wang 已提交
191 192
    def data_reader():
        r = reader()
193 194 195 196 197 198 199 200 201 202 203 204
        q = Queue(maxsize=size)
        t = Thread(
            target=read_worker, args=(
                r,
                q, ))
        t.daemon = True
        t.start()
        e = q.get()
        while e != end:
            yield e
            e = q.get()

H
Helin Wang 已提交
205
    return data_reader
Y
Yu Yang 已提交
206 207


Y
Yu Yang 已提交
208
def firstn(reader, n):
Y
Yu Yang 已提交
209 210
    """
    Limit the max number of samples that reader could return.
Y
Yu Yang 已提交
211 212 213 214 215 216 217

    :param reader: the data reader to read from.
    :type reader: callable
    :param n: the max number of samples that return.
    :type n: int
    :return: the decorated reader.
    :rtype: callable
Y
Yu Yang 已提交
218 219
    """

Y
Yu Yang 已提交
220 221 222 223
    # TODO(yuyang18): Check if just drop the reader, could clean the opened
    # resource or not?

    def firstn_reader():
Y
Yu Yang 已提交
224
        for i, item in enumerate(reader()):
Y
Yu Yang 已提交
225
            if i == n:
Y
Yu Yang 已提交
226 227 228
                break
            yield item

Y
Yu Yang 已提交
229
    return firstn_reader
230 231 232 233 234 235


class XmapEndSignal():
    pass


236
def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
237 238 239 240 241 242 243
    """
    Use multiprocess to map samples from reader by a mapper defined by user.
    And this function contains a buffered decorator.
    :param mapper:  a function to map sample.
    :type mapper: callable
    :param reader: the data reader to read from
    :type reader: callable
244
    :param process_num: process number to handle original sample
245 246 247
    :type process_num: int
    :param buffer_size: max buffer size
    :type buffer_size: int
248 249
    :param order: keep the order of reader
    :type order: bool
250 251 252 253
    :return: the decarated reader
    :rtype: callable
    """
    end = XmapEndSignal()
W
wanghaoshuang 已提交
254

255 256 257 258 259
    # define a worker to read samples from reader to in_queue
    def read_worker(reader, in_queue):
        for i in reader():
            in_queue.put(i)
        in_queue.put(end)
W
wanghaoshuang 已提交
260

261 262 263 264
    # define a worker to read samples from reader to in_queue with order flag
    def order_read_worker(reader, in_queue):
        in_order = 0
        for i in reader():
W
wanghaoshuang 已提交
265 266
            in_queue.put((in_order, i))
            in_order += 1
267
        in_queue.put(end)
268 269 270 271 272 273 274 275 276 277 278

    # define a worker to handle samples from in_queue by mapper
    # and put mapped samples into out_queue
    def handle_worker(in_queue, out_queue, mapper):
        sample = in_queue.get()
        while not isinstance(sample, XmapEndSignal):
            r = mapper(sample)
            out_queue.put(r)
            sample = in_queue.get()
        in_queue.put(end)
        out_queue.put(end)
W
wanghaoshuang 已提交
279

280 281 282 283 284 285 286 287 288 289
    # define a worker to handle samples from in_queue by mapper
    # and put mapped samples into out_queue by order
    def order_handle_worker(in_queue, out_queue, mapper, out_order):
        ins = in_queue.get()
        while not isinstance(ins, XmapEndSignal):
            order, sample = ins
            r = mapper(sample)
            while order != out_order[0]:
                pass
            out_queue.put(r)
W
wanghaoshuang 已提交
290
            out_order[0] += 1
291 292 293
            ins = in_queue.get()
        in_queue.put(end)
        out_queue.put(end)
294 295

    def xreader():
296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315
        in_queue = Queue(buffer_size)
        out_queue = Queue(buffer_size)
        out_order = [0]
        # start a read worker in a thread
        target = order_read_worker if order else read_worker
        t = Thread(target=target, args=(reader, in_queue))
        t.daemon = True
        t.start()
        # start several handle_workers
        target = order_handle_worker if order else handle_worker
        args = (in_queue, out_queue, mapper, out_order) if order else (
            in_queue, out_queue, mapper)
        workers = []
        for i in xrange(process_num):
            worker = Thread(target=target, args=args)
            worker.daemon = True
            workers.append(worker)
        for w in workers:
            w.start()

316 317 318 319 320 321 322 323 324 325 326 327 328
        sample = out_queue.get()
        while not isinstance(sample, XmapEndSignal):
            yield sample
            sample = out_queue.get()
        finish = 1
        while finish < process_num:
            sample = out_queue.get()
            if isinstance(sample, XmapEndSignal):
                finish += 1
            else:
                yield sample

    return xreader
329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400


def _buf2lines(buf, line_break="\n"):
    # FIXME: line_break should be automatically configured.
    lines = buf.split(line_break)
    return lines[:-1], lines[-1]


def pipe_reader(left_cmd,
                parser,
                bufsize=8192,
                file_type="plain",
                cut_lines=True,
                line_break="\n"):
    """
    pipe_reader read data by stream from a command, take it's 
    stdout into a pipe buffer and redirect it to the parser to
    parse, then yield data as your desired format.

    You can using standard linux command or call another program
    to read data, from HDFS, Ceph, URL, AWS S3 etc:

    cmd = "hadoop fs -cat /path/to/some/file"
    cmd = "cat sample_file.tar.gz"
    cmd = "curl http://someurl"
    cmd = "python print_s3_bucket.py"

    A sample parser:
    
    def sample_parser(lines):
        # parse each line as one sample data,
        # return a list of samples as batches.
        ret = []
        for l in lines:
            ret.append(l.split(" ")[1:5])
        return ret

    :param left_cmd: command to excute to get stdout from.
    :type left_cmd: string
    :param parser: parser function to parse lines of data.
                   if cut_lines is True, parser will receive list
                   of lines.
                   if cut_lines is False, parser will receive a
                   raw buffer each time.
                   parser should return a list of parsed values.
    :type parser: callable
    :param bufsize: the buffer size used for the stdout pipe.
    :type bufsize: int
    :param file_type: can be plain/gzip, stream buffer data type.
    :type file_type: string
    :param cut_lines: whether to pass lines instead of raw buffer
                      to the parser
    :type cut_lines: bool
    :param line_break: line break of the file, like \n or \r
    :type line_break: string

    :return: the reader generator.
    :rtype: callable
    """
    if not isinstance(left_cmd, str):
        raise TypeError("left_cmd must be a string")
    if not callable(parser):
        raise TypeError("parser must be a callable object")

    # TODO(typhoonzero): add a thread to read stderr

    # Always init a decompress object is better than
    # create in the loop.
    dec = zlib.decompressobj(
        32 + zlib.MAX_WBITS)  # offset 32 to skip the header

    def reader():
Y
Yancey 已提交
401 402
        process = subprocess.Popen(
            left_cmd.split(" "), bufsize=bufsize, stdout=subprocess.PIPE)
403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426
        remained = ""
        while True:
            buff = process.stdout.read(bufsize)
            if buff:
                if file_type == "gzip":
                    decomp_buff = dec.decompress(buff)
                elif file_type == "plain":
                    decomp_buff = buff
                else:
                    raise TypeError("file_type %s is not allowed" % file_type)

                if cut_lines:
                    lines, remained = _buf2lines(''.join(
                        [remained, decomp_buff]), line_break)
                    parsed_list = parser(lines)
                    for ret in parsed_list:
                        yield ret
                else:
                    for ret in parser(decomp_buff):
                        yield ret
            else:
                break

    return reader