diff --git a/paddle/fluid/framework/new_executor/workqueue/thread_data_registry.h b/paddle/fluid/framework/new_executor/workqueue/thread_data_registry.h index ffdddc39a31e39b0e5f16438ff52a6613fd7bdfe..98ed2c1ffc4b371408fcf1bb91df82d20dff91c3 100644 --- a/paddle/fluid/framework/new_executor/workqueue/thread_data_registry.h +++ b/paddle/fluid/framework/new_executor/workqueue/thread_data_registry.h @@ -60,18 +60,51 @@ class ThreadDataRegistry { } private: - // types +// types +// Lock types +#if defined(__clang__) || defined(__GNUC__) // CLANG or GCC +#ifndef __APPLE__ +#if __cplusplus >= 201703L + using LockType = std::shared_mutex; + using SharedLockGuardType = std::shared_lock; +#elif __cplusplus >= 201402L using LockType = std::shared_timed_mutex; + using SharedLockGuardType = std::shared_lock; +#else + using LockType = std::mutex; + using SharedLockGuardType = std::lock_guard; +#endif +// Special case : mac. https://github.com/facebook/react-native/issues/31250 +#else + using LockType = std::mutex; + using SharedLockGuardType = std::lock_guard; +#endif +#elif defined(_MSC_VER) // MSVC +#if _MSVC_LANG >= 201703L + using LockType = std::shared_mutex; + using SharedLockGuardType = std::shared_lock; +#elif _MSVC_LANG >= 201402L + using LockType = std::shared_timed_mutex; + using SharedLockGuardType = std::shared_lock; +#else + using LockType = std::mutex; + using SharedLockGuardType = std::lock_guard; +#endif +#else // other compilers + using LockType = std::mutex; + using SharedLockGuardType = std::lock_guard; +#endif + class ThreadDataHolder; class ThreadDataRegistryImpl { public: void RegisterData(uint64_t tid, ThreadDataHolder* tls_obj) { - std::lock_guard lock(lock_); + std::lock_guard guard(lock_); tid_map_[tid] = tls_obj; } void UnregisterData(uint64_t tid) { - std::lock_guard lock(lock_); + std::lock_guard guard(lock_); tid_map_.erase(tid); } @@ -79,7 +112,7 @@ class ThreadDataRegistry { std::is_copy_constructible::value>> std::unordered_map GetAllThreadDataByValue() { std::unordered_map data_copy; - std::shared_lock lock(lock_); + SharedLockGuardType guard(lock_); data_copy.reserve(tid_map_.size()); for (auto& kv : tid_map_) { data_copy.emplace(kv.first, kv.second->GetData()); @@ -90,7 +123,7 @@ class ThreadDataRegistry { std::unordered_map> GetAllThreadDataByRef() { std::unordered_map> data_ref; - std::shared_lock lock(lock_); + SharedLockGuardType guard(lock_); data_ref.reserve(tid_map_.size()); for (auto& kv : tid_map_) { data_ref.emplace(kv.first, std::ref(kv.second->GetData()));