提交 a2981f5c 编写于 作者: F fengjiayi

fix a bug

上级 87ac675a
...@@ -21,12 +21,10 @@ namespace reader { ...@@ -21,12 +21,10 @@ namespace reader {
class MultipleReader : public framework::ReaderBase { class MultipleReader : public framework::ReaderBase {
public: public:
struct Quota {};
MultipleReader(const std::vector<std::string>& file_names, MultipleReader(const std::vector<std::string>& file_names,
const std::vector<framework::DDim>& dims, size_t thread_num) const std::vector<framework::DDim>& dims, size_t thread_num)
: file_names_(file_names), dims_(dims), thread_num_(thread_num) { : file_names_(file_names), dims_(dims) {
PADDLE_ENFORCE_GT(thread_num_, 0); prefetchers_.resize(thread_num);
StartNewScheduler(); StartNewScheduler();
} }
...@@ -34,16 +32,20 @@ class MultipleReader : public framework::ReaderBase { ...@@ -34,16 +32,20 @@ class MultipleReader : public framework::ReaderBase {
bool HasNext() const override; bool HasNext() const override;
void ReInit() override; void ReInit() override;
~MultipleReader() { EndScheduler(); }
private: private:
void StartNewScheduler(); void StartNewScheduler();
void EndScheduler();
void ScheduleThreadFunc(); void ScheduleThreadFunc();
void PrefetchThreadFunc(std::string file_name); void PrefetchThreadFunc(std::string file_name, size_t thread_idx);
std::vector<std::string> file_names_; std::vector<std::string> file_names_;
std::vector<framework::DDim> dims_; std::vector<framework::DDim> dims_;
size_t thread_num_; std::thread scheduler_;
std::vector<std::thread> prefetchers_;
framework::Channel<size_t>* waiting_file_idx_; framework::Channel<size_t>* waiting_file_idx_;
framework::Channel<Quota>* thread_quotas_; framework::Channel<size_t>* available_thread_idx_;
framework::Channel<std::vector<framework::LoDTensor>>* buffer_; framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
mutable std::vector<framework::LoDTensor> local_buffer_; mutable std::vector<framework::LoDTensor> local_buffer_;
}; };
...@@ -65,59 +67,76 @@ bool MultipleReader::HasNext() const { ...@@ -65,59 +67,76 @@ bool MultipleReader::HasNext() const {
} }
void MultipleReader::ReInit() { void MultipleReader::ReInit() {
buffer_->Close(); EndScheduler();
thread_quotas_->Close();
waiting_file_idx_->Close();
local_buffer_.clear(); local_buffer_.clear();
StartNewScheduler(); StartNewScheduler();
} }
void MultipleReader::StartNewScheduler() { void MultipleReader::StartNewScheduler() {
size_t thread_num = prefetchers_.size();
waiting_file_idx_ = framework::MakeChannel<size_t>(file_names_.size()); waiting_file_idx_ = framework::MakeChannel<size_t>(file_names_.size());
thread_quotas_ = framework::MakeChannel<Quota>(thread_num_); available_thread_idx_ = framework::MakeChannel<size_t>(thread_num);
buffer_ = buffer_ =
framework::MakeChannel<std::vector<framework::LoDTensor>>(thread_num_); framework::MakeChannel<std::vector<framework::LoDTensor>>(thread_num);
for (size_t i = 0; i < file_names_.size(); ++i) { for (size_t i = 0; i < file_names_.size(); ++i) {
waiting_file_idx_->Send(&i); waiting_file_idx_->Send(&i);
} }
waiting_file_idx_->Close(); waiting_file_idx_->Close();
for (size_t i = 0; i < thread_num_; ++i) { for (size_t i = 0; i < thread_num; ++i) {
Quota quota; available_thread_idx_->Send(&i);
thread_quotas_->Send(&quota);
} }
std::thread scheduler([this] { ScheduleThreadFunc(); }); scheduler_ = std::thread([this] { ScheduleThreadFunc(); });
scheduler.detach(); }
void MultipleReader::EndScheduler() {
available_thread_idx_->Close();
buffer_->Close();
waiting_file_idx_->Close();
scheduler_.join();
delete buffer_;
delete available_thread_idx_;
delete waiting_file_idx_;
} }
void MultipleReader::ScheduleThreadFunc() { void MultipleReader::ScheduleThreadFunc() {
VLOG(5) << "MultipleReader schedule thread starts."; VLOG(5) << "MultipleReader schedule thread starts.";
size_t completed_thread_num = 0; size_t completed_thread_num = 0;
Quota quota; size_t thread_idx;
while (thread_quotas_->Receive(&quota)) { while (available_thread_idx_->Receive(&thread_idx)) {
std::thread& prefetcher = prefetchers_[thread_idx];
if (prefetcher.joinable()) {
prefetcher.join();
}
size_t file_idx; size_t file_idx;
if (waiting_file_idx_->Receive(&file_idx)) { if (waiting_file_idx_->Receive(&file_idx)) {
// Still have files to read. Start a new prefetch thread. // Still have files to read. Start a new prefetch thread.
std::string file_name = file_names_[file_idx]; std::string file_name = file_names_[file_idx];
std::thread prefetcher( prefetcher = std::thread([this, file_name, thread_idx] {
[this, file_name] { PrefetchThreadFunc(file_name); }); PrefetchThreadFunc(file_name, thread_idx);
prefetcher.detach(); });
} else { } else {
// No more file to read. // No more file to read.
++completed_thread_num; ++completed_thread_num;
if (completed_thread_num == thread_num_) { if (completed_thread_num == prefetchers_.size()) {
thread_quotas_->Close();
buffer_->Close();
break; break;
} }
} }
} }
// If users invoke ReInit() when scheduler is running, it will close the
// 'avaiable_thread_idx_' and prefecther threads have no way to tell scheduler
// to release their resource. So a check is needed before scheduler ends.
for (auto& p : prefetchers_) {
if (p.joinable()) {
p.join();
}
}
VLOG(5) << "MultipleReader schedule thread terminates."; VLOG(5) << "MultipleReader schedule thread terminates.";
} }
void MultipleReader::PrefetchThreadFunc(std::string file_name) { void MultipleReader::PrefetchThreadFunc(std::string file_name,
size_t thread_idx) {
VLOG(5) << "The prefetch thread of file '" << file_name << "' starts."; VLOG(5) << "The prefetch thread of file '" << file_name << "' starts.";
std::unique_ptr<framework::ReaderBase> reader = std::unique_ptr<framework::ReaderBase> reader =
CreateReaderByFileName(file_name, dims_); CreateReaderByFileName(file_name, dims_);
...@@ -131,8 +150,10 @@ void MultipleReader::PrefetchThreadFunc(std::string file_name) { ...@@ -131,8 +150,10 @@ void MultipleReader::PrefetchThreadFunc(std::string file_name) {
break; break;
} }
} }
Quota quota; if (!available_thread_idx_->Send(&thread_idx)) {
thread_quotas_->Send(&quota); VLOG(5) << "WARNING: The available_thread_idx_ channel has been closed. "
"Fail to send thread_idx.";
}
VLOG(5) << "The prefetch thread of file '" << file_name << "' terminates."; VLOG(5) << "The prefetch thread of file '" << file_name << "' terminates.";
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册