未验证 提交 f49c2c23 编写于 作者: Y Yuang Liu 提交者: GitHub

[fleet_executor] Add sync method (#37167)

上级 1e598f1a
......@@ -33,6 +33,13 @@ bool Carrier::EnqueueInterceptorMessage(
// handle control message
return true;
} else {
if (creating_interceptors_) {
// Cannot handle the message to interceptor since interceptors
// are still under creating. Will enqueue into a tmp stack.
VLOG(3) << "Receiving message while creating interceptors.";
message_tmp_.emplace_back(interceptor_message);
return true;
}
int64_t dst_id = interceptor_message.dst_id();
Interceptor* dst_interceptor = GetInterceptor(dst_id);
bool rst =
......@@ -70,16 +77,45 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
return ptr;
}
void Carrier::SetCreatingFlag(bool flag) {
// set the creating flag
VLOG(3) << "Carrier is set the creating flag from " << creating_interceptors_
<< " to " << flag << ".";
creating_interceptors_ = flag;
if (!flag) {
// finish create interceptors outside, handle tmp messsages
HandleTmpMessages();
}
}
void Carrier::HandleTmpMessages() {
VLOG(3) << "Carrier has received " << message_tmp_.size()
<< " messages during creating interceptors.";
for (const auto& msg : message_tmp_) {
EnqueueInterceptorMessage(msg);
}
message_tmp_.clear();
}
void Carrier::CreateInterceptors() {
// create each Interceptor
if (!interceptor_id_to_node_.empty()) {
// no auto init since there is no config
for (const auto& item : interceptor_id_to_node_) {
int64_t interceptor_id = item.first;
TaskNode* task_node = item.second;
// TODO(wangxi): use node_type to select different Interceptor
auto interceptor = std::make_unique<Interceptor>(interceptor_id, task_node);
auto interceptor =
std::make_unique<Interceptor>(interceptor_id, task_node);
SetInterceptor(interceptor_id, std::move(interceptor));
VLOG(3) << "Create Interceptor for " << interceptor_id;
VLOG(3) << "Create Interceptor with interceptor id: " << interceptor_id
<< ".";
}
// The carrier will be always waiting for outside initializer
// since there is no interceptor has been created during auto init
creating_interceptors_ = false;
HandleTmpMessages();
}
}
......
......@@ -17,6 +17,7 @@
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
......@@ -53,6 +54,8 @@ class Carrier final {
Interceptor* SetInterceptor(int64_t interceptor_id,
std::unique_ptr<Interceptor>);
void SetCreatingFlag(bool flag);
DISABLE_COPY_AND_ASSIGN(Carrier);
private:
......@@ -61,12 +64,17 @@ class Carrier final {
// create each Interceptor
void CreateInterceptors();
void HandleTmpMessages();
// interceptor logic id to the Nodes info
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_node_;
// interceptor logic id to actually interceptor
std::unordered_map<int64_t, std::unique_ptr<Interceptor>>
interceptor_idx_to_interceptor_;
std::vector<InterceptorMessage> message_tmp_{};
bool creating_interceptors_{true};
};
} // namespace distributed
......
......@@ -56,10 +56,10 @@ bool Interceptor::EnqueueRemoteInterceptorMessage(
return true;
}
void Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) {
bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) {
msg.set_src_id(interceptor_id_);
msg.set_dst_id(dst_id);
MessageBus::Instance().Send(msg);
return MessageBus::Instance().Send(msg);
}
void Interceptor::PoolTheMailbox() {
......@@ -76,10 +76,12 @@ void Interceptor::PoolTheMailbox() {
const InterceptorMessage interceptor_message = local_mailbox_.front();
local_mailbox_.pop();
const MessageType message_type = interceptor_message.message_type();
VLOG(3) << interceptor_id_ << " has received a message: " << message_type
<< ".";
VLOG(3) << "Interceptor " << interceptor_id_ << " has received a message"
<< " from interceptor " << interceptor_message.src_id()
<< " with message: " << message_type << ".";
if (message_type == STOP) {
// break the pooling thread
VLOG(3) << "Interceptor " << interceptor_id_ << " is quiting.";
break;
}
......
......@@ -58,7 +58,7 @@ class Interceptor {
bool EnqueueRemoteInterceptorMessage(
const InterceptorMessage& interceptor_message);
void Send(int64_t dst_id, InterceptorMessage& msg); // NOLINT
bool Send(int64_t dst_id, InterceptorMessage& msg); // NOLINT
DISABLE_COPY_AND_ASSIGN(Interceptor);
......
......@@ -26,8 +26,8 @@ void InterceptorMessageServiceImpl::InterceptorMessageService(
const InterceptorMessage* request, InterceptorResponse* response,
google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done);
VLOG(3) << "Interceptor Message Service receives a message from: "
<< request->src_id()
VLOG(3) << "Interceptor Message Service receives a message from interceptor "
<< request->src_id() << " to interceptor " << request->dst_id()
<< ", with the message: " << request->message_type();
response->set_rst(true);
// call interceptor manager's method to handle the message
......
......@@ -55,12 +55,13 @@ bool MessageBus::Send(const InterceptorMessage& interceptor_message) {
int64_t src_id = interceptor_message.src_id();
int64_t dst_id = interceptor_message.dst_id();
if (IsSameRank(src_id, dst_id)) {
VLOG(3) << "Send a message from rank " << src_id << " to rank " << dst_id
<< ", which are same ranks.";
VLOG(3) << "Send a message from interceptor " << src_id
<< " to interceptor " << dst_id << ", which are in the same ranks.";
return SendIntraRank(interceptor_message);
} else {
VLOG(3) << "Send a message from rank " << src_id << " to rank " << dst_id
<< ", which are different ranks.";
VLOG(3) << "Send a message from interceptor " << src_id
<< " to interceptor " << dst_id
<< ", which are in different ranks.";
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
int retry_time = 0; // message bus will retry sending for 10 times
......@@ -155,6 +156,7 @@ bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) {
"Cannot find rank for dst interceptor id %lld. "
"Init error.",
dst_id));
VLOG(3) << "Message bus sending to addr: " << dst_ip->second;
const char* dst_ip_for_brpc = dst_ip->second.c_str();
brpc::Channel channel;
brpc::ChannelOptions options;
......
......@@ -59,8 +59,8 @@ TEST(InterceptorTest, PingPong) {
Interceptor* a = carrier.SetInterceptor(
0, std::make_unique<PingPongInterceptor>(0, nullptr));
carrier.SetInterceptor(1, std::make_unique<PingPongInterceptor>(1, nullptr));
carrier.SetCreatingFlag(false);
InterceptorMessage msg;
a->Send(1, msg);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册