From 2a21681c00b740627d6f5c37b89550f2990704f4 Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Tue, 21 Feb 2017 16:21:57 -0800 Subject: [PATCH] fix according to comments --- python/paddle/reader/decorator.py | 77 ++++++++++---------- python/paddle/reader/tests/decorator_test.py | 37 +++++----- 2 files changed, 56 insertions(+), 58 deletions(-) diff --git a/python/paddle/reader/decorator.py b/python/paddle/reader/decorator.py index dcfaf705870..1192b5de90f 100644 --- a/python/paddle/reader/decorator.py +++ b/python/paddle/reader/decorator.py @@ -20,25 +20,25 @@ import itertools import random -def shuffle(reader_creator, buf_size): - """Creates a data reader creator whose data output is suffled. +def shuffle(reader, buf_size): + """Creates a data reader whose data output is suffled. - Output from the iterator that created by original reader creator will be + Output from the iterator that created by original reader will be buffered into shuffle buffer, and then shuffled. The size of shuffle buffer is determined by argument buf_size. Args: - reader_creator: the original reader creator whose output will be + reader: the original reader whose output will be shuffled. buf_size: shuffle buffer size. Returns: - the new reader creator whose output is shuffled. + the new reader whose output is shuffled. """ - def create_reader_creator(): + def data_reader(): buf = [] - for e in reader_creator(): + for e in reader(): buf.append(e) if len(buf) >= buf_size: random.shuffle(buf) @@ -51,62 +51,61 @@ def shuffle(reader_creator, buf_size): for b in buf: yield b - return create_reader_creator + return data_reader -def chain(*reader_creators): - """Creates a data reader creator whose output is the outputs of input data - reader creators chained together. +def chain(*readers): + """Creates a data reader whose output is the outputs of input data + readers chained together. - If input reader creators output following data entries: + If input readers output following data entries: [0, 0, 0] [1, 1, 1] [2, 2, 2] - The chained reader creator will output: + The chained reader will output: [0, 0, 0, 1, 1, 1, 2, 2, 2] Args: - readers_creators: input reader creators + readerss: input readers. Returns: - the new data reader creator. + the new data reader. """ - def create_reader_creator(): + def reader(): rs = [] - for r in reader_creators: + for r in readers: rs.append(r()) for e in itertools.chain(*rs): yield e - return create_reader_creator + return reader class ComposeNotAligned: pass -def compose(*reader_creators, **kwargs): - """Creates a data reader creator whose output is the combination of input - readers creators. +def compose(*readers, **kwargs): + """Creates a data reader whose output is the combination of input readers. - If input reader creators output following data entries: + If input readers output following data entries: (1, 2) 3 (4, 5) - The composed reader creator will output: + The composed reader will output: (1, 2, 3, 4, 5) Args: - *reader_creators: reader creators that will be composed together. - check_alignment: If True, will check if input reader creators are aligned + *readers: readers that will be composed together. + check_alignment: If True, will check if input readers are aligned correctly. If False, will not check alignment and trailing outputs will be discarded. Defaults to True. Returns: - the new data reader creator. + the new data reader. Raises: - ComposeNotAligned: outputs of reader creators are not aligned. + ComposeNotAligned: outputs of readers are not aligned. Will not raise when check_alignment is set to False. """ check_alignment = kwargs.pop('check_alignment', True) @@ -117,9 +116,9 @@ def compose(*reader_creators, **kwargs): else: return (x, ) - def create_reader_creator(): + def reader(): rs = [] - for r in reader_creators: + for r in readers: rs.append(r()) if not check_alignment: for outputs in itertools.izip(*rs): @@ -132,22 +131,22 @@ def compose(*reader_creators, **kwargs): raise ComposeNotAligned yield sum(map(make_tuple, outputs), ()) - return create_reader_creator + return reader -def buffered(reader_creator, size): - """Creates a buffered data reader creator. +def buffered(reader, size): + """Creates a buffered data reader. - The buffered data reader creator will read and save data entries into a - buffer. Reading from the buffered data reader creator will proceed as long + The buffered data reader will read and save data entries into a + buffer. Reading from the buffered data reader will proceed as long as the buffer is not empty. Args: - reader_creator: the data reader creator to read from. + reader: the data reader to read from. size: max buffer size. Returns: - The buffered data reader creator. + The buffered data reader. """ class EndSignal(): @@ -160,8 +159,8 @@ def buffered(reader_creator, size): q.put(d) q.put(end) - def create_reader_creator(): - r = reader_creator() + def data_reader(): + r = reader() q = Queue(maxsize=size) t = Thread( target=read_worker, args=( @@ -174,4 +173,4 @@ def buffered(reader_creator, size): yield e e = q.get() - return create_reader_creator + return data_reader diff --git a/python/paddle/reader/tests/decorator_test.py b/python/paddle/reader/tests/decorator_test.py index 2830d41bf0a..46eec44158c 100644 --- a/python/paddle/reader/tests/decorator_test.py +++ b/python/paddle/reader/tests/decorator_test.py @@ -16,9 +16,10 @@ import paddle.reader import time -def reader_10(dur): +def reader_creator_10(dur): def reader(): for i in range(10): + # this invocation helps testing paddle.reader.buffer time.sleep(dur) yield i @@ -28,7 +29,7 @@ def reader_10(dur): class TestBuffered(unittest.TestCase): def test_read(self): for size in range(20): - b = paddle.reader.buffered(reader_10(0), size) + b = paddle.reader.buffered(reader_creator_10(0), size) c = 0 for i in b(): self.assertEqual(i, c) @@ -37,7 +38,7 @@ class TestBuffered(unittest.TestCase): def test_buffering(self): # read have 30ms delay. - b = paddle.reader.buffered(reader_10(0.03), 10) + b = paddle.reader.buffered(reader_creator_10(0.03), 10) last_time = time.time() for idx, i in enumerate(b()): elapsed_time = time.time() - last_time @@ -51,29 +52,29 @@ class TestBuffered(unittest.TestCase): class TestCompose(unittest.TestCase): def test_compse(self): - a = reader_10(0) - b = reader_10(0) - c = paddle.reader.compose(a, b) - for idx, e in enumerate(c()): + reader = paddle.reader.compose( + reader_creator_10(0), reader_creator_10(0)) + for idx, e in enumerate(reader()): self.assertEqual(e, (idx, idx)) def test_compose_not_aligned(self): - a = reader_10(0) - b = paddle.reader.chain(a, a) - c = paddle.reader.compose(a, b) total = 0 + reader = paddle.reader.compose( + paddle.reader.chain(reader_creator_10(0), reader_creator_10(0)), + reader_creator_10(0)) with self.assertRaises(paddle.reader.ComposeNotAligned): - for e in c(): + for e in reader(): total += 1 # expecting 10, not 20 self.assertEqual(total, 10) def test_compose_not_aligned_no_check(self): - a = reader_10(0) - b = paddle.reader.chain(a, a) - c = paddle.reader.compose(a, b, check_alignment=False) total = 0 - for e in c(): + reader = paddle.reader.compose( + paddle.reader.chain(reader_creator_10(0), reader_creator_10(0)), + reader_creator_10(0), + check_alignment=False) + for e in reader(): total += 1 # expecting 10, not 20 self.assertEqual(total, 10) @@ -81,9 +82,7 @@ class TestCompose(unittest.TestCase): class TestChain(unittest.TestCase): def test_chain(self): - a = reader_10(0) - b = reader_10(0) - c = paddle.reader.chain(a, b) + c = paddle.reader.chain(reader_creator_10(0), reader_creator_10(0)) idx = 0 for e in c(): self.assertEqual(e, idx % 10) @@ -94,7 +93,7 @@ class TestChain(unittest.TestCase): class TestShuffle(unittest.TestCase): def test_shuffle(self): case = [(0, True), (1, True), (10, False), (100, False)] - a = reader_10(0) + a = reader_creator_10(0) for size, checkEq in case: s = paddle.reader.shuffle(a, size) total = 0 -- GitLab