From 283bdc5062be0ba14b0ae3ca6cc211ddaf25fd1c Mon Sep 17 00:00:00 2001 From: gongweibao Date: Mon, 12 Jun 2017 10:29:35 +0800 Subject: [PATCH] fix by helin's comments --- paddle/parameter/tests/test_argument.cpp | 2 +- python/paddle/v2/dataset/common.py | 58 +++++++++++-------- python/paddle/v2/dataset/tests/common_test.py | 26 +++++++-- 3 files changed, 56 insertions(+), 30 deletions(-) diff --git a/paddle/parameter/tests/test_argument.cpp b/paddle/parameter/tests/test_argument.cpp index 81fe4ee397..98ab013548 100644 --- a/paddle/parameter/tests/test_argument.cpp +++ b/paddle/parameter/tests/test_argument.cpp @@ -42,7 +42,7 @@ TEST(Argument, poolSequenceWithStride) { CHECK_EQ(outStart[3], 4); CHECK_EQ(outStart[4], 7); - CHECK_EQ(stridePositions->getSize(), 8); + CHECK_EQ(stridePositions->getSize(), 8UL); auto result = reversed ? strideResultReversed : strideResult; for (int i = 0; i < 8; i++) { CHECK_EQ(stridePositions->getData()[i], result[i]); diff --git a/python/paddle/v2/dataset/common.py b/python/paddle/v2/dataset/common.py index 89675080e2..8023fa3cf8 100644 --- a/python/paddle/v2/dataset/common.py +++ b/python/paddle/v2/dataset/common.py @@ -151,9 +151,14 @@ def cluster_files_reader(files_pattern, return reader -def convert(output_path, eader, num_shards, name_prefix): +def convert(output_path, + reader, + num_shards, + name_prefix, + max_lines_to_shuffle=10000): import recordio import cPickle as pickle + import random """ Convert data from reader to recordio format files. @@ -161,35 +166,40 @@ def convert(output_path, eader, num_shards, name_prefix): :param reader: a data reader, from which the convert program will read data instances. :param num_shards: the number of shards that the dataset will be partitioned into. :param name_prefix: the name prefix of generated files. + :param max_lines_to_shuffle: the max lines numbers to shuffle before writing. """ - def open_needs(idx): - n = "%s/%s-%05d" % (output_path, name_prefix, idx) - w = recordio.writer(n) - f = open(n, "w") - idx += 1 + assert num_shards >= 1 + assert max_lines_to_shuffle >= 1 - return w, f, idx + def open_writers(): + w = [] + for i in range(0, num_shards): + n = "%s/%s-%05d-of-%05d" % (output_path, name_prefix, i, + num_shards - 1) + w.append(recordio.writer(n)) - def close_needs(w, f): - if w is not None: - w.close() + return w - if f is not None: - f.close() + def close_writers(w): + for i in range(0, num_shards): + w[i].close() - idx = 0 - w = None - f = None + def write_data(w, lines): + random.shuffle(lines) + for i, d in enumerate(lines): + d = pickle.dumps(d, pickle.HIGHEST_PROTOCOL) + w[i % num_shards].write(d) - for i, d in enumerate(reader()): - if w is None: - w, f, idx = open_needs(idx) - - w.write(pickle.dumps(d, pickle.HIGHEST_PROTOCOL)) + w = open_writers() + lines = [] - if i % num_shards == 0 and i >= num_shards: - close_needs(w, f) - w, f, idx = open_needs(idx) + for i, d in enumerate(reader()): + lines.append(d) + if i % max_lines_to_shuffle == 0 and i >= max_lines_to_shuffle: + write_data(w, lines) + lines = [] + continue - close_needs(w, f) + write_data(w, lines) + close_writers(w) diff --git a/python/paddle/v2/dataset/tests/common_test.py b/python/paddle/v2/dataset/tests/common_test.py index 3120026e1e..cfa194eba3 100644 --- a/python/paddle/v2/dataset/tests/common_test.py +++ b/python/paddle/v2/dataset/tests/common_test.py @@ -58,20 +58,36 @@ class TestCommon(unittest.TestCase): self.assertEqual(e, str("0")) def test_convert(self): + record_num = 10 + num_shards = 4 + def test_reader(): def reader(): - for x in xrange(10): + for x in xrange(record_num): yield x return reader path = tempfile.mkdtemp() - paddle.v2.dataset.common.convert(path, - test_reader(), 4, 'random_images') + test_reader(), num_shards, + 'random_images') - files = glob.glob(temp_path + '/random_images-*') - self.assertEqual(len(files), 3) + files = glob.glob(path + '/random_images-*') + self.assertEqual(len(files), num_shards) + + recs = [] + for i in range(0, num_shards): + n = "%s/random_images-%05d-of-%05d" % (path, i, num_shards - 1) + r = recordio.reader(n) + while True: + d = r.read() + if d is None: + break + recs.append(d) + + recs.sort() + self.assertEqual(total, record_num) if __name__ == '__main__': -- GitLab