未验证 提交 8cdd5564 编写于 作者: W WangXi 提交者: GitHub

[fleet_executor] interceptor send message through message_bus (#37106)

上级 f5e7b02a
......@@ -16,6 +16,7 @@ cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc
if(WITH_DISTRIBUTE)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(message_bus.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(message_bus.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(carrier.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
......
......@@ -46,10 +46,5 @@ std::shared_ptr<Carrier> FleetExecutor::GetCarrier() {
return nullptr;
}
std::shared_ptr<MessageBus> FleetExecutor::GetMessageBus() {
// get message bus
return nullptr;
}
} // namespace distributed
} // namespace paddle
......@@ -37,14 +37,12 @@ class FleetExecutor final {
void Run();
void Release();
static std::shared_ptr<Carrier> GetCarrier();
static std::shared_ptr<MessageBus> GetMessageBus();
private:
DISABLE_COPY_AND_ASSIGN(FleetExecutor);
FleetExecutorDesc exe_desc_;
std::unique_ptr<RuntimeGraph> runtime_graph_;
static std::shared_ptr<Carrier> global_carrier_;
static std::shared_ptr<MessageBus> global_message_bus_;
};
} // namespace distributed
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
namespace paddle {
namespace distributed {
......@@ -27,9 +28,7 @@ Interceptor::Interceptor(int64_t interceptor_id, TaskNode* node)
Interceptor::~Interceptor() { interceptor_thread_.join(); }
void Interceptor::RegisterInterceptorHandle(InterceptorHandle handle) {
handle_ = handle;
}
void Interceptor::RegisterMsgHandle(MsgHandle handle) { handle_ = handle; }
void Interceptor::Handle(const InterceptorMessage& msg) {
if (handle_) {
......@@ -61,7 +60,7 @@ void Interceptor::Send(int64_t dst_id,
std::unique_ptr<InterceptorMessage> msg) {
msg->set_src_id(interceptor_id_);
msg->set_dst_id(dst_id);
// send interceptor msg
MessageBus::Instance().Send(*msg.get());
}
void Interceptor::PoolTheMailbox() {
......
......@@ -34,7 +34,7 @@ class TaskNode;
class Interceptor {
public:
using InterceptorHandle = std::function<void(const InterceptorMessage&)>;
using MsgHandle = std::function<void(const InterceptorMessage&)>;
public:
Interceptor() = delete;
......@@ -44,7 +44,7 @@ class Interceptor {
virtual ~Interceptor();
// register interceptor handle
void RegisterInterceptorHandle(InterceptorHandle handle);
void RegisterMsgHandle(MsgHandle handle);
void Handle(const InterceptorMessage& msg);
......@@ -77,7 +77,7 @@ class Interceptor {
TaskNode* node_;
// interceptor handle which process message
InterceptorHandle handle_{nullptr};
MsgHandle handle_{nullptr};
// mutex to control read/write conflict for remote mailbox
std::mutex remote_mailbox_mutex_;
......
......@@ -21,20 +21,28 @@
namespace paddle {
namespace distributed {
MessageBus::MessageBus(
void MessageBus::Init(
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_map<int64_t, std::string>& rank_to_addr,
const std::string& addr)
: interceptor_id_to_rank_(interceptor_id_to_rank),
rank_to_addr_(rank_to_addr),
addr_(addr) {
const std::string& addr) {
PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists(
"MessageBus is already init."));
is_init_ = true;
interceptor_id_to_rank_ = interceptor_id_to_rank;
rank_to_addr_ = rank_to_addr;
addr_ = addr;
listen_port_thread_ = std::thread([this]() {
VLOG(3) << "Start listen_port_thread_ for message bus";
ListenPort();
});
std::call_once(once_flag_, []() {
std::atexit([]() { MessageBus::Instance().Release(); });
});
}
MessageBus::~MessageBus() {
void MessageBus::Release() {
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
server_.Stop(1000);
......
......@@ -14,6 +14,7 @@
#pragma once
#include <mutex>
#include <string>
#include <thread>
#include <unordered_map>
......@@ -35,15 +36,19 @@ namespace distributed {
class Carrier;
// A singleton MessageBus
class MessageBus final {
public:
MessageBus() = delete;
static MessageBus& Instance() {
static MessageBus msg_bus;
return msg_bus;
}
MessageBus(const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_map<int64_t, std::string>& rank_to_addr,
const std::string& addr);
void Init(const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_map<int64_t, std::string>& rank_to_addr,
const std::string& addr);
~MessageBus();
void Release();
// called by Interceptor, send InterceptorMessage to dst
bool Send(const InterceptorMessage& interceptor_message);
......@@ -51,6 +56,8 @@ class MessageBus final {
DISABLE_COPY_AND_ASSIGN(MessageBus);
private:
MessageBus() = default;
// function keep listen the port and handle the message
void ListenPort();
......@@ -66,6 +73,9 @@ class MessageBus final {
// send the message intra rank (dst is the same rank with src)
bool SendIntraRank(const InterceptorMessage& interceptor_message);
bool is_init_{false};
std::once_flag once_flag_;
// handed by above layer, save the info mapping interceptor id to rank id
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册