未验证 提交 87a5590b 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #11151 from JiayiFeng/dev_update_open_files_op

Update open files op
......@@ -26,7 +26,11 @@ class MultiFileReader : public framework::ReaderBase {
MultiFileReader(const std::vector<std::string>& file_names,
const std::vector<framework::DDim>& dims, size_t thread_num,
size_t buffer_size)
: file_names_(file_names), dims_(dims), buffer_size_(buffer_size) {
: buffer_size_(buffer_size) {
readers_.reserve(file_names.size());
for (const std::string& f_name : file_names) {
readers_.emplace_back(CreateReaderByFileName(f_name, dims));
}
prefetchers_.resize(thread_num);
StartNewScheduler();
}
......@@ -40,14 +44,13 @@ class MultiFileReader : public framework::ReaderBase {
void StartNewScheduler();
void EndScheduler();
void ScheduleThreadFunc();
void PrefetchThreadFunc(std::string file_name, size_t thread_idx);
void PrefetchThreadFunc(size_t reader_idx, size_t thread_idx);
std::vector<std::string> file_names_;
std::vector<framework::DDim> dims_;
std::vector<std::unique_ptr<framework::ReaderBase>> readers_;
std::thread scheduler_;
std::vector<std::thread> prefetchers_;
size_t buffer_size_;
reader::BlockingQueue<size_t>* waiting_file_idx_;
reader::BlockingQueue<size_t>* waiting_reader_idx_;
reader::BlockingQueue<size_t>* available_thread_idx_;
reader::BlockingQueue<std::vector<framework::LoDTensor>>* buffer_;
};
......@@ -65,15 +68,15 @@ void MultiFileReader::ReInit() {
void MultiFileReader::StartNewScheduler() {
size_t thread_num = prefetchers_.size();
waiting_file_idx_ = new reader::BlockingQueue<size_t>(file_names_.size());
waiting_reader_idx_ = new reader::BlockingQueue<size_t>(readers_.size());
available_thread_idx_ = new reader::BlockingQueue<size_t>(thread_num);
buffer_ = new reader::BlockingQueue<std::vector<framework::LoDTensor>>(
buffer_size_);
for (size_t i = 0; i < file_names_.size(); ++i) {
waiting_file_idx_->Send(i);
for (size_t i = 0; i < readers_.size(); ++i) {
waiting_reader_idx_->Send(i);
}
waiting_file_idx_->Close();
waiting_reader_idx_->Close();
for (size_t i = 0; i < thread_num; ++i) {
available_thread_idx_->Send(i);
}
......@@ -84,13 +87,13 @@ void MultiFileReader::StartNewScheduler() {
void MultiFileReader::EndScheduler() {
available_thread_idx_->Close();
buffer_->Close();
waiting_file_idx_->Close();
waiting_reader_idx_->Close();
if (scheduler_.joinable()) {
scheduler_.join();
}
delete buffer_;
delete available_thread_idx_;
delete waiting_file_idx_;
delete waiting_reader_idx_;
}
void MultiFileReader::ScheduleThreadFunc() {
......@@ -102,12 +105,11 @@ void MultiFileReader::ScheduleThreadFunc() {
if (prefetcher.joinable()) {
prefetcher.join();
}
size_t file_idx;
if (waiting_file_idx_->Receive(&file_idx)) {
size_t reader_idx;
if (waiting_reader_idx_->Receive(&reader_idx)) {
// Still have files to read. Start a new prefetch thread.
std::string file_name = file_names_[file_idx];
prefetcher = std::thread([this, file_name, thread_idx] {
PrefetchThreadFunc(file_name, thread_idx);
prefetcher = std::thread([this, reader_idx, thread_idx] {
PrefetchThreadFunc(reader_idx, thread_idx);
});
} else {
// No more file to read.
......@@ -129,23 +131,22 @@ void MultiFileReader::ScheduleThreadFunc() {
VLOG(5) << "MultiFileReader schedule thread terminates.";
}
void MultiFileReader::PrefetchThreadFunc(std::string file_name,
size_t thread_idx) {
VLOG(5) << "The prefetch thread of file '" << file_name << "' starts.";
std::unique_ptr<framework::ReaderBase> reader =
CreateReaderByFileName(file_name, dims_);
void MultiFileReader::PrefetchThreadFunc(size_t reader_idx, size_t thread_idx) {
VLOG(5) << "The prefetch thread of file idx '" << reader_idx << "' starts.";
std::unique_ptr<framework::ReaderBase>& reader = readers_[reader_idx];
while (true) {
std::vector<framework::LoDTensor> ins;
reader->ReadNext(&ins);
if (ins.empty()) {
reader->ReInit();
break;
}
try {
buffer_->Send(std::move(ins));
} catch (paddle::platform::EnforceNotMet e) {
VLOG(5) << "WARNING: The buffer channel has been closed. The prefetch "
"thread of file '"
<< file_name << "' will terminate.";
"thread of file idx '"
<< reader_idx << "' will terminate.";
break;
}
}
......@@ -154,7 +155,8 @@ void MultiFileReader::PrefetchThreadFunc(std::string file_name,
VLOG(5) << "WARNING: The available_thread_idx_ channel has been closed. "
"Fail to send thread_idx.";
}
VLOG(5) << "The prefetch thread of file '" << file_name << "' terminates.";
VLOG(5) << "The prefetch thread of file idx '" << reader_idx
<< "' terminates.";
}
class OpenFilesOp : public framework::OperatorBase {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册