未验证 提交 aeae81a7 编写于 作者: L liutiexing 提交者: GitHub

Thread data registry (#40912)

* add align for WorkQueue

* add spinlock

* merge develop

* merge

* Add EventsWaiter

* Revert "Add EventsWaiter"

This reverts commit e206173aa9be7401b83a53581627bfaf557c8fb2.

* Update ThreadDataRegistry
Co-authored-by: Nliutiexing <liutiexing@google.com>
上级 9ffedcfd
...@@ -142,20 +142,24 @@ std::string EventsWaiter::WaitEvent() { ...@@ -142,20 +142,24 @@ std::string EventsWaiter::WaitEvent() {
} }
trigger_event_.store(kEmptyEventId, std::memory_order_relaxed); trigger_event_.store(kEmptyEventId, std::memory_order_relaxed);
waiting_.store(false, std::memory_order_relaxed);
std::string evt_name = std::string evt_name =
triggered == kEmptyEventId ? "NoEventNotifier" : GetEventName(triggered); triggered == kEmptyEventId ? "NoEventNotifier" : GetEventName(triggered);
VLOG(10) << "Consume event id:" << triggered << ", name:" << evt_name; VLOG(10) << "Consume event id:" << triggered << ", name:" << evt_name;
// lazy deletion // lazy deletion
{ {
triggered = trigger_event_;
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_); std::lock_guard<paddle::memory::SpinLock> guard(events_lock_);
if (deleted_events_.size() > 0) { if (deleted_events_.size() > 0) {
for (auto evt : deleted_events_) { for (auto evt : deleted_events_) {
if (evt == triggered) {
continue;
}
events_.erase(evt); events_.erase(evt);
} }
deleted_events_.clear(); deleted_events_.clear();
} }
} }
waiting_.store(false, std::memory_order_relaxed);
return evt_name; return evt_name;
} }
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <functional> #include <functional>
#include <memory>
#include <mutex> #include <mutex>
#include <thread> #include <thread>
#include <type_traits> #include <type_traits>
...@@ -23,13 +24,8 @@ ...@@ -23,13 +24,8 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
static uint64_t main_tid =
std::hash<std::thread::id>()(std::this_thread::get_id());
template <typename T> template <typename T>
class ThreadDataRegistry { class ThreadDataRegistry {
class ThreadDataHolder;
public: public:
// Singleton // Singleton
static ThreadDataRegistry& GetInstance() { static ThreadDataRegistry& GetInstance() {
...@@ -52,73 +48,92 @@ class ThreadDataRegistry { ...@@ -52,73 +48,92 @@ class ThreadDataRegistry {
template <typename Alias = T, typename = std::enable_if_t< template <typename Alias = T, typename = std::enable_if_t<
std::is_copy_constructible<Alias>::value>> std::is_copy_constructible<Alias>::value>>
std::unordered_map<uint64_t, T> GetAllThreadDataByValue() { std::unordered_map<uint64_t, T> GetAllThreadDataByValue() {
std::unordered_map<uint64_t, T> data_copy; return impl_->GetAllThreadDataByValue();
std::lock_guard<std::mutex> lock(lock_);
data_copy.reserve(tid_map_.size());
for (auto& kv : tid_map_) {
data_copy.emplace(kv.first, kv.second->GetData());
}
return data_copy;
} }
// Returns current snapshot of all threads. Make sure there is no thread // Returns current snapshot of all threads. Make sure there is no thread
// create/destory when using it. // create/destory when using it.
std::unordered_map<uint64_t, std::reference_wrapper<T>> std::unordered_map<uint64_t, std::reference_wrapper<T>>
GetAllThreadDataByRef() { GetAllThreadDataByRef() {
std::unordered_map<uint64_t, std::reference_wrapper<T>> data_ref; return impl_->GetAllThreadDataByRef();
std::lock_guard<std::mutex> lock(lock_);
data_ref.reserve(tid_map_.size());
for (auto& kv : tid_map_) {
data_ref.emplace(kv.first, std::ref(kv.second->GetData()));
}
return data_ref;
} }
void RegisterData(uint64_t tid, ThreadDataHolder* tls_obj) { private:
std::lock_guard<std::mutex> lock(lock_); // types
tid_map_[tid] = tls_obj; class ThreadDataHolder;
} class ThreadDataRegistryImpl {
public:
void RegisterData(uint64_t tid, ThreadDataHolder* tls_obj) {
std::lock_guard<std::mutex> lock(lock_);
tid_map_[tid] = tls_obj;
}
void UnregisterData(uint64_t tid) { void UnregisterData(uint64_t tid) {
if (tid == main_tid) { std::lock_guard<std::mutex> lock(lock_);
return; tid_map_.erase(tid);
} }
std::lock_guard<std::mutex> lock(lock_);
tid_map_.erase(tid);
}
private: template <typename Alias = T, typename = std::enable_if_t<
std::is_copy_constructible<Alias>::value>>
std::unordered_map<uint64_t, T> GetAllThreadDataByValue() {
std::unordered_map<uint64_t, T> data_copy;
std::lock_guard<std::mutex> lock(lock_);
data_copy.reserve(tid_map_.size());
for (auto& kv : tid_map_) {
data_copy.emplace(kv.first, kv.second->GetData());
}
return data_copy;
}
std::unordered_map<uint64_t, std::reference_wrapper<T>>
GetAllThreadDataByRef() {
std::unordered_map<uint64_t, std::reference_wrapper<T>> data_ref;
std::lock_guard<std::mutex> lock(lock_);
data_ref.reserve(tid_map_.size());
for (auto& kv : tid_map_) {
data_ref.emplace(kv.first, std::ref(kv.second->GetData()));
}
return data_ref;
}
private:
std::mutex lock_;
std::unordered_map<uint64_t, ThreadDataHolder*> tid_map_; // not owned
};
class ThreadDataHolder { class ThreadDataHolder {
public: public:
ThreadDataHolder() { explicit ThreadDataHolder(
std::shared_ptr<ThreadDataRegistryImpl> registry) {
registry_ = std::move(registry);
tid_ = std::hash<std::thread::id>()(std::this_thread::get_id()); tid_ = std::hash<std::thread::id>()(std::this_thread::get_id());
ThreadDataRegistry::GetInstance().RegisterData(tid_, this); registry_->RegisterData(tid_, this);
} }
~ThreadDataHolder() { ~ThreadDataHolder() { registry_->UnregisterData(tid_); }
ThreadDataRegistry::GetInstance().UnregisterData(tid_);
}
T& GetData() { return data_; } T& GetData() { return data_; }
private: private:
std::shared_ptr<ThreadDataRegistryImpl> registry_;
uint64_t tid_; uint64_t tid_;
T data_; T data_;
}; };
ThreadDataRegistry() = default; // methods
ThreadDataRegistry() { impl_ = std::make_shared<ThreadDataRegistryImpl>(); }
ThreadDataRegistry(const ThreadDataRegistry&) = delete; ThreadDataRegistry(const ThreadDataRegistry&) = delete;
ThreadDataRegistry& operator=(const ThreadDataRegistry&) = delete; ThreadDataRegistry& operator=(const ThreadDataRegistry&) = delete;
T& CurrentThreadData() { T& CurrentThreadData() {
static thread_local ThreadDataHolder thread_data; static thread_local ThreadDataHolder thread_data(impl_);
return thread_data.GetData(); return thread_data.GetData();
} }
std::mutex lock_; // data
std::unordered_map<uint64_t, ThreadDataHolder*> tid_map_; // not owned std::shared_ptr<ThreadDataRegistryImpl> impl_;
}; };
} // namespace framework } // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册