提交 2da240c7 编写于 作者: H Helin Wang

fix local recordio reader

上级 f80fea8d
......@@ -57,7 +57,7 @@ def text_file(path):
return reader
def recordio_local(paths, buf_size=100):
def recordio(paths, buf_size=100):
"""
Creates a data reader from given RecordIO file paths separated by ",",
glob pattern is supported.
......@@ -67,15 +67,19 @@ def recordio_local(paths, buf_size=100):
import recordio as rec
import paddle.v2.reader.decorator as dec
import cPickle as pickle
def reader():
a = ','.join(paths)
f = rec.reader(a)
if isinstance(paths, basestring):
path = paths
else:
path = ",".join(paths)
f = rec.reader(path)
while True:
r = f.read()
if r is None:
break
yield r
yield pickle.loads(r)
f.close()
return dec.buffered(reader, buf_size)
......
......@@ -34,5 +34,27 @@ class TestTextFile(unittest.TestCase):
self.assertEqual(e, str(idx * 2) + " " + str(idx * 2 + 1))
class TestRecordIO(unittest.TestCase):
def do_test(self, path):
reader = paddle.v2.reader.creator.recordio(path)
idx = 0
for e in reader():
if idx == 0:
self.assertEqual(e, (1, 2, 3))
elif idx == 1:
self.assertEqual(e, (4, 5, 6))
idx += 1
self.assertEqual(idx, 2)
def test_recordIO(self):
self.do_test(
os.path.join(
os.path.dirname(__file__), "test_reader_recordio.dat"))
self.do_test([
os.path.join(
os.path.dirname(__file__), "test_reader_recordio.dat")
])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册