diff --git a/python/paddle/v2/dataset/common.py b/python/paddle/v2/dataset/common.py index 418b592a5ac638cc61b86a9b3fbdcee1e3a0bcaf..89675080e25339a12d07460601c2d7e5d36015ae 100644 --- a/python/paddle/v2/dataset/common.py +++ b/python/paddle/v2/dataset/common.py @@ -149,3 +149,47 @@ def cluster_files_reader(files_pattern, yield line return reader + + +def convert(output_path, eader, num_shards, name_prefix): + import recordio + import cPickle as pickle + """ + Convert data from reader to recordio format files. + + :param output_path: directory in which output files will be saved. + :param reader: a data reader, from which the convert program will read data instances. + :param num_shards: the number of shards that the dataset will be partitioned into. + :param name_prefix: the name prefix of generated files. + """ + + def open_needs(idx): + n = "%s/%s-%05d" % (output_path, name_prefix, idx) + w = recordio.writer(n) + f = open(n, "w") + idx += 1 + + return w, f, idx + + def close_needs(w, f): + if w is not None: + w.close() + + if f is not None: + f.close() + + idx = 0 + w = None + f = None + + for i, d in enumerate(reader()): + if w is None: + w, f, idx = open_needs(idx) + + w.write(pickle.dumps(d, pickle.HIGHEST_PROTOCOL)) + + if i % num_shards == 0 and i >= num_shards: + close_needs(w, f) + w, f, idx = open_needs(idx) + + close_needs(w, f) diff --git a/python/paddle/v2/dataset/tests/common_test.py b/python/paddle/v2/dataset/tests/common_test.py index f9815d4f9e1ee3bbe9ccf2dae588c51c262468c1..3120026e1e669fef2dc665192a338835a0d5d2e5 100644 --- a/python/paddle/v2/dataset/tests/common_test.py +++ b/python/paddle/v2/dataset/tests/common_test.py @@ -57,6 +57,22 @@ class TestCommon(unittest.TestCase): for idx, e in enumerate(reader()): self.assertEqual(e, str("0")) + def test_convert(self): + def test_reader(): + def reader(): + for x in xrange(10): + yield x + + return reader + + path = tempfile.mkdtemp() + + paddle.v2.dataset.common.convert(path, + test_reader(), 4, 'random_images') + + files = glob.glob(temp_path + '/random_images-*') + self.assertEqual(len(files), 3) + if __name__ == '__main__': unittest.main()