From 2532b922dc4897478589d7b4064cde40113f943b Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 20 Mar 2018 19:20:58 +0800 Subject: [PATCH] Add more unittests and fix bugs --- paddle/fluid/operators/reader/open_files_op.cc | 1 + python/paddle/fluid/tests/unittests/.gitignore | 3 +++ .../tests/unittests/test_multiple_reader.py | 17 ++++++++++------- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/reader/open_files_op.cc b/paddle/fluid/operators/reader/open_files_op.cc index 1ab4111efe8..414c76fea0b 100644 --- a/paddle/fluid/operators/reader/open_files_op.cc +++ b/paddle/fluid/operators/reader/open_files_op.cc @@ -122,6 +122,7 @@ void MultipleReader::ScheduleThreadFunc() { // No more file to read. ++completed_thread_num; if (completed_thread_num == prefetchers_.size()) { + buffer_->Close(); break; } } diff --git a/python/paddle/fluid/tests/unittests/.gitignore b/python/paddle/fluid/tests/unittests/.gitignore index 6b3fc2a83c6..ad02bdecf43 100644 --- a/python/paddle/fluid/tests/unittests/.gitignore +++ b/python/paddle/fluid/tests/unittests/.gitignore @@ -1 +1,4 @@ mnist.recordio +mnist_0.recordio +mnist_1.recordio +mnist_2.recordio diff --git a/python/paddle/fluid/tests/unittests/test_multiple_reader.py b/python/paddle/fluid/tests/unittests/test_multiple_reader.py index cb1aaaae5a7..69f8acf81ef 100644 --- a/python/paddle/fluid/tests/unittests/test_multiple_reader.py +++ b/python/paddle/fluid/tests/unittests/test_multiple_reader.py @@ -22,9 +22,10 @@ from shutil import copyfile class TestMultipleReader(unittest.TestCase): def setUp(self): + self.batch_size = 64 # Convert mnist to recordio file with fluid.program_guard(fluid.Program(), fluid.Program()): - reader = paddle.batch(mnist.train(), batch_size=32) + reader = paddle.batch(mnist.train(), batch_size=self.batch_size) feeder = fluid.DataFeeder( feed_list=[ # order is image and label fluid.layers.data( @@ -37,9 +38,8 @@ class TestMultipleReader(unittest.TestCase): './mnist_0.recordio', reader, feeder) copyfile('./mnist_0.recordio', './mnist_1.recordio') copyfile('./mnist_0.recordio', './mnist_2.recordio') - print(self.num_batch) - def test_multiple_reader(self, thread_num=3): + def main(self, thread_num): file_list = [ './mnist_0.recordio', './mnist_1.recordio', './mnist_2.recordio' ] @@ -64,8 +64,11 @@ class TestMultipleReader(unittest.TestCase): while not data_files.eof(): img_val, = exe.run(fetch_list=[img]) batch_count += 1 - print(batch_count) - # data_files.reset() - print("FUCK") - + self.assertLessEqual(img_val.shape[0], self.batch_size) + data_files.reset() self.assertEqual(batch_count, self.num_batch * 3) + + def test_main(self): + self.main(thread_num=3) # thread number equals to file number + self.main(thread_num=10) # thread number is larger than file number + self.main(thread_num=2) # thread number is less than file number -- GitLab