diff --git a/paddle/phi/api/include/context_pool.h b/paddle/phi/api/include/context_pool.h index 754833a2ddab3601f61069a916aea05181425c8f..a2983d9c2aa656e072b7cef010e220201ae3857f 100644 --- a/paddle/phi/api/include/context_pool.h +++ b/paddle/phi/api/include/context_pool.h @@ -14,6 +14,8 @@ limitations under the License. */ #pragma once +#include + #include "paddle/phi/common/place.h" #include "paddle/phi/core/macros.h" #include "paddle/utils/flat_hash_map.h" @@ -58,21 +60,22 @@ class DeviceContextPool { public: static DeviceContextPool& Instance(); - const phi::DeviceContext* Get(const Place& place) const; + const phi::DeviceContext* Get(const Place& place); phi::DeviceContext* GetMutable(const Place& place); template - const typename DefaultDeviceContextType::TYPE* Get( - const Place& place) const { + const typename DefaultDeviceContextType::TYPE* Get(const Place& place) { return reinterpret_cast::TYPE*>( Get(place)); } private: - DeviceContextPool(); + DeviceContextPool() = default; + paddle::flat_hash_map context_map_; + std::mutex mutex_; DISABLE_COPY_AND_ASSIGN(DeviceContextPool); }; diff --git a/paddle/phi/api/lib/context_pool.cc b/paddle/phi/api/lib/context_pool.cc index d1408a88d6ff784039f9e45393d9aec9ff37df2a..07ac9822d3310e2c3976296168b2c4527e082274 100644 --- a/paddle/phi/api/lib/context_pool.cc +++ b/paddle/phi/api/lib/context_pool.cc @@ -25,12 +25,17 @@ DeviceContextPool& DeviceContextPool::Instance() { return g_device_context_pool; } -const phi::DeviceContext* DeviceContextPool::Get(const Place& place) const { +const phi::DeviceContext* DeviceContextPool::Get(const Place& place) { auto it = context_map_.find(place); - PADDLE_ENFORCE_NE( - it, - context_map_.end(), - phi::errors::NotFound("The DeviceContext of %s does not exists.", place)); + if (it == context_map_.end()) { + // only when we need the specific DeviceContext, get and cache it + auto* dev_ctx = paddle::platform::DeviceContextPool::Instance().Get(place); + { + std::lock_guard lock(mutex_); + context_map_[place] = dev_ctx; + } + return dev_ctx; + } return it->second; } @@ -38,28 +43,5 @@ phi::DeviceContext* DeviceContextPool::GetMutable(const Place& place) { return const_cast(Get(place)); } -DeviceContextPool::DeviceContextPool() { - // We need to make sure that the correct value exists - // whenever we get the DeviceContext from DeviceContextPool - const auto& device_contexts = - paddle::platform::DeviceContextPool::Instance().device_contexts(); - for (const auto& pair : device_contexts) { - // only get CPU and GPU DeviceContext now, add other DeviceContext type - // later if needed - if (platform::is_cpu_place(pair.first) -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - || - platform::is_gpu_place(pair.first)) { -#else - ) { -#endif - const phi::DeviceContext* dev_ctx = pair.second.get().get(); - VLOG(3) << "Init phi DeviceContextPool: insert {" << pair.first << ", " - << dev_ctx << "}"; - context_map_[pair.first] = dev_ctx; - } - } -} - } // namespace experimental } // namespace paddle