提交 a2981f5c 编写于 作者: F fengjiayi

fix a bug

上级 87ac675a
......@@ -21,12 +21,10 @@ namespace reader {
class MultipleReader : public framework::ReaderBase {
public:
struct Quota {};
MultipleReader(const std::vector<std::string>& file_names,
const std::vector<framework::DDim>& dims, size_t thread_num)
: file_names_(file_names), dims_(dims), thread_num_(thread_num) {
PADDLE_ENFORCE_GT(thread_num_, 0);
: file_names_(file_names), dims_(dims) {
prefetchers_.resize(thread_num);
StartNewScheduler();
}
......@@ -34,16 +32,20 @@ class MultipleReader : public framework::ReaderBase {
bool HasNext() const override;
void ReInit() override;
~MultipleReader() { EndScheduler(); }
private:
void StartNewScheduler();
void EndScheduler();
void ScheduleThreadFunc();
void PrefetchThreadFunc(std::string file_name);
void PrefetchThreadFunc(std::string file_name, size_t thread_idx);
std::vector<std::string> file_names_;
std::vector<framework::DDim> dims_;
size_t thread_num_;
std::thread scheduler_;
std::vector<std::thread> prefetchers_;
framework::Channel<size_t>* waiting_file_idx_;
framework::Channel<Quota>* thread_quotas_;
framework::Channel<size_t>* available_thread_idx_;
framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
mutable std::vector<framework::LoDTensor> local_buffer_;
};
......@@ -65,59 +67,76 @@ bool MultipleReader::HasNext() const {
}
void MultipleReader::ReInit() {
buffer_->Close();
thread_quotas_->Close();
waiting_file_idx_->Close();
EndScheduler();
local_buffer_.clear();
StartNewScheduler();
}
void MultipleReader::StartNewScheduler() {
size_t thread_num = prefetchers_.size();
waiting_file_idx_ = framework::MakeChannel<size_t>(file_names_.size());
thread_quotas_ = framework::MakeChannel<Quota>(thread_num_);
available_thread_idx_ = framework::MakeChannel<size_t>(thread_num);
buffer_ =
framework::MakeChannel<std::vector<framework::LoDTensor>>(thread_num_);
framework::MakeChannel<std::vector<framework::LoDTensor>>(thread_num);
for (size_t i = 0; i < file_names_.size(); ++i) {
waiting_file_idx_->Send(&i);
}
waiting_file_idx_->Close();
for (size_t i = 0; i < thread_num_; ++i) {
Quota quota;
thread_quotas_->Send(&quota);
for (size_t i = 0; i < thread_num; ++i) {
available_thread_idx_->Send(&i);
}
std::thread scheduler([this] { ScheduleThreadFunc(); });
scheduler.detach();
scheduler_ = std::thread([this] { ScheduleThreadFunc(); });
}
void MultipleReader::EndScheduler() {
available_thread_idx_->Close();
buffer_->Close();
waiting_file_idx_->Close();
scheduler_.join();
delete buffer_;
delete available_thread_idx_;
delete waiting_file_idx_;
}
void MultipleReader::ScheduleThreadFunc() {
VLOG(5) << "MultipleReader schedule thread starts.";
size_t completed_thread_num = 0;
Quota quota;
while (thread_quotas_->Receive(&quota)) {
size_t thread_idx;
while (available_thread_idx_->Receive(&thread_idx)) {
std::thread& prefetcher = prefetchers_[thread_idx];
if (prefetcher.joinable()) {
prefetcher.join();
}
size_t file_idx;
if (waiting_file_idx_->Receive(&file_idx)) {
// Still have files to read. Start a new prefetch thread.
std::string file_name = file_names_[file_idx];
std::thread prefetcher(
[this, file_name] { PrefetchThreadFunc(file_name); });
prefetcher.detach();
prefetcher = std::thread([this, file_name, thread_idx] {
PrefetchThreadFunc(file_name, thread_idx);
});
} else {
// No more file to read.
++completed_thread_num;
if (completed_thread_num == thread_num_) {
thread_quotas_->Close();
buffer_->Close();
if (completed_thread_num == prefetchers_.size()) {
break;
}
}
}
// If users invoke ReInit() when scheduler is running, it will close the
// 'avaiable_thread_idx_' and prefecther threads have no way to tell scheduler
// to release their resource. So a check is needed before scheduler ends.
for (auto& p : prefetchers_) {
if (p.joinable()) {
p.join();
}
}
VLOG(5) << "MultipleReader schedule thread terminates.";
}
void MultipleReader::PrefetchThreadFunc(std::string file_name) {
void MultipleReader::PrefetchThreadFunc(std::string file_name,
size_t thread_idx) {
VLOG(5) << "The prefetch thread of file '" << file_name << "' starts.";
std::unique_ptr<framework::ReaderBase> reader =
CreateReaderByFileName(file_name, dims_);
......@@ -131,8 +150,10 @@ void MultipleReader::PrefetchThreadFunc(std::string file_name) {
break;
}
}
Quota quota;
thread_quotas_->Send(&quota);
if (!available_thread_idx_->Send(&thread_idx)) {
VLOG(5) << "WARNING: The available_thread_idx_ channel has been closed. "
"Fail to send thread_idx.";
}
VLOG(5) << "The prefetch thread of file '" << file_name << "' terminates.";
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册