diff --git a/paddle/fluid/operators/reader/open_files_op.cc b/paddle/fluid/operators/reader/open_files_op.cc index 8c0dac65dd691954b112bfa61622d399b2b9c3e5..31e5d81e55ed9703eb3a9ef2595fa2a280f1a734 100644 --- a/paddle/fluid/operators/reader/open_files_op.cc +++ b/paddle/fluid/operators/reader/open_files_op.cc @@ -26,7 +26,11 @@ class MultiFileReader : public framework::ReaderBase { MultiFileReader(const std::vector& file_names, const std::vector& dims, size_t thread_num, size_t buffer_size) - : file_names_(file_names), dims_(dims), buffer_size_(buffer_size) { + : buffer_size_(buffer_size) { + readers_.reserve(file_names.size()); + for (const std::string& f_name : file_names) { + readers_.emplace_back(CreateReaderByFileName(f_name, dims)); + } prefetchers_.resize(thread_num); StartNewScheduler(); } @@ -40,14 +44,13 @@ class MultiFileReader : public framework::ReaderBase { void StartNewScheduler(); void EndScheduler(); void ScheduleThreadFunc(); - void PrefetchThreadFunc(std::string file_name, size_t thread_idx); + void PrefetchThreadFunc(size_t reader_idx, size_t thread_idx); - std::vector file_names_; - std::vector dims_; + std::vector> readers_; std::thread scheduler_; std::vector prefetchers_; size_t buffer_size_; - reader::BlockingQueue* waiting_file_idx_; + reader::BlockingQueue* waiting_reader_idx_; reader::BlockingQueue* available_thread_idx_; reader::BlockingQueue>* buffer_; }; @@ -65,15 +68,15 @@ void MultiFileReader::ReInit() { void MultiFileReader::StartNewScheduler() { size_t thread_num = prefetchers_.size(); - waiting_file_idx_ = new reader::BlockingQueue(file_names_.size()); + waiting_reader_idx_ = new reader::BlockingQueue(readers_.size()); available_thread_idx_ = new reader::BlockingQueue(thread_num); buffer_ = new reader::BlockingQueue>( buffer_size_); - for (size_t i = 0; i < file_names_.size(); ++i) { - waiting_file_idx_->Send(i); + for (size_t i = 0; i < readers_.size(); ++i) { + waiting_reader_idx_->Send(i); } - waiting_file_idx_->Close(); + waiting_reader_idx_->Close(); for (size_t i = 0; i < thread_num; ++i) { available_thread_idx_->Send(i); } @@ -84,13 +87,13 @@ void MultiFileReader::StartNewScheduler() { void MultiFileReader::EndScheduler() { available_thread_idx_->Close(); buffer_->Close(); - waiting_file_idx_->Close(); + waiting_reader_idx_->Close(); if (scheduler_.joinable()) { scheduler_.join(); } delete buffer_; delete available_thread_idx_; - delete waiting_file_idx_; + delete waiting_reader_idx_; } void MultiFileReader::ScheduleThreadFunc() { @@ -102,12 +105,11 @@ void MultiFileReader::ScheduleThreadFunc() { if (prefetcher.joinable()) { prefetcher.join(); } - size_t file_idx; - if (waiting_file_idx_->Receive(&file_idx)) { + size_t reader_idx; + if (waiting_reader_idx_->Receive(&reader_idx)) { // Still have files to read. Start a new prefetch thread. - std::string file_name = file_names_[file_idx]; - prefetcher = std::thread([this, file_name, thread_idx] { - PrefetchThreadFunc(file_name, thread_idx); + prefetcher = std::thread([this, reader_idx, thread_idx] { + PrefetchThreadFunc(reader_idx, thread_idx); }); } else { // No more file to read. @@ -129,23 +131,22 @@ void MultiFileReader::ScheduleThreadFunc() { VLOG(5) << "MultiFileReader schedule thread terminates."; } -void MultiFileReader::PrefetchThreadFunc(std::string file_name, - size_t thread_idx) { - VLOG(5) << "The prefetch thread of file '" << file_name << "' starts."; - std::unique_ptr reader = - CreateReaderByFileName(file_name, dims_); +void MultiFileReader::PrefetchThreadFunc(size_t reader_idx, size_t thread_idx) { + VLOG(5) << "The prefetch thread of file idx '" << reader_idx << "' starts."; + std::unique_ptr& reader = readers_[reader_idx]; while (true) { std::vector ins; reader->ReadNext(&ins); if (ins.empty()) { + reader->ReInit(); break; } try { buffer_->Send(std::move(ins)); } catch (paddle::platform::EnforceNotMet e) { VLOG(5) << "WARNING: The buffer channel has been closed. The prefetch " - "thread of file '" - << file_name << "' will terminate."; + "thread of file idx '" + << reader_idx << "' will terminate."; break; } } @@ -154,7 +155,8 @@ void MultiFileReader::PrefetchThreadFunc(std::string file_name, VLOG(5) << "WARNING: The available_thread_idx_ channel has been closed. " "Fail to send thread_idx."; } - VLOG(5) << "The prefetch thread of file '" << file_name << "' terminates."; + VLOG(5) << "The prefetch thread of file idx '" << reader_idx + << "' terminates."; } class OpenFilesOp : public framework::OperatorBase {