diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 48713f2c2ac62a37b7b7a4602f7f6a325aecb0b8..15e5574ecfd406b87db8370948352b7e736937ea 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -21,7 +21,7 @@ endif() cc_test(eigen_test SRCS eigen_test.cc DEPS tensor) nv_test(mixed_vector_test SRCS mixed_vector_test.cu DEPS place paddle_memory device_context init) -cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto) +cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto recordio) cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor paddle_memory) nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor init) diff --git a/paddle/fluid/framework/lod_tensor.cc b/paddle/fluid/framework/lod_tensor.cc index e2f4e9cad1996578b7c51257785e1273d126f80f..8155cb55a468a09320b1196b49fc3e34cea261b1 100644 --- a/paddle/fluid/framework/lod_tensor.cc +++ b/paddle/fluid/framework/lod_tensor.cc @@ -19,6 +19,9 @@ limitations under the License. */ #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memory.h" +#include "paddle/fluid/recordio/scanner.h" +#include "paddle/fluid/recordio/writer.h" + #include #include #include @@ -291,6 +294,31 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor, TensorFromStream(is, static_cast(tensor), dev_ctx); } +void WriteToRecordIO(recordio::Writer &writer, + const std::vector &tensor, + const platform::DeviceContext &dev_ctx) { + std::stringstream buffer; + size_t sz = tensor.size(); + buffer.write(reinterpret_cast(&sz), sizeof(uint32_t)); + for (auto &each : tensor) { + SerializeToStream(buffer, each, dev_ctx); + } + writer.Write(buffer.str()); +} + +std::vector ReadFromRecordIO( + recordio::Scanner &scanner, const platform::DeviceContext &dev_ctx) { + std::istringstream sin(scanner.Next()); + uint32_t sz; + sin.read(reinterpret_cast(&sz), sizeof(uint32_t)); + std::vector result; + result.resize(sz); + for (uint32_t i = 0; i < sz; ++i) { + DeserializeFromStream(sin, &result[i], dev_ctx); + } + return result; +} + std::vector LoDTensor::SplitLoDTensor( const std::vector places) const { check_memory_size(); diff --git a/paddle/fluid/framework/lod_tensor.h b/paddle/fluid/framework/lod_tensor.h index 94d5a6e9fd9b68d3d8230a8c258316efadda5a05..dee505fee0dccd8d60bb290a8bec4df243e504a2 100644 --- a/paddle/fluid/framework/lod_tensor.h +++ b/paddle/fluid/framework/lod_tensor.h @@ -29,6 +29,12 @@ limitations under the License. */ #include "paddle/fluid/platform/place.h" namespace paddle { + +namespace recordio { +class Writer; +class Scanner; +} + namespace framework { /* @@ -209,5 +215,12 @@ void SerializeToStream(std::ostream& os, const LoDTensor& tensor, void DeserializeFromStream(std::istream& is, LoDTensor* tensor, const platform::DeviceContext& dev_ctx); +extern void WriteToRecordIO(recordio::Writer& writer, + const std::vector& tensor, + const platform::DeviceContext& dev_ctx); + +extern std::vector ReadFromRecordIO( + recordio::Scanner& scanner, const platform::DeviceContext& dev_ctx); + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/lod_tensor_test.cc b/paddle/fluid/framework/lod_tensor_test.cc index 5e135192ce774ab5c351b89164be9d7600ae3640..e691e29383d4842b80769021e0e494967d38e9bb 100644 --- a/paddle/fluid/framework/lod_tensor_test.cc +++ b/paddle/fluid/framework/lod_tensor_test.cc @@ -14,6 +14,9 @@ #include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/recordio/scanner.h" +#include "paddle/fluid/recordio/writer.h" + #include #include #include @@ -224,5 +227,43 @@ TEST(LoD, CheckAbsLoD) { abs_lod0.push_back(std::vector({0})); ASSERT_FALSE(CheckAbsLoD(abs_lod0)); } + +TEST(LoDTensor, RecordIO) { + LoDTensor tensor; + int* tmp = tensor.mutable_data(make_ddim({4, 5}), platform::CPUPlace()); + for (int i = 0; i < 20; ++i) { + tmp[i] = i; + } + + std::stringstream* stream = new std::stringstream(); + auto& ctx = + *platform::DeviceContextPool::Instance().Get(platform::CPUPlace()); + { + recordio::Writer writer(stream, recordio::Compressor::kSnappy); + WriteToRecordIO(writer, {tensor, tensor}, ctx); + WriteToRecordIO(writer, {tensor, tensor}, ctx); + writer.Flush(); + } + + auto assert_tensor_ok = [](const LoDTensor& tensor) { + for (int i = 0; i < 20; ++i) { + ASSERT_EQ(tensor.data()[i], i); + } + }; + + { + std::unique_ptr stream_ptr(stream); + recordio::Scanner scanner(std::move(stream_ptr)); + auto tensors = ReadFromRecordIO(scanner, ctx); + ASSERT_EQ(tensors.size(), 2); + assert_tensor_ok(tensors[0]); + assert_tensor_ok(tensors[1]); + tensors = ReadFromRecordIO(scanner, ctx); + ASSERT_EQ(tensors.size(), 2); + assert_tensor_ok(tensors[0]); + assert_tensor_ok(tensors[1]); + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/recordio/CMakeLists.txt b/paddle/fluid/recordio/CMakeLists.txt index e1e7c2cdb3d0c960d5cd408420b5aaead73e70d7..92e97a6c85d7c8f01c8473feb9772f2285d49673 100644 --- a/paddle/fluid/recordio/CMakeLists.txt +++ b/paddle/fluid/recordio/CMakeLists.txt @@ -3,4 +3,7 @@ cc_library(header SRCS header.cc) cc_test(header_test SRCS header_test.cc DEPS header) cc_library(chunk SRCS chunk.cc DEPS snappystream snappy header zlib) cc_test(chunk_test SRCS chunk_test.cc DEPS chunk) -cc_library(recordio DEPS chunk header) +cc_library(writer SRCS writer.cc DEPS chunk) +cc_library(scanner SRCS scanner.cc DEPS chunk) +cc_test(writer_scanner_test SRCS writer_scanner_test.cc DEPS writer scanner) +cc_library(recordio DEPS chunk header writer scanner) diff --git a/paddle/fluid/recordio/chunk.cc b/paddle/fluid/recordio/chunk.cc index 587fd375c38ca83e1c65cb3ccc20b3509b6348c7..c504aa6859c0a0e3831461fc34bff80aaadb7ef5 100644 --- a/paddle/fluid/recordio/chunk.cc +++ b/paddle/fluid/recordio/chunk.cc @@ -97,9 +97,12 @@ bool Chunk::Write(std::ostream& os, Compressor ct) const { return true; } -void Chunk::Parse(std::istream& sin) { +bool Chunk::Parse(std::istream& sin) { Header hdr; - hdr.Parse(sin); + bool ok = hdr.Parse(sin); + if (!ok) { + return ok; + } auto beg_pos = sin.tellg(); auto crc = Crc32Stream(sin, hdr.CompressSize()); PADDLE_ENFORCE_EQ(hdr.Checksum(), crc); @@ -128,6 +131,7 @@ void Chunk::Parse(std::istream& sin) { stream.read(&buf[0], rec_len); Add(buf); } + return true; } } // namespace recordio diff --git a/paddle/fluid/recordio/chunk.h b/paddle/fluid/recordio/chunk.h index 0ba9c63abbe72e7a51ddb1af5f0d206aa9f6cc5b..bf20ebd455c26ddeebeeea8db04cf7103b0c085f 100644 --- a/paddle/fluid/recordio/chunk.h +++ b/paddle/fluid/recordio/chunk.h @@ -26,9 +26,9 @@ namespace recordio { class Chunk { public: Chunk() : num_bytes_(0) {} - void Add(std::string buf) { - records_.push_back(buf); + void Add(const std::string& buf) { num_bytes_ += buf.size(); + records_.emplace_back(buf); } // dump the chunk into w, and clears the chunk and makes it ready for // the next add invocation. @@ -37,10 +37,15 @@ public: records_.clear(); num_bytes_ = 0; } - void Parse(std::istream& sin); - size_t NumBytes() { return num_bytes_; } + + // returns true if ok, false if eof + bool Parse(std::istream& sin); + size_t NumBytes() const { return num_bytes_; } + size_t NumRecords() const { return records_.size(); } const std::string& Record(int i) const { return records_[i]; } + bool Empty() const { return records_.empty(); } + private: std::vector records_; // sum of record lengths in bytes. diff --git a/paddle/fluid/recordio/header.cc b/paddle/fluid/recordio/header.cc index 71d64a62a1efcd2989e9a4e1536ec1250df79bc5..1d96bcb9efedf67343c616d87cce3307d0e36c47 100644 --- a/paddle/fluid/recordio/header.cc +++ b/paddle/fluid/recordio/header.cc @@ -27,15 +27,20 @@ Header::Header() Header::Header(uint32_t num, uint32_t sum, Compressor c, uint32_t cs) : num_records_(num), checksum_(sum), compressor_(c), compress_size_(cs) {} -void Header::Parse(std::istream& is) { +bool Header::Parse(std::istream& is) { uint32_t magic; - is.read(reinterpret_cast(&magic), sizeof(uint32_t)); + size_t read_size = + is.readsome(reinterpret_cast(&magic), sizeof(uint32_t)); + if (read_size < sizeof(uint32_t)) { + return false; + } PADDLE_ENFORCE_EQ(magic, kMagicNumber); is.read(reinterpret_cast(&num_records_), sizeof(uint32_t)) .read(reinterpret_cast(&checksum_), sizeof(uint32_t)) .read(reinterpret_cast(&compressor_), sizeof(uint32_t)) .read(reinterpret_cast(&compress_size_), sizeof(uint32_t)); + return true; } void Header::Write(std::ostream& os) const { diff --git a/paddle/fluid/recordio/header.h b/paddle/fluid/recordio/header.h index 77f2f3a597afe4ec39c5432935409477930805f4..9200ac090de4514bef3704ac502039222eef2284 100644 --- a/paddle/fluid/recordio/header.h +++ b/paddle/fluid/recordio/header.h @@ -42,7 +42,9 @@ public: Header(uint32_t num, uint32_t sum, Compressor ct, uint32_t cs); void Write(std::ostream& os) const; - void Parse(std::istream& is); + + // returns true if OK, false if eof + bool Parse(std::istream& is); uint32_t NumRecords() const { return num_records_; } uint32_t Checksum() const { return checksum_; } diff --git a/paddle/fluid/recordio/scanner.cc b/paddle/fluid/recordio/scanner.cc new file mode 100644 index 0000000000000000000000000000000000000000..7f19c46e7e0be29ef11c76b21dce751141da36bc --- /dev/null +++ b/paddle/fluid/recordio/scanner.cc @@ -0,0 +1,51 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/recordio/scanner.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace recordio { +Scanner::Scanner(std::unique_ptr &&stream) + : stream_(std::move(stream)) { + Reset(); +} + +Scanner::Scanner(const std::string &filename) { + stream_.reset(new std::ifstream(filename)); + Reset(); +} + +void Scanner::Reset() { + stream_->seekg(0, std::ios::beg); + ParseNextChunk(); +} + +const std::string &Scanner::Next() { + PADDLE_ENFORCE(!eof_, "StopIteration"); + auto &rec = cur_chunk_.Record(offset_++); + if (offset_ == cur_chunk_.NumRecords()) { + ParseNextChunk(); + } + return rec; +} + +void Scanner::ParseNextChunk() { + eof_ = !cur_chunk_.Parse(*stream_); + offset_ = 0; +} + +bool Scanner::HasNext() const { return !eof_; } +} // namespace recordio +} // namespace paddle diff --git a/paddle/fluid/recordio/scanner.h b/paddle/fluid/recordio/scanner.h new file mode 100644 index 0000000000000000000000000000000000000000..3073d0c5c872502f4567fcbadd6b4129f865de10 --- /dev/null +++ b/paddle/fluid/recordio/scanner.h @@ -0,0 +1,44 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/recordio/chunk.h" +namespace paddle { +namespace recordio { + +class Scanner { +public: + explicit Scanner(std::unique_ptr&& stream); + + explicit Scanner(const std::string& filename); + + void Reset(); + + const std::string& Next(); + + bool HasNext() const; + +private: + std::unique_ptr stream_; + Chunk cur_chunk_; + size_t offset_; + bool eof_; + + void ParseNextChunk(); +}; +} // namespace recordio +} // namespace paddle diff --git a/paddle/fluid/recordio/writer.cc b/paddle/fluid/recordio/writer.cc new file mode 100644 index 0000000000000000000000000000000000000000..7a4143c662350158051b767db1f7eb820c7436c4 --- /dev/null +++ b/paddle/fluid/recordio/writer.cc @@ -0,0 +1,35 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/fluid/recordio/writer.h" + +namespace paddle { +namespace recordio { +void Writer::Write(const std::string& record) { + cur_chunk_.Add(record); + if (cur_chunk_.NumRecords() >= max_num_records_in_chunk_) { + Flush(); + } +} + +void Writer::Flush() { + cur_chunk_.Write(stream_, compressor_); + cur_chunk_.Clear(); +} + +Writer::~Writer() { + PADDLE_ENFORCE(cur_chunk_.Empty(), "Writer must be flushed when destroy."); +} + +} // namespace recordio +} // namespace paddle diff --git a/paddle/fluid/recordio/writer.h b/paddle/fluid/recordio/writer.h new file mode 100644 index 0000000000000000000000000000000000000000..2db6f60f41acbe6b52dc47b7670fea39065be4d1 --- /dev/null +++ b/paddle/fluid/recordio/writer.h @@ -0,0 +1,44 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/recordio/chunk.h" +namespace paddle { +namespace recordio { + +class Writer { +public: + Writer(std::ostream* sout, + Compressor compressor, + size_t max_num_records_in_chunk = 1000) + : stream_(*sout), + max_num_records_in_chunk_(max_num_records_in_chunk), + compressor_(compressor) {} + + void Write(const std::string& record); + + void Flush(); + + ~Writer(); + +private: + std::ostream& stream_; + size_t max_num_records_in_chunk_; + Chunk cur_chunk_; + Compressor compressor_; +}; + +} // namespace recordio +} // namespace paddle diff --git a/paddle/fluid/recordio/writer_scanner_test.cc b/paddle/fluid/recordio/writer_scanner_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a14d3bc3b2b18d27c6b65ad606098f2e4cae7245 --- /dev/null +++ b/paddle/fluid/recordio/writer_scanner_test.cc @@ -0,0 +1,44 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" + +#include +#include "paddle/fluid/recordio/scanner.h" +#include "paddle/fluid/recordio/writer.h" + +TEST(WriterScanner, Normal) { + std::stringstream* stream = new std::stringstream(); + + { + paddle::recordio::Writer writer(stream, + paddle::recordio::Compressor::kSnappy); + writer.Write("ABC"); + writer.Write("BCD"); + writer.Write("CDE"); + writer.Flush(); + } + + { + stream->seekg(0, std::ios::beg); + std::unique_ptr stream_ptr(stream); + paddle::recordio::Scanner scanner(std::move(stream_ptr)); + ASSERT_TRUE(scanner.HasNext()); + ASSERT_EQ(scanner.Next(), "ABC"); + ASSERT_EQ("BCD", scanner.Next()); + ASSERT_TRUE(scanner.HasNext()); + ASSERT_EQ("CDE", scanner.Next()); + ASSERT_FALSE(scanner.HasNext()); + } +} \ No newline at end of file