diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc index 76d1d42c7d653f908227f423b36b9a583105ad3f..db713ac304e291712ce046760d1f95f40eb5a7b6 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc @@ -453,7 +453,8 @@ std::shared_ptr ProcessGroupNCCL::PointToPoint( platform::CUDADeviceGuard cuda_guard; - if (FLAGS_use_stream_safe_cuda_allocator) { + { + platform::NCCLGroupGuard nccl_guard; for (size_t i = 0; i < tensors.size(); ++i) { cuda_guard.SetDevice(places[i]); gpuStream_t nccl_stream; @@ -465,12 +466,11 @@ std::shared_ptr ProcessGroupNCCL::PointToPoint( } else { nccl_stream = places_to_ctx_[key][i]->stream(); } - memory::RecordStream(tensors[i].Holder(), nccl_stream); + fn(tensors[i], nccl_comms[i]->GetNcclComm(), nccl_stream, dst_rank); } } - { - platform::NCCLGroupGuard nccl_guard; + if (FLAGS_use_stream_safe_cuda_allocator) { for (size_t i = 0; i < tensors.size(); ++i) { cuda_guard.SetDevice(places[i]); gpuStream_t nccl_stream; @@ -482,7 +482,7 @@ std::shared_ptr ProcessGroupNCCL::PointToPoint( } else { nccl_stream = places_to_ctx_[key][i]->stream(); } - fn(tensors[i], nccl_comms[i]->GetNcclComm(), nccl_stream, dst_rank); + memory::RecordStream(tensors[i].Holder(), nccl_stream); } } @@ -521,20 +521,20 @@ std::shared_ptr ProcessGroupNCCL::PointToPoint( // construct uninitialize guard for device platform::CUDADeviceGuard cuda_guard; - if (FLAGS_use_stream_safe_cuda_allocator) { + { + platform::NCCLGroupGuard nccl_guard; for (size_t i = 0; i < tensors.size(); ++i) { cuda_guard.SetDevice(places[i]); - memory::RecordStream(tensors[i].Holder(), - places_to_ctx_[key][i]->stream()); + const auto& nccl_stream = places_to_ctx_[key][i]->stream(); + fn(tensors[i], nccl_comms[i]->GetNcclComm(), nccl_stream, dst_rank); } } - { - platform::NCCLGroupGuard nccl_guard; + if (FLAGS_use_stream_safe_cuda_allocator) { for (size_t i = 0; i < tensors.size(); ++i) { cuda_guard.SetDevice(places[i]); - const auto& nccl_stream = places_to_ctx_[key][i]->stream(); - fn(tensors[i], nccl_comms[i]->GetNcclComm(), nccl_stream, dst_rank); + memory::RecordStream(tensors[i].Holder(), + places_to_ctx_[key][i]->stream()); } }