diff --git a/paddle/fluid/CMakeLists.txt b/paddle/fluid/CMakeLists.txt index 519a00fb073b08f6c88de8186de187476b548fd3..6e3411f7a2861ea050ebe09de14abd054ca2bb12 100644 --- a/paddle/fluid/CMakeLists.txt +++ b/paddle/fluid/CMakeLists.txt @@ -14,4 +14,4 @@ if(WITH_INFERENCE) add_subdirectory(inference) endif() -add_subdirectory(train) +#add_subdirectory(train) diff --git a/python/paddle/reader/decorator.py b/python/paddle/reader/decorator.py index 7b73a3a930cc2c53a54bf06bb16007e01ed5c9c3..2c1ae57472faada9634bd41f66b0254cb395d9d0 100644 --- a/python/paddle/reader/decorator.py +++ b/python/paddle/reader/decorator.py @@ -15,7 +15,7 @@ __all__ = [ 'map_readers', 'buffered', 'compose', 'chain', 'shuffle', 'ComposeNotAligned', 'firstn', 'xmap_readers', 'PipeReader', - 'multiprocess_reader', 'fake' + 'multiprocess_reader', 'Fake' ] from threading import Thread @@ -506,25 +506,37 @@ class PipeReader: break -def fake(reader, data_num): +class Fake(object): """ fake reader will cache the first data it read and yield it out for data_num times. It is used to cache a data from real reader and use it for speed testing. :param reader: the origin reader :param data_num: times that this reader will yield data. + :return: a fake reader. + + Examples: + .. code-block:: python + + def reader(): + for i in range(10): + yield i + + fake_reader = Fake()(reader, 100) """ - def fake_reader(): - if fake_reader.data is None: - fake_reader.data = reader().next() - while fake_reader.yield_num < data_num: - yield fake_reader.data - fake_reader.yield_num += 1 - fake_reader.yield_num = 0 + def __init__(self): + self.data = None + self.yield_num = 0 - fake_reader.data = None - fake_reader.yield_num = 0 + def __call__(self, reader, data_num): + def fake_reader(): + if self.data is None: + self.data = reader().next() + while self.yield_num < data_num: + yield self.data + self.yield_num += 1 + self.yield_num = 0 - return fake_reader + return fake_reader diff --git a/python/paddle/reader/tests/decorator_test.py b/python/paddle/reader/tests/decorator_test.py index e57f9cc29dfe8cf925a418ee26d57afe4746b195..b9af8348e16c051db64d57a9594aee303d83aef2 100644 --- a/python/paddle/reader/tests/decorator_test.py +++ b/python/paddle/reader/tests/decorator_test.py @@ -210,7 +210,7 @@ class TestFakeReader(unittest.TestCase): yield i data_num = 100 - fake_reader = paddle.reader.fake(reader, data_num) + fake_reader = paddle.reader.Fake()(reader, data_num) for _ in range(10): i = 0 for data in fake_reader():