未验证 提交 9751bd0d 编写于 作者: R ronnywang 提交者: GitHub

add UpdateWaitChain for process_group_custom (#51491)

* add UpdateWaitChain for  process_group_custom

* add UpdateWaitChain for  process_group_custom
上级 524eeb17
......@@ -100,6 +100,18 @@ bool ProcessGroupCustom::CustomTask::Wait(std::chrono::milliseconds timeout) {
// Same as Wait
void ProcessGroupCustom::CustomTask::Synchronize() { Wait(kWaitTimeout); }
void ProcessGroupCustom::CustomTask::UpdateWaitChain(
const phi::DeviceContext& ctx) {
PADDLE_ENFORCE_NE(
std::find(places_.cbegin(), places_.cend(), ctx.GetPlace()),
places_.cend(),
phi::errors::NotFound("Cannot find the device context in this task."));
auto index = std::find(places_.cbegin(), places_.cend(), ctx.GetPlace()) -
places_.cbegin();
control_events_[index].Record(
reinterpret_cast<const phi::CustomContext&>(ctx));
}
ProcessGroupCustom::ProcessGroupCustom(
const std::shared_ptr<phi::distributed::Store>& store,
const std::string& device_type,
......
......@@ -46,10 +46,11 @@ class ProcessGroupCustom : public ProcessGroupWithoutStream {
CommType CommType,
const std::vector<phi::DenseTensor>& inputs);
bool IsCompleted();
bool IsCompleted() override;
void SynchronizeStreams();
bool Wait(std::chrono::milliseconds timeout = kWaitTimeout);
void Synchronize();
bool Wait(std::chrono::milliseconds timeout = kWaitTimeout) override;
void Synchronize() override;
void UpdateWaitChain(const phi::DeviceContext& ctx) override;
void SetOutputs(std::vector<phi::DenseTensor>& outputs); // NOLINT
virtual ~CustomTask();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册