提交 283bdc50 编写于 作者: G gongweibao

fix by helin's comments

上级 96a56b96
......@@ -42,7 +42,7 @@ TEST(Argument, poolSequenceWithStride) {
CHECK_EQ(outStart[3], 4);
CHECK_EQ(outStart[4], 7);
CHECK_EQ(stridePositions->getSize(), 8);
CHECK_EQ(stridePositions->getSize(), 8UL);
auto result = reversed ? strideResultReversed : strideResult;
for (int i = 0; i < 8; i++) {
CHECK_EQ(stridePositions->getData()[i], result[i]);
......
......@@ -151,9 +151,14 @@ def cluster_files_reader(files_pattern,
return reader
def convert(output_path, eader, num_shards, name_prefix):
def convert(output_path,
reader,
num_shards,
name_prefix,
max_lines_to_shuffle=10000):
import recordio
import cPickle as pickle
import random
"""
Convert data from reader to recordio format files.
......@@ -161,35 +166,40 @@ def convert(output_path, eader, num_shards, name_prefix):
:param reader: a data reader, from which the convert program will read data instances.
:param num_shards: the number of shards that the dataset will be partitioned into.
:param name_prefix: the name prefix of generated files.
:param max_lines_to_shuffle: the max lines numbers to shuffle before writing.
"""
def open_needs(idx):
n = "%s/%s-%05d" % (output_path, name_prefix, idx)
w = recordio.writer(n)
f = open(n, "w")
idx += 1
assert num_shards >= 1
assert max_lines_to_shuffle >= 1
return w, f, idx
def open_writers():
w = []
for i in range(0, num_shards):
n = "%s/%s-%05d-of-%05d" % (output_path, name_prefix, i,
num_shards - 1)
w.append(recordio.writer(n))
def close_needs(w, f):
if w is not None:
w.close()
return w
if f is not None:
f.close()
def close_writers(w):
for i in range(0, num_shards):
w[i].close()
idx = 0
w = None
f = None
def write_data(w, lines):
random.shuffle(lines)
for i, d in enumerate(lines):
d = pickle.dumps(d, pickle.HIGHEST_PROTOCOL)
w[i % num_shards].write(d)
for i, d in enumerate(reader()):
if w is None:
w, f, idx = open_needs(idx)
w.write(pickle.dumps(d, pickle.HIGHEST_PROTOCOL))
w = open_writers()
lines = []
if i % num_shards == 0 and i >= num_shards:
close_needs(w, f)
w, f, idx = open_needs(idx)
for i, d in enumerate(reader()):
lines.append(d)
if i % max_lines_to_shuffle == 0 and i >= max_lines_to_shuffle:
write_data(w, lines)
lines = []
continue
close_needs(w, f)
write_data(w, lines)
close_writers(w)
......@@ -58,20 +58,36 @@ class TestCommon(unittest.TestCase):
self.assertEqual(e, str("0"))
def test_convert(self):
record_num = 10
num_shards = 4
def test_reader():
def reader():
for x in xrange(10):
for x in xrange(record_num):
yield x
return reader
path = tempfile.mkdtemp()
paddle.v2.dataset.common.convert(path,
test_reader(), 4, 'random_images')
test_reader(), num_shards,
'random_images')
files = glob.glob(temp_path + '/random_images-*')
self.assertEqual(len(files), 3)
files = glob.glob(path + '/random_images-*')
self.assertEqual(len(files), num_shards)
recs = []
for i in range(0, num_shards):
n = "%s/random_images-%05d-of-%05d" % (path, i, num_shards - 1)
r = recordio.reader(n)
while True:
d = r.read()
if d is None:
break
recs.append(d)
recs.sort()
self.assertEqual(total, record_num)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册