From 7fd42b8cce0585d2ff276f2cbd075a6ec5a502d0 Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Tue, 21 Feb 2017 14:32:36 -0800 Subject: [PATCH] create reader creator decorators: shuffle, compose, chain --- python/paddle/reader/decorator.py | 139 +++++++++++++++++-- python/paddle/reader/tests/decorator_test.py | 71 +++++++++- 2 files changed, 193 insertions(+), 17 deletions(-) diff --git a/python/paddle/reader/decorator.py b/python/paddle/reader/decorator.py index f0ddb0ff812..dcfaf705870 100644 --- a/python/paddle/reader/decorator.py +++ b/python/paddle/reader/decorator.py @@ -12,25 +12,142 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = ['buffered'] +__all__ = ['buffered', 'compose', 'chain', 'shuffle', 'ComposeNotAligned'] from Queue import Queue from threading import Thread +import itertools +import random -def buffered(reader, size): - """Creates a buffered data reader. +def shuffle(reader_creator, buf_size): + """Creates a data reader creator whose data output is suffled. - The buffered data reader will read and save data entries into a buffer. - Reading from the buffered data reader will proceed as long as the buffer - is not empty. + Output from the iterator that created by original reader creator will be + buffered into shuffle buffer, and then shuffled. The size of shuffle buffer + is determined by argument buf_size. + + Args: + reader_creator: the original reader creator whose output will be + shuffled. + buf_size: shuffle buffer size. + + Returns: + the new reader creator whose output is shuffled. + """ + + def create_reader_creator(): + buf = [] + for e in reader_creator(): + buf.append(e) + if len(buf) >= buf_size: + random.shuffle(buf) + for b in buf: + yield b + buf = [] + + if len(buf) > 0: + random.shuffle(buf) + for b in buf: + yield b + + return create_reader_creator + + +def chain(*reader_creators): + """Creates a data reader creator whose output is the outputs of input data + reader creators chained together. + + If input reader creators output following data entries: + [0, 0, 0] + [1, 1, 1] + [2, 2, 2] + The chained reader creator will output: + [0, 0, 0, 1, 1, 1, 2, 2, 2] + + Args: + readers_creators: input reader creators + + Returns: + the new data reader creator. + """ + + def create_reader_creator(): + rs = [] + for r in reader_creators: + rs.append(r()) + + for e in itertools.chain(*rs): + yield e + + return create_reader_creator + + +class ComposeNotAligned: + pass + + +def compose(*reader_creators, **kwargs): + """Creates a data reader creator whose output is the combination of input + readers creators. + + If input reader creators output following data entries: + (1, 2) 3 (4, 5) + The composed reader creator will output: + (1, 2, 3, 4, 5) + + Args: + *reader_creators: reader creators that will be composed together. + check_alignment: If True, will check if input reader creators are aligned + correctly. If False, will not check alignment and trailing outputs + will be discarded. Defaults to True. + + Returns: + the new data reader creator. + + Raises: + ComposeNotAligned: outputs of reader creators are not aligned. + Will not raise when check_alignment is set to False. + """ + check_alignment = kwargs.pop('check_alignment', True) + + def make_tuple(x): + if isinstance(x, tuple): + return x + else: + return (x, ) + + def create_reader_creator(): + rs = [] + for r in reader_creators: + rs.append(r()) + if not check_alignment: + for outputs in itertools.izip(*rs): + yield sum(map(make_tuple, outputs), ()) + else: + for outputs in itertools.izip_longest(*rs): + for o in outputs: + if o is None: + # None will be not be present if compose is aligned + raise ComposeNotAligned + yield sum(map(make_tuple, outputs), ()) + + return create_reader_creator + + +def buffered(reader_creator, size): + """Creates a buffered data reader creator. + + The buffered data reader creator will read and save data entries into a + buffer. Reading from the buffered data reader creator will proceed as long + as the buffer is not empty. Args: - reader: the data reader to read from. + reader_creator: the data reader creator to read from. size: max buffer size. Returns: - The buffered data reader. + The buffered data reader creator. """ class EndSignal(): @@ -43,8 +160,8 @@ def buffered(reader, size): q.put(d) q.put(end) - def create_reader(): - r = reader() + def create_reader_creator(): + r = reader_creator() q = Queue(maxsize=size) t = Thread( target=read_worker, args=( @@ -57,4 +174,4 @@ def buffered(reader, size): yield e e = q.get() - return create_reader + return create_reader_creator diff --git a/python/paddle/reader/tests/decorator_test.py b/python/paddle/reader/tests/decorator_test.py index 879d1d9c1d0..2830d41bf0a 100644 --- a/python/paddle/reader/tests/decorator_test.py +++ b/python/paddle/reader/tests/decorator_test.py @@ -17,15 +17,18 @@ import time def reader_10(dur): - for i in range(10): - time.sleep(dur) - yield i + def reader(): + for i in range(10): + time.sleep(dur) + yield i + + return reader class TestBuffered(unittest.TestCase): def test_read(self): for size in range(20): - b = paddle.reader.buffered(lambda: reader_10(0), size) + b = paddle.reader.buffered(reader_10(0), size) c = 0 for i in b(): self.assertEqual(i, c) @@ -34,7 +37,7 @@ class TestBuffered(unittest.TestCase): def test_buffering(self): # read have 30ms delay. - b = paddle.reader.buffered(lambda: reader_10(0.03), 10) + b = paddle.reader.buffered(reader_10(0.03), 10) last_time = time.time() for idx, i in enumerate(b()): elapsed_time = time.time() - last_time @@ -42,9 +45,65 @@ class TestBuffered(unittest.TestCase): time.sleep(0.3) else: # read time should be short, meaning already buffered. - self.assertLess(elapsed_time, 0.01) + self.assertLess(elapsed_time, 0.05) last_time = time.time() +class TestCompose(unittest.TestCase): + def test_compse(self): + a = reader_10(0) + b = reader_10(0) + c = paddle.reader.compose(a, b) + for idx, e in enumerate(c()): + self.assertEqual(e, (idx, idx)) + + def test_compose_not_aligned(self): + a = reader_10(0) + b = paddle.reader.chain(a, a) + c = paddle.reader.compose(a, b) + total = 0 + with self.assertRaises(paddle.reader.ComposeNotAligned): + for e in c(): + total += 1 + # expecting 10, not 20 + self.assertEqual(total, 10) + + def test_compose_not_aligned_no_check(self): + a = reader_10(0) + b = paddle.reader.chain(a, a) + c = paddle.reader.compose(a, b, check_alignment=False) + total = 0 + for e in c(): + total += 1 + # expecting 10, not 20 + self.assertEqual(total, 10) + + +class TestChain(unittest.TestCase): + def test_chain(self): + a = reader_10(0) + b = reader_10(0) + c = paddle.reader.chain(a, b) + idx = 0 + for e in c(): + self.assertEqual(e, idx % 10) + idx += 1 + self.assertEqual(idx, 20) + + +class TestShuffle(unittest.TestCase): + def test_shuffle(self): + case = [(0, True), (1, True), (10, False), (100, False)] + a = reader_10(0) + for size, checkEq in case: + s = paddle.reader.shuffle(a, size) + total = 0 + for idx, e in enumerate(s()): + if checkEq: + self.assertEqual(idx, e) + total += 1 + self.assertEqual(total, 10) + + if __name__ == '__main__': unittest.main() -- GitLab