diff --git a/paddle/fluid/operators/reader/open_files_op.cc b/paddle/fluid/operators/reader/open_files_op.cc index 6b62e1db49076008255c15845bdd1d0dd27d297f..49cdf5365c996485ea31b9c5db3adbf3e1ba18d1 100644 --- a/paddle/fluid/operators/reader/open_files_op.cc +++ b/paddle/fluid/operators/reader/open_files_op.cc @@ -21,12 +21,10 @@ namespace reader { class MultipleReader : public framework::ReaderBase { public: - struct Quota {}; - MultipleReader(const std::vector& file_names, const std::vector& dims, size_t thread_num) - : file_names_(file_names), dims_(dims), thread_num_(thread_num) { - PADDLE_ENFORCE_GT(thread_num_, 0); + : file_names_(file_names), dims_(dims) { + prefetchers_.resize(thread_num); StartNewScheduler(); } @@ -34,16 +32,20 @@ class MultipleReader : public framework::ReaderBase { bool HasNext() const override; void ReInit() override; + ~MultipleReader() { EndScheduler(); } + private: void StartNewScheduler(); + void EndScheduler(); void ScheduleThreadFunc(); - void PrefetchThreadFunc(std::string file_name); + void PrefetchThreadFunc(std::string file_name, size_t thread_idx); std::vector file_names_; std::vector dims_; - size_t thread_num_; + std::thread scheduler_; + std::vector prefetchers_; framework::Channel* waiting_file_idx_; - framework::Channel* thread_quotas_; + framework::Channel* available_thread_idx_; framework::Channel>* buffer_; mutable std::vector local_buffer_; }; @@ -65,59 +67,76 @@ bool MultipleReader::HasNext() const { } void MultipleReader::ReInit() { - buffer_->Close(); - thread_quotas_->Close(); - waiting_file_idx_->Close(); + EndScheduler(); local_buffer_.clear(); - StartNewScheduler(); } void MultipleReader::StartNewScheduler() { + size_t thread_num = prefetchers_.size(); waiting_file_idx_ = framework::MakeChannel(file_names_.size()); - thread_quotas_ = framework::MakeChannel(thread_num_); + available_thread_idx_ = framework::MakeChannel(thread_num); buffer_ = - framework::MakeChannel>(thread_num_); + framework::MakeChannel>(thread_num); for (size_t i = 0; i < file_names_.size(); ++i) { waiting_file_idx_->Send(&i); } waiting_file_idx_->Close(); - for (size_t i = 0; i < thread_num_; ++i) { - Quota quota; - thread_quotas_->Send("a); + for (size_t i = 0; i < thread_num; ++i) { + available_thread_idx_->Send(&i); } - std::thread scheduler([this] { ScheduleThreadFunc(); }); - scheduler.detach(); + scheduler_ = std::thread([this] { ScheduleThreadFunc(); }); +} + +void MultipleReader::EndScheduler() { + available_thread_idx_->Close(); + buffer_->Close(); + waiting_file_idx_->Close(); + scheduler_.join(); + delete buffer_; + delete available_thread_idx_; + delete waiting_file_idx_; } void MultipleReader::ScheduleThreadFunc() { VLOG(5) << "MultipleReader schedule thread starts."; size_t completed_thread_num = 0; - Quota quota; - while (thread_quotas_->Receive("a)) { + size_t thread_idx; + while (available_thread_idx_->Receive(&thread_idx)) { + std::thread& prefetcher = prefetchers_[thread_idx]; + if (prefetcher.joinable()) { + prefetcher.join(); + } size_t file_idx; if (waiting_file_idx_->Receive(&file_idx)) { // Still have files to read. Start a new prefetch thread. std::string file_name = file_names_[file_idx]; - std::thread prefetcher( - [this, file_name] { PrefetchThreadFunc(file_name); }); - prefetcher.detach(); + prefetcher = std::thread([this, file_name, thread_idx] { + PrefetchThreadFunc(file_name, thread_idx); + }); } else { // No more file to read. ++completed_thread_num; - if (completed_thread_num == thread_num_) { - thread_quotas_->Close(); - buffer_->Close(); + if (completed_thread_num == prefetchers_.size()) { break; } } } + // If users invoke ReInit() when scheduler is running, it will close the + // 'avaiable_thread_idx_' and prefecther threads have no way to tell scheduler + // to release their resource. So a check is needed before scheduler ends. + for (auto& p : prefetchers_) { + if (p.joinable()) { + p.join(); + } + } VLOG(5) << "MultipleReader schedule thread terminates."; } -void MultipleReader::PrefetchThreadFunc(std::string file_name) { +void MultipleReader::PrefetchThreadFunc(std::string file_name, + size_t thread_idx) { VLOG(5) << "The prefetch thread of file '" << file_name << "' starts."; std::unique_ptr reader = CreateReaderByFileName(file_name, dims_); @@ -131,8 +150,10 @@ void MultipleReader::PrefetchThreadFunc(std::string file_name) { break; } } - Quota quota; - thread_quotas_->Send("a); + if (!available_thread_idx_->Send(&thread_idx)) { + VLOG(5) << "WARNING: The available_thread_idx_ channel has been closed. " + "Fail to send thread_idx."; + } VLOG(5) << "The prefetch thread of file '" << file_name << "' terminates."; }