From 0201ccc412556c447c872a290505528de1393c96 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Tue, 1 Nov 2022 11:13:39 +0800 Subject: [PATCH] fix p2p comm memory release logic (#47497) (#47517) --- .../collective/ProcessGroupNCCL.cc | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc index 2e18dfcc3ba..f9429e1efa7 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc @@ -448,7 +448,8 @@ std::shared_ptr ProcessGroupNCCL::PointToPoint( platform::CUDADeviceGuard cuda_guard; - if (FLAGS_use_stream_safe_cuda_allocator) { + { + platform::NCCLGroupGuard nccl_guard; for (size_t i = 0; i < tensors.size(); ++i) { cuda_guard.SetDevice(places[i]); gpuStream_t nccl_stream; @@ -460,12 +461,11 @@ std::shared_ptr ProcessGroupNCCL::PointToPoint( } else { nccl_stream = places_to_ctx_[key][i]->stream(); } - memory::RecordStream(tensors[i].Holder(), nccl_stream); + fn(tensors[i], nccl_comms[i]->GetNcclComm(), nccl_stream, dst_rank); } } - { - platform::NCCLGroupGuard nccl_guard; + if (FLAGS_use_stream_safe_cuda_allocator) { for (size_t i = 0; i < tensors.size(); ++i) { cuda_guard.SetDevice(places[i]); gpuStream_t nccl_stream; @@ -477,7 +477,7 @@ std::shared_ptr ProcessGroupNCCL::PointToPoint( } else { nccl_stream = places_to_ctx_[key][i]->stream(); } - fn(tensors[i], nccl_comms[i]->GetNcclComm(), nccl_stream, dst_rank); + memory::RecordStream(tensors[i].Holder(), nccl_stream); } } @@ -516,20 +516,20 @@ std::shared_ptr ProcessGroupNCCL::PointToPoint( // construct uninitialize guard for device platform::CUDADeviceGuard cuda_guard; - if (FLAGS_use_stream_safe_cuda_allocator) { + { + platform::NCCLGroupGuard nccl_guard; for (size_t i = 0; i < tensors.size(); ++i) { cuda_guard.SetDevice(places[i]); - memory::RecordStream(tensors[i].Holder(), - places_to_ctx_[key][i]->stream()); + const auto& nccl_stream = places_to_ctx_[key][i]->stream(); + fn(tensors[i], nccl_comms[i]->GetNcclComm(), nccl_stream, dst_rank); } } - { - platform::NCCLGroupGuard nccl_guard; + if (FLAGS_use_stream_safe_cuda_allocator) { for (size_t i = 0; i < tensors.size(); ++i) { cuda_guard.SetDevice(places[i]); - const auto& nccl_stream = places_to_ctx_[key][i]->stream(); - fn(tensors[i], nccl_comms[i]->GetNcclComm(), nccl_stream, dst_rank); + memory::RecordStream(tensors[i].Holder(), + places_to_ctx_[key][i]->stream()); } } -- GitLab