提交 15193c9e 编写于 作者: Y yuyang18

Faster RecordIO Scanner

上级 86efecb9
...@@ -119,40 +119,56 @@ bool Chunk::Write(std::ostream& os, Compressor ct) const { ...@@ -119,40 +119,56 @@ bool Chunk::Write(std::ostream& os, Compressor ct) const {
} }
bool Chunk::Parse(std::istream& sin) { bool Chunk::Parse(std::istream& sin) {
Header hdr; ChunkParser parser(sin);
bool ok = hdr.Parse(sin); if (!parser.Init()) {
return false;
}
Clear();
while (parser.HasNext()) {
Add(parser.Next());
}
return true;
}
ChunkParser::ChunkParser(std::istream& sin) : in_(sin) {}
bool ChunkParser::Init() {
pos_ = 0;
bool ok = header_.Parse(in_);
if (!ok) { if (!ok) {
return ok; return ok;
} }
auto beg_pos = sin.tellg(); auto beg_pos = in_.tellg();
uint32_t crc = Crc32Stream(sin, hdr.CompressSize()); uint32_t crc = Crc32Stream(in_, header_.CompressSize());
PADDLE_ENFORCE_EQ(hdr.Checksum(), crc); PADDLE_ENFORCE_EQ(header_.Checksum(), crc);
Clear(); in_.seekg(beg_pos, in_.beg);
sin.seekg(beg_pos, sin.beg);
std::unique_ptr<std::istream> compressed_stream; switch (header_.CompressType()) {
switch (hdr.CompressType()) {
case Compressor::kNoCompress: case Compressor::kNoCompress:
break; break;
case Compressor::kSnappy: case Compressor::kSnappy:
compressed_stream.reset(new snappy::iSnappyStream(sin)); compressed_stream_.reset(new snappy::iSnappyStream(in_));
break; break;
default: default:
PADDLE_THROW("Not implemented"); PADDLE_THROW("Not implemented");
} }
return true;
}
std::istream& stream = compressed_stream ? *compressed_stream : sin; bool ChunkParser::HasNext() const { return pos_ < header_.NumRecords(); }
for (uint32_t i = 0; i < hdr.NumRecords(); ++i) { std::string ChunkParser::Next() {
uint32_t rec_len; if (!HasNext()) {
stream.read(reinterpret_cast<char*>(&rec_len), sizeof(uint32_t)); return "";
std::string buf;
buf.resize(rec_len);
stream.read(&buf[0], rec_len);
PADDLE_ENFORCE_EQ(rec_len, stream.gcount());
Add(buf);
} }
return true; ++pos_;
std::istream& stream = compressed_stream_ ? *compressed_stream_ : in_;
uint32_t rec_len;
stream.read(reinterpret_cast<char*>(&rec_len), sizeof(uint32_t));
std::string buf;
buf.resize(rec_len);
stream.read(&buf[0], rec_len);
PADDLE_ENFORCE_EQ(rec_len, stream.gcount());
return buf;
} }
} // namespace recordio } // namespace recordio
} // namespace paddle } // namespace paddle
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -53,9 +54,20 @@ class Chunk { ...@@ -53,9 +54,20 @@ class Chunk {
DISABLE_COPY_AND_ASSIGN(Chunk); DISABLE_COPY_AND_ASSIGN(Chunk);
}; };
size_t CompressData(const char* in, size_t in_length, Compressor ct, char* out); class ChunkParser {
public:
explicit ChunkParser(std::istream& sin);
bool Init();
std::string Next();
bool HasNext() const;
void DeflateData(const char* in, size_t in_length, Compressor ct, char* out); private:
Header header_;
uint32_t pos_{0};
std::istream& in_;
std::unique_ptr<std::istream> compressed_stream_;
};
} // namespace recordio } // namespace recordio
} // namespace paddle } // namespace paddle
...@@ -22,35 +22,33 @@ namespace paddle { ...@@ -22,35 +22,33 @@ namespace paddle {
namespace recordio { namespace recordio {
Scanner::Scanner(std::unique_ptr<std::istream> &&stream) Scanner::Scanner(std::unique_ptr<std::istream> &&stream)
: stream_(std::move(stream)) { : stream_(std::move(stream)), parser_(*stream_) {
Reset(); Reset();
} }
Scanner::Scanner(const std::string &filename) { Scanner::Scanner(const std::string &filename)
stream_.reset(new std::ifstream(filename)); : stream_(new std::ifstream(filename)), parser_(*stream_) {
Reset(); Reset();
} }
void Scanner::Reset() { void Scanner::Reset() {
stream_->clear(); stream_->clear();
stream_->seekg(0, std::ios::beg); stream_->seekg(0, std::ios::beg);
ParseNextChunk(); parser_.Init();
} }
std::string Scanner::Next() { std::string Scanner::Next() {
PADDLE_ENFORCE(!eof_, "StopIteration"); if (stream_->eof()) {
auto rec = cur_chunk_.Record(offset_++); return "";
if (offset_ == cur_chunk_.NumRecords()) {
ParseNextChunk();
} }
return rec;
}
void Scanner::ParseNextChunk() { auto res = parser_.Next();
eof_ = !cur_chunk_.Parse(*stream_); if (!parser_.HasNext() && HasNext()) {
offset_ = 0; parser_.Init();
}
return res;
} }
bool Scanner::HasNext() const { return !eof_; } bool Scanner::HasNext() const { return !stream_->eof(); }
} // namespace recordio } // namespace recordio
} // namespace paddle } // namespace paddle
...@@ -37,11 +37,7 @@ class Scanner { ...@@ -37,11 +37,7 @@ class Scanner {
private: private:
std::unique_ptr<std::istream> stream_; std::unique_ptr<std::istream> stream_;
Chunk cur_chunk_; ChunkParser parser_;
size_t offset_;
bool eof_;
void ParseNextChunk();
}; };
} // namespace recordio } // namespace recordio
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册