未验证 提交 8cdd5564 编写于 作者: W WangXi 提交者: GitHub

[fleet_executor] interceptor send message through message_bus (#37106)

上级 f5e7b02a
...@@ -16,6 +16,7 @@ cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc ...@@ -16,6 +16,7 @@ cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(message_bus.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(message_bus.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(message_bus.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(message_bus.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(carrier.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(carrier.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
......
...@@ -46,10 +46,5 @@ std::shared_ptr<Carrier> FleetExecutor::GetCarrier() { ...@@ -46,10 +46,5 @@ std::shared_ptr<Carrier> FleetExecutor::GetCarrier() {
return nullptr; return nullptr;
} }
std::shared_ptr<MessageBus> FleetExecutor::GetMessageBus() {
// get message bus
return nullptr;
}
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -37,14 +37,12 @@ class FleetExecutor final { ...@@ -37,14 +37,12 @@ class FleetExecutor final {
void Run(); void Run();
void Release(); void Release();
static std::shared_ptr<Carrier> GetCarrier(); static std::shared_ptr<Carrier> GetCarrier();
static std::shared_ptr<MessageBus> GetMessageBus();
private: private:
DISABLE_COPY_AND_ASSIGN(FleetExecutor); DISABLE_COPY_AND_ASSIGN(FleetExecutor);
FleetExecutorDesc exe_desc_; FleetExecutorDesc exe_desc_;
std::unique_ptr<RuntimeGraph> runtime_graph_; std::unique_ptr<RuntimeGraph> runtime_graph_;
static std::shared_ptr<Carrier> global_carrier_; static std::shared_ptr<Carrier> global_carrier_;
static std::shared_ptr<MessageBus> global_message_bus_;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
...@@ -27,9 +28,7 @@ Interceptor::Interceptor(int64_t interceptor_id, TaskNode* node) ...@@ -27,9 +28,7 @@ Interceptor::Interceptor(int64_t interceptor_id, TaskNode* node)
Interceptor::~Interceptor() { interceptor_thread_.join(); } Interceptor::~Interceptor() { interceptor_thread_.join(); }
void Interceptor::RegisterInterceptorHandle(InterceptorHandle handle) { void Interceptor::RegisterMsgHandle(MsgHandle handle) { handle_ = handle; }
handle_ = handle;
}
void Interceptor::Handle(const InterceptorMessage& msg) { void Interceptor::Handle(const InterceptorMessage& msg) {
if (handle_) { if (handle_) {
...@@ -61,7 +60,7 @@ void Interceptor::Send(int64_t dst_id, ...@@ -61,7 +60,7 @@ void Interceptor::Send(int64_t dst_id,
std::unique_ptr<InterceptorMessage> msg) { std::unique_ptr<InterceptorMessage> msg) {
msg->set_src_id(interceptor_id_); msg->set_src_id(interceptor_id_);
msg->set_dst_id(dst_id); msg->set_dst_id(dst_id);
// send interceptor msg MessageBus::Instance().Send(*msg.get());
} }
void Interceptor::PoolTheMailbox() { void Interceptor::PoolTheMailbox() {
......
...@@ -34,7 +34,7 @@ class TaskNode; ...@@ -34,7 +34,7 @@ class TaskNode;
class Interceptor { class Interceptor {
public: public:
using InterceptorHandle = std::function<void(const InterceptorMessage&)>; using MsgHandle = std::function<void(const InterceptorMessage&)>;
public: public:
Interceptor() = delete; Interceptor() = delete;
...@@ -44,7 +44,7 @@ class Interceptor { ...@@ -44,7 +44,7 @@ class Interceptor {
virtual ~Interceptor(); virtual ~Interceptor();
// register interceptor handle // register interceptor handle
void RegisterInterceptorHandle(InterceptorHandle handle); void RegisterMsgHandle(MsgHandle handle);
void Handle(const InterceptorMessage& msg); void Handle(const InterceptorMessage& msg);
...@@ -77,7 +77,7 @@ class Interceptor { ...@@ -77,7 +77,7 @@ class Interceptor {
TaskNode* node_; TaskNode* node_;
// interceptor handle which process message // interceptor handle which process message
InterceptorHandle handle_{nullptr}; MsgHandle handle_{nullptr};
// mutex to control read/write conflict for remote mailbox // mutex to control read/write conflict for remote mailbox
std::mutex remote_mailbox_mutex_; std::mutex remote_mailbox_mutex_;
......
...@@ -21,20 +21,28 @@ ...@@ -21,20 +21,28 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
MessageBus::MessageBus( void MessageBus::Init(
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank, const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_map<int64_t, std::string>& rank_to_addr, const std::unordered_map<int64_t, std::string>& rank_to_addr,
const std::string& addr) const std::string& addr) {
: interceptor_id_to_rank_(interceptor_id_to_rank), PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists(
rank_to_addr_(rank_to_addr), "MessageBus is already init."));
addr_(addr) { is_init_ = true;
interceptor_id_to_rank_ = interceptor_id_to_rank;
rank_to_addr_ = rank_to_addr;
addr_ = addr;
listen_port_thread_ = std::thread([this]() { listen_port_thread_ = std::thread([this]() {
VLOG(3) << "Start listen_port_thread_ for message bus"; VLOG(3) << "Start listen_port_thread_ for message bus";
ListenPort(); ListenPort();
}); });
std::call_once(once_flag_, []() {
std::atexit([]() { MessageBus::Instance().Release(); });
});
} }
MessageBus::~MessageBus() { void MessageBus::Release() {
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL) !defined(PADDLE_WITH_ASCEND_CL)
server_.Stop(1000); server_.Stop(1000);
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <mutex>
#include <string> #include <string>
#include <thread> #include <thread>
#include <unordered_map> #include <unordered_map>
...@@ -35,15 +36,19 @@ namespace distributed { ...@@ -35,15 +36,19 @@ namespace distributed {
class Carrier; class Carrier;
// A singleton MessageBus
class MessageBus final { class MessageBus final {
public: public:
MessageBus() = delete; static MessageBus& Instance() {
static MessageBus msg_bus;
return msg_bus;
}
MessageBus(const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank, void Init(const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_map<int64_t, std::string>& rank_to_addr, const std::unordered_map<int64_t, std::string>& rank_to_addr,
const std::string& addr); const std::string& addr);
~MessageBus(); void Release();
// called by Interceptor, send InterceptorMessage to dst // called by Interceptor, send InterceptorMessage to dst
bool Send(const InterceptorMessage& interceptor_message); bool Send(const InterceptorMessage& interceptor_message);
...@@ -51,6 +56,8 @@ class MessageBus final { ...@@ -51,6 +56,8 @@ class MessageBus final {
DISABLE_COPY_AND_ASSIGN(MessageBus); DISABLE_COPY_AND_ASSIGN(MessageBus);
private: private:
MessageBus() = default;
// function keep listen the port and handle the message // function keep listen the port and handle the message
void ListenPort(); void ListenPort();
...@@ -66,6 +73,9 @@ class MessageBus final { ...@@ -66,6 +73,9 @@ class MessageBus final {
// send the message intra rank (dst is the same rank with src) // send the message intra rank (dst is the same rank with src)
bool SendIntraRank(const InterceptorMessage& interceptor_message); bool SendIntraRank(const InterceptorMessage& interceptor_message);
bool is_init_{false};
std::once_flag once_flag_;
// handed by above layer, save the info mapping interceptor id to rank id // handed by above layer, save the info mapping interceptor id to rank id
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_; std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册