diff --git a/python/paddle/v2/reader/decorator.py b/python/paddle/v2/reader/decorator.py index 68ffbd6f3d459b219aef7eb52749671e86025711..e432003129d2b8dea60138d08f13ec5e9d29a7ad 100644 --- a/python/paddle/v2/reader/decorator.py +++ b/python/paddle/v2/reader/decorator.py @@ -251,18 +251,19 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False): in_queue = Queue(buffer_size) out_queue = Queue(buffer_size) out_order = [0] + # define a worker to read samples from reader to in_queue def read_worker(reader, in_queue): for i in reader(): in_queue.put(i) in_queue.put(end) - + # define a worker to read samples from reader to in_queue with order flag def order_read_worker(reader, in_queue): in_order = 0 for i in reader(): - in_queue.put((in_order,i)) - in_order+=1 + in_queue.put((in_order, i)) + in_order += 1 in_queue.put(end) # start a read worker in a thread @@ -281,7 +282,7 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False): sample = in_queue.get() in_queue.put(end) out_queue.put(end) - + # define a worker to handle samples from in_queue by mapper # and put mapped samples into out_queue by order def order_handle_worker(in_queue, out_queue, mapper, out_order): @@ -292,18 +293,18 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False): while order != out_order[0]: pass out_queue.put(r) - out_order[0] += 1 + out_order[0] += 1 ins = in_queue.get() in_queue.put(end) out_queue.put(end) # start several handle_workers target = order_handle_worker if order else handle_worker - args = (in_queue, out_queue, mapper, out_order) if order else (in_queue, out_queue, mapper) + args = (in_queue, out_queue, mapper, out_order) if order else ( + in_queue, out_queue, mapper) workers = [] for i in xrange(process_num): - worker = Thread( - target=target, args=args) + worker = Thread(target=target, args=args) worker.daemon = True workers.append(worker) for w in workers: diff --git a/python/paddle/v2/reader/tests/decorator_test.py b/python/paddle/v2/reader/tests/decorator_test.py index 76db91a44b8d75f1af7271d5f4f94ea21c51d943..0bd773395507c074d962084cca8842a2026e798a 100644 --- a/python/paddle/v2/reader/tests/decorator_test.py +++ b/python/paddle/v2/reader/tests/decorator_test.py @@ -120,10 +120,12 @@ class TestShuffle(unittest.TestCase): total += 1 self.assertEqual(total, 10) + class TestXmap(unittest.TestCase): def test_xmap(self): def mapper(x): return (x + 1) + orders = (True, False) thread_nums = (1, 2, 4, 8, 16) buffered_size = (1, 2, 4, 8, 16) @@ -131,13 +133,15 @@ class TestXmap(unittest.TestCase): for tNum in thread_nums: for size in buffered_size: result = [] - for i in paddle.v2.reader.xmap_readers(mapper, reader_creator_10(), tNum, size, order)(): + for i in paddle.v2.reader.xmap_readers(mapper, + reader_creator_10(), + tNum, size, order)(): result.append(i) if not order: result.sort() for idx, e in enumerate(result): self.assertEqual(e, mapper(idx)) - + if __name__ == '__main__': unittest.main()