提交 466935b4 编写于 作者: H Helin Wang

add decorator: map_readers

上级 963de16b
...@@ -12,7 +12,10 @@ ...@@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
__all__ = ['buffered', 'compose', 'chain', 'shuffle', 'ComposeNotAligned'] __all__ = [
'map_readers', 'buffered', 'compose', 'chain', 'shuffle',
'ComposeNotAligned'
]
from Queue import Queue from Queue import Queue
from threading import Thread from threading import Thread
...@@ -20,6 +23,26 @@ import itertools ...@@ -20,6 +23,26 @@ import itertools
import random import random
def map_readers(func, *readers):
"""
Creates a data reader that outputs return value of function using
output of each data readers as arguments.
:param func: function to use.
:param *readers: readers whose outputs will be used as arguments of func.
:returns: the created data reader.
"""
def reader():
rs = []
for r in readers:
rs.append(r())
for e in itertools.imap(func, *rs):
yield e
return reader
def shuffle(reader, buf_size): def shuffle(reader, buf_size):
""" """
Creates a data reader whose data output is suffled. Creates a data reader whose data output is suffled.
......
...@@ -26,6 +26,22 @@ def reader_creator_10(dur): ...@@ -26,6 +26,22 @@ def reader_creator_10(dur):
return reader return reader
class TestMap(unittest.TestCase):
def test_map(self):
d = {"h": 0, "i": 1}
def tokenize(x):
return d[x]
def read():
yield "h"
yield "i"
r = paddle.reader.map_readers(tokenize, read)
for i, e in enumerate(r()):
self.assertEqual(e, i)
class TestBuffered(unittest.TestCase): class TestBuffered(unittest.TestCase):
def test_read(self): def test_read(self):
for size in range(20): for size in range(20):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册