未验证 提交 cd2855b0 编写于 作者: L LiYuRio 提交者: GitHub

[fleet_executor] Add barrier rpc (#38799)

上级 492e6dd0
......@@ -13,7 +13,7 @@ endif()
cc_library(task_loop_thread_pool SRCS task_loop_thread_pool.cc task_loop_thread.cc task_loop.cc DEPS enforce glog)
cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc
interceptor.cc compute_interceptor.cc amplifier_interceptor.cc interceptor_message_service.cc message_bus.cc
interceptor.cc compute_interceptor.cc amplifier_interceptor.cc message_service.cc message_bus.cc
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto task_loop_thread_pool collective_helper
op_registry executor_gc_helper gflags glog ${BRPC_DEPS})
......@@ -29,8 +29,8 @@ if(WITH_DISTRIBUTE)
set_source_files_properties(message_bus.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(fleet_executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(carrier.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(interceptor_message_service.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(interceptor_message_service.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(message_service.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(message_service.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
add_subdirectory(test)
endif()
......@@ -15,7 +15,6 @@
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
......
......@@ -137,7 +137,6 @@ void FleetExecutor::Run(const std::string& carrier_id) {
// Set current running carrier
if (*GlobalVal<std::string>::Get() != carrier_id) {
GlobalVal<std::string>::Set(new std::string(carrier_id));
// TODO(liyurui): Move barrier to service
GlobalVal<MessageBus>::Get()->Barrier();
}
carrier->Start();
......
......@@ -34,7 +34,8 @@ message InterceptorMessage {
message InterceptorResponse { optional bool rst = 1 [ default = false ]; }
service TheInterceptorMessageService {
rpc InterceptorMessageService(InterceptorMessage)
service MessageService {
rpc ReceiveInterceptorMessage(InterceptorMessage)
returns (InterceptorResponse);
rpc IncreaseBarrierCount(InterceptorMessage) returns (InterceptorResponse);
}
......@@ -163,18 +163,9 @@ void MessageBus::Barrier() {
bool MessageBus::DispatchMsgToCarrier(
const InterceptorMessage& interceptor_message) {
if (interceptor_message.ctrl_message()) {
VLOG(3) << "Receiving control message from rank "
<< interceptor_message.src_id() << " to rank "
<< interceptor_message.dst_id();
// for barrier
IncreaseBarrierCount();
return true;
} else {
const std::string& carrier_id = *GlobalVal<std::string>::Get();
return GlobalMap<std::string, Carrier>::Get(carrier_id)
->EnqueueInterceptorMessage(interceptor_message);
}
const std::string& carrier_id = *GlobalVal<std::string>::Get();
return GlobalMap<std::string, Carrier>::Get(carrier_id)
->EnqueueInterceptorMessage(interceptor_message);
}
void MessageBus::ListenPort() {
......@@ -185,10 +176,9 @@ void MessageBus::ListenPort() {
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
// function keep listen the port and handle the message
PADDLE_ENFORCE_EQ(server_.AddService(&interceptor_message_service_,
brpc::SERVER_DOESNT_OWN_SERVICE),
0, platform::errors::Unavailable(
"Message bus: init brpc service error."));
PADDLE_ENFORCE_EQ(
server_.AddService(&message_service_, brpc::SERVER_DOESNT_OWN_SERVICE), 0,
platform::errors::Unavailable("Message bus: init brpc service error."));
// start the server
const char* ip_for_brpc = addr_.c_str();
......@@ -229,11 +219,16 @@ bool MessageBus::SendInterRank(int64_t dst_rank,
PADDLE_ENFORCE_EQ(
channel.Init(dst_addr_for_brpc, &options), 0,
platform::errors::Unavailable("Message bus: init brpc channel error."));
TheInterceptorMessageService_Stub stub(&channel);
MessageService_Stub stub(&channel);
InterceptorResponse response;
brpc::Controller ctrl;
ctrl.set_log_id(0);
stub.InterceptorMessageService(&ctrl, &interceptor_message, &response, NULL);
if (interceptor_message.ctrl_message()) {
stub.IncreaseBarrierCount(&ctrl, &interceptor_message, &response, NULL);
} else {
stub.ReceiveInterceptorMessage(&ctrl, &interceptor_message, &response,
NULL);
}
if (!ctrl.Failed()) {
if (response.rst()) {
VLOG(3) << "Message bus: brpc sends success.";
......@@ -248,6 +243,7 @@ bool MessageBus::SendInterRank(int64_t dst_rank,
return false;
}
}
#endif
} // namespace distributed
......
......@@ -24,7 +24,7 @@
!defined(PADDLE_WITH_ASCEND_CL)
#include "brpc/channel.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h"
#include "paddle/fluid/distributed/fleet_executor/message_service.h"
#endif
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
......@@ -83,7 +83,7 @@ class MessageBus final {
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
InterceptorMessageServiceImpl interceptor_message_service_;
MessageServiceImpl message_service_;
// brpc server
brpc::Server server_;
#endif
......
......@@ -13,7 +13,7 @@
// limitations under the License.
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h"
#include "paddle/fluid/distributed/fleet_executor/message_service.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
......@@ -21,18 +21,29 @@
namespace paddle {
namespace distributed {
void InterceptorMessageServiceImpl::InterceptorMessageService(
void MessageServiceImpl::ReceiveInterceptorMessage(
google::protobuf::RpcController* control_base,
const InterceptorMessage* request, InterceptorResponse* response,
google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done);
VLOG(3) << "Interceptor Message Service receives a message from interceptor "
VLOG(3) << "Message Service receives a message from interceptor "
<< request->src_id() << " to interceptor " << request->dst_id()
<< ", with the message: " << request->message_type();
bool flag = GlobalVal<MessageBus>::Get()->DispatchMsgToCarrier(*request);
response->set_rst(flag);
}
void MessageServiceImpl::IncreaseBarrierCount(
google::protobuf::RpcController* control_base,
const InterceptorMessage* request, InterceptorResponse* response,
google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done);
VLOG(3) << "Barrier Service receives a message from rank "
<< request->src_id() << " to rank " << request->dst_id();
GlobalVal<MessageBus>::Get()->IncreaseBarrierCount();
response->set_rst(true);
}
} // namespace distributed
} // namespace paddle
#endif
......@@ -21,11 +21,15 @@
namespace paddle {
namespace distributed {
class InterceptorMessageServiceImpl : public TheInterceptorMessageService {
class MessageServiceImpl : public MessageService {
public:
InterceptorMessageServiceImpl() {}
virtual ~InterceptorMessageServiceImpl() {}
virtual void InterceptorMessageService(
MessageServiceImpl() {}
virtual ~MessageServiceImpl() {}
virtual void ReceiveInterceptorMessage(
google::protobuf::RpcController* control_base,
const InterceptorMessage* request, InterceptorResponse* response,
google::protobuf::Closure* done);
virtual void IncreaseBarrierCount(
google::protobuf::RpcController* control_base,
const InterceptorMessage* request, InterceptorResponse* response,
google::protobuf::Closure* done);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册