提交 198689c3 编写于 作者: Q qiaolongfei

add a fake reader for speed test

上级 8cd17c04
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
__all__ = [ __all__ = [
'map_readers', 'buffered', 'compose', 'chain', 'shuffle', 'map_readers', 'buffered', 'compose', 'chain', 'shuffle',
'ComposeNotAligned', 'firstn', 'xmap_readers', 'PipeReader', 'ComposeNotAligned', 'firstn', 'xmap_readers', 'PipeReader',
'multiprocess_reader' 'multiprocess_reader', 'fake'
] ]
from threading import Thread from threading import Thread
...@@ -504,3 +504,26 @@ class PipeReader: ...@@ -504,3 +504,26 @@ class PipeReader:
yield decomp_buff yield decomp_buff
else: else:
break break
def fake(reader, data_num):
"""
fake reader will cache the first data it read and yield it out for data_num times.
It is used to cache a data from real reader and use it for speed testing.
:param reader: the origin reader
:param data_num: times that this reader will yield data.
:return: a fake reader.
"""
def fake_reader():
if fake_reader.data is None:
fake_reader.data = reader().next()
while fake_reader.yield_num < data_num:
yield fake_reader.data
fake_reader.yield_num += 1
fake_reader.data = None
fake_reader.yield_num = 0
return fake_reader
...@@ -203,5 +203,20 @@ class TestMultiProcessReader(unittest.TestCase): ...@@ -203,5 +203,20 @@ class TestMultiProcessReader(unittest.TestCase):
self.reader_test(use_pipe=True) self.reader_test(use_pipe=True)
class TestFakeReader(unittest.TestCase):
def test_fake_reader(self):
def reader():
for i in range(10):
yield i
data_num = 100
fake_reader = paddle.reader.fake(reader, data_num)
i = 0
for data in fake_reader():
self.assertEqual(data, 0)
i += 1
self.assertEqual(i, data_num)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册