提交 2532b922 编写于 作者: F fengjiayi

Add more unittests and fix bugs

上级 f8638664
...@@ -122,6 +122,7 @@ void MultipleReader::ScheduleThreadFunc() { ...@@ -122,6 +122,7 @@ void MultipleReader::ScheduleThreadFunc() {
// No more file to read. // No more file to read.
++completed_thread_num; ++completed_thread_num;
if (completed_thread_num == prefetchers_.size()) { if (completed_thread_num == prefetchers_.size()) {
buffer_->Close();
break; break;
} }
} }
......
mnist.recordio mnist.recordio
mnist_0.recordio
mnist_1.recordio
mnist_2.recordio
...@@ -22,9 +22,10 @@ from shutil import copyfile ...@@ -22,9 +22,10 @@ from shutil import copyfile
class TestMultipleReader(unittest.TestCase): class TestMultipleReader(unittest.TestCase):
def setUp(self): def setUp(self):
self.batch_size = 64
# Convert mnist to recordio file # Convert mnist to recordio file
with fluid.program_guard(fluid.Program(), fluid.Program()): with fluid.program_guard(fluid.Program(), fluid.Program()):
reader = paddle.batch(mnist.train(), batch_size=32) reader = paddle.batch(mnist.train(), batch_size=self.batch_size)
feeder = fluid.DataFeeder( feeder = fluid.DataFeeder(
feed_list=[ # order is image and label feed_list=[ # order is image and label
fluid.layers.data( fluid.layers.data(
...@@ -37,9 +38,8 @@ class TestMultipleReader(unittest.TestCase): ...@@ -37,9 +38,8 @@ class TestMultipleReader(unittest.TestCase):
'./mnist_0.recordio', reader, feeder) './mnist_0.recordio', reader, feeder)
copyfile('./mnist_0.recordio', './mnist_1.recordio') copyfile('./mnist_0.recordio', './mnist_1.recordio')
copyfile('./mnist_0.recordio', './mnist_2.recordio') copyfile('./mnist_0.recordio', './mnist_2.recordio')
print(self.num_batch)
def test_multiple_reader(self, thread_num=3): def main(self, thread_num):
file_list = [ file_list = [
'./mnist_0.recordio', './mnist_1.recordio', './mnist_2.recordio' './mnist_0.recordio', './mnist_1.recordio', './mnist_2.recordio'
] ]
...@@ -64,8 +64,11 @@ class TestMultipleReader(unittest.TestCase): ...@@ -64,8 +64,11 @@ class TestMultipleReader(unittest.TestCase):
while not data_files.eof(): while not data_files.eof():
img_val, = exe.run(fetch_list=[img]) img_val, = exe.run(fetch_list=[img])
batch_count += 1 batch_count += 1
print(batch_count) self.assertLessEqual(img_val.shape[0], self.batch_size)
# data_files.reset() data_files.reset()
print("FUCK")
self.assertEqual(batch_count, self.num_batch * 3) self.assertEqual(batch_count, self.num_batch * 3)
def test_main(self):
self.main(thread_num=3) # thread number equals to file number
self.main(thread_num=10) # thread number is larger than file number
self.main(thread_num=2) # thread number is less than file number
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册