未验证 提交 7a705727 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] fix release error in process_group_custom (#55293)

* [CustomDevice] fix release error for process_group_custom

* update
上级 766fcdf0
......@@ -148,7 +148,7 @@ class CustomCCLCommManager {
~CustomCCLCommManager() noexcept {
std::unique_lock<std::mutex> lock(mutex_);
if (ccl_comm_) {
if (phi::DeviceManager::HasDeviceType(device_type_) && ccl_comm_) {
phi::DeviceManager::CCLDestroyComm(device_type_, ccl_comm_);
}
}
......
......@@ -1021,7 +1021,7 @@ PYBIND11_MODULE(libpaddle, m) {
platform::CustomTracer::Release();
platform::CustomDeviceEventResourcePool::Release();
platform::CustomDeviceStreamResourcePool::Release();
phi::DeviceManager::Clear();
phi::DeviceManager::Release();
#endif
});
......
......@@ -76,7 +76,7 @@ CustomContext::CustomContext(const CustomPlace& place)
impl_->Init();
}
CustomContext::~CustomContext() { impl_->Init(); }
CustomContext::~CustomContext() { impl_.reset(); }
phi::ccl::CCLComm CustomContext::xccl_comm() const {
return impl_->xccl_comm();
......
......@@ -343,8 +343,9 @@ std::vector<std::string> DeviceManager::GetAllCustomDeviceList() {
}
bool DeviceManager::HasDeviceType(const std::string& device_type) {
auto dev_impl = GetDeviceInterfaceWithType(device_type);
return dev_impl != nullptr;
phi::AutoRDLock lock(&_global_device_manager_rw_lock);
auto& dev_impl_map = Instance().device_impl_map_;
return dev_impl_map.find(device_type) != dev_impl_map.end();
}
bool DeviceManager::IsCustom(const std::string& device_type) {
......@@ -670,7 +671,9 @@ DeviceManager& DeviceManager::Instance() {
return platform_manager;
}
void DeviceManager::Clear() {
void DeviceManager::Release() {
stream::Stream::ReleaseAll();
event::Event::ReleaseAll();
Instance().device_map_.clear();
Instance().device_impl_map_.clear();
}
......
......@@ -273,7 +273,7 @@ class DeviceManager {
uint64_t start_ns,
void* context);
static void Clear();
static void Release();
private:
DISABLE_COPY_AND_ASSIGN(DeviceManager);
......
......@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <list>
#include "paddle/phi/backends/event.h"
#include "glog/logging.h"
......@@ -22,6 +24,14 @@
namespace phi {
namespace event {
std::list<Event*> g_events;
void Event::ReleaseAll() {
for (auto* event : g_events) {
event->Destroy();
}
}
event_t Event::raw_event() const { return event_; }
void Event::set_event(event_t event) { event_ = event; }
......@@ -32,7 +42,10 @@ Event::Event(const Place& place, event_t event)
event_(event),
own_data_(false) {}
Event::~Event() { Destroy(); }
Event::~Event() {
g_events.remove(this);
Destroy();
}
bool Event::Init(const Place& place, Flag flags) {
place_ = place;
......@@ -45,14 +58,19 @@ bool Event::Init(const Place& place, Flag flags) {
VLOG(3) << "Init Event: " << event_ << ", place: " << place_
<< ", flag:" << static_cast<int>(flags);
own_data_ = true;
g_events.push_back(this);
return true;
}
void Event::Destroy() {
if (device_) {
if (own_data_) {
phi::DeviceManager::SetDevice(place_);
device_->DestroyEvent(this);
}
own_data_ = false;
event_ = nullptr;
device_ = nullptr;
}
}
......
......@@ -49,6 +49,8 @@ class Event {
void Synchronize() const;
const Place& GetPlace() const;
static void ReleaseAll();
private:
DISABLE_COPY_AND_ASSIGN(Event);
Place place_;
......
......@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <list>
#include "paddle/phi/backends/stream.h"
#include "glog/logging.h"
......@@ -22,7 +24,18 @@
namespace phi {
namespace stream {
Stream::~Stream() { Destroy(); }
std::list<Stream*> g_streams;
void Stream::ReleaseAll() {
for (auto* stream : g_streams) {
stream->Destroy();
}
}
Stream::~Stream() {
g_streams.remove(this);
Destroy();
}
const stream_t& Stream::raw_stream() const { return stream_; }
......@@ -52,6 +65,7 @@ bool Stream::Init(const Place& place,
<< ", priority: " << static_cast<int>(priority)
<< ", flag:" << static_cast<int>(flag);
own_data_ = true;
g_streams.push_back(this);
return true;
}
......@@ -83,11 +97,14 @@ void Stream::Wait() const {
void Stream::WaitCallback() const { callback_manager_->Wait(); }
void Stream::Destroy() {
if (own_data_ && stream_ != nullptr) {
if (device_) {
if (own_data_) {
phi::DeviceManager::SetDevice(place_);
device_->DestroyStream(this);
}
own_data_ = false;
stream_ = nullptr;
device_ = nullptr;
}
}
......
......@@ -66,6 +66,8 @@ class Stream {
void Synchronize() const;
const Place& GetPlace() const;
static void ReleaseAll();
private:
DISABLE_COPY_AND_ASSIGN(Stream);
Place place_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册