From 67f2c9f7555157bdfbb49954134e4597d666dcc8 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Tue, 22 Mar 2022 13:58:29 +0800 Subject: [PATCH] [Phi] Fix context pool sync init (#40787) * fix context pool sync init * add lock for insert --- paddle/phi/api/include/context_pool.h | 11 +++++--- paddle/phi/api/lib/context_pool.cc | 38 +++++++-------------------- 2 files changed, 17 insertions(+), 32 deletions(-) diff --git a/paddle/phi/api/include/context_pool.h b/paddle/phi/api/include/context_pool.h index 754833a2dda..a2983d9c2aa 100644 --- a/paddle/phi/api/include/context_pool.h +++ b/paddle/phi/api/include/context_pool.h @@ -14,6 +14,8 @@ limitations under the License. */ #pragma once +#include + #include "paddle/phi/common/place.h" #include "paddle/phi/core/macros.h" #include "paddle/utils/flat_hash_map.h" @@ -58,21 +60,22 @@ class DeviceContextPool { public: static DeviceContextPool& Instance(); - const phi::DeviceContext* Get(const Place& place) const; + const phi::DeviceContext* Get(const Place& place); phi::DeviceContext* GetMutable(const Place& place); template - const typename DefaultDeviceContextType::TYPE* Get( - const Place& place) const { + const typename DefaultDeviceContextType::TYPE* Get(const Place& place) { return reinterpret_cast::TYPE*>( Get(place)); } private: - DeviceContextPool(); + DeviceContextPool() = default; + paddle::flat_hash_map context_map_; + std::mutex mutex_; DISABLE_COPY_AND_ASSIGN(DeviceContextPool); }; diff --git a/paddle/phi/api/lib/context_pool.cc b/paddle/phi/api/lib/context_pool.cc index d1408a88d6f..07ac9822d33 100644 --- a/paddle/phi/api/lib/context_pool.cc +++ b/paddle/phi/api/lib/context_pool.cc @@ -25,12 +25,17 @@ DeviceContextPool& DeviceContextPool::Instance() { return g_device_context_pool; } -const phi::DeviceContext* DeviceContextPool::Get(const Place& place) const { +const phi::DeviceContext* DeviceContextPool::Get(const Place& place) { auto it = context_map_.find(place); - PADDLE_ENFORCE_NE( - it, - context_map_.end(), - phi::errors::NotFound("The DeviceContext of %s does not exists.", place)); + if (it == context_map_.end()) { + // only when we need the specific DeviceContext, get and cache it + auto* dev_ctx = paddle::platform::DeviceContextPool::Instance().Get(place); + { + std::lock_guard lock(mutex_); + context_map_[place] = dev_ctx; + } + return dev_ctx; + } return it->second; } @@ -38,28 +43,5 @@ phi::DeviceContext* DeviceContextPool::GetMutable(const Place& place) { return const_cast(Get(place)); } -DeviceContextPool::DeviceContextPool() { - // We need to make sure that the correct value exists - // whenever we get the DeviceContext from DeviceContextPool - const auto& device_contexts = - paddle::platform::DeviceContextPool::Instance().device_contexts(); - for (const auto& pair : device_contexts) { - // only get CPU and GPU DeviceContext now, add other DeviceContext type - // later if needed - if (platform::is_cpu_place(pair.first) -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - || - platform::is_gpu_place(pair.first)) { -#else - ) { -#endif - const phi::DeviceContext* dev_ctx = pair.second.get().get(); - VLOG(3) << "Init phi DeviceContextPool: insert {" << pair.first << ", " - << dev_ctx << "}"; - context_map_[pair.first] = dev_ctx; - } - } -} - } // namespace experimental } // namespace paddle -- GitLab