From cd2855b0626a4cce979ce58587c942b0b0304691 Mon Sep 17 00:00:00 2001 From: LiYuRio <63526175+LiYuRio@users.noreply.github.com> Date: Mon, 10 Jan 2022 10:54:24 +0800 Subject: [PATCH] [fleet_executor] Add barrier rpc (#38799) --- .../distributed/fleet_executor/CMakeLists.txt | 6 ++-- .../distributed/fleet_executor/carrier.cc | 1 - .../fleet_executor/fleet_executor.cc | 1 - .../fleet_executor/interceptor_message.proto | 5 +-- .../distributed/fleet_executor/message_bus.cc | 32 ++++++++----------- .../distributed/fleet_executor/message_bus.h | 4 +-- ..._message_service.cc => message_service.cc} | 17 ++++++++-- ...or_message_service.h => message_service.h} | 12 ++++--- 8 files changed, 44 insertions(+), 34 deletions(-) rename paddle/fluid/distributed/fleet_executor/{interceptor_message_service.cc => message_service.cc} (68%) rename paddle/fluid/distributed/fleet_executor/{interceptor_message_service.h => message_service.h} (75%) diff --git a/paddle/fluid/distributed/fleet_executor/CMakeLists.txt b/paddle/fluid/distributed/fleet_executor/CMakeLists.txt index e9da55c417..d8372e1088 100644 --- a/paddle/fluid/distributed/fleet_executor/CMakeLists.txt +++ b/paddle/fluid/distributed/fleet_executor/CMakeLists.txt @@ -13,7 +13,7 @@ endif() cc_library(task_loop_thread_pool SRCS task_loop_thread_pool.cc task_loop_thread.cc task_loop.cc DEPS enforce glog) cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc - interceptor.cc compute_interceptor.cc amplifier_interceptor.cc interceptor_message_service.cc message_bus.cc + interceptor.cc compute_interceptor.cc amplifier_interceptor.cc message_service.cc message_bus.cc DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto task_loop_thread_pool collective_helper op_registry executor_gc_helper gflags glog ${BRPC_DEPS}) @@ -29,8 +29,8 @@ if(WITH_DISTRIBUTE) set_source_files_properties(message_bus.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(fleet_executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(carrier.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) - set_source_files_properties(interceptor_message_service.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) - set_source_files_properties(interceptor_message_service.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + set_source_files_properties(message_service.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + set_source_files_properties(message_service.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) add_subdirectory(test) endif() diff --git a/paddle/fluid/distributed/fleet_executor/carrier.cc b/paddle/fluid/distributed/fleet_executor/carrier.cc index 79be1824b8..79ca6f467a 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.cc +++ b/paddle/fluid/distributed/fleet_executor/carrier.cc @@ -15,7 +15,6 @@ #include "paddle/fluid/distributed/fleet_executor/carrier.h" #include "paddle/fluid/distributed/fleet_executor/global.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h" -#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/runtime_graph.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h" diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc index e22d0945a2..d6c1e678ad 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc @@ -137,7 +137,6 @@ void FleetExecutor::Run(const std::string& carrier_id) { // Set current running carrier if (*GlobalVal::Get() != carrier_id) { GlobalVal::Set(new std::string(carrier_id)); - // TODO(liyurui): Move barrier to service GlobalVal::Get()->Barrier(); } carrier->Start(); diff --git a/paddle/fluid/distributed/fleet_executor/interceptor_message.proto b/paddle/fluid/distributed/fleet_executor/interceptor_message.proto index c9ab477183..ed38894641 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor_message.proto +++ b/paddle/fluid/distributed/fleet_executor/interceptor_message.proto @@ -34,7 +34,8 @@ message InterceptorMessage { message InterceptorResponse { optional bool rst = 1 [ default = false ]; } -service TheInterceptorMessageService { - rpc InterceptorMessageService(InterceptorMessage) +service MessageService { + rpc ReceiveInterceptorMessage(InterceptorMessage) returns (InterceptorResponse); + rpc IncreaseBarrierCount(InterceptorMessage) returns (InterceptorResponse); } diff --git a/paddle/fluid/distributed/fleet_executor/message_bus.cc b/paddle/fluid/distributed/fleet_executor/message_bus.cc index 110c5feafc..8d2ec5c41d 100644 --- a/paddle/fluid/distributed/fleet_executor/message_bus.cc +++ b/paddle/fluid/distributed/fleet_executor/message_bus.cc @@ -163,18 +163,9 @@ void MessageBus::Barrier() { bool MessageBus::DispatchMsgToCarrier( const InterceptorMessage& interceptor_message) { - if (interceptor_message.ctrl_message()) { - VLOG(3) << "Receiving control message from rank " - << interceptor_message.src_id() << " to rank " - << interceptor_message.dst_id(); - // for barrier - IncreaseBarrierCount(); - return true; - } else { - const std::string& carrier_id = *GlobalVal::Get(); - return GlobalMap::Get(carrier_id) - ->EnqueueInterceptorMessage(interceptor_message); - } + const std::string& carrier_id = *GlobalVal::Get(); + return GlobalMap::Get(carrier_id) + ->EnqueueInterceptorMessage(interceptor_message); } void MessageBus::ListenPort() { @@ -185,10 +176,9 @@ void MessageBus::ListenPort() { #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ !defined(PADDLE_WITH_ASCEND_CL) // function keep listen the port and handle the message - PADDLE_ENFORCE_EQ(server_.AddService(&interceptor_message_service_, - brpc::SERVER_DOESNT_OWN_SERVICE), - 0, platform::errors::Unavailable( - "Message bus: init brpc service error.")); + PADDLE_ENFORCE_EQ( + server_.AddService(&message_service_, brpc::SERVER_DOESNT_OWN_SERVICE), 0, + platform::errors::Unavailable("Message bus: init brpc service error.")); // start the server const char* ip_for_brpc = addr_.c_str(); @@ -229,11 +219,16 @@ bool MessageBus::SendInterRank(int64_t dst_rank, PADDLE_ENFORCE_EQ( channel.Init(dst_addr_for_brpc, &options), 0, platform::errors::Unavailable("Message bus: init brpc channel error.")); - TheInterceptorMessageService_Stub stub(&channel); + MessageService_Stub stub(&channel); InterceptorResponse response; brpc::Controller ctrl; ctrl.set_log_id(0); - stub.InterceptorMessageService(&ctrl, &interceptor_message, &response, NULL); + if (interceptor_message.ctrl_message()) { + stub.IncreaseBarrierCount(&ctrl, &interceptor_message, &response, NULL); + } else { + stub.ReceiveInterceptorMessage(&ctrl, &interceptor_message, &response, + NULL); + } if (!ctrl.Failed()) { if (response.rst()) { VLOG(3) << "Message bus: brpc sends success."; @@ -248,6 +243,7 @@ bool MessageBus::SendInterRank(int64_t dst_rank, return false; } } + #endif } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/message_bus.h b/paddle/fluid/distributed/fleet_executor/message_bus.h index 456cd77e2d..d805ac8160 100644 --- a/paddle/fluid/distributed/fleet_executor/message_bus.h +++ b/paddle/fluid/distributed/fleet_executor/message_bus.h @@ -24,7 +24,7 @@ !defined(PADDLE_WITH_ASCEND_CL) #include "brpc/channel.h" #include "brpc/server.h" -#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h" +#include "paddle/fluid/distributed/fleet_executor/message_service.h" #endif #include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h" @@ -83,7 +83,7 @@ class MessageBus final { #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ !defined(PADDLE_WITH_ASCEND_CL) - InterceptorMessageServiceImpl interceptor_message_service_; + MessageServiceImpl message_service_; // brpc server brpc::Server server_; #endif diff --git a/paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc b/paddle/fluid/distributed/fleet_executor/message_service.cc similarity index 68% rename from paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc rename to paddle/fluid/distributed/fleet_executor/message_service.cc index ce8a73602d..c3fff98f68 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc +++ b/paddle/fluid/distributed/fleet_executor/message_service.cc @@ -13,7 +13,7 @@ // limitations under the License. #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ !defined(PADDLE_WITH_ASCEND_CL) -#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h" +#include "paddle/fluid/distributed/fleet_executor/message_service.h" #include "brpc/server.h" #include "paddle/fluid/distributed/fleet_executor/global.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" @@ -21,18 +21,29 @@ namespace paddle { namespace distributed { -void InterceptorMessageServiceImpl::InterceptorMessageService( +void MessageServiceImpl::ReceiveInterceptorMessage( google::protobuf::RpcController* control_base, const InterceptorMessage* request, InterceptorResponse* response, google::protobuf::Closure* done) { brpc::ClosureGuard done_guard(done); - VLOG(3) << "Interceptor Message Service receives a message from interceptor " + VLOG(3) << "Message Service receives a message from interceptor " << request->src_id() << " to interceptor " << request->dst_id() << ", with the message: " << request->message_type(); bool flag = GlobalVal::Get()->DispatchMsgToCarrier(*request); response->set_rst(flag); } +void MessageServiceImpl::IncreaseBarrierCount( + google::protobuf::RpcController* control_base, + const InterceptorMessage* request, InterceptorResponse* response, + google::protobuf::Closure* done) { + brpc::ClosureGuard done_guard(done); + VLOG(3) << "Barrier Service receives a message from rank " + << request->src_id() << " to rank " << request->dst_id(); + GlobalVal::Get()->IncreaseBarrierCount(); + response->set_rst(true); +} + } // namespace distributed } // namespace paddle #endif diff --git a/paddle/fluid/distributed/fleet_executor/interceptor_message_service.h b/paddle/fluid/distributed/fleet_executor/message_service.h similarity index 75% rename from paddle/fluid/distributed/fleet_executor/interceptor_message_service.h rename to paddle/fluid/distributed/fleet_executor/message_service.h index 0a8dfc861a..02f73471e3 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor_message_service.h +++ b/paddle/fluid/distributed/fleet_executor/message_service.h @@ -21,11 +21,15 @@ namespace paddle { namespace distributed { -class InterceptorMessageServiceImpl : public TheInterceptorMessageService { +class MessageServiceImpl : public MessageService { public: - InterceptorMessageServiceImpl() {} - virtual ~InterceptorMessageServiceImpl() {} - virtual void InterceptorMessageService( + MessageServiceImpl() {} + virtual ~MessageServiceImpl() {} + virtual void ReceiveInterceptorMessage( + google::protobuf::RpcController* control_base, + const InterceptorMessage* request, InterceptorResponse* response, + google::protobuf::Closure* done); + virtual void IncreaseBarrierCount( google::protobuf::RpcController* control_base, const InterceptorMessage* request, InterceptorResponse* response, google::protobuf::Closure* done); -- GitLab