提交 2da240c7 编写于 作者: H Helin Wang

fix local recordio reader

上级 f80fea8d
...@@ -57,7 +57,7 @@ def text_file(path): ...@@ -57,7 +57,7 @@ def text_file(path):
return reader return reader
def recordio_local(paths, buf_size=100): def recordio(paths, buf_size=100):
""" """
Creates a data reader from given RecordIO file paths separated by ",", Creates a data reader from given RecordIO file paths separated by ",",
glob pattern is supported. glob pattern is supported.
...@@ -67,15 +67,19 @@ def recordio_local(paths, buf_size=100): ...@@ -67,15 +67,19 @@ def recordio_local(paths, buf_size=100):
import recordio as rec import recordio as rec
import paddle.v2.reader.decorator as dec import paddle.v2.reader.decorator as dec
import cPickle as pickle
def reader(): def reader():
a = ','.join(paths) if isinstance(paths, basestring):
f = rec.reader(a) path = paths
else:
path = ",".join(paths)
f = rec.reader(path)
while True: while True:
r = f.read() r = f.read()
if r is None: if r is None:
break break
yield r yield pickle.loads(r)
f.close() f.close()
return dec.buffered(reader, buf_size) return dec.buffered(reader, buf_size)
......
...@@ -34,5 +34,27 @@ class TestTextFile(unittest.TestCase): ...@@ -34,5 +34,27 @@ class TestTextFile(unittest.TestCase):
self.assertEqual(e, str(idx * 2) + " " + str(idx * 2 + 1)) self.assertEqual(e, str(idx * 2) + " " + str(idx * 2 + 1))
class TestRecordIO(unittest.TestCase):
def do_test(self, path):
reader = paddle.v2.reader.creator.recordio(path)
idx = 0
for e in reader():
if idx == 0:
self.assertEqual(e, (1, 2, 3))
elif idx == 1:
self.assertEqual(e, (4, 5, 6))
idx += 1
self.assertEqual(idx, 2)
def test_recordIO(self):
self.do_test(
os.path.join(
os.path.dirname(__file__), "test_reader_recordio.dat"))
self.do_test([
os.path.join(
os.path.dirname(__file__), "test_reader_recordio.dat")
])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册