未验证 提交 e13aec60 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #8830 from reyoung/feature/recordio_file_reader

Feature/recordio file reader
......@@ -21,7 +21,7 @@ endif()
cc_test(eigen_test SRCS eigen_test.cc DEPS tensor)
nv_test(mixed_vector_test SRCS mixed_vector_test.cu DEPS place paddle_memory device_context init)
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto)
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto recordio)
cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor paddle_memory)
nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor init)
......
......@@ -19,6 +19,9 @@ limitations under the License. */
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/recordio/scanner.h"
#include "paddle/fluid/recordio/writer.h"
#include <stdint.h>
#include <string.h>
#include <algorithm>
......@@ -291,6 +294,31 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
TensorFromStream(is, static_cast<Tensor *>(tensor), dev_ctx);
}
void WriteToRecordIO(recordio::Writer &writer,
const std::vector<LoDTensor> &tensor,
const platform::DeviceContext &dev_ctx) {
std::stringstream buffer;
size_t sz = tensor.size();
buffer.write(reinterpret_cast<const char *>(&sz), sizeof(uint32_t));
for (auto &each : tensor) {
SerializeToStream(buffer, each, dev_ctx);
}
writer.Write(buffer.str());
}
std::vector<LoDTensor> ReadFromRecordIO(
recordio::Scanner &scanner, const platform::DeviceContext &dev_ctx) {
std::istringstream sin(scanner.Next());
uint32_t sz;
sin.read(reinterpret_cast<char *>(&sz), sizeof(uint32_t));
std::vector<LoDTensor> result;
result.resize(sz);
for (uint32_t i = 0; i < sz; ++i) {
DeserializeFromStream(sin, &result[i], dev_ctx);
}
return result;
}
std::vector<LoDTensor> LoDTensor::SplitLoDTensor(
const std::vector<platform::Place> places) const {
check_memory_size();
......
......@@ -29,6 +29,12 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace recordio {
class Writer;
class Scanner;
}
namespace framework {
/*
......@@ -209,5 +215,12 @@ void SerializeToStream(std::ostream& os, const LoDTensor& tensor,
void DeserializeFromStream(std::istream& is, LoDTensor* tensor,
const platform::DeviceContext& dev_ctx);
extern void WriteToRecordIO(recordio::Writer& writer,
const std::vector<LoDTensor>& tensor,
const platform::DeviceContext& dev_ctx);
extern std::vector<LoDTensor> ReadFromRecordIO(
recordio::Scanner& scanner, const platform::DeviceContext& dev_ctx);
} // namespace framework
} // namespace paddle
......@@ -14,6 +14,9 @@
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/recordio/scanner.h"
#include "paddle/fluid/recordio/writer.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <algorithm>
......@@ -224,5 +227,43 @@ TEST(LoD, CheckAbsLoD) {
abs_lod0.push_back(std::vector<size_t>({0}));
ASSERT_FALSE(CheckAbsLoD(abs_lod0));
}
TEST(LoDTensor, RecordIO) {
LoDTensor tensor;
int* tmp = tensor.mutable_data<int>(make_ddim({4, 5}), platform::CPUPlace());
for (int i = 0; i < 20; ++i) {
tmp[i] = i;
}
std::stringstream* stream = new std::stringstream();
auto& ctx =
*platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
{
recordio::Writer writer(stream, recordio::Compressor::kSnappy);
WriteToRecordIO(writer, {tensor, tensor}, ctx);
WriteToRecordIO(writer, {tensor, tensor}, ctx);
writer.Flush();
}
auto assert_tensor_ok = [](const LoDTensor& tensor) {
for (int i = 0; i < 20; ++i) {
ASSERT_EQ(tensor.data<int>()[i], i);
}
};
{
std::unique_ptr<std::istream> stream_ptr(stream);
recordio::Scanner scanner(std::move(stream_ptr));
auto tensors = ReadFromRecordIO(scanner, ctx);
ASSERT_EQ(tensors.size(), 2);
assert_tensor_ok(tensors[0]);
assert_tensor_ok(tensors[1]);
tensors = ReadFromRecordIO(scanner, ctx);
ASSERT_EQ(tensors.size(), 2);
assert_tensor_ok(tensors[0]);
assert_tensor_ok(tensors[1]);
}
}
} // namespace framework
} // namespace paddle
......@@ -33,6 +33,8 @@ class ReaderBase {
std::vector<DDim> shapes() const { return shapes_; }
void set_shapes(const std::vector<DDim>& shapes) { shapes_ = shapes; }
virtual bool HasNext() const = 0;
virtual ~ReaderBase() {}
protected:
......@@ -53,6 +55,8 @@ class DecoratedReader : public ReaderBase {
void ReInit() override { reader_->ReInit(); }
bool HasNext() const override { return reader_->HasNext(); }
protected:
ReaderBase* reader_;
};
......@@ -87,6 +91,8 @@ class ReaderHolder {
reader_->set_shapes(shapes);
}
bool HasNext() const { return reader_->HasNext(); }
private:
std::unique_ptr<ReaderBase> reader_;
};
......
......@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace detail {
......
cc_library(reader_op_registry SRCS reader_op_registry.cc DEPS operator op_registry reader)
op_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc DEPS reader_op_registry)
op_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc DEPS reader_op_registry)
op_library(create_batch_reader_op SRCS create_batch_reader_op.cc DEPS reader_op_registry)
op_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc DEPS reader_op_registry)
set(READER_LIBRARY create_random_data_generator_op create_shuffle_reader_op create_batch_reader_op create_double_buffer_reader_op PARENT_SCOPE)
set(LOCAL_READER_LIBS)
function(reader_library TARGET_NAME)
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS)
set(options "")
set(common_deps reader_op_registry)
cmake_parse_arguments(reader_library "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN})
op_library(${TARGET_NAME} SRCS ${reader_library_SRCS} DEPS ${common_deps} ${reader_library_DEPS})
set(LOCAL_READER_LIBS
${TARGET_NAME}
${LOCAL_READER_LIBS}
PARENT_SCOPE)
endfunction()
reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc)
reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc)
reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc)
reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc)
reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc)
# Export local libraries to parent
set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE)
......@@ -37,6 +37,8 @@ class DoubleBufferReader : public framework::DecoratedReader {
~DoubleBufferReader() { buffer_->Close(); }
bool HasNext() const override;
private:
void PrefetchThreadFunc();
......@@ -106,6 +108,8 @@ void DoubleBufferReader::PrefetchThreadFunc() {
}
}
bool DoubleBufferReader::HasNext() const { PADDLE_THROW("Not Implemented"); }
} // namespace reader
} // namespace operators
} // namespace paddle
......
......@@ -52,6 +52,8 @@ class RandomDataGenerator : public framework::FileReader {
void ReInit() override { return; }
bool HasNext() const override { return true; }
private:
float min_;
float max_;
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reader/reader_op_registry.h"
#include "paddle/fluid/recordio/scanner.h"
namespace paddle {
namespace operators {
namespace reader {
class RecordIOFileReader : public framework::FileReader {
public:
RecordIOFileReader(const std::string& filename,
const std::vector<framework::DDim>& shapes)
: FileReader(shapes),
scanner_(filename),
dev_ctx_(*platform::DeviceContextPool::Instance().Get(
platform::CPUPlace())) {}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
*out = framework::ReadFromRecordIO(scanner_, dev_ctx_);
}
bool HasNext() const override { return scanner_.HasNext(); }
void ReInit() override { scanner_.Reset(); }
private:
recordio::Scanner scanner_;
const platform::DeviceContext& dev_ctx_;
};
class CreateRecordIOReaderOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
const auto& shape_concat = Attr<std::vector<int>>("shape_concat");
const auto& ranks = Attr<std::vector<int>>("ranks");
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
int(shape_concat.size()),
"The accumulate of all ranks should be equal to the "
"shape concat's length.");
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
std::string filename = Attr<std::string>("filename");
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new RecordIOFileReader(filename, shapes));
}
};
class CreateRecordIOReaderOpMaker : public FileReaderMakerBase {
public:
CreateRecordIOReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: FileReaderMakerBase(op_proto, op_checker) {
AddAttr<std::string>("filename", "The filename of record io reader");
AddComment(R"DOC(
CreateRecordIOReader Operator
Create a reader from a record io file
)DOC");
}
};
} // namespace reader
} // namespace operators
} // namespace paddle
namespace reader = paddle::operators::reader;
REGISTER_FILE_READER_OPERATOR(create_recordio_file_reader,
reader::CreateRecordIOReaderOp,
reader::CreateRecordIOReaderOpMaker);
......@@ -35,7 +35,7 @@ FileReaderMakerBase::FileReaderMakerBase(
framework::OpProtoAndCheckerMaker::OpProto* op_proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(op_proto, op_checker) {
AddOutput("Out", "(ReaderHolder) The created random reader.");
AddOutput("Out", "(ReaderHolder) The created random reader.").AsDuplicable();
AddAttr<std::vector<int>>("shape_concat", "The concat of all data's shapes.");
AddAttr<std::vector<int>>(
"ranks",
......
if(WITH_PYTHON)
cc_library(paddle_pybind SHARED
SRCS pybind.cc exception.cc protobuf.cc const_value.cc
SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc
DEPS pybind python backward proto_desc paddle_memory executor prune init profiler feed_fetch_method
${GLOB_OP_LIB})
if(NOT APPLE AND NOT ANDROID)
......
......@@ -26,6 +26,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/prune.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/cond_op.h"
#include "paddle/fluid/operators/net_op.h"
......@@ -35,7 +36,9 @@ limitations under the License. */
#include "paddle/fluid/pybind/const_value.h"
#include "paddle/fluid/pybind/exception.h"
#include "paddle/fluid/pybind/pybind.h"
#include "paddle/fluid/pybind/recordio.h"
#include "paddle/fluid/pybind/tensor_py.h"
#include "paddle/fluid/string/to_string.h"
#ifdef PADDLE_WITH_CUDA
......@@ -217,8 +220,18 @@ All parameter, weight, gradient are variables in Paddle.
[](Variable &self) -> operators::NetOp * {
return self.GetMutable<operators::NetOp>();
},
py::return_value_policy::reference)
.def("get_reader",
[](Variable &self) -> framework::ReaderHolder * {
PADDLE_ENFORCE(self.IsType<framework::ReaderHolder>());
return self.GetMutable<framework::ReaderHolder>();
},
py::return_value_policy::reference);
py::class_<framework::ReaderHolder>(m, "Reader", "")
.def("has_next", &framework::ReaderHolder::HasNext)
.def("reset", &framework::ReaderHolder::ReInit);
py::class_<Scope>(m, "Scope", "")
.def("var",
[](Scope &self, const std::string &name) -> Variable * {
......@@ -474,6 +487,8 @@ All parameter, weight, gradient are variables in Paddle.
m.def("enable_profiler", platform::EnableProfiler);
m.def("disable_profiler", platform::DisableProfiler);
m.def("reset_profiler", platform::ResetProfiler);
BindRecordIOWriter(m);
return m.ptr();
}
} // namespace pybind
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/pybind/recordio.h"
#include <fstream>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/recordio/writer.h"
namespace paddle {
namespace pybind {
class RecordIOWriter {
public:
RecordIOWriter(const std::string& filename, recordio::Compressor compressor,
size_t max_num_record)
: stream_(filename), writer_(&stream_, compressor, max_num_record) {}
void AppendTensor(const framework::LoDTensor& tensor) {
tensors_.push_back(tensor);
}
void CompleteAppendTensor() {
auto& ctx =
*platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
framework::WriteToRecordIO(writer_, tensors_, ctx);
tensors_.clear();
}
void Close() {
PADDLE_ENFORCE(tensors_.empty());
writer_.Flush();
stream_.close();
}
private:
std::vector<framework::LoDTensor> tensors_;
std::ofstream stream_;
recordio::Writer writer_;
};
void BindRecordIOWriter(py::module& m) {
py::class_<RecordIOWriter> writer(m, "RecordIOWriter", "");
py::enum_<recordio::Compressor>(writer, "Compressor", "")
.value("Snappy", recordio::Compressor::kSnappy)
.value("NoCompress", recordio::Compressor::kNoCompress);
writer
.def("__init__",
[](RecordIOWriter& self, const std::string& filename,
recordio::Compressor compressor, size_t max_num_record) {
new (&self) RecordIOWriter(filename, compressor, max_num_record);
})
.def("append_tensor", &RecordIOWriter::AppendTensor)
.def("complete_append_tensor", &RecordIOWriter::CompleteAppendTensor)
.def("close", &RecordIOWriter::Close);
}
} // namespace pybind
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace py = pybind11;
namespace paddle {
namespace pybind {
extern void BindRecordIOWriter(py::module& m);
} // namespace pybind
} // namespace paddle
......@@ -3,4 +3,7 @@ cc_library(header SRCS header.cc)
cc_test(header_test SRCS header_test.cc DEPS header)
cc_library(chunk SRCS chunk.cc DEPS snappystream snappy header zlib)
cc_test(chunk_test SRCS chunk_test.cc DEPS chunk)
cc_library(recordio DEPS chunk header)
cc_library(writer SRCS writer.cc DEPS chunk)
cc_library(scanner SRCS scanner.cc DEPS chunk)
cc_test(writer_scanner_test SRCS writer_scanner_test.cc DEPS writer scanner)
cc_library(recordio DEPS chunk header writer scanner)
......@@ -24,33 +24,52 @@ namespace paddle {
namespace recordio {
constexpr size_t kMaxBufSize = 1024;
/**
* Read Stream by a fixed sized buffer.
* @param in input stream
* @param limit read at most `limit` bytes from input stream. 0 means no limit
* @param callback A function object with (const char* buf, size_t size) -> void
* as its type.
*/
template <typename Callback>
static void ReadStreamByBuf(std::istream& in, int limit, Callback callback) {
static void ReadStreamByBuf(std::istream& in, size_t limit, Callback callback) {
char buf[kMaxBufSize];
std::streamsize actual_size;
size_t counter = 0;
do {
auto actual_max =
limit > 0 ? std::min(limit - counter, kMaxBufSize) : kMaxBufSize;
actual_size = in.readsome(buf, actual_max);
size_t actual_max;
while (!in.eof() ||
(limit != 0 && counter >= limit)) { // End of file or reach limit
actual_max =
limit != 0 ? std::min(limit - counter, kMaxBufSize) : kMaxBufSize;
in.read(buf, actual_max);
actual_size = in.gcount();
if (actual_size == 0) {
break;
}
callback(buf, actual_size);
if (limit > 0) {
if (limit != 0) {
counter += actual_size;
}
} while (actual_size == kMaxBufSize);
}
in.clear(); // unset eof state
}
/**
* Copy stream in to another stream
*/
static void PipeStream(std::istream& in, std::ostream& os) {
ReadStreamByBuf(
in, -1, [&os](const char* buf, size_t len) { os.write(buf, len); });
in, 0, [&os](const char* buf, size_t len) { os.write(buf, len); });
}
static uint32_t Crc32Stream(std::istream& in, int limit = -1) {
auto crc = crc32(0, nullptr, 0);
/**
* Calculate CRC32 from an input stream.
*/
static uint32_t Crc32Stream(std::istream& in, size_t limit = 0) {
uint32_t crc = static_cast<uint32_t>(crc32(0, nullptr, 0));
ReadStreamByBuf(in, limit, [&crc](const char* buf, size_t len) {
crc = crc32(crc, reinterpret_cast<const Bytef*>(buf), len);
crc = static_cast<uint32_t>(crc32(
crc, reinterpret_cast<const Bytef*>(buf), static_cast<uInt>(len)));
});
return crc;
}
......@@ -85,28 +104,29 @@ bool Chunk::Write(std::ostream& os, Compressor ct) const {
compressed_stream.reset();
}
auto end_pos = sout.tellg();
sout.seekg(0, std::ios::end);
uint32_t len = static_cast<uint32_t>(sout.tellg());
sout.seekg(0, std::ios::beg);
uint32_t len = static_cast<uint32_t>(end_pos - sout.tellg());
uint32_t crc = Crc32Stream(sout);
sout.seekg(0, std::ios::beg);
Header hdr(static_cast<uint32_t>(records_.size()), crc, ct, len);
hdr.Write(os);
sout.seekg(0, std::ios::beg);
sout.clear();
PipeStream(sout, os);
return true;
}
void Chunk::Parse(std::istream& sin) {
bool Chunk::Parse(std::istream& sin) {
Header hdr;
hdr.Parse(sin);
bool ok = hdr.Parse(sin);
if (!ok) {
return ok;
}
auto beg_pos = sin.tellg();
auto crc = Crc32Stream(sin, hdr.CompressSize());
uint32_t crc = Crc32Stream(sin, hdr.CompressSize());
PADDLE_ENFORCE_EQ(hdr.Checksum(), crc);
Clear();
sin.seekg(beg_pos, std::ios::beg);
sin.seekg(beg_pos, sin.beg);
std::unique_ptr<std::istream> compressed_stream;
switch (hdr.CompressType()) {
case Compressor::kNoCompress:
......@@ -126,8 +146,10 @@ void Chunk::Parse(std::istream& sin) {
std::string buf;
buf.resize(rec_len);
stream.read(&buf[0], rec_len);
PADDLE_ENFORCE_EQ(rec_len, stream.gcount());
Add(buf);
}
return true;
}
} // namespace recordio
......
......@@ -26,9 +26,9 @@ namespace recordio {
class Chunk {
public:
Chunk() : num_bytes_(0) {}
void Add(std::string buf) {
records_.push_back(buf);
void Add(const std::string& buf) {
num_bytes_ += buf.size();
records_.emplace_back(buf);
}
// dump the chunk into w, and clears the chunk and makes it ready for
// the next add invocation.
......@@ -37,10 +37,15 @@ public:
records_.clear();
num_bytes_ = 0;
}
void Parse(std::istream& sin);
size_t NumBytes() { return num_bytes_; }
// returns true if ok, false if eof
bool Parse(std::istream& sin);
size_t NumBytes() const { return num_bytes_; }
size_t NumRecords() const { return records_.size(); }
const std::string& Record(int i) const { return records_[i]; }
bool Empty() const { return records_.empty(); }
private:
std::vector<std::string> records_;
// sum of record lengths in bytes.
......
......@@ -26,7 +26,7 @@ TEST(Chunk, SaveLoad) {
ch.Add(std::string("123", 4));
std::stringstream ss;
ch.Write(ss, Compressor::kNoCompress);
ch.Clear();
ss.seekg(0);
ch.Parse(ss);
ASSERT_EQ(ch.NumBytes(), 10U);
}
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/recordio/header.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace recordio {
......@@ -26,23 +27,33 @@ Header::Header()
Header::Header(uint32_t num, uint32_t sum, Compressor c, uint32_t cs)
: num_records_(num), checksum_(sum), compressor_(c), compress_size_(cs) {}
void Header::Parse(std::istream& is) {
bool Header::Parse(std::istream& is) {
uint32_t magic;
size_t read_size =
is.readsome(reinterpret_cast<char*>(&magic), sizeof(uint32_t));
if (read_size < sizeof(uint32_t)) {
return false;
}
PADDLE_ENFORCE_EQ(magic, kMagicNumber);
is.read(reinterpret_cast<char*>(&num_records_), sizeof(uint32_t))
.read(reinterpret_cast<char*>(&checksum_), sizeof(uint32_t))
.read(reinterpret_cast<char*>(&compressor_), sizeof(uint32_t))
.read(reinterpret_cast<char*>(&compress_size_), sizeof(uint32_t));
return true;
}
void Header::Write(std::ostream& os) const {
os.write(reinterpret_cast<const char*>(&num_records_), sizeof(uint32_t))
os.write(reinterpret_cast<const char*>(&kMagicNumber), sizeof(uint32_t))
.write(reinterpret_cast<const char*>(&num_records_), sizeof(uint32_t))
.write(reinterpret_cast<const char*>(&checksum_), sizeof(uint32_t))
.write(reinterpret_cast<const char*>(&compressor_), sizeof(uint32_t))
.write(reinterpret_cast<const char*>(&compress_size_), sizeof(uint32_t));
}
std::ostream& operator<<(std::ostream& os, Header h) {
os << h.NumRecords() << h.Checksum()
<< static_cast<uint32_t>(h.CompressType()) << h.CompressSize();
os << "Header: " << h.NumRecords() << ", " << h.Checksum() << ", "
<< static_cast<uint32_t>(h.CompressType()) << ", " << h.CompressSize();
return os;
}
......
......@@ -19,8 +19,6 @@
namespace paddle {
namespace recordio {
// Default ChunkSize
constexpr size_t kDefaultMaxChunkSize = 32 * 1024 * 1024;
// MagicNumber for memory checking
constexpr uint32_t kMagicNumber = 0x01020304;
......@@ -44,7 +42,9 @@ public:
Header(uint32_t num, uint32_t sum, Compressor ct, uint32_t cs);
void Write(std::ostream& os) const;
void Parse(std::istream& is);
// returns true if OK, false if eof
bool Parse(std::istream& is);
uint32_t NumRecords() const { return num_records_; }
uint32_t Checksum() const { return checksum_; }
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/recordio/scanner.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace recordio {
Scanner::Scanner(std::unique_ptr<std::istream> &&stream)
: stream_(std::move(stream)) {
Reset();
}
Scanner::Scanner(const std::string &filename) {
stream_.reset(new std::ifstream(filename));
Reset();
}
void Scanner::Reset() {
stream_->seekg(0, std::ios::beg);
ParseNextChunk();
}
std::string Scanner::Next() {
PADDLE_ENFORCE(!eof_, "StopIteration");
auto rec = cur_chunk_.Record(offset_++);
if (offset_ == cur_chunk_.NumRecords()) {
ParseNextChunk();
}
return rec;
}
void Scanner::ParseNextChunk() {
eof_ = !cur_chunk_.Parse(*stream_);
offset_ = 0;
}
bool Scanner::HasNext() const { return !eof_; }
} // namespace recordio
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <fstream>
#include <memory>
#include "paddle/fluid/recordio/chunk.h"
namespace paddle {
namespace recordio {
class Scanner {
public:
explicit Scanner(std::unique_ptr<std::istream>&& stream);
explicit Scanner(const std::string& filename);
void Reset();
std::string Next();
bool HasNext() const;
private:
std::unique_ptr<std::istream> stream_;
Chunk cur_chunk_;
size_t offset_;
bool eof_;
void ParseNextChunk();
};
} // namespace recordio
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/recordio/writer.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace recordio {
void Writer::Write(const std::string& record) {
cur_chunk_.Add(record);
if (cur_chunk_.NumRecords() >= max_num_records_in_chunk_) {
Flush();
}
}
void Writer::Flush() {
cur_chunk_.Write(stream_, compressor_);
cur_chunk_.Clear();
}
Writer::~Writer() {
PADDLE_ENFORCE(cur_chunk_.Empty(), "Writer must be flushed when destroy.");
}
} // namespace recordio
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/recordio/chunk.h"
namespace paddle {
namespace recordio {
class Writer {
public:
Writer(std::ostream* sout,
Compressor compressor,
size_t max_num_records_in_chunk = 1000)
: stream_(*sout),
max_num_records_in_chunk_(max_num_records_in_chunk),
compressor_(compressor) {}
void Write(const std::string& record);
void Flush();
~Writer();
private:
std::ostream& stream_;
size_t max_num_records_in_chunk_;
Chunk cur_chunk_;
Compressor compressor_;
};
} // namespace recordio
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gtest/gtest.h"
#include <sstream>
#include "paddle/fluid/recordio/scanner.h"
#include "paddle/fluid/recordio/writer.h"
TEST(WriterScanner, Normal) {
std::stringstream* stream = new std::stringstream();
{
paddle::recordio::Writer writer(stream,
paddle::recordio::Compressor::kSnappy);
writer.Write("ABC");
writer.Write("BCD");
writer.Write("CDE");
writer.Flush();
}
{
stream->seekg(0, std::ios::beg);
std::unique_ptr<std::istream> stream_ptr(stream);
paddle::recordio::Scanner scanner(std::move(stream_ptr));
ASSERT_TRUE(scanner.HasNext());
ASSERT_EQ(scanner.Next(), "ABC");
ASSERT_EQ("BCD", scanner.Next());
ASSERT_TRUE(scanner.HasNext());
ASSERT_EQ("CDE", scanner.Next());
ASSERT_FALSE(scanner.HasNext());
}
}
TEST(WriterScanner, TinyChunk) {
std::stringstream* stream = new std::stringstream();
{
paddle::recordio::Writer writer(
stream, paddle::recordio::Compressor::kNoCompress, 2 /*max chunk num*/);
writer.Write("ABC");
writer.Write("BCD");
writer.Write("CDE");
writer.Write("DEFG");
writer.Flush();
}
{
stream->seekg(0, std::ios::beg);
std::unique_ptr<std::istream> stream_ptr(stream);
paddle::recordio::Scanner scanner(std::move(stream_ptr));
ASSERT_TRUE(scanner.HasNext());
ASSERT_EQ(scanner.Next(), "ABC");
ASSERT_EQ(scanner.Next(), "BCD");
ASSERT_EQ(scanner.Next(), "CDE");
ASSERT_EQ(scanner.Next(), "DEFG");
ASSERT_FALSE(scanner.HasNext());
}
}
\ No newline at end of file
......@@ -40,6 +40,7 @@ import clip
from memory_optimization_transpiler import memory_optimize, release_memory
import profiler
import unique_name
import recordio_writer
Tensor = LoDTensor
......@@ -66,6 +67,7 @@ __all__ = framework.__all__ + executor.__all__ + concurrency.__all__ + [
'release_memory',
'profiler',
'unique_name',
'recordio_writer',
]
......
......@@ -47,7 +47,7 @@ def is_parameter(var):
def is_persistable(var):
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
var.desc.type() == core.VarDesc.VarType.FETCH_LIST:
var.desc.type() == core.VarDesc.VarType.FETCH_LIST:
return False
return var.persistable
......
......@@ -13,11 +13,16 @@
# limitations under the License.
from .. import core
from ..layer_helper import LayerHelper
from ..framework import convert_np_dtype_to_dtype_, default_main_program, default_startup_program
from ..unique_name import generate as unique_name
from control_flow import BlockGuard
from ..layer_helper import LayerHelper
from ..executor import global_scope
__all__ = ['data', 'BlockGuardServ', 'ListenAndServ', 'Send']
__all__ = [
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file',
'read_file'
]
def data(name,
......@@ -224,3 +229,72 @@ def Recv(endpoints, get_vars):
outputs={"Out": get_vars},
attrs={"endpoints": endpoints,
"epmap": epmap})
def monkey_patch_reader_methods(reader):
def __get_reader__():
scope = global_scope()
var = scope.find_var(reader.name)
return var.get_reader()
def eof():
return not __get_reader__().has_next()
def reset():
return __get_reader__().reset()
reader.eof = eof
reader.reset = reset
return reader
def _copy_reader_var_(block, var):
new_var = block.create_var(name=var.name, type=core.VarDesc.VarType.READER)
new_var.desc.set_shapes(var.desc.shapes())
new_var.desc.set_dtypes(var.desc.dtypes())
new_var.persistable = True
return monkey_patch_reader_methods(new_var)
def open_recordio_file(filename, shapes, lod_levels, dtypes):
dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes]
shape_concat = []
ranks = []
for shape in shapes:
shape_concat.extend(shape)
ranks.append(len(shape))
var_name = unique_name('open_recordio_file')
startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=var_name)
startup_blk.append_op(
type='create_recordio_file_reader',
outputs={'Out': [startup_var]},
attrs={
'shape_concat': shape_concat,
'lod_levels': lod_levels,
'filename': filename,
'ranks': ranks
})
startup_var.desc.set_dtypes(dtypes)
startup_var.persistable = True
return _copy_reader_var_(default_main_program().current_block(),
startup_var)
def read_file(file_obj):
helper = LayerHelper('read_file')
out = [
helper.create_tmp_variable(
stop_gradient=True, dtype='float32')
for _ in range(len(file_obj.desc.shapes()))
]
helper.append_op(
type='read', inputs={'Reader': [file_obj]}, outputs={'Out': out})
if len(out) == 1:
return out[0]
else:
return out
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import core
import contextlib
__all__ = ['convert_reader_to_recordio_file']
@contextlib.contextmanager
def create_recordio_writer(filename,
compressor=core.RecordIOWriter.Compressor.Snappy,
max_num_records=1000):
writer = core.RecordIOWriter(filename, compressor, max_num_records)
yield writer
writer.close()
def convert_reader_to_recordio_file(
filename,
reader_creator,
feeder,
compressor=core.RecordIOWriter.Compressor.Snappy,
max_num_records=1000,
feed_order=None):
if feed_order is None:
feed_order = feeder.feed_names
with create_recordio_writer(filename, compressor,
max_num_records) as writer:
for batch in reader_creator():
res = feeder.feed(batch)
for each in feed_order:
writer.append_tensor(res[each])
writer.complete_append_tensor()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.fluid as fluid
import paddle.v2.dataset.mnist as mnist
import paddle.v2 as paddle
class TestRecordIO(unittest.TestCase):
def setUp(self):
# Convert mnist to recordio file
with fluid.program_guard(fluid.Program(), fluid.Program()):
reader = paddle.batch(mnist.train(), batch_size=32)
feeder = fluid.DataFeeder(
feed_list=[ # order is image and label
fluid.layers.data(
name='image', shape=[784]),
fluid.layers.data(
name='label', shape=[1], dtype='int64'),
],
place=fluid.CPUPlace())
fluid.recordio_writer.convert_reader_to_recordio_file(
'./mnist.recordio', reader, feeder)
def test_main(self):
# use new program
with fluid.program_guard(fluid.Program(), fluid.Program()):
data_file = fluid.layers.open_recordio_file(
'./mnist.recordio',
shapes=[[-1, 784], [-1, 1]],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
img, label = fluid.layers.read_file(data_file)
hidden = fluid.layers.fc(input=img, size=100, act='tanh')
prediction = fluid.layers.fc(input=hidden, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=prediction, label=label)
avg_loss = fluid.layers.mean(loss)
fluid.optimizer.Adam(learning_rate=1e-3).minimize(avg_loss)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
avg_loss_np = []
# train a pass
while not data_file.eof():
tmp, = exe.run(fetch_list=[avg_loss])
avg_loss_np.append(tmp)
data_file.reset()
self.assertLess(avg_loss_np[-1], avg_loss_np[0])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册