提交 283bdc50 编写于 作者: G gongweibao

fix by helin's comments

上级 96a56b96
...@@ -42,7 +42,7 @@ TEST(Argument, poolSequenceWithStride) { ...@@ -42,7 +42,7 @@ TEST(Argument, poolSequenceWithStride) {
CHECK_EQ(outStart[3], 4); CHECK_EQ(outStart[3], 4);
CHECK_EQ(outStart[4], 7); CHECK_EQ(outStart[4], 7);
CHECK_EQ(stridePositions->getSize(), 8); CHECK_EQ(stridePositions->getSize(), 8UL);
auto result = reversed ? strideResultReversed : strideResult; auto result = reversed ? strideResultReversed : strideResult;
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
CHECK_EQ(stridePositions->getData()[i], result[i]); CHECK_EQ(stridePositions->getData()[i], result[i]);
......
...@@ -151,9 +151,14 @@ def cluster_files_reader(files_pattern, ...@@ -151,9 +151,14 @@ def cluster_files_reader(files_pattern,
return reader return reader
def convert(output_path, eader, num_shards, name_prefix): def convert(output_path,
reader,
num_shards,
name_prefix,
max_lines_to_shuffle=10000):
import recordio import recordio
import cPickle as pickle import cPickle as pickle
import random
""" """
Convert data from reader to recordio format files. Convert data from reader to recordio format files.
...@@ -161,35 +166,40 @@ def convert(output_path, eader, num_shards, name_prefix): ...@@ -161,35 +166,40 @@ def convert(output_path, eader, num_shards, name_prefix):
:param reader: a data reader, from which the convert program will read data instances. :param reader: a data reader, from which the convert program will read data instances.
:param num_shards: the number of shards that the dataset will be partitioned into. :param num_shards: the number of shards that the dataset will be partitioned into.
:param name_prefix: the name prefix of generated files. :param name_prefix: the name prefix of generated files.
:param max_lines_to_shuffle: the max lines numbers to shuffle before writing.
""" """
def open_needs(idx): assert num_shards >= 1
n = "%s/%s-%05d" % (output_path, name_prefix, idx) assert max_lines_to_shuffle >= 1
w = recordio.writer(n)
f = open(n, "w")
idx += 1
return w, f, idx def open_writers():
w = []
for i in range(0, num_shards):
n = "%s/%s-%05d-of-%05d" % (output_path, name_prefix, i,
num_shards - 1)
w.append(recordio.writer(n))
def close_needs(w, f): return w
if w is not None:
w.close()
if f is not None: def close_writers(w):
f.close() for i in range(0, num_shards):
w[i].close()
idx = 0
w = None
f = None
for i, d in enumerate(reader()): def write_data(w, lines):
if w is None: random.shuffle(lines)
w, f, idx = open_needs(idx) for i, d in enumerate(lines):
d = pickle.dumps(d, pickle.HIGHEST_PROTOCOL)
w[i % num_shards].write(d)
w.write(pickle.dumps(d, pickle.HIGHEST_PROTOCOL)) w = open_writers()
lines = []
if i % num_shards == 0 and i >= num_shards: for i, d in enumerate(reader()):
close_needs(w, f) lines.append(d)
w, f, idx = open_needs(idx) if i % max_lines_to_shuffle == 0 and i >= max_lines_to_shuffle:
write_data(w, lines)
lines = []
continue
close_needs(w, f) write_data(w, lines)
close_writers(w)
...@@ -58,20 +58,36 @@ class TestCommon(unittest.TestCase): ...@@ -58,20 +58,36 @@ class TestCommon(unittest.TestCase):
self.assertEqual(e, str("0")) self.assertEqual(e, str("0"))
def test_convert(self): def test_convert(self):
record_num = 10
num_shards = 4
def test_reader(): def test_reader():
def reader(): def reader():
for x in xrange(10): for x in xrange(record_num):
yield x yield x
return reader return reader
path = tempfile.mkdtemp() path = tempfile.mkdtemp()
paddle.v2.dataset.common.convert(path, paddle.v2.dataset.common.convert(path,
test_reader(), 4, 'random_images') test_reader(), num_shards,
'random_images')
files = glob.glob(temp_path + '/random_images-*') files = glob.glob(path + '/random_images-*')
self.assertEqual(len(files), 3) self.assertEqual(len(files), num_shards)
recs = []
for i in range(0, num_shards):
n = "%s/random_images-%05d-of-%05d" % (path, i, num_shards - 1)
r = recordio.reader(n)
while True:
d = r.read()
if d is None:
break
recs.append(d)
recs.sort()
self.assertEqual(total, record_num)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册