From 9011f9e52c614c1f357a4220ffbb16ee5155f0df Mon Sep 17 00:00:00 2001 From: gongweibao Date: Wed, 7 Jun 2017 15:52:06 +0800 Subject: [PATCH] add precommit --- python/paddle/v2/dataset/common.py | 44 +++++++++++++++++++ python/paddle/v2/dataset/tests/common_test.py | 16 +++++++ 2 files changed, 60 insertions(+) diff --git a/python/paddle/v2/dataset/common.py b/python/paddle/v2/dataset/common.py index 418b592a5ac..89675080e25 100644 --- a/python/paddle/v2/dataset/common.py +++ b/python/paddle/v2/dataset/common.py @@ -149,3 +149,47 @@ def cluster_files_reader(files_pattern, yield line return reader + + +def convert(output_path, eader, num_shards, name_prefix): + import recordio + import cPickle as pickle + """ + Convert data from reader to recordio format files. + + :param output_path: directory in which output files will be saved. + :param reader: a data reader, from which the convert program will read data instances. + :param num_shards: the number of shards that the dataset will be partitioned into. + :param name_prefix: the name prefix of generated files. + """ + + def open_needs(idx): + n = "%s/%s-%05d" % (output_path, name_prefix, idx) + w = recordio.writer(n) + f = open(n, "w") + idx += 1 + + return w, f, idx + + def close_needs(w, f): + if w is not None: + w.close() + + if f is not None: + f.close() + + idx = 0 + w = None + f = None + + for i, d in enumerate(reader()): + if w is None: + w, f, idx = open_needs(idx) + + w.write(pickle.dumps(d, pickle.HIGHEST_PROTOCOL)) + + if i % num_shards == 0 and i >= num_shards: + close_needs(w, f) + w, f, idx = open_needs(idx) + + close_needs(w, f) diff --git a/python/paddle/v2/dataset/tests/common_test.py b/python/paddle/v2/dataset/tests/common_test.py index f9815d4f9e1..3120026e1e6 100644 --- a/python/paddle/v2/dataset/tests/common_test.py +++ b/python/paddle/v2/dataset/tests/common_test.py @@ -57,6 +57,22 @@ class TestCommon(unittest.TestCase): for idx, e in enumerate(reader()): self.assertEqual(e, str("0")) + def test_convert(self): + def test_reader(): + def reader(): + for x in xrange(10): + yield x + + return reader + + path = tempfile.mkdtemp() + + paddle.v2.dataset.common.convert(path, + test_reader(), 4, 'random_images') + + files = glob.glob(temp_path + '/random_images-*') + self.assertEqual(len(files), 3) + if __name__ == '__main__': unittest.main() -- GitLab