提交 f2e6b99a 编写于 作者: W whs 提交者: GitHub

Merge pull request #2527 from wanghaoshuang/order_xmap

Add an order switch to xmap_readers
...@@ -230,7 +230,7 @@ class XmapEndSignal(): ...@@ -230,7 +230,7 @@ class XmapEndSignal():
pass pass
def xmap_readers(mapper, reader, process_num, buffer_size): def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
""" """
Use multiprocess to map samples from reader by a mapper defined by user. Use multiprocess to map samples from reader by a mapper defined by user.
And this function contains a buffered decorator. And this function contains a buffered decorator.
...@@ -242,12 +242,15 @@ def xmap_readers(mapper, reader, process_num, buffer_size): ...@@ -242,12 +242,15 @@ def xmap_readers(mapper, reader, process_num, buffer_size):
:type process_num: int :type process_num: int
:param buffer_size: max buffer size :param buffer_size: max buffer size
:type buffer_size: int :type buffer_size: int
:param order: keep the order of reader
:type order: bool
:return: the decarated reader :return: the decarated reader
:rtype: callable :rtype: callable
""" """
end = XmapEndSignal() end = XmapEndSignal()
in_queue = Queue(buffer_size) in_queue = Queue(buffer_size)
out_queue = Queue(buffer_size) out_queue = Queue(buffer_size)
out_order = [0]
# define a worker to read samples from reader to in_queue # define a worker to read samples from reader to in_queue
def read_worker(reader, in_queue): def read_worker(reader, in_queue):
...@@ -255,8 +258,17 @@ def xmap_readers(mapper, reader, process_num, buffer_size): ...@@ -255,8 +258,17 @@ def xmap_readers(mapper, reader, process_num, buffer_size):
in_queue.put(i) in_queue.put(i)
in_queue.put(end) in_queue.put(end)
# define a worker to read samples from reader to in_queue with order flag
def order_read_worker(reader, in_queue):
in_order = 0
for i in reader():
in_queue.put((in_order, i))
in_order += 1
in_queue.put(end)
# start a read worker in a thread # start a read worker in a thread
t = Thread(target=read_worker, args=(reader, in_queue)) target = order_read_worker if order else read_worker
t = Thread(target=target, args=(reader, in_queue))
t.daemon = True t.daemon = True
t.start() t.start()
...@@ -271,11 +283,28 @@ def xmap_readers(mapper, reader, process_num, buffer_size): ...@@ -271,11 +283,28 @@ def xmap_readers(mapper, reader, process_num, buffer_size):
in_queue.put(end) in_queue.put(end)
out_queue.put(end) out_queue.put(end)
# define a worker to handle samples from in_queue by mapper
# and put mapped samples into out_queue by order
def order_handle_worker(in_queue, out_queue, mapper, out_order):
ins = in_queue.get()
while not isinstance(ins, XmapEndSignal):
order, sample = ins
r = mapper(sample)
while order != out_order[0]:
pass
out_queue.put(r)
out_order[0] += 1
ins = in_queue.get()
in_queue.put(end)
out_queue.put(end)
# start several handle_workers # start several handle_workers
target = order_handle_worker if order else handle_worker
args = (in_queue, out_queue, mapper, out_order) if order else (
in_queue, out_queue, mapper)
workers = [] workers = []
for i in xrange(process_num): for i in xrange(process_num):
worker = Thread( worker = Thread(target=target, args=args)
target=handle_worker, args=(in_queue, out_queue, mapper))
worker.daemon = True worker.daemon = True
workers.append(worker) workers.append(worker)
for w in workers: for w in workers:
......
...@@ -121,5 +121,27 @@ class TestShuffle(unittest.TestCase): ...@@ -121,5 +121,27 @@ class TestShuffle(unittest.TestCase):
self.assertEqual(total, 10) self.assertEqual(total, 10)
class TestXmap(unittest.TestCase):
def test_xmap(self):
def mapper(x):
return (x + 1)
orders = (True, False)
thread_nums = (1, 2, 4, 8, 16)
buffered_size = (1, 2, 4, 8, 16)
for order in orders:
for tNum in thread_nums:
for size in buffered_size:
result = []
for i in paddle.v2.reader.xmap_readers(mapper,
reader_creator_10(0),
tNum, size, order)():
result.append(i)
if not order:
result.sort()
for idx, e in enumerate(result):
self.assertEqual(e, mapper(idx))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册