diff --git a/paddle/fluid/operators/reader/CMakeLists.txt b/paddle/fluid/operators/reader/CMakeLists.txt index 744bd3b7ef71f83ad82979eb966369c2e9456a7d..1254783d69a87b8b13650449fdb84174f7aef91e 100644 --- a/paddle/fluid/operators/reader/CMakeLists.txt +++ b/paddle/fluid/operators/reader/CMakeLists.txt @@ -20,5 +20,6 @@ reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc) reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc) reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc) reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc) +reader_library(open_files_op SRCS open_files_op.cc) # Export local libraries to parent set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE) diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index d0de092947eb04a1b7d06dedea919f6b1094dd06..447fae10535c1b458ed7de24ad3659b3c48ecb4a 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -120,10 +120,13 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase { }; void DoubleBufferReader::ReadNext(std::vector* out) { + if (!HasNext()) { + PADDLE_THROW("There is no next data!"); + } + if (local_buffer_.payloads_.empty()) { buffer_->Receive(&local_buffer_); } - *out = local_buffer_.payloads_; local_buffer_.payloads_.clear(); if (local_buffer_.ctx_) { diff --git a/paddle/fluid/operators/reader/open_files_op.cc b/paddle/fluid/operators/reader/open_files_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..473c002e93a6db65c5a47943e8b5c820abd19b34 --- /dev/null +++ b/paddle/fluid/operators/reader/open_files_op.cc @@ -0,0 +1,199 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/channel.h" +#include "paddle/fluid/operators/reader/reader_op_registry.h" + +namespace paddle { +namespace operators { +namespace reader { + +class MultipleReader : public framework::ReaderBase { + public: + struct Quota {}; + + MultipleReader(const std::vector& file_names, + const std::vector& dims, size_t thread_num) + : file_names_(file_names), dims_(dims), thread_num_(thread_num) { + PADDLE_ENFORCE_GT(thread_num_, 0); + StartNewScheduler(); + } + + void ReadNext(std::vector* out) override; + bool HasNext() const override; + void ReInit() override; + + private: + void StartNewScheduler(); + void ScheduleThreadFunc(); + void PrefetchThreadFunc(std::string file_name); + + std::vector file_names_; + std::vector dims_; + size_t thread_num_; + framework::Channel* waiting_file_idx_; + framework::Channel* thread_quotas_; + framework::Channel>* buffer_; + mutable std::vector local_buffer_; +}; + +void MultipleReader::ReadNext(std::vector* out) { + if (!HasNext()) { + PADDLE_THROW("There is no next data!"); + } + + if (local_buffer_.empty()) { + buffer_->Receive(&local_buffer_); + } + *out = local_buffer_; + local_buffer_.clear(); +} + +bool MultipleReader::HasNext() const { + return local_buffer_.empty() ? buffer_->Receive(&local_buffer_) : true; +} + +void MultipleReader::ReInit() { + buffer_->Close(); + thread_quotas_->Close(); + waiting_file_idx_->Close(); + local_buffer_.clear(); + + StartNewScheduler(); +} + +void MultipleReader::StartNewScheduler() { + waiting_file_idx_ = framework::MakeChannel(file_names_.size()); + thread_quotas_ = framework::MakeChannel(thread_num_); + buffer_ = + framework::MakeChannel>(thread_num_); + + for (size_t i = 0; i < file_names_.size(); ++i) { + waiting_file_idx_->Send(&i); + } + waiting_file_idx_->Close(); + for (size_t i = 0; i < thread_num_; ++i) { + Quota quota; + thread_quotas_->Send("a); + } + + std::thread scheduler([this] { ScheduleThreadFunc(); }); + scheduler.detach(); +} + +void MultipleReader::ScheduleThreadFunc() { + VLOG(5) << "MultipleReader schedule thread starts."; + size_t completed_thread_num = 0; + Quota quota; + while (thread_quotas_->Receive("a)) { + size_t file_idx; + if (waiting_file_idx_->Receive(&file_idx)) { + // Still have files to read. Start a new prefetch thread. + std::string file_name = file_names_[file_idx]; + std::thread prefetcher( + [this, file_name] { PrefetchThreadFunc(file_name); }); + prefetcher.detach(); + } else { + // No more file to read. + ++completed_thread_num; + if (completed_thread_num == thread_num_) { + thread_quotas_->Close(); + buffer_->Close(); + break; + } + } + } + VLOG(5) << "MultipleReader schedule thread terminates."; +} + +void MultipleReader::PrefetchThreadFunc(std::string file_name) { + VLOG(5) << "The prefetch thread of file '" << file_name << "' starts."; + std::unique_ptr reader = + CreateReaderByFileName(file_name, dims_); + while (reader->HasNext()) { + std::vector ins; + reader->ReadNext(&ins); + if (!buffer_->Send(&ins)) { + VLOG(5) << "WARNING: The buffer channel has been closed. The prefetch " + "thread of file '" + << file_name << "' will terminate."; + break; + } + } + Quota quota; + thread_quotas_->Send("a); + VLOG(5) << "The prefetch thread of file '" << file_name << "' terminates."; +} + +class OpenFilesOp : public framework::OperatorBase { + public: + using framework::OperatorBase::OperatorBase; + + private: + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { + const auto& shape_concat = Attr>("shape_concat"); + const auto& ranks = Attr>("ranks"); + PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty()); + PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0), + int(shape_concat.size()), + "The accumulate of all ranks should be equal to the " + "shape concat's length."); + const auto& file_names = Attr>("file_names"); + PADDLE_ENFORCE(!file_names.empty(), "No file to be read!"); + const size_t thread_num = Attr("thread_num"); + + auto* out = scope.FindVar(Output("Out")) + ->template GetMutable(); + out->Reset(new MultipleReader( + file_names, RestoreShapes(shape_concat, ranks), thread_num)); + } +}; + +class OpenFilesOpMaker : public framework::OpProtoAndCheckerMaker { + public: + OpenFilesOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(op_proto, op_checker) { + AddComment(R"DOC( + OpenFiles Operator + + An OpenFilesOp creates a MultipleReader, which is able to + read data multi-threaded from multiple files. + )DOC"); + AddOutput("Out", "(ReaderHolder) The created MultipleReader."); + AddAttr>("shape_concat", + "The concat of all data's shapes."); + AddAttr>( + "ranks", + "The ranks of each data." + "e.g." + "shape_concat = [2,3,4,5,6]" + "ranks = [3,2]" + "It means the reader will generate two data each time," + "whose shapes are [2,3,4] and [5,6] respectively."); + AddAttr>("lod_levels", "The LoD levels of each data."); + AddAttr>("file_names", "Files to be read."); + AddAttr("thread_num", "The maximal concurrent prefetch thread number.") + .GreaterThan(0); + } +}; + +} // namespace reader +} // namespace operators +} // namespace paddle + +namespace reader = paddle::operators::reader; + +REGISTER_FILE_READER_OPERATOR(open_files, reader::OpenFilesOp, + reader::OpenFilesOpMaker); \ No newline at end of file diff --git a/paddle/fluid/operators/reader/reader_op_registry.h b/paddle/fluid/operators/reader/reader_op_registry.h index 58f9b4ba35546571fd3b1d0c3ce128f18e248f01..feab7c63a3eeea6da78dca4c752c33f76df25a80 100644 --- a/paddle/fluid/operators/reader/reader_op_registry.h +++ b/paddle/fluid/operators/reader/reader_op_registry.h @@ -21,6 +21,8 @@ namespace paddle { namespace operators { namespace reader { +static constexpr char kFileFormatSeparator[] = ":"; + using FileReaderCreator = std::function&)>; @@ -29,12 +31,28 @@ std::unordered_map& FileReaderRegistry(); template int RegisterFileReader(const std::string& filetype) { FileReaderRegistry()[filetype] = []( - const std::string& fn, const std::vector& dim) { - return new Reader(fn, dim); + const std::string& fn, const std::vector& dims) { + return new Reader(fn, dims); }; return 0; } +std::unique_ptr CreateReaderByFileName( + const std::string& file_name, const std::vector& dims) { + size_t separator_pos = file_name.find(kFileFormatSeparator); + PADDLE_ENFORCE_NE(separator_pos, std::string::npos, + "File name illegal! A legal file name should be like: " + "[file_format]:[file_name] (e.g., 'recordio:data_file')."); + std::string filetype = file_name.substr(0, separator_pos); + std::string f_name = file_name.substr(separator_pos + 1); + + auto itor = FileReaderRegistry().find(filetype); + PADDLE_ENFORCE(itor != FileReaderRegistry().end(), + "No file reader registered for '%s' format.", filetype); + framework::ReaderBase* reader = (itor->second)(f_name, dims); + return std::unique_ptr(reader); +} + extern std::vector RestoreShapes( const std::vector& shape_concat, const std::vector& ranks);