diff --git a/paddle/fluid/framework/new_executor/workqueue/events_waiter.cc b/paddle/fluid/framework/new_executor/workqueue/events_waiter.cc index 163050ae5a65a7758fdcbe00f5e880acf4262f9d..346e20d811e84f70726b36b06b61c3d55b11a6ec 100644 --- a/paddle/fluid/framework/new_executor/workqueue/events_waiter.cc +++ b/paddle/fluid/framework/new_executor/workqueue/events_waiter.cc @@ -142,20 +142,24 @@ std::string EventsWaiter::WaitEvent() { } trigger_event_.store(kEmptyEventId, std::memory_order_relaxed); - waiting_.store(false, std::memory_order_relaxed); std::string evt_name = triggered == kEmptyEventId ? "NoEventNotifier" : GetEventName(triggered); VLOG(10) << "Consume event id:" << triggered << ", name:" << evt_name; // lazy deletion { + triggered = trigger_event_; std::lock_guard guard(events_lock_); if (deleted_events_.size() > 0) { for (auto evt : deleted_events_) { + if (evt == triggered) { + continue; + } events_.erase(evt); } deleted_events_.clear(); } } + waiting_.store(false, std::memory_order_relaxed); return evt_name; } diff --git a/paddle/fluid/framework/new_executor/workqueue/thread_data_registry.h b/paddle/fluid/framework/new_executor/workqueue/thread_data_registry.h index 21b2927b52eab653e20611e135a8c0f905057fcf..b46bf7a8d34821c43c4d455790c97486351147de 100644 --- a/paddle/fluid/framework/new_executor/workqueue/thread_data_registry.h +++ b/paddle/fluid/framework/new_executor/workqueue/thread_data_registry.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include #include @@ -23,13 +24,8 @@ namespace paddle { namespace framework { -static uint64_t main_tid = - std::hash()(std::this_thread::get_id()); - template class ThreadDataRegistry { - class ThreadDataHolder; - public: // Singleton static ThreadDataRegistry& GetInstance() { @@ -52,73 +48,92 @@ class ThreadDataRegistry { template ::value>> std::unordered_map GetAllThreadDataByValue() { - std::unordered_map data_copy; - std::lock_guard lock(lock_); - data_copy.reserve(tid_map_.size()); - for (auto& kv : tid_map_) { - data_copy.emplace(kv.first, kv.second->GetData()); - } - return data_copy; + return impl_->GetAllThreadDataByValue(); } // Returns current snapshot of all threads. Make sure there is no thread // create/destory when using it. std::unordered_map> GetAllThreadDataByRef() { - std::unordered_map> data_ref; - std::lock_guard lock(lock_); - data_ref.reserve(tid_map_.size()); - for (auto& kv : tid_map_) { - data_ref.emplace(kv.first, std::ref(kv.second->GetData())); - } - return data_ref; + return impl_->GetAllThreadDataByRef(); } - void RegisterData(uint64_t tid, ThreadDataHolder* tls_obj) { - std::lock_guard lock(lock_); - tid_map_[tid] = tls_obj; - } + private: + // types + class ThreadDataHolder; + class ThreadDataRegistryImpl { + public: + void RegisterData(uint64_t tid, ThreadDataHolder* tls_obj) { + std::lock_guard lock(lock_); + tid_map_[tid] = tls_obj; + } - void UnregisterData(uint64_t tid) { - if (tid == main_tid) { - return; + void UnregisterData(uint64_t tid) { + std::lock_guard lock(lock_); + tid_map_.erase(tid); } - std::lock_guard lock(lock_); - tid_map_.erase(tid); - } - private: + template ::value>> + std::unordered_map GetAllThreadDataByValue() { + std::unordered_map data_copy; + std::lock_guard lock(lock_); + data_copy.reserve(tid_map_.size()); + for (auto& kv : tid_map_) { + data_copy.emplace(kv.first, kv.second->GetData()); + } + return data_copy; + } + + std::unordered_map> + GetAllThreadDataByRef() { + std::unordered_map> data_ref; + std::lock_guard lock(lock_); + data_ref.reserve(tid_map_.size()); + for (auto& kv : tid_map_) { + data_ref.emplace(kv.first, std::ref(kv.second->GetData())); + } + return data_ref; + } + + private: + std::mutex lock_; + std::unordered_map tid_map_; // not owned + }; + class ThreadDataHolder { public: - ThreadDataHolder() { + explicit ThreadDataHolder( + std::shared_ptr registry) { + registry_ = std::move(registry); tid_ = std::hash()(std::this_thread::get_id()); - ThreadDataRegistry::GetInstance().RegisterData(tid_, this); + registry_->RegisterData(tid_, this); } - ~ThreadDataHolder() { - ThreadDataRegistry::GetInstance().UnregisterData(tid_); - } + ~ThreadDataHolder() { registry_->UnregisterData(tid_); } T& GetData() { return data_; } private: + std::shared_ptr registry_; uint64_t tid_; T data_; }; - ThreadDataRegistry() = default; + // methods + ThreadDataRegistry() { impl_ = std::make_shared(); } ThreadDataRegistry(const ThreadDataRegistry&) = delete; ThreadDataRegistry& operator=(const ThreadDataRegistry&) = delete; T& CurrentThreadData() { - static thread_local ThreadDataHolder thread_data; + static thread_local ThreadDataHolder thread_data(impl_); return thread_data.GetData(); } - std::mutex lock_; - std::unordered_map tid_map_; // not owned + // data + std::shared_ptr impl_; }; } // namespace framework