未验证 提交 aeae81a7 编写于 作者: L liutiexing 提交者: GitHub

Thread data registry (#40912)

* add align for WorkQueue

* add spinlock

* merge develop

* merge

* Add EventsWaiter

* Revert "Add EventsWaiter"

This reverts commit e206173aa9be7401b83a53581627bfaf557c8fb2.

* Update ThreadDataRegistry
Co-authored-by: Nliutiexing <liutiexing@google.com>
上级 9ffedcfd
......@@ -142,20 +142,24 @@ std::string EventsWaiter::WaitEvent() {
}
trigger_event_.store(kEmptyEventId, std::memory_order_relaxed);
waiting_.store(false, std::memory_order_relaxed);
std::string evt_name =
triggered == kEmptyEventId ? "NoEventNotifier" : GetEventName(triggered);
VLOG(10) << "Consume event id:" << triggered << ", name:" << evt_name;
// lazy deletion
{
triggered = trigger_event_;
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
if (deleted_events_.size() > 0) {
for (auto evt : deleted_events_) {
if (evt == triggered) {
continue;
}
events_.erase(evt);
}
deleted_events_.clear();
}
}
waiting_.store(false, std::memory_order_relaxed);
return evt_name;
}
......
......@@ -15,6 +15,7 @@
#pragma once
#include <functional>
#include <memory>
#include <mutex>
#include <thread>
#include <type_traits>
......@@ -23,13 +24,8 @@
namespace paddle {
namespace framework {
static uint64_t main_tid =
std::hash<std::thread::id>()(std::this_thread::get_id());
template <typename T>
class ThreadDataRegistry {
class ThreadDataHolder;
public:
// Singleton
static ThreadDataRegistry& GetInstance() {
......@@ -52,73 +48,92 @@ class ThreadDataRegistry {
template <typename Alias = T, typename = std::enable_if_t<
std::is_copy_constructible<Alias>::value>>
std::unordered_map<uint64_t, T> GetAllThreadDataByValue() {
std::unordered_map<uint64_t, T> data_copy;
std::lock_guard<std::mutex> lock(lock_);
data_copy.reserve(tid_map_.size());
for (auto& kv : tid_map_) {
data_copy.emplace(kv.first, kv.second->GetData());
}
return data_copy;
return impl_->GetAllThreadDataByValue();
}
// Returns current snapshot of all threads. Make sure there is no thread
// create/destory when using it.
std::unordered_map<uint64_t, std::reference_wrapper<T>>
GetAllThreadDataByRef() {
std::unordered_map<uint64_t, std::reference_wrapper<T>> data_ref;
std::lock_guard<std::mutex> lock(lock_);
data_ref.reserve(tid_map_.size());
for (auto& kv : tid_map_) {
data_ref.emplace(kv.first, std::ref(kv.second->GetData()));
}
return data_ref;
return impl_->GetAllThreadDataByRef();
}
void RegisterData(uint64_t tid, ThreadDataHolder* tls_obj) {
std::lock_guard<std::mutex> lock(lock_);
tid_map_[tid] = tls_obj;
}
private:
// types
class ThreadDataHolder;
class ThreadDataRegistryImpl {
public:
void RegisterData(uint64_t tid, ThreadDataHolder* tls_obj) {
std::lock_guard<std::mutex> lock(lock_);
tid_map_[tid] = tls_obj;
}
void UnregisterData(uint64_t tid) {
if (tid == main_tid) {
return;
void UnregisterData(uint64_t tid) {
std::lock_guard<std::mutex> lock(lock_);
tid_map_.erase(tid);
}
std::lock_guard<std::mutex> lock(lock_);
tid_map_.erase(tid);
}
private:
template <typename Alias = T, typename = std::enable_if_t<
std::is_copy_constructible<Alias>::value>>
std::unordered_map<uint64_t, T> GetAllThreadDataByValue() {
std::unordered_map<uint64_t, T> data_copy;
std::lock_guard<std::mutex> lock(lock_);
data_copy.reserve(tid_map_.size());
for (auto& kv : tid_map_) {
data_copy.emplace(kv.first, kv.second->GetData());
}
return data_copy;
}
std::unordered_map<uint64_t, std::reference_wrapper<T>>
GetAllThreadDataByRef() {
std::unordered_map<uint64_t, std::reference_wrapper<T>> data_ref;
std::lock_guard<std::mutex> lock(lock_);
data_ref.reserve(tid_map_.size());
for (auto& kv : tid_map_) {
data_ref.emplace(kv.first, std::ref(kv.second->GetData()));
}
return data_ref;
}
private:
std::mutex lock_;
std::unordered_map<uint64_t, ThreadDataHolder*> tid_map_; // not owned
};
class ThreadDataHolder {
public:
ThreadDataHolder() {
explicit ThreadDataHolder(
std::shared_ptr<ThreadDataRegistryImpl> registry) {
registry_ = std::move(registry);
tid_ = std::hash<std::thread::id>()(std::this_thread::get_id());
ThreadDataRegistry::GetInstance().RegisterData(tid_, this);
registry_->RegisterData(tid_, this);
}
~ThreadDataHolder() {
ThreadDataRegistry::GetInstance().UnregisterData(tid_);
}
~ThreadDataHolder() { registry_->UnregisterData(tid_); }
T& GetData() { return data_; }
private:
std::shared_ptr<ThreadDataRegistryImpl> registry_;
uint64_t tid_;
T data_;
};
ThreadDataRegistry() = default;
// methods
ThreadDataRegistry() { impl_ = std::make_shared<ThreadDataRegistryImpl>(); }
ThreadDataRegistry(const ThreadDataRegistry&) = delete;
ThreadDataRegistry& operator=(const ThreadDataRegistry&) = delete;
T& CurrentThreadData() {
static thread_local ThreadDataHolder thread_data;
static thread_local ThreadDataHolder thread_data(impl_);
return thread_data.GetData();
}
std::mutex lock_;
std::unordered_map<uint64_t, ThreadDataHolder*> tid_map_; // not owned
// data
std::shared_ptr<ThreadDataRegistryImpl> impl_;
};
} // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册