未验证 提交 072e7801 编写于 作者: Y Yuang Liu 提交者: GitHub

[fleet_executor] Implementation of the message bus, the carrier and part of...

[fleet_executor] Implementation of the message bus, the carrier and part of the interceptor (#37049)
上级 f0c77378
......@@ -21,22 +21,57 @@ namespace paddle {
namespace distributed {
Carrier::Carrier(
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node) {
// init
}
Carrier::~Carrier() {
// destroy
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node)
: interceptor_id_to_node_(interceptor_id_to_node) {
CreateInterceptors();
}
bool Carrier::EnqueueInterceptorMessage(
const InterceptorMessage& interceptor_message) {
// enqueue message to interceptor
if (interceptor_message.ctrl_message()) {
// handle control message
return true;
} else {
int64_t dst_id = interceptor_message.dst_id();
Interceptor* dst_interceptor = GetInterceptor(dst_id);
bool rst =
dst_interceptor->EnqueueRemoteInterceptorMessage(interceptor_message);
if (rst) {
std::condition_variable& interceptor_cond_var =
dst_interceptor->GetCondVar();
interceptor_cond_var.notify_all();
}
return rst;
}
}
Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) {
auto iter = interceptor_idx_to_interceptor_.find(interceptor_id);
PADDLE_ENFORCE_NE(iter, interceptor_idx_to_interceptor_.end(),
platform::errors::InvalidArgument(
"Cannot find interceptor instance for interceptor "
"id %lld. Wrong dst? Call before init?",
interceptor_id));
return iter->second.get();
}
void Carrier::CreateInterceptors() {
// create each Interceptor
for (const auto& item : interceptor_id_to_node_) {
int64_t interceptor_id = item.first;
TaskNode* task_node = item.second;
const auto& iter = interceptor_idx_to_interceptor_.find(interceptor_id);
PADDLE_ENFORCE_EQ(iter, interceptor_idx_to_interceptor_.end(),
platform::errors::AlreadyExists(
"The interceptor id %lld has already been created! "
"The interceptor is should be unique.",
interceptor_id));
interceptor_idx_to_interceptor_.insert(std::make_pair(
interceptor_id,
std::make_unique<Interceptor>(interceptor_id, task_node)));
VLOG(3) << "Create Interceptor for " << interceptor_id;
}
}
} // namespace distributed
......
......@@ -19,6 +19,8 @@
#include <unordered_map>
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle {
......@@ -32,9 +34,10 @@ class Carrier final {
public:
Carrier() = delete;
Carrier(const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node);
explicit Carrier(
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node);
~Carrier();
~Carrier() = default;
// Enqueue a message to corresponding interceptor id
bool EnqueueInterceptorMessage(const InterceptorMessage& interceptor_message);
......
......@@ -17,28 +17,73 @@
namespace paddle {
namespace distributed {
Interceptor::Interceptor(int64_t interceptor_id_, TaskNode* node) {
// init
Interceptor::Interceptor(int64_t interceptor_id, TaskNode* node)
: interceptor_id_(interceptor_id), node_(node) {
interceptor_thread_ = std::thread([this]() {
VLOG(3) << "Start pooling local mailbox's thread.";
PoolTheMailbox();
});
}
Interceptor::~Interceptor() { interceptor_thread_.join(); }
std::condition_variable& Interceptor::GetCondVar() {
// get the conditional var
return cond_var_;
}
int64_t Interceptor::GetInterceptorId() const {
// return the interceptor id
return 0;
return interceptor_id_;
}
bool Interceptor::EnqueueRemoteInterceptorMessage(
const InterceptorMessage& interceptor_message) {
// Called by Carrier, enqueue an InterceptorMessage to remote mailbox
VLOG(3) << "Enqueue message: " << interceptor_message.message_type()
<< " into " << interceptor_id_ << "'s remote mailbox.";
remote_mailbox_mutex_.lock();
remote_mailbox_.push(interceptor_message);
remote_mailbox_mutex_.unlock();
return true;
}
void Interceptor::PoolTheMailbox() {
// pool the local mailbox, parse the Message
while (true) {
if (local_mailbox_.empty()) {
// local mailbox is empty, fetch the remote mailbox
VLOG(3) << interceptor_id_ << "'s local mailbox is empty. "
<< "Fetch the remote mailbox.";
PADDLE_ENFORCE_EQ(FetchRemoteMailbox(), true,
platform::errors::InvalidArgument(
"Error encountered when fetch remote mailbox."));
}
const InterceptorMessage interceptor_message = local_mailbox_.front();
local_mailbox_.pop();
const MessageType message_type = interceptor_message.message_type();
VLOG(3) << interceptor_id_ << " has received a message: " << message_type
<< ".";
if (message_type == STOP) {
// break the pooling thread
break;
}
}
}
bool Interceptor::FetchRemoteMailbox() {
// fetch all Message from remote mailbox to local mailbox
// return true if remote mailbox not empty, otherwise return false
std::unique_lock<std::mutex> lock(remote_mailbox_mutex_);
cond_var_.wait(lock, [this]() { return !remote_mailbox_.empty(); });
if (remote_mailbox_.empty()) {
// the thread has been unblocked accidentally
return false;
}
while (!remote_mailbox_.empty()) {
local_mailbox_.push(std::move(remote_mailbox_.front()));
remote_mailbox_.pop();
}
return true;
}
......
......@@ -22,6 +22,8 @@
#include <vector>
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle {
......@@ -33,13 +35,16 @@ class Interceptor {
public:
Interceptor() = delete;
Interceptor(int64_t interceptor_id_, TaskNode* node);
Interceptor(int64_t interceptor_id, TaskNode* node);
virtual ~Interceptor() = default;
virtual ~Interceptor();
// return the interceptor id
int64_t GetInterceptorId() const;
// return the conditional var
std::condition_variable& GetCondVar();
// Called by Carrier, enqueue an InterceptorMessage to remote mailbox
bool EnqueueRemoteInterceptorMessage(
const InterceptorMessage& interceptor_message);
......
......@@ -11,9 +11,12 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PADDLE_WITH_ASCEND_CL
#ifdef PADDLE_WITH_DISTRIBUTE
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
namespace paddle {
namespace distributed {
......@@ -22,10 +25,18 @@ void InterceptorMessageServiceImpl::InterceptorMessageService(
google::protobuf::RpcController* control_base,
const InterceptorMessage* request, InterceptorResponse* response,
google::protobuf::Closure* done) {
// receive msg
brpc::ClosureGuard done_guard(done);
VLOG(3) << "Interceptor Message Service receives a message from: "
<< request->src_id()
<< ", with the message: " << request->message_type();
response->set_rst(true);
// call interceptor manager's method to handle the message
std::shared_ptr<Carrier> carrier = FleetExecutor::GetCarrier();
if (carrier != nullptr) {
carrier->EnqueueInterceptorMessage(*request);
}
}
} // namespace distributed
} // namespace paddle
#endif
#endif
......@@ -11,8 +11,8 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PADDLE_WITH_ASCEND_CL
#ifdef PADDLE_WITH_DISTRIBUTE
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
#pragma once
#include "brpc/server.h"
......@@ -34,4 +34,3 @@ class InterceptorMessageServiceImpl : public TheInterceptorMessageService {
} // namespace distributed
} // namespace paddle
#endif
#endif
......@@ -12,41 +12,160 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include <memory>
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
namespace paddle {
namespace distributed {
MessageBus::MessageBus(
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_map<int64_t, std::string>& rank_to_addr,
const std::string& addr)
: interceptor_id_to_rank_(interceptor_id_to_rank),
rank_to_addr_(rank_to_addr),
addr_(addr) {
listen_port_thread_ = std::thread([this]() {
VLOG(3) << "Start listen_port_thread_ for message bus";
ListenPort();
});
}
MessageBus::~MessageBus() {
// destroy
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
server_.Stop(1000);
server_.Join();
#endif
listen_port_thread_.join();
}
bool MessageBus::Send(const InterceptorMessage& interceptor_message) {
// called by Interceptor, send InterceptorMessage to dst
int64_t src_id = interceptor_message.src_id();
int64_t dst_id = interceptor_message.dst_id();
if (IsSameRank(src_id, dst_id)) {
VLOG(3) << "Send a message from: " << src_id << " to " << dst_id
<< " within a same rank.";
return SendIntraRank(interceptor_message);
} else {
VLOG(3) << "Send a message from: " << src_id << " to " << dst_id
<< " between different ranks.";
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
return SendInterRank(interceptor_message);
#else
PADDLE_THROW(platform::errors::Unavailable(
"Fleet executor does not support sending message between different "
"ranks when Paddle is compiled with npu or "
"isn't compiled with distributed for now."));
#endif
}
return true;
}
void MessageBus::ListenPort() {
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
// function keep listen the port and handle the message
InterceptorMessageServiceImpl interceptor_message_service;
PADDLE_ENFORCE_EQ(server_.AddService(&interceptor_message_service,
brpc::SERVER_DOESNT_OWN_SERVICE),
0, platform::errors::Unavailable(
"Message bus: init brpc service error."));
// start the server
const char* ip_for_brpc = addr_.c_str();
brpc::ServerOptions options;
options.idle_timeout_sec = -1;
PADDLE_ENFORCE_EQ(
server_.Start(ip_for_brpc, &options), 0,
platform::errors::Unavailable("Message bus: start brpc service error."));
VLOG(3) << "Message bus's listen port thread starts successful.";
#else
VLOG(3) << "Fleet executor's ListenPort() is a fake function when Paddle is "
"compiled with npu or Paddle isn't compiled "
"with distributed for now.";
#endif
}
bool MessageBus::IsSameRank(int64_t src_id, int64_t dst_id) {
// check whether the dst is the same rank or different rank with src
return true;
const auto& src_rank = interceptor_id_to_rank_.find(src_id);
const auto& dst_rank = interceptor_id_to_rank_.find(dst_id);
PADDLE_ENFORCE_NE(
src_rank, interceptor_id_to_rank_.end(),
platform::errors::NotFound(
"Cannot find rank for src interceptor id %lld. Init error.", src_id));
PADDLE_ENFORCE_NE(
dst_rank, interceptor_id_to_rank_.end(),
platform::errors::NotFound(
"Cannot find rank for dst interceptor id %lld. Init error.", dst_id));
const auto& src_ip = rank_to_addr_.find(src_rank->second);
PADDLE_ENFORCE_NE(src_ip, rank_to_addr_.end(),
platform::errors::NotFound(
"Cannot find addr for src rank id %lld. Init error.",
src_rank->second));
PADDLE_ENFORCE_EQ(
src_ip->second, addr_,
platform::errors::Fatal("The src interceptor's addr is %s, while the "
"message bus's addr is %s, which are different. "
"Init error.",
src_ip->second, addr_));
return src_rank->second == dst_rank->second;
}
#ifndef PADDLE_WITH_ASCEND_CL
#ifdef PADDLE_WITH_DISTRIBUTE
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) {
// send the message inter rank (dst is different rank with src)
int64_t dst_id = interceptor_message.dst_id();
int64_t dst_rank = interceptor_id_to_rank_[dst_id];
auto dst_ip = rank_to_addr_.find(dst_rank);
PADDLE_ENFORCE_NE(dst_ip, rank_to_addr_.end(),
platform::errors::InvalidArgument(
"Cannot find rank for dst interceptor id %lld. "
"Init error.",
dst_id));
const char* dst_ip_for_brpc = dst_ip->second.c_str();
brpc::Channel channel;
brpc::ChannelOptions options;
options.protocol = "baidu_std";
options.timeout_ms = 1000;
options.max_retry = 5;
PADDLE_ENFORCE_EQ(
channel.Init(dst_ip_for_brpc, &options), 0,
platform::errors::Unavailable("Message bus: init brpc channel error."));
TheInterceptorMessageService_Stub stub(&channel);
InterceptorResponse response;
brpc::Controller ctrl;
ctrl.set_log_id(0);
stub.InterceptorMessageService(&ctrl, &interceptor_message, &response, NULL);
if (!ctrl.Failed()) {
if (response.rst()) {
VLOG(3) << "Message bus: brpc sends success.";
return true;
} else {
VLOG(3) << "Message bus: InterceptorMessageService error.";
return false;
}
} else {
VLOG(3) << "Message bus: brpc sends failed with error text: "
<< ctrl.ErrorText();
return false;
}
}
#endif
#endif
bool MessageBus::SendIntraRank(const InterceptorMessage& interceptor_message) {
// send the message intra rank (dst is the same rank with src)
std::shared_ptr<Carrier> carrier = FleetExecutor::GetCarrier();
if (carrier != nullptr) {
return carrier->EnqueueInterceptorMessage(interceptor_message);
}
return true;
}
......
......@@ -18,14 +18,16 @@
#include <thread>
#include <unordered_map>
#ifndef PADDLE_WITH_ASCEND_CL
#ifdef PADDLE_WITH_DISTRIBUTE
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
#include "brpc/channel.h"
#include "brpc/server.h"
#endif
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h"
#endif
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle {
......@@ -37,13 +39,9 @@ class MessageBus final {
public:
MessageBus() = delete;
explicit MessageBus(
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
MessageBus(const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_map<int64_t, std::string>& rank_to_addr,
const std::string& addr)
: interceptor_id_to_rank_(interceptor_id_to_rank),
rank_to_addr_(rank_to_addr),
addr_(addr) {}
const std::string& addr);
~MessageBus();
......@@ -59,11 +57,10 @@ class MessageBus final {
// check whether the dst is the same rank or different rank with src
bool IsSameRank(int64_t src_id, int64_t dst_id);
#ifndef PADDLE_WITH_ASCEND_CL
#ifdef PADDLE_WITH_DISTRIBUTE
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
// send the message inter rank (dst is different rank with src)
bool SendInterRank(const InterceptorMessage& interceptor_message);
#endif
#endif
// send the message intra rank (dst is the same rank with src)
......@@ -78,11 +75,10 @@ class MessageBus final {
// the ip needs to be listened
std::string addr_;
#ifndef PADDLE_WITH_ASCEND_CL
#ifdef PADDLE_WITH_DISTRIBUTE
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
// brpc server
brpc::Server server_;
#endif
#endif
// thread keeps listening to the port to receive remote message
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册