提交 cadea35a 编写于 作者: W wanghaoshuang

format code

上级 09cc4408
...@@ -251,6 +251,7 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False): ...@@ -251,6 +251,7 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
in_queue = Queue(buffer_size) in_queue = Queue(buffer_size)
out_queue = Queue(buffer_size) out_queue = Queue(buffer_size)
out_order = [0] out_order = [0]
# define a worker to read samples from reader to in_queue # define a worker to read samples from reader to in_queue
def read_worker(reader, in_queue): def read_worker(reader, in_queue):
for i in reader(): for i in reader():
...@@ -261,8 +262,8 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False): ...@@ -261,8 +262,8 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
def order_read_worker(reader, in_queue): def order_read_worker(reader, in_queue):
in_order = 0 in_order = 0
for i in reader(): for i in reader():
in_queue.put((in_order,i)) in_queue.put((in_order, i))
in_order+=1 in_order += 1
in_queue.put(end) in_queue.put(end)
# start a read worker in a thread # start a read worker in a thread
...@@ -299,11 +300,11 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False): ...@@ -299,11 +300,11 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
# start several handle_workers # start several handle_workers
target = order_handle_worker if order else handle_worker target = order_handle_worker if order else handle_worker
args = (in_queue, out_queue, mapper, out_order) if order else (in_queue, out_queue, mapper) args = (in_queue, out_queue, mapper, out_order) if order else (
in_queue, out_queue, mapper)
workers = [] workers = []
for i in xrange(process_num): for i in xrange(process_num):
worker = Thread( worker = Thread(target=target, args=args)
target=target, args=args)
worker.daemon = True worker.daemon = True
workers.append(worker) workers.append(worker)
for w in workers: for w in workers:
......
...@@ -120,10 +120,12 @@ class TestShuffle(unittest.TestCase): ...@@ -120,10 +120,12 @@ class TestShuffle(unittest.TestCase):
total += 1 total += 1
self.assertEqual(total, 10) self.assertEqual(total, 10)
class TestXmap(unittest.TestCase): class TestXmap(unittest.TestCase):
def test_xmap(self): def test_xmap(self):
def mapper(x): def mapper(x):
return (x + 1) return (x + 1)
orders = (True, False) orders = (True, False)
thread_nums = (1, 2, 4, 8, 16) thread_nums = (1, 2, 4, 8, 16)
buffered_size = (1, 2, 4, 8, 16) buffered_size = (1, 2, 4, 8, 16)
...@@ -131,7 +133,9 @@ class TestXmap(unittest.TestCase): ...@@ -131,7 +133,9 @@ class TestXmap(unittest.TestCase):
for tNum in thread_nums: for tNum in thread_nums:
for size in buffered_size: for size in buffered_size:
result = [] result = []
for i in paddle.v2.reader.xmap_readers(mapper, reader_creator_10(), tNum, size, order)(): for i in paddle.v2.reader.xmap_readers(mapper,
reader_creator_10(),
tNum, size, order)():
result.append(i) result.append(i)
if not order: if not order:
result.sort() result.sort()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册