未验证 提交 031debb7 编写于 作者: S ShenLiang 提交者: GitHub

fix memory leak (#44971)

上级 18d5b44c
...@@ -197,17 +197,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHCCL::Collective( ...@@ -197,17 +197,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHCCL::Collective(
SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]); SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]);
auto task = CreateTask(places, rank_, op_type, inputs); auto task = CreateTask(places, rank_, op_type, inputs);
task->SetOutputs(outputs);
// if (FLAGS_use_stream_safe_npu_allocator) {
// for (size_t i = 0; i < inputs.size(); ++i) {
// platform::NPUDeviceGuard guard(places[i].GetDeviceId());
// auto dense_tensor =
// std::dynamic_pointer_cast<phi::DenseTensor>(inputs[i].impl());
// memory::RecordStream(dense_tensor->Holder(),
// places_to_ctx_[key][i]->stream());
// }
// }
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
platform::NPUDeviceGuard guard(places[i].GetDeviceId()); platform::NPUDeviceGuard guard(places[i].GetDeviceId());
......
...@@ -244,25 +244,24 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective( ...@@ -244,25 +244,24 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]); SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]);
auto task = CreateTask(places, rank_, op_type, inputs); auto task = CreateTask(places, rank_, op_type, inputs);
task->SetOutputs(outputs);
// construct uninitialize guard for device // construct uninitialize guard for device
platform::CUDADeviceGuard cuda_guard; platform::CUDADeviceGuard cuda_guard;
if (FLAGS_use_stream_safe_cuda_allocator) { {
platform::NCCLGroupGuard nccl_guard;
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
cuda_guard.SetDevice(places[i]); cuda_guard.SetDevice(places[i]);
memory::RecordStream(inputs[i].Holder(), const auto& nccl_stream = places_to_ctx_[key][i]->stream();
places_to_ctx_[key][i]->stream()); fn(inputs[i], outputs[i], nccl_comms[i]->GetNcclComm(), nccl_stream);
} }
} }
{ if (FLAGS_use_stream_safe_cuda_allocator) {
platform::NCCLGroupGuard nccl_guard;
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
cuda_guard.SetDevice(places[i]); cuda_guard.SetDevice(places[i]);
const auto& nccl_stream = places_to_ctx_[key][i]->stream(); memory::RecordStream(inputs[i].Holder(),
fn(inputs[i], outputs[i], nccl_comms[i]->GetNcclComm(), nccl_stream); places_to_ctx_[key][i]->stream());
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册