提交 2532b922 编写于 作者: F fengjiayi

Add more unittests and fix bugs

上级 f8638664
......@@ -122,6 +122,7 @@ void MultipleReader::ScheduleThreadFunc() {
// No more file to read.
++completed_thread_num;
if (completed_thread_num == prefetchers_.size()) {
buffer_->Close();
break;
}
}
......
mnist.recordio
mnist_0.recordio
mnist_1.recordio
mnist_2.recordio
......@@ -22,9 +22,10 @@ from shutil import copyfile
class TestMultipleReader(unittest.TestCase):
def setUp(self):
self.batch_size = 64
# Convert mnist to recordio file
with fluid.program_guard(fluid.Program(), fluid.Program()):
reader = paddle.batch(mnist.train(), batch_size=32)
reader = paddle.batch(mnist.train(), batch_size=self.batch_size)
feeder = fluid.DataFeeder(
feed_list=[ # order is image and label
fluid.layers.data(
......@@ -37,9 +38,8 @@ class TestMultipleReader(unittest.TestCase):
'./mnist_0.recordio', reader, feeder)
copyfile('./mnist_0.recordio', './mnist_1.recordio')
copyfile('./mnist_0.recordio', './mnist_2.recordio')
print(self.num_batch)
def test_multiple_reader(self, thread_num=3):
def main(self, thread_num):
file_list = [
'./mnist_0.recordio', './mnist_1.recordio', './mnist_2.recordio'
]
......@@ -64,8 +64,11 @@ class TestMultipleReader(unittest.TestCase):
while not data_files.eof():
img_val, = exe.run(fetch_list=[img])
batch_count += 1
print(batch_count)
# data_files.reset()
print("FUCK")
self.assertLessEqual(img_val.shape[0], self.batch_size)
data_files.reset()
self.assertEqual(batch_count, self.num_batch * 3)
def test_main(self):
self.main(thread_num=3) # thread number equals to file number
self.main(thread_num=10) # thread number is larger than file number
self.main(thread_num=2) # thread number is less than file number
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册