提交 bcb80756 编写于 作者: Y Yu Yang

Add Writer/Scanner

Make vec<Tensor> can be serialized to RecordIO
上级 10343123
...@@ -21,7 +21,7 @@ endif() ...@@ -21,7 +21,7 @@ endif()
cc_test(eigen_test SRCS eigen_test.cc DEPS tensor) cc_test(eigen_test SRCS eigen_test.cc DEPS tensor)
nv_test(mixed_vector_test SRCS mixed_vector_test.cu DEPS place paddle_memory device_context init) nv_test(mixed_vector_test SRCS mixed_vector_test.cu DEPS place paddle_memory device_context init)
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto) cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto recordio)
cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor paddle_memory) cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor paddle_memory)
nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor init) nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor init)
......
...@@ -19,6 +19,9 @@ limitations under the License. */ ...@@ -19,6 +19,9 @@ limitations under the License. */
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/memory/memory.h" #include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/recordio/scanner.h"
#include "paddle/fluid/recordio/writer.h"
#include <stdint.h> #include <stdint.h>
#include <string.h> #include <string.h>
#include <algorithm> #include <algorithm>
...@@ -291,6 +294,31 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor, ...@@ -291,6 +294,31 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
TensorFromStream(is, static_cast<Tensor *>(tensor), dev_ctx); TensorFromStream(is, static_cast<Tensor *>(tensor), dev_ctx);
} }
void WriteToRecordIO(recordio::Writer &writer,
const std::vector<LoDTensor> &tensor,
const platform::DeviceContext &dev_ctx) {
std::stringstream buffer;
size_t sz = tensor.size();
buffer.write(reinterpret_cast<const char *>(&sz), sizeof(uint32_t));
for (auto &each : tensor) {
SerializeToStream(buffer, each, dev_ctx);
}
writer.Write(buffer.str());
}
std::vector<LoDTensor> ReadFromRecordIO(
recordio::Scanner &scanner, const platform::DeviceContext &dev_ctx) {
std::istringstream sin(scanner.Next());
uint32_t sz;
sin.read(reinterpret_cast<char *>(&sz), sizeof(uint32_t));
std::vector<LoDTensor> result;
result.resize(sz);
for (uint32_t i = 0; i < sz; ++i) {
DeserializeFromStream(sin, &result[i], dev_ctx);
}
return result;
}
std::vector<LoDTensor> LoDTensor::SplitLoDTensor( std::vector<LoDTensor> LoDTensor::SplitLoDTensor(
const std::vector<platform::Place> places) const { const std::vector<platform::Place> places) const {
check_memory_size(); check_memory_size();
......
...@@ -29,6 +29,12 @@ limitations under the License. */ ...@@ -29,6 +29,12 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
namespace paddle { namespace paddle {
namespace recordio {
class Writer;
class Scanner;
}
namespace framework { namespace framework {
/* /*
...@@ -209,5 +215,12 @@ void SerializeToStream(std::ostream& os, const LoDTensor& tensor, ...@@ -209,5 +215,12 @@ void SerializeToStream(std::ostream& os, const LoDTensor& tensor,
void DeserializeFromStream(std::istream& is, LoDTensor* tensor, void DeserializeFromStream(std::istream& is, LoDTensor* tensor,
const platform::DeviceContext& dev_ctx); const platform::DeviceContext& dev_ctx);
extern void WriteToRecordIO(recordio::Writer& writer,
const std::vector<LoDTensor>& tensor,
const platform::DeviceContext& dev_ctx);
extern std::vector<LoDTensor> ReadFromRecordIO(
recordio::Scanner& scanner, const platform::DeviceContext& dev_ctx);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -14,6 +14,9 @@ ...@@ -14,6 +14,9 @@
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/recordio/scanner.h"
#include "paddle/fluid/recordio/writer.h"
#include <glog/logging.h> #include <glog/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <algorithm> #include <algorithm>
...@@ -224,5 +227,43 @@ TEST(LoD, CheckAbsLoD) { ...@@ -224,5 +227,43 @@ TEST(LoD, CheckAbsLoD) {
abs_lod0.push_back(std::vector<size_t>({0})); abs_lod0.push_back(std::vector<size_t>({0}));
ASSERT_FALSE(CheckAbsLoD(abs_lod0)); ASSERT_FALSE(CheckAbsLoD(abs_lod0));
} }
TEST(LoDTensor, RecordIO) {
LoDTensor tensor;
int* tmp = tensor.mutable_data<int>(make_ddim({4, 5}), platform::CPUPlace());
for (int i = 0; i < 20; ++i) {
tmp[i] = i;
}
std::stringstream* stream = new std::stringstream();
auto& ctx =
*platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
{
recordio::Writer writer(stream, recordio::Compressor::kSnappy);
WriteToRecordIO(writer, {tensor, tensor}, ctx);
WriteToRecordIO(writer, {tensor, tensor}, ctx);
writer.Flush();
}
auto assert_tensor_ok = [](const LoDTensor& tensor) {
for (int i = 0; i < 20; ++i) {
ASSERT_EQ(tensor.data<int>()[i], i);
}
};
{
std::unique_ptr<std::istream> stream_ptr(stream);
recordio::Scanner scanner(std::move(stream_ptr));
auto tensors = ReadFromRecordIO(scanner, ctx);
ASSERT_EQ(tensors.size(), 2);
assert_tensor_ok(tensors[0]);
assert_tensor_ok(tensors[1]);
tensors = ReadFromRecordIO(scanner, ctx);
ASSERT_EQ(tensors.size(), 2);
assert_tensor_ok(tensors[0]);
assert_tensor_ok(tensors[1]);
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -3,4 +3,7 @@ cc_library(header SRCS header.cc) ...@@ -3,4 +3,7 @@ cc_library(header SRCS header.cc)
cc_test(header_test SRCS header_test.cc DEPS header) cc_test(header_test SRCS header_test.cc DEPS header)
cc_library(chunk SRCS chunk.cc DEPS snappystream snappy header zlib) cc_library(chunk SRCS chunk.cc DEPS snappystream snappy header zlib)
cc_test(chunk_test SRCS chunk_test.cc DEPS chunk) cc_test(chunk_test SRCS chunk_test.cc DEPS chunk)
cc_library(recordio DEPS chunk header) cc_library(writer SRCS writer.cc DEPS chunk)
cc_library(scanner SRCS scanner.cc DEPS chunk)
cc_test(writer_scanner_test SRCS writer_scanner_test.cc DEPS writer scanner)
cc_library(recordio DEPS chunk header writer scanner)
...@@ -97,9 +97,12 @@ bool Chunk::Write(std::ostream& os, Compressor ct) const { ...@@ -97,9 +97,12 @@ bool Chunk::Write(std::ostream& os, Compressor ct) const {
return true; return true;
} }
void Chunk::Parse(std::istream& sin) { bool Chunk::Parse(std::istream& sin) {
Header hdr; Header hdr;
hdr.Parse(sin); bool ok = hdr.Parse(sin);
if (!ok) {
return ok;
}
auto beg_pos = sin.tellg(); auto beg_pos = sin.tellg();
auto crc = Crc32Stream(sin, hdr.CompressSize()); auto crc = Crc32Stream(sin, hdr.CompressSize());
PADDLE_ENFORCE_EQ(hdr.Checksum(), crc); PADDLE_ENFORCE_EQ(hdr.Checksum(), crc);
...@@ -128,6 +131,7 @@ void Chunk::Parse(std::istream& sin) { ...@@ -128,6 +131,7 @@ void Chunk::Parse(std::istream& sin) {
stream.read(&buf[0], rec_len); stream.read(&buf[0], rec_len);
Add(buf); Add(buf);
} }
return true;
} }
} // namespace recordio } // namespace recordio
......
...@@ -26,9 +26,9 @@ namespace recordio { ...@@ -26,9 +26,9 @@ namespace recordio {
class Chunk { class Chunk {
public: public:
Chunk() : num_bytes_(0) {} Chunk() : num_bytes_(0) {}
void Add(std::string buf) { void Add(const std::string& buf) {
records_.push_back(buf);
num_bytes_ += buf.size(); num_bytes_ += buf.size();
records_.emplace_back(buf);
} }
// dump the chunk into w, and clears the chunk and makes it ready for // dump the chunk into w, and clears the chunk and makes it ready for
// the next add invocation. // the next add invocation.
...@@ -37,10 +37,15 @@ public: ...@@ -37,10 +37,15 @@ public:
records_.clear(); records_.clear();
num_bytes_ = 0; num_bytes_ = 0;
} }
void Parse(std::istream& sin);
size_t NumBytes() { return num_bytes_; } // returns true if ok, false if eof
bool Parse(std::istream& sin);
size_t NumBytes() const { return num_bytes_; }
size_t NumRecords() const { return records_.size(); }
const std::string& Record(int i) const { return records_[i]; } const std::string& Record(int i) const { return records_[i]; }
bool Empty() const { return records_.empty(); }
private: private:
std::vector<std::string> records_; std::vector<std::string> records_;
// sum of record lengths in bytes. // sum of record lengths in bytes.
......
...@@ -27,15 +27,20 @@ Header::Header() ...@@ -27,15 +27,20 @@ Header::Header()
Header::Header(uint32_t num, uint32_t sum, Compressor c, uint32_t cs) Header::Header(uint32_t num, uint32_t sum, Compressor c, uint32_t cs)
: num_records_(num), checksum_(sum), compressor_(c), compress_size_(cs) {} : num_records_(num), checksum_(sum), compressor_(c), compress_size_(cs) {}
void Header::Parse(std::istream& is) { bool Header::Parse(std::istream& is) {
uint32_t magic; uint32_t magic;
is.read(reinterpret_cast<char*>(&magic), sizeof(uint32_t)); size_t read_size =
is.readsome(reinterpret_cast<char*>(&magic), sizeof(uint32_t));
if (read_size < sizeof(uint32_t)) {
return false;
}
PADDLE_ENFORCE_EQ(magic, kMagicNumber); PADDLE_ENFORCE_EQ(magic, kMagicNumber);
is.read(reinterpret_cast<char*>(&num_records_), sizeof(uint32_t)) is.read(reinterpret_cast<char*>(&num_records_), sizeof(uint32_t))
.read(reinterpret_cast<char*>(&checksum_), sizeof(uint32_t)) .read(reinterpret_cast<char*>(&checksum_), sizeof(uint32_t))
.read(reinterpret_cast<char*>(&compressor_), sizeof(uint32_t)) .read(reinterpret_cast<char*>(&compressor_), sizeof(uint32_t))
.read(reinterpret_cast<char*>(&compress_size_), sizeof(uint32_t)); .read(reinterpret_cast<char*>(&compress_size_), sizeof(uint32_t));
return true;
} }
void Header::Write(std::ostream& os) const { void Header::Write(std::ostream& os) const {
......
...@@ -42,7 +42,9 @@ public: ...@@ -42,7 +42,9 @@ public:
Header(uint32_t num, uint32_t sum, Compressor ct, uint32_t cs); Header(uint32_t num, uint32_t sum, Compressor ct, uint32_t cs);
void Write(std::ostream& os) const; void Write(std::ostream& os) const;
void Parse(std::istream& is);
// returns true if OK, false if eof
bool Parse(std::istream& is);
uint32_t NumRecords() const { return num_records_; } uint32_t NumRecords() const { return num_records_; }
uint32_t Checksum() const { return checksum_; } uint32_t Checksum() const { return checksum_; }
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/recordio/scanner.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace recordio {
Scanner::Scanner(std::unique_ptr<std::istream> &&stream)
: stream_(std::move(stream)) {
Reset();
}
Scanner::Scanner(const std::string &filename) {
stream_.reset(new std::ifstream(filename));
Reset();
}
void Scanner::Reset() {
stream_->seekg(0, std::ios::beg);
ParseNextChunk();
}
const std::string &Scanner::Next() {
PADDLE_ENFORCE(!eof_, "StopIteration");
auto &rec = cur_chunk_.Record(offset_++);
if (offset_ == cur_chunk_.NumRecords()) {
ParseNextChunk();
}
return rec;
}
void Scanner::ParseNextChunk() {
eof_ = !cur_chunk_.Parse(*stream_);
offset_ = 0;
}
bool Scanner::HasNext() const { return !eof_; }
} // namespace recordio
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <fstream>
#include <memory>
#include "paddle/fluid/recordio/chunk.h"
namespace paddle {
namespace recordio {
class Scanner {
public:
explicit Scanner(std::unique_ptr<std::istream>&& stream);
explicit Scanner(const std::string& filename);
void Reset();
const std::string& Next();
bool HasNext() const;
private:
std::unique_ptr<std::istream> stream_;
Chunk cur_chunk_;
size_t offset_;
bool eof_;
void ParseNextChunk();
};
} // namespace recordio
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/recordio/writer.h"
namespace paddle {
namespace recordio {
void Writer::Write(const std::string& record) {
cur_chunk_.Add(record);
if (cur_chunk_.NumRecords() >= max_num_records_in_chunk_) {
Flush();
}
}
void Writer::Flush() {
cur_chunk_.Write(stream_, compressor_);
cur_chunk_.Clear();
}
Writer::~Writer() {
PADDLE_ENFORCE(cur_chunk_.Empty(), "Writer must be flushed when destroy.");
}
} // namespace recordio
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/recordio/chunk.h"
namespace paddle {
namespace recordio {
class Writer {
public:
Writer(std::ostream* sout,
Compressor compressor,
size_t max_num_records_in_chunk = 1000)
: stream_(*sout),
max_num_records_in_chunk_(max_num_records_in_chunk),
compressor_(compressor) {}
void Write(const std::string& record);
void Flush();
~Writer();
private:
std::ostream& stream_;
size_t max_num_records_in_chunk_;
Chunk cur_chunk_;
Compressor compressor_;
};
} // namespace recordio
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gtest/gtest.h"
#include <sstream>
#include "paddle/fluid/recordio/scanner.h"
#include "paddle/fluid/recordio/writer.h"
TEST(WriterScanner, Normal) {
std::stringstream* stream = new std::stringstream();
{
paddle::recordio::Writer writer(stream,
paddle::recordio::Compressor::kSnappy);
writer.Write("ABC");
writer.Write("BCD");
writer.Write("CDE");
writer.Flush();
}
{
stream->seekg(0, std::ios::beg);
std::unique_ptr<std::istream> stream_ptr(stream);
paddle::recordio::Scanner scanner(std::move(stream_ptr));
ASSERT_TRUE(scanner.HasNext());
ASSERT_EQ(scanner.Next(), "ABC");
ASSERT_EQ("BCD", scanner.Next());
ASSERT_TRUE(scanner.HasNext());
ASSERT_EQ("CDE", scanner.Next());
ASSERT_FALSE(scanner.HasNext());
}
}
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册