未验证 提交 6af85a81 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] fix resource_pool release bug (#55229)

上级 b8f265d2
......@@ -30,39 +30,47 @@ CustomDeviceStreamResourcePool::CustomDeviceStreamResourcePool(
int dev_cnt = phi::DeviceManager::GetDeviceCount(place.GetDeviceType());
pool_.reserve(dev_cnt);
for (int dev_idx = 0; dev_idx < dev_cnt; ++dev_idx) {
auto creator = [place, dev_idx] {
auto creator = [place, dev_idx, this] {
auto place_ = phi::CustomPlace(place.GetDeviceType(), dev_idx);
phi::DeviceManager::SetDevice(place_);
phi::stream::Stream* stream = new phi::stream::Stream(place_, nullptr);
phi::DeviceManager::GetDeviceWithPlace(place_)->CreateStream(stream);
phi::stream::Stream* stream = new phi::stream::Stream;
stream->Init(place_);
this->streams_.push_back(stream);
return stream;
};
auto deleter = [place, dev_idx](phi::stream::Stream* stream) {
auto place_ = phi::CustomPlace(place.GetDeviceType(), dev_idx);
phi::DeviceManager::SetDevice(place_);
phi::DeviceManager::GetDeviceWithPlace(place_)->DestroyStream(stream);
delete stream;
};
pool_.emplace_back(
ResourcePool<CustomDeviceStreamObject>::Create(creator, deleter));
pool_.emplace_back(ResourcePool<CustomDeviceStreamObject>::Create(
creator, [](phi::stream::Stream* stream) {}));
}
}
std::unordered_map<
std::string,
std::vector<std::shared_ptr<CustomDeviceStreamResourcePool>>>&
std::unordered_map<std::string, std::vector<CustomDeviceStreamResourcePool*>>&
CustomDeviceStreamResourcePool::GetMap() {
static std::unordered_map<
std::string,
std::vector<std::shared_ptr<CustomDeviceStreamResourcePool>>>
static std::unordered_map<std::string,
std::vector<CustomDeviceStreamResourcePool*>>
pool;
return pool;
}
CustomDeviceStreamResourcePool::~CustomDeviceStreamResourcePool() {
for (auto* p : streams_) {
delete p;
}
pool_.clear();
}
void CustomDeviceStreamResourcePool::Release() {
auto& pool = GetMap();
for (auto& item : pool) {
for (auto& p : item.second) {
delete p;
}
item.second.clear();
}
pool.clear();
}
CustomDeviceStreamResourcePool& CustomDeviceStreamResourcePool::Instance(
const paddle::Place& place) {
auto& pool = GetMap();
......@@ -72,9 +80,8 @@ CustomDeviceStreamResourcePool& CustomDeviceStreamResourcePool::Instance(
platform::errors::PreconditionNotMet(
"Required device shall be CustomPlace, but received %d. ", place));
if (pool.find(place.GetDeviceType()) == pool.end()) {
pool.insert(
{place.GetDeviceType(),
std::vector<std::shared_ptr<CustomDeviceStreamResourcePool>>()});
pool.insert({place.GetDeviceType(),
std::vector<CustomDeviceStreamResourcePool*>()});
for (size_t i = 0;
i < phi::DeviceManager::GetDeviceCount(place.GetDeviceType());
++i) {
......@@ -121,43 +128,50 @@ CustomDeviceEventResourcePool::CustomDeviceEventResourcePool(
int dev_cnt = phi::DeviceManager::GetDeviceCount(place.GetDeviceType());
pool_.reserve(dev_cnt);
for (int dev_idx = 0; dev_idx < dev_cnt; ++dev_idx) {
auto creator = [place, dev_idx] {
auto creator = [place, dev_idx, this] {
auto place_ = phi::CustomPlace(place.GetDeviceType(), dev_idx);
phi::DeviceManager::SetDevice(place_);
phi::event::Event* event = new phi::event::Event(place_, nullptr);
phi::DeviceManager::GetDeviceWithPlace(place_)->CreateEvent(event);
phi::event::Event* event = new phi::event::Event;
event->Init(place_);
this->events_.push_back(event);
return event;
};
auto deleter = [place, dev_idx](phi::event::Event* event) {
auto place_ = phi::CustomPlace(place.GetDeviceType(), dev_idx);
phi::DeviceManager::SetDevice(place_);
phi::DeviceManager::GetDeviceWithPlace(place_)->DestroyEvent(event);
};
pool_.emplace_back(
ResourcePool<CustomDeviceEventObject>::Create(creator, deleter));
pool_.emplace_back(ResourcePool<CustomDeviceEventObject>::Create(
creator, [](phi::event::Event* event) {}));
}
}
std::unordered_map<std::string,
std::vector<std::shared_ptr<CustomDeviceEventResourcePool>>>&
std::unordered_map<std::string, std::vector<CustomDeviceEventResourcePool*>>&
CustomDeviceEventResourcePool::GetMap() {
static std::unordered_map<
std::string,
std::vector<std::shared_ptr<CustomDeviceEventResourcePool>>>
static std::unordered_map<std::string,
std::vector<CustomDeviceEventResourcePool*>>
pool;
return pool;
}
CustomDeviceEventResourcePool::~CustomDeviceEventResourcePool() {
for (auto* p : events_) {
delete p;
}
pool_.clear();
}
void CustomDeviceEventResourcePool::Release() {
auto& pool = GetMap();
for (auto& item : pool) {
for (auto& p : item.second) {
delete p;
}
item.second.clear();
}
pool.clear();
}
CustomDeviceEventResourcePool& CustomDeviceEventResourcePool::Instance(
const phi::Place& place) {
static std::unordered_map<
std::string,
std::vector<std::shared_ptr<CustomDeviceEventResourcePool>>>
pool;
auto& pool = GetMap();
PADDLE_ENFORCE_EQ(
platform::is_custom_place(place),
true,
......@@ -165,8 +179,7 @@ CustomDeviceEventResourcePool& CustomDeviceEventResourcePool::Instance(
"Required device shall be CustomPlace, but received %d. ", place));
if (pool.find(place.GetDeviceType()) == pool.end()) {
pool.insert(
{place.GetDeviceType(),
std::vector<std::shared_ptr<CustomDeviceEventResourcePool>>()});
{place.GetDeviceType(), std::vector<CustomDeviceEventResourcePool*>()});
for (size_t i = 0;
i < phi::DeviceManager::GetDeviceCount(place.GetDeviceType());
++i) {
......
......@@ -31,15 +31,18 @@ using CustomDeviceEventObject = phi::event::Event;
class CustomDeviceStreamResourcePool {
public:
static std::unordered_map<
std::string,
std::vector<std::shared_ptr<CustomDeviceStreamResourcePool>>>&
static std::unordered_map<std::string,
std::vector<CustomDeviceStreamResourcePool*>>&
GetMap();
static void Release();
std::shared_ptr<CustomDeviceStreamObject> New(int dev_idx);
static CustomDeviceStreamResourcePool& Instance(const paddle::Place& place);
~CustomDeviceStreamResourcePool();
private:
explicit CustomDeviceStreamResourcePool(const paddle::Place& place);
......@@ -47,19 +50,23 @@ class CustomDeviceStreamResourcePool {
private:
std::vector<std::shared_ptr<ResourcePool<CustomDeviceStreamObject>>> pool_;
std::vector<phi::stream::Stream*> streams_;
};
class CustomDeviceEventResourcePool {
public:
std::shared_ptr<CustomDeviceEventObject> New(int dev_idx);
static std::unordered_map<
std::string,
std::vector<std::shared_ptr<CustomDeviceEventResourcePool>>>&
static std::unordered_map<std::string,
std::vector<CustomDeviceEventResourcePool*>>&
GetMap();
static void Release();
static CustomDeviceEventResourcePool& Instance(const paddle::Place& place);
~CustomDeviceEventResourcePool();
private:
explicit CustomDeviceEventResourcePool(const paddle::Place& place);
......@@ -67,6 +74,7 @@ class CustomDeviceEventResourcePool {
private:
std::vector<std::shared_ptr<ResourcePool<CustomDeviceEventObject>>> pool_;
std::vector<phi::event::Event*> events_;
};
} // namespace platform
......
......@@ -38,6 +38,21 @@ CustomTracer::~CustomTracer() {
#endif
}
std::unordered_map<std::string, std::unique_ptr<CustomTracer>>&
CustomTracer::GetMap() {
static std::unordered_map<std::string, std::unique_ptr<CustomTracer>>
instance;
return instance;
}
void CustomTracer::Release() {
auto& pool = GetMap();
for (auto& item : pool) {
item.second.reset();
}
pool.clear();
}
void CustomTracer::PrepareTracing() {
PADDLE_ENFORCE_EQ(
state_ == TracerState::UNINITED || state_ == TracerState::STOPED,
......
......@@ -26,18 +26,16 @@ namespace platform {
class CustomTracer : public TracerBase {
public:
static std::unordered_map<std::string, std::shared_ptr<CustomTracer>>&
GetMap() {
static std::unordered_map<std::string, std::shared_ptr<CustomTracer>>
instance;
return instance;
}
static std::unordered_map<std::string, std::unique_ptr<CustomTracer>>&
GetMap();
static void Release();
static CustomTracer& GetInstance(const std::string& device_type) {
auto& instance = GetMap();
if (instance.find(device_type) == instance.cend()) {
instance.insert(
{device_type, std::make_shared<CustomTracer>(device_type)});
{device_type, std::make_unique<CustomTracer>(device_type)});
}
return *instance[device_type];
}
......
......@@ -1009,9 +1009,9 @@ PYBIND11_MODULE(libpaddle, m) {
m.def("clear_device_manager", []() {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
platform::XCCLCommContext::Release();
platform::CustomTracer::GetMap().clear();
platform::CustomDeviceEventResourcePool::GetMap().clear();
platform::CustomDeviceStreamResourcePool::GetMap().clear();
platform::CustomTracer::Release();
platform::CustomDeviceEventResourcePool::Release();
platform::CustomDeviceStreamResourcePool::Release();
phi::DeviceManager::Clear();
#endif
});
......
......@@ -671,8 +671,8 @@ DeviceManager& DeviceManager::Instance() {
}
void DeviceManager::Clear() {
// Instance().device_map_.clear();
// Instance().device_impl_map_.clear();
Instance().device_map_.clear();
Instance().device_impl_map_.clear();
}
std::vector<std::string> ListAllLibraries(const std::string& library_dir) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册