未验证 提交 a811256d 编写于 作者: Y Yi Liu 提交者: GitHub

Fix CUDAHandleHolder destruction problem. (#23772) (#23830)

eagerly release cuda resources before cuda enviroment destroying
test=develop
上级 52d0967a
...@@ -67,7 +67,6 @@ NCCLComm* NCCLCommContext::CreateNCCLComm(ncclUniqueId* nccl_id, int nranks, ...@@ -67,7 +67,6 @@ NCCLComm* NCCLCommContext::CreateNCCLComm(ncclUniqueId* nccl_id, int nranks,
PADDLE_ENFORCE_CUDA_SUCCESS(cudaSetDevice(dev_id)); PADDLE_ENFORCE_CUDA_SUCCESS(cudaSetDevice(dev_id));
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::ncclCommInitRank(&comm, nranks, *nccl_id, rank)); platform::dynload::ncclCommInitRank(&comm, nranks, *nccl_id, rank));
comm_vec_.push_back(comm);
auto* comm_wrapper = AssignNCCLComm(comm, nranks, rank, dev_id, ring_id); auto* comm_wrapper = AssignNCCLComm(comm, nranks, rank, dev_id, ring_id);
...@@ -89,7 +88,6 @@ void NCCLCommContext::CreateAllNCCLComms(const std::vector<int>& dev_ids, ...@@ -89,7 +88,6 @@ void NCCLCommContext::CreateAllNCCLComms(const std::vector<int>& dev_ids,
ncclComm_t comms[kDevices]; ncclComm_t comms[kDevices];
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclCommInitAll( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclCommInitAll(
comms, dev_ids.size(), dev_ids.data())); comms, dev_ids.size(), dev_ids.data()));
comm_vec_.insert(comm_vec_.end(), comms, comms + kDevices);
PADDLE_ENFORCE_EQ(comm_map_.count(ring_id), 0); PADDLE_ENFORCE_EQ(comm_map_.count(ring_id), 0);
for (size_t i = 0; i < dev_ids.size(); ++i) { for (size_t i = 0; i < dev_ids.size(); ++i) {
...@@ -135,10 +133,10 @@ NCCLComm* NCCLCommContext::AssignNCCLComm(ncclComm_t comm, int nranks, int rank, ...@@ -135,10 +133,10 @@ NCCLComm* NCCLCommContext::AssignNCCLComm(ncclComm_t comm, int nranks, int rank,
} }
void NCCLCommContext::ReleaseNCCLComms() { void NCCLCommContext::ReleaseNCCLComms() {
for (auto comm : comm_vec_) { for (auto& p : comm_map_) {
PADDLE_ENFORCE_CUDA_SUCCESS( for (auto& q : p.second) {
platform::dynload::ncclCommDestroy(comm), q.second.reset();
platform::errors::External("Fail to destroy nccl comm")); }
} }
} }
......
...@@ -110,8 +110,6 @@ class NCCLCommContext { ...@@ -110,8 +110,6 @@ class NCCLCommContext {
// ring id to dev-NCCLComm // ring id to dev-NCCLComm
std::map<int, std::map<int, std::unique_ptr<NCCLComm>>> comm_map_; std::map<int, std::map<int, std::unique_ptr<NCCLComm>>> comm_map_;
std::vector<ncclComm_t> comm_vec_;
void ReleaseNCCLComms(); void ReleaseNCCLComms();
NCCLCommContext() = default; NCCLCommContext() = default;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册