diff --git a/paddle/fluid/distributed/collective/process_group_custom.cc b/paddle/fluid/distributed/collective/process_group_custom.cc index 070b1f217094bc401ffcedd443cca6fa6f402c2e..ecf88d93ebb5e90656c4a966fdd93cc221ff1453 100644 --- a/paddle/fluid/distributed/collective/process_group_custom.cc +++ b/paddle/fluid/distributed/collective/process_group_custom.cc @@ -100,6 +100,18 @@ bool ProcessGroupCustom::CustomTask::Wait(std::chrono::milliseconds timeout) { // Same as Wait void ProcessGroupCustom::CustomTask::Synchronize() { Wait(kWaitTimeout); } +void ProcessGroupCustom::CustomTask::UpdateWaitChain( + const phi::DeviceContext& ctx) { + PADDLE_ENFORCE_NE( + std::find(places_.cbegin(), places_.cend(), ctx.GetPlace()), + places_.cend(), + phi::errors::NotFound("Cannot find the device context in this task.")); + auto index = std::find(places_.cbegin(), places_.cend(), ctx.GetPlace()) - + places_.cbegin(); + control_events_[index].Record( + reinterpret_cast(ctx)); +} + ProcessGroupCustom::ProcessGroupCustom( const std::shared_ptr& store, const std::string& device_type, diff --git a/paddle/fluid/distributed/collective/process_group_custom.h b/paddle/fluid/distributed/collective/process_group_custom.h index 794d0b0ef8f437f0c0ab4ab94f7d810978e09be8..d8db5e9f9083afab8ef607e5c3bab98c56a968f5 100644 --- a/paddle/fluid/distributed/collective/process_group_custom.h +++ b/paddle/fluid/distributed/collective/process_group_custom.h @@ -46,10 +46,11 @@ class ProcessGroupCustom : public ProcessGroupWithoutStream { CommType CommType, const std::vector& inputs); - bool IsCompleted(); + bool IsCompleted() override; void SynchronizeStreams(); - bool Wait(std::chrono::milliseconds timeout = kWaitTimeout); - void Synchronize(); + bool Wait(std::chrono::milliseconds timeout = kWaitTimeout) override; + void Synchronize() override; + void UpdateWaitChain(const phi::DeviceContext& ctx) override; void SetOutputs(std::vector& outputs); // NOLINT virtual ~CustomTask();