提交 15193c9e 编写于 作者: Y yuyang18

Faster RecordIO Scanner

上级 86efecb9
......@@ -119,40 +119,56 @@ bool Chunk::Write(std::ostream& os, Compressor ct) const {
}
bool Chunk::Parse(std::istream& sin) {
Header hdr;
bool ok = hdr.Parse(sin);
ChunkParser parser(sin);
if (!parser.Init()) {
return false;
}
Clear();
while (parser.HasNext()) {
Add(parser.Next());
}
return true;
}
ChunkParser::ChunkParser(std::istream& sin) : in_(sin) {}
bool ChunkParser::Init() {
pos_ = 0;
bool ok = header_.Parse(in_);
if (!ok) {
return ok;
}
auto beg_pos = sin.tellg();
uint32_t crc = Crc32Stream(sin, hdr.CompressSize());
PADDLE_ENFORCE_EQ(hdr.Checksum(), crc);
Clear();
sin.seekg(beg_pos, sin.beg);
std::unique_ptr<std::istream> compressed_stream;
switch (hdr.CompressType()) {
auto beg_pos = in_.tellg();
uint32_t crc = Crc32Stream(in_, header_.CompressSize());
PADDLE_ENFORCE_EQ(header_.Checksum(), crc);
in_.seekg(beg_pos, in_.beg);
switch (header_.CompressType()) {
case Compressor::kNoCompress:
break;
case Compressor::kSnappy:
compressed_stream.reset(new snappy::iSnappyStream(sin));
compressed_stream_.reset(new snappy::iSnappyStream(in_));
break;
default:
PADDLE_THROW("Not implemented");
}
return true;
}
std::istream& stream = compressed_stream ? *compressed_stream : sin;
bool ChunkParser::HasNext() const { return pos_ < header_.NumRecords(); }
for (uint32_t i = 0; i < hdr.NumRecords(); ++i) {
std::string ChunkParser::Next() {
if (!HasNext()) {
return "";
}
++pos_;
std::istream& stream = compressed_stream_ ? *compressed_stream_ : in_;
uint32_t rec_len;
stream.read(reinterpret_cast<char*>(&rec_len), sizeof(uint32_t));
std::string buf;
buf.resize(rec_len);
stream.read(&buf[0], rec_len);
PADDLE_ENFORCE_EQ(rec_len, stream.gcount());
Add(buf);
}
return true;
return buf;
}
} // namespace recordio
} // namespace paddle
......@@ -13,6 +13,7 @@
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
......@@ -53,9 +54,20 @@ class Chunk {
DISABLE_COPY_AND_ASSIGN(Chunk);
};
size_t CompressData(const char* in, size_t in_length, Compressor ct, char* out);
class ChunkParser {
public:
explicit ChunkParser(std::istream& sin);
bool Init();
std::string Next();
bool HasNext() const;
void DeflateData(const char* in, size_t in_length, Compressor ct, char* out);
private:
Header header_;
uint32_t pos_{0};
std::istream& in_;
std::unique_ptr<std::istream> compressed_stream_;
};
} // namespace recordio
} // namespace paddle
......@@ -22,35 +22,33 @@ namespace paddle {
namespace recordio {
Scanner::Scanner(std::unique_ptr<std::istream> &&stream)
: stream_(std::move(stream)) {
: stream_(std::move(stream)), parser_(*stream_) {
Reset();
}
Scanner::Scanner(const std::string &filename) {
stream_.reset(new std::ifstream(filename));
Scanner::Scanner(const std::string &filename)
: stream_(new std::ifstream(filename)), parser_(*stream_) {
Reset();
}
void Scanner::Reset() {
stream_->clear();
stream_->seekg(0, std::ios::beg);
ParseNextChunk();
parser_.Init();
}
std::string Scanner::Next() {
PADDLE_ENFORCE(!eof_, "StopIteration");
auto rec = cur_chunk_.Record(offset_++);
if (offset_ == cur_chunk_.NumRecords()) {
ParseNextChunk();
if (stream_->eof()) {
return "";
}
return rec;
}
void Scanner::ParseNextChunk() {
eof_ = !cur_chunk_.Parse(*stream_);
offset_ = 0;
auto res = parser_.Next();
if (!parser_.HasNext() && HasNext()) {
parser_.Init();
}
return res;
}
bool Scanner::HasNext() const { return !eof_; }
bool Scanner::HasNext() const { return !stream_->eof(); }
} // namespace recordio
} // namespace paddle
......@@ -37,11 +37,7 @@ class Scanner {
private:
std::unique_ptr<std::istream> stream_;
Chunk cur_chunk_;
size_t offset_;
bool eof_;
void ParseNextChunk();
ChunkParser parser_;
};
} // namespace recordio
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册