From 466935b4ecdc57e314ebec76a647d70c67f4494b Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Wed, 22 Feb 2017 15:57:24 -0800 Subject: [PATCH] add decorator: map_readers --- python/paddle/reader/decorator.py | 25 +++++++++++++++++++- python/paddle/reader/tests/decorator_test.py | 16 +++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/python/paddle/reader/decorator.py b/python/paddle/reader/decorator.py index d656d5feb4..9f4234358f 100644 --- a/python/paddle/reader/decorator.py +++ b/python/paddle/reader/decorator.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = ['buffered', 'compose', 'chain', 'shuffle', 'ComposeNotAligned'] +__all__ = [ + 'map_readers', 'buffered', 'compose', 'chain', 'shuffle', + 'ComposeNotAligned' +] from Queue import Queue from threading import Thread @@ -20,6 +23,26 @@ import itertools import random +def map_readers(func, *readers): + """ + Creates a data reader that outputs return value of function using + output of each data readers as arguments. + + :param func: function to use. + :param *readers: readers whose outputs will be used as arguments of func. + :returns: the created data reader. + """ + + def reader(): + rs = [] + for r in readers: + rs.append(r()) + for e in itertools.imap(func, *rs): + yield e + + return reader + + def shuffle(reader, buf_size): """ Creates a data reader whose data output is suffled. diff --git a/python/paddle/reader/tests/decorator_test.py b/python/paddle/reader/tests/decorator_test.py index 46eec44158..0396a61786 100644 --- a/python/paddle/reader/tests/decorator_test.py +++ b/python/paddle/reader/tests/decorator_test.py @@ -26,6 +26,22 @@ def reader_creator_10(dur): return reader +class TestMap(unittest.TestCase): + def test_map(self): + d = {"h": 0, "i": 1} + + def tokenize(x): + return d[x] + + def read(): + yield "h" + yield "i" + + r = paddle.reader.map_readers(tokenize, read) + for i, e in enumerate(r()): + self.assertEqual(e, i) + + class TestBuffered(unittest.TestCase): def test_read(self): for size in range(20): -- GitLab