提交 ec38effc 编写于 作者: Q qiaolongfei

optimize fake, change it to a class instead a function. test=develop

上级 ebd9f7fc
...@@ -14,4 +14,4 @@ if(WITH_INFERENCE) ...@@ -14,4 +14,4 @@ if(WITH_INFERENCE)
add_subdirectory(inference) add_subdirectory(inference)
endif() endif()
add_subdirectory(train) #add_subdirectory(train)
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
__all__ = [ __all__ = [
'map_readers', 'buffered', 'compose', 'chain', 'shuffle', 'map_readers', 'buffered', 'compose', 'chain', 'shuffle',
'ComposeNotAligned', 'firstn', 'xmap_readers', 'PipeReader', 'ComposeNotAligned', 'firstn', 'xmap_readers', 'PipeReader',
'multiprocess_reader', 'fake' 'multiprocess_reader', 'Fake'
] ]
from threading import Thread from threading import Thread
...@@ -506,25 +506,37 @@ class PipeReader: ...@@ -506,25 +506,37 @@ class PipeReader:
break break
def fake(reader, data_num): class Fake(object):
""" """
fake reader will cache the first data it read and yield it out for data_num times. fake reader will cache the first data it read and yield it out for data_num times.
It is used to cache a data from real reader and use it for speed testing. It is used to cache a data from real reader and use it for speed testing.
:param reader: the origin reader :param reader: the origin reader
:param data_num: times that this reader will yield data. :param data_num: times that this reader will yield data.
:return: a fake reader. :return: a fake reader.
Examples:
.. code-block:: python
def reader():
for i in range(10):
yield i
fake_reader = Fake()(reader, 100)
""" """
def __init__(self):
self.data = None
self.yield_num = 0
def __call__(self, reader, data_num):
def fake_reader(): def fake_reader():
if fake_reader.data is None: if self.data is None:
fake_reader.data = reader().next() self.data = reader().next()
while fake_reader.yield_num < data_num: while self.yield_num < data_num:
yield fake_reader.data yield self.data
fake_reader.yield_num += 1 self.yield_num += 1
fake_reader.yield_num = 0 self.yield_num = 0
fake_reader.data = None
fake_reader.yield_num = 0
return fake_reader return fake_reader
...@@ -210,7 +210,7 @@ class TestFakeReader(unittest.TestCase): ...@@ -210,7 +210,7 @@ class TestFakeReader(unittest.TestCase):
yield i yield i
data_num = 100 data_num = 100
fake_reader = paddle.reader.fake(reader, data_num) fake_reader = paddle.reader.Fake()(reader, data_num)
for _ in range(10): for _ in range(10):
i = 0 i = 0
for data in fake_reader(): for data in fake_reader():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册