未验证 提交 e8b9ae20 编写于 作者: S sneaxiy 提交者: GitHub

move some singleton to cc file (#46470)

上级 991ec7d3
...@@ -70,6 +70,11 @@ OpVersionDesc&& OpVersionDesc::DeleteOutput(const std::string& name, ...@@ -70,6 +70,11 @@ OpVersionDesc&& OpVersionDesc::DeleteOutput(const std::string& name,
return std::move(*this); return std::move(*this);
} }
OpVersionRegistrar& OpVersionRegistrar::GetInstance() {
static OpVersionRegistrar instance;
return instance;
}
OpVersion& OpVersionRegistrar::Register(const std::string& op_type) { OpVersion& OpVersionRegistrar::Register(const std::string& op_type) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
op_version_map_.find(op_type), op_version_map_.find(op_type),
...@@ -89,6 +94,11 @@ uint32_t OpVersionRegistrar::version_id(const std::string& op_type) const { ...@@ -89,6 +94,11 @@ uint32_t OpVersionRegistrar::version_id(const std::string& op_type) const {
return op_version_map_.find(op_type)->second.version_id(); return op_version_map_.find(op_type)->second.version_id();
} }
PassVersionCheckerRegistrar& PassVersionCheckerRegistrar::GetInstance() {
static PassVersionCheckerRegistrar instance;
return instance;
}
// Provide a fake registration item for pybind testing. // Provide a fake registration item for pybind testing.
#include "paddle/fluid/framework/op_version_registry.inl" #include "paddle/fluid/framework/op_version_registry.inl"
......
...@@ -230,10 +230,8 @@ class OpVersion { ...@@ -230,10 +230,8 @@ class OpVersion {
class OpVersionRegistrar { class OpVersionRegistrar {
public: public:
static OpVersionRegistrar& GetInstance() { static OpVersionRegistrar& GetInstance();
static OpVersionRegistrar instance;
return instance;
}
OpVersion& Register(const std::string& op_type); OpVersion& Register(const std::string& op_type);
const std::unordered_map<std::string, OpVersion>& GetVersionMap() { const std::unordered_map<std::string, OpVersion>& GetVersionMap() {
return op_version_map_; return op_version_map_;
...@@ -365,10 +363,8 @@ class PassVersionCheckers { ...@@ -365,10 +363,8 @@ class PassVersionCheckers {
class PassVersionCheckerRegistrar { class PassVersionCheckerRegistrar {
public: public:
static PassVersionCheckerRegistrar& GetInstance() { static PassVersionCheckerRegistrar& GetInstance();
static PassVersionCheckerRegistrar instance;
return instance;
}
PassVersionCheckers& Register(const std::string& pass_name) { PassVersionCheckers& Register(const std::string& pass_name) {
PADDLE_ENFORCE_EQ(pass_version_checkers_map_.find(pass_name), PADDLE_ENFORCE_EQ(pass_version_checkers_map_.find(pass_name),
pass_version_checkers_map_.end(), pass_version_checkers_map_.end(),
......
...@@ -73,7 +73,30 @@ DeviceType Place2DeviceType(const platform::Place& place) { ...@@ -73,7 +73,30 @@ DeviceType Place2DeviceType(const platform::Place& place) {
} }
} }
DeviceContextPool* DeviceContextPool::pool = nullptr; static DeviceContextPool* pool = nullptr;
DeviceContextPool& DeviceContextPool::Instance() {
PADDLE_ENFORCE_NOT_NULL(pool,
phi::errors::PreconditionNotMet(
"Need to Create DeviceContextPool firstly!"));
return *pool;
}
/*! \brief Create should only called by Init function */
DeviceContextPool& DeviceContextPool::Init(
const std::vector<platform::Place>& places) {
if (pool == nullptr) {
pool = new DeviceContextPool(places);
}
return *pool;
}
bool DeviceContextPool::IsInitialized() { return pool != nullptr; }
void DeviceContextPool::SetPool(DeviceContextPool* dev_pool) {
pool = dev_pool;
}
thread_local const std::map<Place, thread_local const std::map<Place,
std::shared_future<std::unique_ptr<DeviceContext>>>* std::shared_future<std::unique_ptr<DeviceContext>>>*
DeviceContextPool::external_device_contexts_ = nullptr; DeviceContextPool::external_device_contexts_ = nullptr;
......
...@@ -339,24 +339,14 @@ void EmplaceDeviceContexts( ...@@ -339,24 +339,14 @@ void EmplaceDeviceContexts(
/*! \brief device context pool singleton */ /*! \brief device context pool singleton */
class DeviceContextPool { class DeviceContextPool {
public: public:
static DeviceContextPool& Instance() { static DeviceContextPool& Instance();
PADDLE_ENFORCE_NOT_NULL(pool,
platform::errors::PreconditionNotMet(
"Need to Create DeviceContextPool firstly!"));
return *pool;
}
/*! \brief Create should only called by Init function */ /*! \brief Create should only called by Init function */
static DeviceContextPool& Init(const std::vector<platform::Place>& places) { static DeviceContextPool& Init(const std::vector<platform::Place>& places);
if (pool == nullptr) {
pool = new DeviceContextPool(places);
}
return *pool;
}
static bool IsInitialized() { return pool != nullptr; } static bool IsInitialized();
static void SetPool(DeviceContextPool* dev_pool) { pool = dev_pool; } static void SetPool(DeviceContextPool* dev_pool);
/*! \brief Return handle of single device context. */ /*! \brief Return handle of single device context. */
platform::DeviceContext* Get(const platform::Place& place); platform::DeviceContext* Get(const platform::Place& place);
...@@ -380,7 +370,6 @@ class DeviceContextPool { ...@@ -380,7 +370,6 @@ class DeviceContextPool {
private: private:
explicit DeviceContextPool(const std::vector<platform::Place>& places); explicit DeviceContextPool(const std::vector<platform::Place>& places);
static DeviceContextPool* pool;
std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>> std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>
device_contexts_; device_contexts_;
static thread_local const std:: static thread_local const std::
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册