提交 7b5a9d75 编写于 作者: S sneaxiy

add cache reader

test=develop
上级 69b1ebdf
...@@ -511,6 +511,7 @@ paddle.fluid.unique_name.guard ArgSpec(args=['new_generator'], varargs=None, key ...@@ -511,6 +511,7 @@ paddle.fluid.unique_name.guard ArgSpec(args=['new_generator'], varargs=None, key
paddle.fluid.recordio_writer.convert_reader_to_recordio_file ArgSpec(args=['filename', 'reader_creator', 'feeder', 'compressor', 'max_num_records', 'feed_order'], varargs=None, keywords=None, defaults=(Compressor.Snappy, 1000, None)) paddle.fluid.recordio_writer.convert_reader_to_recordio_file ArgSpec(args=['filename', 'reader_creator', 'feeder', 'compressor', 'max_num_records', 'feed_order'], varargs=None, keywords=None, defaults=(Compressor.Snappy, 1000, None))
paddle.fluid.recordio_writer.convert_reader_to_recordio_files ArgSpec(args=['filename', 'batch_per_file', 'reader_creator', 'feeder', 'compressor', 'max_num_records', 'feed_order'], varargs=None, keywords=None, defaults=(Compressor.Snappy, 1000, None)) paddle.fluid.recordio_writer.convert_reader_to_recordio_files ArgSpec(args=['filename', 'batch_per_file', 'reader_creator', 'feeder', 'compressor', 'max_num_records', 'feed_order'], varargs=None, keywords=None, defaults=(Compressor.Snappy, 1000, None))
paddle.fluid.Scope Scope() -> paddle.fluid.core._Scope paddle.fluid.Scope Scope() -> paddle.fluid.core._Scope
paddle.reader.cache ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None)
paddle.reader.map_readers ArgSpec(args=['func'], varargs='readers', keywords=None, defaults=None) paddle.reader.map_readers ArgSpec(args=['func'], varargs='readers', keywords=None, defaults=None)
paddle.reader.buffered ArgSpec(args=['reader', 'size'], varargs=None, keywords=None, defaults=None) paddle.reader.buffered ArgSpec(args=['reader', 'size'], varargs=None, keywords=None, defaults=None)
paddle.reader.compose ArgSpec(args=[], varargs='readers', keywords='kwargs', defaults=None) paddle.reader.compose ArgSpec(args=[], varargs='readers', keywords='kwargs', defaults=None)
......
...@@ -123,7 +123,6 @@ class PyReader(object): ...@@ -123,7 +123,6 @@ class PyReader(object):
self._use_double_buffer = use_double_buffer self._use_double_buffer = use_double_buffer
self._capacity = capacity self._capacity = capacity
self._feed_list = feed_list self._feed_list = feed_list
self._scope = global_scope()
if not self._iterable: if not self._iterable:
self._init_non_iterable() self._init_non_iterable()
...@@ -153,7 +152,7 @@ class PyReader(object): ...@@ -153,7 +152,7 @@ class PyReader(object):
reader_name = PyReader.unique_name_generator('create_py_reader') reader_name = PyReader.unique_name_generator('create_py_reader')
double_buffer_name = PyReader.unique_name_generator('double_buffer') double_buffer_name = PyReader.unique_name_generator('double_buffer')
var = self._scope.var(queue_name) var = global_scope().var(queue_name)
self._queue = core.init_lod_tensor_blocking_queue(var, self._capacity) self._queue = core.init_lod_tensor_blocking_queue(var, self._capacity)
startup_blk = default_startup_program().current_block() startup_blk = default_startup_program().current_block()
...@@ -215,6 +214,9 @@ class PyReader(object): ...@@ -215,6 +214,9 @@ class PyReader(object):
def __iter__(self): def __iter__(self):
return self return self
def __next__(self):
return self.next()
def next(self): def next(self):
ret = self._reader.read_next() ret = self._reader.read_next()
if ret: if ret:
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
__all__ = [ __all__ = [
'map_readers', 'buffered', 'compose', 'chain', 'shuffle', 'cache', 'map_readers', 'buffered', 'compose', 'chain', 'shuffle',
'ComposeNotAligned', 'firstn', 'xmap_readers', 'PipeReader', 'ComposeNotAligned', 'firstn', 'xmap_readers', 'PipeReader',
'multiprocess_reader', 'Fake' 'multiprocess_reader', 'Fake'
] ]
...@@ -33,6 +33,31 @@ import zlib ...@@ -33,6 +33,31 @@ import zlib
import paddle.compat as cpt import paddle.compat as cpt
def cache(reader):
"""
Cache the reader data into memory.
Be careful that this method may take long time to process,
and consume lots of memory. :code:`reader()` would only
call once.
Args:
reader (generator): a reader object which yields
data each time.
Returns:
reader (generator): a decorated reader object
which yields data from cached memory.
"""
all_data = tuple(reader())
def __impl__():
for item in all_data:
yield item
return __impl__
def map_readers(func, *readers): def map_readers(func, *readers):
""" """
Creates a data reader that outputs return value of function using Creates a data reader that outputs return value of function using
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册