From 7eedced82a91f6edcaf1ce4b41c20ab443d3cfc2 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 12 Mar 2018 12:51:19 +0800 Subject: [PATCH] Polish RecordIO --- paddle/fluid/framework/reader.h | 6 +++++ .../reader/create_double_buffer_reader_op.cc | 4 +++ .../reader/create_random_data_generator_op.cc | 2 ++ paddle/fluid/recordio/chunk.cc | 1 + paddle/fluid/recordio/scanner.cc | 4 +-- paddle/fluid/recordio/scanner.h | 2 +- paddle/fluid/recordio/writer_scanner_test.cc | 25 +++++++++++++++++++ .../tests/unittests/test_recordio_reader.py | 17 +++++-------- 8 files changed, 47 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/framework/reader.h b/paddle/fluid/framework/reader.h index 1be3f4ef1f..e281c9b13f 100644 --- a/paddle/fluid/framework/reader.h +++ b/paddle/fluid/framework/reader.h @@ -33,6 +33,8 @@ class ReaderBase { std::vector shapes() const { return shapes_; } void set_shapes(const std::vector& shapes) { shapes_ = shapes; } + virtual bool HasNext() const = 0; + virtual ~ReaderBase() {} protected: @@ -53,6 +55,8 @@ class DecoratedReader : public ReaderBase { void ReInit() override { reader_->ReInit(); } + bool HasNext() const override { return reader_->HasNext(); } + protected: ReaderBase* reader_; }; @@ -74,6 +78,8 @@ class ReaderHolder { reader_->set_shapes(shapes); } + bool HasNext() const { return reader_->HasNext(); } + private: std::unique_ptr reader_; }; diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index b6a0609a1e..ba08ea12e2 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -37,6 +37,8 @@ class DoubleBufferReader : public framework::DecoratedReader { ~DoubleBufferReader() { buffer_->Close(); } + bool HasNext() const override; + private: void PrefetchThreadFunc(); @@ -106,6 +108,8 @@ void DoubleBufferReader::PrefetchThreadFunc() { } } +bool DoubleBufferReader::HasNext() const { PADDLE_THROW("Not Implemented"); } + } // namespace reader } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/reader/create_random_data_generator_op.cc b/paddle/fluid/operators/reader/create_random_data_generator_op.cc index 73c39b5da4..e62f952d0e 100644 --- a/paddle/fluid/operators/reader/create_random_data_generator_op.cc +++ b/paddle/fluid/operators/reader/create_random_data_generator_op.cc @@ -52,6 +52,8 @@ class RandomDataGenerator : public framework::FileReader { void ReInit() override { return; } + bool HasNext() const override { return true; } + private: float min_; float max_; diff --git a/paddle/fluid/recordio/chunk.cc b/paddle/fluid/recordio/chunk.cc index 13d059f844..187a6a4ea7 100644 --- a/paddle/fluid/recordio/chunk.cc +++ b/paddle/fluid/recordio/chunk.cc @@ -146,6 +146,7 @@ bool Chunk::Parse(std::istream& sin) { std::string buf; buf.resize(rec_len); stream.read(&buf[0], rec_len); + PADDLE_ENFORCE_EQ(rec_len, stream.gcount()); Add(buf); } return true; diff --git a/paddle/fluid/recordio/scanner.cc b/paddle/fluid/recordio/scanner.cc index 7f19c46e7e..d842f8fe5a 100644 --- a/paddle/fluid/recordio/scanner.cc +++ b/paddle/fluid/recordio/scanner.cc @@ -32,9 +32,9 @@ void Scanner::Reset() { ParseNextChunk(); } -const std::string &Scanner::Next() { +std::string Scanner::Next() { PADDLE_ENFORCE(!eof_, "StopIteration"); - auto &rec = cur_chunk_.Record(offset_++); + auto rec = cur_chunk_.Record(offset_++); if (offset_ == cur_chunk_.NumRecords()) { ParseNextChunk(); } diff --git a/paddle/fluid/recordio/scanner.h b/paddle/fluid/recordio/scanner.h index 3073d0c5c8..f3f17b69f1 100644 --- a/paddle/fluid/recordio/scanner.h +++ b/paddle/fluid/recordio/scanner.h @@ -28,7 +28,7 @@ public: void Reset(); - const std::string& Next(); + std::string Next(); bool HasNext() const; diff --git a/paddle/fluid/recordio/writer_scanner_test.cc b/paddle/fluid/recordio/writer_scanner_test.cc index a14d3bc3b2..7e764f0d94 100644 --- a/paddle/fluid/recordio/writer_scanner_test.cc +++ b/paddle/fluid/recordio/writer_scanner_test.cc @@ -41,4 +41,29 @@ TEST(WriterScanner, Normal) { ASSERT_EQ("CDE", scanner.Next()); ASSERT_FALSE(scanner.HasNext()); } +} + +TEST(WriterScanner, TinyChunk) { + std::stringstream* stream = new std::stringstream(); + { + paddle::recordio::Writer writer( + stream, paddle::recordio::Compressor::kNoCompress, 2 /*max chunk num*/); + writer.Write("ABC"); + writer.Write("BCD"); + writer.Write("CDE"); + writer.Write("DEFG"); + writer.Flush(); + } + + { + stream->seekg(0, std::ios::beg); + std::unique_ptr stream_ptr(stream); + paddle::recordio::Scanner scanner(std::move(stream_ptr)); + ASSERT_TRUE(scanner.HasNext()); + ASSERT_EQ(scanner.Next(), "ABC"); + ASSERT_EQ(scanner.Next(), "BCD"); + ASSERT_EQ(scanner.Next(), "CDE"); + ASSERT_EQ(scanner.Next(), "DEFG"); + ASSERT_FALSE(scanner.HasNext()); + } } \ No newline at end of file diff --git a/python/paddle/fluid/tests/unittests/test_recordio_reader.py b/python/paddle/fluid/tests/unittests/test_recordio_reader.py index 7844d46320..d249742bd3 100644 --- a/python/paddle/fluid/tests/unittests/test_recordio_reader.py +++ b/python/paddle/fluid/tests/unittests/test_recordio_reader.py @@ -55,15 +55,10 @@ class TestRecordIO(unittest.TestCase): exe.run(fluid.default_startup_program()) avg_loss_np = [] - for i in xrange(2): # 2 pass - batch_id = 0 - while not data_file.eof(): - try: - batch_id += 1 - tmp, = exe.run(fetch_list=[avg_loss]) - avg_loss_np.append(tmp) - except: - print batch_id - break - data_file.reset() + # train a pass + while not data_file.eof(): + tmp, = exe.run(fetch_list=[avg_loss]) + avg_loss_np.append(tmp) + data_file.reset() + self.assertLess(avg_loss_np[-1], avg_loss_np[0]) -- GitLab