From 97a594e7f8a43e3e7fadd436933f84fb0e835c36 Mon Sep 17 00:00:00 2001 From: Yancey Date: Fri, 2 Jun 2017 15:01:09 +0800 Subject: [PATCH] Split dataset into multiple files (#2320) cluster dataset split and reader --- python/paddle/v2/dataset/common.py | 77 ++++++++++++++++++- python/paddle/v2/dataset/tests/common_test.py | 25 ++++++ 2 files changed, 101 insertions(+), 1 deletion(-) diff --git a/python/paddle/v2/dataset/common.py b/python/paddle/v2/dataset/common.py index 2eb018b8d6..418b592a5a 100644 --- a/python/paddle/v2/dataset/common.py +++ b/python/paddle/v2/dataset/common.py @@ -19,8 +19,10 @@ import shutil import sys import importlib import paddle.v2.dataset +import cPickle +import glob -__all__ = ['DATA_HOME', 'download', 'md5file'] +__all__ = ['DATA_HOME', 'download', 'md5file', 'split', 'cluster_files_reader'] DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset') @@ -74,3 +76,76 @@ def fetch_all(): getattr( importlib.import_module("paddle.v2.dataset.%s" % module_name), "fetch")() + + +def split(reader, line_count, suffix="%05d.pickle", dumper=cPickle.dump): + """ + you can call the function as: + + split(paddle.v2.dataset.cifar.train10(), line_count=1000, + suffix="imikolov-train-%05d.pickle") + + the output files as: + + |-imikolov-train-00000.pickle + |-imikolov-train-00001.pickle + |- ... + |-imikolov-train-00480.pickle + + :param reader: is a reader creator + :param line_count: line count for each file + :param suffix: the suffix for the output files, should contain "%d" + means the id for each file. Default is "%05d.pickle" + :param dumper: is a callable function that dump object to file, this + function will be called as dumper(obj, f) and obj is the object + will be dumped, f is a file object. Default is cPickle.dump. + """ + if not callable(dumper): + raise TypeError("dumper should be callable.") + lines = [] + indx_f = 0 + for i, d in enumerate(reader()): + lines.append(d) + if i >= line_count and i % line_count == 0: + with open(suffix % indx_f, "w") as f: + dumper(lines, f) + lines = [] + indx_f += 1 + if lines: + with open(suffix % indx_f, "w") as f: + dumper(lines, f) + + +def cluster_files_reader(files_pattern, + trainer_count, + trainer_id, + loader=cPickle.load): + """ + Create a reader that yield element from the given files, select + a file set according trainer count and trainer_id + + :param files_pattern: the files which generating by split(...) + :param trainer_count: total trainer count + :param trainer_id: the trainer rank id + :param loader: is a callable function that load object from file, this + function will be called as loader(f) and f is a file object. + Default is cPickle.load + """ + + def reader(): + if not callable(loader): + raise TypeError("loader should be callable.") + file_list = glob.glob(files_pattern) + file_list.sort() + my_file_list = [] + for idx, fn in enumerate(file_list): + if idx % trainer_count == trainer_id: + print "append file: %s" % fn + my_file_list.append(fn) + for fn in my_file_list: + with open(fn, "r") as f: + lines = loader(f) + for line in lines: + yield line + + return reader diff --git a/python/paddle/v2/dataset/tests/common_test.py b/python/paddle/v2/dataset/tests/common_test.py index 5babcef0eb..f9815d4f9e 100644 --- a/python/paddle/v2/dataset/tests/common_test.py +++ b/python/paddle/v2/dataset/tests/common_test.py @@ -15,6 +15,7 @@ import paddle.v2.dataset.common import unittest import tempfile +import glob class TestCommon(unittest.TestCase): @@ -32,6 +33,30 @@ class TestCommon(unittest.TestCase): paddle.v2.dataset.common.download( yi_avatar, 'test', 'f75287202d6622414c706c36c16f8e0d')) + def test_split(self): + def test_reader(): + def reader(): + for x in xrange(10): + yield x + + return reader + + _, temp_path = tempfile.mkstemp() + paddle.v2.dataset.common.split( + test_reader(), 4, suffix=temp_path + '/test-%05d.pickle') + files = glob.glob(temp_path + '/test-%05d.pickle') + self.assertEqual(len(files), 3) + + def test_cluster_file_reader(self): + _, temp_path = tempfile.mkstemp() + for x in xrange(5): + with open(temp_path + '/%05d.test' % x) as f: + f.write('%d\n' % x) + reader = paddle.v2.dataset.common.cluster_files_reader( + temp_path + '/*.test', 5, 0) + for idx, e in enumerate(reader()): + self.assertEqual(e, str("0")) + if __name__ == '__main__': unittest.main() -- GitLab