diff --git a/paddle/fluid/distributed/fleet_executor/CMakeLists.txt b/paddle/fluid/distributed/fleet_executor/CMakeLists.txt index 0941b2075b8932de2602ff0969861d00f159d9ce..9d6b30b06334a917193cc0cfda5acbe3293f23db 100644 --- a/paddle/fluid/distributed/fleet_executor/CMakeLists.txt +++ b/paddle/fluid/distributed/fleet_executor/CMakeLists.txt @@ -16,6 +16,7 @@ cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc if(WITH_DISTRIBUTE) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") + set_source_files_properties(interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(message_bus.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(message_bus.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(carrier.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc index a2cde8bdd51ad3d65879f69fcf46bef3bb4da762..842695fdea60b26935223365ca083dbc451d38cc 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc @@ -46,10 +46,5 @@ std::shared_ptr FleetExecutor::GetCarrier() { return nullptr; } -std::shared_ptr FleetExecutor::GetMessageBus() { - // get message bus - return nullptr; -} - } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.h b/paddle/fluid/distributed/fleet_executor/fleet_executor.h index 613dacf5496f77cc58a5fc3a27c9574a92168428..e12629844933a7e69f0df9984e4345a79935ef9e 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.h +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.h @@ -37,14 +37,12 @@ class FleetExecutor final { void Run(); void Release(); static std::shared_ptr GetCarrier(); - static std::shared_ptr GetMessageBus(); private: DISABLE_COPY_AND_ASSIGN(FleetExecutor); FleetExecutorDesc exe_desc_; std::unique_ptr runtime_graph_; static std::shared_ptr global_carrier_; - static std::shared_ptr global_message_bus_; }; } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.cc b/paddle/fluid/distributed/fleet_executor/interceptor.cc index 0b3f3ff2de84ae5444cb447c1e2f228b17c81c9d..03f04d8340f0acf68f5488f7410ecfb7ef8645cc 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/interceptor.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/distributed/fleet_executor/interceptor.h" +#include "paddle/fluid/distributed/fleet_executor/message_bus.h" namespace paddle { namespace distributed { @@ -27,9 +28,7 @@ Interceptor::Interceptor(int64_t interceptor_id, TaskNode* node) Interceptor::~Interceptor() { interceptor_thread_.join(); } -void Interceptor::RegisterInterceptorHandle(InterceptorHandle handle) { - handle_ = handle; -} +void Interceptor::RegisterMsgHandle(MsgHandle handle) { handle_ = handle; } void Interceptor::Handle(const InterceptorMessage& msg) { if (handle_) { @@ -61,7 +60,7 @@ void Interceptor::Send(int64_t dst_id, std::unique_ptr msg) { msg->set_src_id(interceptor_id_); msg->set_dst_id(dst_id); - // send interceptor msg + MessageBus::Instance().Send(*msg.get()); } void Interceptor::PoolTheMailbox() { diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.h b/paddle/fluid/distributed/fleet_executor/interceptor.h index 02696d8edd737bcf378988b45ff008c70466a0a3..7744ecbb11026cb30816516ac66939ddeda3857e 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/interceptor.h @@ -34,7 +34,7 @@ class TaskNode; class Interceptor { public: - using InterceptorHandle = std::function; + using MsgHandle = std::function; public: Interceptor() = delete; @@ -44,7 +44,7 @@ class Interceptor { virtual ~Interceptor(); // register interceptor handle - void RegisterInterceptorHandle(InterceptorHandle handle); + void RegisterMsgHandle(MsgHandle handle); void Handle(const InterceptorMessage& msg); @@ -77,7 +77,7 @@ class Interceptor { TaskNode* node_; // interceptor handle which process message - InterceptorHandle handle_{nullptr}; + MsgHandle handle_{nullptr}; // mutex to control read/write conflict for remote mailbox std::mutex remote_mailbox_mutex_; diff --git a/paddle/fluid/distributed/fleet_executor/message_bus.cc b/paddle/fluid/distributed/fleet_executor/message_bus.cc index 0094dbd1f10a1acce1d31d94641e19fa416196c6..08cd100f108fd8efa3afcf590017196f27e651e9 100644 --- a/paddle/fluid/distributed/fleet_executor/message_bus.cc +++ b/paddle/fluid/distributed/fleet_executor/message_bus.cc @@ -21,20 +21,28 @@ namespace paddle { namespace distributed { -MessageBus::MessageBus( +void MessageBus::Init( const std::unordered_map& interceptor_id_to_rank, const std::unordered_map& rank_to_addr, - const std::string& addr) - : interceptor_id_to_rank_(interceptor_id_to_rank), - rank_to_addr_(rank_to_addr), - addr_(addr) { + const std::string& addr) { + PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists( + "MessageBus is already init.")); + is_init_ = true; + interceptor_id_to_rank_ = interceptor_id_to_rank; + rank_to_addr_ = rank_to_addr; + addr_ = addr; + listen_port_thread_ = std::thread([this]() { VLOG(3) << "Start listen_port_thread_ for message bus"; ListenPort(); }); + + std::call_once(once_flag_, []() { + std::atexit([]() { MessageBus::Instance().Release(); }); + }); } -MessageBus::~MessageBus() { +void MessageBus::Release() { #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ !defined(PADDLE_WITH_ASCEND_CL) server_.Stop(1000); diff --git a/paddle/fluid/distributed/fleet_executor/message_bus.h b/paddle/fluid/distributed/fleet_executor/message_bus.h index 86f34e203c5de589b06bbe769ba6811295b5b700..08e8d2e24abd8347c171d08a8c87cee5161884d5 100644 --- a/paddle/fluid/distributed/fleet_executor/message_bus.h +++ b/paddle/fluid/distributed/fleet_executor/message_bus.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include #include @@ -35,15 +36,19 @@ namespace distributed { class Carrier; +// A singleton MessageBus class MessageBus final { public: - MessageBus() = delete; + static MessageBus& Instance() { + static MessageBus msg_bus; + return msg_bus; + } - MessageBus(const std::unordered_map& interceptor_id_to_rank, - const std::unordered_map& rank_to_addr, - const std::string& addr); + void Init(const std::unordered_map& interceptor_id_to_rank, + const std::unordered_map& rank_to_addr, + const std::string& addr); - ~MessageBus(); + void Release(); // called by Interceptor, send InterceptorMessage to dst bool Send(const InterceptorMessage& interceptor_message); @@ -51,6 +56,8 @@ class MessageBus final { DISABLE_COPY_AND_ASSIGN(MessageBus); private: + MessageBus() = default; + // function keep listen the port and handle the message void ListenPort(); @@ -66,6 +73,9 @@ class MessageBus final { // send the message intra rank (dst is the same rank with src) bool SendIntraRank(const InterceptorMessage& interceptor_message); + bool is_init_{false}; + std::once_flag once_flag_; + // handed by above layer, save the info mapping interceptor id to rank id std::unordered_map interceptor_id_to_rank_;