diff --git a/paddle/fluid/distributed/fleet_executor/CMakeLists.txt b/paddle/fluid/distributed/fleet_executor/CMakeLists.txt index 4114cb08119dba636b1504292a49c6cb426c3baf..adcf3c5e41defa5b890949bc25bf5f3d99932388 100644 --- a/paddle/fluid/distributed/fleet_executor/CMakeLists.txt +++ b/paddle/fluid/distributed/fleet_executor/CMakeLists.txt @@ -5,7 +5,7 @@ endif() proto_library(interceptor_message_proto SRCS interceptor_message.proto) if(WITH_DISTRIBUTE AND NOT (WITH_ASCEND OR WITH_ASCEND_CL)) - set(BRPC_DEPS brpc ssl crypto) + set(BRPC_DEPS brpc ssl crypto protobuf gflags glog zlib leveldb snappy gflags glog) else() set(BRPC_DEPS "") endif() @@ -23,4 +23,6 @@ if(WITH_DISTRIBUTE) set_source_files_properties(carrier.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(interceptor_message_service.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(interceptor_message_service.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + + add_subdirectory(test) endif() diff --git a/paddle/fluid/distributed/fleet_executor/carrier.cc b/paddle/fluid/distributed/fleet_executor/carrier.cc index 53a3af22c45e7e6495b0b3dac249bf97bc52ed45..0e79656edea091b9dea48d4415d228ab88270758 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.cc +++ b/paddle/fluid/distributed/fleet_executor/carrier.cc @@ -20,9 +20,9 @@ namespace paddle { namespace distributed { -Carrier::Carrier( - const std::unordered_map& interceptor_id_to_node) - : interceptor_id_to_node_(interceptor_id_to_node) { +void Carrier::Init( + const std::unordered_map& interceptor_id_to_node) { + interceptor_id_to_node_ = interceptor_id_to_node; CreateInterceptors(); } @@ -56,20 +56,29 @@ Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) { return iter->second.get(); } +Interceptor* Carrier::SetInterceptor(int64_t interceptor_id, + std::unique_ptr interceptor) { + auto iter = interceptor_idx_to_interceptor_.find(interceptor_id); + PADDLE_ENFORCE_EQ(iter, interceptor_idx_to_interceptor_.end(), + platform::errors::AlreadyExists( + "The interceptor id %lld has already been created! " + "The interceptor id should be unique.", + interceptor_id)); + auto* ptr = interceptor.get(); + interceptor_idx_to_interceptor_.insert( + std::make_pair(interceptor_id, std::move(interceptor))); + return ptr; +} + void Carrier::CreateInterceptors() { // create each Interceptor for (const auto& item : interceptor_id_to_node_) { int64_t interceptor_id = item.first; TaskNode* task_node = item.second; - const auto& iter = interceptor_idx_to_interceptor_.find(interceptor_id); - PADDLE_ENFORCE_EQ(iter, interceptor_idx_to_interceptor_.end(), - platform::errors::AlreadyExists( - "The interceptor id %lld has already been created! " - "The interceptor is should be unique.", - interceptor_id)); - interceptor_idx_to_interceptor_.insert(std::make_pair( - interceptor_id, - std::make_unique(interceptor_id, task_node))); + + // TODO(wangxi): use node_type to select different Interceptor + auto interceptor = std::make_unique(interceptor_id, task_node); + SetInterceptor(interceptor_id, std::move(interceptor)); VLOG(3) << "Create Interceptor for " << interceptor_id; } } diff --git a/paddle/fluid/distributed/fleet_executor/carrier.h b/paddle/fluid/distributed/fleet_executor/carrier.h index bac836deaaaf7f184dfcfe824ace1d97469d059e..64974714f7b1c779c8a078f51c412c798090650b 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.h +++ b/paddle/fluid/distributed/fleet_executor/carrier.h @@ -18,6 +18,7 @@ #include #include +#include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/errors.h" @@ -26,15 +27,18 @@ namespace paddle { namespace distributed { -class Interceptor; class TaskNode; class InterceptorMessageServiceImpl; +// A singleton MessageBus class Carrier final { public: - Carrier() = delete; + static Carrier& Instance() { + static Carrier carrier; + return carrier; + } - explicit Carrier( + void Init( const std::unordered_map& interceptor_id_to_node); ~Carrier() = default; @@ -42,15 +46,21 @@ class Carrier final { // Enqueue a message to corresponding interceptor id bool EnqueueInterceptorMessage(const InterceptorMessage& interceptor_message); + // get interceptor based on the interceptor id + Interceptor* GetInterceptor(int64_t interceptor_id); + + // set interceptor with interceptor id + Interceptor* SetInterceptor(int64_t interceptor_id, + std::unique_ptr); + DISABLE_COPY_AND_ASSIGN(Carrier); private: + Carrier() = default; + // create each Interceptor void CreateInterceptors(); - // get interceptor based on the interceptor id - Interceptor* GetInterceptor(int64_t interceptor_id); - // interceptor logic id to the Nodes info std::unordered_map interceptor_id_to_node_; diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc index eed6d6ef7e47e106a4ccb222d7a120f6fd0956cf..47d0c526c03ab3c08e31c92463d30fc6dcc4bef1 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc @@ -76,10 +76,5 @@ void FleetExecutor::Release() { // Release } -std::shared_ptr FleetExecutor::GetCarrier() { - // get carrier - return nullptr; -} - } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.h b/paddle/fluid/distributed/fleet_executor/fleet_executor.h index 242e1a74fc489d270c86da1fcd01582723c41e9e..779d2f91221dfa4cb97f86315ac54c4e7a304353 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.h +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.h @@ -43,7 +43,6 @@ class FleetExecutor final { FleetExecutorDesc exe_desc_; std::unique_ptr runtime_graph_; void InitMessageBus(); - static std::shared_ptr global_carrier_; }; } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.cc b/paddle/fluid/distributed/fleet_executor/interceptor.cc index 03f04d8340f0acf68f5488f7410ecfb7ef8645cc..dbee46afcf86fac3f2b47f7d6d729afe2968407b 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/interceptor.cc @@ -56,11 +56,10 @@ bool Interceptor::EnqueueRemoteInterceptorMessage( return true; } -void Interceptor::Send(int64_t dst_id, - std::unique_ptr msg) { - msg->set_src_id(interceptor_id_); - msg->set_dst_id(dst_id); - MessageBus::Instance().Send(*msg.get()); +void Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) { + msg.set_src_id(interceptor_id_); + msg.set_dst_id(dst_id); + MessageBus::Instance().Send(msg); } void Interceptor::PoolTheMailbox() { diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.h b/paddle/fluid/distributed/fleet_executor/interceptor.h index 7744ecbb11026cb30816516ac66939ddeda3857e..24fad8331863e29099c53d7e536d2dffb9df3d12 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/interceptor.h @@ -58,7 +58,7 @@ class Interceptor { bool EnqueueRemoteInterceptorMessage( const InterceptorMessage& interceptor_message); - void Send(int64_t dst_id, std::unique_ptr msg); + void Send(int64_t dst_id, InterceptorMessage& msg); // NOLINT DISABLE_COPY_AND_ASSIGN(Interceptor); diff --git a/paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc b/paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc index d30d356e4ff28f54728efc6ded5a6a59eb0f23a4..2205c6e5544bb5128f329edc0844c795e66462ea 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc +++ b/paddle/fluid/distributed/fleet_executor/interceptor_message_service.cc @@ -31,10 +31,7 @@ void InterceptorMessageServiceImpl::InterceptorMessageService( << ", with the message: " << request->message_type(); response->set_rst(true); // call interceptor manager's method to handle the message - std::shared_ptr carrier = FleetExecutor::GetCarrier(); - if (carrier != nullptr) { - carrier->EnqueueInterceptorMessage(*request); - } + Carrier::Instance().EnqueueInterceptorMessage(*request); } } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/message_bus.cc b/paddle/fluid/distributed/fleet_executor/message_bus.cc index 75e7b2fcb3dc5b5493d38067226abcce13a1352c..2a8afb99ba7a9034ee6d0afde2f8b53779b87197 100644 --- a/paddle/fluid/distributed/fleet_executor/message_bus.cc +++ b/paddle/fluid/distributed/fleet_executor/message_bus.cc @@ -32,10 +32,7 @@ void MessageBus::Init( rank_to_addr_ = rank_to_addr; addr_ = addr; - listen_port_thread_ = std::thread([this]() { - VLOG(3) << "Start listen_port_thread_ for message bus"; - ListenPort(); - }); + ListenPort(); std::call_once(once_flag_, []() { std::atexit([]() { MessageBus::Instance().Release(); }); @@ -51,7 +48,6 @@ void MessageBus::Release() { server_.Stop(1000); server_.Join(); #endif - listen_port_thread_.join(); } bool MessageBus::Send(const InterceptorMessage& interceptor_message) { @@ -184,11 +180,7 @@ bool MessageBus::SendInterRank(const InterceptorMessage& interceptor_message) { bool MessageBus::SendIntraRank(const InterceptorMessage& interceptor_message) { // send the message intra rank (dst is the same rank with src) - std::shared_ptr carrier = FleetExecutor::GetCarrier(); - if (carrier != nullptr) { - return carrier->EnqueueInterceptorMessage(interceptor_message); - } - return true; + return Carrier::Instance().EnqueueInterceptorMessage(interceptor_message); } } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/message_bus.h b/paddle/fluid/distributed/fleet_executor/message_bus.h index e45f2e3c7125955c310aae6f43b207b7bbc44382..9212a93df425fe7e279b7ef83a6eb4ed96b773ca 100644 --- a/paddle/fluid/distributed/fleet_executor/message_bus.h +++ b/paddle/fluid/distributed/fleet_executor/message_bus.h @@ -92,10 +92,6 @@ class MessageBus final { // brpc server brpc::Server server_; #endif - - // thread keeps listening to the port to receive remote message - // this thread runs ListenPort() function - std::thread listen_port_thread_; }; } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt b/paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..524aebe3b959f5f5833376765291e4694145ca39 --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt @@ -0,0 +1,2 @@ +set_source_files_properties(interceptor_ping_pong_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +cc_test(interceptor_ping_pong_test SRCS interceptor_ping_pong_test.cc DEPS fleet_executor ${BRPC_DEPS}) diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..856bbb4754738efd4bde8a569e38df1716711d23 --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc @@ -0,0 +1,70 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "gtest/gtest.h" + +#include "paddle/fluid/distributed/fleet_executor/carrier.h" +#include "paddle/fluid/distributed/fleet_executor/interceptor.h" +#include "paddle/fluid/distributed/fleet_executor/message_bus.h" + +namespace paddle { +namespace distributed { + +class PingPongInterceptor : public Interceptor { + public: + PingPongInterceptor(int64_t interceptor_id, TaskNode* node) + : Interceptor(interceptor_id, node) { + RegisterMsgHandle([this](const InterceptorMessage& msg) { PingPong(msg); }); + } + + void PingPong(const InterceptorMessage& msg) { + std::cout << GetInterceptorId() << " recv msg, count=" << count_ + << std::endl; + ++count_; + if (count_ == 20) { + InterceptorMessage stop; + stop.set_message_type(STOP); + Send(0, stop); + Send(1, stop); + return; + } + + InterceptorMessage resp; + Send(msg.src_id(), resp); + } + + private: + int count_{0}; +}; + +TEST(InterceptorTest, PingPong) { + MessageBus& msg_bus = MessageBus::Instance(); + msg_bus.Init({{0, 0}, {1, 0}}, {{0, "127.0.0.0:0"}}, "127.0.0.0:0"); + + Carrier& carrier = Carrier::Instance(); + + Interceptor* a = carrier.SetInterceptor( + 0, std::make_unique(0, nullptr)); + + carrier.SetInterceptor(1, std::make_unique(1, nullptr)); + + InterceptorMessage msg; + a->Send(1, msg); +} + +} // namespace distributed +} // namespace paddle