提交 8bc07dee 编写于 作者: W wanghaoshuang

format code

上级 c9a76ebb
......@@ -251,18 +251,19 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
in_queue = Queue(buffer_size)
out_queue = Queue(buffer_size)
out_order = [0]
# define a worker to read samples from reader to in_queue
def read_worker(reader, in_queue):
for i in reader():
in_queue.put(i)
in_queue.put(end)
# define a worker to read samples from reader to in_queue with order flag
def order_read_worker(reader, in_queue):
in_order = 0
for i in reader():
in_queue.put((in_order,i))
in_order+=1
in_queue.put((in_order, i))
in_order += 1
in_queue.put(end)
# start a read worker in a thread
......@@ -281,7 +282,7 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
sample = in_queue.get()
in_queue.put(end)
out_queue.put(end)
# define a worker to handle samples from in_queue by mapper
# and put mapped samples into out_queue by order
def order_handle_worker(in_queue, out_queue, mapper, out_order):
......@@ -292,18 +293,18 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
while order != out_order[0]:
pass
out_queue.put(r)
out_order[0] += 1
out_order[0] += 1
ins = in_queue.get()
in_queue.put(end)
out_queue.put(end)
# start several handle_workers
target = order_handle_worker if order else handle_worker
args = (in_queue, out_queue, mapper, out_order) if order else (in_queue, out_queue, mapper)
args = (in_queue, out_queue, mapper, out_order) if order else (
in_queue, out_queue, mapper)
workers = []
for i in xrange(process_num):
worker = Thread(
target=target, args=args)
worker = Thread(target=target, args=args)
worker.daemon = True
workers.append(worker)
for w in workers:
......
......@@ -120,10 +120,12 @@ class TestShuffle(unittest.TestCase):
total += 1
self.assertEqual(total, 10)
class TestXmap(unittest.TestCase):
def test_xmap(self):
def mapper(x):
return (x + 1)
orders = (True, False)
thread_nums = (1, 2, 4, 8, 16)
buffered_size = (1, 2, 4, 8, 16)
......@@ -131,13 +133,15 @@ class TestXmap(unittest.TestCase):
for tNum in thread_nums:
for size in buffered_size:
result = []
for i in paddle.v2.reader.xmap_readers(mapper, reader_creator_10(), tNum, size, order)():
for i in paddle.v2.reader.xmap_readers(mapper,
reader_creator_10(),
tNum, size, order)():
result.append(i)
if not order:
result.sort()
for idx, e in enumerate(result):
self.assertEqual(e, mapper(idx))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册