未验证 提交 0201ccc4 编写于 作者: Y Yuang Liu 提交者: GitHub

fix p2p comm memory release logic (#47497) (#47517)

上级 4b3589fb
...@@ -448,7 +448,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint( ...@@ -448,7 +448,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
platform::CUDADeviceGuard cuda_guard; platform::CUDADeviceGuard cuda_guard;
if (FLAGS_use_stream_safe_cuda_allocator) { {
platform::NCCLGroupGuard nccl_guard;
for (size_t i = 0; i < tensors.size(); ++i) { for (size_t i = 0; i < tensors.size(); ++i) {
cuda_guard.SetDevice(places[i]); cuda_guard.SetDevice(places[i]);
gpuStream_t nccl_stream; gpuStream_t nccl_stream;
...@@ -460,12 +461,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint( ...@@ -460,12 +461,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
} else { } else {
nccl_stream = places_to_ctx_[key][i]->stream(); nccl_stream = places_to_ctx_[key][i]->stream();
} }
memory::RecordStream(tensors[i].Holder(), nccl_stream); fn(tensors[i], nccl_comms[i]->GetNcclComm(), nccl_stream, dst_rank);
} }
} }
{ if (FLAGS_use_stream_safe_cuda_allocator) {
platform::NCCLGroupGuard nccl_guard;
for (size_t i = 0; i < tensors.size(); ++i) { for (size_t i = 0; i < tensors.size(); ++i) {
cuda_guard.SetDevice(places[i]); cuda_guard.SetDevice(places[i]);
gpuStream_t nccl_stream; gpuStream_t nccl_stream;
...@@ -477,7 +477,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint( ...@@ -477,7 +477,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
} else { } else {
nccl_stream = places_to_ctx_[key][i]->stream(); nccl_stream = places_to_ctx_[key][i]->stream();
} }
fn(tensors[i], nccl_comms[i]->GetNcclComm(), nccl_stream, dst_rank); memory::RecordStream(tensors[i].Holder(), nccl_stream);
} }
} }
...@@ -516,20 +516,20 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint( ...@@ -516,20 +516,20 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
// construct uninitialize guard for device // construct uninitialize guard for device
platform::CUDADeviceGuard cuda_guard; platform::CUDADeviceGuard cuda_guard;
if (FLAGS_use_stream_safe_cuda_allocator) { {
platform::NCCLGroupGuard nccl_guard;
for (size_t i = 0; i < tensors.size(); ++i) { for (size_t i = 0; i < tensors.size(); ++i) {
cuda_guard.SetDevice(places[i]); cuda_guard.SetDevice(places[i]);
memory::RecordStream(tensors[i].Holder(), const auto& nccl_stream = places_to_ctx_[key][i]->stream();
places_to_ctx_[key][i]->stream()); fn(tensors[i], nccl_comms[i]->GetNcclComm(), nccl_stream, dst_rank);
} }
} }
{ if (FLAGS_use_stream_safe_cuda_allocator) {
platform::NCCLGroupGuard nccl_guard;
for (size_t i = 0; i < tensors.size(); ++i) { for (size_t i = 0; i < tensors.size(); ++i) {
cuda_guard.SetDevice(places[i]); cuda_guard.SetDevice(places[i]);
const auto& nccl_stream = places_to_ctx_[key][i]->stream(); memory::RecordStream(tensors[i].Holder(),
fn(tensors[i], nccl_comms[i]->GetNcclComm(), nccl_stream, dst_rank); places_to_ctx_[key][i]->stream());
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册