diff --git a/paddle/fluid/distributed/collective/process_group_custom.cc b/paddle/fluid/distributed/collective/process_group_custom.cc index ecf88d93ebb5e90656c4a966fdd93cc221ff1453..9acf961cdc45950ea421caa4ccdcde8a44ce0fbb 100644 --- a/paddle/fluid/distributed/collective/process_group_custom.cc +++ b/paddle/fluid/distributed/collective/process_group_custom.cc @@ -38,8 +38,8 @@ void SyncDefaultStream( for (size_t i = 0; i < places.size(); ++i) { auto* default_ctx = static_cast( platform::DeviceContextPool::Instance().Get(places[i])); - cclEvents[i].Record(*dev_ctx[i]); - cclEvents[i].Block(*default_ctx); + cclEvents[i].Record(*default_ctx); + cclEvents[i].Block(*dev_ctx[i]); } } @@ -74,8 +74,7 @@ void ProcessGroupCustom::CustomTask::SynchronizeStreams() { auto* default_ctx = static_cast( platform::DeviceContextPool::Instance().Get(places_[i])); phi::DeviceGuard guard(default_ctx->GetPlace()); - phi::stream::Stream stream(default_ctx->GetPlace(), default_ctx->stream()); - stream.WaitEvent(control_events_[i].GetCustomEvent()); + control_events_[i].Block(*default_ctx); } } @@ -118,7 +117,7 @@ ProcessGroupCustom::ProcessGroupCustom( int rank, int size, int gid) - : ProcessGroupWithoutStream(rank, size, gid), + : ProcessGroupWithStream(rank, size, gid), store_(store), device_type_(device_type) {} diff --git a/paddle/fluid/distributed/collective/process_group_custom.h b/paddle/fluid/distributed/collective/process_group_custom.h index d8db5e9f9083afab8ef607e5c3bab98c56a968f5..4d95ef0ae8e1a8ebdd076e66814c575526272ccd 100644 --- a/paddle/fluid/distributed/collective/process_group_custom.h +++ b/paddle/fluid/distributed/collective/process_group_custom.h @@ -23,7 +23,7 @@ #include "paddle/fluid/distributed/collective/custom_ccl_tools.h" #include "paddle/fluid/distributed/collective/process_group.h" -#include "paddle/fluid/distributed/collective/process_group_without_stream.h" +#include "paddle/fluid/distributed/collective/process_group_with_stream.h" #include "paddle/fluid/platform/device/npu/npu_stream.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/enforce.h" @@ -36,7 +36,7 @@ namespace distributed { using Place = paddle::platform::Place; using CustomDeviceContext = paddle::platform::CustomDeviceContext; -class ProcessGroupCustom : public ProcessGroupWithoutStream { +class ProcessGroupCustom : public ProcessGroupWithStream { public: class CustomTask : public ProcessGroup::Task, public std::enable_shared_from_this {