From 9751bd0d1c36372d4b25b1abc3ccea7a9688a1a8 Mon Sep 17 00:00:00 2001 From: ronnywang Date: Mon, 13 Mar 2023 17:03:43 +0800 Subject: [PATCH] add UpdateWaitChain for process_group_custom (#51491) * add UpdateWaitChain for process_group_custom * add UpdateWaitChain for process_group_custom --- .../distributed/collective/process_group_custom.cc | 12 ++++++++++++ .../distributed/collective/process_group_custom.h | 7 ++++--- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/distributed/collective/process_group_custom.cc b/paddle/fluid/distributed/collective/process_group_custom.cc index 070b1f21709..ecf88d93ebb 100644 --- a/paddle/fluid/distributed/collective/process_group_custom.cc +++ b/paddle/fluid/distributed/collective/process_group_custom.cc @@ -100,6 +100,18 @@ bool ProcessGroupCustom::CustomTask::Wait(std::chrono::milliseconds timeout) { // Same as Wait void ProcessGroupCustom::CustomTask::Synchronize() { Wait(kWaitTimeout); } +void ProcessGroupCustom::CustomTask::UpdateWaitChain( + const phi::DeviceContext& ctx) { + PADDLE_ENFORCE_NE( + std::find(places_.cbegin(), places_.cend(), ctx.GetPlace()), + places_.cend(), + phi::errors::NotFound("Cannot find the device context in this task.")); + auto index = std::find(places_.cbegin(), places_.cend(), ctx.GetPlace()) - + places_.cbegin(); + control_events_[index].Record( + reinterpret_cast(ctx)); +} + ProcessGroupCustom::ProcessGroupCustom( const std::shared_ptr& store, const std::string& device_type, diff --git a/paddle/fluid/distributed/collective/process_group_custom.h b/paddle/fluid/distributed/collective/process_group_custom.h index 794d0b0ef8f..d8db5e9f908 100644 --- a/paddle/fluid/distributed/collective/process_group_custom.h +++ b/paddle/fluid/distributed/collective/process_group_custom.h @@ -46,10 +46,11 @@ class ProcessGroupCustom : public ProcessGroupWithoutStream { CommType CommType, const std::vector& inputs); - bool IsCompleted(); + bool IsCompleted() override; void SynchronizeStreams(); - bool Wait(std::chrono::milliseconds timeout = kWaitTimeout); - void Synchronize(); + bool Wait(std::chrono::milliseconds timeout = kWaitTimeout) override; + void Synchronize() override; + void UpdateWaitChain(const phi::DeviceContext& ctx) override; void SetOutputs(std::vector& outputs); // NOLINT virtual ~CustomTask(); -- GitLab