提交 97a594e7 编写于 作者: Y Yancey 提交者: GitHub

Split dataset into multiple files (#2320)

cluster dataset split and reader
上级 4ac9a6fa
......@@ -19,8 +19,10 @@ import shutil
import sys
import importlib
import paddle.v2.dataset
import cPickle
import glob
__all__ = ['DATA_HOME', 'download', 'md5file']
__all__ = ['DATA_HOME', 'download', 'md5file', 'split', 'cluster_files_reader']
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
......@@ -74,3 +76,76 @@ def fetch_all():
getattr(
importlib.import_module("paddle.v2.dataset.%s" % module_name),
"fetch")()
def split(reader, line_count, suffix="%05d.pickle", dumper=cPickle.dump):
"""
you can call the function as:
split(paddle.v2.dataset.cifar.train10(), line_count=1000,
suffix="imikolov-train-%05d.pickle")
the output files as:
|-imikolov-train-00000.pickle
|-imikolov-train-00001.pickle
|- ...
|-imikolov-train-00480.pickle
:param reader: is a reader creator
:param line_count: line count for each file
:param suffix: the suffix for the output files, should contain "%d"
means the id for each file. Default is "%05d.pickle"
:param dumper: is a callable function that dump object to file, this
function will be called as dumper(obj, f) and obj is the object
will be dumped, f is a file object. Default is cPickle.dump.
"""
if not callable(dumper):
raise TypeError("dumper should be callable.")
lines = []
indx_f = 0
for i, d in enumerate(reader()):
lines.append(d)
if i >= line_count and i % line_count == 0:
with open(suffix % indx_f, "w") as f:
dumper(lines, f)
lines = []
indx_f += 1
if lines:
with open(suffix % indx_f, "w") as f:
dumper(lines, f)
def cluster_files_reader(files_pattern,
trainer_count,
trainer_id,
loader=cPickle.load):
"""
Create a reader that yield element from the given files, select
a file set according trainer count and trainer_id
:param files_pattern: the files which generating by split(...)
:param trainer_count: total trainer count
:param trainer_id: the trainer rank id
:param loader: is a callable function that load object from file, this
function will be called as loader(f) and f is a file object.
Default is cPickle.load
"""
def reader():
if not callable(loader):
raise TypeError("loader should be callable.")
file_list = glob.glob(files_pattern)
file_list.sort()
my_file_list = []
for idx, fn in enumerate(file_list):
if idx % trainer_count == trainer_id:
print "append file: %s" % fn
my_file_list.append(fn)
for fn in my_file_list:
with open(fn, "r") as f:
lines = loader(f)
for line in lines:
yield line
return reader
......@@ -15,6 +15,7 @@
import paddle.v2.dataset.common
import unittest
import tempfile
import glob
class TestCommon(unittest.TestCase):
......@@ -32,6 +33,30 @@ class TestCommon(unittest.TestCase):
paddle.v2.dataset.common.download(
yi_avatar, 'test', 'f75287202d6622414c706c36c16f8e0d'))
def test_split(self):
def test_reader():
def reader():
for x in xrange(10):
yield x
return reader
_, temp_path = tempfile.mkstemp()
paddle.v2.dataset.common.split(
test_reader(), 4, suffix=temp_path + '/test-%05d.pickle')
files = glob.glob(temp_path + '/test-%05d.pickle')
self.assertEqual(len(files), 3)
def test_cluster_file_reader(self):
_, temp_path = tempfile.mkstemp()
for x in xrange(5):
with open(temp_path + '/%05d.test' % x) as f:
f.write('%d\n' % x)
reader = paddle.v2.dataset.common.cluster_files_reader(
temp_path + '/*.test', 5, 0)
for idx, e in enumerate(reader()):
self.assertEqual(e, str("0"))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册