未验证 提交 f49c2c23 编写于 作者: Y Yuang Liu 提交者: GitHub

[fleet_executor] Add sync method (#37167)

上级 1e598f1a
...@@ -33,6 +33,13 @@ bool Carrier::EnqueueInterceptorMessage( ...@@ -33,6 +33,13 @@ bool Carrier::EnqueueInterceptorMessage(
// handle control message // handle control message
return true; return true;
} else { } else {
if (creating_interceptors_) {
// Cannot handle the message to interceptor since interceptors
// are still under creating. Will enqueue into a tmp stack.
VLOG(3) << "Receiving message while creating interceptors.";
message_tmp_.emplace_back(interceptor_message);
return true;
}
int64_t dst_id = interceptor_message.dst_id(); int64_t dst_id = interceptor_message.dst_id();
Interceptor* dst_interceptor = GetInterceptor(dst_id); Interceptor* dst_interceptor = GetInterceptor(dst_id);
bool rst = bool rst =
...@@ -70,16 +77,45 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id, ...@@ -70,16 +77,45 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
return ptr; return ptr;
} }
void Carrier::SetCreatingFlag(bool flag) {
// set the creating flag
VLOG(3) << "Carrier is set the creating flag from " << creating_interceptors_
<< " to " << flag << ".";
creating_interceptors_ = flag;
if (!flag) {
// finish create interceptors outside, handle tmp messsages
HandleTmpMessages();
}
}
void Carrier::HandleTmpMessages() {
VLOG(3) << "Carrier has received " << message_tmp_.size()
<< " messages during creating interceptors.";
for (const auto& msg : message_tmp_) {
EnqueueInterceptorMessage(msg);
}
message_tmp_.clear();
}
void Carrier::CreateInterceptors() { void Carrier::CreateInterceptors() {
// create each Interceptor // create each Interceptor
if (!interceptor_id_to_node_.empty()) {
// no auto init since there is no config
for (const auto& item : interceptor_id_to_node_) { for (const auto& item : interceptor_id_to_node_) {
int64_t interceptor_id = item.first; int64_t interceptor_id = item.first;
TaskNode* task_node = item.second; TaskNode* task_node = item.second;
// TODO(wangxi): use node_type to select different Interceptor // TODO(wangxi): use node_type to select different Interceptor
auto interceptor = std::make_unique<Interceptor>(interceptor_id, task_node); auto interceptor =
std::make_unique<Interceptor>(interceptor_id, task_node);
SetInterceptor(interceptor_id, std::move(interceptor)); SetInterceptor(interceptor_id, std::move(interceptor));
VLOG(3) << "Create Interceptor for " << interceptor_id; VLOG(3) << "Create Interceptor with interceptor id: " << interceptor_id
<< ".";
}
// The carrier will be always waiting for outside initializer
// since there is no interceptor has been created during auto init
creating_interceptors_ = false;
HandleTmpMessages();
} }
} }
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h" #include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
...@@ -53,6 +54,8 @@ class Carrier final { ...@@ -53,6 +54,8 @@ class Carrier final {
Interceptor* SetInterceptor(int64_t interceptor_id, Interceptor* SetInterceptor(int64_t interceptor_id,
std::unique_ptr<Interceptor>); std::unique_ptr<Interceptor>);
void SetCreatingFlag(bool flag);
DISABLE_COPY_AND_ASSIGN(Carrier); DISABLE_COPY_AND_ASSIGN(Carrier);
private: private:
...@@ -61,12 +64,17 @@ class Carrier final { ...@@ -61,12 +64,17 @@ class Carrier final {
// create each Interceptor // create each Interceptor
void CreateInterceptors(); void CreateInterceptors();
void HandleTmpMessages();
// interceptor logic id to the Nodes info // interceptor logic id to the Nodes info
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_node_; std::unordered_map<int64_t, TaskNode*> interceptor_id_to_node_;
// interceptor logic id to actually interceptor // interceptor logic id to actually interceptor
std::unordered_map<int64_t, std::unique_ptr<Interceptor>> std::unordered_map<int64_t, std::unique_ptr<Interceptor>>
interceptor_idx_to_interceptor_; interceptor_idx_to_interceptor_;
std::vector<InterceptorMessage> message_tmp_{};
bool creating_interceptors_{true};
}; };
} // namespace distributed } // namespace distributed
......
...@@ -56,10 +56,10 @@ bool Interceptor::EnqueueRemoteInterceptorMessage( ...@@ -56,10 +56,10 @@ bool Interceptor::EnqueueRemoteInterceptorMessage(
return true; return true;
} }
void Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) { bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) {
msg.set_src_id(interceptor_id_); msg.set_src_id(interceptor_id_);
msg.set_dst_id(dst_id); msg.set_dst_id(dst_id);
MessageBus::Instance().Send(msg); return MessageBus::Instance().Send(msg);
} }
void Interceptor::PoolTheMailbox() { void Interceptor::PoolTheMailbox() {
...@@ -76,10 +76,12 @@ void Interceptor::PoolTheMailbox() { ...@@ -76,10 +76,12 @@ void Interceptor::PoolTheMailbox() {
const InterceptorMessage interceptor_message = local_mailbox_.front(); const InterceptorMessage interceptor_message = local_mailbox_.front();
local_mailbox_.pop(); local_mailbox_.pop();
const MessageType message_type = interceptor_message.message_type(); const MessageType message_type = interceptor_message.message_type();
VLOG(3) << interceptor_id_ << " has received a message: " << message_type VLOG(3) << "Interceptor " << interceptor_id_ << " has received a message"
<< "."; << " from interceptor " << interceptor_message.src_id()
<< " with message: " << message_type << ".";
if (message_type == STOP) { if (message_type == STOP) {
// break the pooling thread // break the pooling thread
VLOG(3) << "Interceptor " << interceptor_id_ << " is quiting.";
break; break;
} }
......
...@@ -58,7 +58,7 @@ class Interceptor { ...@@ -58,7 +58,7 @@ class Interceptor {
bool EnqueueRemoteInterceptorMessage( bool EnqueueRemoteInterceptorMessage(
const InterceptorMessage& interceptor_message); const InterceptorMessage& interceptor_message);
void Send(int64_t dst_id, InterceptorMessage& msg); // NOLINT bool Send(int64_t dst_id, InterceptorMessage& msg); // NOLINT
DISABLE_COPY_AND_ASSIGN(Interceptor); DISABLE_COPY_AND_ASSIGN(Interceptor);
......
...@@ -26,8 +26,8 @@ void InterceptorMessageServiceImpl::InterceptorMessageService( ...@@ -26,8 +26,8 @@ void InterceptorMessageServiceImpl::InterceptorMessageService(
const InterceptorMessage* request, InterceptorResponse* response, const InterceptorMessage* request, InterceptorResponse* response,
google::protobuf::Closure* done) { google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done); brpc::ClosureGuard done_guard(done);
VLOG(3) << "Interceptor Message Service receives a message from: " VLOG(3) << "Interceptor Message Service receives a message from interceptor "
<< request->src_id() << request->src_id() << " to interceptor " << request->dst_id()
<< ", with the message: " << request->message_type(); << ", with the message: " << request->message_type();
response->set_rst(true); response->set_rst(true);
// call interceptor manager's method to handle the message // call interceptor manager's method to handle the message
......
...@@ -55,12 +55,13 @@ bool MessageBus::Send(const InterceptorMessage& interceptor_message) { ...@@ -55,12 +55,13 @@ bool MessageBus::Send(const InterceptorMessage& interceptor_message) {
int64_t src_id = interceptor_message.src_id(); int64_t src_id = interceptor_message.src_id();
int64_t dst_id = interceptor_message.dst_id(); int64_t dst_id = interceptor_message.dst_id();
if (IsSameRank(src_id, dst_id)) { if (IsSameRank(src_id, dst_id)) {
VLOG(3) << "Send a message from rank " << src_id << " to rank " << dst_id VLOG(3) << "Send a message from interceptor " << src_id
<< ", which are same ranks."; << " to interceptor " << dst_id << ", which are in the same ranks.";
return SendIntraRank(interceptor_message); return SendIntraRank(interceptor_message);
} else { } else {
VLOG(3) << "Send a message from rank " << src_id << " to rank " << dst_id VLOG(3) << "Send a message from interceptor " << src_id
<< ", which are different ranks."; << " to interceptor " << dst_id
<< ", which are in different ranks.";
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL) !defined(PADDLE_WITH_ASCEND_CL)
int retry_time = 0; // message bus will retry sending for 10 times int retry_time = 0; // message bus will retry sending for 10 times
...@@ -155,6 +156,7 @@ bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) { ...@@ -155,6 +156,7 @@ bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) {
"Cannot find rank for dst interceptor id %lld. " "Cannot find rank for dst interceptor id %lld. "
"Init error.", "Init error.",
dst_id)); dst_id));
VLOG(3) << "Message bus sending to addr: " << dst_ip->second;
const char* dst_ip_for_brpc = dst_ip->second.c_str(); const char* dst_ip_for_brpc = dst_ip->second.c_str();
brpc::Channel channel; brpc::Channel channel;
brpc::ChannelOptions options; brpc::ChannelOptions options;
......
...@@ -59,8 +59,8 @@ TEST(InterceptorTest, PingPong) { ...@@ -59,8 +59,8 @@ TEST(InterceptorTest, PingPong) {
Interceptor* a = carrier.SetInterceptor( Interceptor* a = carrier.SetInterceptor(
0, std::make_unique<PingPongInterceptor>(0, nullptr)); 0, std::make_unique<PingPongInterceptor>(0, nullptr));
carrier.SetInterceptor(1, std::make_unique<PingPongInterceptor>(1, nullptr)); carrier.SetInterceptor(1, std::make_unique<PingPongInterceptor>(1, nullptr));
carrier.SetCreatingFlag(false);
InterceptorMessage msg; InterceptorMessage msg;
a->Send(1, msg); a->Send(1, msg);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册