提交 8bc07dee 编写于 作者: W wanghaoshuang

format code

上级 c9a76ebb
......@@ -251,6 +251,7 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
in_queue = Queue(buffer_size)
out_queue = Queue(buffer_size)
out_order = [0]
# define a worker to read samples from reader to in_queue
def read_worker(reader, in_queue):
for i in reader():
......@@ -261,8 +262,8 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
def order_read_worker(reader, in_queue):
in_order = 0
for i in reader():
in_queue.put((in_order,i))
in_order+=1
in_queue.put((in_order, i))
in_order += 1
in_queue.put(end)
# start a read worker in a thread
......@@ -299,11 +300,11 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
# start several handle_workers
target = order_handle_worker if order else handle_worker
args = (in_queue, out_queue, mapper, out_order) if order else (in_queue, out_queue, mapper)
args = (in_queue, out_queue, mapper, out_order) if order else (
in_queue, out_queue, mapper)
workers = []
for i in xrange(process_num):
worker = Thread(
target=target, args=args)
worker = Thread(target=target, args=args)
worker.daemon = True
workers.append(worker)
for w in workers:
......
......@@ -120,10 +120,12 @@ class TestShuffle(unittest.TestCase):
total += 1
self.assertEqual(total, 10)
class TestXmap(unittest.TestCase):
def test_xmap(self):
def mapper(x):
return (x + 1)
orders = (True, False)
thread_nums = (1, 2, 4, 8, 16)
buffered_size = (1, 2, 4, 8, 16)
......@@ -131,7 +133,9 @@ class TestXmap(unittest.TestCase):
for tNum in thread_nums:
for size in buffered_size:
result = []
for i in paddle.v2.reader.xmap_readers(mapper, reader_creator_10(), tNum, size, order)():
for i in paddle.v2.reader.xmap_readers(mapper,
reader_creator_10(),
tNum, size, order)():
result.append(i)
if not order:
result.sort()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册