提交 55062252 编写于 作者: F fengjiayi

Add MultipleReader and open_files_op

上级 128adf53
...@@ -20,5 +20,6 @@ reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc) ...@@ -20,5 +20,6 @@ reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc)
reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc) reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc)
reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc) reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc)
reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc) reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc)
reader_library(open_files_op SRCS open_files_op.cc)
# Export local libraries to parent # Export local libraries to parent
set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE) set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE)
...@@ -120,10 +120,13 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase { ...@@ -120,10 +120,13 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
}; };
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) { void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
if (!HasNext()) {
PADDLE_THROW("There is no next data!");
}
if (local_buffer_.payloads_.empty()) { if (local_buffer_.payloads_.empty()) {
buffer_->Receive(&local_buffer_); buffer_->Receive(&local_buffer_);
} }
*out = local_buffer_.payloads_; *out = local_buffer_.payloads_;
local_buffer_.payloads_.clear(); local_buffer_.payloads_.clear();
if (local_buffer_.ctx_) { if (local_buffer_.ctx_) {
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace paddle {
namespace operators {
namespace reader {
class MultipleReader : public framework::ReaderBase {
public:
struct Quota {};
MultipleReader(const std::vector<std::string>& file_names,
const std::vector<framework::DDim>& dims, size_t thread_num)
: file_names_(file_names), dims_(dims), thread_num_(thread_num) {
PADDLE_ENFORCE_GT(thread_num_, 0);
StartNewScheduler();
}
void ReadNext(std::vector<framework::LoDTensor>* out) override;
bool HasNext() const override;
void ReInit() override;
private:
void StartNewScheduler();
void ScheduleThreadFunc();
void PrefetchThreadFunc(std::string file_name);
std::vector<std::string> file_names_;
std::vector<framework::DDim> dims_;
size_t thread_num_;
framework::Channel<size_t>* waiting_file_idx_;
framework::Channel<Quota>* thread_quotas_;
framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
mutable std::vector<framework::LoDTensor> local_buffer_;
};
void MultipleReader::ReadNext(std::vector<framework::LoDTensor>* out) {
if (!HasNext()) {
PADDLE_THROW("There is no next data!");
}
if (local_buffer_.empty()) {
buffer_->Receive(&local_buffer_);
}
*out = local_buffer_;
local_buffer_.clear();
}
bool MultipleReader::HasNext() const {
return local_buffer_.empty() ? buffer_->Receive(&local_buffer_) : true;
}
void MultipleReader::ReInit() {
buffer_->Close();
thread_quotas_->Close();
waiting_file_idx_->Close();
local_buffer_.clear();
StartNewScheduler();
}
void MultipleReader::StartNewScheduler() {
waiting_file_idx_ = framework::MakeChannel<size_t>(file_names_.size());
thread_quotas_ = framework::MakeChannel<Quota>(thread_num_);
buffer_ =
framework::MakeChannel<std::vector<framework::LoDTensor>>(thread_num_);
for (size_t i = 0; i < file_names_.size(); ++i) {
waiting_file_idx_->Send(&i);
}
waiting_file_idx_->Close();
for (size_t i = 0; i < thread_num_; ++i) {
Quota quota;
thread_quotas_->Send(&quota);
}
std::thread scheduler([this] { ScheduleThreadFunc(); });
scheduler.detach();
}
void MultipleReader::ScheduleThreadFunc() {
VLOG(5) << "MultipleReader schedule thread starts.";
size_t completed_thread_num = 0;
Quota quota;
while (thread_quotas_->Receive(&quota)) {
size_t file_idx;
if (waiting_file_idx_->Receive(&file_idx)) {
// Still have files to read. Start a new prefetch thread.
std::string file_name = file_names_[file_idx];
std::thread prefetcher(
[this, file_name] { PrefetchThreadFunc(file_name); });
prefetcher.detach();
} else {
// No more file to read.
++completed_thread_num;
if (completed_thread_num == thread_num_) {
thread_quotas_->Close();
buffer_->Close();
break;
}
}
}
VLOG(5) << "MultipleReader schedule thread terminates.";
}
void MultipleReader::PrefetchThreadFunc(std::string file_name) {
VLOG(5) << "The prefetch thread of file '" << file_name << "' starts.";
std::unique_ptr<framework::ReaderBase> reader =
CreateReaderByFileName(file_name, dims_);
while (reader->HasNext()) {
std::vector<framework::LoDTensor> ins;
reader->ReadNext(&ins);
if (!buffer_->Send(&ins)) {
VLOG(5) << "WARNING: The buffer channel has been closed. The prefetch "
"thread of file '"
<< file_name << "' will terminate.";
break;
}
}
Quota quota;
thread_quotas_->Send(&quota);
VLOG(5) << "The prefetch thread of file '" << file_name << "' terminates.";
}
class OpenFilesOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
const auto& shape_concat = Attr<std::vector<int>>("shape_concat");
const auto& ranks = Attr<std::vector<int>>("ranks");
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
int(shape_concat.size()),
"The accumulate of all ranks should be equal to the "
"shape concat's length.");
const auto& file_names = Attr<std::vector<std::string>>("file_names");
PADDLE_ENFORCE(!file_names.empty(), "No file to be read!");
const size_t thread_num = Attr<int>("thread_num");
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new MultipleReader(
file_names, RestoreShapes(shape_concat, ranks), thread_num));
}
};
class OpenFilesOpMaker : public framework::OpProtoAndCheckerMaker {
public:
OpenFilesOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(op_proto, op_checker) {
AddComment(R"DOC(
OpenFiles Operator
An OpenFilesOp creates a MultipleReader, which is able to
read data multi-threaded from multiple files.
)DOC");
AddOutput("Out", "(ReaderHolder) The created MultipleReader.");
AddAttr<std::vector<int>>("shape_concat",
"The concat of all data's shapes.");
AddAttr<std::vector<int>>(
"ranks",
"The ranks of each data."
"e.g."
"shape_concat = [2,3,4,5,6]"
"ranks = [3,2]"
"It means the reader will generate two data each time,"
"whose shapes are [2,3,4] and [5,6] respectively.");
AddAttr<std::vector<int>>("lod_levels", "The LoD levels of each data.");
AddAttr<std::vector<std::string>>("file_names", "Files to be read.");
AddAttr<int>("thread_num", "The maximal concurrent prefetch thread number.")
.GreaterThan(0);
}
};
} // namespace reader
} // namespace operators
} // namespace paddle
namespace reader = paddle::operators::reader;
REGISTER_FILE_READER_OPERATOR(open_files, reader::OpenFilesOp,
reader::OpenFilesOpMaker);
\ No newline at end of file
...@@ -21,6 +21,8 @@ namespace paddle { ...@@ -21,6 +21,8 @@ namespace paddle {
namespace operators { namespace operators {
namespace reader { namespace reader {
static constexpr char kFileFormatSeparator[] = ":";
using FileReaderCreator = std::function<framework::ReaderBase*( using FileReaderCreator = std::function<framework::ReaderBase*(
const std::string&, const std::vector<framework::DDim>&)>; const std::string&, const std::vector<framework::DDim>&)>;
...@@ -29,12 +31,28 @@ std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry(); ...@@ -29,12 +31,28 @@ std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry();
template <typename Reader> template <typename Reader>
int RegisterFileReader(const std::string& filetype) { int RegisterFileReader(const std::string& filetype) {
FileReaderRegistry()[filetype] = []( FileReaderRegistry()[filetype] = [](
const std::string& fn, const std::vector<paddle::framework::DDim>& dim) { const std::string& fn, const std::vector<framework::DDim>& dims) {
return new Reader(fn, dim); return new Reader(fn, dims);
}; };
return 0; return 0;
} }
std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
const std::string& file_name, const std::vector<framework::DDim>& dims) {
size_t separator_pos = file_name.find(kFileFormatSeparator);
PADDLE_ENFORCE_NE(separator_pos, std::string::npos,
"File name illegal! A legal file name should be like: "
"[file_format]:[file_name] (e.g., 'recordio:data_file').");
std::string filetype = file_name.substr(0, separator_pos);
std::string f_name = file_name.substr(separator_pos + 1);
auto itor = FileReaderRegistry().find(filetype);
PADDLE_ENFORCE(itor != FileReaderRegistry().end(),
"No file reader registered for '%s' format.", filetype);
framework::ReaderBase* reader = (itor->second)(f_name, dims);
return std::unique_ptr<framework::ReaderBase>(reader);
}
extern std::vector<framework::DDim> RestoreShapes( extern std::vector<framework::DDim> RestoreShapes(
const std::vector<int>& shape_concat, const std::vector<int>& ranks); const std::vector<int>& shape_concat, const std::vector<int>& ranks);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册