diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc index 2e18dfcc3ba1208f47a4ceeb2529826d46b44c34..f9429e1efa774d6f9edb3bf21588cacb09d16725 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc @@ -448,7 +448,8 @@ std::shared_ptr ProcessGroupNCCL::PointToPoint( platform::CUDADeviceGuard cuda_guard; - if (FLAGS_use_stream_safe_cuda_allocator) { + { + platform::NCCLGroupGuard nccl_guard; for (size_t i = 0; i < tensors.size(); ++i) { cuda_guard.SetDevice(places[i]); gpuStream_t nccl_stream; @@ -460,12 +461,11 @@ std::shared_ptr ProcessGroupNCCL::PointToPoint( } else { nccl_stream = places_to_ctx_[key][i]->stream(); } - memory::RecordStream(tensors[i].Holder(), nccl_stream); + fn(tensors[i], nccl_comms[i]->GetNcclComm(), nccl_stream, dst_rank); } } - { - platform::NCCLGroupGuard nccl_guard; + if (FLAGS_use_stream_safe_cuda_allocator) { for (size_t i = 0; i < tensors.size(); ++i) { cuda_guard.SetDevice(places[i]); gpuStream_t nccl_stream; @@ -477,7 +477,7 @@ std::shared_ptr ProcessGroupNCCL::PointToPoint( } else { nccl_stream = places_to_ctx_[key][i]->stream(); } - fn(tensors[i], nccl_comms[i]->GetNcclComm(), nccl_stream, dst_rank); + memory::RecordStream(tensors[i].Holder(), nccl_stream); } } @@ -516,20 +516,20 @@ std::shared_ptr ProcessGroupNCCL::PointToPoint( // construct uninitialize guard for device platform::CUDADeviceGuard cuda_guard; - if (FLAGS_use_stream_safe_cuda_allocator) { + { + platform::NCCLGroupGuard nccl_guard; for (size_t i = 0; i < tensors.size(); ++i) { cuda_guard.SetDevice(places[i]); - memory::RecordStream(tensors[i].Holder(), - places_to_ctx_[key][i]->stream()); + const auto& nccl_stream = places_to_ctx_[key][i]->stream(); + fn(tensors[i], nccl_comms[i]->GetNcclComm(), nccl_stream, dst_rank); } } - { - platform::NCCLGroupGuard nccl_guard; + if (FLAGS_use_stream_safe_cuda_allocator) { for (size_t i = 0; i < tensors.size(); ++i) { cuda_guard.SetDevice(places[i]); - const auto& nccl_stream = places_to_ctx_[key][i]->stream(); - fn(tensors[i], nccl_comms[i]->GetNcclComm(), nccl_stream, dst_rank); + memory::RecordStream(tensors[i].Holder(), + places_to_ctx_[key][i]->stream()); } }