提交 7364348d 编写于 作者: D dongzhihong

"move from recordio repo to paddle"

上级 7016979c
...@@ -144,6 +144,7 @@ include(external/eigen) # download eigen3 ...@@ -144,6 +144,7 @@ include(external/eigen) # download eigen3
include(external/pybind11) # download pybind11 include(external/pybind11) # download pybind11
include(external/cares) include(external/cares)
include(external/grpc) include(external/grpc)
include(external/snappy) # download snappy
include(cudnn) # set cudnn libraries, must before configure include(cudnn) # set cudnn libraries, must before configure
include(cupti) include(cupti)
......
...@@ -26,7 +26,7 @@ namespace paddle { ...@@ -26,7 +26,7 @@ namespace paddle {
namespace recordio { namespace recordio {
void Chunk::Add(const char* record, size_t length) { void Chunk::Add(const char* record, size_t length) {
records_.emplace_after(std::move(s)); records_.emplace_after(std::string(record, length));
num_bytes_ += s.size() * sizeof(char); num_bytes_ += s.size() * sizeof(char);
} }
...@@ -42,13 +42,16 @@ bool Chunk::Dump(Stream* fo, Compressor ct) { ...@@ -42,13 +42,16 @@ bool Chunk::Dump(Stream* fo, Compressor ct) {
os.write(record.data(), static_cast<std::streamsize>(record.size())); os.write(record.data(), static_cast<std::streamsize>(record.size()));
} }
std::unique_ptr<char[]> buffer(new char[kDefaultMaxChunkSize]); std::unique_ptr<char[]> buffer(new char[num_bytes_]);
size_t compressed = size_t compressed =
CompressData(os.str().c_str(), num_bytes_, ct, buffer.get()); CompressData(os.str().c_str(), num_bytes_, ct, buffer.get());
uint32_t checksum = Crc32(buffer.get(), compressed); uint32_t checksum = Crc32(buffer.get(), compressed);
Header hdr(records_.size(), checksum, ct, static_cast<uint32_t>(compressed)); Header hdr(records_.size(), checksum, ct, static_cast<uint32_t>(compressed));
hdr.Write(fo); hdr.Write(fo);
fo.Write(buffer.get(), compressed); fo.Write(buffer.get(), compressed);
// clear the content
records_.clear();
num_bytes_ = 0;
return true; return true;
} }
...@@ -57,14 +60,18 @@ void Chunk::Parse(Stream* fi, size_t offset) { ...@@ -57,14 +60,18 @@ void Chunk::Parse(Stream* fi, size_t offset) {
Header hdr; Header hdr;
hdr.Parse(fi); hdr.Parse(fi);
std::unique_ptr<char[]> buffer(new char[kDefaultMaxChunkSize]); size_t size = static_cast<size_t>(hdr.CompressSize());
fi->Read(buffer.get(), static_cast<size_t>(hdr.CompressSize())); std::unique_ptr<char[]> buffer(new char[size]);
uint32_t deflated_size = fi->Read(buffer.get(), size);
DeflateData(buffer.get(), hdr.CompressSize(), hdr.CompressType()); size_t deflated_size = 0;
std::istringstream deflated(std::string(buffer.get(), deflated_size)); snappy::GetUncompressedLength(buffer.get(), size, &deflated_size);
std::unique_ptr<char[]> deflated_buffer(new char[deflated_size]);
DeflateData(buffer.get(), size, hdr.CompressType(), deflated_buffer.get());
std::istringstream deflated(
std::string(deflated_buffer.get(), deflated_size));
for (size_t i = 0; i < hdr.NumRecords(); ++i) { for (size_t i = 0; i < hdr.NumRecords(); ++i) {
uint32_t rs; size_t rs;
deflated >> rs; deflated.read(&rs, sizeof(size_t));
std::string record(rs, '\0'); std::string record(rs, '\0');
deflated.read(&record[0], rs); deflated.read(&record[0], rs);
records_.emplace_back(record); records_.emplace_back(record);
......
...@@ -25,7 +25,7 @@ namespace recordio { ...@@ -25,7 +25,7 @@ namespace recordio {
// A Chunk contains the Header and optionally compressed records. // A Chunk contains the Header and optionally compressed records.
class Chunk { class Chunk {
public: public:
Chunk() {} Chunk() : num_bytes_(0) {}
void Add(const char* record, size_t size); void Add(const char* record, size_t size);
// dump the chunk into w, and clears the chunk and makes it ready for // dump the chunk into w, and clears the chunk and makes it ready for
// the next add invocation. // the next add invocation.
......
...@@ -20,4 +20,36 @@ ...@@ -20,4 +20,36 @@
using namespace paddle::recordio; using namespace paddle::recordio;
TEST(Chunk, SaveLoad) {} TEST(Chunk, SaveLoad) {
Chunk ch;
ch.Add("12345", 6);
ch.Add("123", 4);
{
Stream* fs = Stream::Open("/tmp/record_11", "w");
ch.Dump(fs, Compressor::kNoCompress);
EXPECT_EQ(ch.NumBytes(), 0);
}
{
Stream* fs = Stream::Open("/tmp/record_11", "r");
ch.Parse(fs, 0);
EXPECT_EQ(ch.NumBytes(), 10);
}
}
TEST(Chunk, Compressor) {
Chunk ch;
ch.Add("12345", 6);
ch.Add("123", 4);
ch.Add("123", 4);
ch.Add("123", 4);
{
Stream* fs = Stream::Open("/tmp/record_12", "w");
ch.Dump(fs, Compressor::kSnappy);
EXPECT_EQ(ch.NumBytes(), 0);
}
{
Stream* fs = Stream::Open("/tmp/record_12", "r");
ch.Parse(fs, 0);
EXPECT_EQ(ch.NumBytes(), 10);
}
}
...@@ -27,27 +27,19 @@ Header::Header(uint32_t num, uint32_t sum, Compressor c, uint32_t cs) ...@@ -27,27 +27,19 @@ Header::Header(uint32_t num, uint32_t sum, Compressor c, uint32_t cs)
: num_records_(num), checksum_(sum), compressor_(c), compress_size_(cs) {} : num_records_(num), checksum_(sum), compressor_(c), compress_size_(cs) {}
void Header::Parse(Stream* iss) { void Header::Parse(Stream* iss) {
iss.Read(reinterpret_cast<char*>(&num_records_), sizeof(uint32_t)); iss->Read(reinterpret_cast<char*>(&num_records_), sizeof(uint32_t));
iss.Read(reinterpret_cast<char*>(&checksum_), sizeof(uint32_t)); iss->Read(reinterpret_cast<char*>(&checksum_), sizeof(uint32_t));
iss.Read(reinterpret_cast<char*>(&compressor_), sizeof(uint32_t)); iss->Read(reinterpret_cast<char*>(&compressor_), sizeof(uint32_t));
iss.Read(reinterpret_cast<char*>(&compress_size_), sizeof(uint32_t)); iss->Read(reinterpret_cast<char*>(&compress_size_), sizeof(uint32_t));
} }
void Header::Write(Stream* os) { void Header::Write(Stream* os) {
os.Write(reinterpret_cast<char*>(&num_records_), sizeof(uint32_t)); os->Write(reinterpret_cast<char*>(&num_records_), sizeof(uint32_t));
os.Write(reinterpret_cast<char*>(&checksum_), sizeof(uint32_t)); os->Write(reinterpret_cast<char*>(&checksum_), sizeof(uint32_t));
os.Write(reinterpret_cast<char*>(&compressor_), sizeof(uint32_t)); os->Write(reinterpret_cast<char*>(&compressor_), sizeof(uint32_t));
os.Write(reinterpret_cast<char*>(&compress_size_), sizeof(uint32_t)); os->Write(reinterpret_cast<char*>(&compress_size_), sizeof(uint32_t));
} }
// std::ostream& operator << (std::ostream& os, Header h) {
// os << h.num_records_
// << h.checksum_
// << static_cast<uint32_t>(h.compressor_)
// << h.compress_size_;
// return os;
// }
std::ostream& operator<<(std::ostream& os, Header h) { std::ostream& operator<<(std::ostream& os, Header h) {
os << h.NumRecords() << h.Checksum() os << h.NumRecords() << h.Checksum()
<< static_cast<uint32_t>(h.CompressType()) << h.CompressSize(); << static_cast<uint32_t>(h.CompressType()) << h.CompressSize();
...@@ -59,3 +51,6 @@ bool operator==(Header l, Header r) { ...@@ -59,3 +51,6 @@ bool operator==(Header l, Header r) {
l.CompressType() == r.CompressType() && l.CompressType() == r.CompressType() &&
l.CompressSize() == r.CompressSize(); l.CompressSize() == r.CompressSize();
} }
} // namespace recordio
} // namespace paddle
...@@ -23,11 +23,11 @@ using namespace paddle::recordio; ...@@ -23,11 +23,11 @@ using namespace paddle::recordio;
TEST(Recordio, ChunkHead) { TEST(Recordio, ChunkHead) {
Header hdr(0, 1, Compressor::kGzip, 3); Header hdr(0, 1, Compressor::kGzip, 3);
Stream* oss = Stream::Open("/tmp/record_1", "w"); Stream* oss = Stream::Open("/tmp/record_1", "w");
hdr.Write(oss); hdr->Write(oss);
Stream* iss = Stream::Open("/tmp/record_1", "r"); // Stream* iss = Stream::Open("/tmp/record_1", "r");
Header hdr2; // Header hdr2;
hdr2.Parse(iss); // hdr2.Parse(iss);
EXPECT_TRUE(hdr == hdr2); // EXPECT_TRUE(hdr == hdr2);
} }
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/recordio/range_scanner.h"
namespace paddle {
namespace recordio {
Index Index::ChunkIndex(int i) { Index idx; }
RangeScanner::RangeScanner(std::istream is, Index idx, int start, int len)
: stream_(is.rdbuf()), index_(idx) {
if (start < 0) {
start = 0;
}
if (len < 0 || start + len >= idx.NumRecords()) {
len = idx.NumRecords() - start;
}
start_ = start;
end_ = start + len;
cur_ = start - 1;
chunk_index_ = -1;
// chunk_->reset(new Chunk());
}
bool RangeScanner::Scan() {}
const std::string RangeScanner::Record() {
// int i = index_.Locate(cur_);
// return chunk_->Record(i);
}
} // namespace recordio
} // namespace paddle
...@@ -14,16 +14,23 @@ ...@@ -14,16 +14,23 @@
#pragma once #pragma once
#include <fstream> #include "paddle/fluid/recordio/io.h"
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
namespace paddle {
namespace recordio {
// Index consists offsets and sizes of the consequetive chunks in a RecordIO
// file.
//
// Index supports Gob. Every field in the Index needs to be exported
// for the correct encoding and decoding using Gob.
class Index { class Index {
public: public:
int NumRecords() { return num_records_; } int NumRecords() { return num_records_; }
// NumChunks returns the total number of chunks in a RecordIO file.
int NumChunks() { return chunk_lens_.size(); }
// ChunkIndex return the Index of i-th Chunk.
int ChunkIndex(int i);
// Locate returns the index of chunk that contains the given record, // Locate returns the index of chunk that contains the given record,
// and the record index within the chunk. It returns (-1, -1) if the // and the record index within the chunk. It returns (-1, -1) if the
...@@ -44,9 +51,13 @@ public: ...@@ -44,9 +51,13 @@ public:
} }
private: private:
// the offset of each chunk in a file.
std::vector<int64_t> chunk_offsets_; std::vector<int64_t> chunk_offsets_;
// the length of each chunk in a file.
std::vector<uint32_t> chunk_lens_; std::vector<uint32_t> chunk_lens_;
// the numer of all records in a file.
int num_records_; int num_records_;
// the number of records in chunks.
std::vector<int> chunk_records_; std::vector<int> chunk_records_;
}; };
...@@ -56,14 +67,17 @@ private: ...@@ -56,14 +67,17 @@ private:
// beginning. If len < 0, it scans till the end of file. // beginning. If len < 0, it scans till the end of file.
class RangeScanner { class RangeScanner {
public: public:
RangeScanner(std::istream is, Index idx, int start, int end); RangeScanner(Stream* fi, Index idx, int start, int end);
bool Scan(); bool Scan();
const std::string Record(); const std::string Record();
private: private:
std::istream stream_; Stream* fi;
Index index_; Index index_;
int start_, end_, cur_; int start_, end_, cur_;
int chunk_index_; int chunk_index_;
std::unique_ptr<Chunk> chunk_; std::unique_ptr<Chunk> chunk_;
}; };
} // namespace recordio
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/recordio/chunk.h"
#include <glob.h> // glob
namespace paddle {
namespace recordio {
Scanner::Scanner(const char* paths)
: cur_file_(nullptr), path_idx_(0), end_(false) {
glob_t glob_result;
glob(paths, GLOB_TILDE, NULL, &glob_result);
for (size_t i = 0; i < glob_result.gl_pathc; ++i) {
paths_.emplace_back(std::string(glob_result.gl_pathv[i]));
}
globfree(&glob_result);
}
bool Scanner::Scan() {
if (err_ == -1 || end_ == true) {
return false;
}
if (cur_scanner_ == nullptr) {
if (!NextFile()) {
end_ = true;
return false;
}
if (err_ == -1) {
return false;
}
}
if (!cur_scanner_->Scan()) {
if (err_ == -1) {
return false;
}
}
return true;
}
bool Scanner::NextFile() {}
} // namespace recordio
} // namespace paddle
...@@ -14,12 +14,10 @@ ...@@ -14,12 +14,10 @@
#pragma once #pragma once
#include <fstream> #include "paddle/fluid/recordio/io.h"
#include <memory>
#include <sstream> namespace paddle {
#include <string> namespace recordio {
#include <utility>
#include <vector>
class RangeScanner; class RangeScanner;
...@@ -30,16 +28,17 @@ public: ...@@ -30,16 +28,17 @@ public:
const std::string Record(); const std::string Record();
bool Scan(); bool Scan();
void Close(); void Close();
private:
bool NextFile(); bool NextFile();
int Err() { return err_; } int Err() { return err_; }
private: private:
std::vector<std::string> paths_; std::vector<std::string> paths_;
FILE* cur_file_; Stream* cur_file_;
RangeScanner* cur_scanner_; RangeScanner* cur_scanner_;
int path_idx_; int path_idx_;
bool end_; bool end_;
int err_; int err_;
}; };
} // namespace recordio
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/recordio/scanner.h"
#include "gtest/gtest.h"
using namespace paddle::recordio;
TEST(Scanner, Normal) { Scanner s("/tmp/record_*"); }
...@@ -18,4 +18,12 @@ ...@@ -18,4 +18,12 @@
using namespace paddle::recordio; using namespace paddle::recordio;
TEST(Writer, Normal) {} TEST(Writer, Normal) {
Stream* fs = Stream::Open("/tmp/record_21", "w");
Writer w(fs);
w.Write("123", 4);
// test exception
w.Close();
EXPECT_ANY_THROW(w.Write("123", 4));
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册