未验证 提交 f2c0b886 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #9550 from JiayiFeng/make_MultipleReader_thread-safe

Make MultipleReader thread-safe
...@@ -21,6 +21,22 @@ namespace reader { ...@@ -21,6 +21,22 @@ namespace reader {
class MultipleReader : public framework::ReaderBase { class MultipleReader : public framework::ReaderBase {
public: public:
class ThreadBufferMap {
public:
std::vector<framework::LoDTensor>& operator[](
const std::thread::id& thread_id) {
std::lock_guard<std::mutex> lock(mutex_);
return buffer_[thread_id];
}
void Clear() { buffer_.clear(); }
private:
std::mutex mutex_;
std::unordered_map<std::thread::id, std::vector<framework::LoDTensor>>
buffer_;
};
MultipleReader(const std::vector<std::string>& file_names, MultipleReader(const std::vector<std::string>& file_names,
const std::vector<framework::DDim>& dims, size_t thread_num) const std::vector<framework::DDim>& dims, size_t thread_num)
: file_names_(file_names), dims_(dims) { : file_names_(file_names), dims_(dims) {
...@@ -47,28 +63,27 @@ class MultipleReader : public framework::ReaderBase { ...@@ -47,28 +63,27 @@ class MultipleReader : public framework::ReaderBase {
framework::Channel<size_t>* waiting_file_idx_; framework::Channel<size_t>* waiting_file_idx_;
framework::Channel<size_t>* available_thread_idx_; framework::Channel<size_t>* available_thread_idx_;
framework::Channel<std::vector<framework::LoDTensor>>* buffer_; framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
mutable std::vector<framework::LoDTensor> local_buffer_; mutable ThreadBufferMap thread_buffer_map_;
}; };
void MultipleReader::ReadNext(std::vector<framework::LoDTensor>* out) { void MultipleReader::ReadNext(std::vector<framework::LoDTensor>* out) {
if (!HasNext()) { if (!HasNext()) {
PADDLE_THROW("There is no next data!"); PADDLE_THROW("There is no next data!");
} }
auto& thread_local_buffer = thread_buffer_map_[std::this_thread::get_id()];
if (local_buffer_.empty()) { *out = thread_local_buffer;
buffer_->Receive(&local_buffer_); thread_local_buffer.clear();
}
*out = local_buffer_;
local_buffer_.clear();
} }
bool MultipleReader::HasNext() const { bool MultipleReader::HasNext() const {
return local_buffer_.empty() ? buffer_->Receive(&local_buffer_) : true; auto& thread_local_buffer = thread_buffer_map_[std::this_thread::get_id()];
return thread_local_buffer.empty() ? buffer_->Receive(&thread_local_buffer)
: true;
} }
void MultipleReader::ReInit() { void MultipleReader::ReInit() {
EndScheduler(); EndScheduler();
local_buffer_.clear(); thread_buffer_map_.Clear();
StartNewScheduler(); StartNewScheduler();
} }
...@@ -176,7 +191,7 @@ class OpenFilesOp : public framework::OperatorBase { ...@@ -176,7 +191,7 @@ class OpenFilesOp : public framework::OperatorBase {
const auto& ranks = Attr<std::vector<int>>("ranks"); const auto& ranks = Attr<std::vector<int>>("ranks");
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty()); PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0), PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
int(shape_concat.size()), static_cast<int>(shape_concat.size()),
"The accumulate of all ranks should be equal to the " "The accumulate of all ranks should be equal to the "
"shape concat's length."); "shape concat's length.");
const auto& file_names = Attr<std::vector<std::string>>("file_names"); const auto& file_names = Attr<std::vector<std::string>>("file_names");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册