diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index ba4a9add89fd66a4d2c64ec9030787956e2ddf1a..c6cc29d9ca1c83542168be055b529cc1547f0704 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -906,6 +906,8 @@ class DeviceContextPool { return *pool; } + static bool IsInitialized() { return pool != nullptr; } + static void SetPool(DeviceContextPool* dev_pool) { pool = dev_pool; } /*! \brief Return handle of single device context. */ diff --git a/paddle/phi/api/lib/CMakeLists.txt b/paddle/phi/api/lib/CMakeLists.txt index d01d6273dd09d74a3e957514b04058319c63cb65..9ff21d48420e88c0522493c0830f3fb7c278b3e4 100644 --- a/paddle/phi/api/lib/CMakeLists.txt +++ b/paddle/phi/api/lib/CMakeLists.txt @@ -333,7 +333,7 @@ cc_library( cc_library( context_pool SRCS context_pool.cc - DEPS phi_context phi_enforce place) + DEPS phi_context phi_enforce place init) cc_library( kernel_dispatch diff --git a/paddle/phi/api/lib/context_pool.cc b/paddle/phi/api/lib/context_pool.cc index 07ac9822d3310e2c3976296168b2c4527e082274..f3b148fb7bc9ddb4b5b3a4b7b5ec6a464b254f9f 100644 --- a/paddle/phi/api/lib/context_pool.cc +++ b/paddle/phi/api/lib/context_pool.cc @@ -17,6 +17,8 @@ limitations under the License. */ #include "paddle/phi/backends/all_context.h" #include "paddle/phi/core/enforce.h" +#include "paddle/fluid/platform/init.h" + namespace paddle { namespace experimental { @@ -28,6 +30,9 @@ DeviceContextPool& DeviceContextPool::Instance() { const phi::DeviceContext* DeviceContextPool::Get(const Place& place) { auto it = context_map_.find(place); if (it == context_map_.end()) { + if (!paddle::platform::DeviceContextPool::IsInitialized()) { + paddle::framework::InitDevices(); + } // only when we need the specific DeviceContext, get and cache it auto* dev_ctx = paddle::platform::DeviceContextPool::Instance().Get(place); {