From a811256d416abe4134ac16269ef2247bfcd7e4eb Mon Sep 17 00:00:00 2001 From: Yi Liu Date: Wed, 15 Apr 2020 11:22:25 +0800 Subject: [PATCH] Fix CUDAHandleHolder destruction problem. (#23772) (#23830) eagerly release cuda resources before cuda enviroment destroying test=develop --- paddle/fluid/platform/collective_helper.cc | 10 ++++------ paddle/fluid/platform/collective_helper.h | 2 -- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/platform/collective_helper.cc b/paddle/fluid/platform/collective_helper.cc index 4c4a653344..3732fce580 100644 --- a/paddle/fluid/platform/collective_helper.cc +++ b/paddle/fluid/platform/collective_helper.cc @@ -67,7 +67,6 @@ NCCLComm* NCCLCommContext::CreateNCCLComm(ncclUniqueId* nccl_id, int nranks, PADDLE_ENFORCE_CUDA_SUCCESS(cudaSetDevice(dev_id)); PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::ncclCommInitRank(&comm, nranks, *nccl_id, rank)); - comm_vec_.push_back(comm); auto* comm_wrapper = AssignNCCLComm(comm, nranks, rank, dev_id, ring_id); @@ -89,7 +88,6 @@ void NCCLCommContext::CreateAllNCCLComms(const std::vector& dev_ids, ncclComm_t comms[kDevices]; PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclCommInitAll( comms, dev_ids.size(), dev_ids.data())); - comm_vec_.insert(comm_vec_.end(), comms, comms + kDevices); PADDLE_ENFORCE_EQ(comm_map_.count(ring_id), 0); for (size_t i = 0; i < dev_ids.size(); ++i) { @@ -135,10 +133,10 @@ NCCLComm* NCCLCommContext::AssignNCCLComm(ncclComm_t comm, int nranks, int rank, } void NCCLCommContext::ReleaseNCCLComms() { - for (auto comm : comm_vec_) { - PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::ncclCommDestroy(comm), - platform::errors::External("Fail to destroy nccl comm")); + for (auto& p : comm_map_) { + for (auto& q : p.second) { + q.second.reset(); + } } } diff --git a/paddle/fluid/platform/collective_helper.h b/paddle/fluid/platform/collective_helper.h index 8816703d45..154d3133c1 100644 --- a/paddle/fluid/platform/collective_helper.h +++ b/paddle/fluid/platform/collective_helper.h @@ -110,8 +110,6 @@ class NCCLCommContext { // ring id to dev-NCCLComm std::map>> comm_map_; - std::vector comm_vec_; - void ReleaseNCCLComms(); NCCLCommContext() = default; -- GitLab