提交 49ab52d6 编写于 作者: F fengjiayi

Modify MultipleReader

1. Removes MultipleReader's multi-thread support, for we have got
ThreadedReader.
2. Rename MultipleReader to MultiFileReader
上级 03ff0e58
...@@ -124,7 +124,7 @@ class CreateThreadedReaderOpMaker : public DecoratedReaderMakerBase { ...@@ -124,7 +124,7 @@ class CreateThreadedReaderOpMaker : public DecoratedReaderMakerBase {
can enable them by setting 'unsafe_mode' true. In this case, can enable them by setting 'unsafe_mode' true. In this case,
'HasNext()' returning true only guarantees the safety of 'HasNext()' returning true only guarantees the safety of
invoking 'ReadNext()' in the same thread. Each thread must invoking 'ReadNext()' in the same thread. Each thread must
invoke 'HasNext()' and 'ReadNext()' in pair. invoke 'HasNext()' and 'ReadNext()' in pairs.
)DOC"); )DOC");
} }
}; };
......
...@@ -19,27 +19,11 @@ namespace paddle { ...@@ -19,27 +19,11 @@ namespace paddle {
namespace operators { namespace operators {
namespace reader { namespace reader {
class MultipleReader : public framework::ReaderBase { class MultiFileReader : public framework::ReaderBase {
public: public:
class ThreadBufferMap { MultiFileReader(const std::vector<std::string>& file_names,
public: const std::vector<framework::DDim>& dims, size_t thread_num,
std::vector<framework::LoDTensor>& operator[]( size_t buffer_size)
const std::thread::id& thread_id) {
std::lock_guard<std::mutex> lock(mutex_);
return buffer_[thread_id];
}
void Clear() { buffer_.clear(); }
private:
std::mutex mutex_;
std::unordered_map<std::thread::id, std::vector<framework::LoDTensor>>
buffer_;
};
MultipleReader(const std::vector<std::string>& file_names,
const std::vector<framework::DDim>& dims, size_t thread_num,
size_t buffer_size)
: file_names_(file_names), dims_(dims), buffer_size_(buffer_size) { : file_names_(file_names), dims_(dims), buffer_size_(buffer_size) {
prefetchers_.resize(thread_num); prefetchers_.resize(thread_num);
StartNewScheduler(); StartNewScheduler();
...@@ -49,7 +33,7 @@ class MultipleReader : public framework::ReaderBase { ...@@ -49,7 +33,7 @@ class MultipleReader : public framework::ReaderBase {
bool HasNext() const override; bool HasNext() const override;
void ReInit() override; void ReInit() override;
~MultipleReader() { EndScheduler(); } ~MultiFileReader() { EndScheduler(); }
private: private:
void StartNewScheduler(); void StartNewScheduler();
...@@ -65,31 +49,27 @@ class MultipleReader : public framework::ReaderBase { ...@@ -65,31 +49,27 @@ class MultipleReader : public framework::ReaderBase {
framework::Channel<size_t>* waiting_file_idx_; framework::Channel<size_t>* waiting_file_idx_;
framework::Channel<size_t>* available_thread_idx_; framework::Channel<size_t>* available_thread_idx_;
framework::Channel<std::vector<framework::LoDTensor>>* buffer_; framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
mutable ThreadBufferMap thread_buffer_map_;
}; };
void MultipleReader::ReadNext(std::vector<framework::LoDTensor>* out) { void MultiFileReader::ReadNext(std::vector<framework::LoDTensor>* out) {
if (!HasNext()) { if (!HasNext()) {
PADDLE_THROW("There is no next data!"); PADDLE_THROW("There is no next data!");
} }
auto& thread_local_buffer = thread_buffer_map_[std::this_thread::get_id()]; buffer_->Receive(out);
*out = thread_local_buffer;
thread_local_buffer.clear();
} }
bool MultipleReader::HasNext() const { bool MultiFileReader::HasNext() const {
auto& thread_local_buffer = thread_buffer_map_[std::this_thread::get_id()]; while (!buffer_->IsClosed() && !buffer_->CanReceive()) {
return thread_local_buffer.empty() ? buffer_->Receive(&thread_local_buffer) }
: true; return buffer_->CanReceive();
} }
void MultipleReader::ReInit() { void MultiFileReader::ReInit() {
EndScheduler(); EndScheduler();
thread_buffer_map_.Clear();
StartNewScheduler(); StartNewScheduler();
} }
void MultipleReader::StartNewScheduler() { void MultiFileReader::StartNewScheduler() {
size_t thread_num = prefetchers_.size(); size_t thread_num = prefetchers_.size();
waiting_file_idx_ = framework::MakeChannel<size_t>(file_names_.size()); waiting_file_idx_ = framework::MakeChannel<size_t>(file_names_.size());
available_thread_idx_ = framework::MakeChannel<size_t>(thread_num); available_thread_idx_ = framework::MakeChannel<size_t>(thread_num);
...@@ -107,7 +87,7 @@ void MultipleReader::StartNewScheduler() { ...@@ -107,7 +87,7 @@ void MultipleReader::StartNewScheduler() {
scheduler_ = std::thread([this] { ScheduleThreadFunc(); }); scheduler_ = std::thread([this] { ScheduleThreadFunc(); });
} }
void MultipleReader::EndScheduler() { void MultiFileReader::EndScheduler() {
available_thread_idx_->Close(); available_thread_idx_->Close();
buffer_->Close(); buffer_->Close();
waiting_file_idx_->Close(); waiting_file_idx_->Close();
...@@ -119,8 +99,8 @@ void MultipleReader::EndScheduler() { ...@@ -119,8 +99,8 @@ void MultipleReader::EndScheduler() {
delete waiting_file_idx_; delete waiting_file_idx_;
} }
void MultipleReader::ScheduleThreadFunc() { void MultiFileReader::ScheduleThreadFunc() {
VLOG(5) << "MultipleReader schedule thread starts."; VLOG(5) << "MultiFileReader schedule thread starts.";
size_t completed_thread_num = 0; size_t completed_thread_num = 0;
size_t thread_idx; size_t thread_idx;
while (available_thread_idx_->Receive(&thread_idx)) { while (available_thread_idx_->Receive(&thread_idx)) {
...@@ -152,11 +132,11 @@ void MultipleReader::ScheduleThreadFunc() { ...@@ -152,11 +132,11 @@ void MultipleReader::ScheduleThreadFunc() {
p.join(); p.join();
} }
} }
VLOG(5) << "MultipleReader schedule thread terminates."; VLOG(5) << "MultiFileReader schedule thread terminates.";
} }
void MultipleReader::PrefetchThreadFunc(std::string file_name, void MultiFileReader::PrefetchThreadFunc(std::string file_name,
size_t thread_idx) { size_t thread_idx) {
VLOG(5) << "The prefetch thread of file '" << file_name << "' starts."; VLOG(5) << "The prefetch thread of file '" << file_name << "' starts.";
std::unique_ptr<framework::ReaderBase> reader = std::unique_ptr<framework::ReaderBase> reader =
CreateReaderByFileName(file_name, dims_); CreateReaderByFileName(file_name, dims_);
...@@ -203,9 +183,9 @@ class OpenFilesOp : public framework::OperatorBase { ...@@ -203,9 +183,9 @@ class OpenFilesOp : public framework::OperatorBase {
auto* out = scope.FindVar(Output("Out")) auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>(); ->template GetMutable<framework::ReaderHolder>();
out->Reset(new MultipleReader(file_names, out->Reset(new MultiFileReader(file_names,
RestoreShapes(shape_concat, ranks), RestoreShapes(shape_concat, ranks),
thread_num, buffer_size)); thread_num, buffer_size));
} }
}; };
...@@ -221,7 +201,7 @@ class OpenFilesOpMaker : public FileReaderMakerBase { ...@@ -221,7 +201,7 @@ class OpenFilesOpMaker : public FileReaderMakerBase {
AddComment(R"DOC( AddComment(R"DOC(
OpenFiles Operator OpenFiles Operator
An OpenFilesOp creates a MultipleReader, which is able to An OpenFilesOp creates a MultiFileReader, which is able to
read data multi-threaded from multiple files. read data multi-threaded from multiple files.
)DOC"); )DOC");
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册