diff --git a/paddle/fluid/operators/reader/open_files_op.cc b/paddle/fluid/operators/reader/open_files_op.cc index 1ab4111efe80f573d12552d7de4e11707c23ff33..414c76fea0bb916dfeafe38c0448a7a800889e03 100644 --- a/paddle/fluid/operators/reader/open_files_op.cc +++ b/paddle/fluid/operators/reader/open_files_op.cc @@ -122,6 +122,7 @@ void MultipleReader::ScheduleThreadFunc() { // No more file to read. ++completed_thread_num; if (completed_thread_num == prefetchers_.size()) { + buffer_->Close(); break; } } diff --git a/python/paddle/fluid/tests/unittests/.gitignore b/python/paddle/fluid/tests/unittests/.gitignore index 6b3fc2a83c649c28d21c9a8a0b35c2f2fa04f269..ad02bdecf436bba925e2e3b7efb20c878df70dfd 100644 --- a/python/paddle/fluid/tests/unittests/.gitignore +++ b/python/paddle/fluid/tests/unittests/.gitignore @@ -1 +1,4 @@ mnist.recordio +mnist_0.recordio +mnist_1.recordio +mnist_2.recordio diff --git a/python/paddle/fluid/tests/unittests/test_multiple_reader.py b/python/paddle/fluid/tests/unittests/test_multiple_reader.py index cb1aaaae5a7a459ae7a545cc65209dcfd7e80d89..69f8acf81efaba8fc0f3df4cfe3a42dc4e477df2 100644 --- a/python/paddle/fluid/tests/unittests/test_multiple_reader.py +++ b/python/paddle/fluid/tests/unittests/test_multiple_reader.py @@ -22,9 +22,10 @@ from shutil import copyfile class TestMultipleReader(unittest.TestCase): def setUp(self): + self.batch_size = 64 # Convert mnist to recordio file with fluid.program_guard(fluid.Program(), fluid.Program()): - reader = paddle.batch(mnist.train(), batch_size=32) + reader = paddle.batch(mnist.train(), batch_size=self.batch_size) feeder = fluid.DataFeeder( feed_list=[ # order is image and label fluid.layers.data( @@ -37,9 +38,8 @@ class TestMultipleReader(unittest.TestCase): './mnist_0.recordio', reader, feeder) copyfile('./mnist_0.recordio', './mnist_1.recordio') copyfile('./mnist_0.recordio', './mnist_2.recordio') - print(self.num_batch) - def test_multiple_reader(self, thread_num=3): + def main(self, thread_num): file_list = [ './mnist_0.recordio', './mnist_1.recordio', './mnist_2.recordio' ] @@ -64,8 +64,11 @@ class TestMultipleReader(unittest.TestCase): while not data_files.eof(): img_val, = exe.run(fetch_list=[img]) batch_count += 1 - print(batch_count) - # data_files.reset() - print("FUCK") - + self.assertLessEqual(img_val.shape[0], self.batch_size) + data_files.reset() self.assertEqual(batch_count, self.num_batch * 3) + + def test_main(self): + self.main(thread_num=3) # thread number equals to file number + self.main(thread_num=10) # thread number is larger than file number + self.main(thread_num=2) # thread number is less than file number