提交 ee4e567d 编写于 作者: F fengjiayi

Creating readers before training begining

上级 e0a8c584
...@@ -26,7 +26,11 @@ class MultiFileReader : public framework::ReaderBase { ...@@ -26,7 +26,11 @@ class MultiFileReader : public framework::ReaderBase {
MultiFileReader(const std::vector<std::string>& file_names, MultiFileReader(const std::vector<std::string>& file_names,
const std::vector<framework::DDim>& dims, size_t thread_num, const std::vector<framework::DDim>& dims, size_t thread_num,
size_t buffer_size) size_t buffer_size)
: file_names_(file_names), dims_(dims), buffer_size_(buffer_size) { : buffer_size_(buffer_size) {
readers_.resize(file_names.size());
for (const std::string& f_name : file_names) {
readers_.emplace_back(CreateReaderByFileName(f_name, dims));
}
prefetchers_.resize(thread_num); prefetchers_.resize(thread_num);
StartNewScheduler(); StartNewScheduler();
} }
...@@ -40,14 +44,13 @@ class MultiFileReader : public framework::ReaderBase { ...@@ -40,14 +44,13 @@ class MultiFileReader : public framework::ReaderBase {
void StartNewScheduler(); void StartNewScheduler();
void EndScheduler(); void EndScheduler();
void ScheduleThreadFunc(); void ScheduleThreadFunc();
void PrefetchThreadFunc(std::string file_name, size_t thread_idx); void PrefetchThreadFunc(size_t reader_idx, size_t thread_idx);
std::vector<std::string> file_names_; std::vector<std::unique_ptr<framework::ReaderBase>> readers_;
std::vector<framework::DDim> dims_;
std::thread scheduler_; std::thread scheduler_;
std::vector<std::thread> prefetchers_; std::vector<std::thread> prefetchers_;
size_t buffer_size_; size_t buffer_size_;
reader::BlockingQueue<size_t>* waiting_file_idx_; reader::BlockingQueue<size_t>* waiting_reader_idx_;
reader::BlockingQueue<size_t>* available_thread_idx_; reader::BlockingQueue<size_t>* available_thread_idx_;
reader::BlockingQueue<std::vector<framework::LoDTensor>>* buffer_; reader::BlockingQueue<std::vector<framework::LoDTensor>>* buffer_;
}; };
...@@ -60,20 +63,23 @@ void MultiFileReader::ReadNext(std::vector<framework::LoDTensor>* out) { ...@@ -60,20 +63,23 @@ void MultiFileReader::ReadNext(std::vector<framework::LoDTensor>* out) {
void MultiFileReader::ReInit() { void MultiFileReader::ReInit() {
EndScheduler(); EndScheduler();
for (auto& reader : readers_) {
reader->ReInit();
}
StartNewScheduler(); StartNewScheduler();
} }
void MultiFileReader::StartNewScheduler() { void MultiFileReader::StartNewScheduler() {
size_t thread_num = prefetchers_.size(); size_t thread_num = prefetchers_.size();
waiting_file_idx_ = new reader::BlockingQueue<size_t>(file_names_.size()); waiting_reader_idx_ = new reader::BlockingQueue<size_t>(readers_.size());
available_thread_idx_ = new reader::BlockingQueue<size_t>(thread_num); available_thread_idx_ = new reader::BlockingQueue<size_t>(thread_num);
buffer_ = new reader::BlockingQueue<std::vector<framework::LoDTensor>>( buffer_ = new reader::BlockingQueue<std::vector<framework::LoDTensor>>(
buffer_size_); buffer_size_);
for (size_t i = 0; i < file_names_.size(); ++i) { for (size_t i = 0; i < readers_.size(); ++i) {
waiting_file_idx_->Send(i); waiting_reader_idx_->Send(i);
} }
waiting_file_idx_->Close(); waiting_reader_idx_->Close();
for (size_t i = 0; i < thread_num; ++i) { for (size_t i = 0; i < thread_num; ++i) {
available_thread_idx_->Send(i); available_thread_idx_->Send(i);
} }
...@@ -84,13 +90,13 @@ void MultiFileReader::StartNewScheduler() { ...@@ -84,13 +90,13 @@ void MultiFileReader::StartNewScheduler() {
void MultiFileReader::EndScheduler() { void MultiFileReader::EndScheduler() {
available_thread_idx_->Close(); available_thread_idx_->Close();
buffer_->Close(); buffer_->Close();
waiting_file_idx_->Close(); waiting_reader_idx_->Close();
if (scheduler_.joinable()) { if (scheduler_.joinable()) {
scheduler_.join(); scheduler_.join();
} }
delete buffer_; delete buffer_;
delete available_thread_idx_; delete available_thread_idx_;
delete waiting_file_idx_; delete waiting_reader_idx_;
} }
void MultiFileReader::ScheduleThreadFunc() { void MultiFileReader::ScheduleThreadFunc() {
...@@ -102,12 +108,11 @@ void MultiFileReader::ScheduleThreadFunc() { ...@@ -102,12 +108,11 @@ void MultiFileReader::ScheduleThreadFunc() {
if (prefetcher.joinable()) { if (prefetcher.joinable()) {
prefetcher.join(); prefetcher.join();
} }
size_t file_idx; size_t reader_idx;
if (waiting_file_idx_->Receive(&file_idx)) { if (waiting_reader_idx_->Receive(&reader_idx)) {
// Still have files to read. Start a new prefetch thread. // Still have files to read. Start a new prefetch thread.
std::string file_name = file_names_[file_idx]; prefetcher = std::thread([this, reader_idx, thread_idx] {
prefetcher = std::thread([this, file_name, thread_idx] { PrefetchThreadFunc(reader_idx, thread_idx);
PrefetchThreadFunc(file_name, thread_idx);
}); });
} else { } else {
// No more file to read. // No more file to read.
...@@ -129,11 +134,9 @@ void MultiFileReader::ScheduleThreadFunc() { ...@@ -129,11 +134,9 @@ void MultiFileReader::ScheduleThreadFunc() {
VLOG(5) << "MultiFileReader schedule thread terminates."; VLOG(5) << "MultiFileReader schedule thread terminates.";
} }
void MultiFileReader::PrefetchThreadFunc(std::string file_name, void MultiFileReader::PrefetchThreadFunc(size_t reader_idx, size_t thread_idx) {
size_t thread_idx) { VLOG(5) << "The prefetch thread of file idx '" << reader_idx << "' starts.";
VLOG(5) << "The prefetch thread of file '" << file_name << "' starts."; std::unique_ptr<framework::ReaderBase>& reader = readers_[reader_idx];
std::unique_ptr<framework::ReaderBase> reader =
CreateReaderByFileName(file_name, dims_);
while (true) { while (true) {
std::vector<framework::LoDTensor> ins; std::vector<framework::LoDTensor> ins;
reader->ReadNext(&ins); reader->ReadNext(&ins);
...@@ -144,8 +147,8 @@ void MultiFileReader::PrefetchThreadFunc(std::string file_name, ...@@ -144,8 +147,8 @@ void MultiFileReader::PrefetchThreadFunc(std::string file_name,
buffer_->Send(std::move(ins)); buffer_->Send(std::move(ins));
} catch (paddle::platform::EnforceNotMet e) { } catch (paddle::platform::EnforceNotMet e) {
VLOG(5) << "WARNING: The buffer channel has been closed. The prefetch " VLOG(5) << "WARNING: The buffer channel has been closed. The prefetch "
"thread of file '" "thread of file idx '"
<< file_name << "' will terminate."; << reader_idx << "' will terminate.";
break; break;
} }
} }
...@@ -154,7 +157,8 @@ void MultiFileReader::PrefetchThreadFunc(std::string file_name, ...@@ -154,7 +157,8 @@ void MultiFileReader::PrefetchThreadFunc(std::string file_name,
VLOG(5) << "WARNING: The available_thread_idx_ channel has been closed. " VLOG(5) << "WARNING: The available_thread_idx_ channel has been closed. "
"Fail to send thread_idx."; "Fail to send thread_idx.";
} }
VLOG(5) << "The prefetch thread of file '" << file_name << "' terminates."; VLOG(5) << "The prefetch thread of file idx '" << reader_idx
<< "' terminates.";
} }
class OpenFilesOp : public framework::OperatorBase { class OpenFilesOp : public framework::OperatorBase {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册