From 198689c3bb65a65db0d1a7d22e7906845e7ea3b3 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Fri, 5 Oct 2018 23:33:56 +0800 Subject: [PATCH] add a fake reader for speed test --- python/paddle/reader/decorator.py | 25 +++++++++++++++++++- python/paddle/reader/tests/decorator_test.py | 15 ++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/python/paddle/reader/decorator.py b/python/paddle/reader/decorator.py index 5b9459b670a..e06c151d309 100644 --- a/python/paddle/reader/decorator.py +++ b/python/paddle/reader/decorator.py @@ -15,7 +15,7 @@ __all__ = [ 'map_readers', 'buffered', 'compose', 'chain', 'shuffle', 'ComposeNotAligned', 'firstn', 'xmap_readers', 'PipeReader', - 'multiprocess_reader' + 'multiprocess_reader', 'fake' ] from threading import Thread @@ -504,3 +504,26 @@ class PipeReader: yield decomp_buff else: break + + +def fake(reader, data_num): + """ + fake reader will cache the first data it read and yield it out for data_num times. + It is used to cache a data from real reader and use it for speed testing. + + :param reader: the origin reader + :param data_num: times that this reader will yield data. + :return: a fake reader. + """ + + def fake_reader(): + if fake_reader.data is None: + fake_reader.data = reader().next() + while fake_reader.yield_num < data_num: + yield fake_reader.data + fake_reader.yield_num += 1 + + fake_reader.data = None + fake_reader.yield_num = 0 + + return fake_reader diff --git a/python/paddle/reader/tests/decorator_test.py b/python/paddle/reader/tests/decorator_test.py index c324092f885..cd585403fb1 100644 --- a/python/paddle/reader/tests/decorator_test.py +++ b/python/paddle/reader/tests/decorator_test.py @@ -203,5 +203,20 @@ class TestMultiProcessReader(unittest.TestCase): self.reader_test(use_pipe=True) +class TestFakeReader(unittest.TestCase): + def test_fake_reader(self): + def reader(): + for i in range(10): + yield i + + data_num = 100 + fake_reader = paddle.reader.fake(reader, data_num) + i = 0 + for data in fake_reader(): + self.assertEqual(data, 0) + i += 1 + self.assertEqual(i, data_num) + + if __name__ == '__main__': unittest.main() -- GitLab