diff --git a/python/paddle/reader/decorator.py b/python/paddle/reader/decorator.py index 5b9459b670ac8583ee0e65a3c1b51f6248bb6303..e06c151d309bc01f4686aaf19c14880643abf4d4 100644 --- a/python/paddle/reader/decorator.py +++ b/python/paddle/reader/decorator.py @@ -15,7 +15,7 @@ __all__ = [ 'map_readers', 'buffered', 'compose', 'chain', 'shuffle', 'ComposeNotAligned', 'firstn', 'xmap_readers', 'PipeReader', - 'multiprocess_reader' + 'multiprocess_reader', 'fake' ] from threading import Thread @@ -504,3 +504,26 @@ class PipeReader: yield decomp_buff else: break + + +def fake(reader, data_num): + """ + fake reader will cache the first data it read and yield it out for data_num times. + It is used to cache a data from real reader and use it for speed testing. + + :param reader: the origin reader + :param data_num: times that this reader will yield data. + :return: a fake reader. + """ + + def fake_reader(): + if fake_reader.data is None: + fake_reader.data = reader().next() + while fake_reader.yield_num < data_num: + yield fake_reader.data + fake_reader.yield_num += 1 + + fake_reader.data = None + fake_reader.yield_num = 0 + + return fake_reader diff --git a/python/paddle/reader/tests/decorator_test.py b/python/paddle/reader/tests/decorator_test.py index c324092f8850e4bd64955aa9c987746b5cec54b5..cd585403fb15103a3e49380e7f8f0be325e73481 100644 --- a/python/paddle/reader/tests/decorator_test.py +++ b/python/paddle/reader/tests/decorator_test.py @@ -203,5 +203,20 @@ class TestMultiProcessReader(unittest.TestCase): self.reader_test(use_pipe=True) +class TestFakeReader(unittest.TestCase): + def test_fake_reader(self): + def reader(): + for i in range(10): + yield i + + data_num = 100 + fake_reader = paddle.reader.fake(reader, data_num) + i = 0 + for data in fake_reader(): + self.assertEqual(data, 0) + i += 1 + self.assertEqual(i, data_num) + + if __name__ == '__main__': unittest.main()