From f82d7e3cb2c24be244704e7aa0f61d4afd9a7be7 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Tue, 1 Nov 2022 09:19:25 +0800 Subject: [PATCH] fix p2p comm memory release logic (#47497) --- .../collective/ProcessGroupNCCL.cc | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc index 76d1d42c7d..db713ac304 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc @@ -453,7 +453,8 @@ std::shared_ptr ProcessGroupNCCL::PointToPoint( platform::CUDADeviceGuard cuda_guard; - if (FLAGS_use_stream_safe_cuda_allocator) { + { + platform::NCCLGroupGuard nccl_guard; for (size_t i = 0; i < tensors.size(); ++i) { cuda_guard.SetDevice(places[i]); gpuStream_t nccl_stream; @@ -465,12 +466,11 @@ std::shared_ptr ProcessGroupNCCL::PointToPoint( } else { nccl_stream = places_to_ctx_[key][i]->stream(); } - memory::RecordStream(tensors[i].Holder(), nccl_stream); + fn(tensors[i], nccl_comms[i]->GetNcclComm(), nccl_stream, dst_rank); } } - { - platform::NCCLGroupGuard nccl_guard; + if (FLAGS_use_stream_safe_cuda_allocator) { for (size_t i = 0; i < tensors.size(); ++i) { cuda_guard.SetDevice(places[i]); gpuStream_t nccl_stream; @@ -482,7 +482,7 @@ std::shared_ptr ProcessGroupNCCL::PointToPoint( } else { nccl_stream = places_to_ctx_[key][i]->stream(); } - fn(tensors[i], nccl_comms[i]->GetNcclComm(), nccl_stream, dst_rank); + memory::RecordStream(tensors[i].Holder(), nccl_stream); } } @@ -521,20 +521,20 @@ std::shared_ptr ProcessGroupNCCL::PointToPoint( // construct uninitialize guard for device platform::CUDADeviceGuard cuda_guard; - if (FLAGS_use_stream_safe_cuda_allocator) { + { + platform::NCCLGroupGuard nccl_guard; for (size_t i = 0; i < tensors.size(); ++i) { cuda_guard.SetDevice(places[i]); - memory::RecordStream(tensors[i].Holder(), - places_to_ctx_[key][i]->stream()); + const auto& nccl_stream = places_to_ctx_[key][i]->stream(); + fn(tensors[i], nccl_comms[i]->GetNcclComm(), nccl_stream, dst_rank); } } - { - platform::NCCLGroupGuard nccl_guard; + if (FLAGS_use_stream_safe_cuda_allocator) { for (size_t i = 0; i < tensors.size(); ++i) { cuda_guard.SetDevice(places[i]); - const auto& nccl_stream = places_to_ctx_[key][i]->stream(); - fn(tensors[i], nccl_comms[i]->GetNcclComm(), nccl_stream, dst_rank); + memory::RecordStream(tensors[i].Holder(), + places_to_ctx_[key][i]->stream()); } } -- GitLab