diff --git a/paddle/fluid/platform/device/custom/custom_device_resource_pool.cc b/paddle/fluid/platform/device/custom/custom_device_resource_pool.cc index 5d0fb95d8ea44272bc9c3275237256280e6a7876..c7afa71671563b48d4a84f4c34b4aa26fcc9a871 100644 --- a/paddle/fluid/platform/device/custom/custom_device_resource_pool.cc +++ b/paddle/fluid/platform/device/custom/custom_device_resource_pool.cc @@ -30,39 +30,47 @@ CustomDeviceStreamResourcePool::CustomDeviceStreamResourcePool( int dev_cnt = phi::DeviceManager::GetDeviceCount(place.GetDeviceType()); pool_.reserve(dev_cnt); for (int dev_idx = 0; dev_idx < dev_cnt; ++dev_idx) { - auto creator = [place, dev_idx] { + auto creator = [place, dev_idx, this] { auto place_ = phi::CustomPlace(place.GetDeviceType(), dev_idx); phi::DeviceManager::SetDevice(place_); - phi::stream::Stream* stream = new phi::stream::Stream(place_, nullptr); - phi::DeviceManager::GetDeviceWithPlace(place_)->CreateStream(stream); + phi::stream::Stream* stream = new phi::stream::Stream; + stream->Init(place_); + this->streams_.push_back(stream); return stream; }; - auto deleter = [place, dev_idx](phi::stream::Stream* stream) { - auto place_ = phi::CustomPlace(place.GetDeviceType(), dev_idx); - phi::DeviceManager::SetDevice(place_); - - phi::DeviceManager::GetDeviceWithPlace(place_)->DestroyStream(stream); - delete stream; - }; - - pool_.emplace_back( - ResourcePool::Create(creator, deleter)); + pool_.emplace_back(ResourcePool::Create( + creator, [](phi::stream::Stream* stream) {})); } } -std::unordered_map< - std::string, - std::vector>>& +std::unordered_map>& CustomDeviceStreamResourcePool::GetMap() { - static std::unordered_map< - std::string, - std::vector>> + static std::unordered_map> pool; return pool; } +CustomDeviceStreamResourcePool::~CustomDeviceStreamResourcePool() { + for (auto* p : streams_) { + delete p; + } + pool_.clear(); +} + +void CustomDeviceStreamResourcePool::Release() { + auto& pool = GetMap(); + for (auto& item : pool) { + for (auto& p : item.second) { + delete p; + } + item.second.clear(); + } + pool.clear(); +} + CustomDeviceStreamResourcePool& CustomDeviceStreamResourcePool::Instance( const paddle::Place& place) { auto& pool = GetMap(); @@ -72,9 +80,8 @@ CustomDeviceStreamResourcePool& CustomDeviceStreamResourcePool::Instance( platform::errors::PreconditionNotMet( "Required device shall be CustomPlace, but received %d. ", place)); if (pool.find(place.GetDeviceType()) == pool.end()) { - pool.insert( - {place.GetDeviceType(), - std::vector>()}); + pool.insert({place.GetDeviceType(), + std::vector()}); for (size_t i = 0; i < phi::DeviceManager::GetDeviceCount(place.GetDeviceType()); ++i) { @@ -121,43 +128,50 @@ CustomDeviceEventResourcePool::CustomDeviceEventResourcePool( int dev_cnt = phi::DeviceManager::GetDeviceCount(place.GetDeviceType()); pool_.reserve(dev_cnt); for (int dev_idx = 0; dev_idx < dev_cnt; ++dev_idx) { - auto creator = [place, dev_idx] { + auto creator = [place, dev_idx, this] { auto place_ = phi::CustomPlace(place.GetDeviceType(), dev_idx); phi::DeviceManager::SetDevice(place_); - phi::event::Event* event = new phi::event::Event(place_, nullptr); - phi::DeviceManager::GetDeviceWithPlace(place_)->CreateEvent(event); + phi::event::Event* event = new phi::event::Event; + event->Init(place_); + this->events_.push_back(event); return event; }; - auto deleter = [place, dev_idx](phi::event::Event* event) { - auto place_ = phi::CustomPlace(place.GetDeviceType(), dev_idx); - phi::DeviceManager::SetDevice(place_); - - phi::DeviceManager::GetDeviceWithPlace(place_)->DestroyEvent(event); - }; - - pool_.emplace_back( - ResourcePool::Create(creator, deleter)); + pool_.emplace_back(ResourcePool::Create( + creator, [](phi::event::Event* event) {})); } } -std::unordered_map>>& +std::unordered_map>& CustomDeviceEventResourcePool::GetMap() { - static std::unordered_map< - std::string, - std::vector>> + static std::unordered_map> pool; return pool; } +CustomDeviceEventResourcePool::~CustomDeviceEventResourcePool() { + for (auto* p : events_) { + delete p; + } + pool_.clear(); +} + +void CustomDeviceEventResourcePool::Release() { + auto& pool = GetMap(); + for (auto& item : pool) { + for (auto& p : item.second) { + delete p; + } + item.second.clear(); + } + pool.clear(); +} + CustomDeviceEventResourcePool& CustomDeviceEventResourcePool::Instance( const phi::Place& place) { - static std::unordered_map< - std::string, - std::vector>> - pool; + auto& pool = GetMap(); PADDLE_ENFORCE_EQ( platform::is_custom_place(place), true, @@ -165,8 +179,7 @@ CustomDeviceEventResourcePool& CustomDeviceEventResourcePool::Instance( "Required device shall be CustomPlace, but received %d. ", place)); if (pool.find(place.GetDeviceType()) == pool.end()) { pool.insert( - {place.GetDeviceType(), - std::vector>()}); + {place.GetDeviceType(), std::vector()}); for (size_t i = 0; i < phi::DeviceManager::GetDeviceCount(place.GetDeviceType()); ++i) { diff --git a/paddle/fluid/platform/device/custom/custom_device_resource_pool.h b/paddle/fluid/platform/device/custom/custom_device_resource_pool.h index 749e696215a06f24bbc06c24c7de26c446c96a1e..629877e3a4be828dc14afd8112b8fad4ac01495f 100644 --- a/paddle/fluid/platform/device/custom/custom_device_resource_pool.h +++ b/paddle/fluid/platform/device/custom/custom_device_resource_pool.h @@ -31,15 +31,18 @@ using CustomDeviceEventObject = phi::event::Event; class CustomDeviceStreamResourcePool { public: - static std::unordered_map< - std::string, - std::vector>>& + static std::unordered_map>& GetMap(); + static void Release(); + std::shared_ptr New(int dev_idx); static CustomDeviceStreamResourcePool& Instance(const paddle::Place& place); + ~CustomDeviceStreamResourcePool(); + private: explicit CustomDeviceStreamResourcePool(const paddle::Place& place); @@ -47,19 +50,23 @@ class CustomDeviceStreamResourcePool { private: std::vector>> pool_; + std::vector streams_; }; class CustomDeviceEventResourcePool { public: std::shared_ptr New(int dev_idx); - static std::unordered_map< - std::string, - std::vector>>& + static std::unordered_map>& GetMap(); + static void Release(); + static CustomDeviceEventResourcePool& Instance(const paddle::Place& place); + ~CustomDeviceEventResourcePool(); + private: explicit CustomDeviceEventResourcePool(const paddle::Place& place); @@ -67,6 +74,7 @@ class CustomDeviceEventResourcePool { private: std::vector>> pool_; + std::vector events_; }; } // namespace platform diff --git a/paddle/fluid/platform/profiler/custom_device/custom_tracer.cc b/paddle/fluid/platform/profiler/custom_device/custom_tracer.cc index 70c0ed02a7cf5a0c8c852f32fc8b693a59b0df16..740f27f2922a3b3978f1de1560f943bda8b9ae4d 100644 --- a/paddle/fluid/platform/profiler/custom_device/custom_tracer.cc +++ b/paddle/fluid/platform/profiler/custom_device/custom_tracer.cc @@ -38,6 +38,21 @@ CustomTracer::~CustomTracer() { #endif } +std::unordered_map>& +CustomTracer::GetMap() { + static std::unordered_map> + instance; + return instance; +} + +void CustomTracer::Release() { + auto& pool = GetMap(); + for (auto& item : pool) { + item.second.reset(); + } + pool.clear(); +} + void CustomTracer::PrepareTracing() { PADDLE_ENFORCE_EQ( state_ == TracerState::UNINITED || state_ == TracerState::STOPED, diff --git a/paddle/fluid/platform/profiler/custom_device/custom_tracer.h b/paddle/fluid/platform/profiler/custom_device/custom_tracer.h index 5d5874eabe50ea5ae45e6374decda6afe5153f70..2c3df98a1bd5b63bda873d173dea50ada880e82c 100644 --- a/paddle/fluid/platform/profiler/custom_device/custom_tracer.h +++ b/paddle/fluid/platform/profiler/custom_device/custom_tracer.h @@ -26,18 +26,16 @@ namespace platform { class CustomTracer : public TracerBase { public: - static std::unordered_map>& - GetMap() { - static std::unordered_map> - instance; - return instance; - } + static std::unordered_map>& + GetMap(); + + static void Release(); static CustomTracer& GetInstance(const std::string& device_type) { auto& instance = GetMap(); if (instance.find(device_type) == instance.cend()) { instance.insert( - {device_type, std::make_shared(device_type)}); + {device_type, std::make_unique(device_type)}); } return *instance[device_type]; } diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 60ade1f9875fd6990c0800d5ed9624036d2e8f38..2f66391a9bca8bb65962dde338b08bbda2ebce14 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1009,9 +1009,9 @@ PYBIND11_MODULE(libpaddle, m) { m.def("clear_device_manager", []() { #ifdef PADDLE_WITH_CUSTOM_DEVICE platform::XCCLCommContext::Release(); - platform::CustomTracer::GetMap().clear(); - platform::CustomDeviceEventResourcePool::GetMap().clear(); - platform::CustomDeviceStreamResourcePool::GetMap().clear(); + platform::CustomTracer::Release(); + platform::CustomDeviceEventResourcePool::Release(); + platform::CustomDeviceStreamResourcePool::Release(); phi::DeviceManager::Clear(); #endif }); diff --git a/paddle/phi/backends/device_manager.cc b/paddle/phi/backends/device_manager.cc index 902395a162045bd061a535ceddd4d3d93712a77e..c02ec9912417d40970c165259be9bf9248e469ca 100644 --- a/paddle/phi/backends/device_manager.cc +++ b/paddle/phi/backends/device_manager.cc @@ -671,8 +671,8 @@ DeviceManager& DeviceManager::Instance() { } void DeviceManager::Clear() { - // Instance().device_map_.clear(); - // Instance().device_impl_map_.clear(); + Instance().device_map_.clear(); + Instance().device_impl_map_.clear(); } std::vector ListAllLibraries(const std::string& library_dir) {