diff --git a/paddle/fluid/framework/reader.h b/paddle/fluid/framework/reader.h index 1be3f4ef1f46bd8d72a285afa69b52d6f519ccf5..e281c9b13fb50fba64185b6db2949fc76f5c4834 100644 --- a/paddle/fluid/framework/reader.h +++ b/paddle/fluid/framework/reader.h @@ -33,6 +33,8 @@ class ReaderBase { std::vector shapes() const { return shapes_; } void set_shapes(const std::vector& shapes) { shapes_ = shapes; } + virtual bool HasNext() const = 0; + virtual ~ReaderBase() {} protected: @@ -53,6 +55,8 @@ class DecoratedReader : public ReaderBase { void ReInit() override { reader_->ReInit(); } + bool HasNext() const override { return reader_->HasNext(); } + protected: ReaderBase* reader_; }; @@ -74,6 +78,8 @@ class ReaderHolder { reader_->set_shapes(shapes); } + bool HasNext() const { return reader_->HasNext(); } + private: std::unique_ptr reader_; }; diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index b6a0609a1e23195ececee0f16a69daa1c1c46ed8..ba08ea12e2486aaba8c57a9fe23592bd1738592d 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -37,6 +37,8 @@ class DoubleBufferReader : public framework::DecoratedReader { ~DoubleBufferReader() { buffer_->Close(); } + bool HasNext() const override; + private: void PrefetchThreadFunc(); @@ -106,6 +108,8 @@ void DoubleBufferReader::PrefetchThreadFunc() { } } +bool DoubleBufferReader::HasNext() const { PADDLE_THROW("Not Implemented"); } + } // namespace reader } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/reader/create_random_data_generator_op.cc b/paddle/fluid/operators/reader/create_random_data_generator_op.cc index 73c39b5da4484b27f75aeba3c8171c5ffed2398f..e62f952d0e89561c3eed56112dc9d1d78801b59e 100644 --- a/paddle/fluid/operators/reader/create_random_data_generator_op.cc +++ b/paddle/fluid/operators/reader/create_random_data_generator_op.cc @@ -52,6 +52,8 @@ class RandomDataGenerator : public framework::FileReader { void ReInit() override { return; } + bool HasNext() const override { return true; } + private: float min_; float max_; diff --git a/paddle/fluid/recordio/chunk.cc b/paddle/fluid/recordio/chunk.cc index 13d059f844aebb847282a80a43e4cbdad71d7fa0..187a6a4ea7bd9d3a8ae48fa262e18f71b0f7d20d 100644 --- a/paddle/fluid/recordio/chunk.cc +++ b/paddle/fluid/recordio/chunk.cc @@ -146,6 +146,7 @@ bool Chunk::Parse(std::istream& sin) { std::string buf; buf.resize(rec_len); stream.read(&buf[0], rec_len); + PADDLE_ENFORCE_EQ(rec_len, stream.gcount()); Add(buf); } return true; diff --git a/paddle/fluid/recordio/scanner.cc b/paddle/fluid/recordio/scanner.cc index 7f19c46e7e0be29ef11c76b21dce751141da36bc..d842f8fe5a4c9d1a2b564c738d97fffb02f3ccb5 100644 --- a/paddle/fluid/recordio/scanner.cc +++ b/paddle/fluid/recordio/scanner.cc @@ -32,9 +32,9 @@ void Scanner::Reset() { ParseNextChunk(); } -const std::string &Scanner::Next() { +std::string Scanner::Next() { PADDLE_ENFORCE(!eof_, "StopIteration"); - auto &rec = cur_chunk_.Record(offset_++); + auto rec = cur_chunk_.Record(offset_++); if (offset_ == cur_chunk_.NumRecords()) { ParseNextChunk(); } diff --git a/paddle/fluid/recordio/scanner.h b/paddle/fluid/recordio/scanner.h index 3073d0c5c872502f4567fcbadd6b4129f865de10..f3f17b69f195ddd92f5a39ead9755a7b8e2dd329 100644 --- a/paddle/fluid/recordio/scanner.h +++ b/paddle/fluid/recordio/scanner.h @@ -28,7 +28,7 @@ public: void Reset(); - const std::string& Next(); + std::string Next(); bool HasNext() const; diff --git a/paddle/fluid/recordio/writer_scanner_test.cc b/paddle/fluid/recordio/writer_scanner_test.cc index a14d3bc3b2b18d27c6b65ad606098f2e4cae7245..7e764f0d9439709ad101af2b8864dc0158bd359b 100644 --- a/paddle/fluid/recordio/writer_scanner_test.cc +++ b/paddle/fluid/recordio/writer_scanner_test.cc @@ -41,4 +41,29 @@ TEST(WriterScanner, Normal) { ASSERT_EQ("CDE", scanner.Next()); ASSERT_FALSE(scanner.HasNext()); } +} + +TEST(WriterScanner, TinyChunk) { + std::stringstream* stream = new std::stringstream(); + { + paddle::recordio::Writer writer( + stream, paddle::recordio::Compressor::kNoCompress, 2 /*max chunk num*/); + writer.Write("ABC"); + writer.Write("BCD"); + writer.Write("CDE"); + writer.Write("DEFG"); + writer.Flush(); + } + + { + stream->seekg(0, std::ios::beg); + std::unique_ptr stream_ptr(stream); + paddle::recordio::Scanner scanner(std::move(stream_ptr)); + ASSERT_TRUE(scanner.HasNext()); + ASSERT_EQ(scanner.Next(), "ABC"); + ASSERT_EQ(scanner.Next(), "BCD"); + ASSERT_EQ(scanner.Next(), "CDE"); + ASSERT_EQ(scanner.Next(), "DEFG"); + ASSERT_FALSE(scanner.HasNext()); + } } \ No newline at end of file diff --git a/python/paddle/fluid/tests/unittests/test_recordio_reader.py b/python/paddle/fluid/tests/unittests/test_recordio_reader.py index 7844d46320e2e94cc06c6bb58c2221b91ae5e508..d249742bd30ec41749f16beaa7076f7c6e8f063c 100644 --- a/python/paddle/fluid/tests/unittests/test_recordio_reader.py +++ b/python/paddle/fluid/tests/unittests/test_recordio_reader.py @@ -55,15 +55,10 @@ class TestRecordIO(unittest.TestCase): exe.run(fluid.default_startup_program()) avg_loss_np = [] - for i in xrange(2): # 2 pass - batch_id = 0 - while not data_file.eof(): - try: - batch_id += 1 - tmp, = exe.run(fetch_list=[avg_loss]) - avg_loss_np.append(tmp) - except: - print batch_id - break - data_file.reset() + # train a pass + while not data_file.eof(): + tmp, = exe.run(fetch_list=[avg_loss]) + avg_loss_np.append(tmp) + data_file.reset() + self.assertLess(avg_loss_np[-1], avg_loss_np[0])