提交 8bc07dee 编写于 作者: W wanghaoshuang

format code

上级 c9a76ebb
...@@ -251,18 +251,19 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False): ...@@ -251,18 +251,19 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
in_queue = Queue(buffer_size) in_queue = Queue(buffer_size)
out_queue = Queue(buffer_size) out_queue = Queue(buffer_size)
out_order = [0] out_order = [0]
# define a worker to read samples from reader to in_queue # define a worker to read samples from reader to in_queue
def read_worker(reader, in_queue): def read_worker(reader, in_queue):
for i in reader(): for i in reader():
in_queue.put(i) in_queue.put(i)
in_queue.put(end) in_queue.put(end)
# define a worker to read samples from reader to in_queue with order flag # define a worker to read samples from reader to in_queue with order flag
def order_read_worker(reader, in_queue): def order_read_worker(reader, in_queue):
in_order = 0 in_order = 0
for i in reader(): for i in reader():
in_queue.put((in_order,i)) in_queue.put((in_order, i))
in_order+=1 in_order += 1
in_queue.put(end) in_queue.put(end)
# start a read worker in a thread # start a read worker in a thread
...@@ -281,7 +282,7 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False): ...@@ -281,7 +282,7 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
sample = in_queue.get() sample = in_queue.get()
in_queue.put(end) in_queue.put(end)
out_queue.put(end) out_queue.put(end)
# define a worker to handle samples from in_queue by mapper # define a worker to handle samples from in_queue by mapper
# and put mapped samples into out_queue by order # and put mapped samples into out_queue by order
def order_handle_worker(in_queue, out_queue, mapper, out_order): def order_handle_worker(in_queue, out_queue, mapper, out_order):
...@@ -292,18 +293,18 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False): ...@@ -292,18 +293,18 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
while order != out_order[0]: while order != out_order[0]:
pass pass
out_queue.put(r) out_queue.put(r)
out_order[0] += 1 out_order[0] += 1
ins = in_queue.get() ins = in_queue.get()
in_queue.put(end) in_queue.put(end)
out_queue.put(end) out_queue.put(end)
# start several handle_workers # start several handle_workers
target = order_handle_worker if order else handle_worker target = order_handle_worker if order else handle_worker
args = (in_queue, out_queue, mapper, out_order) if order else (in_queue, out_queue, mapper) args = (in_queue, out_queue, mapper, out_order) if order else (
in_queue, out_queue, mapper)
workers = [] workers = []
for i in xrange(process_num): for i in xrange(process_num):
worker = Thread( worker = Thread(target=target, args=args)
target=target, args=args)
worker.daemon = True worker.daemon = True
workers.append(worker) workers.append(worker)
for w in workers: for w in workers:
......
...@@ -120,10 +120,12 @@ class TestShuffle(unittest.TestCase): ...@@ -120,10 +120,12 @@ class TestShuffle(unittest.TestCase):
total += 1 total += 1
self.assertEqual(total, 10) self.assertEqual(total, 10)
class TestXmap(unittest.TestCase): class TestXmap(unittest.TestCase):
def test_xmap(self): def test_xmap(self):
def mapper(x): def mapper(x):
return (x + 1) return (x + 1)
orders = (True, False) orders = (True, False)
thread_nums = (1, 2, 4, 8, 16) thread_nums = (1, 2, 4, 8, 16)
buffered_size = (1, 2, 4, 8, 16) buffered_size = (1, 2, 4, 8, 16)
...@@ -131,13 +133,15 @@ class TestXmap(unittest.TestCase): ...@@ -131,13 +133,15 @@ class TestXmap(unittest.TestCase):
for tNum in thread_nums: for tNum in thread_nums:
for size in buffered_size: for size in buffered_size:
result = [] result = []
for i in paddle.v2.reader.xmap_readers(mapper, reader_creator_10(), tNum, size, order)(): for i in paddle.v2.reader.xmap_readers(mapper,
reader_creator_10(),
tNum, size, order)():
result.append(i) result.append(i)
if not order: if not order:
result.sort() result.sort()
for idx, e in enumerate(result): for idx, e in enumerate(result):
self.assertEqual(e, mapper(idx)) self.assertEqual(e, mapper(idx))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册