diff --git a/paddle/fluid/recordio/chunk.cc b/paddle/fluid/recordio/chunk.cc index 82d9aa601cf450b8f90573d6c582bb12ced7a48a..6c65d9160c059ac143ee258b2bdaed5915a1dca1 100644 --- a/paddle/fluid/recordio/chunk.cc +++ b/paddle/fluid/recordio/chunk.cc @@ -119,40 +119,56 @@ bool Chunk::Write(std::ostream& os, Compressor ct) const { } bool Chunk::Parse(std::istream& sin) { - Header hdr; - bool ok = hdr.Parse(sin); + ChunkParser parser(sin); + if (!parser.Init()) { + return false; + } + Clear(); + while (parser.HasNext()) { + Add(parser.Next()); + } + return true; +} + +ChunkParser::ChunkParser(std::istream& sin) : in_(sin) {} +bool ChunkParser::Init() { + pos_ = 0; + bool ok = header_.Parse(in_); if (!ok) { return ok; } - auto beg_pos = sin.tellg(); - uint32_t crc = Crc32Stream(sin, hdr.CompressSize()); - PADDLE_ENFORCE_EQ(hdr.Checksum(), crc); - Clear(); - sin.seekg(beg_pos, sin.beg); - std::unique_ptr compressed_stream; - switch (hdr.CompressType()) { + auto beg_pos = in_.tellg(); + uint32_t crc = Crc32Stream(in_, header_.CompressSize()); + PADDLE_ENFORCE_EQ(header_.Checksum(), crc); + in_.seekg(beg_pos, in_.beg); + + switch (header_.CompressType()) { case Compressor::kNoCompress: break; case Compressor::kSnappy: - compressed_stream.reset(new snappy::iSnappyStream(sin)); + compressed_stream_.reset(new snappy::iSnappyStream(in_)); break; default: PADDLE_THROW("Not implemented"); } + return true; +} - std::istream& stream = compressed_stream ? *compressed_stream : sin; +bool ChunkParser::HasNext() const { return pos_ < header_.NumRecords(); } - for (uint32_t i = 0; i < hdr.NumRecords(); ++i) { - uint32_t rec_len; - stream.read(reinterpret_cast(&rec_len), sizeof(uint32_t)); - std::string buf; - buf.resize(rec_len); - stream.read(&buf[0], rec_len); - PADDLE_ENFORCE_EQ(rec_len, stream.gcount()); - Add(buf); +std::string ChunkParser::Next() { + if (!HasNext()) { + return ""; } - return true; + ++pos_; + std::istream& stream = compressed_stream_ ? *compressed_stream_ : in_; + uint32_t rec_len; + stream.read(reinterpret_cast(&rec_len), sizeof(uint32_t)); + std::string buf; + buf.resize(rec_len); + stream.read(&buf[0], rec_len); + PADDLE_ENFORCE_EQ(rec_len, stream.gcount()); + return buf; } - } // namespace recordio } // namespace paddle diff --git a/paddle/fluid/recordio/chunk.h b/paddle/fluid/recordio/chunk.h index 71a1556a33bfa5c937d6a799d2818cd5a5ef2094..cfb954a591679c2d2c4f42ecd99ca0c8bd1084cf 100644 --- a/paddle/fluid/recordio/chunk.h +++ b/paddle/fluid/recordio/chunk.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#include #include #include @@ -53,9 +54,20 @@ class Chunk { DISABLE_COPY_AND_ASSIGN(Chunk); }; -size_t CompressData(const char* in, size_t in_length, Compressor ct, char* out); +class ChunkParser { + public: + explicit ChunkParser(std::istream& sin); + + bool Init(); + std::string Next(); + bool HasNext() const; -void DeflateData(const char* in, size_t in_length, Compressor ct, char* out); + private: + Header header_; + uint32_t pos_{0}; + std::istream& in_; + std::unique_ptr compressed_stream_; +}; } // namespace recordio } // namespace paddle diff --git a/paddle/fluid/recordio/scanner.cc b/paddle/fluid/recordio/scanner.cc index 88b4d4001bc1b6dc935a9aabc2db5edfb55a60e4..06a13e6c5b6ea76456e231e3f7b1eb33492b16ea 100644 --- a/paddle/fluid/recordio/scanner.cc +++ b/paddle/fluid/recordio/scanner.cc @@ -22,35 +22,33 @@ namespace paddle { namespace recordio { Scanner::Scanner(std::unique_ptr &&stream) - : stream_(std::move(stream)) { + : stream_(std::move(stream)), parser_(*stream_) { Reset(); } -Scanner::Scanner(const std::string &filename) { - stream_.reset(new std::ifstream(filename)); +Scanner::Scanner(const std::string &filename) + : stream_(new std::ifstream(filename)), parser_(*stream_) { Reset(); } void Scanner::Reset() { stream_->clear(); stream_->seekg(0, std::ios::beg); - ParseNextChunk(); + parser_.Init(); } std::string Scanner::Next() { - PADDLE_ENFORCE(!eof_, "StopIteration"); - auto rec = cur_chunk_.Record(offset_++); - if (offset_ == cur_chunk_.NumRecords()) { - ParseNextChunk(); + if (stream_->eof()) { + return ""; } - return rec; -} -void Scanner::ParseNextChunk() { - eof_ = !cur_chunk_.Parse(*stream_); - offset_ = 0; + auto res = parser_.Next(); + if (!parser_.HasNext() && HasNext()) { + parser_.Init(); + } + return res; } -bool Scanner::HasNext() const { return !eof_; } +bool Scanner::HasNext() const { return !stream_->eof(); } } // namespace recordio } // namespace paddle diff --git a/paddle/fluid/recordio/scanner.h b/paddle/fluid/recordio/scanner.h index 34f1b0c78d6b5af6072a993579e1866d38c6d009..0d885dd87a2f819ba1d9f76259196f6cfff0b2a0 100644 --- a/paddle/fluid/recordio/scanner.h +++ b/paddle/fluid/recordio/scanner.h @@ -37,11 +37,7 @@ class Scanner { private: std::unique_ptr stream_; - Chunk cur_chunk_; - size_t offset_; - bool eof_; - - void ParseNextChunk(); + ChunkParser parser_; }; } // namespace recordio } // namespace paddle