diff --git a/python/paddle/v2/reader/decorator.py b/python/paddle/v2/reader/decorator.py index c76faa596c9fb9079cab3456b721c18ef9768e95..e432003129d2b8dea60138d08f13ec5e9d29a7ad 100644 --- a/python/paddle/v2/reader/decorator.py +++ b/python/paddle/v2/reader/decorator.py @@ -230,7 +230,7 @@ class XmapEndSignal(): pass -def xmap_readers(mapper, reader, process_num, buffer_size): +def xmap_readers(mapper, reader, process_num, buffer_size, order=False): """ Use multiprocess to map samples from reader by a mapper defined by user. And this function contains a buffered decorator. @@ -242,12 +242,15 @@ def xmap_readers(mapper, reader, process_num, buffer_size): :type process_num: int :param buffer_size: max buffer size :type buffer_size: int + :param order: keep the order of reader + :type order: bool :return: the decarated reader :rtype: callable """ end = XmapEndSignal() in_queue = Queue(buffer_size) out_queue = Queue(buffer_size) + out_order = [0] # define a worker to read samples from reader to in_queue def read_worker(reader, in_queue): @@ -255,8 +258,17 @@ def xmap_readers(mapper, reader, process_num, buffer_size): in_queue.put(i) in_queue.put(end) + # define a worker to read samples from reader to in_queue with order flag + def order_read_worker(reader, in_queue): + in_order = 0 + for i in reader(): + in_queue.put((in_order, i)) + in_order += 1 + in_queue.put(end) + # start a read worker in a thread - t = Thread(target=read_worker, args=(reader, in_queue)) + target = order_read_worker if order else read_worker + t = Thread(target=target, args=(reader, in_queue)) t.daemon = True t.start() @@ -271,11 +283,28 @@ def xmap_readers(mapper, reader, process_num, buffer_size): in_queue.put(end) out_queue.put(end) + # define a worker to handle samples from in_queue by mapper + # and put mapped samples into out_queue by order + def order_handle_worker(in_queue, out_queue, mapper, out_order): + ins = in_queue.get() + while not isinstance(ins, XmapEndSignal): + order, sample = ins + r = mapper(sample) + while order != out_order[0]: + pass + out_queue.put(r) + out_order[0] += 1 + ins = in_queue.get() + in_queue.put(end) + out_queue.put(end) + # start several handle_workers + target = order_handle_worker if order else handle_worker + args = (in_queue, out_queue, mapper, out_order) if order else ( + in_queue, out_queue, mapper) workers = [] for i in xrange(process_num): - worker = Thread( - target=handle_worker, args=(in_queue, out_queue, mapper)) + worker = Thread(target=target, args=args) worker.daemon = True workers.append(worker) for w in workers: diff --git a/python/paddle/v2/reader/tests/decorator_test.py b/python/paddle/v2/reader/tests/decorator_test.py index 734154b9790a4dc118d11992343648364c907305..bb3c5d220b9ce1552d2fc429abb1863930cd4d17 100644 --- a/python/paddle/v2/reader/tests/decorator_test.py +++ b/python/paddle/v2/reader/tests/decorator_test.py @@ -121,5 +121,27 @@ class TestShuffle(unittest.TestCase): self.assertEqual(total, 10) +class TestXmap(unittest.TestCase): + def test_xmap(self): + def mapper(x): + return (x + 1) + + orders = (True, False) + thread_nums = (1, 2, 4, 8, 16) + buffered_size = (1, 2, 4, 8, 16) + for order in orders: + for tNum in thread_nums: + for size in buffered_size: + result = [] + for i in paddle.v2.reader.xmap_readers(mapper, + reader_creator_10(0), + tNum, size, order)(): + result.append(i) + if not order: + result.sort() + for idx, e in enumerate(result): + self.assertEqual(e, mapper(idx)) + + if __name__ == '__main__': unittest.main()