From e8b9ae20629a25e9019c24db92babc4e6fdc6bfc Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Sun, 25 Sep 2022 12:47:51 +0800 Subject: [PATCH] move some singleton to cc file (#46470) --- paddle/fluid/framework/op_version_registry.cc | 10 ++++++++ paddle/fluid/framework/op_version_registry.h | 12 +++------ paddle/fluid/platform/device_context.cc | 25 ++++++++++++++++++- paddle/fluid/platform/device_context.h | 19 +++----------- 4 files changed, 42 insertions(+), 24 deletions(-) diff --git a/paddle/fluid/framework/op_version_registry.cc b/paddle/fluid/framework/op_version_registry.cc index a32613812d8..198394619e0 100644 --- a/paddle/fluid/framework/op_version_registry.cc +++ b/paddle/fluid/framework/op_version_registry.cc @@ -70,6 +70,11 @@ OpVersionDesc&& OpVersionDesc::DeleteOutput(const std::string& name, return std::move(*this); } +OpVersionRegistrar& OpVersionRegistrar::GetInstance() { + static OpVersionRegistrar instance; + return instance; +} + OpVersion& OpVersionRegistrar::Register(const std::string& op_type) { PADDLE_ENFORCE_EQ( op_version_map_.find(op_type), @@ -89,6 +94,11 @@ uint32_t OpVersionRegistrar::version_id(const std::string& op_type) const { return op_version_map_.find(op_type)->second.version_id(); } +PassVersionCheckerRegistrar& PassVersionCheckerRegistrar::GetInstance() { + static PassVersionCheckerRegistrar instance; + return instance; +} + // Provide a fake registration item for pybind testing. #include "paddle/fluid/framework/op_version_registry.inl" diff --git a/paddle/fluid/framework/op_version_registry.h b/paddle/fluid/framework/op_version_registry.h index 579dd320d14..661b206be32 100644 --- a/paddle/fluid/framework/op_version_registry.h +++ b/paddle/fluid/framework/op_version_registry.h @@ -230,10 +230,8 @@ class OpVersion { class OpVersionRegistrar { public: - static OpVersionRegistrar& GetInstance() { - static OpVersionRegistrar instance; - return instance; - } + static OpVersionRegistrar& GetInstance(); + OpVersion& Register(const std::string& op_type); const std::unordered_map& GetVersionMap() { return op_version_map_; @@ -365,10 +363,8 @@ class PassVersionCheckers { class PassVersionCheckerRegistrar { public: - static PassVersionCheckerRegistrar& GetInstance() { - static PassVersionCheckerRegistrar instance; - return instance; - } + static PassVersionCheckerRegistrar& GetInstance(); + PassVersionCheckers& Register(const std::string& pass_name) { PADDLE_ENFORCE_EQ(pass_version_checkers_map_.find(pass_name), pass_version_checkers_map_.end(), diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index f7c715d7905..c39705b618f 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -73,7 +73,30 @@ DeviceType Place2DeviceType(const platform::Place& place) { } } -DeviceContextPool* DeviceContextPool::pool = nullptr; +static DeviceContextPool* pool = nullptr; + +DeviceContextPool& DeviceContextPool::Instance() { + PADDLE_ENFORCE_NOT_NULL(pool, + phi::errors::PreconditionNotMet( + "Need to Create DeviceContextPool firstly!")); + return *pool; +} + +/*! \brief Create should only called by Init function */ +DeviceContextPool& DeviceContextPool::Init( + const std::vector& places) { + if (pool == nullptr) { + pool = new DeviceContextPool(places); + } + return *pool; +} + +bool DeviceContextPool::IsInitialized() { return pool != nullptr; } + +void DeviceContextPool::SetPool(DeviceContextPool* dev_pool) { + pool = dev_pool; +} + thread_local const std::map>>* DeviceContextPool::external_device_contexts_ = nullptr; diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index d8ebb019fc6..f0119d1f839 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -339,24 +339,14 @@ void EmplaceDeviceContexts( /*! \brief device context pool singleton */ class DeviceContextPool { public: - static DeviceContextPool& Instance() { - PADDLE_ENFORCE_NOT_NULL(pool, - platform::errors::PreconditionNotMet( - "Need to Create DeviceContextPool firstly!")); - return *pool; - } + static DeviceContextPool& Instance(); /*! \brief Create should only called by Init function */ - static DeviceContextPool& Init(const std::vector& places) { - if (pool == nullptr) { - pool = new DeviceContextPool(places); - } - return *pool; - } + static DeviceContextPool& Init(const std::vector& places); - static bool IsInitialized() { return pool != nullptr; } + static bool IsInitialized(); - static void SetPool(DeviceContextPool* dev_pool) { pool = dev_pool; } + static void SetPool(DeviceContextPool* dev_pool); /*! \brief Return handle of single device context. */ platform::DeviceContext* Get(const platform::Place& place); @@ -380,7 +370,6 @@ class DeviceContextPool { private: explicit DeviceContextPool(const std::vector& places); - static DeviceContextPool* pool; std::map>> device_contexts_; static thread_local const std:: -- GitLab