diff --git a/paddle/fluid/framework/reader.cc b/paddle/fluid/framework/reader.cc index 56bf00e5f91700f0cffa917aad8608caaab0a7fe..76126f3dc64d71770d13f9d66bb30f176c112629 100644 --- a/paddle/fluid/framework/reader.cc +++ b/paddle/fluid/framework/reader.cc @@ -22,7 +22,9 @@ FileReader::FileReader(const std::vector &dims) : dims_(dims) {} void FileReader::ReadNext(std::vector *out) { ReadNextImpl(out); - PADDLE_ENFORCE_EQ(out->size(), dims_.size()); + if (out->empty()) { + return; + } for (size_t i = 0; i < dims_.size(); ++i) { auto &actual = out->at(i).dims(); auto &expect = dims_[i]; diff --git a/python/paddle/fluid/tests/unittests/test_recordio_reader.py b/python/paddle/fluid/tests/unittests/test_recordio_reader.py index 096d99a3f3fbb3fae080d39d6cdad629727797a0..2982cb8ceb50c31e98a8f272baf0db76e4e288c5 100644 --- a/python/paddle/fluid/tests/unittests/test_recordio_reader.py +++ b/python/paddle/fluid/tests/unittests/test_recordio_reader.py @@ -65,8 +65,14 @@ class TestRecordIO(unittest.TestCase): # train a pass batch_id = 0 - while not data_file.eof(): - tmp, = exe.run(fetch_list=[avg_loss]) + while True: + ex = None + try: + tmp, = exe.run(fetch_list=[avg_loss]) + except fluid.core.EnforceNotMet as ex: + self.assertIn("There is no next data.", ex.message) + break + avg_loss_np.append(tmp) batch_id += 1 data_file.reset()