未验证 提交 cd2855b0 编写于 作者: L LiYuRio 提交者: GitHub

[fleet_executor] Add barrier rpc (#38799)

上级 492e6dd0
...@@ -13,7 +13,7 @@ endif() ...@@ -13,7 +13,7 @@ endif()
cc_library(task_loop_thread_pool SRCS task_loop_thread_pool.cc task_loop_thread.cc task_loop.cc DEPS enforce glog) cc_library(task_loop_thread_pool SRCS task_loop_thread_pool.cc task_loop_thread.cc task_loop.cc DEPS enforce glog)
cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc
interceptor.cc compute_interceptor.cc amplifier_interceptor.cc interceptor_message_service.cc message_bus.cc interceptor.cc compute_interceptor.cc amplifier_interceptor.cc message_service.cc message_bus.cc
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto task_loop_thread_pool collective_helper DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto task_loop_thread_pool collective_helper
op_registry executor_gc_helper gflags glog ${BRPC_DEPS}) op_registry executor_gc_helper gflags glog ${BRPC_DEPS})
...@@ -29,8 +29,8 @@ if(WITH_DISTRIBUTE) ...@@ -29,8 +29,8 @@ if(WITH_DISTRIBUTE)
set_source_files_properties(message_bus.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(message_bus.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(fleet_executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(fleet_executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(carrier.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(carrier.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(interceptor_message_service.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(message_service.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(interceptor_message_service.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(message_service.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
add_subdirectory(test) add_subdirectory(test)
endif() endif()
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
#include "paddle/fluid/distributed/fleet_executor/carrier.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global.h" #include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h" #include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h"
......
...@@ -137,7 +137,6 @@ void FleetExecutor::Run(const std::string& carrier_id) { ...@@ -137,7 +137,6 @@ void FleetExecutor::Run(const std::string& carrier_id) {
// Set current running carrier // Set current running carrier
if (*GlobalVal<std::string>::Get() != carrier_id) { if (*GlobalVal<std::string>::Get() != carrier_id) {
GlobalVal<std::string>::Set(new std::string(carrier_id)); GlobalVal<std::string>::Set(new std::string(carrier_id));
// TODO(liyurui): Move barrier to service
GlobalVal<MessageBus>::Get()->Barrier(); GlobalVal<MessageBus>::Get()->Barrier();
} }
carrier->Start(); carrier->Start();
......
...@@ -34,7 +34,8 @@ message InterceptorMessage { ...@@ -34,7 +34,8 @@ message InterceptorMessage {
message InterceptorResponse { optional bool rst = 1 [ default = false ]; } message InterceptorResponse { optional bool rst = 1 [ default = false ]; }
service TheInterceptorMessageService { service MessageService {
rpc InterceptorMessageService(InterceptorMessage) rpc ReceiveInterceptorMessage(InterceptorMessage)
returns (InterceptorResponse); returns (InterceptorResponse);
rpc IncreaseBarrierCount(InterceptorMessage) returns (InterceptorResponse);
} }
...@@ -163,18 +163,9 @@ void MessageBus::Barrier() { ...@@ -163,18 +163,9 @@ void MessageBus::Barrier() {
bool MessageBus::DispatchMsgToCarrier( bool MessageBus::DispatchMsgToCarrier(
const InterceptorMessage& interceptor_message) { const InterceptorMessage& interceptor_message) {
if (interceptor_message.ctrl_message()) {
VLOG(3) << "Receiving control message from rank "
<< interceptor_message.src_id() << " to rank "
<< interceptor_message.dst_id();
// for barrier
IncreaseBarrierCount();
return true;
} else {
const std::string& carrier_id = *GlobalVal<std::string>::Get(); const std::string& carrier_id = *GlobalVal<std::string>::Get();
return GlobalMap<std::string, Carrier>::Get(carrier_id) return GlobalMap<std::string, Carrier>::Get(carrier_id)
->EnqueueInterceptorMessage(interceptor_message); ->EnqueueInterceptorMessage(interceptor_message);
}
} }
void MessageBus::ListenPort() { void MessageBus::ListenPort() {
...@@ -185,10 +176,9 @@ void MessageBus::ListenPort() { ...@@ -185,10 +176,9 @@ void MessageBus::ListenPort() {
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL) !defined(PADDLE_WITH_ASCEND_CL)
// function keep listen the port and handle the message // function keep listen the port and handle the message
PADDLE_ENFORCE_EQ(server_.AddService(&interceptor_message_service_, PADDLE_ENFORCE_EQ(
brpc::SERVER_DOESNT_OWN_SERVICE), server_.AddService(&message_service_, brpc::SERVER_DOESNT_OWN_SERVICE), 0,
0, platform::errors::Unavailable( platform::errors::Unavailable("Message bus: init brpc service error."));
"Message bus: init brpc service error."));
// start the server // start the server
const char* ip_for_brpc = addr_.c_str(); const char* ip_for_brpc = addr_.c_str();
...@@ -229,11 +219,16 @@ bool MessageBus::SendInterRank(int64_t dst_rank, ...@@ -229,11 +219,16 @@ bool MessageBus::SendInterRank(int64_t dst_rank,
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
channel.Init(dst_addr_for_brpc, &options), 0, channel.Init(dst_addr_for_brpc, &options), 0,
platform::errors::Unavailable("Message bus: init brpc channel error.")); platform::errors::Unavailable("Message bus: init brpc channel error."));
TheInterceptorMessageService_Stub stub(&channel); MessageService_Stub stub(&channel);
InterceptorResponse response; InterceptorResponse response;
brpc::Controller ctrl; brpc::Controller ctrl;
ctrl.set_log_id(0); ctrl.set_log_id(0);
stub.InterceptorMessageService(&ctrl, &interceptor_message, &response, NULL); if (interceptor_message.ctrl_message()) {
stub.IncreaseBarrierCount(&ctrl, &interceptor_message, &response, NULL);
} else {
stub.ReceiveInterceptorMessage(&ctrl, &interceptor_message, &response,
NULL);
}
if (!ctrl.Failed()) { if (!ctrl.Failed()) {
if (response.rst()) { if (response.rst()) {
VLOG(3) << "Message bus: brpc sends success."; VLOG(3) << "Message bus: brpc sends success.";
...@@ -248,6 +243,7 @@ bool MessageBus::SendInterRank(int64_t dst_rank, ...@@ -248,6 +243,7 @@ bool MessageBus::SendInterRank(int64_t dst_rank,
return false; return false;
} }
} }
#endif #endif
} // namespace distributed } // namespace distributed
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
!defined(PADDLE_WITH_ASCEND_CL) !defined(PADDLE_WITH_ASCEND_CL)
#include "brpc/channel.h" #include "brpc/channel.h"
#include "brpc/server.h" #include "brpc/server.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h" #include "paddle/fluid/distributed/fleet_executor/message_service.h"
#endif #endif
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h" #include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
...@@ -83,7 +83,7 @@ class MessageBus final { ...@@ -83,7 +83,7 @@ class MessageBus final {
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL) !defined(PADDLE_WITH_ASCEND_CL)
InterceptorMessageServiceImpl interceptor_message_service_; MessageServiceImpl message_service_;
// brpc server // brpc server
brpc::Server server_; brpc::Server server_;
#endif #endif
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL) !defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h" #include "paddle/fluid/distributed/fleet_executor/message_service.h"
#include "brpc/server.h" #include "brpc/server.h"
#include "paddle/fluid/distributed/fleet_executor/global.h" #include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h"
...@@ -21,18 +21,29 @@ ...@@ -21,18 +21,29 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
void InterceptorMessageServiceImpl::InterceptorMessageService( void MessageServiceImpl::ReceiveInterceptorMessage(
google::protobuf::RpcController* control_base, google::protobuf::RpcController* control_base,
const InterceptorMessage* request, InterceptorResponse* response, const InterceptorMessage* request, InterceptorResponse* response,
google::protobuf::Closure* done) { google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done); brpc::ClosureGuard done_guard(done);
VLOG(3) << "Interceptor Message Service receives a message from interceptor " VLOG(3) << "Message Service receives a message from interceptor "
<< request->src_id() << " to interceptor " << request->dst_id() << request->src_id() << " to interceptor " << request->dst_id()
<< ", with the message: " << request->message_type(); << ", with the message: " << request->message_type();
bool flag = GlobalVal<MessageBus>::Get()->DispatchMsgToCarrier(*request); bool flag = GlobalVal<MessageBus>::Get()->DispatchMsgToCarrier(*request);
response->set_rst(flag); response->set_rst(flag);
} }
void MessageServiceImpl::IncreaseBarrierCount(
google::protobuf::RpcController* control_base,
const InterceptorMessage* request, InterceptorResponse* response,
google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done);
VLOG(3) << "Barrier Service receives a message from rank "
<< request->src_id() << " to rank " << request->dst_id();
GlobalVal<MessageBus>::Get()->IncreaseBarrierCount();
response->set_rst(true);
}
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
#endif #endif
...@@ -21,11 +21,15 @@ ...@@ -21,11 +21,15 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
class InterceptorMessageServiceImpl : public TheInterceptorMessageService { class MessageServiceImpl : public MessageService {
public: public:
InterceptorMessageServiceImpl() {} MessageServiceImpl() {}
virtual ~InterceptorMessageServiceImpl() {} virtual ~MessageServiceImpl() {}
virtual void InterceptorMessageService( virtual void ReceiveInterceptorMessage(
google::protobuf::RpcController* control_base,
const InterceptorMessage* request, InterceptorResponse* response,
google::protobuf::Closure* done);
virtual void IncreaseBarrierCount(
google::protobuf::RpcController* control_base, google::protobuf::RpcController* control_base,
const InterceptorMessage* request, InterceptorResponse* response, const InterceptorMessage* request, InterceptorResponse* response,
google::protobuf::Closure* done); google::protobuf::Closure* done);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册