未验证 提交 b681537e 编写于 作者: Q Qiao Longfei 提交者: GitHub

Add multiprocess reader (#13311)

* add multiprocess_reader

* add multiprocess_reader to reader decorator

* support piped multi process reader

* revert v2 decorator

* add comment to multiprocess_reader

* optimize code

* use ujson to speed up json serialize/deserialize

* add assert to multiprocess_reader

* update comment of multiprocess_reader

* optimize ujson import, handle error case

* optimize import ujson

* remove ujson from requirements.txt

* add import sys to decorator.py
上级 b4dd5c24
......@@ -14,11 +14,14 @@
__all__ = [
'map_readers', 'buffered', 'compose', 'chain', 'shuffle',
'ComposeNotAligned', 'firstn', 'xmap_readers', 'PipeReader'
'ComposeNotAligned', 'firstn', 'xmap_readers', 'PipeReader',
'multiprocess_reader'
]
from threading import Thread
import subprocess
import multiprocessing
import sys
from six.moves.queue import Queue
from six.moves import zip_longest
......@@ -332,6 +335,100 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
return xreader
def multiprocess_reader(readers, use_pipe=True, queue_size=1000):
"""
multiprocess_reader use python multi process to read data from readers
and then use multiprocess.Queue or multiprocess.Pipe to merge all
data. The process number is equal to the number of input readers, each
process call one reader.
Multiprocess.Queue require the rw access right to /dev/shm, some
platform does not support.
you need to create multiple readers first, these readers should be independent
to each other so that each process can work independently.
An example:
.. code-block:: python
reader0 = reader(["file01", "file02"])
reader1 = reader(["file11", "file12"])
reader1 = reader(["file21", "file22"])
reader = multiprocess_reader([reader0, reader1, reader2],
queue_size=100, use_pipe=False)
"""
try:
import ujson as json
except Exception as e:
sys.stderr.write("import ujson error: " + str(e) + " use json\n")
import json
assert type(readers) is list and len(readers) > 0
def _read_into_queue(reader, queue):
for sample in reader():
if sample is None:
raise ValueError("sample has None")
queue.put(sample)
queue.put(None)
def queue_reader():
queue = multiprocessing.Queue(queue_size)
for reader in readers:
p = multiprocessing.Process(
target=_read_into_queue, args=(reader, queue))
p.start()
reader_num = len(readers)
finish_num = 0
while finish_num < reader_num:
sample = queue.get()
if sample is None:
finish_num += 1
else:
yield sample
def _read_into_pipe(reader, conn):
for sample in reader():
if sample is None:
raise ValueError("sample has None!")
conn.send(json.dumps(sample))
conn.send(json.dumps(None))
conn.close()
def pipe_reader():
conns = []
for reader in readers:
parent_conn, child_conn = multiprocessing.Pipe()
conns.append(parent_conn)
p = multiprocessing.Process(
target=_read_into_pipe, args=(reader, child_conn))
p.start()
reader_num = len(readers)
finish_num = 0
conn_to_remove = []
while finish_num < reader_num:
for conn in conn_to_remove:
conns.remove(conn)
conn_to_remove = []
for conn in conns:
sample = json.loads(conn.recv())
if sample is None:
finish_num += 1
conn.close()
conn_to_remove.append(conn)
else:
yield sample
if use_pipe:
return pipe_reader
else:
return queue_reader
def _buf2lines(buf, line_break="\n"):
# FIXME: line_break should be automatically configured.
lines = buf.split(line_break)
......
......@@ -14,6 +14,7 @@
import time
import unittest
import functools
import paddle.reader
......@@ -174,5 +175,33 @@ class TestPipeReader(unittest.TestCase):
temp.close()
class TestMultiProcessReader(unittest.TestCase):
def setup(self):
self.samples = []
for i in range(1000):
self.samples.append([[i], [i + 1, i + 2], i + 3])
def reader(index):
for i in range(len(self.samples)):
if i % 3 == index:
yield self.samples[i]
self.reader0 = functools.partial(reader, 0)
self.reader1 = functools.partial(reader, 1)
self.reader2 = functools.partial(reader, 2)
def reader_test(self, use_pipe):
self.setup()
results = []
for data in paddle.reader.multiprocess_reader(
[self.reader0, self.reader1, self.reader2], 100, use_pipe)():
results.append(data)
self.assertEqual(sorted(self.samples), sorted(results))
def test_multi_process_reader(self):
self.reader_test(use_pipe=False)
self.reader_test(use_pipe=True)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册