diff --git a/paddle/fluid/distributed/collective/ProcessGroupHCCL.cc b/paddle/fluid/distributed/collective/ProcessGroupHCCL.cc index 3b3b505ffb80cfb22dc4c415c30142eac65e47c1..718b33903af8b37510cf79bba600d2e6e5332873 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupHCCL.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupHCCL.cc @@ -197,17 +197,6 @@ std::shared_ptr ProcessGroupHCCL::Collective( SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]); auto task = CreateTask(places, rank_, op_type, inputs); - task->SetOutputs(outputs); - - // if (FLAGS_use_stream_safe_npu_allocator) { - // for (size_t i = 0; i < inputs.size(); ++i) { - // platform::NPUDeviceGuard guard(places[i].GetDeviceId()); - // auto dense_tensor = - // std::dynamic_pointer_cast(inputs[i].impl()); - // memory::RecordStream(dense_tensor->Holder(), - // places_to_ctx_[key][i]->stream()); - // } - // } for (size_t i = 0; i < inputs.size(); ++i) { platform::NPUDeviceGuard guard(places[i].GetDeviceId()); diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc index d776f62373e43bc672ba85fdb316a0dc28a43f88..168548cf9ba06f1584882c0baf78b49f34672d74 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc @@ -244,25 +244,24 @@ std::shared_ptr ProcessGroupNCCL::Collective( SyncDefaultStream(places, places_to_events_[key], places_to_ctx_[key]); auto task = CreateTask(places, rank_, op_type, inputs); - task->SetOutputs(outputs); // construct uninitialize guard for device platform::CUDADeviceGuard cuda_guard; - if (FLAGS_use_stream_safe_cuda_allocator) { + { + platform::NCCLGroupGuard nccl_guard; for (size_t i = 0; i < inputs.size(); ++i) { cuda_guard.SetDevice(places[i]); - memory::RecordStream(inputs[i].Holder(), - places_to_ctx_[key][i]->stream()); + const auto& nccl_stream = places_to_ctx_[key][i]->stream(); + fn(inputs[i], outputs[i], nccl_comms[i]->GetNcclComm(), nccl_stream); } } - { - platform::NCCLGroupGuard nccl_guard; + if (FLAGS_use_stream_safe_cuda_allocator) { for (size_t i = 0; i < inputs.size(); ++i) { cuda_guard.SetDevice(places[i]); - const auto& nccl_stream = places_to_ctx_[key][i]->stream(); - fn(inputs[i], outputs[i], nccl_comms[i]->GetNcclComm(), nccl_stream); + memory::RecordStream(inputs[i].Holder(), + places_to_ctx_[key][i]->stream()); } }