diff --git a/paddle/fluid/operators/reader/create_threaded_reader_op.cc b/paddle/fluid/operators/reader/create_threaded_reader_op.cc index 565cbe4d9f3bcd59d3e7410012f28439a592a287..854381e0eee163cd9193452da34f34f7797e191d 100644 --- a/paddle/fluid/operators/reader/create_threaded_reader_op.cc +++ b/paddle/fluid/operators/reader/create_threaded_reader_op.cc @@ -124,7 +124,7 @@ class CreateThreadedReaderOpMaker : public DecoratedReaderMakerBase { can enable them by setting 'unsafe_mode' true. In this case, 'HasNext()' returning true only guarantees the safety of invoking 'ReadNext()' in the same thread. Each thread must - invoke 'HasNext()' and 'ReadNext()' in pair. + invoke 'HasNext()' and 'ReadNext()' in pairs. )DOC"); } }; diff --git a/paddle/fluid/operators/reader/open_files_op.cc b/paddle/fluid/operators/reader/open_files_op.cc index db4e619e7b1cee2c934aa5da38a95c4a99c4ead3..45db94e7808d1d434b6a12e1c3813e55fd64a395 100644 --- a/paddle/fluid/operators/reader/open_files_op.cc +++ b/paddle/fluid/operators/reader/open_files_op.cc @@ -19,27 +19,11 @@ namespace paddle { namespace operators { namespace reader { -class MultipleReader : public framework::ReaderBase { +class MultiFileReader : public framework::ReaderBase { public: - class ThreadBufferMap { - public: - std::vector& operator[]( - const std::thread::id& thread_id) { - std::lock_guard lock(mutex_); - return buffer_[thread_id]; - } - - void Clear() { buffer_.clear(); } - - private: - std::mutex mutex_; - std::unordered_map> - buffer_; - }; - - MultipleReader(const std::vector& file_names, - const std::vector& dims, size_t thread_num, - size_t buffer_size) + MultiFileReader(const std::vector& file_names, + const std::vector& dims, size_t thread_num, + size_t buffer_size) : file_names_(file_names), dims_(dims), buffer_size_(buffer_size) { prefetchers_.resize(thread_num); StartNewScheduler(); @@ -49,7 +33,7 @@ class MultipleReader : public framework::ReaderBase { bool HasNext() const override; void ReInit() override; - ~MultipleReader() { EndScheduler(); } + ~MultiFileReader() { EndScheduler(); } private: void StartNewScheduler(); @@ -65,31 +49,27 @@ class MultipleReader : public framework::ReaderBase { framework::Channel* waiting_file_idx_; framework::Channel* available_thread_idx_; framework::Channel>* buffer_; - mutable ThreadBufferMap thread_buffer_map_; }; -void MultipleReader::ReadNext(std::vector* out) { +void MultiFileReader::ReadNext(std::vector* out) { if (!HasNext()) { PADDLE_THROW("There is no next data!"); } - auto& thread_local_buffer = thread_buffer_map_[std::this_thread::get_id()]; - *out = thread_local_buffer; - thread_local_buffer.clear(); + buffer_->Receive(out); } -bool MultipleReader::HasNext() const { - auto& thread_local_buffer = thread_buffer_map_[std::this_thread::get_id()]; - return thread_local_buffer.empty() ? buffer_->Receive(&thread_local_buffer) - : true; +bool MultiFileReader::HasNext() const { + while (!buffer_->IsClosed() && !buffer_->CanReceive()) { + } + return buffer_->CanReceive(); } -void MultipleReader::ReInit() { +void MultiFileReader::ReInit() { EndScheduler(); - thread_buffer_map_.Clear(); StartNewScheduler(); } -void MultipleReader::StartNewScheduler() { +void MultiFileReader::StartNewScheduler() { size_t thread_num = prefetchers_.size(); waiting_file_idx_ = framework::MakeChannel(file_names_.size()); available_thread_idx_ = framework::MakeChannel(thread_num); @@ -107,7 +87,7 @@ void MultipleReader::StartNewScheduler() { scheduler_ = std::thread([this] { ScheduleThreadFunc(); }); } -void MultipleReader::EndScheduler() { +void MultiFileReader::EndScheduler() { available_thread_idx_->Close(); buffer_->Close(); waiting_file_idx_->Close(); @@ -119,8 +99,8 @@ void MultipleReader::EndScheduler() { delete waiting_file_idx_; } -void MultipleReader::ScheduleThreadFunc() { - VLOG(5) << "MultipleReader schedule thread starts."; +void MultiFileReader::ScheduleThreadFunc() { + VLOG(5) << "MultiFileReader schedule thread starts."; size_t completed_thread_num = 0; size_t thread_idx; while (available_thread_idx_->Receive(&thread_idx)) { @@ -152,11 +132,11 @@ void MultipleReader::ScheduleThreadFunc() { p.join(); } } - VLOG(5) << "MultipleReader schedule thread terminates."; + VLOG(5) << "MultiFileReader schedule thread terminates."; } -void MultipleReader::PrefetchThreadFunc(std::string file_name, - size_t thread_idx) { +void MultiFileReader::PrefetchThreadFunc(std::string file_name, + size_t thread_idx) { VLOG(5) << "The prefetch thread of file '" << file_name << "' starts."; std::unique_ptr reader = CreateReaderByFileName(file_name, dims_); @@ -203,9 +183,9 @@ class OpenFilesOp : public framework::OperatorBase { auto* out = scope.FindVar(Output("Out")) ->template GetMutable(); - out->Reset(new MultipleReader(file_names, - RestoreShapes(shape_concat, ranks), - thread_num, buffer_size)); + out->Reset(new MultiFileReader(file_names, + RestoreShapes(shape_concat, ranks), + thread_num, buffer_size)); } }; @@ -221,7 +201,7 @@ class OpenFilesOpMaker : public FileReaderMakerBase { AddComment(R"DOC( OpenFiles Operator - An OpenFilesOp creates a MultipleReader, which is able to + An OpenFilesOp creates a MultiFileReader, which is able to read data multi-threaded from multiple files. )DOC"); }