diff --git a/python/paddle/v2/reader/decorator.py b/python/paddle/v2/reader/decorator.py index c76faa596c9fb9079cab3456b721c18ef9768e95..68ffbd6f3d459b219aef7eb52749671e86025711 100644 --- a/python/paddle/v2/reader/decorator.py +++ b/python/paddle/v2/reader/decorator.py @@ -230,7 +230,7 @@ class XmapEndSignal(): pass -def xmap_readers(mapper, reader, process_num, buffer_size): +def xmap_readers(mapper, reader, process_num, buffer_size, order=False): """ Use multiprocess to map samples from reader by a mapper defined by user. And this function contains a buffered decorator. @@ -242,21 +242,32 @@ def xmap_readers(mapper, reader, process_num, buffer_size): :type process_num: int :param buffer_size: max buffer size :type buffer_size: int + :param order: keep the order of reader + :type order: bool :return: the decarated reader :rtype: callable """ end = XmapEndSignal() in_queue = Queue(buffer_size) out_queue = Queue(buffer_size) - + out_order = [0] # define a worker to read samples from reader to in_queue def read_worker(reader, in_queue): for i in reader(): in_queue.put(i) in_queue.put(end) + + # define a worker to read samples from reader to in_queue with order flag + def order_read_worker(reader, in_queue): + in_order = 0 + for i in reader(): + in_queue.put((in_order,i)) + in_order+=1 + in_queue.put(end) # start a read worker in a thread - t = Thread(target=read_worker, args=(reader, in_queue)) + target = order_read_worker if order else read_worker + t = Thread(target=target, args=(reader, in_queue)) t.daemon = True t.start() @@ -270,12 +281,29 @@ def xmap_readers(mapper, reader, process_num, buffer_size): sample = in_queue.get() in_queue.put(end) out_queue.put(end) + + # define a worker to handle samples from in_queue by mapper + # and put mapped samples into out_queue by order + def order_handle_worker(in_queue, out_queue, mapper, out_order): + ins = in_queue.get() + while not isinstance(ins, XmapEndSignal): + order, sample = ins + r = mapper(sample) + while order != out_order[0]: + pass + out_queue.put(r) + out_order[0] += 1 + ins = in_queue.get() + in_queue.put(end) + out_queue.put(end) # start several handle_workers + target = order_handle_worker if order else handle_worker + args = (in_queue, out_queue, mapper, out_order) if order else (in_queue, out_queue, mapper) workers = [] for i in xrange(process_num): worker = Thread( - target=handle_worker, args=(in_queue, out_queue, mapper)) + target=target, args=args) worker.daemon = True workers.append(worker) for w in workers: diff --git a/python/paddle/v2/reader/tests/decorator_test.py b/python/paddle/v2/reader/tests/decorator_test.py index 734154b9790a4dc118d11992343648364c907305..76db91a44b8d75f1af7271d5f4f94ea21c51d943 100644 --- a/python/paddle/v2/reader/tests/decorator_test.py +++ b/python/paddle/v2/reader/tests/decorator_test.py @@ -120,6 +120,24 @@ class TestShuffle(unittest.TestCase): total += 1 self.assertEqual(total, 10) +class TestXmap(unittest.TestCase): + def test_xmap(self): + def mapper(x): + return (x + 1) + orders = (True, False) + thread_nums = (1, 2, 4, 8, 16) + buffered_size = (1, 2, 4, 8, 16) + for order in orders: + for tNum in thread_nums: + for size in buffered_size: + result = [] + for i in paddle.v2.reader.xmap_readers(mapper, reader_creator_10(), tNum, size, order)(): + result.append(i) + if not order: + result.sort() + for idx, e in enumerate(result): + self.assertEqual(e, mapper(idx)) + if __name__ == '__main__': unittest.main()