// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include // NOLINT #include "ThreadPool.h" #include "paddle/fluid/framework/blocking_queue.h" #include "paddle/fluid/operators/reader/blocking_queue.h" #include "paddle/fluid/operators/reader/buffered_reader.h" #include "paddle/fluid/operators/reader/reader_op_registry.h" namespace paddle { namespace operators { namespace reader { class IReaderContainer { public: virtual ~IReaderContainer() {} virtual void AppendReader( std::unique_ptr&& readers) = 0; virtual void Stop() = 0; virtual void Start() = 0; virtual void ReadNext(std::vector* out) = 0; }; class OrderedReaderContainer : public IReaderContainer { public: void AppendReader(std::unique_ptr&& reader) override { pending_.emplace(std::move(reader)); } void Stop() override { while (!pending_.empty()) { MoveFrontPendingToDone(); } } void Start() override { std::swap(done_, pending_); } void ReadNext(std::vector* out) override { if (!pending_.empty()) { pending_.front()->ReadNext(out); if (out->empty()) { MoveFrontPendingToDone(); ReadNext(out); } } else { out->clear(); } } private: void MoveFrontPendingToDone() { pending_.front()->Shutdown(); pending_.front()->Start(); done_.emplace(move(pending_.front())); pending_.pop(); } std::queue> pending_; std::queue> done_; }; class PreemptiveReaderContainer : public IReaderContainer { using ReaderList = std::list>; struct FutureItem { std::vector data_; ReaderList::iterator reader_it_; std::exception_ptr exception_; }; using FutureList = std::list>; public: explicit PreemptiveReaderContainer(size_t thread_num) : pool_(thread_num) {} void Stop() override { if (!pending_.empty()) { for (auto& reader : pending_) { reader->Shutdown(); } for (auto& fu : futures_) { fu.wait(); } futures_.clear(); for (auto& reader : pending_) { reader->Start(); done_.emplace_back(std::move(reader)); } pending_.clear(); bool timeout; complete_queue_.PopAll(1000, &timeout); PADDLE_ENFORCE(!timeout); } } void Start() override { for (auto& reader : done_) { AppendReader(std::move(reader)); } done_.clear(); } void ReadNext(std::vector* out) override { VLOG(1) << "flag"; if (!pending_.empty()) { auto future_it = complete_queue_.Pop(); FutureItem item = future_it->get(); if (item.exception_) { VLOG(1) << "item has exception!!!"; for (auto it = futures_.begin(); it != futures_.end(); ++it) { if (it != future_it) { it->wait(); // Wait all other threads complete. } } std::rethrow_exception(item.exception_); } else if (item.data_.empty()) { // reader done. done_.emplace_back(std::move(*item.reader_it_)); pending_.erase(item.reader_it_); futures_.erase(future_it); ReadNext(out); } else { *out = item.data_; // continue read async ReadAsync(item.reader_it_, &future_it); } } else { out->clear(); } } private: void AppendReader(std::unique_ptr&& reader) override { pending_.emplace_back(std::move(reader)); auto reader_it = pending_.end(); --reader_it; futures_.emplace_back(); auto future_it = futures_.end(); --future_it; ReadAsync(reader_it, &future_it); } void ReadAsync(const ReaderList::iterator& reader_it, FutureList::iterator* future_it_ptr) { auto& future_it = *future_it_ptr; *future_it = pool_.enqueue([reader_it, future_it, this] { try { FutureItem item; item.reader_it_ = reader_it; (*reader_it)->ReadNext(&item.data_); if (item.data_.empty()) { (*reader_it)->Shutdown(); (*reader_it)->Start(); } complete_queue_.Push(future_it); return item; } catch (...) { FutureItem item; item.exception_ = std::current_exception(); complete_queue_.Push(future_it); return item; } }); } FutureList futures_; ThreadPool pool_; framework::BlockingQueue complete_queue_; std::list> pending_; std::list> done_; }; class MultiFileReader : public framework::ReaderBase { public: MultiFileReader(const std::vector& file_names, std::unique_ptr&& container) : container_(std::move(container)) { for (auto& fn : file_names) { container_->AppendReader(CreateReaderByFileName(fn)); } } ~MultiFileReader() { container_->Stop(); } protected: void ReadNextImpl(std::vector* out) override { container_->ReadNext(out); } void ShutdownImpl() override { container_->Stop(); } void StartImpl() override { container_->Start(); } private: std::unique_ptr container_; }; class OpenFilesOp : public framework::OperatorBase { public: using framework::OperatorBase::OperatorBase; private: void RunImpl(const framework::Scope& scope, const platform::Place& dev_place) const override { const auto& shape_concat = Attr>("shape_concat"); const auto& ranks = Attr>("ranks"); PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty()); PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0), static_cast(shape_concat.size()), "The accumulate of all ranks should be equal to the " "shape concat's length."); const auto& file_names = Attr>("file_names"); PADDLE_ENFORCE(!file_names.empty(), "No file to be read!"); bool is_test = Attr("is_test"); auto* out = scope.FindVar(Output("Out")) ->template GetMutable(); std::unique_ptr container; if (is_test) { container.reset(new OrderedReaderContainer()); } else { container.reset(new PreemptiveReaderContainer( static_cast(Attr("thread_num")))); } std::shared_ptr reader( new MultiFileReader(file_names, std::move(container))); auto buffer_size = Attr("buffer_size"); if (buffer_size > 1) { reader = framework::MakeDecoratedReader( reader, platform::CPUPlace(), buffer_size); } out->Reset(reader); } }; class OpenFilesOpMaker : public FileReaderMakerBase { protected: void Apply() override { AddAttr>("file_names", "Files to be read."); AddAttr("is_test", "Used for testing data.").SetDefault(false); AddComment(R"DOC( OpenFiles Operator An OpenFilesOp creates a MultiFileReader, which is able to read data multi-threaded from multiple files. )DOC"); AddAttr("thread_num", "The maximal concurrent prefetch thread number. Used only " "when is_test = False"); AddAttr("buffer_size", "The reading buffer of these files.") .GreaterThan(0); } }; } // namespace reader } // namespace operators } // namespace paddle namespace reader = paddle::operators::reader; REGISTER_FILE_READER_OPERATOR(open_files, reader::OpenFilesOp, reader::OpenFilesOpMaker);