From ec38effccec853983c5152484585661d80b95564 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 8 Oct 2018 17:21:23 +0800 Subject: [PATCH] optimize fake, change it to a class instead a function. test=develop --- paddle/fluid/CMakeLists.txt | 2 +- python/paddle/reader/decorator.py | 36 +++++++++++++------- python/paddle/reader/tests/decorator_test.py | 2 +- 3 files changed, 26 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/CMakeLists.txt b/paddle/fluid/CMakeLists.txt index 519a00fb073..6e3411f7a28 100644 --- a/paddle/fluid/CMakeLists.txt +++ b/paddle/fluid/CMakeLists.txt @@ -14,4 +14,4 @@ if(WITH_INFERENCE) add_subdirectory(inference) endif() -add_subdirectory(train) +#add_subdirectory(train) diff --git a/python/paddle/reader/decorator.py b/python/paddle/reader/decorator.py index 7b73a3a930c..2c1ae57472f 100644 --- a/python/paddle/reader/decorator.py +++ b/python/paddle/reader/decorator.py @@ -15,7 +15,7 @@ __all__ = [ 'map_readers', 'buffered', 'compose', 'chain', 'shuffle', 'ComposeNotAligned', 'firstn', 'xmap_readers', 'PipeReader', - 'multiprocess_reader', 'fake' + 'multiprocess_reader', 'Fake' ] from threading import Thread @@ -506,25 +506,37 @@ class PipeReader: break -def fake(reader, data_num): +class Fake(object): """ fake reader will cache the first data it read and yield it out for data_num times. It is used to cache a data from real reader and use it for speed testing. :param reader: the origin reader :param data_num: times that this reader will yield data. + :return: a fake reader. + + Examples: + .. code-block:: python + + def reader(): + for i in range(10): + yield i + + fake_reader = Fake()(reader, 100) """ - def fake_reader(): - if fake_reader.data is None: - fake_reader.data = reader().next() - while fake_reader.yield_num < data_num: - yield fake_reader.data - fake_reader.yield_num += 1 - fake_reader.yield_num = 0 + def __init__(self): + self.data = None + self.yield_num = 0 - fake_reader.data = None - fake_reader.yield_num = 0 + def __call__(self, reader, data_num): + def fake_reader(): + if self.data is None: + self.data = reader().next() + while self.yield_num < data_num: + yield self.data + self.yield_num += 1 + self.yield_num = 0 - return fake_reader + return fake_reader diff --git a/python/paddle/reader/tests/decorator_test.py b/python/paddle/reader/tests/decorator_test.py index e57f9cc29df..b9af8348e16 100644 --- a/python/paddle/reader/tests/decorator_test.py +++ b/python/paddle/reader/tests/decorator_test.py @@ -210,7 +210,7 @@ class TestFakeReader(unittest.TestCase): yield i data_num = 100 - fake_reader = paddle.reader.fake(reader, data_num) + fake_reader = paddle.reader.Fake()(reader, data_num) for _ in range(10): i = 0 for data in fake_reader(): -- GitLab