提交 a7866113 编写于 作者: F fengjiayi

Replace Channel in MultiFileReader with BlockingQueue

上级 1a25f3cd
......@@ -14,7 +14,7 @@
#include <thread> // NOLINT
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/operators/reader/blocking_queue.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace paddle {
......@@ -48,9 +48,9 @@ class MultiFileReader : public framework::ReaderBase {
std::thread scheduler_;
std::vector<std::thread> prefetchers_;
size_t buffer_size_;
framework::Channel<size_t>* waiting_file_idx_;
framework::Channel<size_t>* available_thread_idx_;
framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
reader::BlockingQueue<size_t>* waiting_file_idx_;
reader::BlockingQueue<size_t>* available_thread_idx_;
reader::BlockingQueue<std::vector<framework::LoDTensor>>* buffer_;
};
void MultiFileReader::ReadNext(std::vector<framework::LoDTensor>* out) {
......@@ -73,17 +73,17 @@ bool MultiFileReader::HasNext() {
void MultiFileReader::StartNewScheduler() {
size_t thread_num = prefetchers_.size();
waiting_file_idx_ = framework::MakeChannel<size_t>(file_names_.size());
available_thread_idx_ = framework::MakeChannel<size_t>(thread_num);
buffer_ =
framework::MakeChannel<std::vector<framework::LoDTensor>>(buffer_size_);
waiting_file_idx_ = new reader::BlockingQueue<size_t>(file_names_.size());
available_thread_idx_ = new reader::BlockingQueue<size_t>(thread_num);
buffer_ = new reader::BlockingQueue<std::vector<framework::LoDTensor>>(
buffer_size_);
for (size_t i = 0; i < file_names_.size(); ++i) {
waiting_file_idx_->Send(&i);
waiting_file_idx_->Send(i);
}
waiting_file_idx_->Close();
for (size_t i = 0; i < thread_num; ++i) {
available_thread_idx_->Send(&i);
available_thread_idx_->Send(i);
}
scheduler_ = std::thread([this] { ScheduleThreadFunc(); });
......@@ -149,7 +149,7 @@ void MultiFileReader::PrefetchThreadFunc(std::string file_name,
break;
}
try {
buffer_->Send(&ins);
buffer_->Send(std::move(ins));
} catch (paddle::platform::EnforceNotMet e) {
VLOG(5) << "WARNING: The buffer channel has been closed. The prefetch "
"thread of file '"
......@@ -158,9 +158,7 @@ void MultiFileReader::PrefetchThreadFunc(std::string file_name,
}
}
try {
available_thread_idx_->Send(&thread_idx);
} catch (paddle::platform::EnforceNotMet e) {
if (!available_thread_idx_->Send(thread_idx)) {
VLOG(5) << "WARNING: The available_thread_idx_ channel has been closed. "
"Fail to send thread_idx.";
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册