From 643fd2f4053b41eb0d356f361e2b26657aa0bf06 Mon Sep 17 00:00:00 2001 From: WangXi Date: Wed, 10 Nov 2021 19:38:27 +0800 Subject: [PATCH] [FleetExecutor]Add interceptor message handle (#37093) --- .../distributed/fleet_executor/interceptor.cc | 19 +++++++++++++++++++ .../distributed/fleet_executor/interceptor.h | 14 ++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.cc b/paddle/fluid/distributed/fleet_executor/interceptor.cc index 7a87e3e6a0..0b3f3ff2de 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/interceptor.cc @@ -27,6 +27,16 @@ Interceptor::Interceptor(int64_t interceptor_id, TaskNode* node) Interceptor::~Interceptor() { interceptor_thread_.join(); } +void Interceptor::RegisterInterceptorHandle(InterceptorHandle handle) { + handle_ = handle; +} + +void Interceptor::Handle(const InterceptorMessage& msg) { + if (handle_) { + handle_(msg); + } +} + std::condition_variable& Interceptor::GetCondVar() { // get the conditional var return cond_var_; @@ -47,6 +57,13 @@ bool Interceptor::EnqueueRemoteInterceptorMessage( return true; } +void Interceptor::Send(int64_t dst_id, + std::unique_ptr msg) { + msg->set_src_id(interceptor_id_); + msg->set_dst_id(dst_id); + // send interceptor msg +} + void Interceptor::PoolTheMailbox() { // pool the local mailbox, parse the Message while (true) { @@ -67,6 +84,8 @@ void Interceptor::PoolTheMailbox() { // break the pooling thread break; } + + Handle(interceptor_message); } } diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.h b/paddle/fluid/distributed/fleet_executor/interceptor.h index 85b1d2351f..02696d8edd 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/interceptor.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include #include @@ -32,6 +33,9 @@ namespace distributed { class TaskNode; class Interceptor { + public: + using InterceptorHandle = std::function; + public: Interceptor() = delete; @@ -39,6 +43,11 @@ class Interceptor { virtual ~Interceptor(); + // register interceptor handle + void RegisterInterceptorHandle(InterceptorHandle handle); + + void Handle(const InterceptorMessage& msg); + // return the interceptor id int64_t GetInterceptorId() const; @@ -49,6 +58,8 @@ class Interceptor { bool EnqueueRemoteInterceptorMessage( const InterceptorMessage& interceptor_message); + void Send(int64_t dst_id, std::unique_ptr msg); + DISABLE_COPY_AND_ASSIGN(Interceptor); private: @@ -65,6 +76,9 @@ class Interceptor { // node need to be handled by this interceptor TaskNode* node_; + // interceptor handle which process message + InterceptorHandle handle_{nullptr}; + // mutex to control read/write conflict for remote mailbox std::mutex remote_mailbox_mutex_; -- GitLab