diff --git a/python/paddle/reader/decorator.py b/python/paddle/reader/decorator.py index f0ddb0ff812b15ede21e6965c7c8857f12716fa0..5fc799e61dab954b9993321c4e816f2f5abce448 100644 --- a/python/paddle/reader/decorator.py +++ b/python/paddle/reader/decorator.py @@ -12,18 +12,135 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = ['buffered'] +__all__ = ['buffered', 'compose', 'chain', 'shuffle', 'ComposeNotAligned'] from Queue import Queue from threading import Thread +import itertools +import random + + +def shuffle(reader, buf_size): + """Creates a data reader whose data output is suffled. + + Output from the iterator that created by original reader will be + buffered into shuffle buffer, and then shuffled. The size of shuffle buffer + is determined by argument buf_size. + + Args: + reader: the original reader whose output will be + shuffled. + buf_size: shuffle buffer size. + + Returns: + the new reader whose output is shuffled. + """ + + def data_reader(): + buf = [] + for e in reader(): + buf.append(e) + if len(buf) >= buf_size: + random.shuffle(buf) + for b in buf: + yield b + buf = [] + + if len(buf) > 0: + random.shuffle(buf) + for b in buf: + yield b + + return data_reader + + +def chain(*readers): + """Creates a data reader whose output is the outputs of input data + readers chained together. + + If input readers output following data entries: + [0, 0, 0] + [1, 1, 1] + [2, 2, 2] + The chained reader will output: + [0, 0, 0, 1, 1, 1, 2, 2, 2] + + Args: + readers: input readers. + + Returns: + the new data reader. + """ + + def reader(): + rs = [] + for r in readers: + rs.append(r()) + + for e in itertools.chain(*rs): + yield e + + return reader + + +class ComposeNotAligned(ValueError): + pass + + +def compose(*readers, **kwargs): + """Creates a data reader whose output is the combination of input readers. + + If input readers output following data entries: + (1, 2) 3 (4, 5) + The composed reader will output: + (1, 2, 3, 4, 5) + + Args: + *readers: readers that will be composed together. + check_alignment: If True, will check if input readers are aligned + correctly. If False, will not check alignment and trailing outputs + will be discarded. Defaults to True. + + Returns: + the new data reader. + + Raises: + ComposeNotAligned: outputs of readers are not aligned. + Will not raise when check_alignment is set to False. + """ + check_alignment = kwargs.pop('check_alignment', True) + + def make_tuple(x): + if isinstance(x, tuple): + return x + else: + return (x, ) + + def reader(): + rs = [] + for r in readers: + rs.append(r()) + if not check_alignment: + for outputs in itertools.izip(*rs): + yield sum(map(make_tuple, outputs), ()) + else: + for outputs in itertools.izip_longest(*rs): + for o in outputs: + if o is None: + # None will be not be present if compose is aligned + raise ComposeNotAligned( + "outputs of readers are not aligned.") + yield sum(map(make_tuple, outputs), ()) + + return reader def buffered(reader, size): """Creates a buffered data reader. - The buffered data reader will read and save data entries into a buffer. - Reading from the buffered data reader will proceed as long as the buffer - is not empty. + The buffered data reader will read and save data entries into a + buffer. Reading from the buffered data reader will proceed as long + as the buffer is not empty. Args: reader: the data reader to read from. @@ -43,7 +160,7 @@ def buffered(reader, size): q.put(d) q.put(end) - def create_reader(): + def data_reader(): r = reader() q = Queue(maxsize=size) t = Thread( @@ -57,4 +174,4 @@ def buffered(reader, size): yield e e = q.get() - return create_reader + return data_reader diff --git a/python/paddle/reader/tests/decorator_test.py b/python/paddle/reader/tests/decorator_test.py index 879d1d9c1d0e0650d347b5c44e36771a0c15390e..46eec44158cee5f8c70a0e6197e856e485a7d40c 100644 --- a/python/paddle/reader/tests/decorator_test.py +++ b/python/paddle/reader/tests/decorator_test.py @@ -16,16 +16,20 @@ import paddle.reader import time -def reader_10(dur): - for i in range(10): - time.sleep(dur) - yield i +def reader_creator_10(dur): + def reader(): + for i in range(10): + # this invocation helps testing paddle.reader.buffer + time.sleep(dur) + yield i + + return reader class TestBuffered(unittest.TestCase): def test_read(self): for size in range(20): - b = paddle.reader.buffered(lambda: reader_10(0), size) + b = paddle.reader.buffered(reader_creator_10(0), size) c = 0 for i in b(): self.assertEqual(i, c) @@ -34,7 +38,7 @@ class TestBuffered(unittest.TestCase): def test_buffering(self): # read have 30ms delay. - b = paddle.reader.buffered(lambda: reader_10(0.03), 10) + b = paddle.reader.buffered(reader_creator_10(0.03), 10) last_time = time.time() for idx, i in enumerate(b()): elapsed_time = time.time() - last_time @@ -42,9 +46,63 @@ class TestBuffered(unittest.TestCase): time.sleep(0.3) else: # read time should be short, meaning already buffered. - self.assertLess(elapsed_time, 0.01) + self.assertLess(elapsed_time, 0.05) last_time = time.time() +class TestCompose(unittest.TestCase): + def test_compse(self): + reader = paddle.reader.compose( + reader_creator_10(0), reader_creator_10(0)) + for idx, e in enumerate(reader()): + self.assertEqual(e, (idx, idx)) + + def test_compose_not_aligned(self): + total = 0 + reader = paddle.reader.compose( + paddle.reader.chain(reader_creator_10(0), reader_creator_10(0)), + reader_creator_10(0)) + with self.assertRaises(paddle.reader.ComposeNotAligned): + for e in reader(): + total += 1 + # expecting 10, not 20 + self.assertEqual(total, 10) + + def test_compose_not_aligned_no_check(self): + total = 0 + reader = paddle.reader.compose( + paddle.reader.chain(reader_creator_10(0), reader_creator_10(0)), + reader_creator_10(0), + check_alignment=False) + for e in reader(): + total += 1 + # expecting 10, not 20 + self.assertEqual(total, 10) + + +class TestChain(unittest.TestCase): + def test_chain(self): + c = paddle.reader.chain(reader_creator_10(0), reader_creator_10(0)) + idx = 0 + for e in c(): + self.assertEqual(e, idx % 10) + idx += 1 + self.assertEqual(idx, 20) + + +class TestShuffle(unittest.TestCase): + def test_shuffle(self): + case = [(0, True), (1, True), (10, False), (100, False)] + a = reader_creator_10(0) + for size, checkEq in case: + s = paddle.reader.shuffle(a, size) + total = 0 + for idx, e in enumerate(s()): + if checkEq: + self.assertEqual(idx, e) + total += 1 + self.assertEqual(total, 10) + + if __name__ == '__main__': unittest.main()